44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ from enum import Enum
78from typing import Optional , Tuple
89
910import torch
1415from executorch .exir .pass_base import PassResult
1516
1617
18+ class InputDimOrder (Enum ):
19+ NCHW = 1
20+ NHWC = 2
21+
22+
1723# TODO(T151254305) use subgraph_rewriter
1824class ChannelsLastTaggedReshapePass (XNNPACKPass ):
1925 """
@@ -84,11 +90,19 @@ def mark_as_nhwc_node(self, node: torch.fx.Node) -> None:
8490 def mark_as_nchw_node (self , node : torch .fx .Node ) -> None :
8591 node .meta [ChannelsLastTaggedReshapePass .XNN_NHWC_NODE ] = False
8692
87- def is_nhwc_node (self , node : torch .fx .Node ) -> bool :
93+ def tag_node (self , node : torch .fx .Node ) -> None :
94+ if node .kwargs ["memory_format" ] == torch .channels_last :
95+ self .mark_as_nhwc_node (node )
96+ else :
97+ self .mark_as_nchw_node (node )
98+
99+ @staticmethod
100+ def is_nhwc_node (node : torch .fx .Node ) -> bool :
88101 return node .meta .get (ChannelsLastTaggedReshapePass .XNN_NHWC_NODE , False )
89102
90- def is_nchw_node (self , node : torch .fx .Node ) -> bool :
91- return not self .is_nhwc_node (node )
103+ @staticmethod
104+ def is_nchw_node (node : torch .fx .Node ) -> bool :
105+ return not ChannelsLastTaggedReshapePass .is_nhwc_node (node )
92106
93107 def requires_nhwc_input (self , node : torch .fx .Node ) -> bool :
94108 return node .target in self .memory_sensitive_ops_nhwc
@@ -106,7 +120,7 @@ def can_be_converted_to_nhwc(self, node: torch.fx.Node) -> bool:
106120 is_nchw_constant = (
107121 is_param_node (self .exported_program , node )
108122 and (ChannelsLastTaggedReshapePass .XNN_NHWC_NODE in node .meta )
109- and (self .is_nchw_node (node ))
123+ and (ChannelsLastTaggedReshapePass .is_nchw_node (node ))
110124 )
111125 return is_4d and not is_nchw_constant
112126
@@ -249,6 +263,22 @@ def insert_copy_and_assign_partner_nodes_quantization_sensitive(
249263 # in that case
250264 self .make_partners (original_input , copy_node )
251265
266+ def input_dim_order (
267+ self , input_node : torch .fx .Node , input_order : InputDimOrder
268+ ) -> bool :
269+ if input_node .op == "placeholder" :
270+ return (
271+ input_node .meta ["val" ].is_contiguous ()
272+ if input_order == InputDimOrder .NCHW
273+ else not input_node .meta ["val" ].is_contiguous ()
274+ )
275+ else :
276+ return (
277+ ChannelsLastTaggedReshapePass .is_nchw_node (input_node )
278+ if input_order == InputDimOrder .NCHW
279+ else ChannelsLastTaggedReshapePass .is_nhwc_node (input_node )
280+ )
281+
252282 def input_to_nhwc (
253283 self ,
254284 graph_module : torch .fx .GraphModule ,
@@ -258,7 +288,7 @@ def input_to_nhwc(
258288 if is_param_node (self .exported_program , input_node ):
259289 if (
260290 ChannelsLastTaggedReshapePass .XNN_NHWC_NODE in input_node .meta
261- and self .is_nchw_node (input_node )
291+ and ChannelsLastTaggedReshapePass .is_nchw_node (input_node )
262292 ):
263293 # This constant data tensor has been used somewhere else
264294 # in NCHW format so we can't use it here in NHWC format
@@ -275,6 +305,9 @@ def input_to_nhwc(
275305 elif self .is_nhwc_node (input_node ):
276306 return
277307
308+ if self .input_dim_order (input_node , InputDimOrder .NHWC ):
309+ return
310+
278311 if not self .can_be_converted_to_nhwc (input_node ):
279312 raise AssertionError (
280313 "Attempting to convert non-NHWC compatible node to NHWC"
@@ -302,6 +335,7 @@ def input_to_nhwc(
302335 args = (input_node ,),
303336 memory_format = torch .channels_last ,
304337 )
338+ self .mark_as_nhwc_node (input_node_nhwc )
305339
306340 if is_dynamic_input :
307341 # Replace downstream input_nodes with NHWC node
@@ -324,7 +358,7 @@ def input_to_nchw(
324358 if is_param_node (self .exported_program , input_node ):
325359 if (
326360 ChannelsLastTaggedReshapePass .XNN_NHWC_NODE in input_node .meta
327- and self .is_nhwc_node (input_node )
361+ and ChannelsLastTaggedReshapePass .is_nhwc_node (input_node )
328362 ):
329363 # This constant data tensor has been used somewhere else
330364 # in NHWC format so we can't use it here in NCHW format
@@ -342,6 +376,9 @@ def input_to_nchw(
342376 elif self .is_nchw_node (input_node ):
343377 return
344378
379+ if self .input_dim_order (input_node , InputDimOrder .NCHW ):
380+ return
381+
345382 if ChannelsLastTaggedReshapePass .PARTNER_NODE in input_node .meta :
346383 # Already has an associated NCHW node
347384 input_node_nchw = input_node .meta [
@@ -356,6 +393,7 @@ def input_to_nchw(
356393 args = (input_node ,),
357394 memory_format = torch .contiguous_format ,
358395 )
396+ self .mark_as_nchw_node (input_node_nchw )
359397
360398 self .insert_copy_and_assign_partner_nodes_quantization_sensitive (
361399 graph_module = graph_module ,
@@ -383,10 +421,12 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
383421 elif self .requires_nhwc_input (node ):
384422 # Nodes which enter this branch are ones that require their
385423 # first input to be nhwc. This makes this node's output nhwc too
386-
387424 self .input_to_nhwc (graph_module , node .args [0 ], node )
388- for input_node in node .all_input_nodes :
389- if input_node .op == "placeholder" and self .is_nhwc_node (input_node ):
425+ for input_node in node .all_input_nodes [1 :]:
426+ if (
427+ input_node .op == "placeholder"
428+ and ChannelsLastTaggedReshapePass .is_nhwc_node (input_node )
429+ ):
390430 raise AssertionError (
391431 f"Expected { input_node } to be NCHW in channels last reshape pass"
392432 )
@@ -395,11 +435,14 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
395435 # The node requires nchw inputs
396436 for input_node in node .all_input_nodes :
397437 self .input_to_nchw (graph_module , input_node , node )
438+ elif node .target == exir_ops .edge .aten ._to_copy .default :
439+ self .tag_node (node )
398440 else :
399441 # The node can have inputs in any format (but all must be the
400442 # same format)
401443 is_or_isnt_nhwc_node = [
402- self .is_nhwc_node (input_node ) for input_node in node .all_input_nodes
444+ ChannelsLastTaggedReshapePass .is_nhwc_node (input_node )
445+ for input_node in node .all_input_nodes
403446 ]
404447 if all (is_or_isnt_nhwc_node ):
405448 # All inputs are nhwc so this node's output is nhwc too
0 commit comments