From d789fccce6a0a04e21a32e11b6ecc190fd22bd9d Mon Sep 17 00:00:00 2001 From: Elena Zhelezina Date: Wed, 30 Jul 2025 16:23:02 +0100 Subject: [PATCH] Arm backend: Add tests for TOSA and VGF for extract_io_params_tosa Fix error with "is_U55_subset" by changing TOSAQuantizer to accept compiled spec Signed-off-by: Elena Zhelezina Change-Id: I164d68190e6fb9edd2efb9c8bd2da8c1bfc4e406 --- backends/arm/quantizer/arm_quantizer.py | 32 ++++++- .../test/misc/test_extract_io_params_tosa.py | 96 +++++++++++++++++++ 2 files changed, 125 insertions(+), 3 deletions(-) create mode 100644 backends/arm/test/misc/test_extract_io_params_tosa.py diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index 734ddec4359..28bb70be2b1 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -14,7 +14,7 @@ from __future__ import annotations import functools -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union import torch from executorch.backends.arm._passes import ArmPassManager @@ -218,9 +218,35 @@ def not_module_type_or_name_filter(n: Node) -> bool: class TOSAQuantizer(Quantizer): - def __init__(self, tosa_spec: TosaSpecification) -> None: + def __init__( + self, compile_spec_or_tosa_spec: Union[TosaSpecification, List[CompileSpec]] + ) -> None: + super().__init__() - self.tosa_spec = tosa_spec + if isinstance(compile_spec_or_tosa_spec, TosaSpecification): + self.tosa_spec = compile_spec_or_tosa_spec + self.compile_spec = None + elif isinstance(compile_spec_or_tosa_spec, list): + self.compile_spec = compile_spec_or_tosa_spec + # find entry that is 'tosa_spec' + for cs in compile_spec_or_tosa_spec: + if cs.key == "tosa_spec": + spec_val = ( + cs.value.decode() if isinstance(cs.value, bytes) else cs.value + ) + self.tosa_spec = TosaSpecification.create_from_string(spec_val) + break + else: + raise ValueError( + "compile_spec list did not contain a 'tosa_spec' entry" + ) + else: + raise TypeError( + f"TOSAQuantizer constructor expects " + f"a TosaSpecification or compile_spec list, " + f"got {type(compile_spec_or_tosa_spec)}" + ) + self.global_config: Optional[QuantizationConfig] = None self.io_config: Optional[QuantizationConfig] = None self.module_type_config: Dict[Callable, Optional[QuantizationConfig]] = {} diff --git a/backends/arm/test/misc/test_extract_io_params_tosa.py b/backends/arm/test/misc/test_extract_io_params_tosa.py new file mode 100644 index 00000000000..8483de63656 --- /dev/null +++ b/backends/arm/test/misc/test_extract_io_params_tosa.py @@ -0,0 +1,96 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy + +import pytest +import torch +from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder +from executorch.backends.arm.quantizer import VgfQuantizer +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_quantization_config, + TOSAQuantizer, +) + +from executorch.backends.arm.test.common import SkipIfNoModelConverter +from executorch.backends.arm.tosa_partitioner import TOSAPartitioner +from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.backends.arm.vgf_partitioner import VgfPartitioner +from executorch.exir import to_edge_transform_and_lower +from executorch.exir.passes.quantize_io_pass import extract_io_quant_params +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e + + +class SimpleAdd(torch.nn.Module): + def forward(self, x, y): + return x + y + + +@pytest.mark.parametrize( + "builder_method, quantizer_cls, partitioner_cls", + [ + ("tosa_compile_spec", TOSAQuantizer, TOSAPartitioner), + pytest.param( + "vgf_compile_spec", + VgfQuantizer, + VgfPartitioner, + marks=SkipIfNoModelConverter, + id="VGF", + ), + ], +) +def test_roundtrip_extracts_io_params(builder_method, quantizer_cls, partitioner_cls): + """ + Validates that IO quantization parameters round-trip for both flows. + """ + example_inputs = ( + torch.ones(1, 5), + torch.full((1, 5), 2.0), + ) + mod = SimpleAdd().eval() + + base_spec = TosaSpecification.create_from_string("TOSA-1.0+INT") + compile_spec = getattr(ArmCompileSpecBuilder(), builder_method)( + tosa_spec=base_spec + ).build() + + quantizer = quantizer_cls(compile_spec) + operator_config = get_symmetric_quantization_config(is_qat=True) + quantizer.set_global(operator_config) + + exported = torch.export.export_for_training( + mod, copy.deepcopy(example_inputs), strict=True + ) + prepared = prepare_pt2e(exported.module(), quantizer) + _ = prepared(*example_inputs) + + converted = convert_pt2e(prepared) + final_export = torch.export.export_for_training( + converted, example_inputs, strict=True + ) + partitioner = partitioner_cls(compile_spec) + edge_prog = to_edge_transform_and_lower(final_export, partitioner=[partitioner]) + + # Extract IO quantization parameters + q = extract_io_quant_params( + edge_prog, + input_idxs=(0, 1), + output_idxs=(0,), + ) + + assert "inputs" in q + assert "outputs" in q + assert len(q["inputs"]) == 2 + assert len(q["outputs"]) == 1 + + for name, params in q["inputs"].items(): + assert isinstance(name, str) + assert isinstance(params["scale"], float) + assert isinstance(params["zero_point"], int) + + out_name, out_params = next(iter(q["outputs"].items())) + assert isinstance(out_name, str) + assert isinstance(out_params["scale"], float) + assert isinstance(out_params["zero_point"], int)