Skip to content

Commit 3ef100d

Browse files
authored
Use ArmQuantizer to quantize bias (#7649)
Remove the 'manual' quantization of bias parameter and let the quantizer handle the quantization instead. Signed-off-by: Per Åstrand <[email protected]>
1 parent 24f0d34 commit 3ef100d

File tree

5 files changed

+51
-78
lines changed

5 files changed

+51
-78
lines changed

backends/arm/process_node.py

Lines changed: 8 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -11,11 +11,6 @@
1111
import serializer.tosa_serializer as ts
1212
import torch
1313
import torch.fx
14-
15-
# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
16-
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
17-
get_input_qparams,
18-
)
1914
from executorch.backends.arm.operators.node_visitor import NodeVisitor
2015
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
2116
from executorch.backends.arm.tosa_quant_utils import (
@@ -24,11 +19,7 @@
2419
is_node_quantized,
2520
)
2621
from executorch.backends.arm.tosa_specification import TosaSpecification
27-
from executorch.backends.arm.tosa_utils import (
28-
getNodeArgs,
29-
is_bias_node_for_quantized_conv,
30-
tosa_shape,
31-
)
22+
from executorch.backends.arm.tosa_utils import getNodeArgs, tosa_shape
3223
from torch.export.exported_program import ExportedProgram
3324

3425

