Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions backends/arm/test/models/test_nn_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,6 @@ def test_nn_Modules_FP(test_data):
"test_data",
test_parameters,
xfails={
"GRUModule": "RuntimeError: Node aten_linear_default with op <EdgeOpOverload: aten.linear[...]> was not decomposed or delegated.",
"PReLUModule": "RuntimeError: mul(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.",
"TransformerModule": "AssertionError: Output 0 does not match reference output.",
},
)
Expand Down
1 change: 1 addition & 0 deletions backends/arm/test/tester/arm_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def __init__(
StageType.QUANTIZE,
StageType.EXPORT,
]
self.original_module.requires_grad_(False)

# Initial model needs to be set as a *possible* but not yet added Stage, therefore add None entry.
self.stages[StageType.INITIAL_MODEL] = None
Expand Down
41 changes: 28 additions & 13 deletions backends/arm/tosa/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def ops_to_not_decompose(
ops_to_not_decompose_if_quant_op = [
torch.ops.aten.hardsigmoid.default,
torch.ops.aten.hardswish.default,
torch.ops.aten.linear.default,
]

def filter_fn(node: torch.fx.Node) -> bool:
Expand All @@ -355,31 +356,45 @@ def filter_fn(node: torch.fx.Node) -> bool:
bool: True to keep the op intact; otherwise, False.

"""
dq = torch.ops.quantized_decomposed.dequantize_per_tensor.default
q = torch.ops.quantized_decomposed.quantize_per_tensor.default
dq = (
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_channel.default,
)
q = (
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_channel.default,
)

if node.target in ops_to_not_decompose_if_quant_op:
# Assume we should not decompose the operator (it is quantized)
should_not_decompose = True
correct_output_quant = True
correct_input_quant = True

input_nodes = node.all_input_nodes
ouput_nodes = node.users
output_nodes = node.users

for inp in input_nodes:
if inp.target != dq:
should_not_decompose = False

for out in ouput_nodes:
if out.target != q:
should_not_decompose = False

return should_not_decompose
if inp.target not in dq:
correct_input_quant = False

for out in output_nodes:
if out.target not in q:
correct_output_quant = False
# In some cases, a linear is quantized together with its activation.
if (
node.target == torch.ops.aten.linear.default
and len(output_nodes) == 1
and list(output_nodes)[0].target
in (torch.ops.aten.relu.default, torch.ops.aten.hardtanh.default)
):
correct_output_quant = True

return correct_input_quant and correct_output_quant

# By default, do not decompose the operator
return True

ops_to_not_decompose = [
torch.ops.aten.linear.default,
torch.ops.aten.eye.default,
torch.ops.aten.linspace.default,
torch.ops.aten.logit.default,
Expand Down
Loading