44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ from enum import Enum
78from typing import Optional , Tuple
89
910import torch
1415from executorch .exir .pass_base import PassResult
1516
1617
18+ class InputDimOrder (Enum ):
19+ NCHW = 1
20+ NHWC = 2
21+
22+
1723# TODO(T151254305) use subgraph_rewriter
1824class ChannelsLastTaggedReshapePass (XNNPACKPass ):
1925 """
@@ -84,11 +90,13 @@ def mark_as_nhwc_node(self, node: torch.fx.Node) -> None:
8490 def mark_as_nchw_node (self , node : torch .fx .Node ) -> None :
8591 node .meta [ChannelsLastTaggedReshapePass .XNN_NHWC_NODE ] = False
8692
87- def is_nhwc_node (self , node : torch .fx .Node ) -> bool :
93+ @staticmethod
94+ def is_nhwc_node (node : torch .fx .Node ) -> bool :
8895 return node .meta .get (ChannelsLastTaggedReshapePass .XNN_NHWC_NODE , False )
8996
90- def is_nchw_node (self , node : torch .fx .Node ) -> bool :
91- return not self .is_nhwc_node (node )
97+ @staticmethod
98+ def is_nchw_node (node : torch .fx .Node ) -> bool :
99+ return not ChannelsLastTaggedReshapePass .is_nhwc_node (node )
92100
93101 def requires_nhwc_input (self , node : torch .fx .Node ) -> bool :
94102 return (
@@ -114,7 +122,7 @@ def can_be_converted_to_nhwc(self, node: torch.fx.Node) -> bool:
114122 is_nchw_constant = (
115123 is_param_node (self .exported_program , node )
116124 and (ChannelsLastTaggedReshapePass .XNN_NHWC_NODE in node .meta )
117- and (self .is_nchw_node (node ))
125+ and (ChannelsLastTaggedReshapePass .is_nchw_node (node ))
118126 )
119127 return is_4d and not is_nchw_constant
120128
@@ -257,6 +265,22 @@ def insert_copy_and_assign_partner_nodes_quantization_sensitive(
257265 # in that case
258266 self .make_partners (original_input , copy_node )
259267
268+ def input_dim_order (
269+ self , input_node : torch .fx .Node , input_order : InputDimOrder
270+ ) -> bool :
271+ if input_node .name == "x" :
272+ return (
273+ input_node .meta ["val" ].is_contiguous ()
274+ if input_order == InputDimOrder .NCHW
275+ else not input_node .meta ["val" ].is_contiguous ()
276+ )
277+ else :
278+ return (
279+ ChannelsLastTaggedReshapePass .is_nchw_node (input_node )
280+ if input_order == InputDimOrder .NCHW
281+ else ChannelsLastTaggedReshapePass .is_nhwc_node (input_node )
282+ )
283+
260284 def input_to_nhwc (
261285 self ,
262286 graph_module : torch .fx .GraphModule ,
@@ -266,7 +290,7 @@ def input_to_nhwc(
266290 if is_param_node (self .exported_program , input_node ):
267291 if (
268292 ChannelsLastTaggedReshapePass .XNN_NHWC_NODE in input_node .meta
269- and self .is_nchw_node (input_node )
293+ and ChannelsLastTaggedReshapePass .is_nchw_node (input_node )
270294 ):
271295 # This constant data tensor has been used somewhere else
272296 # in NCHW format so we can't use it here in NHWC format
@@ -277,10 +301,7 @@ def input_to_nhwc(
277301 # serializing graph, but don't do anything else here
278302 self .mark_as_nhwc_node (input_node )
279303
280- if input_node .name == "x" :
281- if not input_node .meta ["val" ][0 ].is_contiguous ():
282- return
283- elif self .is_nhwc_node (input_node ):
304+ if self .input_dim_order (input_node , InputDimOrder .NHWC ):
284305 return
285306
286307 if not self .can_be_converted_to_nhwc (input_node ):
@@ -332,7 +353,7 @@ def input_to_nchw(
332353 if is_param_node (self .exported_program , input_node ):
333354 if (
334355 ChannelsLastTaggedReshapePass .XNN_NHWC_NODE in input_node .meta
335- and self .is_nhwc_node (input_node )
356+ and ChannelsLastTaggedReshapePass .is_nhwc_node (input_node )
336357 ):
337358 # This constant data tensor has been used somewhere else
338359 # in NHWC format so we can't use it here in NCHW format
@@ -344,10 +365,7 @@ def input_to_nchw(
344365 # do anything else here
345366 self .mark_as_nchw_node (input_node )
346367
347- if input_node .name == "x" :
348- if input_node .meta ["val" ].is_contiguous ():
349- return
350- elif self .is_nchw_node (input_node ):
368+ if self .input_dim_order (input_node , InputDimOrder .NCHW ):
351369 return
352370
353371 if ChannelsLastTaggedReshapePass .PARTNER_NODE in input_node .meta :
@@ -391,7 +409,7 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
391409 self .input_to_nhwc (graph_module , node .args [0 ], node )
392410
393411 for input_node in node .all_input_nodes [1 :]:
394- if self .is_nhwc_node (input_node ):
412+ if ChannelsLastTaggedReshapePass .is_nhwc_node (input_node ):
395413 raise AssertionError (
396414 f"Expected { input_node } to be NCHW in channels last reshape pass"
397415 )
@@ -409,7 +427,8 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
409427 # The node can have inputs in any format (but all must be the
410428 # same format)
411429 is_or_isnt_nhwc_node = [
412- self .is_nhwc_node (input_node ) for input_node in node .all_input_nodes
430+ ChannelsLastTaggedReshapePass .is_nhwc_node (input_node )
431+ for input_node in node .all_input_nodes
413432 ]
414433 if all (is_or_isnt_nhwc_node ):
415434 # All inputs are nhwc so this node's output is nhwc too
0 commit comments