44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ from enum import Enum
78from typing import Optional , Tuple
89
910import torch
1920from executorch .exir .pass_base import PassResult
2021
2122
23+ class InputDimOrder (Enum ):
24+ NCHW = 1
25+ NHWC = 2
26+
27+
2228# TODO(T151254305) use subgraph_rewriter
2329class ChannelsLastTaggedReshapePass (XNNPACKPass ):
2430 """
@@ -83,17 +89,49 @@ class ChannelsLastTaggedReshapePass(XNNPACKPass):
8389 # is done
8490 PARTNER_NODE = "XNN_CHANNELS_LAST_TAGGED_RESHAPE_PARTNER_NODE"
8591
86- def mark_as_nhwc_node (self , node : torch .fx .Node ) -> None :
92+ @staticmethod
93+ def mark_as_nhwc_node (node : torch .fx .Node ) -> None :
8794 node .meta [ChannelsLastTaggedReshapePass .XNN_NHWC_NODE ] = True
8895
89- def mark_as_nchw_node (self , node : torch .fx .Node ) -> None :
96+ @staticmethod
97+ def mark_as_nchw_node (node : torch .fx .Node ) -> None :
9098 node .meta [ChannelsLastTaggedReshapePass .XNN_NHWC_NODE ] = False
9199
92- def is_nhwc_node (self , node : torch .fx .Node ) -> bool :
100+ def tag_node (self , node : torch .fx .Node ) -> None :
101+ if node .kwargs ["memory_format" ] == torch .channels_last :
102+ self .mark_as_nhwc_node (node )
103+ else :
104+ self .mark_as_nchw_node (node )
105+
106+ @staticmethod
107+ def is_nhwc_node (node : torch .fx .Node ) -> bool :
108+ if is_dequant (node ) and len (node .all_input_nodes ) > 0 :
109+ quantize_node = node .args [0 ]
110+ if len (quantize_node .all_input_nodes ) > 0 :
111+ actual_node = quantize_node .args [0 ]
112+ if actual_node .op == "placeholder" :
113+ return not actual_node .meta ["val" ][0 ].is_contiguous ()
114+ else :
115+ return actual_node .meta .get (
116+ ChannelsLastTaggedReshapePass .XNN_NHWC_NODE , False
117+ )
118+
93119 return node .meta .get (ChannelsLastTaggedReshapePass .XNN_NHWC_NODE , False )
94120
95- def is_nchw_node (self , node : torch .fx .Node ) -> bool :
96- return not self .is_nhwc_node (node )
121+ @staticmethod
122+ def is_nchw_node (node : torch .fx .Node ) -> bool :
123+ if is_dequant (node ) and len (node .all_input_nodes ) > 0 :
124+ quantize_node = node .args [0 ]
125+ if len (quantize_node .all_input_nodes ) > 0 :
126+ actual_node = quantize_node .args [0 ]
127+ if actual_node .op == "placeholder" :
128+ return actual_node .meta ["val" ][0 ].is_contiguous ()
129+ else :
130+ return not actual_node .meta .get (
131+ ChannelsLastTaggedReshapePass .XNN_NHWC_NODE , False
132+ )
133+
134+ return not ChannelsLastTaggedReshapePass .is_nhwc_node (node )
97135
98136 def requires_nhwc_input (self , node : torch .fx .Node ) -> bool :
99137 return node .target in self .memory_sensitive_ops_nhwc
@@ -111,7 +149,7 @@ def can_be_converted_to_nhwc(self, node: torch.fx.Node) -> bool:
111149 is_nchw_constant = (
112150 is_param_node (self .exported_program , node )
113151 and (ChannelsLastTaggedReshapePass .XNN_NHWC_NODE in node .meta )
114- and (self .is_nchw_node (node ))
152+ and (ChannelsLastTaggedReshapePass .is_nchw_node (node ))
115153 )
116154 return is_4d and not is_nchw_constant
117155
@@ -273,6 +311,22 @@ def insert_copy_and_assign_partner_nodes_quantization_sensitive(
273311 # in that case
274312 self .make_partners (original_input , copy_node )
275313
314+ def input_dim_order (
315+ self , input_node : torch .fx .Node , input_order : InputDimOrder
316+ ) -> bool :
317+ if input_node .op == "placeholder" :
318+ return (
319+ input_node .meta ["val" ].is_contiguous ()
320+ if input_order == InputDimOrder .NCHW
321+ else not input_node .meta ["val" ].is_contiguous ()
322+ )
323+ else :
324+ return (
325+ ChannelsLastTaggedReshapePass .is_nchw_node (input_node )
326+ if input_order == InputDimOrder .NCHW
327+ else ChannelsLastTaggedReshapePass .is_nhwc_node (input_node )
328+ )
329+
276330 def input_to_nhwc (
277331 self ,
278332 graph_module : torch .fx .GraphModule ,
@@ -282,7 +336,7 @@ def input_to_nhwc(
282336 if is_param_node (self .exported_program , input_node ):
283337 if (
284338 ChannelsLastTaggedReshapePass .XNN_NHWC_NODE in input_node .meta
285- and self .is_nchw_node (input_node )
339+ and ChannelsLastTaggedReshapePass .is_nchw_node (input_node )
286340 ):
287341 # This constant data tensor has been used somewhere else
288342 # in NCHW format so we can't use it here in NHWC format
@@ -296,7 +350,10 @@ def input_to_nhwc(
296350 if input_node .op == "placeholder" :
297351 if not input_node .meta ["val" ][0 ].is_contiguous ():
298352 return
299- elif self .is_nhwc_node (input_node ):
353+ elif ChannelsLastTaggedReshapePass .is_nhwc_node (input_node ):
354+ return
355+
356+ if self .input_dim_order (input_node , InputDimOrder .NHWC ):
300357 return
301358
302359 if not self .can_be_converted_to_nhwc (input_node ):
@@ -326,6 +383,8 @@ def input_to_nhwc(
326383 args = (input_node ,),
327384 memory_format = torch .channels_last ,
328385 )
386+ # Use static method for consistency
387+ ChannelsLastTaggedReshapePass .mark_as_nhwc_node (input_node_nhwc )
329388
330389 if is_dynamic_input :
331390 # Replace downstream input_nodes with NHWC node
@@ -348,7 +407,7 @@ def input_to_nchw(
348407 if is_param_node (self .exported_program , input_node ):
349408 if (
350409 ChannelsLastTaggedReshapePass .XNN_NHWC_NODE in input_node .meta
351- and self .is_nhwc_node (input_node )
410+ and ChannelsLastTaggedReshapePass .is_nhwc_node (input_node )
352411 ):
353412 # This constant data tensor has been used somewhere else
354413 # in NHWC format so we can't use it here in NCHW format
@@ -363,7 +422,10 @@ def input_to_nchw(
363422 if input_node .op == "placeholder" :
364423 if input_node .meta ["val" ].is_contiguous ():
365424 return
366- elif self .is_nchw_node (input_node ):
425+ elif ChannelsLastTaggedReshapePass .is_nchw_node (input_node ):
426+ return
427+
428+ if self .input_dim_order (input_node , InputDimOrder .NCHW ):
367429 return
368430
369431 if ChannelsLastTaggedReshapePass .PARTNER_NODE in input_node .meta :
@@ -380,6 +442,7 @@ def input_to_nchw(
380442 args = (input_node ,),
381443 memory_format = torch .contiguous_format ,
382444 )
445+ ChannelsLastTaggedReshapePass .mark_as_nchw_node (input_node_nchw )
383446
384447 self .insert_copy_and_assign_partner_nodes_quantization_sensitive (
385448 graph_module = graph_module ,
@@ -393,7 +456,12 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
393456 original_nodes = list (graph .nodes )
394457 for node in original_nodes :
395458 if len (node .all_input_nodes ) == 0 :
396- # This node has no inputs so we don't need to change anything
459+ # This node has no inputs so we don't need to change anything, but still need to tag input nodes
460+ if "val" in node .meta and isinstance (node .meta ["val" ], torch .Tensor ):
461+ if node .meta ["val" ].is_contiguous ():
462+ self .mark_as_nchw_node (node )
463+ else :
464+ self .mark_as_nhwc_node (node )
397465 continue
398466
399467 # Need special case for output node because it can have multiple output dim orders as we can output a tuple multiple nodes
@@ -407,10 +475,12 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
407475 elif self .requires_nhwc_input (node ):
408476 # Nodes which enter this branch are ones that require their
409477 # first input to be nhwc. This makes this node's output nhwc too
410-
411478 self .input_to_nhwc (graph_module , node .args [0 ], node )
412- for input_node in node .all_input_nodes :
413- if input_node .op == "placeholder" and self .is_nhwc_node (input_node ):
479+ for input_node in node .all_input_nodes [1 :]:
480+ if (
481+ input_node .op == "placeholder"
482+ and ChannelsLastTaggedReshapePass .is_nhwc_node (input_node )
483+ ):
414484 raise AssertionError (
415485 f"Expected { input_node } to be NCHW in channels last reshape pass"
416486 )
@@ -419,11 +489,14 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
419489 # The node requires nchw inputs
420490 for input_node in node .all_input_nodes :
421491 self .input_to_nchw (graph_module , input_node , node )
492+ elif node .target == exir_ops .edge .aten ._to_copy .default :
493+ self .tag_node (node )
422494 else :
423495 # The node can have inputs in any format (but all must be the
424496 # same format)
425497 is_or_isnt_nhwc_node = [
426- self .is_nhwc_node (input_node ) for input_node in node .all_input_nodes
498+ ChannelsLastTaggedReshapePass .is_nhwc_node (input_node )
499+ for input_node in node .all_input_nodes
427500 ]
428501 if all (is_or_isnt_nhwc_node ):
429502 # All inputs are nhwc so this node's output is nhwc too
0 commit comments