Skip to content

Commit 2ac76c2

Browse files
committed
Arm backend: fix meandim when dim = None
This is a valid argument, but the pass did not support it. Signed-off-by: Erik Lundell <[email protected]> Change-Id: I27d0cf1732b07b9c4b0aa100e730bd1580716dc6
1 parent 0b748bf commit 2ac76c2

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

backends/arm/_passes/decompose_meandim_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ def call_operator(self, op, args, kwargs, meta):
9494
input_shape = list(x.data.shape)
9595
output_shape = list(meta["val"].shape)
9696
dims_to_reduce = get_node_arg(args, 1)
97+
if dims_to_reduce is None:
98+
dims_to_reduce = range(len(input_shape))
9799
dims_to_reduce = [dim % len(input_shape) for dim in dims_to_reduce]
98100
dims_to_reduce = [dim for dim in dims_to_reduce if input_shape[dim] != 1]
99101

backends/arm/test/ops/test_mean_dim.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class MeanDim(torch.nn.Module):
115115
test_data_suite: dict[str, tuple] = {
116116
"rank_1_keepdim": lambda: (
117117
torch.rand(7),
118-
(0),
118+
0,
119119
True,
120120
),
121121
"rank_2_keepdim": lambda: (
@@ -168,6 +168,11 @@ class MeanDim(torch.nn.Module):
168168
(0, 1, 2, 3),
169169
True,
170170
),
171+
"rand_none_keepdim": lambda: (
172+
torch.rand(1, 5, 7, 3),
173+
None,
174+
True,
175+
),
171176
"rank_1": lambda: (
172177
torch.rand(7),
173178
(-1),
@@ -280,7 +285,6 @@ def test_mean_dim_tosa_INT(test_data):
280285
(test_data,),
281286
[], # Might be sum, avgpool, or both
282287
symmetric_io_quantization=True,
283-
custom_path="MEANDIM",
284288
)
285289
pipeline.run()
286290

0 commit comments

Comments
 (0)