Skip to content

Commit 619e98c

Browse files
committed
fixed bug with multiple outputs having multiple dim orders
1 parent 1af16cd commit 619e98c

File tree

6 files changed

+66
-24
lines changed

6 files changed

+66
-24
lines changed

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -91,18 +91,10 @@ def is_nchw_node(self, node: torch.fx.Node) -> bool:
9191
return not self.is_nhwc_node(node)
9292

9393
def requires_nhwc_input(self, node: torch.fx.Node) -> bool:
94-
return (
95-
node.target in self.memory_sensitive_ops_nhwc
96-
or node.name == "output"
97-
and not node.args[0][0].meta["val"].is_contiguous()
98-
)
94+
return node.target in self.memory_sensitive_ops_nhwc
9995

10096
def requires_nchw_inputs(self, node: torch.fx.Node) -> bool:
101-
return (
102-
node.target in self.memory_sensitive_ops_nchw
103-
or node.name == "output"
104-
and node.args[0][0].meta["val"].is_contiguous()
105-
)
97+
return node.target in self.memory_sensitive_ops_nchw
10698

10799
def can_be_converted_to_nhwc(self, node: torch.fx.Node) -> bool:
108100
# There are two conditions that must be met for a node to be able to
@@ -380,18 +372,21 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
380372
# This node has no inputs so we don't need to change anything
381373
continue
382374

383-
if self.requires_nhwc_input(node):
375+
# Need special case for output node because it can have multiple output dim orders as we can output a tuple multiple nodes
376+
if node.op == "output":
377+
out_tuple = node.args[0]
378+
for out_node in out_tuple:
379+
if out_node.meta["val"].is_contiguous():
380+
self.input_to_nchw(graph_module, out_node, node)
381+
else:
382+
self.input_to_nhwc(graph_module, out_node, node)
383+
elif self.requires_nhwc_input(node):
384384
# Nodes which enter this branch are ones that require their
385385
# first input to be nhwc. This makes this node's output nhwc too
386-
# Currently, all nodes like this should have all of their other
387-
# inputs as nchw, so fail if this is not true
388-
if node.name == "output":
389-
self.input_to_nhwc(graph_module, node.args[0][0], node)
390-
else:
391-
self.input_to_nhwc(graph_module, node.args[0], node)
392-
393-
for input_node in node.all_input_nodes[1:]:
394-
if self.is_nhwc_node(input_node):
386+
387+
self.input_to_nhwc(graph_module, node.args[0], node)
388+
for input_node in node.all_input_nodes:
389+
if input_node.op == "placeholder" and self.is_nhwc_node(input_node):
395390
raise AssertionError(
396391
f"Expected {input_node} to be NCHW in channels last reshape pass"
397392
)

backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,3 +335,50 @@ def test_dq_conv2d_channels_last_tagged_reshape_pass(self) -> None:
335335
)
336336
.run_method_and_compare_outputs()
337337
)
338+
339+
class ConvAddConvOutput(torch.nn.Module):
340+
def __init__(self):
341+
super().__init__()
342+
self.conv1 = torch.nn.Conv2d(3, 16, 3)
343+
self.conv2 = torch.nn.Conv2d(16, 16, 3)
344+
345+
def forward(self, x):
346+
y = self.conv1(x)
347+
z = torch.add(y, 1.0)
348+
out1 = self.conv2(z)
349+
out2 = z
350+
return out1, out2
351+
352+
ConvAddConvOutputModule = ConvAddConvOutput()
353+
354+
def test_conv_add_conv_output(self):
355+
x = torch.randn(1, 3, 8, 8)
356+
357+
self.run_tester(self.ConvAddConvOutput().eval(), (x,))
358+
359+
x_cl = x.to(memory_format=torch.channels_last)
360+
self.run_tester(self.ConvAddConvOutput().eval(), (x_cl,))
361+
362+
class ThreeOutputsModel(torch.nn.Module):
363+
def __init__(self):
364+
super().__init__()
365+
self.conv1 = torch.nn.Conv2d(3, 3, 3)
366+
self.conv2 = torch.nn.Conv2d(3, 3, 3)
367+
self.linear = torch.nn.Linear(6, 6)
368+
369+
def forward(self, x):
370+
conv1_out = self.conv1(x)
371+
conv2_out = self.conv2(x)
372+
linear_out = self.linear(x)
373+
374+
return linear_out, conv1_out, conv2_out
375+
376+
ThreeOutputsModelModule = ThreeOutputsModel()
377+
378+
def test_three_outputs_model(self):
379+
x = torch.randn(1, 3, 6, 6)
380+
381+
self.run_tester(self.ThreeOutputsModelModule.eval(), (x,))
382+
383+
x_cl = x.to(memory_format=torch.channels_last)
384+
self.run_tester(self.ThreeOutputsModelModule.eval(), (x_cl,))

extension/llm/tokenizers

Submodule eigen updated from 7294434 to a39ade4

third-party/ao

Submodule ao updated 100 files

0 commit comments

Comments
 (0)