diff --git a/backends/arm/_passes/decompose_meandim_pass.py b/backends/arm/_passes/decompose_meandim_pass.py index 4d4c0ee75b1..135e2830d5d 100644 --- a/backends/arm/_passes/decompose_meandim_pass.py +++ b/backends/arm/_passes/decompose_meandim_pass.py @@ -19,13 +19,13 @@ def get_meandim_decomposition(op) -> tuple: - if op == exir_ops.edge.aten.mean.dim: + if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default): return ( exir_ops.edge.aten.sum.dim_IntList, exir_ops.edge.aten.full.default, exir_ops.edge.aten.mul.Tensor, ) - if op == torch.ops.aten.mean.dim: + if op in (torch.ops.aten.mean.dim, torch.ops.aten.mean.default): return ( torch.ops.aten.sum.dim_IntList, torch.ops.aten.full.default, @@ -35,17 +35,17 @@ def get_meandim_decomposition(op) -> tuple: def get_avgpool(op): - if op == exir_ops.edge.aten.mean.dim: + if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default): return exir_ops.edge.aten.avg_pool2d.default - if op == torch.ops.aten.mean.dim: + if op in (torch.ops.aten.mean.dim, torch.ops.aten.mean.default): return torch.ops.aten.avg_pool2d.default raise RuntimeError(f"Can't get meandim decomposition for op {op}") def get_view(op): - if op == exir_ops.edge.aten.mean.dim: + if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default): return exir_ops.edge.aten.view_copy.default - if op == torch.ops.aten.mean.dim: + if op in (torch.ops.aten.mean.dim, torch.ops.aten.mean.default): return torch.ops.aten.view_copy.default raise RuntimeError(f"Can't get meandim decomposition for op {op}") @@ -87,13 +87,18 @@ def __init__(self, graph_module, tosa_spec): ) def call_operator(self, op, args, kwargs, meta): - if op not in (exir_ops.edge.aten.mean.dim, torch.ops.aten.mean.dim): + if op not in ( + exir_ops.edge.aten.mean.dim, + torch.ops.aten.mean.dim, + exir_ops.edge.aten.mean.default, + torch.ops.aten.mean.default, + ): return super().call_operator(op, args, kwargs, meta) x = get_node_arg(args, 0) input_shape = list(x.data.shape) output_shape = list(meta["val"].shape) - dims_to_reduce = get_node_arg(args, 1) + dims_to_reduce = get_node_arg(args, 1, range(len(input_shape))) if dims_to_reduce is None: dims_to_reduce = range(len(input_shape)) dims_to_reduce = [dim % len(input_shape) for dim in dims_to_reduce] diff --git a/backends/arm/operator_support/tosa_profile_supported_op_lists.py b/backends/arm/operator_support/tosa_profile_supported_op_lists.py index ee61aa4cce6..b91ed4fb130 100644 --- a/backends/arm/operator_support/tosa_profile_supported_op_lists.py +++ b/backends/arm/operator_support/tosa_profile_supported_op_lists.py @@ -178,6 +178,7 @@ exir_ops.edge.aten.native_group_norm.default, exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten.mean.dim, + exir_ops.edge.aten.mean.default, exir_ops.edge.aten.mm.default, exir_ops.edge.aten.minimum.default, exir_ops.edge.aten.maximum.default, diff --git a/backends/arm/scripts/parse_test_names.py b/backends/arm/scripts/parse_test_names.py index 54f8aa7421d..56a2a9c6890 100644 --- a/backends/arm/scripts/parse_test_names.py +++ b/backends/arm/scripts/parse_test_names.py @@ -14,6 +14,7 @@ "hardswish.default", "linear.default", "maximum.default", + "mean.default", "multihead_attention.default", "adaptive_avg_pool2d.default", "bitwise_right_shift.Tensor", diff --git a/backends/arm/test/ops/test_mean_dim.py b/backends/arm/test/ops/test_mean_dim.py index 31797e72e78..babfb7d10da 100644 --- a/backends/arm/test/ops/test_mean_dim.py +++ b/backends/arm/test/ops/test_mean_dim.py @@ -4,6 +4,8 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Callable + import torch from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( @@ -344,3 +346,43 @@ def test_mean_dim_vgf_INT(test_data): tosa_version="TOSA-1.0+INT", ) pipeline.run() + + +mean_input_t = tuple[torch.Tensor, bool] + + +class MeanDefault(torch.nn.Module): + def forward(self, tensor: torch.Tensor, keepdim: bool): + return tensor.mean() + + test_data_suite: dict[str, Callable[[], mean_input_t]] = { + "rank1": lambda: ( + torch.rand( + 1, + ), + False, + ), + "rank2": lambda: (torch.rand(5, 5), True), + "rank4": lambda: (torch.rand(5, 1, 10, 1), False), + } + + +@common.parametrize("test_data", MeanDefault.test_data_suite) +def test_mean_tosa_FP(test_data): + pipeline = TosaPipelineFP[mean_input_t]( + MeanDefault(), + test_data(), + [], # Might be sum, avgpool, or both + ) + pipeline.run() + + +@common.parametrize("test_data", MeanDefault.test_data_suite) +def test_mean_tosa_INT(test_data): + pipeline = TosaPipelineINT[mean_input_t]( + MeanDefault(), + test_data(), + [], # Might be sum, avgpool, or both + symmetric_io_quantization=True, + ) + pipeline.run()