44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ from enum import Enum
78from typing import Optional , Tuple
89
910import torch
1415from executorch .exir .pass_base import PassResult
1516
1617
18+ class InputDimOrder (Enum ):
19+ NCHW = 1
20+ NHWC = 2
21+
22+
1723# TODO(T151254305) use subgraph_rewriter
1824class ChannelsLastTaggedReshapePass (XNNPACKPass ):
1925 """
@@ -84,11 +90,13 @@ def mark_as_nhwc_node(self, node: torch.fx.Node) -> None:
8490 def mark_as_nchw_node (self , node : torch .fx .Node ) -> None :
8591 node .meta [ChannelsLastTaggedReshapePass .XNN_NHWC_NODE ] = False
8692
87- def is_nhwc_node (self , node : torch .fx .Node ) -> bool :
93+ @staticmethod
94+ def is_nhwc_node (node : torch .fx .Node ) -> bool :
8895 return node .meta .get (ChannelsLastTaggedReshapePass .XNN_NHWC_NODE , False )
8996
90- def is_nchw_node (self , node : torch .fx .Node ) -> bool :
91- return not self .is_nhwc_node (node )
97+ @staticmethod
98+ def is_nchw_node (node : torch .fx .Node ) -> bool :
99+ return not ChannelsLastTaggedReshapePass .is_nhwc_node (node )
92100
93101 def requires_nhwc_input (self , node : torch .fx .Node ) -> bool :
94102 return (
@@ -114,7 +122,7 @@ def can_be_converted_to_nhwc(self, node: torch.fx.Node) -> bool:
114122 is_nchw_constant = (
115123 is_param_node (self .exported_program , node )
116124 and (ChannelsLastTaggedReshapePass .XNN_NHWC_NODE in node .meta )
117- and (self .is_nchw_node (node ))
125+ and (ChannelsLastTaggedReshapePass .is_nchw_node (node ))
118126 )
119127 return is_4d and not is_nchw_constant
120128
@@ -257,6 +265,22 @@ def insert_copy_and_assign_partner_nodes_quantization_sensitive(
257265 # in that case
258266 self .make_partners (original_input , copy_node )
259267
268+ def input_dim_order (
269+ self , input_node : torch .fx .Node , input_order : InputDimOrder
270+ ) -> bool :
271+ if input_node .name == "x" :
272+ return (
273+ input_node .meta ["val" ].is_contiguous ()
274+ if input_order == InputDimOrder .NCHW
275+ else not input_node .meta ["val" ].is_contiguous ()
276+ )
277+ else :
278+ return (
279+ ChannelsLastTaggedReshapePass .is_nchw_node (input_node )
280+ if input_order == InputDimOrder .NCHW
281+ else ChannelsLastTaggedReshapePass .is_nhwc_node (input_node )
282+ )
283+
260284 def input_to_nhwc (
261285 self ,
262286 graph_module : torch .fx .GraphModule ,
@@ -266,7 +290,7 @@ def input_to_nhwc(
266290 if is_param_node (self .exported_program , input_node ):
267291 if (
268292 ChannelsLastTaggedReshapePass .XNN_NHWC_NODE in input_node .meta
269- and self .is_nchw_node (input_node )
293+ and ChannelsLastTaggedReshapePass .is_nchw_node (input_node )
270294 ):
271295 # This constant data tensor has been used somewhere else
272296 # in NCHW format so we can't use it here in NHWC format
@@ -282,6 +306,8 @@ def input_to_nhwc(
282306 return
283307 elif self .is_nhwc_node (input_node ):
284308 return
309+ elif self .input_dim_order (input_node , InputDimOrder .NHWC ):
310+ return
285311
286312 if not self .can_be_converted_to_nhwc (input_node ):
287313 raise AssertionError (
@@ -332,7 +358,7 @@ def input_to_nchw(
332358 if is_param_node (self .exported_program , input_node ):
333359 if (
334360 ChannelsLastTaggedReshapePass .XNN_NHWC_NODE in input_node .meta
335- and self .is_nhwc_node (input_node )
361+ and ChannelsLastTaggedReshapePass .is_nhwc_node (input_node )
336362 ):
337363 # This constant data tensor has been used somewhere else
338364 # in NHWC format so we can't use it here in NCHW format
@@ -349,6 +375,8 @@ def input_to_nchw(
349375 return
350376 elif self .is_nchw_node (input_node ):
351377 return
378+ elif self .input_dim_order (input_node , InputDimOrder .NCHW ):
379+ return
352380
353381 if ChannelsLastTaggedReshapePass .PARTNER_NODE in input_node .meta :
354382 # Already has an associated NCHW node
@@ -391,7 +419,7 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
391419 self .input_to_nhwc (graph_module , node .args [0 ], node )
392420
393421 for input_node in node .all_input_nodes [1 :]:
394- if self .is_nhwc_node (input_node ):
422+ if ChannelsLastTaggedReshapePass .is_nhwc_node (input_node ):
395423 raise AssertionError (
396424 f"Expected { input_node } to be NCHW in channels last reshape pass"
397425 )
@@ -409,7 +437,8 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
409437 # The node can have inputs in any format (but all must be the
410438 # same format)
411439 is_or_isnt_nhwc_node = [
412- self .is_nhwc_node (input_node ) for input_node in node .all_input_nodes
440+ ChannelsLastTaggedReshapePass .is_nhwc_node (input_node )
441+ for input_node in node .all_input_nodes
413442 ]
414443 if all (is_or_isnt_nhwc_node ):
415444 # All inputs are nhwc so this node's output is nhwc too
0 commit comments