Skip to content

Commit 94ce147

Browse files
authored
Arm backend: Avoid not decomposing linears we reject (#15406)
If a linear is not quantized properly, we will reject it when partitioning. However, if we tell Executorch to _not_ not decompose an op, we are required to partition it. We thus need to figure out if we will partition the linear or not in the ops_not_to_decompose filter function. Also turn off grad in the arm tester to solve an error that popped up in the GRU model. Since we only do inference, grad is never relevant. Signed-off-by: Erik Lundell <[email protected]>
1 parent ed91b6a commit 94ce147

File tree

3 files changed

+29
-15
lines changed

3 files changed

+29
-15
lines changed

backends/arm/test/models/test_nn_modules.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,6 @@ def test_nn_Modules_FP(test_data):
137137
"test_data",
138138
test_parameters,
139139
xfails={
140-
"GRUModule": "RuntimeError: Node aten_linear_default with op <EdgeOpOverload: aten.linear[...]> was not decomposed or delegated.",
141-
"PReLUModule": "RuntimeError: mul(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.",
142140
"TransformerModule": "AssertionError: Output 0 does not match reference output.",
143141
},
144142
)

backends/arm/test/tester/arm_tester.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ def __init__(
266266
StageType.QUANTIZE,
267267
StageType.EXPORT,
268268
]
269+
self.original_module.requires_grad_(False)
269270

270271
# Initial model needs to be set as a *possible* but not yet added Stage, therefore add None entry.
271272
self.stages[StageType.INITIAL_MODEL] = None

backends/arm/tosa/partitioner.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,7 @@ def ops_to_not_decompose(
338338
ops_to_not_decompose_if_quant_op = [
339339
torch.ops.aten.hardsigmoid.default,
340340
torch.ops.aten.hardswish.default,
341+
torch.ops.aten.linear.default,
341342
]
342343

343344
def filter_fn(node: torch.fx.Node) -> bool:
@@ -355,31 +356,45 @@ def filter_fn(node: torch.fx.Node) -> bool:
355356
bool: True to keep the op intact; otherwise, False.
356357
357358
"""
358-
dq = torch.ops.quantized_decomposed.dequantize_per_tensor.default
359-
q = torch.ops.quantized_decomposed.quantize_per_tensor.default
359+
dq = (
360+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
361+
torch.ops.quantized_decomposed.dequantize_per_channel.default,
362+
)
363+
q = (
364+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
365+
torch.ops.quantized_decomposed.quantize_per_channel.default,
366+
)
360367

361368
if node.target in ops_to_not_decompose_if_quant_op:
362369
# Assume we should not decompose the operator (it is quantized)
363-
should_not_decompose = True
370+
correct_output_quant = True
371+
correct_input_quant = True
364372

365373
input_nodes = node.all_input_nodes
366-
ouput_nodes = node.users
374+
output_nodes = node.users
367375

368376
for inp in input_nodes:
369-
if inp.target != dq:
370-
should_not_decompose = False
371-
372-
for out in ouput_nodes:
373-
if out.target != q:
374-
should_not_decompose = False
375-
376-
return should_not_decompose
377+
if inp.target not in dq:
378+
correct_input_quant = False
379+
380+
for out in output_nodes:
381+
if out.target not in q:
382+
correct_output_quant = False
383+
# In some cases, a linear is quantized together with its activation.
384+
if (
385+
node.target == torch.ops.aten.linear.default
386+
and len(output_nodes) == 1
387+
and list(output_nodes)[0].target
388+
in (torch.ops.aten.relu.default, torch.ops.aten.hardtanh.default)
389+
):
390+
correct_output_quant = True
391+
392+
return correct_input_quant and correct_output_quant
377393

378394
# By default, do not decompose the operator
379395
return True
380396

381397
ops_to_not_decompose = [
382-
torch.ops.aten.linear.default,
383398
torch.ops.aten.eye.default,
384399
torch.ops.aten.linspace.default,
385400
torch.ops.aten.logit.default,

0 commit comments

Comments
 (0)