3
3
# Please refer to the license found in the LICENSE file in the root directory of the source tree.
4
4
5
5
6
+ import logging
6
7
from typing import Any , Optional , Sequence
7
8
8
9
import coremltools as ct
@@ -111,8 +112,9 @@ def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> Non
111
112
112
113
unexpected = set (kwargs .keys ()) - expected_keys
113
114
if unexpected :
114
- raise ValueError (
115
- f"Recipe '{ recipe_type .value } ' received unexpected parameters: { list (unexpected )} "
115
+ logging .warning (
116
+ f"CoreML recipe '{ recipe_type .value } ' ignoring unexpected parameters: { list (unexpected )} . "
117
+ f"Expected parameters: { list (expected_keys )} "
116
118
)
117
119
118
120
self ._validate_base_parameters (kwargs )
@@ -121,7 +123,13 @@ def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> Non
121
123
122
124
def _get_expected_keys (self , recipe_type : RecipeType ) -> set :
123
125
"""Get expected parameter keys for a recipe type"""
124
- common_keys = {"minimum_deployment_target" , "compute_unit" }
126
+ common_keys = {
127
+ "minimum_deployment_target" ,
128
+ "compute_unit" ,
129
+ "skip_ops_for_coreml_delegation" ,
130
+ "lower_full_graph" ,
131
+ "take_over_constant_data" ,
132
+ }
125
133
126
134
if recipe_type in [
127
135
CoreMLRecipeType .TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP ,
@@ -377,9 +385,19 @@ def _get_coreml_lowering_recipe(
377
385
if minimum_deployment_target and minimum_deployment_target < ct .target .iOS18 :
378
386
take_over_mutable_buffer = False
379
387
388
+ # Extract additional partitioner parameters
389
+ skip_ops_for_coreml_delegation = kwargs .get (
390
+ "skip_ops_for_coreml_delegation" , None
391
+ )
392
+ lower_full_graph = kwargs .get ("lower_full_graph" , False )
393
+ take_over_constant_data = kwargs .get ("take_over_constant_data" , True )
394
+
380
395
partitioner = CoreMLPartitioner (
381
396
compile_specs = compile_specs ,
382
397
take_over_mutable_buffer = take_over_mutable_buffer ,
398
+ skip_ops_for_coreml_delegation = skip_ops_for_coreml_delegation ,
399
+ lower_full_graph = lower_full_graph ,
400
+ take_over_constant_data = take_over_constant_data ,
383
401
)
384
402
385
403
edge_compile_config = EdgeCompileConfig (
0 commit comments