44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ from enum import Enum
78from typing import Optional , Tuple
89
910import torch
1011from executorch .backends .xnnpack ._passes .xnnpack_pass import XNNPACKPass
11- from executorch .backends .xnnpack .utils .quant_utils import is_dynamic_qdq
12+ from executorch .backends .xnnpack .utils .quant_utils import is_dequant , is_dynamic_qdq
1213from executorch .backends .xnnpack .utils .utils import is_param_node
1314from executorch .exir .dialects ._ops import ops as exir_ops
1415from executorch .exir .pass_base import PassResult
1516
1617
18+ class InputDimOrder (Enum ):
19+ NCHW = 1
20+ NHWC = 2
21+
22+
1723# TODO(T151254305) use subgraph_rewriter
1824class ChannelsLastTaggedReshapePass (XNNPACKPass ):
1925 """
@@ -78,17 +84,49 @@ class ChannelsLastTaggedReshapePass(XNNPACKPass):
7884 # is done
7985 PARTNER_NODE = "XNN_CHANNELS_LAST_TAGGED_RESHAPE_PARTNER_NODE"
8086
81- def mark_as_nhwc_node (self , node : torch .fx .Node ) -> None :
87+ @staticmethod
88+ def mark_as_nhwc_node (node : torch .fx .Node ) -> None :
8289 node .meta [ChannelsLastTaggedReshapePass .XNN_NHWC_NODE ] = True
8390
84- def mark_as_nchw_node (self , node : torch .fx .Node ) -> None :
91+ @staticmethod
92+ def mark_as_nchw_node (node : torch .fx .Node ) -> None :
8593 node .meta [ChannelsLastTaggedReshapePass .XNN_NHWC_NODE ] = False
8694
87- def is_nhwc_node (self , node : torch .fx .Node ) -> bool :
95+ def tag_node (self , node : torch .fx .Node ) -> None :
96+ if node .kwargs ["memory_format" ] == torch .channels_last :
97+ self .mark_as_nhwc_node (node )
98+ else :
99+ self .mark_as_nchw_node (node )
100+
101+ @staticmethod
102+ def is_nhwc_node (node : torch .fx .Node ) -> bool :
103+ if is_dequant (node ) and len (node .all_input_nodes ) > 0 :
104+ quantize_node = node .args [0 ]
105+ if len (quantize_node .all_input_nodes ) > 0 :
106+ actual_node = quantize_node .args [0 ]
107+ if actual_node .op == "placeholder" :
108+ return not actual_node .meta ["val" ][0 ].is_contiguous ()
109+ else :
110+ return actual_node .meta .get (
111+ ChannelsLastTaggedReshapePass .XNN_NHWC_NODE , False
112+ )
113+
88114 return node .meta .get (ChannelsLastTaggedReshapePass .XNN_NHWC_NODE , False )
89115
90- def is_nchw_node (self , node : torch .fx .Node ) -> bool :
91- return not self .is_nhwc_node (node )
116+ @staticmethod
117+ def is_nchw_node (node : torch .fx .Node ) -> bool :
118+ if is_dequant (node ) and len (node .all_input_nodes ) > 0 :
119+ quantize_node = node .args [0 ]
120+ if len (quantize_node .all_input_nodes ) > 0 :
121+ actual_node = quantize_node .args [0 ]
122+ if actual_node .op == "placeholder" :
123+ return actual_node .meta ["val" ][0 ].is_contiguous ()
124+ else :
125+ return not actual_node .meta .get (
126+ ChannelsLastTaggedReshapePass .XNN_NHWC_NODE , False
127+ )
128+
129+ return not ChannelsLastTaggedReshapePass .is_nhwc_node (node )
92130
93131 def requires_nhwc_input (self , node : torch .fx .Node ) -> bool :
94132 return node .target in self .memory_sensitive_ops_nhwc
@@ -106,7 +144,7 @@ def can_be_converted_to_nhwc(self, node: torch.fx.Node) -> bool:
106144 is_nchw_constant = (
107145 is_param_node (self .exported_program , node )
108146 and (ChannelsLastTaggedReshapePass .XNN_NHWC_NODE in node .meta )
109- and (self .is_nchw_node (node ))
147+ and (ChannelsLastTaggedReshapePass .is_nchw_node (node ))
110148 )
111149 return is_4d and not is_nchw_constant
112150
@@ -249,6 +287,22 @@ def insert_copy_and_assign_partner_nodes_quantization_sensitive(
249287 # in that case
250288 self .make_partners (original_input , copy_node )
251289
290+ def input_dim_order (
291+ self , input_node : torch .fx .Node , input_order : InputDimOrder
292+ ) -> bool :
293+ if input_node .op == "placeholder" :
294+ return (
295+ input_node .meta ["val" ].is_contiguous ()
296+ if input_order == InputDimOrder .NCHW
297+ else not input_node .meta ["val" ].is_contiguous ()
298+ )
299+ else :
300+ return (
301+ ChannelsLastTaggedReshapePass .is_nchw_node (input_node )
302+ if input_order == InputDimOrder .NCHW
303+ else ChannelsLastTaggedReshapePass .is_nhwc_node (input_node )
304+ )
305+
252306 def input_to_nhwc (
253307 self ,
254308 graph_module : torch .fx .GraphModule ,
@@ -258,7 +312,7 @@ def input_to_nhwc(
258312 if is_param_node (self .exported_program , input_node ):
259313 if (
260314 ChannelsLastTaggedReshapePass .XNN_NHWC_NODE in input_node .meta
261- and self .is_nchw_node (input_node )
315+ and ChannelsLastTaggedReshapePass .is_nchw_node (input_node )
262316 ):
263317 # This constant data tensor has been used somewhere else
264318 # in NCHW format so we can't use it here in NHWC format
@@ -272,7 +326,10 @@ def input_to_nhwc(
272326 if input_node .op == "placeholder" :
273327 if not input_node .meta ["val" ][0 ].is_contiguous ():
274328 return
275- elif self .is_nhwc_node (input_node ):
329+ elif ChannelsLastTaggedReshapePass .is_nhwc_node (input_node ):
330+ return
331+
332+ if self .input_dim_order (input_node , InputDimOrder .NHWC ):
276333 return
277334
278335 if not self .can_be_converted_to_nhwc (input_node ):
@@ -302,6 +359,8 @@ def input_to_nhwc(
302359 args = (input_node ,),
303360 memory_format = torch .channels_last ,
304361 )
362+ # Use static method for consistency
363+ ChannelsLastTaggedReshapePass .mark_as_nhwc_node (input_node_nhwc )
305364
306365 if is_dynamic_input :
307366 # Replace downstream input_nodes with NHWC node
@@ -324,7 +383,7 @@ def input_to_nchw(
324383 if is_param_node (self .exported_program , input_node ):
325384 if (
326385 ChannelsLastTaggedReshapePass .XNN_NHWC_NODE in input_node .meta
327- and self .is_nhwc_node (input_node )
386+ and ChannelsLastTaggedReshapePass .is_nhwc_node (input_node )
328387 ):
329388 # This constant data tensor has been used somewhere else
330389 # in NHWC format so we can't use it here in NCHW format
@@ -339,7 +398,10 @@ def input_to_nchw(
339398 if input_node .op == "placeholder" :
340399 if input_node .meta ["val" ].is_contiguous ():
341400 return
342- elif self .is_nchw_node (input_node ):
401+ elif ChannelsLastTaggedReshapePass .is_nchw_node (input_node ):
402+ return
403+
404+ if self .input_dim_order (input_node , InputDimOrder .NCHW ):
343405 return
344406
345407 if ChannelsLastTaggedReshapePass .PARTNER_NODE in input_node .meta :
@@ -356,6 +418,7 @@ def input_to_nchw(
356418 args = (input_node ,),
357419 memory_format = torch .contiguous_format ,
358420 )
421+ ChannelsLastTaggedReshapePass .mark_as_nchw_node (input_node_nchw )
359422
360423 self .insert_copy_and_assign_partner_nodes_quantization_sensitive (
361424 graph_module = graph_module ,
@@ -369,7 +432,12 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
369432 original_nodes = list (graph .nodes )
370433 for node in original_nodes :
371434 if len (node .all_input_nodes ) == 0 :
372- # This node has no inputs so we don't need to change anything
435+ # This node has no inputs so we don't need to change anything, but still need to tag input nodes
436+ if "val" in node .meta and isinstance (node .meta ["val" ], torch .Tensor ):
437+ if node .meta ["val" ].is_contiguous ():
438+ self .mark_as_nchw_node (node )
439+ else :
440+ self .mark_as_nhwc_node (node )
373441 continue
374442
375443 # Need special case for output node because it can have multiple output dim orders as we can output a tuple multiple nodes
@@ -383,10 +451,12 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
383451 elif self .requires_nhwc_input (node ):
384452 # Nodes which enter this branch are ones that require their
385453 # first input to be nhwc. This makes this node's output nhwc too
386-
387454 self .input_to_nhwc (graph_module , node .args [0 ], node )
388- for input_node in node .all_input_nodes :
389- if input_node .op == "placeholder" and self .is_nhwc_node (input_node ):
455+ for input_node in node .all_input_nodes [1 :]:
456+ if (
457+ input_node .op == "placeholder"
458+ and ChannelsLastTaggedReshapePass .is_nhwc_node (input_node )
459+ ):
390460 raise AssertionError (
391461 f"Expected { input_node } to be NCHW in channels last reshape pass"
392462 )
@@ -395,11 +465,14 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
395465 # The node requires nchw inputs
396466 for input_node in node .all_input_nodes :
397467 self .input_to_nchw (graph_module , input_node , node )
468+ elif node .target == exir_ops .edge .aten ._to_copy .default :
469+ self .tag_node (node )
398470 else :
399471 # The node can have inputs in any format (but all must be the
400472 # same format)
401473 is_or_isnt_nhwc_node = [
402- self .is_nhwc_node (input_node ) for input_node in node .all_input_nodes
474+ ChannelsLastTaggedReshapePass .is_nhwc_node (input_node )
475+ for input_node in node .all_input_nodes
403476 ]
404477 if all (is_or_isnt_nhwc_node ):
405478 # All inputs are nhwc so this node's output is nhwc too
0 commit comments