Skip to content
Closed
62 changes: 55 additions & 7 deletions backends/apple/coreml/test/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,64 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, List, Optional, Tuple
from typing import Any, List, Optional, Sequence, Tuple

import coremltools as ct
import executorch
import executorch.backends.test.harness.stages as BaseStages

import functools
import torch

from executorch.backends.apple.coreml.compiler import CoreMLBackend
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer
from executorch.backends.test.harness import Tester as TesterBase
from executorch.backends.test.harness.stages import StageType
from executorch.exir import EdgeCompileConfig
from executorch.exir.backend.partitioner import Partitioner


def _get_static_int8_qconfig():
return ct.optimize.torch.quantization.LinearQuantizerConfig(
global_config=ct.optimize.torch.quantization.ModuleLinearQuantizerConfig(
quantization_scheme="symmetric",
activation_dtype=torch.quint8,
weight_dtype=torch.qint8,
weight_per_channel=True,
)
)


class Quantize(BaseStages.Quantize):
def __init__(
self,
quantizer: Optional[CoreMLQuantizer] = None,
quantization_config: Optional[Any] = None,
calibrate: bool = True,
calibration_samples: Optional[Sequence[Any]] = None,
is_qat: Optional[bool] = False,
):
super().__init__(
quantizer=quantizer or CoreMLQuantizer(quantization_config or _get_static_int8_qconfig()),
calibrate=calibrate,
calibration_samples=calibration_samples,
is_qat=is_qat,
)



class Partition(BaseStages.Partition):
def __init__(self, partitioner: Optional[Partitioner] = None):
def __init__(
self,
partitioner: Optional[Partitioner] = None,
minimum_deployment_target: Optional[Any] = ct.target.iOS15,
):
super().__init__(
partitioner=partitioner or CoreMLPartitioner,
partitioner=partitioner or CoreMLPartitioner(
compile_specs=CoreMLBackend.generate_compile_specs(
minimum_deployment_target=minimum_deployment_target
)
),
)


Expand All @@ -29,9 +70,14 @@ def __init__(
self,
partitioners: Optional[List[Partitioner]] = None,
edge_compile_config: Optional[EdgeCompileConfig] = None,
minimum_deployment_target: Optional[Any] = ct.target.iOS15,
):
super().__init__(
default_partitioner_cls=CoreMLPartitioner,
default_partitioner_cls=lambda: CoreMLPartitioner(
compile_specs=CoreMLBackend.generate_compile_specs(
minimum_deployment_target=minimum_deployment_target
)
),
partitioners=partitioners,
edge_compile_config=edge_compile_config,
)
Expand All @@ -43,13 +89,15 @@ def __init__(
module: torch.nn.Module,
example_inputs: Tuple[torch.Tensor],
dynamic_shapes: Optional[Tuple[Any]] = None,
minimum_deployment_target: Optional[Any] = ct.target.iOS15,
):
# Specialize for XNNPACK
stage_classes = (
executorch.backends.test.harness.Tester.default_stage_classes()
| {
StageType.PARTITION: Partition,
StageType.TO_EDGE_TRANSFORM_AND_LOWER: ToEdgeTransformAndLower,
StageType.QUANTIZE: Quantize,
StageType.PARTITION: functools.partial(Partition, minimum_deployment_target=minimum_deployment_target),
StageType.TO_EDGE_TRANSFORM_AND_LOWER: functools.partial(ToEdgeTransformAndLower, minimum_deployment_target=minimum_deployment_target),
}
)

