Skip to content

Commit 85d274a

Browse files
Arm backend: Size adjust conv2d pass improvements (#7646)
[Arm backend] Improve the documentation of the size adjust conv2d pass and remove duplicated code. Also add more tests to conv1d and conv2d that need to go through the pass.
1 parent 3ef100d commit 85d274a

File tree

3 files changed

+200
-53
lines changed

3 files changed

+200
-53
lines changed

backends/arm/_passes/size_adjust_conv2d_pass.py

Lines changed: 51 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,73 +1,74 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
# All rights reserved.
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

77
# pyre-unsafe
88

9-
from typing import cast, Optional
9+
from typing import cast
1010

1111
import torch.fx
12+
from executorch.backends.arm._passes.arm_pass_utils import create_node
1213
from executorch.exir.dialects._ops import ops as exir_ops
1314
from executorch.exir.pass_base import ExportPass, PassResult
14-
from torch._ops import OpOverload
1515

1616

1717
def conv_remainder(input_length, pad, dilation, weight, stride):
1818
"""
19-
Returns the size
19+
Returns the remainder of input_length; given the padding, dilation, stride,
20+
and kernel size.
2021
"""
2122
return (input_length + 2 * pad - dilation * (weight - 1) - 1) % stride
2223

2324

24-
def insert_q_dq_pair(
25-
graph: torch.fx.Graph,
26-
anchor: torch.fx.Node,
27-
q_params: tuple,
28-
):
29-
with graph.inserting_after(anchor):
30-
q = create_node(
31-
graph=graph,
32-
op_target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
33-
args=(), # We add the argument last
34-
)
35-
q.meta = anchor.meta
36-
37-
with graph.inserting_after(q):
38-
dq = create_node(
39-
graph=graph,
40-
op_target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
41-
args=(q,) + q_params,
42-
)
43-
dq.meta = q.meta
44-
45-
anchor.replace_all_uses_with(dq)
46-
# We add this last so the replace all uses above does not replace the quantized
47-
# node's first use
48-
q.args = (anchor,) + q_params
49-
return dq
50-
51-
52-
def create_node(
53-
graph: torch.fx.Graph,
54-
op_target: OpOverload,
55-
args: tuple = (),
56-
kwargs: Optional[dict] = None,
57-
):
58-
return graph.create_node(
59-
"call_function",
60-
op_target,
61-
args=args,
62-
kwargs=kwargs or {},
63-
)
64-
65-
6625
class SizeAdjustConv2DPass(ExportPass):
6726
"""
68-
Adjust the convolution input size to match perfectly with the
69-
weight size, padding, stride and dilation parameters.
70-
This is done by inserting a slice op to remove the uneven end of the input.
27+
Adjust the convolution input size to match the kernel size, padding, stride,
28+
and dilation parameters. Pytorch allows the input and kernel shape to not
29+
"match", in which case the remaining rows/columns are truncated. However,
30+
matching the size is a requirement in the TOSA specification. In case the
31+
input and kernel shape do not match, the following is done to meet the
32+
specification:
33+
34+
1) The padding is truncated (done in the node visitor)
35+
2) (if neccessary) The input is truncated (done in this pass)."
36+
37+
A simple example would be a 2x2 kernel (no padding, stride=2) and a 5x5
38+
input:
39+
40+
┌───┬───┬───┬───┬───┐ ┌───┬───┬───┬───┬───┐ ┌───┬───┬───┬───┬───┐
41+
│ X │ X │ │ │ │ │ │ │ X │ X │ │ │ │ │ │ │ - │
42+
├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤
43+
│ X │ X │ │ │ │ │ │ │ X │ X │ │ │ │ │ │ │ - │
44+
├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤
45+
│ │ │ │ │ │ -> │ │ │ │ │ │ -> │ X │ X │ │ │ │ ->
46+
├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤
47+
│ │ │ │ │ │ │ │ │ │ │ │ │ X │ X │ │ │ │
48+
├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤
49+
│ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │
50+
└───┴───┴───┴───┴───┘ └───┴───┴───┴───┴───┘ └───┴───┴───┴───┴───┘
51+
First pass second pass third pass
52+
53+
┌───┬───┬───┬───┬───┐ ┌───┬───┬───┬───┬───┐
54+
│ │ │ │ │ │ │ │ │ │ │ - │
55+
├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤
56+
│ │ │ │ │ │ │ │ │ │ │ - │
57+
├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤
58+
│ │ │ X │ X │ │ -> │ │ │ │ │ - │
59+
├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤
60+
│ │ │ X │ X │ │ │ │ │ │ │ - │
61+
├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤
62+
│ │ │ │ │ │ │ - │ - │ - │ - │ - │
63+
└───┴───┴───┴───┴───┘ └───┴───┴───┴───┴───┘
64+
Fourth pass Unvisited cells
65+
66+
Cells that are never visited are marked with `-` and are never considered
67+
when the kernel traverses over the input, hence they can be removed.
68+
69+
To match the shape of the kernel (and all parameters) with the input, a
70+
slice op is inserted to remove the remaining edges (rows and columns) of the
71+
input.
7172
"""
7273

7374
conv2d_op = exir_ops.edge.aten.convolution.default
@@ -109,9 +110,7 @@ def call(self, graph_module: torch.fx.GraphModule):
109110
with graph_module.graph.inserting_before(node):
110111
last_node = cast(torch.fx.Node, input_node)
111112
for args in slice_args:
112-
slice_node = graph.create_node(
113-
"call_function", self.slice_op, (last_node,) + args
114-
)
113+
slice_node = create_node(graph, self.slice_op, (last_node,) + args)
115114
last_node = slice_node
116115
conv_node.replace_input_with(cast(torch.fx.Node, input_node), last_node)
117116
modified_graph = True

backends/arm/test/ops/test_conv1d.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,47 @@ def forward(self, x):
180180
batches=1,
181181
)
182182

