Skip to content

Commit b02db12

Browse files
[Executorch][target recipes] Add target based recipes for lowering models to a target device (#13983)
Co-authored-by: Abhinay Kukkadapu <[email protected]>
1 parent fcdcd8e commit b02db12

File tree

17 files changed

+540
-70
lines changed

17 files changed

+540
-70
lines changed

backends/apple/coreml/TARGETS

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,21 @@ runtime.python_library(
6161
)
6262

6363
runtime.python_library(
64-
name = "recipes",
65-
srcs = glob([
66-
"recipes/*.py",
67-
]),
64+
name = "coreml_recipes",
65+
srcs = [
66+
"recipes/__init__.py",
67+
"recipes/coreml_recipe_provider.py"
68+
],
6869
visibility = [
6970
"@EXECUTORCH_CLIENTS",
71+
"//executorch/export/...",
7072
],
7173
deps = [
7274
"fbsource//third-party/pypi/coremltools:coremltools",
75+
":coreml_recipe_types",
7376
":backend",
77+
":partitioner",
78+
":quantizer",
7479
"//caffe2:torch",
7580
"//executorch/exir:lib",
7681
"//executorch/exir/backend:compile_spec_schema",
@@ -80,6 +85,20 @@ runtime.python_library(
8085
],
8186
)
8287

88+
runtime.python_library(
89+
name = "coreml_recipe_types",
90+
srcs = [
91+
"recipes/coreml_recipe_types.py",
92+
],
93+
visibility = [
94+
"@EXECUTORCH_CLIENTS",
95+
"//executorch/export/...",
96+
],
97+
deps = [
98+
"//executorch/export:recipe",
99+
],
100+
)
101+
83102
runtime.cxx_python_extension(
84103
name = "executorchcoreml",
85104
srcs = [
@@ -124,7 +143,7 @@ runtime.python_test(
124143
"fbsource//third-party/pypi/pytest:pytest",
125144
":partitioner",
126145
":quantizer",
127-
":recipes",
146+
":coreml_recipes",
128147
"//caffe2:torch",
129148
"//pytorch/vision:torchvision",
130149
"fbsource//third-party/pypi/scikit-learn:scikit-learn",

backends/apple/coreml/recipes/coreml_recipe_provider.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# Please refer to the license found in the LICENSE file in the root directory of the source tree.
44

55

6+
import logging
67
from typing import Any, Optional, Sequence
78

89
import coremltools as ct
@@ -111,8 +112,9 @@ def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> Non
111112

112113
unexpected = set(kwargs.keys()) - expected_keys
113114
if unexpected:
114-
raise ValueError(
115-
f"Recipe '{recipe_type.value}' received unexpected parameters: {list(unexpected)}"
115+
logging.warning(
116+
f"CoreML recipe '{recipe_type.value}' ignoring unexpected parameters: {list(unexpected)}. "
117+
f"Expected parameters: {list(expected_keys)}"
116118
)
117119

118120
self._validate_base_parameters(kwargs)
@@ -121,7 +123,13 @@ def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> Non
121123

122124
def _get_expected_keys(self, recipe_type: RecipeType) -> set:
123125
"""Get expected parameter keys for a recipe type"""
124-
common_keys = {"minimum_deployment_target", "compute_unit"}
126+
common_keys = {
127+
"minimum_deployment_target",
128+
"compute_unit",
129+
"skip_ops_for_coreml_delegation",
130+
"lower_full_graph",
131+
"take_over_constant_data",
132+
}
125133

126134
if recipe_type in [
127135
CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP,
@@ -377,9 +385,19 @@ def _get_coreml_lowering_recipe(
377385
if minimum_deployment_target and minimum_deployment_target < ct.target.iOS18:
378386
take_over_mutable_buffer = False
379387

388+
# Extract additional partitioner parameters
389+
skip_ops_for_coreml_delegation = kwargs.get(
390+
"skip_ops_for_coreml_delegation", None
391+
)
392+
lower_full_graph = kwargs.get("lower_full_graph", False)
393+
take_over_constant_data = kwargs.get("take_over_constant_data", True)
394+
380395
partitioner = CoreMLPartitioner(
381396
compile_specs=compile_specs,
382397
take_over_mutable_buffer=take_over_mutable_buffer,
398+
skip_ops_for_coreml_delegation=skip_ops_for_coreml_delegation,
399+
lower_full_graph=lower_full_graph,
400+
take_over_constant_data=take_over_constant_data,
383401
)
384402

385403
edge_compile_config = EdgeCompileConfig(

backends/apple/coreml/test/test_coreml_recipes.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -185,14 +185,6 @@ def test_int4_weight_only_per_group_validation(self):
185185
)
186186
self.assertIn("must be positive", str(cm.exception))
187187

188-
# Test unexpected parameter
189-
with self.assertRaises(ValueError) as cm:
190-
self.provider.create_recipe(
191-
CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_CHANNEL,
192-
group_size=32, # group_size not valid for per-channel
193-
)
194-
self.assertIn("unexpected parameters", str(cm.exception))
195-
196188
def test_int8_weight_only_per_channel(self):
197189
"""Test INT8 weight-only per-channel quantization"""
198190
model = TestHelperModules.TwoLinearModule().eval()
@@ -385,23 +377,6 @@ def forward(self, x):
385377
self._compare_eager_quantized_model_outputs(session, example_inputs, atol=1e-2)
386378
self._compare_eager_unquantized_model_outputs(session, model, example_inputs)
387379

388-
def test_pt2e_recipes_parameter_rejection(self):
389-
"""Test that PT2E recipes reject TorchAO-specific parameters"""
390-
# PT2E recipes should reject TorchAO-specific parameters
391-
pt2e_recipes = [
392-
CoreMLRecipeType.PT2E_INT8_STATIC,
393-
CoreMLRecipeType.PT2E_INT8_WEIGHT_ONLY,
394-
]
395-
torchao_params = ["filter_fn", "group_size", "bits", "block_size"]
396-
397-
for recipe_type in pt2e_recipes:
398-
for param in torchao_params:
399-
with self.subTest(recipe=recipe_type.value, param=param):
400-
kwargs = {param: "dummy_value"}
401-
with self.assertRaises(ValueError) as cm:
402-
self.provider.create_recipe(recipe_type, **kwargs)
403-
self.assertIn("unexpected parameters", str(cm.exception).lower())
404-
405380
def test_filter_fn_comprehensive(self):
406381
"""Comprehensive test for filter_fn parameter functionality"""
407382

backends/xnnpack/TARGETS

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,7 @@ runtime.python_library(
3636
],
3737
deps = [
3838
":xnnpack_preprocess",
39-
"//executorch/export:lib",
4039
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
4140
"//executorch/backends/xnnpack/utils:xnnpack_utils",
42-
"//executorch/backends/xnnpack/recipes:xnnpack_recipe_provider",
43-
"//executorch/backends/xnnpack/recipes:xnnpack_recipe_types",
4441
],
4542
)

backends/xnnpack/__init__.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,11 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from executorch.export import recipe_registry
8-
97
# Exposed Partitioners in XNNPACK Package
108
from .partition.xnnpack_partitioner import (
119
XnnpackDynamicallyQuantizedPartitioner,
1210
XnnpackPartitioner,
1311
)
14-
from .recipes.xnnpack_recipe_provider import XNNPACKRecipeProvider
15-
from .recipes.xnnpack_recipe_types import XNNPackRecipeType
16-
17-
# Auto-register XNNPACK recipe provider
18-
recipe_registry.register_backend_recipe_provider(XNNPACKRecipeProvider())
1912

2013
# Exposed Configs in XNNPACK Package
2114
from .utils.configs import (
@@ -34,7 +27,6 @@
3427
"XnnpackDynamicallyQuantizedPartitioner",
3528
"XnnpackPartitioner",
3629
"XnnpackBackend",
37-
"XNNPackRecipeType",
3830
"capture_graph_for_xnnpack",
3931
"get_xnnpack_capture_config",
4032
"get_xnnpack_edge_compile_config",

backends/xnnpack/recipes/TARGETS

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,22 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
22

33
oncall("executorch")
44

5+
runtime.python_library(
6+
name = "xnnpack_recipes",
7+
srcs = [
8+
"__init__.py",
9+
],
10+
visibility = [
11+
"//executorch/...",
12+
"@EXECUTORCH_CLIENTS",
13+
],
14+
deps = [
15+
"//executorch/export:recipe_registry",
16+
":xnnpack_recipe_provider",
17+
":xnnpack_recipe_types",
18+
],
19+
)
20+
521
runtime.python_library(
622
name = "xnnpack_recipe_provider",
723
srcs = [
@@ -30,6 +46,6 @@ runtime.python_library(
3046
"@EXECUTORCH_CLIENTS",
3147
],
3248
deps = [
33-
"//executorch/export:lib",
49+
"//executorch/export:recipe",
3450
],
3551
)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from executorch.export import recipe_registry
8+
9+
from .xnnpack_recipe_provider import XNNPACKRecipeProvider
10+
from .xnnpack_recipe_types import XNNPackRecipeType
11+
12+
# Auto-register XNNPACK recipe provider
13+
recipe_registry.register_backend_recipe_provider(XNNPACKRecipeProvider())
14+
15+
16+
__all__ = [
17+
"XNNPACKRecipeProvider",
18+
"XNNPackRecipeType",
19+
]

backends/xnnpack/recipes/xnnpack_recipe_provider.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-strict
88

9+
import logging
910
from typing import Any, Optional, Sequence
1011

1112
import torch
@@ -180,9 +181,9 @@ def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> Non
180181
expected_keys = {"group_size"}
181182
unexpected = set(kwargs.keys()) - expected_keys
182183
if unexpected:
183-
raise ValueError(
184-
f"Recipe '{recipe_type.value}' only accepts 'group_size' parameter. "
185-
f"Unexpected parameters: {list(unexpected)}"
184+
logging.warning(
185+
f"XNNPACK recipe '{recipe_type.value}' ignoring unexpected parameters: {list(unexpected)}. "
186+
f"Only 'group_size' is supported for this recipe."
186187
)
187188
if "group_size" in kwargs:
188189
group_size = kwargs["group_size"]
@@ -193,7 +194,7 @@ def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> Non
193194
elif kwargs:
194195
# All other recipes don't expect any kwargs
195196
unexpected = list(kwargs.keys())
196-
raise ValueError(
197-
f"Recipe '{recipe_type.value}' does not accept any parameters. "
198-
f"Unexpected parameters: {unexpected}"
197+
logging.warning(
198+
f"XNNPACK recipe '{recipe_type.value}' ignoring unexpected parameters: {unexpected}. "
199+
f"This recipe does not accept any parameters."
199200
)

backends/xnnpack/recipes/xnnpack_recipe_types.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,25 @@
1212
class XNNPackRecipeType(RecipeType):
1313
"""XNNPACK-specific recipe types"""
1414

15-
FP32 = "fp32"
15+
FP32 = "xnnpack_fp32"
1616

1717
## PT2E-based quantization recipes
1818
# INT8 Dynamic Quantization
19-
PT2E_INT8_DYNAMIC_PER_CHANNEL = "pt2e_int8_dynamic_per_channel"
19+
PT2E_INT8_DYNAMIC_PER_CHANNEL = "xnnpack_pt2e_int8_dynamic_per_channel"
2020
# INT8 Static Quantization, needs calibration dataset
21-
PT2E_INT8_STATIC_PER_CHANNEL = "pt2e_int8_static_per_channel"
22-
PT2E_INT8_STATIC_PER_TENSOR = "pt2e_int8_static_per_tensor"
21+
PT2E_INT8_STATIC_PER_CHANNEL = "xnnpack_pt2e_int8_static_per_channel"
22+
PT2E_INT8_STATIC_PER_TENSOR = "xnnpack_pt2e_int8_static_per_tensor"
2323

2424
## TorchAO-based quantization recipes
2525
# INT8 Dynamic Activations INT4 Weight Quantization, Axis = 0
2626
TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL = (
27-
"torchao_int8da_int4w_per_channel"
27+
"xnnpack_torchao_int8da_int4w_per_channel"
2828
)
2929
# INT8 Dynamic Activations INT4 Weight Quantization, default group_size = 32
3030
# can be overriden by group_size kwarg
31-
TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR = "torchao_int8da_int4w_per_tensor"
31+
TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR = (
32+
"xnnpack_torchao_int8da_int4w_per_tensor"
33+
)
3234

3335
@classmethod
3436
def get_backend_name(cls) -> str:

backends/xnnpack/test/TARGETS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ runtime.python_test(
105105
"HTTPS_PROXY": "http://fwdproxy:8080",
106106
},
107107
deps = [
108-
"//executorch/backends/xnnpack:xnnpack_delegate",
108+
"//executorch/backends/xnnpack/recipes:xnnpack_recipes",
109109
"//executorch/export:lib",
110110
"//pytorch/vision:torchvision", # @manual
111111
"//executorch/backends/xnnpack/test/tester:tester",

0 commit comments

Comments
 (0)