Skip to content

Commit 50f6b0c

Browse files
martinlsmMartin Lindström
andauthored
Arm backend: Replace squeeze with view in DecomposeAnyPass (#15733)
Rename ConvertAnyDefaultDimDimsPass to DecomposeAnyPass because, in practice, it does same thing for the any op as DecomposeSumPass does for sum. When keepdim=False, create a single view_copy to the reduced shape instead of inserting squeeze_copy. This eliminates the chronological dependency of ConvertSqueezesToViewPass. cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai Signed-off-by: Martin Lindström <[email protected]> Co-authored-by: Martin Lindström <[email protected]>
1 parent caeb9ff commit 50f6b0c

File tree

4 files changed

+25
-25
lines changed

4 files changed

+25
-25
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa
1414
from .cast_to_int32_pass import CastToInt32Pass # noqa
1515
from .conv1d_unsqueeze_pass import Conv1dUnsqueezePass # noqa
16-
from .convert_any_default_dim_dims_pass import ConvertAnyDefaultDimDimsPass # noqa
1716
from .convert_elu_params import ConvertELUParamsPass # noqa
1817
from .convert_expand_copy_to_repeat import ConvertExpandCopyToRepeatPass # noqa
1918
from .convert_full_like_to_full_pass import ConvertFullLikeToFullPass # noqa
@@ -31,6 +30,7 @@
3130
from .decompose_adaptive_avg_pool2d_pass import DecomposeAdaptiveAvgPool2dPass # noqa
3231
from .decompose_add_sub_alpha_pass import DecomposeAddSubAlphaPass # noqa
3332
from .decompose_addmm_pass import DecomposeAddmmPass # noqa
33+
from .decompose_any_pass import DecomposeAnyPass # noqa
3434
from .decompose_asin_and_acos_pass import DecomposeAsinAndAcosPass # noqa
3535
from .decompose_asinh_pass import DecomposeAsinhPass # noqa
3636
from .decompose_atan_pass import DecomposeAtanPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
CastToInt32Pass,
1919
ComputeConstantOpsAOT,
2020
Conv1dUnsqueezePass,
21-
ConvertAnyDefaultDimDimsPass,
2221
ConvertELUParamsPass,
2322
ConvertExpandCopyToRepeatPass,
2423
ConvertFullLikeToFullPass,
@@ -35,6 +34,7 @@
3534
DecomposeAdaptiveAvgPool2dPass,
3635
DecomposeAddmmPass,
3736
DecomposeAddSubAlphaPass,
37+
DecomposeAnyPass,
3838
DecomposeAsinAndAcosPass,
3939
DecomposeAsinhPass,
4040
DecomposeAtanhPass,
@@ -241,7 +241,7 @@ def _tosa_pipeline(
241241
self.add_pass(DecomposeDivPass())
242242
self.add_pass(DecomposeSoftmaxPass())
243243
self.add_pass(ConvertMinMaxPass())
244-
self.add_pass(ConvertAnyDefaultDimDimsPass())
244+
self.add_pass(DecomposeAnyPass())
245245
self.add_pass(DecomposeAdaptiveAvgPool2dPass())
246246
self.add_pass(DecomposeAvgPool2d())
247247
self.add_pass(

backends/arm/_passes/convert_any_default_dim_dims_pass.py renamed to backends/arm/_passes/decompose_any_pass.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77

88
import torch
99
from executorch.backends.arm._passes import ArmPass
10-
from executorch.backends.arm._passes.convert_squeezes_to_view import (
11-
ConvertSqueezesToViewPass,
12-
)
10+
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
1311
from executorch.exir.dialects._ops import ( # type: ignore[import-not-found]
1412
ops as exir_ops,
1513
)
@@ -19,38 +17,40 @@
1917
)
2018

2119

22-
class ConvertAnyDefaultDimDimsPass(ArmPass):
20+
class DecomposeAnyPass(ArmPass):
2321
"""
24-
Converts any.default, any.dim and any.dims to a sequence of any.dim by unrolling multi-dimensional reduction.
25-
Please refer to KeepDimsFalseToSqueezePass for an explanation of this coversion.
22+
Converts any.default, any.dim and any.dims to a sequence of any.dim by
23+
unrolling multi-dimensional reductions with keepdim=True. If keepdim=False
24+
was requested, the final shape adjustment is implemented with a
25+
view_copy.default to the reduced shape.
2626
2727
Example 1
2828
Original:
29-
any() # x.shape: [dim1, dim2, ..., dimn]
29+
any.dim() # x.shape: [dim1, dim2, ..., dimn]
3030
After pass:
3131
any.dim(dim1, keepdim = True)
3232
any.dim(dim2, keepdim = True)
3333
...
3434
any.dim(dimn, keepdim = True)
35-
squeeze(dim = [dim1, dim2, ...., dimn])
35+
view_copy(shape = squeezed_shape)
3636
3737
Example 2
3838
Original:
3939
any.dim(dim1, keepdim = False)
4040
After pass:
4141
any.dim(dim1, keepdim = True)
42-
squeeze(dim = [dim1])
42+
view_copy(shape = squeezed_shape)
4343
4444
Example 3
4545
Original:
4646
any.dims([dim1, dim2], keepdim = False)
4747
After pass:
4848
any.dim(dim1, keepdim = True)
4949
any.dim(dim2, keepdim = True)
50-
squeeze(dim = [dim1, dim2])
50+
view_copy(shape = squeezed_shape)
5151
"""
5252

53-
_passes_required_after: Set[Type[ExportPass]] = {ConvertSqueezesToViewPass}
53+
_passes_required_after: Set[Type[ExportPass]] = set()
5454

5555
def call(self, graph_module: torch.fx.GraphModule):
5656
modified = False
@@ -67,40 +67,40 @@ def call(self, graph_module: torch.fx.GraphModule):
6767
if len(node.args) == 1:
6868
# any.default(input)
6969
input_node = (node.args)[0]
70-
dims = range(len(input_node.meta["val"].shape))
70+
dims_to_reduce = range(len(input_node.meta["val"].shape))
7171
keepdim = False
7272
elif len(node.args) == 2:
7373
# any.dim/dims(input, dims=dims)
74-
input_node, dims = node.args
74+
input_node, dims_to_reduce = node.args
7575
keepdim = False
7676
elif len(node.args) == 3:
7777
# any.dim/dims(input, dims=dims, keepdim=keepdim)
78-
input_node, dims, keepdim = node.args
78+
input_node, dims_to_reduce, keepdim = node.args
7979
else:
8080
raise RuntimeError(
8181
f"Unexpected arg size {len(node.args)} in {node.name}"
8282
)
8383
try:
84-
iter(dims)
84+
iter(dims_to_reduce)
8585
except:
86-
dims = [dims] # type: ignore[assignment]
86+
dims_to_reduce = [dims_to_reduce] # type: ignore[assignment]
8787
else:
88-
dims = list(dims) # type: ignore[assignment]
88+
dims_to_reduce = list(dims_to_reduce) # type: ignore[assignment]
8989

9090
# Unroll multi-dimensional reduction and keep-dims arg
9191
with graph_module.graph.inserting_before(node):
92-
for dim in dims:
92+
for dim in dims_to_reduce:
9393
args = (input_node, dim, True)
9494
input_node = graph_module.graph.create_node(
9595
"call_function", exir_ops.edge.aten.any.dim, args, node.kwargs
9696
)
9797

9898
if not keepdim:
99-
args = (input_node, dims) # type: ignore[assignment]
99+
output_shape = list(get_first_fake_tensor(node).shape)
100100
input_node = graph_module.graph.create_node(
101101
"call_function",
102-
exir_ops.edge.aten.squeeze_copy.dims,
103-
args,
102+
exir_ops.edge.aten.view_copy.default,
103+
(input_node, output_shape),
104104
)
105105

106106
node.replace_all_uses_with(input_node)

backends/arm/operators/op_any.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def define_node(
4646
) # process the negative index
4747
keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False)
4848
if not keep_dim:
49-
raise ValueError("This case should be handled by ConvertAnyDimDimsPass")
49+
raise ValueError("This case should be handled by DecomposeAnyPass")
5050

5151
attr = ts.TosaSerializerAttribute()
5252
attr.ReduceAnyAttribute(inputs[0].dim_order.index(dim))

0 commit comments

Comments
 (0)