|
3 | 3 | # This source code is licensed under the BSD-style license found in the |
4 | 4 | # LICENSE file in the root directory of this source tree. |
5 | 5 |
|
| 6 | +# Create flows for Arm Backends used to test operator and model suits |
6 | 7 |
|
7 | | -from executorch.backends.arm.quantizer import ( |
8 | | - get_symmetric_quantization_config, |
9 | | - TOSAQuantizer, |
10 | | -) |
| 8 | +from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec |
| 9 | +from executorch.backends.arm.quantizer import get_symmetric_quantization_config |
11 | 10 | from executorch.backends.arm.test import common |
12 | 11 | from executorch.backends.arm.test.tester.arm_tester import ArmTester |
| 12 | +from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec |
| 13 | +from executorch.backends.arm.util._factory import create_quantizer |
13 | 14 | from executorch.backends.test.suite.flow import TestFlow |
14 | 15 | from executorch.backends.xnnpack.test.tester.tester import Quantize |
15 | 16 |
|
16 | 17 |
|
17 | | -def _create_tosa_flow( |
| 18 | +def _create_arm_flow( |
18 | 19 | name, |
19 | | - compile_spec, |
20 | | - quantize: bool = False, |
| 20 | + compile_spec: ArmCompileSpec, |
21 | 21 | symmetric_io_quantization: bool = False, |
22 | 22 | per_channel_quantization: bool = True, |
23 | 23 | ) -> TestFlow: |
24 | 24 |
|
25 | 25 | def _create_arm_tester(*args, **kwargs) -> ArmTester: |
26 | 26 | kwargs["compile_spec"] = compile_spec |
| 27 | + return ArmTester(*args, **kwargs) |
| 28 | + |
| 29 | + support_serialize = not isinstance(compile_spec, TosaCompileSpec) |
| 30 | + quantize = compile_spec.tosa_spec.support_integer() |
27 | 31 |
|
28 | | - return ArmTester( |
29 | | - *args, |
30 | | - **kwargs, |
31 | | - ) |
| 32 | + if quantize is True: |
32 | 33 |
|
33 | | - # Create and configure quantizer to use in the flow |
34 | | - def create_quantize_stage() -> Quantize: |
35 | | - quantizer = TOSAQuantizer(compile_spec) |
36 | | - quantization_config = get_symmetric_quantization_config( |
37 | | - is_per_channel=per_channel_quantization |
38 | | - ) |
39 | | - if symmetric_io_quantization: |
40 | | - quantizer.set_io(quantization_config) |
41 | | - return Quantize(quantizer, quantization_config) |
| 34 | + def create_quantize_stage() -> Quantize: |
| 35 | + quantizer = create_quantizer(compile_spec) |
| 36 | + quantization_config = get_symmetric_quantization_config( |
| 37 | + is_per_channel=per_channel_quantization |
| 38 | + ) |
| 39 | + if symmetric_io_quantization: |
| 40 | + quantizer.set_io(quantization_config) |
| 41 | + return Quantize(quantizer, quantization_config) |
42 | 42 |
|
43 | 43 | return TestFlow( |
44 | 44 | name, |
45 | 45 | backend="arm", |
46 | 46 | tester_factory=_create_arm_tester, |
47 | | - supports_serialize=False, |
| 47 | + supports_serialize=support_serialize, |
48 | 48 | quantize=quantize, |
49 | | - quantize_stage_factory=create_quantize_stage if quantize else None, |
| 49 | + quantize_stage_factory=(create_quantize_stage if quantize is True else False), |
50 | 50 | ) |
51 | 51 |
|
52 | 52 |
|
53 | | -ARM_TOSA_FP_FLOW = _create_tosa_flow( |
54 | | - "arm_tosa_fp", common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP") |
| 53 | +ARM_TOSA_FP_FLOW = _create_arm_flow( |
| 54 | + "arm_tosa_fp", |
| 55 | + common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP"), |
55 | 56 | ) |
56 | | -ARM_TOSA_INT_FLOW = _create_tosa_flow( |
| 57 | +ARM_TOSA_INT_FLOW = _create_arm_flow( |
57 | 58 | "arm_tosa_int", |
58 | 59 | common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+INT"), |
59 | | - quantize=True, |
60 | 60 | ) |
61 | | - |
62 | | -ARM_ETHOS_U55_FLOW = _create_tosa_flow( |
| 61 | +ARM_ETHOS_U55_FLOW = _create_arm_flow( |
63 | 62 | "arm_ethos_u55", |
64 | 63 | common.get_u55_compile_spec(), |
65 | | - quantize=True, |
66 | 64 | ) |
67 | | - |
68 | | -ARM_ETHOS_U85_FLOW = _create_tosa_flow( |
| 65 | +ARM_ETHOS_U85_FLOW = _create_arm_flow( |
69 | 66 | "arm_ethos_u85", |
70 | 67 | common.get_u85_compile_spec(), |
71 | | - quantize=True, |
72 | 68 | ) |
0 commit comments