diff --git a/backends/apple/coreml/recipes/coreml_recipe_provider.py b/backends/apple/coreml/recipes/coreml_recipe_provider.py index 77e15aeced3..ed29e57f26b 100644 --- a/backends/apple/coreml/recipes/coreml_recipe_provider.py +++ b/backends/apple/coreml/recipes/coreml_recipe_provider.py @@ -69,6 +69,7 @@ def create_recipe( recipe_type, activation_dtype=torch.float32, **kwargs ) elif recipe_type == CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_CHANNEL: + self._validate_and_set_deployment_target(kwargs, ct.target.iOS18, "torchao") return self._build_torchao_quantized_recipe( recipe_type, weight_dtype=torch.int4, @@ -77,6 +78,7 @@ def create_recipe( ) elif recipe_type == CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP: group_size = kwargs.pop("group_size", 32) + self._validate_and_set_deployment_target(kwargs, ct.target.iOS18, "torchao") return self._build_torchao_quantized_recipe( recipe_type, weight_dtype=torch.int4, @@ -85,11 +87,14 @@ def create_recipe( **kwargs, ) elif recipe_type == CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_CHANNEL: + self._validate_and_set_deployment_target(kwargs, ct.target.iOS16, "torchao") return self._build_torchao_quantized_recipe( recipe_type, weight_dtype=torch.int8, is_per_channel=True, **kwargs ) elif recipe_type == CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_GROUP: group_size = kwargs.pop("group_size", 32) + # override minimum_deployment_target to ios18 for torchao (GH issue #13122) + self._validate_and_set_deployment_target(kwargs, ct.target.iOS18, "torchao") return self._build_torchao_quantized_recipe( recipe_type, weight_dtype=torch.int8, @@ -312,8 +317,6 @@ def _build_torchao_quantized_recipe( ao_quantization_configs=[config], ) - # override minimum_deployment_target to ios18 for torchao (GH issue #13122) - self._validate_and_set_deployment_target(kwargs, ct.target.iOS18, "torchao") lowering_recipe = self._get_coreml_lowering_recipe(**kwargs) return ExportRecipe( diff --git a/backends/apple/coreml/test/test_coreml_recipes.py b/backends/apple/coreml/test/test_coreml_recipes.py index 78d5a30063c..f326a8879a4 100644 --- a/backends/apple/coreml/test/test_coreml_recipes.py +++ b/backends/apple/coreml/test/test_coreml_recipes.py @@ -501,7 +501,7 @@ def test_minimum_deployment_target_validation(self): (CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP, ct.target.iOS18, {}), ( CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_CHANNEL, - ct.target.iOS18, + ct.target.iOS16, {}, ), (CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_GROUP, ct.target.iOS18, {}),