Skip to content

Commit 8178183

Browse files
AdrianLundellYIWENX14
authored andcommitted
Add pass changing (un-)squeeze ops to view in TOSA lowering (#7784)
Since squeeze ops are special cases of the view op it is enough to handle only view ops, removing the need for the squeeze/unsqueeze node visitors.
1 parent 32c343a commit 8178183

File tree

6 files changed

+37
-145
lines changed

6 files changed

+37
-145
lines changed

backends/arm/_passes/annotate_channels_last_dim_order_pass.py

Lines changed: 2 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from executorch.backends.arm._passes.arm_pass_utils import (
1313
create_node,
1414
get_first_fake_tensor,
15-
get_node_arg,
1615
insert_q_dq_pair,
1716
)
1817
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
@@ -26,9 +25,8 @@
2625
# when lowering to TOSA, e.g. a passthrough_to_tosa._transpose will not affect
2726
# the edge IR graph but will be lowered to a TOSA-TRANSPOSE.
2827
lib = Library("passthrough_to_tosa", "DEF")
29-
# For operators that change the rank of the input, such as unsqueeze and squeeze, we may need
30-
# to switch dim_order before the opertation. Changing tosa_dim_order is not sufficient
31-
# as we also need transpose the data into the correct data format.
28+
# For certain operators we need the data in a specific data format. Changing tosa_dim_order
29+
# is not sufficient as we also need transpose the data.
3230
# By utilizing an edge IR passthrough operator we can keep the edge program in
3331
# channels-first/contiguous and get the desired behavior in the TOSA lowering.
3432
lib.define("_transpose(Tensor self, int[] dim_order) -> Tensor")
@@ -153,27 +151,6 @@ def insert_output_transpose(node, graph_module):
153151
q_params = node.args[0].args[1:]
154152
insert_q_dq_pair(graph_module.graph, node, q_params)
155153

156-
@staticmethod
157-
def _insert_squeeze_transpose(
158-
input_shape, output_shape, node, input_node, graph_module
159-
):
160-
nhwc_to_nhwc = len(input_shape) == 4 and len(output_shape) <= 3
161-
162-
if nhwc_to_nhwc and AnnotateChannelsLastDimOrder.memory_format_differs(
163-
input_shape
164-
):
165-
AnnotateChannelsLastDimOrder.insert_input_transpose(
166-
node, input_node, graph_module
167-
)
168-
169-
@staticmethod
170-
def _insert_unsqueeze_transpose(input_shape, output_shape, node, graph_module):
171-
nchw_to_nhwc = len(input_shape) == 3 and len(output_shape) == 4
172-
if nchw_to_nhwc and AnnotateChannelsLastDimOrder.memory_format_differs(
173-
output_shape
174-
):
175-
AnnotateChannelsLastDimOrder.insert_output_transpose(node, graph_module)
176-
177154
@staticmethod
178155
def _insert_view_transpose(
179156
input_shape, output_shape, node, input_node, graph_module
@@ -199,8 +176,6 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
199176
"""
200177
Transposes are needed for operators transforming the input to a different rank, as 4D-tensors are assumed to be in NHWC-format, whereas all other are in NCHW format.
201178
This is relevant for the following cases:
202-
- squeeze: 4D -> <4D
203-
- unsqueeze: 3D -> 4D
204179
- view: <4D -> 4D
205180
- view: 4D -> <4D
206181
Additionally, a 4D->4D view operation acting on the channel dimension currently needs to be performed in NCHW format, leadning to one extra input and output transpose for this case.
@@ -214,27 +189,6 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
214189
if node.op != "call_function":
215190
continue
216191

217-
if node.target == exir_ops.edge.aten.squeeze_copy.dims:
218-
input_node = node.args[0]
219-
input_shape = input_node.meta["val"].shape
220-
output_shape = node.meta["val"].shape
221-
222-
self._insert_squeeze_transpose(
223-
input_shape, output_shape, node, input_node, graph_module
224-
)
225-
226-
elif node.target == exir_ops.edge.aten.unsqueeze_copy.default:
227-
input_node = get_node_arg(node.args, 0, default_value=False)
228-
if input_node:
229-
input_shape = input_node.meta["val"].shape
230-
else:
231-
input_shape = ()
232-
output_shape = node.meta["val"].shape
233-
234-
self._insert_unsqueeze_transpose(
235-
input_shape, output_shape, node, graph_module
236-
)
237-
238192
elif node.target == exir_ops.edge.aten.view_copy.default:
239193
input_node = node.args[0]
240194
input_shape = input_node.meta["val"].shape

backends/arm/_passes/arm_pass_manager.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
from executorch.backends.arm._passes.convert_split_to_slice import (
2222
ConvertSplitToSlicePass,
2323
)
24+
from executorch.backends.arm._passes.convert_squeezes_to_view import (
25+
ConvertSqueezesToViewPass,
26+
)
2427
from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass
2528
from executorch.backends.arm._passes.decompose_layernorm_pass import (
2629
DecomposeLayerNormPass,
@@ -100,6 +103,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
100103
self.add_pass(KeepDimsFalseToSqueezePass())
101104
self.add_pass(Conv1dUnsqueezePass(exported_program))
102105
self.add_pass(DecomposeSelectPass())
106+
self.add_pass(ConvertSqueezesToViewPass())
103107

104108
self.add_pass(AnnotateChannelsLastDimOrder())
105109

@@ -135,6 +139,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
135139
self.add_pass(KeepDimsFalseToSqueezePass())
136140
self.add_pass(Conv1dUnsqueezePass(exported_program))
137141
self.add_pass(DecomposeSelectPass())
142+
self.add_pass(ConvertSqueezesToViewPass())
138143

139144
self.add_pass(AnnotateChannelsLastDimOrder())
140145

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
from executorch.exir.dialects._ops import ops as exir_ops
10+
from executorch.exir.pass_base import ExportPass
11+
12+
13+
class ConvertSqueezesToViewPass(ExportPass):
14+
"""
15+
Replaces squeeze/unsqueeze operators with view. These are simply special cases of the view op, so removing them gives us less cases to handle in the node visitiors.
16+
"""
17+
18+
def call_operator(self, op, args, kwargs, meta):
19+
if op not in [
20+
exir_ops.edge.aten.squeeze_copy.dims,
21+
exir_ops.edge.aten.unsqueeze_copy.default,
22+
]:
23+
return super().call_operator(op, args, kwargs, meta)
24+
25+
x = args[0]
26+
shape = meta["val"].size()
27+
view_args = (x, list(shape))
28+
return super().call_operator(
29+
exir_ops.edge.aten.view_copy.default, view_args, kwargs, meta
30+
)

backends/arm/operators/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,12 @@
3030
op_rsqrt,
3131
op_sigmoid,
3232
op_slice,
33-
op_squeeze,
3433
op_sub,
3534
op_sum,
3635
op_table,
3736
op_tanh,
3837
op_to_copy,
3938
op_transpose,
40-
op_unsqueeze,
4139
op_upsample_nearest2d,
4240
op_view,
4341
)

backends/arm/operators/op_squeeze.py

Lines changed: 0 additions & 43 deletions
This file was deleted.

backends/arm/operators/op_unsqueeze.py

Lines changed: 0 additions & 52 deletions
This file was deleted.

0 commit comments

Comments
 (0)