Skip to content

Commit 756d5f8

Browse files
Arm backend: Fix bug in decompose_linear_pass (#13725)
The final reshape node of linear decomposition got incorrect quantization parameters which lead to incorrect outputs when that reshape could be fused at a later stage in the pass pipeline. Signed-off-by: Oscar Andersson <[email protected]>
1 parent b4e1145 commit 756d5f8

File tree

2 files changed

+55
-1
lines changed

2 files changed

+55
-1
lines changed

backends/arm/_passes/decompose_linear_pass.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,13 @@ def call(self, graph_module):
9090
kwargs={},
9191
from_node=node,
9292
)
93+
# Quantization parameters are inherited from original linear node, but
94+
# output reshape should use the linear node's output qparams for both input
95+
# and output.
96+
if "input_qparams" in output.meta:
97+
output.meta["input_qparams"] = output.meta.get(
98+
"output_qparams", None
99+
)
93100

94101
node.replace_all_uses_with(output)
95102
graph_module.graph.erase_node(node)

backends/arm/test/passes/test_fuse_constant_ops_pass.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
FuseConstantArgsPass,
1313
)
1414
from executorch.backends.arm.test import common
15-
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
15+
from executorch.backends.arm.test.tester.test_pipeline import (
16+
PassPipeline,
17+
TosaPipelineFP,
18+
TosaPipelineINT,
19+
)
1620

1721
input_t = Tuple[torch.Tensor] # Input x
1822
input_t2 = Tuple[torch.Tensor, torch.Tensor]
@@ -103,6 +107,22 @@ def forward(self, a, b):
103107
return torch.cat((a, b), dim=0)
104108

105109

110+
class LinearConst(torch.nn.Module):
111+
"""A linear layer that can be computed AOT"""
112+
113+
def __init__(self, in_out_features: int = 3, bias: bool = True):
114+
super().__init__()
115+
self.linear = torch.nn.Linear(in_out_features, in_out_features, bias=bias)
116+
self.example_input = torch.rand(in_out_features, in_out_features)
117+
118+
def forward(self, x: torch.Tensor):
119+
y = torch.full_like(x, 1.0)
120+
return self.linear(y) + x
121+
122+
def get_example_input(self):
123+
return self.example_input
124+
125+
106126
modules = {
107127
"fuse_parameter": FuseParameter(),
108128
"fuse_buffer": FuseBuffer(),
@@ -152,3 +172,30 @@ def test_fuse_const_ops_tosa_BI_cat(module: torch.nn.Module):
152172
passes_with_exported_program=[ComputeConstantOpsAOT, FuseConstantArgsPass],
153173
)
154174
pipeline.run()
175+
176+
177+
def test_linear_const_tosa_FP():
178+
model = LinearConst()
179+
example_input = model.get_example_input()
180+
pipeline = TosaPipelineFP[input_t](
181+
model,
182+
(example_input,),
183+
aten_op=[],
184+
exir_op=[],
185+
use_to_edge_transform_and_lower=True,
186+
)
187+
pipeline.run()
188+
189+
190+
def test_linear_const_tosa_INT():
191+
model = LinearConst()
192+
example_input = model.get_example_input()
193+
pipeline = TosaPipelineINT[input_t](
194+
model,
195+
(example_input,),
196+
aten_op=[],
197+
exir_op=[],
198+
per_channel_quantization=False,
199+
use_to_edge_transform_and_lower=True,
200+
)
201+
pipeline.run()

0 commit comments

Comments
 (0)