@@ -99,41 +90,6 @@ def process_inputs(
9990
tosa_graph.addInputTensor(tensor)
10091

10192

102-
def process_quantized_bias(
103-
node: torch.fx.Node,
104-
tosa_graph: ts.TosaSerializer,
105-
parameter_values,
106-
):
107-
"""
108-
Serialize bias node that needs to be quantized.
109-
"""
110-
consumer_node = list(node.users)[0]
111-
(
112-
input_node,
113-
weight_node,
114-
_,
115-
) = consumer_node.all_input_nodes
116-
117-
input_qargs = get_input_qparams( # pyre-ignore[16]: Module `executorch.backends.arm` has no attribute `_passes`.
118-
consumer_node
119-
)
120-
121-
input_node_scale = input_qargs[0].scale
122-
weight_node_scale = input_qargs[1].scale
123-
bias_values_quantized = (
124-
(parameter_values / (input_node_scale * weight_node_scale))
125-
.round()
126-
.astype(np.int32)
127-
)
128-
129-
tosa_graph.addConst(
130-
bias_values_quantized.shape,
131-
ts.DType.INT32,
132-
bias_values_quantized,
133-
name=node.name,
134-
)
135-
136-
13793
def process_inputs_to_parameters(
13894
node: torch.fx.Node,
13995
tosa_graph: ts.TosaSerializer,
@@ -148,20 +104,14 @@ def process_inputs_to_parameters(
148104
assert isinstance(parameter_data, torch.Tensor), "Expect Attr to be tensor"
149105
parameter_values = parameter_data.detach().numpy()
150106

151-
if is_bias_node_for_quantized_conv(node):
152-
# BI bias
153-
assert tosa_spec.support_integer(), f"{tosa_spec} doesnt't support integer"
154-
process_quantized_bias(node, tosa_graph, parameter_values)
155-
else:
156-
# MI weights or bias
157-
if inputs[0].dtype == torch.float32:
158-
assert tosa_spec.support_float(), f"{tosa_spec} doesn't support float"
107+
if inputs[0].dtype == torch.float32:
108+
assert tosa_spec.support_float(), f"{tosa_spec} doesn't support float"
159109

160-
parameter_values = np.transpose(parameter_values, inputs[0].dim_order)
110+
parameter_values = np.transpose(parameter_values, inputs[0].dim_order)
161111

162-
tosa_graph.addConst(
163-
parameter_values.shape, inputs[0].dtype, parameter_values, name=node.name
164-
)
112+
tosa_graph.addConst(
113+
parameter_values.shape, inputs[0].dtype, parameter_values, name=node.name
114+
)
165115

166116

167117
def process_inputs_to_buffers(

backends/arm/quantizer/quantization_annotator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -196,7 +196,7 @@ def get_quant_properties( # noqa: C901
196196
input_act_qspec = quantization_config.get_input_act_qspec()
197197
weight_qspec = quantization_config.get_weight_qspec()
198198
output_act_qspec = quantization_config.get_output_act_qspec()
199-
bias_qspec = quantization_config.get_bias_qspec()
199+
bias_qspec = quantization_config.get_bias_qspec(node)
200200

201201
quant_properties = _OpQuantProperties()
202202

backends/arm/quantizer/quantization_config.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
@@ -9,8 +9,10 @@
99
from dataclasses import dataclass
1010

1111
import torch
12+
from torch.ao.quantization import ObserverOrFakeQuantize
1213

1314
from torch.ao.quantization.quantizer import (
15+
DerivedQuantizationSpec,
1416
FixedQParamsQuantizationSpec,
1517
QuantizationSpec,
1618
)
@@ -53,8 +55,42 @@ def get_weight_qspec(self) -> QuantizationSpec | None:
5355
], f"Unsupported quantization_spec {self.weight} for weight"
5456
return self.weight
5557

56-
def get_bias_qspec(self) -> QuantizationSpec | None:
58+
def get_bias_qspec(self, node: torch.fx.Node) -> QuantizationSpec | None:
5759
"""Returns QuantizationSpec 'bias' after asserting that bias.dtype is torch.float."""
60+
61+
def _derive_qparams_fn(
62+
obs_or_fqs: list[ObserverOrFakeQuantize],
63+
) -> tuple[torch.Tensor, torch.Tensor]:
64+
assert (
65+
len(obs_or_fqs) == 2
66+
), "Expecting two obs/fqs, one for activation and one for weight, got: {}".format(
67+
len(obs_or_fqs)
68+
)
69+
act_obs_or_fq = obs_or_fqs[0]
70+
weight_obs_or_fq = obs_or_fqs[1]
71+
act_scale, act_zp = act_obs_or_fq.calculate_qparams()
72+
weight_scale, weight_zp = weight_obs_or_fq.calculate_qparams()
73+
return torch.tensor([act_scale * weight_scale]).to(
74+
torch.float32
75+
), torch.tensor([0]).to(torch.int32)
76+
77+
if node.target in [
78+
torch.ops.aten.conv1d.default,
79+
torch.ops.aten.conv2d.default,
80+
torch.ops.aten.linear.default,
81+
]:
82+
input_act = node.args[0]
83+
weight = node.args[1]
84+
quantization_spec = DerivedQuantizationSpec(
85+
derived_from=[(input_act, node), (weight, node)],
86+
derive_qparams_fn=_derive_qparams_fn,
87+
dtype=torch.int32,
88+
quant_min=torch.iinfo(torch.int32).min,
89+
quant_max=torch.iinfo(torch.int32).max - 1,
90+
qscheme=torch.per_tensor_symmetric,
91+
)
92+
return quantization_spec
93+
5894
if self.bias is None:
5995
return None
6096
assert (

backends/arm/test/misc/test_debug_feats.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,10 @@ def test_collate_tosa_BI_tests(self):
197197
"test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests"
198198
)
199199
assert os.path.exists(
200-
"test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/output_tag5.tosa"
200+
"test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/output_tag6.tosa"
201201
)
202202
assert os.path.exists(
203-
"test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/desc_tag5.json"
203+
"test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/desc_tag6.json"
204204
)
205205

206206
os.environ.pop("TOSA_TESTCASES_BASE_PATH")

backends/arm/tosa_utils.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023-2024 Arm Limited and/or its affiliates.
1+
# Copyright 2023-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -133,19 +133,6 @@ def build_reshape(tosa_fb, input_name, new_shape, output_name):
133133
tosa_fb.addOperator(TosaOp.Op().RESHAPE, [input_name], [output_name], attr)
134134

135135

136-
def is_bias_node_for_quantized_conv(node):
137-
consumer_node = list(node.users)[0]
138-
139-
if (
140-
consumer_node.target == exir_ops.edge.aten.convolution.default
141-
and consumer_node.args[2] == node
142-
and consumer_node.meta["val"].dtype == torch.int8
143-
):
144-
return True
145-
146-
return False
147-
148-
149136
def is_consumer_node_depthwise_conv2d(node):
150137
consumer_node = list(node.users)[0]
151138
if consumer_node.target == exir_ops.edge.aten.convolution.default:

0 commit comments

Comments
 (0)