diff --git a/backends/arm/_passes/convert_minmax_pass.py b/backends/arm/_passes/convert_minmax_pass.py index 9f409632c20..4cfb259070d 100644 --- a/backends/arm/_passes/convert_minmax_pass.py +++ b/backends/arm/_passes/convert_minmax_pass.py @@ -3,7 +3,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import cast, Set, Type + import torch +from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.arm._passes.convert_squeezes_to_view import ( + ConvertSqueezesToViewPass, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -94,20 +100,28 @@ def call(self, graph_module: torch.fx.GraphModule): replace_node, op, squeeze_op = self.get_variables(node) # Unwrap args - if len(node.args) == 2: + if len(node.args) == 1: + # If dims is unspecified, min/max over all dims. + input_node = cast(torch.fx.Node, node.args[0]) + input_shape = get_first_fake_tensor(input_node).shape + dims = range(len(input_shape)) + keepdims = False + elif len(node.args) == 2: input_node, dims = node.args keepdims = False elif len(node.args) == 3: input_node, dims, keepdims = node.args else: - raise RuntimeError(f"Unexpected arg size in {node.name}") + raise RuntimeError( + f"Unexpected arg size {len(node.args)} in {node.name}" + ) try: - iter(dims) - except: - dims = [dims] + iter(dims) # type:ignore[assignment] + except Exception: + dims = [dims] # type:ignore[assignment] else: - dims = list(dims) + dims = list(dims) # type:ignore[assignment] # Unroll multi-dimensional reduction and keep-dims arg with graph_module.graph.inserting_before(node): diff --git a/backends/arm/test/ops/test_amin.py b/backends/arm/test/ops/test_amin.py index a24da9e1ba0..4d064e9f746 100644 --- a/backends/arm/test/ops/test_amin.py +++ b/backends/arm/test/ops/test_amin.py @@ -29,12 +29,16 @@ def __init__(self, dim, keep_dims): super().__init__() def forward(self, x): - return torch.amin(x, self.dim, self.keep_dims) + if self.dim is None: + return torch.amin(x, keepdim=self.keep_dims) + else: + return torch.amin(x, self.dim, self.keep_dims) - test_data: Dict[str, input_t] = { + test_data: Dict = { "rank_1_dim_0": lambda: ((torch.rand([10]),), 0, False), "rank_2_dim_1_keep_dims": lambda: ((torch.rand([2, 2]),), (1,), True), "rank_4_all_dim": lambda: ((torch.rand([1, 2, 5, 5]),), (0, 1, 2, 3), False), + "rank_4_no_dim": lambda: ((torch.rand([1, 2, 5, 5]),), None, False), "rank_4_0,3_keep_dims": lambda: ((torch.rand([1, 2, 2, 2]),), (0, 3), True), "rank_4_mult_batches": lambda: ((torch.rand([2, 2, 2, 2]),), (0), True), } @@ -52,7 +56,7 @@ def forward(self, x): x = torch.min(x, self.dim) return x[0] - test_data: Dict[str, input_t] = { + test_data: Dict = { "rank_1_dim_0": lambda: ((torch.rand([10]),), 0), "rank_2_dim_1": lambda: ((torch.rand([2, 2]),), 1), "rank_4_dim_2": lambda: ((torch.rand([2, 2, 2, 2]),), 2),