| 
 | 1 | +# Copyright 2025 Arm Limited and/or its affiliates.  | 
 | 2 | +#  | 
 | 3 | +# This source code is licensed under the BSD-style license found in the  | 
 | 4 | +# LICENSE file in the root directory of this source tree.  | 
 | 5 | + | 
 | 6 | + | 
 | 7 | +from executorch.backends.arm.quantizer import (  | 
 | 8 | +    get_symmetric_quantization_config,  | 
 | 9 | +    TOSAQuantizer,  | 
 | 10 | +)  | 
1 | 11 | from executorch.backends.arm.test import common  | 
2 | 12 | from executorch.backends.arm.test.tester.arm_tester import ArmTester  | 
3 | 13 | from executorch.backends.test.suite.flow import TestFlow  | 
 | 14 | +from executorch.backends.xnnpack.test.tester.tester import Quantize  | 
4 | 15 | 
 
  | 
5 | 16 | 
 
  | 
6 |  | -def _create_arm_tester_tosa_fp(*args, **kwargs) -> ArmTester:  | 
7 |  | -    kwargs["compile_spec"] = common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP")  | 
 | 17 | +def _create_tosa_flow(  | 
 | 18 | +    name,  | 
 | 19 | +    compile_spec,  | 
 | 20 | +    quantize: bool = False,  | 
 | 21 | +    symmetric_io_quantization: bool = False,  | 
 | 22 | +    per_channel_quantization: bool = True,  | 
 | 23 | +) -> TestFlow:  | 
8 | 24 | 
 
  | 
9 |  | -    return ArmTester(  | 
10 |  | -        *args,  | 
11 |  | -        **kwargs,  | 
12 |  | -    )  | 
 | 25 | +    def _create_arm_tester(*args, **kwargs) -> ArmTester:  | 
 | 26 | +        kwargs["compile_spec"] = compile_spec  | 
 | 27 | + | 
 | 28 | +        return ArmTester(  | 
 | 29 | +            *args,  | 
 | 30 | +            **kwargs,  | 
 | 31 | +        )  | 
13 | 32 | 
 
  | 
 | 33 | +    # Create and configure quantizer to use in the flow  | 
 | 34 | +    def create_quantize_stage() -> Quantize:  | 
 | 35 | +        quantizer = TOSAQuantizer(compile_spec)  | 
 | 36 | +        quantization_config = get_symmetric_quantization_config(  | 
 | 37 | +            is_per_channel=per_channel_quantization  | 
 | 38 | +        )  | 
 | 39 | +        if symmetric_io_quantization:  | 
 | 40 | +            quantizer.set_io(quantization_config)  | 
 | 41 | +        return Quantize(quantizer, quantization_config)  | 
14 | 42 | 
 
  | 
15 |  | -def _create_tosa_flow() -> TestFlow:  | 
16 | 43 |     return TestFlow(  | 
17 |  | -        "arm_tosa",  | 
 | 44 | +        name,  | 
18 | 45 |         backend="arm",  | 
19 |  | -        tester_factory=_create_arm_tester_tosa_fp,  | 
 | 46 | +        tester_factory=_create_arm_tester,  | 
20 | 47 |         supports_serialize=False,  | 
 | 48 | +        quantize=quantize,  | 
 | 49 | +        quantize_stage_factory=create_quantize_stage if quantize else None,  | 
21 | 50 |     )  | 
22 | 51 | 
 
  | 
23 | 52 | 
 
  | 
24 |  | -ARM_TOSA_FLOW = _create_tosa_flow()  | 
 | 53 | +ARM_TOSA_FP_FLOW = _create_tosa_flow(  | 
 | 54 | +    "arm_tosa_fp", common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP")  | 
 | 55 | +)  | 
 | 56 | +ARM_TOSA_INT_FLOW = _create_tosa_flow(  | 
 | 57 | +    "arm_tosa_int",  | 
 | 58 | +    common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+INT"),  | 
 | 59 | +    quantize=True,  | 
 | 60 | +)  | 
0 commit comments