Skip to content

Commit 9c520ff

Browse files
committed
[Backend Tester] Add quantized test flows for XNNPACK and Core ML
ghstack-source-id: 6cc40f2 ghstack-comment-id: 3105090683 Pull-Request: #12733
1 parent eb5b332 commit 9c520ff

31 files changed

+443
-338
lines changed

backends/apple/coreml/test/tester.py

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,64 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Any, List, Optional, Tuple
7+
from typing import Any, List, Optional, Sequence, Tuple
88

9+
import coremltools as ct
910
import executorch
1011
import executorch.backends.test.harness.stages as BaseStages
11-
12+
import functools
1213
import torch
14+
15+
from executorch.backends.apple.coreml.compiler import CoreMLBackend
1316
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
17+
from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer
1418
from executorch.backends.test.harness import Tester as TesterBase
1519
from executorch.backends.test.harness.stages import StageType
1620
from executorch.exir import EdgeCompileConfig
1721
from executorch.exir.backend.partitioner import Partitioner
1822

1923

24+
def _get_static_int8_qconfig():
25+
return ct.optimize.torch.quantization.LinearQuantizerConfig(
26+
global_config=ct.optimize.torch.quantization.ModuleLinearQuantizerConfig(
27+
quantization_scheme="symmetric",
28+
activation_dtype=torch.quint8,
29+
weight_dtype=torch.qint8,
30+
weight_per_channel=True,
31+
)
32+
)
33+
34+
35+
class Quantize(BaseStages.Quantize):
36+
def __init__(
37+
self,
38+
quantizer: Optional[CoreMLQuantizer] = None,
39+
quantization_config: Optional[Any] = None,
40+
calibrate: bool = True,
41+
calibration_samples: Optional[Sequence[Any]] = None,
42+
is_qat: Optional[bool] = False,
43+
):
44+
super().__init__(
45+
quantizer=quantizer or CoreMLQuantizer(quantization_config or _get_static_int8_qconfig()),
46+
calibrate=calibrate,
47+
calibration_samples=calibration_samples,
48+
is_qat=is_qat,
49+
)
50+
51+
52+
2053
class Partition(BaseStages.Partition):
21-
def __init__(self, partitioner: Optional[Partitioner] = None):
54+
def __init__(
55+
self,
56+
partitioner: Optional[Partitioner] = None,
57+
minimum_deployment_target: Optional[Any] = ct.target.iOS15,
58+
):
2259
super().__init__(
23-
partitioner=partitioner or CoreMLPartitioner,
60+
partitioner=partitioner or CoreMLPartitioner(
61+
compile_specs=CoreMLBackend.generate_compile_specs(
62+
minimum_deployment_target=minimum_deployment_target
63+
)
64+
),
2465
)
2566

2667

@@ -29,9 +70,14 @@ def __init__(
2970
self,
3071
partitioners: Optional[List[Partitioner]] = None,
3172
edge_compile_config: Optional[EdgeCompileConfig] = None,
73+
minimum_deployment_target: Optional[Any] = ct.target.iOS15,
3274
):
3375
super().__init__(
34-
default_partitioner_cls=CoreMLPartitioner,
76+
default_partitioner_cls=lambda: CoreMLPartitioner(
77+
compile_specs=CoreMLBackend.generate_compile_specs(
78+
minimum_deployment_target=minimum_deployment_target
79+
)
80+
),
3581
partitioners=partitioners,
3682
edge_compile_config=edge_compile_config,
3783
)
@@ -43,13 +89,15 @@ def __init__(
4389
module: torch.nn.Module,
4490
example_inputs: Tuple[torch.Tensor],
4591
dynamic_shapes: Optional[Tuple[Any]] = None,
92+
minimum_deployment_target: Optional[Any] = ct.target.iOS15,
4693
):
4794
# Specialize for XNNPACK
4895
stage_classes = (
4996
executorch.backends.test.harness.Tester.default_stage_classes()
5097
| {
51-
StageType.PARTITION: Partition,
52-
StageType.TO_EDGE_TRANSFORM_AND_LOWER: ToEdgeTransformAndLower,
98+
StageType.QUANTIZE: Quantize,
99+
StageType.PARTITION: functools.partial(Partition, minimum_deployment_target=minimum_deployment_target),
100+
StageType.TO_EDGE_TRANSFORM_AND_LOWER: functools.partial(ToEdgeTransformAndLower, minimum_deployment_target=minimum_deployment_target),
53101
}
54102
)
55103

