Skip to content

Commit 4d7e3ee

Browse files
authored
Arm backend: Improve quantizer configuration in tests (pytorch#16073)
Updating the quantizer in test pipelines has been rather cumbersome. As we anticipate more tests with different quantization, we want to make it easy. The quantizer can now be accessed, and modified using pipeline.quantizer.set_[...] See the patch for examples. Signed-off-by: Erik Lundell <[email protected]>
1 parent 4860984 commit 4d7e3ee

File tree

6 files changed

+67
-163
lines changed

6 files changed

+67
-163
lines changed

backends/arm/test/misc/test_quant_custom_meta.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,7 @@
55

66
import pytest
77
import torch
8-
from executorch.backends.arm.quantizer import (
9-
get_symmetric_quantization_config,
10-
TOSAQuantizer,
11-
)
128
from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineINT
13-
from executorch.backends.arm.tosa import TosaSpecification
14-
from executorch.backends.xnnpack.test.tester import Quantize
159

1610

1711
class AddSigmoidMul(torch.nn.Module):
@@ -23,15 +17,6 @@ def forward(self, x, y):
2317
return self.sigmoid(x + y) * x
2418

2519

26-
def get_selective_quantizer(modules):
27-
quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT"))
28-
quantizer.set_global(get_symmetric_quantization_config())
29-
for module in modules:
30-
quantizer.set_module_type(module, None)
31-
32-
return Quantize(quantizer, get_symmetric_quantization_config())
33-
34-
3520
@pytest.mark.parametrize("fp_extension", [True, False])
3621
def test_qdq_squeezed_fp_op(fp_extension: bool):
3722
"""Test that a float operation surrounded by quantize-dequantize pairs
@@ -52,7 +37,7 @@ def test_qdq_squeezed_fp_op(fp_extension: bool):
5237
exir_op=exir_op,
5338
tosa_extensions=["FP"] if fp_extension else None,
5439
)
55-
pipeline.change_args("quantize", get_selective_quantizer([torch.nn.Sigmoid]))
40+
pipeline.quantizer.set_module_type(torch.nn.Sigmoid, None) # type: ignore
5641

5742
if not fp_extension:
5843
# In case we don't have the FP extension, the unquantized part of the
@@ -114,7 +99,7 @@ def test_quantized_to_float_transition(fp_extension: bool):
11499
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2,
115100
},
116101
)
117-
pipeline.change_args(
118-
"quantize", get_selective_quantizer([torch.nn.Sigmoid, torch.nn.Conv1d])
119-
)
102+
pipeline.quantizer.set_module_type(torch.nn.Sigmoid, None) # type: ignore
103+
pipeline.quantizer.set_module_type(torch.nn.Conv1d, None) # type: ignore
104+
120105
pipeline.run()

backends/arm/test/models/test_lstm_arm.py

Lines changed: 12 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,9 @@
99
import torch
1010
from executorch.backends.arm.quantizer.arm_quantizer import (
1111
get_symmetric_a16w8_quantization_config,
12-
TOSAQuantizer,
1312
)
1413

15-
from executorch.backends.arm.test import common, conftest
14+
from executorch.backends.arm.test import common
1615
from executorch.backends.arm.test.tester.test_pipeline import (
1716
EthosU55PipelineINT,
1817
EthosU85PipelineINT,
@@ -21,9 +20,6 @@
2120
VgfPipeline,
2221
)
2322

24-
from executorch.backends.arm.tosa import TosaSpecification
25-
from executorch.backends.xnnpack.test.tester import Quantize
26-
2723
from torch.nn.quantizable.modules import rnn
2824

2925
input_t = Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] # (h0, c0)
@@ -144,27 +140,6 @@ def test_lstm_vgf_FP():
144140
pipeline.run()
145141

146142

147-
def get_symmetric_a16w8_lstm_quantizer(per_channel_quantization=False):
148-
tosa_version = conftest.get_option("tosa_version")
149-
tosa_profiles = {
150-
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
151-
}
152-
153-
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
154-
quantizer.set_global(
155-
get_symmetric_a16w8_quantization_config(
156-
is_per_channel=per_channel_quantization, epsilon=2**-16
157-
)
158-
)
159-
160-
return Quantize(
161-
quantizer,
162-
get_symmetric_a16w8_quantization_config(
163-
is_per_channel=per_channel_quantization, epsilon=2**-16
164-
),
165-
)
166-
167-
168143
def test_lstm_16a8w_tosa_INT():
169144
"""Test LSTM model with 16A8W quantization (16-bit activations, 8-bit weights)"""
170145

@@ -177,8 +152,9 @@ def test_lstm_16a8w_tosa_INT():
177152
use_to_edge_transform_and_lower=True,
178153
tosa_extensions=["int16"],
179154
)
180-
181-
pipeline.change_args("quantize", get_symmetric_a16w8_lstm_quantizer())
155+
pipeline.quantizer.set_global(
156+
get_symmetric_a16w8_quantization_config(is_per_channel=False, epsilon=2**-16)
157+
)
182158
pipeline.run()
183159

184160

@@ -195,7 +171,10 @@ def test_lstm_16a8w_u55_INT():
195171
use_to_edge_transform_and_lower=True,
196172
)
197173

198-
pipeline.change_args("quantize", get_symmetric_a16w8_lstm_quantizer())
174+
pipeline.quantizer.set_global(
175+
get_symmetric_a16w8_quantization_config(is_per_channel=False, epsilon=2**-16)
176+
)
177+
199178
pipeline.run()
200179

201180

@@ -208,5 +187,8 @@ def test_lstm_16a8w_u85_INT():
208187
exir_ops=[],
209188
use_to_edge_transform_and_lower=True,
210189
)
211-
pipeline.change_args("quantize", get_symmetric_a16w8_lstm_quantizer())
190+
pipeline.quantizer.set_global(
191+
get_symmetric_a16w8_quantization_config(is_per_channel=False, epsilon=2**-16)
192+
)
193+
212194
pipeline.run()

backends/arm/test/ops/test_add.py

Lines changed: 10 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,21 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
from typing import cast, Tuple
8+
from typing import Tuple
99

1010
import torch
1111
from executorch.backends.arm.quantizer import arm_quantizer
1212
from executorch.backends.arm.quantizer.arm_quantizer import (
1313
get_symmetric_a16w8_quantization_config,
14-
TOSAQuantizer,
1514
)
16-
from executorch.backends.arm.test import common, conftest
15+
from executorch.backends.arm.test import common
1716
from executorch.backends.arm.test.tester.test_pipeline import (
1817
EthosU55PipelineINT,
1918
EthosU85PipelineINT,
2019
TosaPipelineFP,
2120
TosaPipelineINT,
2221
VgfPipeline,
2322
)
24-
from executorch.backends.arm.tosa import TosaSpecification
25-
from executorch.backends.xnnpack.test.tester import Quantize
2623
from torchao.quantization.pt2e import HistogramObserver
2724
from torchao.quantization.pt2e.quantizer import QuantizationSpec
2825

@@ -101,14 +98,8 @@ def test_add_tensor_tosa_INT(test_data: input_t1):
10198
@common.parametrize("test_data", Add.test_data)
10299
def test_add_tensor_tosa_INT_i32(test_data: input_t1):
103100
pipeline = TosaPipelineINT[input_t1](Add(), test_data(), aten_op, exir_op)
104-
tosa_version = cast(str, conftest.get_option("tosa_version"))
105-
tosa_profiles = {
106-
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT"),
107-
}
108-
# Create a quantizer with int8 quantization on the input and output but int32 on everything else.
109-
quantizer = arm_quantizer.TOSAQuantizer(tosa_profiles[tosa_version])
110101

111-
quantizer.set_io(arm_quantizer.get_symmetric_quantization_config())
102+
pipeline.quantizer.set_io(arm_quantizer.get_symmetric_quantization_config())
112103
observer_options = {"eps": 2**-16}
113104
observer = HistogramObserver.with_args(**observer_options)
114105
input_act_qspec = QuantizationSpec(
@@ -125,12 +116,10 @@ def test_add_tensor_tosa_INT_i32(test_data: input_t1):
125116
quant_max=2**31 - 1,
126117
quant_min=-(2**31),
127118
)
128-
# This quantization_config will be set as global config.
129119
quantization_config = arm_quantizer.QuantizationConfig(
130120
input_act_qspec, output_act_qspec, None, None
131121
)
132-
quantize_stage = Quantize(quantizer, quantization_config)
133-
pipeline.change_args("quantize", quantize_stage)
122+
pipeline.quantizer.set_global(quantization_config)
134123

135124
# Check that we get the additional (dq -> q
136125
pipeline.add_stage_after(
@@ -239,25 +228,6 @@ def test_add_tensor_vgf_INT(test_data: input_t1):
239228
pipeline.run()
240229

241230

242-
def get_symmetric_a16w8_add_quantizer(per_channel_quantization=False):
243-
tosa_version = conftest.get_option("tosa_version")
244-
tosa_profiles = {
245-
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
246-
}
247-
248-
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
249-
quantizer.set_global(
250-
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
251-
)
252-
253-
return Quantize(
254-
quantizer,
255-
get_symmetric_a16w8_quantization_config(
256-
is_per_channel=per_channel_quantization
257-
),
258-
)
259-
260-
261231
@common.parametrize("test_data", Add.test_data)
262232
def test_add_tensor_16a8w_tosa_INT(test_data: input_t1):
263233
"""Test add operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
@@ -273,11 +243,8 @@ def test_add_tensor_16a8w_tosa_INT(test_data: input_t1):
273243
tosa_extensions=["int16"],
274244
)
275245

