Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions backends/arm/_passes/convert_minmax_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import cast, Set, Type

import torch
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
from executorch.backends.arm._passes.convert_squeezes_to_view import (
ConvertSqueezesToViewPass,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult

Expand Down Expand Up @@ -94,20 +100,28 @@ def call(self, graph_module: torch.fx.GraphModule):
replace_node, op, squeeze_op = self.get_variables(node)

# Unwrap args
if len(node.args) == 2:
if len(node.args) == 1:
# If dims is unspecified, min/max over all dims.
input_node = cast(torch.fx.Node, node.args[0])
input_shape = get_first_fake_tensor(input_node).shape
dims = range(len(input_shape))
keepdims = False
elif len(node.args) == 2:
input_node, dims = node.args
keepdims = False
elif len(node.args) == 3:
input_node, dims, keepdims = node.args
else:
raise RuntimeError(f"Unexpected arg size in {node.name}")
raise RuntimeError(
f"Unexpected arg size {len(node.args)} in {node.name}"
)

try:
iter(dims)
except:
dims = [dims]
iter(dims) # type:ignore[assignment]
except Exception:
dims = [dims] # type:ignore[assignment]
else:
dims = list(dims)
dims = list(dims) # type:ignore[assignment]

# Unroll multi-dimensional reduction and keep-dims arg
with graph_module.graph.inserting_before(node):
Expand Down
10 changes: 7 additions & 3 deletions backends/arm/test/ops/test_amin.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,16 @@ def __init__(self, dim, keep_dims):
super().__init__()

def forward(self, x):
return torch.amin(x, self.dim, self.keep_dims)
if self.dim is None:
return torch.amin(x, keepdim=self.keep_dims)
else:
return torch.amin(x, self.dim, self.keep_dims)

test_data: Dict[str, input_t] = {
test_data: Dict = {
"rank_1_dim_0": lambda: ((torch.rand([10]),), 0, False),
"rank_2_dim_1_keep_dims": lambda: ((torch.rand([2, 2]),), (1,), True),
"rank_4_all_dim": lambda: ((torch.rand([1, 2, 5, 5]),), (0, 1, 2, 3), False),
"rank_4_no_dim": lambda: ((torch.rand([1, 2, 5, 5]),), None, False),
"rank_4_0,3_keep_dims": lambda: ((torch.rand([1, 2, 2, 2]),), (0, 3), True),
"rank_4_mult_batches": lambda: ((torch.rand([2, 2, 2, 2]),), (0), True),
}
Expand All @@ -52,7 +56,7 @@ def forward(self, x):
x = torch.min(x, self.dim)
return x[0]

test_data: Dict[str, input_t] = {
test_data: Dict = {
"rank_1_dim_0": lambda: ((torch.rand([10]),), 0),
"rank_2_dim_1": lambda: ((torch.rand([2, 2]),), 1),
"rank_4_dim_2": lambda: ((torch.rand([2, 2, 2, 2]),), 2),
Expand Down
Loading