Skip to content

Commit 9be3aaa

Browse files
authored
Arm backend: Support min/max with unset dim. (#14884)
The dim is defined as optional, but before the pass requried it to be set. When it is not set, the operation should be done on all dims. Signed-off-by: Erik Lundell <[email protected]>
1 parent 45bf018 commit 9be3aaa

File tree

2 files changed

+23
-10
lines changed

2 files changed

+23
-10
lines changed

backends/arm/_passes/convert_minmax_pass.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from typing import Set, Type
6+
from typing import cast, Set, Type
77

88
import torch
9+
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
910
from executorch.backends.arm._passes.convert_squeezes_to_view import (
1011
ConvertSqueezesToViewPass,
1112
)
@@ -101,20 +102,28 @@ def call(self, graph_module: torch.fx.GraphModule):
101102
replace_node, op, squeeze_op = self.get_variables(node)
102103

103104
# Unwrap args
104-
if len(node.args) == 2:
105+
if len(node.args) == 1:
106+
# If dims is unspecified, min/max over all dims.
107+
input_node = cast(torch.fx.Node, node.args[0])
108+
input_shape = get_first_fake_tensor(input_node).shape
109+
dims = range(len(input_shape))
110+
keepdims = False
111+
elif len(node.args) == 2:
105112
input_node, dims = node.args
106113
keepdims = False
107114
elif len(node.args) == 3:
108115
input_node, dims, keepdims = node.args
109116
else:
110-
raise RuntimeError(f"Unexpected arg size in {node.name}")
117+
raise RuntimeError(
118+
f"Unexpected arg size {len(node.args)} in {node.name}"
119+
)
111120

112121
try:
113-
iter(dims)
114-
except:
115-
dims = [dims]
122+
iter(dims) # type:ignore[assignment]
123+
except Exception:
124+
dims = [dims] # type:ignore[assignment]
116125
else:
117-
dims = list(dims)
126+
dims = list(dims) # type:ignore[assignment]
118127

119128
# Unroll multi-dimensional reduction and keep-dims arg
120129
with graph_module.graph.inserting_before(node):

backends/arm/test/ops/test_amin.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,16 @@ def __init__(self, dim, keep_dims):
2929
super().__init__()
3030

3131
def forward(self, x):
32-
return torch.amin(x, self.dim, self.keep_dims)
32+
if self.dim is None:
33+
return torch.amin(x, keepdim=self.keep_dims)
34+
else:
35+
return torch.amin(x, self.dim, self.keep_dims)
3336

34-
test_data: Dict[str, input_t] = {
37+
test_data: Dict = {
3538
"rank_1_dim_0": lambda: ((torch.rand([10]),), 0, False),
3639
"rank_2_dim_1_keep_dims": lambda: ((torch.rand([2, 2]),), (1,), True),
3740
"rank_4_all_dim": lambda: ((torch.rand([1, 2, 5, 5]),), (0, 1, 2, 3), False),
41+
"rank_4_no_dim": lambda: ((torch.rand([1, 2, 5, 5]),), None, False),
3842
"rank_4_0,3_keep_dims": lambda: ((torch.rand([1, 2, 2, 2]),), (0, 3), True),
3943
"rank_4_mult_batches": lambda: ((torch.rand([2, 2, 2, 2]),), (0), True),
4044
}
@@ -52,7 +56,7 @@ def forward(self, x):
5256
x = torch.min(x, self.dim)
5357
return x[0]
5458

55-
test_data: Dict[str, input_t] = {
59+
test_data: Dict = {
5660
"rank_1_dim_0": lambda: ((torch.rand([10]),), 0),
5761
"rank_2_dim_1": lambda: ((torch.rand([2, 2]),), 1),
5862
"rank_4_dim_2": lambda: ((torch.rand([2, 2, 2, 2]),), 2),

0 commit comments

Comments
 (0)