Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions backends/arm/quantizer/arm_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import annotations

import functools
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from executorch.backends.arm._passes import ArmPassManager
Expand Down Expand Up @@ -218,9 +218,35 @@ def not_module_type_or_name_filter(n: Node) -> bool:

class TOSAQuantizer(Quantizer):

def __init__(self, tosa_spec: TosaSpecification) -> None:
def __init__(
self, compile_spec_or_tosa_spec: Union[TosaSpecification, List[CompileSpec]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: do you want to split into two constructors, instead of overloading..

def __init__(self, tosa_spec):
[...]

and

@classmethod
def from_compilespec(self, compilespec):
[...]

) -> None:

super().__init__()
self.tosa_spec = tosa_spec
if isinstance(compile_spec_or_tosa_spec, TosaSpecification):
self.tosa_spec = compile_spec_or_tosa_spec
self.compile_spec = None
elif isinstance(compile_spec_or_tosa_spec, list):
self.compile_spec = compile_spec_or_tosa_spec
# find entry that is 'tosa_spec'
for cs in compile_spec_or_tosa_spec:
if cs.key == "tosa_spec":
spec_val = (
cs.value.decode() if isinstance(cs.value, bytes) else cs.value
)
self.tosa_spec = TosaSpecification.create_from_string(spec_val)
break
else:
raise ValueError(
"compile_spec list did not contain a 'tosa_spec' entry"
)
else:
raise TypeError(
f"TOSAQuantizer constructor expects "
f"a TosaSpecification or compile_spec list, "
f"got {type(compile_spec_or_tosa_spec)}"
)

self.global_config: Optional[QuantizationConfig] = None
self.io_config: Optional[QuantizationConfig] = None
self.module_type_config: Dict[Callable, Optional[QuantizationConfig]] = {}
Expand Down
96 changes: 96 additions & 0 deletions backends/arm/test/misc/test_extract_io_params_tosa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import copy

import pytest
import torch
from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder
from executorch.backends.arm.quantizer import VgfQuantizer
from executorch.backends.arm.quantizer.arm_quantizer import (
get_symmetric_quantization_config,
TOSAQuantizer,
)

from executorch.backends.arm.test.common import SkipIfNoModelConverter
from executorch.backends.arm.tosa_partitioner import TOSAPartitioner
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.arm.vgf_partitioner import VgfPartitioner
from executorch.exir import to_edge_transform_and_lower
from executorch.exir.passes.quantize_io_pass import extract_io_quant_params
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e


class SimpleAdd(torch.nn.Module):
def forward(self, x, y):
return x + y


@pytest.mark.parametrize(
"builder_method, quantizer_cls, partitioner_cls",
[
("tosa_compile_spec", TOSAQuantizer, TOSAPartitioner),
pytest.param(
"vgf_compile_spec",
VgfQuantizer,
VgfPartitioner,
marks=SkipIfNoModelConverter,
id="VGF",
),
],
)
def test_roundtrip_extracts_io_params(builder_method, quantizer_cls, partitioner_cls):
"""
Validates that IO quantization parameters round-trip for both flows.
"""
example_inputs = (
torch.ones(1, 5),
torch.full((1, 5), 2.0),
)
mod = SimpleAdd().eval()

base_spec = TosaSpecification.create_from_string("TOSA-1.0+INT")
compile_spec = getattr(ArmCompileSpecBuilder(), builder_method)(
tosa_spec=base_spec
).build()

quantizer = quantizer_cls(compile_spec)
operator_config = get_symmetric_quantization_config(is_qat=True)
quantizer.set_global(operator_config)

exported = torch.export.export_for_training(
mod, copy.deepcopy(example_inputs), strict=True
)
prepared = prepare_pt2e(exported.module(), quantizer)
_ = prepared(*example_inputs)

converted = convert_pt2e(prepared)
final_export = torch.export.export_for_training(
converted, example_inputs, strict=True
)
partitioner = partitioner_cls(compile_spec)
edge_prog = to_edge_transform_and_lower(final_export, partitioner=[partitioner])

# Extract IO quantization parameters
q = extract_io_quant_params(
edge_prog,
input_idxs=(0, 1),
output_idxs=(0,),
)

assert "inputs" in q
assert "outputs" in q
assert len(q["inputs"]) == 2
assert len(q["outputs"]) == 1

for name, params in q["inputs"].items():
assert isinstance(name, str)
assert isinstance(params["scale"], float)
assert isinstance(params["zero_point"], int)

out_name, out_params = next(iter(q["outputs"].items()))
assert isinstance(out_name, str)
assert isinstance(out_params["scale"], float)
assert isinstance(out_params["zero_point"], int)
Loading