@@ -110,7 +110,9 @@ def is_nhwc_node(node: torch.fx.Node) -> bool:
110110 if len (quantize_node .all_input_nodes ) > 0 :
111111 actual_node = quantize_node .args [0 ]
112112 if actual_node .op == "placeholder" :
113- return not actual_node .meta ["val" ][0 ].is_contiguous ()
113+ return ChannelsLastTaggedReshapePass ._is_nhwc_tensor (
114+ actual_node .meta ["val" ][0 ]
115+ )
114116 else :
115117 return actual_node .meta .get (
116118 ChannelsLastTaggedReshapePass .XNN_NHWC_NODE , False
@@ -125,14 +127,36 @@ def is_nchw_node(node: torch.fx.Node) -> bool:
125127 if len (quantize_node .all_input_nodes ) > 0 :
126128 actual_node = quantize_node .args [0 ]
127129 if actual_node .op == "placeholder" :
128- return actual_node .meta ["val" ][0 ].is_contiguous ()
130+ return not ChannelsLastTaggedReshapePass ._is_nhwc_tensor (
131+ actual_node .meta ["val" ][0 ]
132+ )
129133 else :
130134 return not actual_node .meta .get (
131135 ChannelsLastTaggedReshapePass .XNN_NHWC_NODE , False
132136 )
133137
134138 return not ChannelsLastTaggedReshapePass .is_nhwc_node (node )
135139
140+ @staticmethod
141+ def _is_nhwc_tensor (tensor : torch .Tensor ) -> bool :
142+ nhwc = tensor .is_contiguous (memory_format = torch .channels_last )
143+ nchw = tensor .is_contiguous ()
144+ # if both are true false
145+ # if both nchw and nhwc are true
146+ # then we want to see this is nchw hence return false
147+ # if either of nchw or nhwc is false, then just rely on hwc
148+ # if both are false, mayb channels_last_3d, then return nhwc
149+ # however this should not happen here
150+ # return (not (nchw and nhwc)) and nhwc
151+ # Readable version
152+ if nchw and nhwc :
153+ return False
154+ else :
155+ return nhwc
156+
157+ def _is_nhwc (self , tensor : torch .Tensor ) -> bool :
158+ return ChannelsLastTaggedReshapePass ._is_nhwc_tensor (tensor )
159+
136160 def requires_nhwc_input (self , node : torch .fx .Node ) -> bool :
137161 return node .target in self .memory_sensitive_ops_nhwc
138162
@@ -315,11 +339,8 @@ def input_dim_order(
315339 self , input_node : torch .fx .Node , input_order : InputDimOrder
316340 ) -> bool :
317341 if input_node .op == "placeholder" :
318- return (
319- input_node .meta ["val" ].is_contiguous ()
320- if input_order == InputDimOrder .NCHW
321- else not input_node .meta ["val" ].is_contiguous ()
322- )
342+ is_nhwc = self ._is_nhwc (input_node .meta ["val" ])
343+ return not is_nhwc if input_order == InputDimOrder .NCHW else is_nhwc
323344 else :
324345 return (
325346 ChannelsLastTaggedReshapePass .is_nchw_node (input_node )
@@ -348,7 +369,7 @@ def input_to_nhwc(
348369 self .mark_as_nhwc_node (input_node )
349370
350371 if input_node .op == "placeholder" :
351- if not input_node .meta ["val" ][0 ]. is_contiguous ( ):
372+ if self . _is_nhwc ( input_node .meta ["val" ][0 ]):
352373 return
353374 elif ChannelsLastTaggedReshapePass .is_nhwc_node (input_node ):
354375 return
@@ -420,7 +441,7 @@ def input_to_nchw(
420441 self .mark_as_nchw_node (input_node )
421442
422443 if input_node .op == "placeholder" :
423- if input_node .meta ["val" ]. is_contiguous ( ):
444+ if not self . _is_nhwc ( input_node .meta ["val" ]):
424445 return
425446 elif ChannelsLastTaggedReshapePass .is_nchw_node (input_node ):
426447 return
@@ -462,17 +483,17 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
462483 and isinstance (node .meta ["val" ], torch .Tensor )
463484 and len (node .meta ["val" ].shape ) == 4
464485 ):
465- if node .meta ["val" ].is_contiguous ():
466- self .mark_as_nchw_node (node )
467- else :
486+ if self ._is_nhwc (node .meta ["val" ]):
468487 self .mark_as_nhwc_node (node )
488+ else :
489+ self .mark_as_nchw_node (node )
469490 continue
470491
471492 # Need special case for output node because it can have multiple output dim orders as we can output a tuple multiple nodes
472493 if node .op == "output" :
473494 out_tuple = node .args [0 ]
474495 for out_node in out_tuple :
475- if out_node .meta ["val" ]. is_contiguous ( ):
496+ if not self . _is_nhwc ( out_node .meta ["val" ]):
476497 self .input_to_nchw (graph_module , out_node , node )
477498 else :
478499 self .input_to_nhwc (graph_module , out_node , node )
0 commit comments