11# Copyright 2024-2025 Arm Limited and/or its affiliates.
2- # All rights reserved.
32#
43# This source code is licensed under the BSD-style license found in the
54# LICENSE file in the root directory of this source tree.
65
7- # pyre-unsafe
6+ from math import prod
87
98import torch
109from executorch.backends.arm._passes import ArmPass
@@ -28,42 +27,111 @@ def get_meandim_decomposition(op) -> tuple:
2827 raise RuntimeError(f"Can't get meandim decomposition for op {op}")
2928
3029
30+ def get_avgpool(op):
31+ if op == exir_ops.edge.aten.mean.dim:
32+ return exir_ops.edge.aten.avg_pool2d.default
33+ if op == torch.ops.aten.mean.dim:
34+ return torch.ops.aten.avg_pool2d.default
35+ raise RuntimeError(f"Can't get meandim decomposition for op {op}")
36+
37+
38+ def get_view(op):
39+ if op == exir_ops.edge.aten.mean.dim:
40+ return exir_ops.edge.aten.view_copy.default
41+ if op == torch.ops.aten.mean.dim:
42+ return torch.ops.aten.view_copy.default
43+ raise RuntimeError(f"Can't get meandim decomposition for op {op}")
44+
45+
3146class DecomposeMeanDimPass(ArmPass):
3247 """
33- This pass decomposes meandim into a sum and mul node.
48+ Decomposes a meandim into avg_pool and/or sum + mul (1/N) depending on which dims the mean is taken for:
49+ h,w -> avg_pool
50+ n,c -> sum + mul(1/N)
51+ For rank < 4, the input is first reshaped to 4D by padding with dim=1 from the left.
3452
3553 Example:
36- y = mean_dim(x, dim, keepdim)
54+ x = mean_dim(x, (0,2), keepdim=False) # x = (c,h,w )
3755 Becomes:
38- sum = sum.dim_IntList(x, dim, keepdim)
39- y = mul(sum, 1/N)
56+ x = view_copy.default(x, new_shape=(1,c,h,w)) # Reshape to work with avg_pool
57+ x = avg_pool2d.default(x, kernel=(1,w), stride=(1,1)) # Reduce w with avg_pool
58+ x = sum.dim_IntList(x, dim=1, keepdims=True) # Reduce c with sum
59+ x = mul.Tensor(x, 1/c) # Divide by number of channels to get mean
60+ x = view_copy.default(x, new_shape=(h)) # Squeeze dims since keepdims = False
4061 """
4162
4263 def call_operator(self, op, args, kwargs, meta):
4364 if op not in (exir_ops.edge.aten.mean.dim, torch.ops.aten.mean.dim):
4465 return super().call_operator(op, args, kwargs, meta)
4566
4667 x = get_node_arg(args, 0)
47- dim = get_node_arg(args, 1)
48- keepdim = get_node_arg(args, 2, False)
49-
50- # if dim == [-1, -2], mean.dim can be
51- # decomposed to avg_pool2d. This is handled by ConvertMeanDimToAveragePool.
52- if dim == [-1, -2]:
53- # Simply return the mean.dim operator for future decomposition.
54- return super().call_operator(op, args, kwargs, meta)
68+ input_shape = x.data.size()
69+ output_shape = meta["val"].size()
70+ dims_to_reduce = get_node_arg(args, 1)
71+ dims_to_reduce = [dim % len(input_shape) for dim in dims_to_reduce]
5572
56- shape = meta["val"].size()
5773 dtype = meta["val"].dtype
58- input_shape = x.data.size()
59- N = 1
60- for d in dim:
61- N *= input_shape[d]
74+ view_op = get_view(op)
6275
76+ if len(input_shape) > 4:
77+ raise NotImplementedError(
78+ f"{op} with rank > 4 is currently not supported for the TOSA backend."
79+ )
80+
81+ # Unsqueeze to 4D
82+ if len(input_shape) < 4:
83+ pad_n = 4 - len(input_shape)
84+ new_shape = [1] * pad_n + list(input_shape)
85+ dims_to_reduce = [dim + pad_n for dim in dims_to_reduce]
86+
87+ x = super().call_operator(view_op, (x, new_shape), {}, meta, True)
88+
89+ # Reduce (h,w) by avg pool
90+ dims_to_reduce_by_avgpool = [dim for dim in dims_to_reduce if dim >= 2]
91+ x = self._reduce_by_average_pool(op, x, dims_to_reduce_by_avgpool, meta)
92+
93+ # Reduce (n, c) by reduce sum
94+ dims_to_reduce_by_sum = [dim for dim in dims_to_reduce if dim < 2]
95+ x = self._reduce_by_sum(op, x, dims_to_reduce_by_sum, meta, dtype)
96+
97+ # Reshape to correct output shape if necessary
98+ if x.data.size() != output_shape:
99+ x = super().call_operator(view_op, (x, output_shape), {}, meta, True)
100+
101+ return x
102+
103+ def _reduce_by_sum(self, op, input_node, dims, meta, dtype):
104+ if len(dims) == 0:
105+ return input_node
106+
107+ input_shape = input_node.data.size()
108+ output_shape = meta["val"].size()
109+ N = prod((n for i, n in enumerate(input_shape) if i in dims))
63110 sum_op, full_op, mul_op = get_meandim_decomposition(op)
64111
65- sum = super().call_operator(sum_op, (x, dim, keepdim ), {}, meta, True)
112+ sum = super().call_operator(sum_op, (input_node, dims, True ), {}, meta, True)
66113 full = super().call_operator(
67- full_op, ([1] * len(shape ), 1 / N), {"dtype": dtype}, meta, True
114+ full_op, ([1] * len(output_shape ), 1 / N), {"dtype": dtype}, meta, True
68115 )
69116 return super().call_operator(mul_op, (sum, full), {}, meta, True)
117+
118+ def _reduce_by_average_pool(self, op, input_node, dims, meta):
119+ if len(dims) == 0:
120+ return input_node
121+
122+ avgpool_op = get_avgpool(op)
123+ input_shape = input_node.data.size()
124+
125+ stride = [1, 1]
126+ if dims in ([2, 3], [3, 2]):
127+ kernel_size = [input_shape[2], input_shape[3]]
128+ elif dims == [3]:
129+ kernel_size = [1, input_shape[3]]
130+ elif dims == [2]:
131+ kernel_size = [input_shape[2], 1]
132+ else:
133+ raise RuntimeError(f"Bad dims {dims} for {op} decomposition of mean_dim.")
134+
135+ return super().call_operator(
136+ avgpool_op, (input_node, kernel_size, stride), {}, meta, True
137+ )
0 commit comments