Skip to content

Commit ce331a6

Browse files
committed
Arm backend: Support min/max with unset dim.
the dim is defined as optional, but before the pass requried it to be set. When it is not set, the operation should be done on all dims. Signed-off-by: Erik Lundell <[email protected]> Change-Id: Ifa5ae8c616bdaa422c96f71227d5594702cfd99a
1 parent 0b748bf commit ce331a6

File tree

2 files changed

+23
-10
lines changed

2 files changed

+23
-10
lines changed

backends/arm/_passes/convert_minmax_pass.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from typing import Set, Type
6+
from typing import cast, Set, Type
77

88
import torch
9+
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
910
from executorch.backends.arm._passes.convert_squeezes_to_view import (
1011
ConvertSqueezesToViewPass,
1112
)
@@ -101,20 +102,28 @@ def call(self, graph_module: torch.fx.GraphModule):
101102
replace_node, op, squeeze_op = self.get_variables(node)
102103

103104
# Unwrap args
104-
if len(node.args) == 2:
105+
if len(node.args) == 1:
106+
# If dims is unspecified, min/max over all dims.
107+
input_node = cast(torch.fx.Node, node.args[0])
108+
input_shape = get_first_fake_tensor(input_node).shape
109+
dims = range(len(input_shape))
110+
keepdims = False
111+
elif len(node.args) == 2:
105112
input_node, dims = node.args
106113
keepdims = False
107114
elif len(node.args) == 3:
108115
input_node, dims, keepdims = node.args
109116
else:
110-
raise RuntimeError(f"Unexpected arg size in {node.name}")
117+
raise RuntimeError(
118+
f"Unexpected arg size {len(node.args)} in {node.name}"
119+
)
111120

112121
try:
113-
iter(dims)
114-
except:
115-
dims = [dims]
122+
iter(dims) # type:ignore[assignment]
123+
except Exception:
124+
dims = [dims] # type:ignore[assignment]
116125
else:
117-
dims = list(dims)
126+
dims = list(dims) # type:ignore[assignment]
118127

119128
# Unroll multi-dimensional reduction and keep-dims arg
120129
with graph_module.graph.inserting_before(node):

backends/arm/test/ops/test_amin.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,16 @@ def __init__(self, dim, keep_dims):
2929
super().__init__()
3030

3131
def forward(self, x):
32-
return torch.amin(x, self.dim, self.keep_dims)
32+
if self.dim is None:
33+
return torch.amin(x, keepdim=self.keep_dims)
34+
else:
35+
return torch.amin(x, self.dim, self.keep_dims)
3336

34-
test_data: Dict[str, input_t] = {
37+
test_data: Dict = {
3538
"rank_1_dim_0": lambda: ((torch.rand([10]),), 0, False),
3639
"rank_2_dim_1_keep_dims": lambda: ((torch.rand([2, 2]),), (1,), True),
3740
"rank_4_all_dim": lambda: ((torch.rand([1, 2, 5, 5]),), (0, 1, 2, 3), False),
41+
"rank_4_no_dim": lambda: ((torch.rand([1, 2, 5, 5]),), None, False),
3842
"rank_4_0,3_keep_dims": lambda: ((torch.rand([1, 2, 2, 2]),), (0, 3), True),
3943
"rank_4_mult_batches": lambda: ((torch.rand([2, 2, 2, 2]),), (0), True),
4044
}
@@ -52,7 +56,7 @@ def forward(self, x):
5256
x = torch.min(x, self.dim)
5357
return x[0]
5458

55-
test_data: Dict[str, input_t] = {
59+
test_data: Dict = {
5660
"rank_1_dim_0": lambda: ((torch.rand([10]),), 0),
5761
"rank_2_dim_1": lambda: ((torch.rand([2, 2]),), 1),
5862
"rank_4_dim_2": lambda: ((torch.rand([2, 2, 2, 2]),), 2),

0 commit comments

Comments
 (0)