Skip to content

Commit a5db1ed

Browse files
Add TagUnquantizedNodesPass
For models with operations that are not quantized, this pass keeps unquantized operators on the CPU. For example, the deit-tiny-patch16-224 network has an unquantized scaled_dot_product_attention operation. When compiling to Vela, invalid argument errors occur because unquantized operations are offloaded to the NPU. This pass is designed to solve this problem.
1 parent 82763a9 commit a5db1ed

File tree

5 files changed

+186
-6
lines changed

5 files changed

+186
-6
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import torch
10+
from executorch.backends.arm.tosa_quant_utils import dq_q_ops, get_neighbour_quant_args
11+
from executorch.exir.pass_base import ExportPass, PassResult
12+
13+
14+
class TagUnquantizedNodesPass(ExportPass):
15+
"""
16+
Pass run before partitioning to tag unquantized nodes
17+
to ensure we don't greedily partition them for device. Unquantized operations must remain on the CPU.
18+
"""
19+
20+
def is_node_quantized(self, node: torch.fx.Node) -> bool:
21+
user_q_args, input_q_args = get_neighbour_quant_args(node)
22+
23+
# If there are no neighboring quantized nodes, then this node is not quantized except for constants,
24+
# they can only have a dequantization node.
25+
if (
26+
len(node.all_input_nodes) > 0
27+
and len(input_q_args) == 0
28+
or len(user_q_args) == 0
29+
):
30+
return False
31+
32+
return True
33+
34+
def call(self, graph_module: torch.fx.GraphModule):
35+
for node in graph_module.graph.nodes:
36+
# Look through operations that are not quantization or dequantization
37+
if node.op == "call_function" and node.target not in dq_q_ops:
38+
is_node_quantized = self.is_node_quantized(node)
39+
if not is_node_quantized:
40+
# For a non-quantized node, we tag the node and its inputs and outputs.
41+
node.meta["arm_override_partition"] = False
42+
for input_node in node.all_input_nodes:
43+
input_node.meta["arm_override_partition"] = False
44+
for user in node.users.keys():
45+
user.meta["arm_override_partition"] = False
46+
47+
graph_module.recompile()
48+
return PassResult(graph_module, True)

backends/arm/arm_backend.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def __init__(self):
5252
# TODO MLETORCH-265 Remove permute_nhwc flag
5353
self.permute_nhwc = False
5454
self.quantize_io = False
55+
self.unquantized_nodes_to_cpu = False
5556
self.tosa_version = None
5657
self.input_order = None
5758

