Skip to content

Commit 55ea36d

Browse files
Arm backend: Add check to not partition float inputs for BI (#8681)
Add check to not partition float inputs for BI Floats are not supported in TOSA BI profile. Some supported operators are only quantized if the previous node was quantized. In practice, this means that if an unsupported operator preceeds such an operator, it will not be quantized and the input will be a float. This will likely lead to an assertion error or invalid TOSA graph. This patch aims to detect such nodes, and to reject them. Signed-off-by: Oscar Andersson <[email protected]>
1 parent 5848cc3 commit 55ea36d

File tree

3 files changed

+132
-12
lines changed

3 files changed

+132
-12
lines changed

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -310,11 +310,11 @@ def is_node_supported(
310310
if not input_quantized:
311311
return False
312312

313-
output_quantized = output_quantized or all(
314-
(output_node.target == self.q_op)
315-
or (not get_first_fake_tensor(output_node).dtype.is_floating_point)
316-
for output_node in node.users
313+
all_q_users = all(
314+
(output_node.target == self.q_op) for output_node in node.users
317315
)
316+
is_floating_point = get_first_fake_tensor(node).dtype.is_floating_point
317+
output_quantized = output_quantized or all_q_users or not is_floating_point
318318

319319
if not output_quantized:
320320
return False

backends/arm/test/misc/test_partition_decomposed_quantized_ops.py

Lines changed: 112 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,39 @@
1919
)
2020

2121
input_t1 = Tuple[torch.Tensor]
22-
aten_op: list[str] = ["torch.ops.aten.add.Tensor", "torch.ops.aten.softplus.default"]
23-
exir_op: list[str] = [
22+
softplus_aten_op: list[str] = [
23+
"torch.ops.aten.add.Tensor",
24+
"torch.ops.aten.softplus.default",
25+
]
26+
softplus_exir_op: list[str] = [
2427
"executorch_exir_dialects_edge__ops_aten_add_Tensor",
2528
"executorch_exir_dialects_edge__ops_aten_mul_Tensor",
2629
"executorch_exir_dialects_edge__ops_aten_exp_default",
2730
"executorch_exir_dialects_edge__ops_aten_div_Tensor",
2831
]
2932

33+
linear_residual_aten_op: list[str] = [
34+
"torch.ops.aten.linear.default",
35+
"torch.ops.aten.gelu.default",
36+
"torch.ops.aten.dropout.default",
37+
"torch.ops.aten.add.Tensor",
38+
]
39+
linear_residual_exir_op: list[str] = [
40+
"executorch_exir_dialects_edge__ops_aten_gelu_default",
41+
"executorch_exir_dialects_edge__ops_aten_clone_default",
42+
"executorch_exir_dialects_edge__ops_aten_linear_default",
43+
"executorch_exir_dialects_edge__ops_aten_add_Tensor",
44+
]
45+
3046

3147
test_data: dict[input_t1] = {
3248
"3d_rand": (torch.rand(1, 5, 5),),
3349
}
3450

3551

36-
class Module(torch.nn.Module):
52+
class SoftplusModule(torch.nn.Module):
53+
"""Module containing an addition followed by a Softplus. Softplus is currently not supported by TosaBackend."""
54+
3755
def __init__(self):
3856
super().__init__()
3957
self.softplus = torch.nn.Softplus()
@@ -42,10 +60,35 @@ def forward(self, x: torch.Tensor):
4260
return self.softplus(x + x)
4361

4462

63+
class LinearResidualModule(torch.nn.Module):
64+
"""Module containing a residual and a linear layer followed by GELU and a Dropout.
65+
GELU is currently not supported by TosaBackend nor TosaQuantizer.
66+
"""
67+
68+
def __init__(
69+
self,
70+
):
71+
super().__init__()
72+
self.linear = torch.nn.Linear(in_features=5, out_features=3)
73+
self.gelu = torch.nn.GELU()
74+
self.dropout = torch.nn.Dropout(0.5)
75+
76+
def forward(self, x: torch.Tensor):
77+
x1 = self.linear(x)
78+
x2 = self.gelu(x1)
79+
x3 = self.dropout(x2)
80+
return x1 + x3
81+
82+
83+
# Softplus is decomposed which messes up the quantization. This test tests that CheckProperQuantization does not
84+
# partition nodes where quantization is not as expected.
4585
@common.parametrize("test_data", test_data)
4686
def test_softplus_tosa_MI(test_data: input_t1):
4787
pipeline = TosaPipelineMI[input_t1](
48-
Module(), test_data=test_data, aten_op=aten_op, exir_op=exir_op
88+
SoftplusModule(),
89+
test_data=test_data,
90+
aten_op=softplus_aten_op,
91+
exir_op=softplus_exir_op,
4992
)
5093
# remove check_count.exir as there will be more than one delegate
5194
pipeline.pop_stage("check_count.exir")
@@ -55,14 +98,76 @@ def test_softplus_tosa_MI(test_data: input_t1):
5598
@common.parametrize("test_data", test_data)
5699
def test_softplus_tosa_BI(test_data: input_t1):
57100
pipeline = TosaPipelineBI[input_t1](
58-
Module(), test_data=test_data, aten_op=aten_op, exir_op=exir_op
101+
SoftplusModule(),
102+
test_data=test_data,
103+
aten_op=softplus_aten_op,
104+
exir_op=softplus_exir_op,
105+
)
106+
pipeline.pop_stage("check_not.exir")
107+
# check that all ops in softplus_exir_op except add are rejected
108+
pipeline.add_stage_after(
109+
"to_edge_transform_and_lower",
110+
pipeline.tester.check,
111+
softplus_exir_op[1:],
112+
suffix="exir_post_partition",
113+
)
114+
pipeline.run()
115+
116+
117+
# Since GELU will not be quantized by TosaQuantizer, the Dropout's input will not be quantized either.
118+
# If so, the Dropout should not be partitioned by TosaPartitioner for TOSA BI profile. This test tests that the
119+
# partitioner indeed does not partition the Dropout (clone) for TOSA BI.
120+
@common.parametrize("test_data", test_data)
121+
def test_linear_residaul_tosa_MI(test_data: input_t1):
122+
pipeline = TosaPipelineMI[input_t1](
123+
LinearResidualModule(),
124+
test_data=test_data,
125+
aten_op=linear_residual_aten_op,
126+
exir_op=linear_residual_exir_op,
127+
use_to_edge_transform_and_lower=True,
128+
)
129+
# remove check_count.exir as there will be more than one delegate
130+
pipeline.pop_stage("check_count.exir")
131+
pipeline.pop_stage("check_not.exir")
132+
# check that all ops in linear_residual_exir_op except GELU are partitioned
133+
pipeline.add_stage_after(
134+
"to_edge_transform_and_lower",
135+
pipeline.tester.check_not,
136+
linear_residual_exir_op[1:],
137+
suffix="exir_post_partition",
138+
)
139+
pipeline.add_stage_after(
140+
"to_edge_transform_and_lower",
141+
pipeline.tester.check,
142+
linear_residual_exir_op[:1],
143+
suffix="exir_post_partition",
144+
)
145+
pipeline.run()
146+
147+
148+
@common.parametrize("test_data", test_data)
149+
def test_linear_residual_tosa_BI(test_data: input_t1):
150+
pipeline = TosaPipelineBI[input_t1](
151+
LinearResidualModule(),
152+
test_data=test_data,
153+
aten_op=linear_residual_aten_op,
154+
exir_op=linear_residual_exir_op,
155+
use_to_edge_transform_and_lower=True,
59156
)
157+
# remove check_count.exir as there will be more than one delegate
158+
pipeline.pop_stage("check_count.exir")
60159
pipeline.pop_stage("check_not.exir")
61-
# check that all ops in exir_op except add are rejected
160+
# check that all ops in linear_residual_exir_op except GELU and Dropout are partitioned
161+
pipeline.add_stage_after(
162+
"to_edge_transform_and_lower",
163+
pipeline.tester.check_not,
164+
linear_residual_exir_op[2:],
165+
suffix="exir_post_partition",
166+
)
62167
pipeline.add_stage_after(
63168
"to_edge_transform_and_lower",
64169
pipeline.tester.check,
65-
exir_op[1:],
170+
linear_residual_exir_op[:2],
66171
suffix="exir_post_partition",
67172
)
68173
pipeline.run()

backends/arm/tosa_partitioner.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
get_tosa_spec,
1515
is_tosa,
1616
) # usort: skip
17+
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
1718
from executorch.backends.arm.operator_support.tosa_supported_operators import (
1819
tosa_support_factory,
1920
)
@@ -66,7 +67,7 @@ def __init__(
6667
self.delegation_spec = DelegationSpec(TOSABackend.__name__, compile_spec)
6768
self.additional_checks = additional_checks
6869

69-
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
70+
def partition(self, exported_program: ExportedProgram) -> PartitionResult: # noqa
7071
# Run the CapabilityBasedPartitioner to return the largest possible
7172
# subgraphs containing the nodes with the tags
7273

@@ -110,6 +111,20 @@ def is_partitioned(node: torch.fx.Node, tag=tag) -> bool:
110111
del node.meta["delegation_tag"]
111112
break
112113

114+
if tosa_spec.support_float():
115+
continue
116+
117+
if is_partitioned(node):
118+
for input in node.all_input_nodes:
119+
if is_partitioned(input):
120+
continue
121+
if get_first_fake_tensor(input).dtype.is_floating_point:
122+
logger.info(
123+
f"Not partitioning {node.name} becuase input {input.name} has floating point dtype."
124+
)
125+
del node.meta["delegation_tag"]
126+
break
127+
113128
tag_constant_data(exported_program)
114129

115130
return PartitionResult(

0 commit comments

Comments
 (0)