276-
pipeline.change_args(
277-
"quantize",
278-
get_symmetric_a16w8_add_quantizer(
279-
per_channel_quantization=per_channel_quantization
280-
),
246+
pipeline.quantizer.set_global(
247+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
281248
)
282249
pipeline.run()
283250

@@ -297,11 +264,8 @@ def test_add_tensor_16a8w_u55_INT16(test_data: input_t1):
297264
use_to_edge_transform_and_lower=True,
298265
)
299266

300-
pipeline.change_args(
301-
"quantize",
302-
get_symmetric_a16w8_add_quantizer(
303-
per_channel_quantization=per_channel_quantization
304-
),
267+
pipeline.quantizer.set_global(
268+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
305269
)
306270
pipeline.run()
307271

@@ -321,10 +285,7 @@ def test_add_tensor_16a8w_u85_INT16(test_data: input_t1):
321285
use_to_edge_transform_and_lower=True,
322286
)
323287

324-
pipeline.change_args(
325-
"quantize",
326-
get_symmetric_a16w8_add_quantizer(
327-
per_channel_quantization=per_channel_quantization
328-
),
288+
pipeline.quantizer.set_global(
289+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
329290
)
330291
pipeline.run()

backends/arm/test/ops/test_cat.py

Lines changed: 8 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,8 @@
1111
import torch
1212
from executorch.backends.arm.quantizer.arm_quantizer import (
1313
get_symmetric_a16w8_quantization_config,
14-
TOSAQuantizer,
1514
)
16-
from executorch.backends.arm.test import common, conftest
15+
from executorch.backends.arm.test import common
1716

1817
from executorch.backends.arm.test.tester.test_pipeline import (
1918
EthosU55PipelineINT,
@@ -22,8 +21,6 @@
2221
TosaPipelineINT,
2322
VgfPipeline,
2423
)
25-
from executorch.backends.arm.tosa.specification import TosaSpecification
26-
from executorch.backends.xnnpack.test.tester import Quantize
2724

2825
input_t1 = Tuple[torch.Tensor] # Input x
2926

@@ -157,25 +154,6 @@ def test_cat_vgf_INT(test_data: Tuple):
157154
pipeline.run()
158155

159156

160-
def get_symmetric_a16w8_cat_quantizer(per_channel_quantization=False):
161-
tosa_version = conftest.get_option("tosa_version")
162-
tosa_profiles = {
163-
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
164-
}
165-
166-
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
167-
quantizer.set_global(
168-
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
169-
)
170-
171-
return Quantize(
172-
quantizer,
173-
get_symmetric_a16w8_quantization_config(
174-
is_per_channel=per_channel_quantization
175-
),
176-
)
177-
178-
179157
@common.parametrize("test_data", Cat.test_parameters)
180158
def test_cat_16a8w_tosa_INT(test_data: Tuple):
181159
"""Test cat operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
@@ -190,12 +168,8 @@ def test_cat_16a8w_tosa_INT(test_data: Tuple):
190168
use_to_edge_transform_and_lower=True,
191169
tosa_extensions=["int16"],
192170
)
193-
194-
pipeline.change_args(
195-
"quantize",
196-
get_symmetric_a16w8_cat_quantizer(
197-
per_channel_quantization=per_channel_quantization
198-
),
171+
pipeline.quantizer.set_global(
172+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
199173
)
200174
pipeline.run()
201175

@@ -214,13 +188,10 @@ def test_cat_16a8w_u55_INT16(test_data: Tuple):
214188
per_channel_quantization=per_channel_quantization,
215189
use_to_edge_transform_and_lower=True,
216190
)
217-
218-
pipeline.change_args(
219-
"quantize",
220-
get_symmetric_a16w8_cat_quantizer(
221-
per_channel_quantization=per_channel_quantization
222-
),
191+
pipeline.quantizer.set_global(
192+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
223193
)
194+
224195
pipeline.run()
225196

226197

@@ -238,11 +209,7 @@ def test_cat_16a8w_u85_INT16(test_data: Tuple):
238209
per_channel_quantization=per_channel_quantization,
239210
use_to_edge_transform_and_lower=True,
240211
)
241-
242-
pipeline.change_args(
243-
"quantize",
244-
get_symmetric_a16w8_cat_quantizer(
245-
per_channel_quantization=per_channel_quantization
246-
),
212+
pipeline.quantizer.set_global(
213+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
247214
)
248215
pipeline.run()

0 commit comments

Comments
 (0)