183+
conv1d_7_1x3x16_st2_pd1_dl2 = Conv1d(
184+
in_channels=3,
185+
out_channels=3,
186+
kernel_size=7,
187+
stride=2,
188+
padding=1,
189+
dilation=2,
190+
length=16,
191+
batches=1,
192+
)
193+
conv1d_7_1x3x15_st1_pd0_dl1 = Conv1d(
194+
in_channels=3,
195+
out_channels=3,
196+
kernel_size=7,
197+
stride=1,
198+
padding=0,
199+
dilation=1,
200+
length=15,
201+
batches=1,
202+
)
203+
conv1d_5_1x3x14_st5_pd0_dl1 = Conv1d(
204+
in_channels=3,
205+
out_channels=3,
206+
kernel_size=5,
207+
stride=5,
208+
padding=0,
209+
dilation=1,
210+
length=14,
211+
batches=1,
212+
)
213+
conv1d_5_1x3x9_st5_pd0_dl1 = Conv1d(
214+
in_channels=3,
215+
out_channels=3,
216+
kernel_size=5,
217+
stride=5,
218+
padding=0,
219+
dilation=1,
220+
length=9,
221+
batches=1,
222+
)
223+
183224
two_conv1d_nobias = Conv1d(
184225
nbr_conv=2,
185226
length=256,
@@ -214,6 +255,10 @@ def forward(self, x):
214255
("2_1x2x14_st2", conv1d_2_1x2x14_st2),
215256
("5_3x2x128_st1", conv1d_5_3x2x128_st1),
216257
("3_1x3x224_st2_pd1", conv1d_3_1x3x224_st2_pd1),
258+
("7_1x3x16_st2_pd1_dl2_needs_adjust_pass", conv1d_7_1x3x16_st2_pd1_dl2),
259+
("7_1x3x15_st1_pd0_dl1_needs_adjust_pass", conv1d_7_1x3x15_st1_pd0_dl1),
260+
("5_1x3x14_st5_pd0_dl1_needs_adjust_pass", conv1d_5_1x3x14_st5_pd0_dl1),
261+
("5_1x3x9_st5_pd0_dl1_needs_adjust_pass", conv1d_5_1x3x9_st5_pd0_dl1),
217262
("two_conv1d_nobias", two_conv1d_nobias),
218263
("two_conv1d", two_conv1d),
219264
]

backends/arm/test/ops/test_conv2d.py

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,101 @@ def forward(self, x):
201201
batches=1,
202202
)
203203

