|
| 1 | +# Copyright 2025 Arm Limited and/or its affiliates. |
| 2 | +# |
| 3 | +# This source code is licensed under the BSD-style license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | + |
| 7 | +from executorch.backends.arm.quantizer import ( |
| 8 | + get_symmetric_quantization_config, |
| 9 | + TOSAQuantizer, |
| 10 | +) |
1 | 11 | from executorch.backends.arm.test import common |
2 | 12 | from executorch.backends.arm.test.tester.arm_tester import ArmTester |
3 | 13 | from executorch.backends.test.suite.flow import TestFlow |
| 14 | +from executorch.backends.xnnpack.test.tester.tester import Quantize |
4 | 15 |
|
5 | 16 |
|
6 | | -def _create_arm_tester_tosa_fp(*args, **kwargs) -> ArmTester: |
7 | | - kwargs["compile_spec"] = common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP") |
| 17 | +def _create_tosa_flow( |
| 18 | + name, |
| 19 | + compile_spec, |
| 20 | + quantize: bool = False, |
| 21 | + symmetric_io_quantization: bool = False, |
| 22 | + per_channel_quantization: bool = True, |
| 23 | +) -> TestFlow: |
8 | 24 |
|
9 | | - return ArmTester( |
10 | | - *args, |
11 | | - **kwargs, |
12 | | - ) |
| 25 | + def _create_arm_tester(*args, **kwargs) -> ArmTester: |
| 26 | + kwargs["compile_spec"] = compile_spec |
13 | 27 |
|
| 28 | + return ArmTester( |
| 29 | + *args, |
| 30 | + **kwargs, |
| 31 | + ) |
| 32 | + |
| 33 | + # Create and configure quantizer to use in the flow |
| 34 | + def create_quantize_stage() -> Quantize: |
| 35 | + quantizer = TOSAQuantizer(compile_spec) |
| 36 | + quantization_config = get_symmetric_quantization_config( |
| 37 | + is_per_channel=per_channel_quantization |
| 38 | + ) |
| 39 | + if symmetric_io_quantization: |
| 40 | + quantizer.set_io(quantization_config) |
| 41 | + return Quantize(quantizer, quantization_config) |
14 | 42 |
|
15 | | -def _create_tosa_flow() -> TestFlow: |
16 | 43 | return TestFlow( |
17 | | - "arm_tosa", |
| 44 | + name, |
18 | 45 | backend="arm", |
19 | | - tester_factory=_create_arm_tester_tosa_fp, |
| 46 | + tester_factory=_create_arm_tester, |
20 | 47 | supports_serialize=False, |
| 48 | + quantize=quantize, |
| 49 | + quantize_stage_factory=create_quantize_stage if quantize else None, |
21 | 50 | ) |
22 | 51 |
|
23 | 52 |
|
24 | | -ARM_TOSA_FLOW = _create_tosa_flow() |
| 53 | +ARM_TOSA_FP_FLOW = _create_tosa_flow( |
| 54 | + "arm_tosa_fp", common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP") |
| 55 | +) |
| 56 | +ARM_TOSA_INT_FLOW = _create_tosa_flow( |
| 57 | + "arm_tosa_int", |
| 58 | + common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+INT"), |
| 59 | + quantize=True, |
| 60 | +) |
| 61 | + |
| 62 | +ARM_ETHOS_U55_FLOW = _create_tosa_flow( |
| 63 | + "arm_ethos_u55", |
| 64 | + common.get_u55_compile_spec(), |
| 65 | + quantize=True, |
| 66 | +) |
| 67 | + |
| 68 | +ARM_ETHOS_U85_FLOW = _create_tosa_flow( |
| 69 | + "arm_ethos_u85", |
| 70 | + common.get_u85_compile_spec(), |
| 71 | + quantize=True, |
| 72 | +) |
0 commit comments