Skip to content

Commit fd363e0

Browse files
Add meandim decomposition
Signed-off-by: Oscar Andersson <[email protected]> Change-Id: I6e9b3da1384d33903b039696d6c156b1a2f1841a
1 parent 8673567 commit fd363e0

File tree

6 files changed

+247
-78
lines changed

6 files changed

+247
-78
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
ConvertSplitToSlicePass,
2020
)
2121
from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass
22+
from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass
2223
from executorch.backends.arm._passes.insert_squeeze_after_sum_pass import (
2324
InsertSqueezeAfterSumPass,
2425
)
@@ -49,6 +50,7 @@ def transform_to_backend_pipeline(
4950
self.add_pass(RemoveClonePass())
5051
self.add_pass(ConvertExpandCopyToRepeatPass())
5152
self.add_pass(ConvertMeanDimToAveragePool())
53+
self.add_pass(DecomposeMeanDimPass())
5254
self.add_pass(DecomposeDivPass())
5355
self.add_pass(InsertSqueezeAfterSumPass())
5456
self.add_pass(ConvertSplitToSlicePass())
@@ -63,4 +65,5 @@ def transform_to_backend_pipeline(
6365
def transform_for_annotation_pipeline(self, graph_module: torch.fx.GraphModule):
6466
self.add_pass(DecomposeDivPass())
6567
self.add_pass(ScalarsToAttributePass())
68+
self.add_pass(DecomposeMeanDimPass())
6669
return self._transform(graph_module)
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.pass_base import ExportPass
10+
11+
12+
def get_meandim_decomposition(op) -> tuple:
13+
if op == exir_ops.edge.aten.mean.dim:
14+
return (
15+
exir_ops.edge.aten.sum.dim_IntList,
16+
exir_ops.edge.aten.full.default,
17+
exir_ops.edge.aten.mul.Tensor,
18+
)
19+
if op == torch.ops.aten.mean.dim:
20+
return (
21+
torch.ops.aten.sum.dim_IntList,
22+
torch.ops.aten.full.default,
23+
torch.ops.aten.mul.Tensor,
24+
)
25+
raise RuntimeError(f"Can't get meandim decomposition for op {op}")
26+
27+
28+
class DecomposeMeanDimPass(ExportPass):
29+
"""
30+
This pass decomposes meandim into a sum and mul node.
31+
32+
Example:
33+
y = mean_dim(x, dim, keepdim)
34+
Becomes:
35+
sum = sum.dim_IntList(x, dim, keepdim)
36+
y = mul(sum, 1/N)
37+
"""
38+
39+
def call_operator(self, op, args, kwargs, meta):
40+
if op not in (exir_ops.edge.aten.mean.dim, torch.ops.aten.mean.dim):
41+
return super().call_operator(op, args, kwargs, meta)
42+
43+
x = args[0]
44+
dim = args[1]
45+
keepdim = args[2] if len(args) > 2 else False
46+
if not keepdim:
47+
return super().call_operator(op, args, kwargs, meta)
48+
# if keepdim == True and dim == [-1, -2], mean.dim can be
49+
# decomposed to avg_pool2d. This is handled by ConvertMeanDimToAveragePool.
50+
if dim == [-1, -2]:
51+
# Simply return the mean.dim operator for future decomposition.
52+
return super().call_operator(op, args, kwargs, meta)
53+
shape = meta["val"].size()
54+
dtype = meta["val"].dtype
55+
input_shape = x.data.size()
56+
N = 1
57+
for d in dim:
58+
N *= input_shape[d]
59+
60+
sum_op, full_op, mul_op = get_meandim_decomposition(op)
61+
62+
sum = super().call_operator(sum_op, (x, dim, keepdim), {}, meta)
63+
full = super().call_operator(
64+
full_op, ([1] * len(shape), 1 / N), {"dtype": dtype}, meta
65+
)
66+
return super().call_operator(mul_op, (sum, full), {}, meta)

backends/arm/arm_partitioner.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,8 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
8585

8686
def is_node_supported_custom(self, node: torch.fx.Node) -> bool:
8787
if node.target == exir_ops.edge.aten.mean.dim:
88-
dim = node.args[1]
89-
keep_dim = node.args[2]
90-
if dim != [-1, -2] or keep_dim is False:
91-
return False
88+
keep_dim = node.args[2] if len(node.args) > 2 else False
89+
return keep_dim
9290
return True
9391

9492

backends/arm/operators/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
op_get_item,
2121
op_hardtanh,
2222
op_log,
23-
op_mean_dim,
2423
op_mm,
2524
op_mul,
2625
op_permute,

backends/arm/operators/op_mean_dim.py

Lines changed: 0 additions & 34 deletions
This file was deleted.

0 commit comments

Comments
 (0)