@@ -146,6 +147,16 @@ def set_input_order(
146147
self.input_order = input_order
147148
return self
148149

150+
def set_unquantized_nodes_to_cpu(
151+
self, unquantized_nodes_to_cpu: bool = False
152+
) -> "ArmCompileSpecBuilder":
153+
"""
154+
For models with operations that are not quantized,
155+
this option keeps the unquantized operators on the CPU.
156+
"""
157+
self.unquantized_nodes_to_cpu = unquantized_nodes_to_cpu
158+
return self
159+
149160
def build(self) -> List[CompileSpec]:
150161
"""
151162
Generate a list of compile spec objects from the builder
@@ -185,6 +196,11 @@ def build(self) -> List[CompileSpec]:
185196
if self.quantize_io:
186197
self.compile_spec.append(CompileSpec("quantize_io", "True".encode()))
187198

199+
if self.unquantized_nodes_to_cpu:
200+
self.compile_spec.append(
201+
CompileSpec("unquantized_nodes_to_cpu", "True".encode())
202+
)
203+
188204
return self.compile_spec
189205

190206

backends/arm/arm_partitioner.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
import torch
1313
from executorch.backends.arm.arm_backend import ArmBackend # usort: skip
1414
from executorch.backends.arm._passes.tag_io_quant_pass import TagIOQuantPass
15+
from executorch.backends.arm._passes.tag_unquantized_nodes_pass import (
16+
TagUnquantizedNodesPass,
17+
)
1518
from executorch.backends.arm.operator_support.tosa_supported_operators import (
1619
TOSASupportedOperators,
1720
)
@@ -52,15 +55,17 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
5255

5356
logger.info(f"Partitioning for {tosa_spec}")
5457

58+
passes = []
5559
for spec in self.delegation_spec.compile_specs:
5660
if spec.key == "quantize_io" and spec.value.decode() == "True":
5761
# Exclude IO quantization from the partition
58-
passes = PassManager(
59-
passes=[
60-
TagIOQuantPass(),
61-
]
62-
)
63-
passes(exported_program.graph_module)
62+
passes.append(TagIOQuantPass())
63+
if spec.key == "unquantized_nodes_to_cpu" and spec.value.decode() == "True":
64+
# Exclude unquantized nodes from the partition
65+
passes.append(TagUnquantizedNodesPass())
66+
67+
passes = PassManager(passes=passes)
68+
passes(exported_program.graph_module)
6469

6570
capability_partitioner = CapabilityBasedPartitioner(
6671
exported_program.graph_module,

backends/arm/test/common.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def get_u55_compile_spec(
9393
quantize_io=False,
9494
custom_path=None,
9595
reorder_inputs=None,
96+
unquantized_nodes_to_cpu=False,
9697
) -> list[CompileSpec]:
9798
"""
9899
Default compile spec for Ethos-U55 tests.
@@ -102,6 +103,7 @@ def get_u55_compile_spec(
102103
quantize_io=quantize_io,
103104
custom_path=custom_path,
104105
reorder_inputs=reorder_inputs,
106+
unquantized_nodes_to_cpu=unquantized_nodes_to_cpu,
105107
).build()
106108

107109

@@ -110,6 +112,7 @@ def get_u85_compile_spec(
110112
quantize_io=False,
111113
custom_path=None,
112114
reorder_inputs=None,
115+
unquantized_nodes_to_cpu=False,
113116
) -> list[CompileSpec]:
114117
"""
115118
Default compile spec for Ethos-U85 tests.
@@ -119,6 +122,7 @@ def get_u85_compile_spec(
119122
quantize_io=quantize_io,
120123
custom_path=custom_path,
121124
reorder_inputs=reorder_inputs,
125+
unquantized_nodes_to_cpu=unquantized_nodes_to_cpu,
122126
).build()
123127

124128

@@ -127,6 +131,7 @@ def get_u55_compile_spec_unbuilt(
127131
quantize_io=False,
128132
custom_path=None,
129133
reorder_inputs=None,
134+
unquantized_nodes_to_cpu=False,
130135
) -> ArmCompileSpecBuilder:
131136
"""Get the ArmCompileSpecBuilder for the Ethos-U55 tests, to modify
132137
the compile spec before calling .build() to finalize it.
@@ -143,6 +148,7 @@ def get_u55_compile_spec_unbuilt(
143148
extra_flags="--debug-force-regor --output-format=raw",
144149
)
145150
.set_quantize_io(is_option_enabled("quantize_io") or quantize_io)
151+
.set_unquantized_nodes_to_cpu(unquantized_nodes_to_cpu)
146152
.set_permute_memory_format(permute_memory_to_nhwc)
147153
.dump_intermediate_artifacts_to(artifact_path)
148154
.set_input_order(reorder_inputs)
@@ -155,6 +161,7 @@ def get_u85_compile_spec_unbuilt(
155161
quantize_io=False,
156162
custom_path=None,
157163
reorder_inputs=None,
164+
unquantized_nodes_to_cpu=False,
158165
) -> list[CompileSpec]:
159166
"""Get the ArmCompileSpecBuilder for the Ethos-U85 tests, to modify
160167
the compile spec before calling .build() to finalize it.
@@ -169,6 +176,7 @@ def get_u85_compile_spec_unbuilt(
169176
extra_flags="--output-format=raw",
170177
)
171178
.set_quantize_io(is_option_enabled("quantize_io") or quantize_io)
179+
.set_unquantized_nodes_to_cpu(unquantized_nodes_to_cpu)
172180
.set_permute_memory_format(permute_memory_to_nhwc)
173181
.dump_intermediate_artifacts_to(artifact_path)
174182
.set_input_order(reorder_inputs)
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.arm.quantizer.arm_quantizer import (
11+
ArmQuantizer,
12+
get_symmetric_quantization_config,
13+
)
14+
15+
from executorch.backends.arm.test import common
16+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
17+
from executorch.backends.xnnpack.test.tester.tester import Quantize
18+
from executorch.exir.backend.compile_spec_schema import CompileSpec
19+
20+
21+
class TestModel(torch.nn.Module):
22+
23+
def get_inputs(self):
24+
return (torch.rand(1, 10, 10, 10), (torch.rand(1, 10, 10, 10)))
25+
26+
def forward(self, x, y):
27+
result = x + y
28+
result = result * y
29+
result = result * x
30+
result = result - y
31+
return result
32+
33+
34+
class TestTagUnquantizedNodesPass(unittest.TestCase):
35+
"""
36+
Tests the TagUnquantizedNodesPass which tags unquantized nodes on model
37+
to not include them in our partitions.
38+
"""
39+
40+
def _tosa_BI_pipeline(
41+
self, module: torch.nn.Module, compile_spec: list[CompileSpec]
42+
):
43+
quantizer = ArmQuantizer()
44+
# Quantize only add and sub nodes
45+
quantizer.STATIC_ANNOTATION_ORDER = [
46+
"add",
47+
"sub",
48+
]
49+
(
50+
ArmTester(
51+
module,
52+
example_inputs=module.get_inputs(),
53+
compile_spec=compile_spec,
54+
)
55+
.quantize(
56+
Quantize(
57+
quantizer,
58+
get_symmetric_quantization_config(is_per_channel=False),
59+
)
60+
)
61+
.export()
62+
.to_edge()
63+
.check_count(
64+
{
65+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 5
66+
}
67+
)
68+
.check_count(
69+
{
70+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 6
71+
}
72+
)
73+
.partition()
74+
.check_count(
75+
{
76+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3
77+
}
78+
)
79+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 2})
80+
.check_count(
81+
{
82+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2
83+
}
84+
)
85+
)
86+
87+
def test_BI_u55_artifact(self):
88+
model = TestModel()
89+
self._tosa_BI_pipeline(
90+
model,
91+
common.get_u55_compile_spec(
92+
quantize_io=True, unquantized_nodes_to_cpu=True
93+
),
94+
)
95+
96+
def test_BI_u85_artifact(self):
97+
model = TestModel()
98+
self._tosa_BI_pipeline(
99+
model,
100+
common.get_u85_compile_spec(
101+
quantize_io=True, unquantized_nodes_to_cpu=True
102+
),
103+
)

0 commit comments

Comments
 (0)