Skip to content

Commit 4860984

Browse files
Arm backend: Make sure DW-conv weights are reshaped once (pytorch#16071)
Make sure RewriteConv2DPass only reshapes shared weights once. ### Test plan Functionality is tested in backends/arm/test/misc/test_dw_convs_with_shared_weights.py. Signed-off-by: Oscar Andersson <[email protected]>
1 parent d536d18 commit 4860984

File tree

2 files changed

+68
-2
lines changed

2 files changed

+68
-2
lines changed

backends/arm/_passes/rewrite_conv2d_pass.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def _is_depthwise_conv2d(self, node: torch.fx.Node) -> bool:
9090
return False
9191
groups = node.args[-1]
9292
in_channels = get_first_fake_tensor(node.all_input_nodes[0]).shape[1]
93-
out_channels = get_first_fake_tensor(node.all_input_nodes[1]).shape[0]
93+
out_channels = get_first_fake_tensor(node).shape[1]
9494
return (in_channels == groups) and (out_channels % in_channels) == 0
9595

9696
def _reshape_weights(self, weight_node: torch.fx.Node, in_channels: int) -> None:
@@ -103,6 +103,7 @@ def _reshape_weights(self, weight_node: torch.fx.Node, in_channels: int) -> None
103103
raise RuntimeError(
104104
f"Weight node {weight_node.name} is not a parameter or buffer"
105105
)
106+
106107
reshaped_weight_tensor = (
107108
weight_tensor.permute(HWCM_ORDER)
108109
.reshape(
@@ -118,14 +119,19 @@ def _reshape_weights(self, weight_node: torch.fx.Node, in_channels: int) -> None
118119
param_name = self.exported_program.graph_signature.inputs_to_buffers[
119120
weight_node.name
120121
]
122+
reshaped_weight_tensor = torch.nn.Buffer(reshaped_weight_tensor)
121123
elif is_param(self.exported_program, weight_node):
122124
param_name = self.exported_program.graph_signature.inputs_to_parameters[
123125
weight_node.name
124126
]
127+
reshaped_weight_tensor = torch.nn.Parameter(
128+
reshaped_weight_tensor, requires_grad=False
129+
)
125130
else:
126131
raise RuntimeError(
127132
f"Weight node {weight_node.name} is neither a parameter nor a buffer"
128133
)
134+
129135
self.exported_program.state_dict[param_name] = reshaped_weight_tensor
130136
weight_node.meta["val"] = weight_node.meta["val"].reshape(
131137
weight_tensor.shape[2],
@@ -243,7 +249,9 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
243249

244250
if self._is_depthwise_conv2d(node):
245251
target_op = exir_ops.backend.tosa.DEPTHWISE_CONV2D.default
246-
self._reshape_weights(weight, input_fake_tensor.shape[1])
252+
# If there are any TOSA.DEPTHWISE_CONV2D nodes using the weights, we've already reshaped them.
253+
if all(user.target != target_op for user in weight.users):
254+
self._reshape_weights(weight, input_fake_tensor.shape[1])
247255
weight_fake_tensor = get_first_fake_tensor(weight)
248256
else:
249257
target_op = exir_ops.backend.tosa.CONV2D.default
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Any, Tuple
7+
8+
import torch
9+
from executorch.backends.arm._passes.rewrite_conv2d_pass import RewriteConv2dPass
10+
from executorch.backends.arm.test.tester.test_pipeline import (
11+
PassPipeline,
12+
TosaPipelineFP,
13+
TosaPipelineINT,
14+
)
15+
16+
input_t = Tuple[torch.Tensor]
17+
18+
19+
class DWConvsModule(torch.nn.Module):
20+
def __init__(self, *args: Any, **kwargs: Any) -> None:
21+
super().__init__(*args, **kwargs)
22+
conv = torch.nn.Conv2d(6, 6, kernel_size=(2, 2), groups=6)
23+
relu = torch.nn.ReLU()
24+
self.sequential = torch.nn.ModuleList([conv, relu, conv])
25+
26+
def forward(self, x) -> torch.Tensor:
27+
for m in self.sequential:
28+
x = m(x)
29+
return x
30+
31+
def get_inputs(self) -> input_t:
32+
return (torch.randn(1, 6, 24, 24),)
33+
34+
35+
def test_convs_tosa_fp():
36+
module = DWConvsModule()
37+
pipeline = TosaPipelineFP[input_t](
38+
module, module.get_inputs(), aten_op=[], exir_op=[]
39+
)
40+
pipeline.run()
41+
42+
43+
def test_convs_tosa_int():
44+
module = DWConvsModule()
45+
pipeline = TosaPipelineINT[input_t](
46+
module, module.get_inputs(), aten_op=[], exir_op=[]
47+
)
48+
pipeline.run()
49+
50+
51+
def test_rewrite_conv_pass():
52+
module = DWConvsModule()
53+
pipeline = PassPipeline(
54+
module, module.get_inputs(), passes_with_exported_program=[RewriteConv2dPass]
55+
)
56+
# We can't run TOSA backend dialect operators in eager mode
57+
pipeline.pop_stage("run_method_and_compare_outputs")
58+
pipeline.run()

0 commit comments

Comments
 (0)