Skip to content

[Backend Tester] Add quantized test flows for XNNPACK and Core ML #12733

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
Jul 23, 2025
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 55 additions & 7 deletions backends/apple/coreml/test/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,64 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, List, Optional, Tuple
from typing import Any, List, Optional, Sequence, Tuple

import coremltools as ct
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

come on :p

Suggested change
import coremltools as ct
import coremltools

import executorch
import executorch.backends.test.harness.stages as BaseStages

import functools
import torch

from executorch.backends.apple.coreml.compiler import CoreMLBackend
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer
from executorch.backends.test.harness import Tester as TesterBase
from executorch.backends.test.harness.stages import StageType
from executorch.exir import EdgeCompileConfig
from executorch.exir.backend.partitioner import Partitioner


def _get_static_int8_qconfig():
return ct.optimize.torch.quantization.LinearQuantizerConfig(
global_config=ct.optimize.torch.quantization.ModuleLinearQuantizerConfig(
quantization_scheme="symmetric",
Copy link
Contributor

@digantdesai digantdesai Jul 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this the main int8 schema we should be testing for Linear @metascroy

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI This is pulled directly from our docs at https://docs.pytorch.org/executorch/main/backends-coreml.html#bit-quantization-using-the-pt2e-flow. Would be good to sanity check with Scott, though.

activation_dtype=torch.quint8,
weight_dtype=torch.qint8,
weight_per_channel=True,
)
)


class Quantize(BaseStages.Quantize):
def __init__(
self,
quantizer: Optional[CoreMLQuantizer] = None,
quantization_config: Optional[Any] = None,
calibrate: bool = True,
calibration_samples: Optional[Sequence[Any]] = None,
is_qat: Optional[bool] = False,
):
super().__init__(
quantizer=quantizer or CoreMLQuantizer(quantization_config or _get_static_int8_qconfig()),
calibrate=calibrate,
calibration_samples=calibration_samples,
is_qat=is_qat,
)



class Partition(BaseStages.Partition):
def __init__(self, partitioner: Optional[Partitioner] = None):
def __init__(
self,
partitioner: Optional[Partitioner] = None,
minimum_deployment_target: Optional[Any] = ct.target.iOS15,
):
super().__init__(
partitioner=partitioner or CoreMLPartitioner,
partitioner=partitioner or CoreMLPartitioner(
compile_specs=CoreMLBackend.generate_compile_specs(
minimum_deployment_target=minimum_deployment_target
)
),
)


Expand All @@ -29,9 +70,14 @@ def __init__(
self,
partitioners: Optional[List[Partitioner]] = None,
edge_compile_config: Optional[EdgeCompileConfig] = None,
minimum_deployment_target: Optional[Any] = ct.target.iOS15,
):
super().__init__(
default_partitioner_cls=CoreMLPartitioner,
default_partitioner_cls=lambda: CoreMLPartitioner(
compile_specs=CoreMLBackend.generate_compile_specs(
minimum_deployment_target=minimum_deployment_target
)
),
partitioners=partitioners,
edge_compile_config=edge_compile_config,
)
Expand All @@ -43,13 +89,15 @@ def __init__(
module: torch.nn.Module,
example_inputs: Tuple[torch.Tensor],
dynamic_shapes: Optional[Tuple[Any]] = None,
minimum_deployment_target: Optional[Any] = ct.target.iOS15,
):
# Specialize for XNNPACK
stage_classes = (
executorch.backends.test.harness.Tester.default_stage_classes()
| {
StageType.PARTITION: Partition,
StageType.TO_EDGE_TRANSFORM_AND_LOWER: ToEdgeTransformAndLower,
StageType.QUANTIZE: Quantize,
StageType.PARTITION: functools.partial(Partition, minimum_deployment_target=minimum_deployment_target),
StageType.TO_EDGE_TRANSFORM_AND_LOWER: functools.partial(ToEdgeTransformAndLower, minimum_deployment_target=minimum_deployment_target),
}
)

