Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions backends/arm/_passes/decompose_meandim_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@


def get_meandim_decomposition(op) -> tuple:
if op == exir_ops.edge.aten.mean.dim:
if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default):
return (
exir_ops.edge.aten.sum.dim_IntList,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.mul.Tensor,
)
if op == torch.ops.aten.mean.dim:
if op in (torch.ops.aten.mean.dim, torch.ops.aten.mean.default):
return (
torch.ops.aten.sum.dim_IntList,
torch.ops.aten.full.default,
Expand All @@ -35,17 +35,17 @@ def get_meandim_decomposition(op) -> tuple:


def get_avgpool(op):
if op == exir_ops.edge.aten.mean.dim:
if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default):
return exir_ops.edge.aten.avg_pool2d.default
if op == torch.ops.aten.mean.dim:
if op in (torch.ops.aten.mean.dim, torch.ops.aten.mean.default):
return torch.ops.aten.avg_pool2d.default
raise RuntimeError(f"Can't get meandim decomposition for op {op}")


def get_view(op):
if op == exir_ops.edge.aten.mean.dim:
if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default):
return exir_ops.edge.aten.view_copy.default
if op == torch.ops.aten.mean.dim:
if op in (torch.ops.aten.mean.dim, torch.ops.aten.mean.default):
return torch.ops.aten.view_copy.default
raise RuntimeError(f"Can't get meandim decomposition for op {op}")

Expand Down Expand Up @@ -87,13 +87,18 @@ def __init__(self, graph_module, tosa_spec):
)

def call_operator(self, op, args, kwargs, meta):
if op not in (exir_ops.edge.aten.mean.dim, torch.ops.aten.mean.dim):
if op not in (
exir_ops.edge.aten.mean.dim,
torch.ops.aten.mean.dim,
exir_ops.edge.aten.mean.default,
torch.ops.aten.mean.default,
):
return super().call_operator(op, args, kwargs, meta)

x = get_node_arg(args, 0)
input_shape = list(x.data.shape)
output_shape = list(meta["val"].shape)
dims_to_reduce = get_node_arg(args, 1)
dims_to_reduce = get_node_arg(args, 1, range(len(input_shape)))
if dims_to_reduce is None:
dims_to_reduce = range(len(input_shape))
dims_to_reduce = [dim % len(input_shape) for dim in dims_to_reduce]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@
exir_ops.edge.aten.native_group_norm.default,
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.mean.dim,
exir_ops.edge.aten.mean.default,
exir_ops.edge.aten.mm.default,
exir_ops.edge.aten.minimum.default,
exir_ops.edge.aten.maximum.default,
Expand Down
1 change: 1 addition & 0 deletions backends/arm/scripts/parse_test_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"hardswish.default",
"linear.default",
"maximum.default",
"mean.default",
"multihead_attention.default",
"adaptive_avg_pool2d.default",
"bitwise_right_shift.Tensor",
Expand Down
42 changes: 42 additions & 0 deletions backends/arm/test/ops/test_mean_dim.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable

import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import (
Expand Down Expand Up @@ -344,3 +346,43 @@ def test_mean_dim_vgf_INT(test_data):
tosa_version="TOSA-1.0+INT",
)
pipeline.run()


mean_input_t = tuple[torch.Tensor, bool]


class MeanDefault(torch.nn.Module):
def forward(self, tensor: torch.Tensor, keepdim: bool):
return tensor.mean()

test_data_suite: dict[str, Callable[[], mean_input_t]] = {
"rank1": lambda: (
torch.rand(
1,
),
False,
),
"rank2": lambda: (torch.rand(5, 5), True),
"rank4": lambda: (torch.rand(5, 1, 10, 1), False),
}


@common.parametrize("test_data", MeanDefault.test_data_suite)
def test_mean_tosa_FP(test_data):
pipeline = TosaPipelineFP[mean_input_t](
MeanDefault(),
test_data(),
[], # Might be sum, avgpool, or both
)
pipeline.run()


@common.parametrize("test_data", MeanDefault.test_data_suite)
def test_mean_tosa_INT(test_data):
pipeline = TosaPipelineINT[mean_input_t](
MeanDefault(),
test_data(),
[], # Might be sum, avgpool, or both
symmetric_io_quantization=True,
)
pipeline.run()
Loading