Skip to content

Commit e55bf54

Browse files
authored
Arm backend: Add tests for TOSA and VGF for extract_io_params_tosa + fix for bug with is_U55_subset (#13065)
Fix error with "is_U55_subset" by changing TOSAQuantizer to accept compiled spec Signed-off-by: Elena Zhelezina <[email protected]>
1 parent 6bc312a commit e55bf54

File tree

2 files changed

+125
-3
lines changed

2 files changed

+125
-3
lines changed

backends/arm/quantizer/arm_quantizer.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import annotations
1515

1616
import functools
17-
from typing import Any, Callable, Dict, List, Optional
17+
from typing import Any, Callable, Dict, List, Optional, Union
1818

1919
import torch
2020
from executorch.backends.arm._passes import ArmPassManager
@@ -218,9 +218,35 @@ def not_module_type_or_name_filter(n: Node) -> bool:
218218

219219
class TOSAQuantizer(Quantizer):
220220

221-
def __init__(self, tosa_spec: TosaSpecification) -> None:
221+
def __init__(
222+
self, compile_spec_or_tosa_spec: Union[TosaSpecification, List[CompileSpec]]
223+
) -> None:
224+
222225
super().__init__()
223-
self.tosa_spec = tosa_spec
226+
if isinstance(compile_spec_or_tosa_spec, TosaSpecification):
227+
self.tosa_spec = compile_spec_or_tosa_spec
228+
self.compile_spec = None
229+
elif isinstance(compile_spec_or_tosa_spec, list):
230+
self.compile_spec = compile_spec_or_tosa_spec
231+
# find entry that is 'tosa_spec'
232+
for cs in compile_spec_or_tosa_spec:
233+
if cs.key == "tosa_spec":
234+
spec_val = (
235+
cs.value.decode() if isinstance(cs.value, bytes) else cs.value
236+
)
237+
self.tosa_spec = TosaSpecification.create_from_string(spec_val)
238+
break
239+
else:
240+
raise ValueError(
241+
"compile_spec list did not contain a 'tosa_spec' entry"
242+
)
243+
else:
244+
raise TypeError(
245+
f"TOSAQuantizer constructor expects "
246+
f"a TosaSpecification or compile_spec list, "
247+
f"got {type(compile_spec_or_tosa_spec)}"
248+
)
249+
224250
self.global_config: Optional[QuantizationConfig] = None
225251
self.io_config: Optional[QuantizationConfig] = None
226252
self.module_type_config: Dict[Callable, Optional[QuantizationConfig]] = {}
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import copy
7+
8+
import pytest
9+
import torch
10+
from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder
11+
from executorch.backends.arm.quantizer import VgfQuantizer
12+
from executorch.backends.arm.quantizer.arm_quantizer import (
13+
get_symmetric_quantization_config,
14+
TOSAQuantizer,
15+
)
16+
17+
from executorch.backends.arm.test.common import SkipIfNoModelConverter
18+
from executorch.backends.arm.tosa_partitioner import TOSAPartitioner
19+
from executorch.backends.arm.tosa_specification import TosaSpecification
20+
from executorch.backends.arm.vgf_partitioner import VgfPartitioner
21+
from executorch.exir import to_edge_transform_and_lower
22+
from executorch.exir.passes.quantize_io_pass import extract_io_quant_params
23+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
24+
25+
26+
class SimpleAdd(torch.nn.Module):
27+
def forward(self, x, y):
28+
return x + y
29+
30+
31+
@pytest.mark.parametrize(
32+
"builder_method, quantizer_cls, partitioner_cls",
33+
[
34+
("tosa_compile_spec", TOSAQuantizer, TOSAPartitioner),
35+
pytest.param(
36+
"vgf_compile_spec",
37+
VgfQuantizer,
38+
VgfPartitioner,
39+
marks=SkipIfNoModelConverter,
40+
id="VGF",
41+
),
42+
],
43+
)
44+
def test_roundtrip_extracts_io_params(builder_method, quantizer_cls, partitioner_cls):
45+
"""
46+
Validates that IO quantization parameters round-trip for both flows.
47+
"""
48+
example_inputs = (
49+
torch.ones(1, 5),
50+
torch.full((1, 5), 2.0),
51+
)
52+
mod = SimpleAdd().eval()
53+
54+
base_spec = TosaSpecification.create_from_string("TOSA-1.0+INT")
55+
compile_spec = getattr(ArmCompileSpecBuilder(), builder_method)(
56+
tosa_spec=base_spec
57+
).build()
58+
59+
quantizer = quantizer_cls(compile_spec)
60+
operator_config = get_symmetric_quantization_config(is_qat=True)
61+
quantizer.set_global(operator_config)
62+
63+
exported = torch.export.export_for_training(
64+
mod, copy.deepcopy(example_inputs), strict=True
65+
)
66+
prepared = prepare_pt2e(exported.module(), quantizer)
67+
_ = prepared(*example_inputs)
68+
69+
converted = convert_pt2e(prepared)
70+
final_export = torch.export.export_for_training(
71+
converted, example_inputs, strict=True
72+
)
73+
partitioner = partitioner_cls(compile_spec)
74+
edge_prog = to_edge_transform_and_lower(final_export, partitioner=[partitioner])
75+
76+
# Extract IO quantization parameters
77+
q = extract_io_quant_params(
78+
edge_prog,
79+
input_idxs=(0, 1),
80+
output_idxs=(0,),
81+
)
82+
83+
assert "inputs" in q
84+
assert "outputs" in q
85+
assert len(q["inputs"]) == 2
86+
assert len(q["outputs"]) == 1
87+
88+
for name, params in q["inputs"].items():
89+
assert isinstance(name, str)
90+
assert isinstance(params["scale"], float)
91+
assert isinstance(params["zero_point"], int)
92+
93+
out_name, out_params = next(iter(q["outputs"].items()))
94+
assert isinstance(out_name, str)
95+
assert isinstance(out_params["scale"], float)
96+
assert isinstance(out_params["zero_point"], int)

0 commit comments

Comments
 (0)