Expand Down
3 changes: 2 additions & 1 deletion backends/test/harness/stages/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def __init__(
self.calibrate = calibrate
self.calibration_samples = calibration_samples

self.quantizer.set_global(self.quantization_config)
if self.quantization_config is not None:
self.quantizer.set_global(self.quantization_config)

self.converted_graph = None
self.is_qat = is_qat
Expand Down
6 changes: 3 additions & 3 deletions backends/test/harness/tester.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import random
from collections import Counter, OrderedDict
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch

Expand Down Expand Up @@ -33,7 +33,7 @@ def __init__(
self,
module: torch.nn.Module,
example_inputs: Tuple[torch.Tensor],
stage_classes: Dict[StageType, Type],
stage_classes: Dict[StageType, Callable],
dynamic_shapes: Optional[Tuple[Any]] = None,
):
module.eval()
Expand Down Expand Up @@ -81,7 +81,7 @@ def __init__(
self.stage_output = None

@staticmethod
def default_stage_classes() -> Dict[StageType, Type]:
def default_stage_classes() -> Dict[StageType, Callable]:
"""
Returns a map of StageType to default Stage implementation.
"""
Expand Down
7 changes: 3 additions & 4 deletions backends/test/suite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def _make_wrapped_test(
def wrapped_test(self):
with TestContext(test_name, flow.name, params):
test_kwargs = params or {}
test_kwargs["tester_factory"] = flow.tester_factory
test_kwargs["flow"] = flow

test_func(self, **test_kwargs)

Expand Down Expand Up @@ -175,7 +175,7 @@ def load_tests(loader, suite, pattern):


class OperatorTest(unittest.TestCase):
def _test_op(self, model, inputs, tester_factory):
def _test_op(self, model, inputs, flow: TestFlow):
context = get_active_test_context()

# This should be set in the wrapped test. See _make_wrapped_test above.
Expand All @@ -184,9 +184,8 @@ def _test_op(self, model, inputs, tester_factory):
run_summary = run_test(
model,
inputs,
tester_factory,
flow,
context.test_name,
context.flow_name,
context.params,
)

Expand Down
4 changes: 4 additions & 0 deletions backends/test/suite/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def _filter_tests(

def _is_test_enabled(test_case: unittest.TestCase, test_filter: TestFilter) -> bool:
test_method = getattr(test_case, test_case._testMethodName)

if not hasattr(test_method, "_flow"):
print(f"Test missing flow: {test_method}")

flow: TestFlow = test_method._flow

if test_filter.backends is not None and flow.backend not in test_filter.backends:
Expand Down
58 changes: 26 additions & 32 deletions backends/test/suite/flow.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging

from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Callable

from executorch.backends.test.harness import Tester
from executorch.backends.test.harness.stages import Quantize

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand All @@ -21,42 +22,35 @@ class TestFlow:

backend: str
""" The name of the target backend. """

tester_factory: Callable[[], Tester]
tester_factory: Callable[..., Tester]
""" A factory function that returns a Tester instance for this lowering flow. """

quantize: bool = field(default=False)
""" Whether to tester should run the quantize stage on the model. """

quantize_stage_factory: Callable[..., Quantize] | None = None
""" A factory function which instantiates a Quantize stage. Can be None to use the tester's default. """
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why an extra flag?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The specific reason is that if quantize_stage_factory isn't provided, it will use the default Quantize stage from the tester. I could maybe just always require the caller to provide quantize_stage_factory.


def create_xnnpack_flow() -> TestFlow | None:
def all_flows() -> dict[str, TestFlow]:
flows = []

try:
from executorch.backends.xnnpack.test.tester import Tester as XnnpackTester

return TestFlow(
name="xnnpack",
backend="xnnpack",
tester_factory=XnnpackTester,
)
except Exception:
logger.info("Skipping XNNPACK flow registration due to import failure.")
return None

from executorch.backends.test.suite.flows.xnnpack import XNNPACK_TEST_FLOW, XNNPACK_STATIC_INT8_TEST_FLOW
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from ... import *?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to clean up flow registration slightly in the stack, so I'll take this as a follow-up there.

flows += [
XNNPACK_TEST_FLOW,
XNNPACK_STATIC_INT8_TEST_FLOW,
]
except Exception as e:
logger.info(f"Skipping XNNPACK flow registration: {e}")

def create_coreml_flow() -> TestFlow | None:
try:
from executorch.backends.apple.coreml.test.tester import CoreMLTester
from executorch.backends.test.suite.flows.coreml import COREML_TEST_FLOW, COREML_STATIC_INT8_TEST_FLOW
flows += [
COREML_TEST_FLOW,
COREML_STATIC_INT8_TEST_FLOW,
]
except Exception as e:
logger.info(f"Skipping Core ML flow registration: {e}")

return TestFlow(
name="coreml",
backend="coreml",
tester_factory=CoreMLTester,
)
except Exception:
logger.info("Skipping Core ML flow registration due to import failure.")
return None


def all_flows() -> dict[str, TestFlow]:
flows = [
create_xnnpack_flow(),
create_coreml_flow(),
]
return {f.name: f for f in flows if f is not None}
7 changes: 7 additions & 0 deletions backends/test/suite/flows/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
24 changes: 24 additions & 0 deletions backends/test/suite/flows/coreml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import coremltools
import functools

from executorch.backends.apple.coreml.test.tester import CoreMLTester
from executorch.backends.test.suite.flow import TestFlow
from typing import Any

def _create_coreml_flow(
name: str,
quantize: bool = False,
minimum_deployment_target: Any = coremltools.target.iOS15
) -> TestFlow:
return TestFlow(
name,
backend="coreml",
tester_factory=functools.partial(CoreMLTester, minimum_deployment_target=minimum_deployment_target),
quantize=quantize,
)

COREML_TEST_FLOW = _create_coreml_flow("coreml")
COREML_STATIC_INT8_TEST_FLOW = _create_coreml_flow(
"coreml_static_int8",
quantize=True,
minimum_deployment_target=coremltools.target.iOS17)
36 changes: 36 additions & 0 deletions backends/test/suite/flows/xnnpack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from executorch.backends.test.harness.stages import Quantize
from executorch.backends.test.suite.flow import TestFlow
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import get_symmetric_quantization_config
from executorch.backends.xnnpack.test.tester import (
Quantize as XnnpackQuantize,
Tester as XnnpackTester
)
from typing import Callable

import logging

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

def _create_xnnpack_flow_base(name: str, quantize_stage_factory: Callable[..., Quantize] | None = None) -> TestFlow:
return TestFlow(
name,
backend="xnnpack",
tester_factory=XnnpackTester,
quantize=quantize_stage_factory is not None,
quantize_stage_factory=quantize_stage_factory,
)

def _create_xnnpack_flow() -> TestFlow:
return _create_xnnpack_flow_base("xnnpack")

def _create_xnnpack_static_int8_flow() -> TestFlow:
def create_quantize_stage() -> Quantize:
qparams = get_symmetric_quantization_config(is_per_channel=True)
return XnnpackQuantize(
quantization_config=qparams,
)
return _create_xnnpack_flow_base("xnnpack_static_int8", create_quantize_stage)

XNNPACK_TEST_FLOW = _create_xnnpack_flow()
XNNPACK_STATIC_INT8_TEST_FLOW = _create_xnnpack_static_int8_flow()
Loading
Loading