Skip to content

Commit 2198015

Browse files
committed
Add pass changing (un-)squeeze ops to view in TOSA lowering
Since squeeze ops are special cases of the view op it is enough to handle only view ops, removing the need for the squeeze/unsqueeze node visitors. Change-Id: Ib00d05692fa3a5e48c21f52eb2234c4495ac169e
1 parent f4e77c7 commit 2198015

File tree

6 files changed

+37
-145
lines changed

6 files changed

+37
-145
lines changed

backends/arm/_passes/annotate_channels_last_dim_order_pass.py

Lines changed: 2 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from executorch.backends.arm._passes.arm_pass_utils import (
1313
create_node,
1414
get_first_fake_tensor,
15-
get_node_arg,
1615
insert_q_dq_pair,
1716
)
1817
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
@@ -26,9 +25,8 @@
2625
# when lowering to TOSA, e.g. a passthrough_to_tosa._transpose will not affect
2726
# the edge IR graph but will be lowered to a TOSA-TRANSPOSE.
2827
lib = Library("passthrough_to_tosa", "DEF")
29-
# For operators that change the rank of the input, such as unsqueeze and squeeze, we may need
30-
# to switch dim_order before the opertation. Changing tosa_dim_order is not sufficient
31-
# as we also need transpose the data into the correct data format.
28+
# For certain operators we need the data in a specific data format. Changing tosa_dim_order
29+
# is not sufficient as we also need transpose the data.
3230
# By utilizing an edge IR passthrough operator we can keep the edge program in
3331
# channels-first/contiguous and get the desired behavior in the TOSA lowering.
3432
lib.define("_transpose(Tensor self, int[] dim_order) -> Tensor")
@@ -153,27 +151,6 @@ def insert_output_transpose(node, graph_module):
153151
q_params = node.args[0].args[1:]
154152
insert_q_dq_pair(graph_module.graph, node, q_params)
155153

156-
@staticmethod
157-
def _insert_squeeze_transpose(
158-
input_shape, output_shape, node, input_node, graph_module
159-
):
160-
nhwc_to_nhwc = len(input_shape) == 4 and len(output_shape) <= 3
161-
162-
if nhwc_to_nhwc and AnnotateChannelsLastDimOrder.memory_format_differs(
163-
input_shape
164-
):
165-
AnnotateChannelsLastDimOrder.insert_input_transpose(
166-
node, input_node, graph_module
167-
)
168-
169-
@staticmethod
170-
def _insert_unsqueeze_transpose(input_shape, output_shape, node, graph_module):
171-
nchw_to_nhwc = len(input_shape) == 3 and len(output_shape) == 4
172-
if nchw_to_nhwc and AnnotateChannelsLastDimOrder.memory_format_differs(
173-
output_shape
174-
):
175-
AnnotateChannelsLastDimOrder.insert_output_transpose(node, graph_module)
176-
177154
@staticmethod
178155
def _insert_view_transpose(
179156
input_shape, output_shape, node, input_node, graph_module
@@ -199,8 +176,6 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
199176
"""
200177
Transposes are needed for operators transforming the input to a different rank, as 4D-tensors are assumed to be in NHWC-format, whereas all other are in NCHW format.
201178
This is relevant for the following cases:
202-
- squeeze: 4D -> <4D
203-
- unsqueeze: 3D -> 4D
204179
- view: <4D -> 4D
205180
- view: 4D -> <4D
206181
Additionally, a 4D->4D view operation acting on the channel dimension currently needs to be performed in NCHW format, leadning to one extra input and output transpose for this case.
@@ -214,27 +189,6 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
214189
if node.op != "call_function":
215190
continue
216191

217-
if node.target == exir_ops.edge.aten.squeeze_copy.dims:
218-
input_node = node.args[0]
219-
input_shape = input_node.meta["val"].shape
220-
output_shape = node.meta["val"].shape
221-
222-
self._insert_squeeze_transpose(
223-
input_shape, output_shape, node, input_node, graph_module
224-
)
225-
226-
elif node.target == exir_ops.edge.aten.unsqueeze_copy.default:
227-
input_node = get_node_arg(node.args, 0, default_value=False)
228-
if input_node:
229-
input_shape = input_node.meta["val"].shape
230-
else:
231-
input_shape = ()
232-
output_shape = node.meta["val"].shape
233-
234-
self._insert_unsqueeze_transpose(
235-
input_shape, output_shape, node, graph_module
236-
)
237-
238192
elif node.target == exir_ops.edge.aten.view_copy.default:
239193
input_node = node.args[0]
240194
input_shape = input_node.meta["val"].shape

backends/arm/_passes/arm_pass_manager.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
from executorch.backends.arm._passes.convert_split_to_slice import (
2222
ConvertSplitToSlicePass,
2323
)
24+
from executorch.backends.arm._passes.convert_squeezes_to_view import (
25+
ConvertSqueezesToViewPass,
26+
)
2427
from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass
2528
from executorch.backends.arm._passes.decompose_layernorm_pass import (
2629
DecomposeLayerNormPass,
@@ -100,6 +103,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
100103
self.add_pass(KeepDimsFalseToSqueezePass())
101104
self.add_pass(Conv1dUnsqueezePass(exported_program))
102105
self.add_pass(DecomposeSelectPass())
106+
self.add_pass(ConvertSqueezesToViewPass())
103107

104108
self.add_pass(AnnotateChannelsLastDimOrder())
105109

@@ -135,6 +139,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
135139
self.add_pass(KeepDimsFalseToSqueezePass())
136140
self.add_pass(Conv1dUnsqueezePass(exported_program))
137141
self.add_pass(DecomposeSelectPass())
142+
self.add_pass(ConvertSqueezesToViewPass())
138143

139144
self.add_pass(AnnotateChannelsLastDimOrder())
140145

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
from executorch.exir.dialects._ops import ops as exir_ops
10+
from executorch.exir.pass_base import ExportPass
11+
12+
13+
class ConvertSqueezesToViewPass(ExportPass):
14+
"""
15+
Replaces squeeze/unsqueeze operators with view. These are simply special cases of the view op, so removing them gives us less cases to handle in the node visitiors.
16+
"""
17+
18+
def call_operator(self, op, args, kwargs, meta):
19+
if op not in [
20+
exir_ops.edge.aten.squeeze_copy.dims,
21+
exir_ops.edge.aten.unsqueeze_copy.default,
22+
]:
23+
return super().call_operator(op, args, kwargs, meta)
24+
25+
x = args[0]
26+
shape = meta["val"].size()
27+
view_args = (x, list(shape))
28+
return super().call_operator(
29+
exir_ops.edge.aten.view_copy.default, view_args, kwargs, meta
30+
)

backends/arm/operators/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,12 @@
3030
op_rsqrt,
3131
op_sigmoid,
3232
op_slice,
33-
op_squeeze,
3433
op_sub,
3534
op_sum,
3635
op_table,
3736
op_tanh,
3837
op_to_copy,
3938
op_transpose,
40-
op_unsqueeze,
4139
op_upsample_nearest2d,
4240
op_view,
4341
)

backends/arm/operators/op_squeeze.py

Lines changed: 0 additions & 43 deletions
This file was deleted.

backends/arm/operators/op_unsqueeze.py

Lines changed: 0 additions & 52 deletions
This file was deleted.

0 commit comments

Comments
 (0)