@@ -91,18 +91,10 @@ def is_nchw_node(self, node: torch.fx.Node) -> bool:
9191 return not self .is_nhwc_node (node )
9292
9393 def requires_nhwc_input (self , node : torch .fx .Node ) -> bool :
94- return (
95- node .target in self .memory_sensitive_ops_nhwc
96- or node .name == "output"
97- and not node .args [0 ][0 ].meta ["val" ].is_contiguous ()
98- )
94+ return node .target in self .memory_sensitive_ops_nhwc
9995
10096 def requires_nchw_inputs (self , node : torch .fx .Node ) -> bool :
101- return (
102- node .target in self .memory_sensitive_ops_nchw
103- or node .name == "output"
104- and node .args [0 ][0 ].meta ["val" ].is_contiguous ()
105- )
97+ return node .target in self .memory_sensitive_ops_nchw
10698
10799 def can_be_converted_to_nhwc (self , node : torch .fx .Node ) -> bool :
108100 # There are two conditions that must be met for a node to be able to
@@ -380,18 +372,21 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
380372 # This node has no inputs so we don't need to change anything
381373 continue
382374
383- if self .requires_nhwc_input (node ):
375+ # Need special case for output node because it can have multiple output dim orders as we can output a tuple multiple nodes
376+ if node .op == "output" :
377+ out_tuple = node .args [0 ]
378+ for out_node in out_tuple :
379+ if out_node .meta ["val" ].is_contiguous ():
380+ self .input_to_nchw (graph_module , out_node , node )
381+ else :
382+ self .input_to_nhwc (graph_module , out_node , node )
383+ elif self .requires_nhwc_input (node ):
384384 # Nodes which enter this branch are ones that require their
385385 # first input to be nhwc. This makes this node's output nhwc too
386- # Currently, all nodes like this should have all of their other
387- # inputs as nchw, so fail if this is not true
388- if node .name == "output" :
389- self .input_to_nhwc (graph_module , node .args [0 ][0 ], node )
390- else :
391- self .input_to_nhwc (graph_module , node .args [0 ], node )
392-
393- for input_node in node .all_input_nodes [1 :]:
394- if self .is_nhwc_node (input_node ):
386+
387+ self .input_to_nhwc (graph_module , node .args [0 ], node )
388+ for input_node in node .all_input_nodes :
389+ if input_node .op == "placeholder" and self .is_nhwc_node (input_node ):
395390 raise AssertionError (
396391 f"Expected { input_node } to be NCHW in channels last reshape pass"
397392 )
0 commit comments