204+
conv2d_7x7_1x3x16x16_st2_pd1_dl2 = Conv2d(
205+
in_channels=3,
206+
out_channels=3,
207+
kernel_size=(7, 7),
208+
stride=2,
209+
padding=1,
210+
dilation=2,
211+
width=16,
212+
height=16,
213+
batches=1,
214+
)
215+
216+
conv2d_7x7_1x3x15x15_st1_pd0_dl1 = Conv2d(
217+
in_channels=3,
218+
out_channels=3,
219+
kernel_size=(7, 7),
220+
stride=1,
221+
padding=0,
222+
dilation=1,
223+
width=15,
224+
height=15,
225+
batches=1,
226+
)
227+
228+
conv2d_5x5_1x3x14x14_st5_pd0_dl1 = Conv2d(
229+
in_channels=3,
230+
out_channels=3,
231+
kernel_size=(5, 5),
232+
stride=5,
233+
padding=0,
234+
dilation=1,
235+
width=14,
236+
height=14,
237+
batches=1,
238+
)
239+
240+
conv2d_5x5_1x3x9x9_st5_pd0_dl1 = Conv2d(
241+
in_channels=3,
242+
out_channels=3,
243+
kernel_size=(5, 5),
244+
stride=5,
245+
padding=0,
246+
dilation=1,
247+
width=9,
248+
height=9,
249+
batches=1,
250+
)
251+
252+
conv2d_3x3_1x3x8x9_st3_pd0_dl1 = Conv2d(
253+
in_channels=3,
254+
out_channels=3,
255+
kernel_size=(3, 3),
256+
stride=3,
257+
padding=0,
258+
dilation=1,
259+
width=8,
260+
height=9,
261+
batches=1,
262+
)
263+
264+
conv2d_3x3_1x3x9x8_st3_pd0_dl1 = Conv2d(
265+
in_channels=3,
266+
out_channels=3,
267+
kernel_size=(3, 3),
268+
stride=3,
269+
padding=0,
270+
dilation=1,
271+
width=8,
272+
height=9,
273+
batches=1,
274+
)
275+
276+
conv2d_3x4_1x3x7x7_st3_pd0_dl1 = Conv2d(
277+
in_channels=3,
278+
out_channels=3,
279+
kernel_size=(3, 4),
280+
stride=3,
281+
padding=0,
282+
dilation=1,
283+
width=7,
284+
height=7,
285+
batches=1,
286+
)
287+
288+
conv2d_4x3_1x3x7x7_st3_pd0_dl1 = Conv2d(
289+
in_channels=3,
290+
out_channels=3,
291+
kernel_size=(4, 3),
292+
stride=3,
293+
padding=0,
294+
dilation=1,
295+
width=7,
296+
height=7,
297+
batches=1,
298+
)
204299

205300
two_conv2d_nobias = Conv2d(
206301
nbr_conv=2,
@@ -236,7 +331,15 @@ def forward(self, x):
236331
("3x3_1x3x12x12_st2_pd1", conv2d_3x3_1x3x12x12_st2_pd1),
237332
("1x1_1x2x128x128_st1", conv2d_1x1_1x2x128x128_st1),
238333
("2x2_1x1x14x13_st2_needs_adjust_pass", conv2d_2x2_1x1x14x13_st2),
239-
("conv2d_5x5_1x3x14x15_st3_pd1_needs_adjust_pass", conv2d_5x5_1x3x14x15_st3_pd1),
334+
("5x5_1x3x14x15_st3_pd1_needs_adjust_pass", conv2d_5x5_1x3x14x15_st3_pd1),
335+
("7x7_1x3x16x16_st2_pd1_dl2_needs_adjust_pass", conv2d_7x7_1x3x16x16_st2_pd1_dl2),
336+
("7x7_1x3x15x15_st1_pd0_dl1_needs_adjust_pass", conv2d_7x7_1x3x15x15_st1_pd0_dl1),
337+
("5x5_1x3x14x14_st5_pd0_dl1_needs_adjust_pass", conv2d_5x5_1x3x14x14_st5_pd0_dl1),
338+
("5x5_1x3x9x9_st5_pd0_dl1_needs_adjust_pass", conv2d_5x5_1x3x9x9_st5_pd0_dl1),
339+
("3x3_1x3x9x8_st3_pd0_dl1_needs_adjust_pass", conv2d_3x3_1x3x9x8_st3_pd0_dl1),
340+
("3x3_1x3x8x9_st3_pd0_dl1_needs_adjust_pass", conv2d_3x3_1x3x8x9_st3_pd0_dl1),
341+
("3x4_1x3x7x7_st3_pd0_dl1_needs_adjust_pass", conv2d_3x4_1x3x7x7_st3_pd0_dl1),
342+
("4x3_1x3x7x7_st3_pd0_dl1_needs_adjust_pass", conv2d_4x3_1x3x7x7_st3_pd0_dl1),
240343
("5x5_3x2x128x128_st1", conv2d_5x5_3x2x128x128_st1),
241344
("3x3_1x3x224x224_st2_pd1", conv2d_3x3_1x3x224x224_st2_pd1),
242345
("two_conv2d_nobias", two_conv2d_nobias),

0 commit comments

Comments
 (0)