backends/test/harness/stages/quantize.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ def __init__(
3131
self.calibrate = calibrate
3232
self.calibration_samples = calibration_samples
3333

34-
self.quantizer.set_global(self.quantization_config)
34+
if self.quantization_config is not None:
35+
self.quantizer.set_global(self.quantization_config)
3536

3637
self.converted_graph = None
3738
self.is_qat = is_qat

backends/test/harness/tester.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import random
22
from collections import Counter, OrderedDict
3-
from typing import Any, Dict, List, Optional, Tuple, Type
3+
from typing import Any, Callable, Dict, List, Optional, Tuple
44

55
import torch
66

@@ -33,7 +33,7 @@ def __init__(
3333
self,
3434
module: torch.nn.Module,
3535
example_inputs: Tuple[torch.Tensor],
36-
stage_classes: Dict[StageType, Type],
36+
stage_classes: Dict[StageType, Callable],
3737
dynamic_shapes: Optional[Tuple[Any]] = None,
3838
):
3939
module.eval()
@@ -81,7 +81,7 @@ def __init__(
8181
self.stage_output = None
8282

8383
@staticmethod
84-
def default_stage_classes() -> Dict[StageType, Type]:
84+
def default_stage_classes() -> Dict[StageType, Callable]:
8585
"""
8686
Returns a map of StageType to default Stage implementation.
8787
"""

backends/test/suite/__init__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def _make_wrapped_test(
129129
def wrapped_test(self):
130130
with TestContext(test_name, flow.name, params):
131131
test_kwargs = params or {}
132-
test_kwargs["tester_factory"] = flow.tester_factory
132+
test_kwargs["flow"] = flow
133133

134134
test_func(self, **test_kwargs)
135135

@@ -175,7 +175,7 @@ def load_tests(loader, suite, pattern):
175175

176176

177177
class OperatorTest(unittest.TestCase):
178-
def _test_op(self, model, inputs, tester_factory):
178+
def _test_op(self, model, inputs, flow: TestFlow):
179179
context = get_active_test_context()
180180

181181
# This should be set in the wrapped test. See _make_wrapped_test above.
@@ -184,9 +184,8 @@ def _test_op(self, model, inputs, tester_factory):
184184
run_summary = run_test(
185185
model,
186186
inputs,
187-
tester_factory,
187+
flow,
188188
context.test_name,
189-
context.flow_name,
190189
context.params,
191190
)
192191

backends/test/suite/flow.py

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import logging
22

3-
from dataclasses import dataclass
3+
from dataclasses import dataclass, field
44
from typing import Callable
55

66
from executorch.backends.test.harness import Tester
7+
from executorch.backends.test.harness.stages import Quantize
78

89
logger = logging.getLogger(__name__)
910
logger.setLevel(logging.INFO)
@@ -21,42 +22,35 @@ class TestFlow:
2122

2223
backend: str
2324
""" The name of the target backend. """
24-
25-
tester_factory: Callable[[], Tester]
25+
26+
tester_factory: Callable[..., Tester]
2627
""" A factory function that returns a Tester instance for this lowering flow. """
2728

29+
quantize: bool = field(default=False)
30+
""" Whether to tester should run the quantize stage on the model. """
31+
32+
quantize_stage_factory: Callable[..., Quantize] | None = None
33+
""" A factory function which instantiates a Quantize stage. Can be None to use the tester's default. """
2834

29-
def create_xnnpack_flow() -> TestFlow | None:
35+
def all_flows() -> dict[str, TestFlow]:
36+
flows = []
37+
3038
try:
31-
from executorch.backends.xnnpack.test.tester import Tester as XnnpackTester
32-
33-
return TestFlow(
34-
name="xnnpack",
35-
backend="xnnpack",
36-
tester_factory=XnnpackTester,
37-
)
38-
except Exception:
39-
logger.info("Skipping XNNPACK flow registration due to import failure.")
40-
return None
41-
39+
from executorch.backends.test.suite.flows.xnnpack import XNNPACK_TEST_FLOW, XNNPACK_STATIC_INT8_TEST_FLOW
40+
flows += [
41+
XNNPACK_TEST_FLOW,
42+
XNNPACK_STATIC_INT8_TEST_FLOW,
43+
]
44+
except Exception as e:
45+
logger.info(f"Skipping XNNPACK flow registration: {e}")
4246

43-
def create_coreml_flow() -> TestFlow | None:
4447
try:
45-
from executorch.backends.apple.coreml.test.tester import CoreMLTester
48+
from executorch.backends.test.suite.flows.coreml import COREML_TEST_FLOW, COREML_STATIC_INT8_TEST_FLOW
49+
flows += [
50+
COREML_TEST_FLOW,
51+
COREML_STATIC_INT8_TEST_FLOW,
52+
]
53+
except Exception as e:
54+
logger.info(f"Skipping Core ML flow registration: {e}")
4655

47-
return TestFlow(
48-
name="coreml",
49-
backend="coreml",
50-
tester_factory=CoreMLTester,
51-
)
52-
except Exception:
53-
logger.info("Skipping Core ML flow registration due to import failure.")
54-
return None
55-
56-
57-
def all_flows() -> dict[str, TestFlow]:
58-
flows = [
59-
create_xnnpack_flow(),
60-
create_coreml_flow(),
61-
]
6256
return {f.name: f for f in flows if f is not None}

backends/test/suite/flows/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe

backends/test/suite/flows/coreml.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import coremltools
2+
import functools
3+
4+
from executorch.backends.apple.coreml.test.tester import CoreMLTester
5+
from executorch.backends.test.suite.flow import TestFlow
6+
from typing import Any
7+
8+
def _create_coreml_flow(
9+
name: str,
10+
quantize: bool = False,
11+
minimum_deployment_target: Any = coremltools.target.iOS15
12+
) -> TestFlow:
13+
return TestFlow(
14+
name,
15+
backend="coreml",
16+
tester_factory=functools.partial(CoreMLTester, minimum_deployment_target=minimum_deployment_target),
17+
quantize=quantize,
18+
)
19+
20+
COREML_TEST_FLOW = _create_coreml_flow("coreml")
21+
COREML_STATIC_INT8_TEST_FLOW = _create_coreml_flow(
22+
"coreml_static_int8",
23+
quantize=True,
24+
minimum_deployment_target=coremltools.target.iOS17)

backends/test/suite/flows/xnnpack.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from executorch.backends.test.harness.stages import Quantize
2+
from executorch.backends.test.suite.flow import TestFlow
3+
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import get_symmetric_quantization_config
4+
from executorch.backends.xnnpack.test.tester import (
5+
Quantize as XnnpackQuantize,
6+
Tester as XnnpackTester
7+
)
8+
from typing import Callable
9+
10+
import logging
11+
12+
logger = logging.getLogger(__name__)
13+
logger.setLevel(logging.INFO)
14+
15+
def _create_xnnpack_flow_base(name: str, quantize_stage_factory: Callable[..., Quantize] | None = None) -> TestFlow:
16+
return TestFlow(
17+
name,
18+
backend="xnnpack",
19+
tester_factory=XnnpackTester,
20+
quantize=True,
21+
quantize_stage_factory=quantize_stage_factory,
22+
)
23+
24+
def _create_xnnpack_flow() -> TestFlow:
25+
return _create_xnnpack_flow_base("xnnpack")
26+
27+
def _create_xnnpack_static_int8_flow() -> TestFlow:
28+
def create_quantize_stage() -> Quantize:
29+
qparams = get_symmetric_quantization_config(is_per_channel=True)
30+
return XnnpackQuantize(
31+
quantization_config=qparams,
32+
)
33+
return _create_xnnpack_flow_base("xnnpack_static_int8", create_quantize_stage)
34+
35+
XNNPACK_TEST_FLOW = _create_xnnpack_flow()
36+
XNNPACK_STATIC_INT8_TEST_FLOW = _create_xnnpack_static_int8_flow()

backends/test/suite/models/__init__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def wrapped_test(self):
4949
"use_dynamic_shapes": use_dynamic_shapes,
5050
}
5151
with TestContext(test_name, flow.name, params):
52-
test_func(self, dtype, use_dynamic_shapes, flow.tester_factory)
52+
test_func(self, flow, dtype, use_dynamic_shapes)
5353

5454
dtype_name = str(dtype)[6:] # strip "torch."
5555
test_name = f"{test_func.__name__}_{flow.name}_{dtype_name}"
@@ -104,9 +104,9 @@ def inner_decorator(func: Callable) -> Callable:
104104
def run_model_test(
105105
model: torch.nn.Module,
106106
inputs: tuple[Any],
107+
flow: TestFlow,
107108
dtype: torch.dtype,
108109
dynamic_shapes: Any | None,
109-
tester_factory: Callable[[], Tester],
110110
):
111111
model = model.to(dtype)
112112
context = get_active_test_context()
@@ -117,9 +117,8 @@ def run_model_test(
117117
run_summary = run_test(
118118
model,
119119
inputs,
120-
tester_factory,
120+
flow,
121121
context.test_name,
122-
context.flow_name,
123122
context.params,
124123
dynamic_shapes=dynamic_shapes,
125124
)

backends/test/suite/models/test_torchaudio.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch
1313
import torchaudio
1414

15+
from executorch.backends.test.suite.flow import TestFlow
1516
from executorch.backends.test.suite.models import (
1617
model_test_cls,
1718
model_test_params,
@@ -48,7 +49,7 @@ def forward(
4849
class TorchAudio(unittest.TestCase):
4950
@model_test_params(dtypes=[torch.float32], supports_dynamic_shapes=False)
5051
def test_conformer(
51-
self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable
52+
self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool
5253
):
5354
inner_model = torchaudio.models.Conformer(
5455
input_dim=80,
@@ -68,11 +69,11 @@ def test_conformer(
6869
encoder_padding_mask,
6970
)
7071

71-
run_model_test(model, inputs, dtype, None, tester_factory)
72+
run_model_test(model, inputs, flow, dtype, None)
7273

7374
@model_test_params(dtypes=[torch.float32])
7475
def test_wav2letter(
75-
self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable
76+
self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool
7677
):
7778
model = torchaudio.models.Wav2Letter()
7879
inputs = (torch.randn(1, 1, 1024, dtype=dtype),)
@@ -85,11 +86,11 @@ def test_wav2letter(
8586
if use_dynamic_shapes
8687
else None
8788
)
88-
run_model_test(model, inputs, dtype, dynamic_shapes, tester_factory)
89+
run_model_test(model, inputs, flow, dtype, dynamic_shapes)
8990

9091
@unittest.skip("This model times out on all backends.")
9192
def test_wavernn(
92-
self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable
93+
self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool,
9394
):
9495
model = torchaudio.models.WaveRNN(
9596
upsample_scales=[5, 5, 8], n_classes=512, hop_length=200
@@ -101,4 +102,4 @@ def test_wavernn(
101102
torch.randn(1, 1, 128, 64), # specgram
102103
)
103104

104-
run_model_test(model, inputs, dtype, None, tester_factory)
105+
run_model_test(model, inputs, flow, dtype, None)

0 commit comments

Comments
 (0)