diff --git a/backends/apple/coreml/TARGETS b/backends/apple/coreml/TARGETS index 188d2b63b53..6993b699427 100644 --- a/backends/apple/coreml/TARGETS +++ b/backends/apple/coreml/TARGETS @@ -60,6 +60,26 @@ runtime.python_library( ], ) +runtime.python_library( + name = "recipes", + srcs = glob([ + "recipes/*.py", + ]), + visibility = [ + "@EXECUTORCH_CLIENTS", + ], + deps = [ + "fbsource//third-party/pypi/coremltools:coremltools", + ":backend", + "//caffe2:torch", + "//executorch/exir:lib", + "//executorch/exir/backend:compile_spec_schema", + "//executorch/exir/backend:partitioner", + "//executorch/exir/backend:utils", + "//executorch/export:lib", + ], +) + runtime.cxx_python_extension( name = "executorchcoreml", srcs = [ @@ -103,6 +123,7 @@ runtime.python_test( "fbsource//third-party/pypi/pytest:pytest", ":partitioner", ":quantizer", + ":recipes", "//caffe2:torch", "//pytorch/vision:torchvision", ], diff --git a/backends/apple/coreml/recipes/__init__.py b/backends/apple/coreml/recipes/__init__.py new file mode 100644 index 00000000000..8bcd1c254a8 --- /dev/null +++ b/backends/apple/coreml/recipes/__init__.py @@ -0,0 +1,17 @@ +# Copyright © 2025 Apple Inc. All rights reserved. +# +# Please refer to the license found in the LICENSE file in the root directory of the source tree. + + +from executorch.export import recipe_registry + +from .coreml_recipe_provider import CoreMLRecipeProvider +from .coreml_recipe_types import CoreMLRecipeType + +# Auto-register CoreML backend recipe provider +recipe_registry.register_backend_recipe_provider(CoreMLRecipeProvider()) + +__all__ = [ + "CoreMLRecipeProvider", + "CoreMLRecipeType", +] diff --git a/backends/apple/coreml/recipes/coreml_recipe_provider.py b/backends/apple/coreml/recipes/coreml_recipe_provider.py new file mode 100644 index 00000000000..75c937027bb --- /dev/null +++ b/backends/apple/coreml/recipes/coreml_recipe_provider.py @@ -0,0 +1,132 @@ +# Copyright © 2025 Apple Inc. All rights reserved. +# +# Please refer to the license found in the LICENSE file in the root directory of the source tree. + + +from typing import Any, Optional, Sequence + +import coremltools as ct + +from executorch.backends.apple.coreml.compiler import CoreMLBackend +from executorch.backends.apple.coreml.partition.coreml_partitioner import ( + CoreMLPartitioner, +) +from executorch.backends.apple.coreml.recipes.coreml_recipe_types import ( + COREML_BACKEND, + CoreMLRecipeType, +) + +from executorch.exir import EdgeCompileConfig +from executorch.export import ( + BackendRecipeProvider, + ExportRecipe, + LoweringRecipe, + RecipeType, +) + + +class CoreMLRecipeProvider(BackendRecipeProvider): + @property + def backend_name(self) -> str: + return COREML_BACKEND + + def get_supported_recipes(self) -> Sequence[RecipeType]: + return list(CoreMLRecipeType) + + def create_recipe( + self, recipe_type: RecipeType, **kwargs: Any + ) -> Optional[ExportRecipe]: + """Create CoreML recipe with precision and compute unit combinations""" + + if recipe_type not in self.get_supported_recipes(): + return None + + if ct is None: + raise ImportError( + "coremltools is required for CoreML recipes. " + "Install it with: pip install coremltools" + ) + + # Validate kwargs + self._validate_recipe_kwargs(recipe_type, **kwargs) + + # Parse recipe type to get precision and compute unit + precision = None + if recipe_type == CoreMLRecipeType.FP32: + precision = ct.precision.FLOAT32 + elif recipe_type == CoreMLRecipeType.FP16: + precision = ct.precision.FLOAT16 + + if precision is None: + raise ValueError(f"Unknown precision for recipe: {recipe_type.value}") + + return self._build_recipe(recipe_type, precision, **kwargs) + + def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> None: + if not kwargs: + return + expected_keys = {"minimum_deployment_target", "compute_unit"} + unexpected = set(kwargs.keys()) - expected_keys + if unexpected: + raise ValueError( + f"CoreML Recipes only accept 'minimum_deployment_target' or 'compute_unit' as parameter. " + f"Unexpected parameters: {list(unexpected)}" + ) + if "minimum_deployment_target" in kwargs: + minimum_deployment_target = kwargs["minimum_deployment_target"] + if not isinstance(minimum_deployment_target, ct.target): + raise ValueError( + f"Parameter 'minimum_deployment_target' must be an enum of type ct.target, got {type(minimum_deployment_target)}" + ) + if "compute_unit" in kwargs: + compute_unit = kwargs["compute_unit"] + if not isinstance(compute_unit, ct.ComputeUnit): + raise ValueError( + f"Parameter 'compute_unit' must be an enum of type ct.ComputeUnit, got {type(compute_unit)}" + ) + + def _build_recipe( + self, + recipe_type: RecipeType, + precision: ct.precision, + **kwargs: Any, + ) -> ExportRecipe: + lowering_recipe = self._get_coreml_lowering_recipe( + compute_precision=precision, + **kwargs, + ) + + return ExportRecipe( + name=recipe_type.value, + quantization_recipe=None, # TODO - add quantization recipe + lowering_recipe=lowering_recipe, + ) + + def _get_coreml_lowering_recipe( + self, + compute_precision: ct.precision, + **kwargs: Any, + ) -> LoweringRecipe: + compile_specs = CoreMLBackend.generate_compile_specs( + compute_precision=compute_precision, + **kwargs, + ) + + minimum_deployment_target = kwargs.get("minimum_deployment_target", None) + take_over_mutable_buffer = True + if minimum_deployment_target and minimum_deployment_target < ct.target.iOS18: + take_over_mutable_buffer = False + + partitioner = CoreMLPartitioner( + compile_specs=compile_specs, + take_over_mutable_buffer=take_over_mutable_buffer, + ) + + edge_compile_config = EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=False, + ) + + return LoweringRecipe( + partitioners=[partitioner], edge_compile_config=edge_compile_config + ) diff --git a/backends/apple/coreml/recipes/coreml_recipe_types.py b/backends/apple/coreml/recipes/coreml_recipe_types.py new file mode 100644 index 00000000000..77f808bd982 --- /dev/null +++ b/backends/apple/coreml/recipes/coreml_recipe_types.py @@ -0,0 +1,25 @@ +# Copyright © 2025 Apple Inc. All rights reserved. +# +# Please refer to the license found in the LICENSE file in the root directory of the source tree. + + +from executorch.export import RecipeType + + +COREML_BACKEND: str = "coreml" + + +class CoreMLRecipeType(RecipeType): + """CoreML-specific generic recipe types""" + + # FP32 generic recipe, defaults to values published by the CoreML backend and partitioner + # Precision = FP32, Default compute_unit = All (can be overriden by kwargs) + FP32 = "coreml_fp32" + + # FP16 generic recipe, defaults to values published by the CoreML backend and partitioner + # Precision = FP32, Default compute_unit = All (can be overriden by kwargs) + FP16 = "coreml_fp16" + + @classmethod + def get_backend_name(cls) -> str: + return COREML_BACKEND diff --git a/backends/apple/coreml/test/test_coreml_recipes.py b/backends/apple/coreml/test/test_coreml_recipes.py new file mode 100644 index 00000000000..ca5c6c30c9c --- /dev/null +++ b/backends/apple/coreml/test/test_coreml_recipes.py @@ -0,0 +1,238 @@ +# Copyright © 2025 Apple Inc. All rights reserved. +# +# Please refer to the license found in the LICENSE file in the root directory of the source tree. + + +import unittest +from typing import List + +import coremltools as ct + +import torch +from executorch.backends.apple.coreml.recipes import ( + CoreMLRecipeProvider, + CoreMLRecipeType, +) + +from executorch.backends.apple.coreml.test.test_coreml_utils import ( + IS_VALID_TEST_RUNTIME, +) +from executorch.exir.schema import DelegateCall, Program +from executorch.export import export, ExportRecipe, recipe_registry +from torch import nn +from torch.testing._internal.common_quantization import TestHelperModules + + +class TestCoreMLRecipes(unittest.TestCase): + fp32_recipes: List[CoreMLRecipeType] = [ + CoreMLRecipeType.FP32, + ] + fp16_recipes: List[CoreMLRecipeType] = [ + CoreMLRecipeType.FP16, + ] + + def setUp(self): + torch._dynamo.reset() + super().setUp() + self.provider = CoreMLRecipeProvider() + # Register the provider for recipe registry tests + recipe_registry.register_backend_recipe_provider(CoreMLRecipeProvider()) + + def tearDown(self): + super().tearDown() + + def check_fully_delegated(self, program: Program) -> None: + instructions = program.execution_plan[0].chains[0].instructions + assert instructions is not None + self.assertEqual(len(instructions), 1) + self.assertIsInstance(instructions[0].instr_args, DelegateCall) + + def test_all_fp32_recipes_with_simple_model(self): + """Test all FP32 recipes with a simple linear model""" + for recipe_type in self.fp32_recipes: + with self.subTest(recipe=recipe_type.value): + m_eager = TestHelperModules.TwoLinearModule().eval() + example_inputs = [(torch.randn(9, 8),)] + + session = export( + model=m_eager, + example_inputs=example_inputs, + export_recipe=ExportRecipe.get_recipe(recipe_type), + ) + self.check_fully_delegated(session.get_executorch_program()) + + # Verify outputs match + if IS_VALID_TEST_RUNTIME: + self.assertTrue( + torch.allclose( + session.run_method("forward", example_inputs[0])[0], + m_eager(*example_inputs[0]), + atol=1e-3, + ) + ) + + def test_all_fp16_recipes_with_simple_model(self): + """Test all FP16 recipes with a simple linear model""" + + for recipe_type in self.fp16_recipes: + with self.subTest(recipe=recipe_type.value): + m_eager = TestHelperModules.TwoLinearModule().eval() + example_inputs = [(torch.randn(9, 8),)] + + session = export( + model=m_eager, + example_inputs=example_inputs, + export_recipe=ExportRecipe.get_recipe(recipe_type), + ) + + self.check_fully_delegated(session.get_executorch_program()) + + # Verify outputs match (slightly higher tolerance for FP16) + if IS_VALID_TEST_RUNTIME: + self.assertTrue( + torch.allclose( + session.run_method("forward", example_inputs[0])[0], + m_eager(*example_inputs[0]), + atol=1e-3, + ) + ) + + def test_custom_simple_model(self): + """Test with a custom simple model""" + + class CustomTestModel(nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(10, 20) + self.relu = nn.ReLU() + self.linear2 = nn.Linear(20, 1) + + def forward(self, x): + x = self.linear1(x) + x = self.relu(x) + x = self.linear2(x) + return x + + model = CustomTestModel().eval() + example_inputs = [(torch.randn(1, 10),)] + for recipe_type in self.fp32_recipes + self.fp16_recipes: + with self.subTest(recipe=recipe_type.value): + session = export( + model=model, + example_inputs=example_inputs, + export_recipe=ExportRecipe.get_recipe(recipe_type), + ) + session.print_delegation_info() + self.check_fully_delegated(session.get_executorch_program()) + + if IS_VALID_TEST_RUNTIME: + self.assertTrue( + torch.allclose( + session.run_method("forward", example_inputs[0])[0], + model(*example_inputs[0]), + atol=1e-3, + ) + ) + + def test_unsupported_recipe_type(self): + """Test that unsupported recipe types return None""" + from executorch.export import RecipeType + + class UnsupportedRecipeType(RecipeType): + UNSUPPORTED = "unsupported" + + @classmethod + def get_backend_name(cls) -> str: + return "dummy" + + recipe = self.provider.create_recipe(UnsupportedRecipeType.UNSUPPORTED) + self.assertIsNone(recipe) + + def test_recipe_registry_integration(self): + """Test that recipes work with the global recipe registry""" + for recipe_type in self.fp32_recipes + self.fp16_recipes: + with self.subTest(recipe=recipe_type.value): + recipe = ExportRecipe.get_recipe(recipe_type) + self.assertIsNotNone(recipe) + self.assertEqual(recipe.name, recipe_type.value) + + def test_invalid_recipe_kwargs(self): + """Test detailed error messages for invalid kwargs""" + provider = CoreMLRecipeProvider() + + # Test single invalid parameter + with self.assertRaises(ValueError) as cm: + provider.create_recipe(CoreMLRecipeType.FP16, invalid_param=123) + + error_msg = str(cm.exception) + self.assertIn("Unexpected parameters", error_msg) + + # Test multiple invalid parameters + with self.assertRaises(ValueError) as cm: + provider.create_recipe( + CoreMLRecipeType.FP32, param1="value1", param2="value2" + ) + + error_msg = str(cm.exception) + self.assertIn("Unexpected parameters", error_msg) + + # Test mix of valid and invalid parameters + with self.assertRaises(ValueError) as cm: + provider.create_recipe( + CoreMLRecipeType.FP32, + minimum_deployment_target=ct.target.iOS16, # valid + invalid_param="invalid", # invalid + ) + + error_msg = str(cm.exception) + self.assertIn("Unexpected parameters", error_msg) + + def test_valid_kwargs(self): + """Test valid kwargs""" + recipe = self.provider.create_recipe( + CoreMLRecipeType.FP32, + minimum_deployment_target=ct.target.iOS16, + compute_unit=ct.ComputeUnit.CPU_AND_GPU, + ) + self.assertIsNotNone(recipe) + self.assertEqual(recipe.name, "coreml_fp32") + + # Verify partitioners are properly configured + partitioners = recipe.lowering_recipe.partitioners + self.assertEqual(len(partitioners), 1, "Expected exactly one partitioner") + + # Verify delegation spec and compile specs + delegation_spec = partitioners[0].delegation_spec + self.assertIsNotNone(delegation_spec, "Delegation spec should not be None") + + compile_specs = delegation_spec.compile_specs + self.assertIsNotNone(compile_specs, "Compile specs should not be None") + + spec_dict = {spec.key: spec.value for spec in compile_specs} + + # Assert that all expected specs are present with correct values + self.assertIn( + "min_deployment_target", + spec_dict, + "minimum_deployment_target should be in compile specs", + ) + min_target_value = spec_dict["min_deployment_target"] + if isinstance(min_target_value, bytes): + min_target_value = min_target_value.decode("utf-8") + self.assertEqual( + str(min_target_value), + str(ct.target.iOS16.value), + "minimum_deployment_target should match the provided value", + ) + + self.assertIn( + "compute_units", spec_dict, "compute_unit should be in compile specs" + ) + compute_unit_value = spec_dict["compute_units"] + if isinstance(compute_unit_value, bytes): + compute_unit_value = compute_unit_value.decode("utf-8") + self.assertEqual( + str(compute_unit_value), + ct.ComputeUnit.CPU_AND_GPU.name.lower(), + "compute_unit should match the provided value", + ) diff --git a/backends/apple/coreml/test/test_coreml_utils.py b/backends/apple/coreml/test/test_coreml_utils.py new file mode 100644 index 00000000000..7d9ac7ba5a5 --- /dev/null +++ b/backends/apple/coreml/test/test_coreml_utils.py @@ -0,0 +1,19 @@ +# Copyright © 2025 Apple Inc. All rights reserved. +# +# Please refer to the license found in the LICENSE file in the root directory of the source tree. + +import platform +import sys + +import torch + + +def is_fbcode(): + return not hasattr(torch.version, "git_version") + + +IS_VALID_TEST_RUNTIME: bool = ( + (sys.platform == "darwin") + and not is_fbcode() + and tuple(map(int, platform.mac_ver()[0].split("."))) >= (15, 0) +) diff --git a/backends/apple/coreml/test/test_torch_ops.py b/backends/apple/coreml/test/test_torch_ops.py index 67bc8be197d..0d6b581ee72 100644 --- a/backends/apple/coreml/test/test_torch_ops.py +++ b/backends/apple/coreml/test/test_torch_ops.py @@ -2,8 +2,6 @@ # # Please refer to the license found in the LICENSE file in the root directory of the source tree. -import platform -import sys import unittest import coremltools as ct @@ -14,22 +12,15 @@ from executorch.backends.apple.coreml.compiler import CoreMLBackend from executorch.backends.apple.coreml.partition import CoreMLPartitioner +from executorch.backends.apple.coreml.test.test_coreml_utils import ( + IS_VALID_TEST_RUNTIME, +) from executorch.exir.backend.utils import format_delegated_graph from torchao.prototype.quantization.codebook_coreml import CodebookWeightOnlyConfig from torchao.quantization import IntxWeightOnlyConfig, PerAxis, PerGroup, quantize_ - -def is_fbcode(): - return not hasattr(torch.version, "git_version") - - -_TEST_RUNTIME = ( - (sys.platform == "darwin") - and not is_fbcode() - and tuple(map(int, platform.mac_ver()[0].split("."))) >= (15, 0) -) -if _TEST_RUNTIME: +if IS_VALID_TEST_RUNTIME: from executorch.runtime import Runtime @@ -50,7 +41,7 @@ def _get_test_model(self): return model, example_inputs def _compare_outputs(self, executorch_program, eager_program, example_inputs): - if not _TEST_RUNTIME: + if not IS_VALID_TEST_RUNTIME: return runtime = Runtime.get() program = runtime.load_program(executorch_program.buffer) diff --git a/export/export.py b/export/export.py index e5c3b793ccd..597ec28665b 100644 --- a/export/export.py +++ b/export/export.py @@ -446,10 +446,24 @@ def print_delegation_info(self) -> None: """ Print delegation information for the exported program. """ - delegation_info = self._run_context.get("delegation_info", None) + lowering_stage = list( + set(self._pipeline_stages) + & {StageType.TO_EDGE_TRANSFORM_AND_LOWER, StageType.TO_BACKEND} + ) + if not lowering_stage: + RuntimeError( + "No delegation info available, atleast one of the lowering stages should be present" + ) + + stage_artifact = self._stage_to_artifacts.get(lowering_stage[0]) + if stage_artifact is None: + RuntimeError("No delegation info available, run the lowering stage first") + + # pyre-ignore + delegation_info = stage_artifact.get_context("delegation_info", None) if delegation_info: - logging.info(delegation_info.get_summary()) + print(delegation_info.get_summary()) df = delegation_info.get_operator_delegation_dataframe() - logging.info(tabulate(df, headers="keys", tablefmt="fancy_grid")) + print(tabulate(df, headers="keys", tablefmt="fancy_grid")) else: - logging.info("No delegation info available") + print("No delegation info available")