Skip to content

Commit 1392c1b

Browse files
cccclaifacebook-github-bot
authored andcommitted
Fix aten.amax lowering issue (#13381)
Summary: There was an error when lowering amax around this line `input_tensor = self.get_tensor(input_node, node)` and the issue is that we're trying to permute the tensor inside node_visitors, op_node.meta[QCOM_AXIS_ORDER] is (0, 1), however, tensor.shape is (1, 980, 49). Rollback Plan: Differential Revision: D80187368
1 parent ffb47e3 commit 1392c1b

File tree

3 files changed

+33
-1
lines changed

3 files changed

+33
-1
lines changed

backends/qualcomm/_passes/layout_transform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ def is_layout_agnostic(self, node: torch.fx.Node) -> bool:
175175
exir_ops.edge.aten.mean.dim,
176176
exir_ops.edge.aten.min.dim,
177177
exir_ops.edge.aten.sum.dim_IntList,
178+
exir_ops.edge.aten.amax.default,
178179
}:
179180
# if dimemsion is not kept, we'll have no clue how to do layout transform
180181
if len(node.args) < 3 or not node.args[2]:

backends/qualcomm/tests/models.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,24 @@ def forward(self, x):
102102
return torch.amax(x, dim=self.dim, keepdim=self.keepdim)
103103

104104

105+
class AMaxFollowingConv2D(torch.nn.Module):
106+
def __init__(
107+
self, in_channels, out_channels, kernel_size=3, dim=None, keepdim=False
108+
):
109+
super().__init__()
110+
self.conv = torch.nn.Conv2d(
111+
in_channels, out_channels, kernel_size, padding=kernel_size // 2
112+
)
113+
self.dim = dim
114+
self.keepdim = keepdim
115+
116+
def forward(self, x):
117+
x = self.conv(
118+
x
119+
) # Apply convolution (output shape: [batch, out_channels, H, W])
120+
return torch.amax(x, dim=self.dim, keepdim=self.keepdim)
121+
122+
105123
class AMin(torch.nn.Module):
106124
def __init__(self, dim=None, keepdim=False):
107125
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,13 @@ def test_qnn_backend_amax(self):
134134
with self.subTest(i=i):
135135
self.lower_module_and_test_output(module, sample_input)
136136

137+
def test_qnn_backend_amax_conv(self):
138+
sample_input = (torch.randn(2, 3, 64, 64),) # [batch, channels, height, width]
139+
module = AMaxFollowingConv2D( # noqa: F405
140+
in_channels=3, out_channels=16, kernel_size=3, dim=-1, keepdim=False
141+
)
142+
self.lower_module_and_test_output(module, sample_input)
143+
137144
def test_qnn_backend_amin(self):
138145
modules = [AMin(dim=1, keepdim=False), AMin(dim=1, keepdim=True)] # noqa: F405
139146
sample_input = (torch.randn(4, 4),)
@@ -1435,6 +1442,13 @@ def test_qnn_backend_amax(self):
14351442
module = self.get_qdq_module(module, sample_input)
14361443
self.lower_module_and_test_output(module, sample_input)
14371444

1445+
def test_qnn_backend_amax_conv(self):
1446+
sample_input = (torch.randn(2, 3, 64, 64),) # [batch, channels, height, width]
1447+
module = AMaxFollowingConv2D( # noqa: F405
1448+
in_channels=3, out_channels=16, kernel_size=3, dim=-1, keepdim=False
1449+
)
1450+
self.lower_module_and_test_output(module, sample_input)
1451+
14381452
def test_qnn_backend_amin(self):
14391453
modules = [AMin(dim=1, keepdim=False), AMin(dim=1, keepdim=True)] # noqa: F405
14401454
sample_input = (torch.randn(4, 4),)
@@ -3418,7 +3432,6 @@ def test_qnn_backend_generate_optrace(self):
34183432

34193433
for compiler_spec in compiler_specs:
34203434
with tempfile.TemporaryDirectory() as tmp_dir:
3421-
34223435
edge_prog_mgr = to_edge_transform_and_lower_to_qnn(
34233436
module, sample_input, compiler_spec
34243437
).to_executorch()

0 commit comments

Comments
 (0)