|
3 | 3 | # This source code is licensed under the BSD-style license found in the
|
4 | 4 | # LICENSE file in the root directory of this source tree.
|
5 | 5 |
|
| 6 | +# Create flows for Arm Backends used to test operator and model suits |
6 | 7 |
|
7 |
| -from executorch.backends.arm.quantizer import ( |
8 |
| - get_symmetric_quantization_config, |
9 |
| - TOSAQuantizer, |
10 |
| -) |
| 8 | +from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec |
| 9 | +from executorch.backends.arm.quantizer import get_symmetric_quantization_config |
11 | 10 | from executorch.backends.arm.test import common
|
12 | 11 | from executorch.backends.arm.test.tester.arm_tester import ArmTester
|
| 12 | +from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec |
| 13 | +from executorch.backends.arm.util._factory import create_quantizer |
13 | 14 | from executorch.backends.test.suite.flow import TestFlow
|
14 | 15 | from executorch.backends.xnnpack.test.tester.tester import Quantize
|
15 | 16 |
|
16 | 17 |
|
17 |
| -def _create_tosa_flow( |
| 18 | +def _create_arm_flow( |
18 | 19 | name,
|
19 |
| - compile_spec, |
20 |
| - quantize: bool = False, |
| 20 | + compile_spec: ArmCompileSpec, |
21 | 21 | symmetric_io_quantization: bool = False,
|
22 | 22 | per_channel_quantization: bool = True,
|
23 | 23 | ) -> TestFlow:
|
24 | 24 |
|
25 | 25 | def _create_arm_tester(*args, **kwargs) -> ArmTester:
|
26 | 26 | kwargs["compile_spec"] = compile_spec
|
| 27 | + return ArmTester(*args, **kwargs) |
| 28 | + |
| 29 | + support_serialize = not isinstance(compile_spec, TosaCompileSpec) |
| 30 | + quantize = compile_spec.tosa_spec.support_integer() |
27 | 31 |
|
28 |
| - return ArmTester( |
29 |
| - *args, |
30 |
| - **kwargs, |
31 |
| - ) |
| 32 | + if quantize is True: |
32 | 33 |
|
33 |
| - # Create and configure quantizer to use in the flow |
34 |
| - def create_quantize_stage() -> Quantize: |
35 |
| - quantizer = TOSAQuantizer(compile_spec) |
36 |
| - quantization_config = get_symmetric_quantization_config( |
37 |
| - is_per_channel=per_channel_quantization |
38 |
| - ) |
39 |
| - if symmetric_io_quantization: |
40 |
| - quantizer.set_io(quantization_config) |
41 |
| - return Quantize(quantizer, quantization_config) |
| 34 | + def create_quantize_stage() -> Quantize: |
| 35 | + quantizer = create_quantizer(compile_spec) |
| 36 | + quantization_config = get_symmetric_quantization_config( |
| 37 | + is_per_channel=per_channel_quantization |
| 38 | + ) |
| 39 | + if symmetric_io_quantization: |
| 40 | + quantizer.set_io(quantization_config) |
| 41 | + return Quantize(quantizer, quantization_config) |
42 | 42 |
|
43 | 43 | return TestFlow(
|
44 | 44 | name,
|
45 | 45 | backend="arm",
|
46 | 46 | tester_factory=_create_arm_tester,
|
47 |
| - supports_serialize=False, |
| 47 | + supports_serialize=support_serialize, |
48 | 48 | quantize=quantize,
|
49 |
| - quantize_stage_factory=create_quantize_stage if quantize else None, |
| 49 | + quantize_stage_factory=(create_quantize_stage if quantize is True else False), |
50 | 50 | )
|
51 | 51 |
|
52 | 52 |
|
53 |
| -ARM_TOSA_FP_FLOW = _create_tosa_flow( |
54 |
| - "arm_tosa_fp", common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP") |
| 53 | +ARM_TOSA_FP_FLOW = _create_arm_flow( |
| 54 | + "arm_tosa_fp", |
| 55 | + common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP"), |
55 | 56 | )
|
56 |
| -ARM_TOSA_INT_FLOW = _create_tosa_flow( |
| 57 | +ARM_TOSA_INT_FLOW = _create_arm_flow( |
57 | 58 | "arm_tosa_int",
|
58 | 59 | common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+INT"),
|
59 |
| - quantize=True, |
60 | 60 | )
|
61 |
| - |
62 |
| -ARM_ETHOS_U55_FLOW = _create_tosa_flow( |
| 61 | +ARM_ETHOS_U55_FLOW = _create_arm_flow( |
63 | 62 | "arm_ethos_u55",
|
64 | 63 | common.get_u55_compile_spec(),
|
65 |
| - quantize=True, |
66 | 64 | )
|
67 |
| - |
68 |
| -ARM_ETHOS_U85_FLOW = _create_tosa_flow( |
| 65 | +ARM_ETHOS_U85_FLOW = _create_arm_flow( |
69 | 66 | "arm_ethos_u85",
|
70 | 67 | common.get_u85_compile_spec(),
|
71 |
| - quantize=True, |
72 | 68 | )
|
0 commit comments