Expand Down
3 changes: 2 additions & 1 deletion backends/test/harness/stages/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def __init__(
self.calibrate = calibrate
self.calibration_samples = calibration_samples

self.quantizer.set_global(self.quantization_config)
if self.quantization_config is not None:
self.quantizer.set_global(self.quantization_config)

self.converted_graph = None
self.is_qat = is_qat
Expand Down
6 changes: 3 additions & 3 deletions backends/test/harness/tester.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import random
from collections import Counter, OrderedDict
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch

Expand Down Expand Up @@ -33,7 +33,7 @@ def __init__(
self,
module: torch.nn.Module,
example_inputs: Tuple[torch.Tensor],
stage_classes: Dict[StageType, Type],
stage_classes: Dict[StageType, Callable],
dynamic_shapes: Optional[Tuple[Any]] = None,
):
module.eval()
Expand Down Expand Up @@ -81,7 +81,7 @@ def __init__(
self.stage_output = None

@staticmethod
def default_stage_classes() -> Dict[StageType, Type]:
def default_stage_classes() -> Dict[StageType, Callable]:
"""
Returns a map of StageType to default Stage implementation.
"""
Expand Down
140 changes: 1 addition & 139 deletions backends/test/suite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,11 @@

import logging
import os
import unittest

from enum import Enum
from typing import Callable

import executorch.backends.test.suite.flow

import torch
from executorch.backends.test.suite.context import get_active_test_context, TestContext
from executorch.backends.test.suite.flow import TestFlow
from executorch.backends.test.suite.reporting import log_test_summary
from executorch.backends.test.suite.runner import run_test, runner_main
from executorch.backends.test.suite.runner import runner_main

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -62,109 +55,6 @@ def get_test_flows() -> dict[str, TestFlow]:
return _ALL_TEST_FLOWS


DTYPES = [
# torch.int8,
# torch.uint8,
# torch.int16,
# torch.uint16,
# torch.int32,
# torch.uint32,
# torch.int64,
# torch.uint64,
# torch.float16,
torch.float32,
# torch.float64,
]

FLOAT_DTYPES = [
torch.float16,
torch.float32,
torch.float64,
]


# The type of test function. This controls the test generation and expected signature.
# Standard tests are run, as is. Dtype tests get a variant generated for each dtype and
# take an additional dtype parameter.
class TestType(Enum):
STANDARD = 1
DTYPE = 2


# Function annotation for dtype tests. This instructs the test framework to run the test
# for each supported dtype and to pass dtype as a test parameter.
def dtype_test(func):
func.test_type = TestType.DTYPE
return func


# Class annotation for operator tests. This triggers the test framework to register
# the tests.
def operator_test(cls):
_create_tests(cls)
return cls


# Generate test cases for each backend flow.
def _create_tests(cls):
for key in dir(cls):
if key.startswith("test_"):
_expand_test(cls, key)


# Expand a test into variants for each registered flow.
def _expand_test(cls, test_name: str):
test_func = getattr(cls, test_name)
for flow in get_test_flows().values():
_create_test_for_backend(cls, test_func, flow)
delattr(cls, test_name)


def _make_wrapped_test(
test_func: Callable,
test_name: str,
flow: TestFlow,
params: dict | None = None,
):
def wrapped_test(self):
with TestContext(test_name, flow.name, params):
test_kwargs = params or {}
test_kwargs["tester_factory"] = flow.tester_factory

test_func(self, **test_kwargs)

wrapped_test._name = test_name
wrapped_test._flow = flow

return wrapped_test


def _create_test_for_backend(
cls,
test_func: Callable,
flow: TestFlow,
):
test_type = getattr(test_func, "test_type", TestType.STANDARD)

if test_type == TestType.STANDARD:
wrapped_test = _make_wrapped_test(test_func, test_func.__name__, flow)
test_name = f"{test_func.__name__}_{flow.name}"
setattr(cls, test_name, wrapped_test)
elif test_type == TestType.DTYPE:
for dtype in DTYPES:
wrapped_test = _make_wrapped_test(
test_func,
test_func.__name__,
flow,
{"dtype": dtype},
)
dtype_name = str(dtype)[6:] # strip "torch."
test_name = f"{test_func.__name__}_{dtype_name}_{flow.name}"
setattr(cls, test_name, wrapped_test)
else:
raise NotImplementedError(f"Unknown test type {test_type}.")


def load_tests(loader, suite, pattern):
package_dir = os.path.dirname(__file__)
discovered_suite = loader.discover(
Expand All @@ -174,33 +64,5 @@ def load_tests(loader, suite, pattern):
return suite


class OperatorTest(unittest.TestCase):
def _test_op(self, model, inputs, tester_factory):
context = get_active_test_context()

# This should be set in the wrapped test. See _make_wrapped_test above.
assert context is not None, "Missing test context."

run_summary = run_test(
model,
inputs,
tester_factory,
context.test_name,
context.flow_name,
context.params,
)

log_test_summary(run_summary)

if not run_summary.result.is_success():
if run_summary.result.is_backend_failure():
raise RuntimeError("Test failure.") from run_summary.error
else:
# Non-backend failure indicates a bad test. Mark as skipped.
raise unittest.SkipTest(
f"Test failed for reasons other than backend failure. Error: {run_summary.error}"
)


if __name__ == "__main__":
runner_main()
51 changes: 40 additions & 11 deletions backends/test/suite/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
import os
import unittest

from dataclasses import dataclass
from types import ModuleType
from typing import Pattern

from executorch.backends.test.suite.flow import TestFlow

Expand All @@ -18,8 +20,19 @@
#


@dataclass
class TestFilter:
"""A set of filters for test discovery."""

backends: set[str] | None
""" The set of backends to include. If None, all backends are included. """

name_regex: Pattern[str] | None
""" A regular expression to filter test names. If None, all tests are included. """


def discover_tests(
root_module: ModuleType, backends: set[str] | None
root_module: ModuleType, test_filter: TestFilter
) -> unittest.TestSuite:
# Collect all tests using the unittest discovery mechanism then filter down.

Expand All @@ -32,32 +45,48 @@ def discover_tests(
module_dir = os.path.dirname(module_file)
suite = loader.discover(module_dir)

return _filter_tests(suite, backends)
return _filter_tests(suite, test_filter)


def _filter_tests(
suite: unittest.TestSuite, backends: set[str] | None
suite: unittest.TestSuite, test_filter: TestFilter
) -> unittest.TestSuite:
# Recursively traverse the test suite and add them to the filtered set.
filtered_suite = unittest.TestSuite()

for child in suite:
if isinstance(child, unittest.TestSuite):
filtered_suite.addTest(_filter_tests(child, backends))
filtered_suite.addTest(_filter_tests(child, test_filter))
elif isinstance(child, unittest.TestCase):
if _is_test_enabled(child, backends):
if _is_test_enabled(child, test_filter):
filtered_suite.addTest(child)
else:
raise RuntimeError(f"Unexpected test type: {type(child)}")

return filtered_suite


def _is_test_enabled(test_case: unittest.TestCase, backends: set[str] | None) -> bool:
def _is_test_enabled(test_case: unittest.TestCase, test_filter: TestFilter) -> bool:
test_method = getattr(test_case, test_case._testMethodName)

if backends is not None:
flow: TestFlow = test_method._flow
return flow.backend in backends
else:

# Handle import / discovery failures - leave them enabled to report nicely at the
# top level. There might be a better way to do this. Internally, unittest seems to
# replace it with a stub method to report the failure.
if "testFailure" in str(test_method):
print(f"Warning: Test {test_case._testMethodName} failed to import.")
return True

if not hasattr(test_method, "_flow"):
raise RuntimeError(f"Test missing flow: {test_case._testMethodName} {test_method}")

flow: TestFlow = test_method._flow

if test_filter.backends is not None and flow.backend not in test_filter.backends:
return False

if test_filter.name_regex is not None and not test_filter.name_regex.search(
test_case.id()
):
return False

return True
Loading
Loading