@@ -56,8 +56,7 @@ def get_mps_partitioner(use_kv_cache: bool = False):
5656
5757
5858def get_coreml_partitioner (
59- enable_state : bool = False ,
60- preserve_sdpa : bool = True ,
59+ ios : int = 15 ,
6160 embedding_quantize : Optional [str ] = None ,
6261 pt2e_quantize : Optional [str ] = None ,
6362 coreml_quantize : Optional [str ] = None ,
@@ -75,29 +74,42 @@ def get_coreml_partitioner(
7574 "Please install the CoreML backend follwing https://pytorch.org/executorch/main/build-run-coreml.html"
7675 )
7776
78- minimum_deployment_target = ct .target .iOS15
79- # In Core ML, stateful execution is introduced in iOS 18
80- if enable_state :
81- minimum_deployment_target = max (minimum_deployment_target , ct .target .iOS18 )
82- # In Core ML, sdpa op is introduced in iOS 18
83- if preserve_sdpa :
84- minimum_deployment_target = max (minimum_deployment_target , ct .target .iOS18 )
85- # In Core ML, quantization is introduced in iOS 16
86- if embedding_quantize is not None or pt2e_quantize is not None :
87- minimum_deployment_target = max (minimum_deployment_target , ct .target .iOS16 )
88- # In Core ML, 8-bit activation quantization is introduced in iOS 17
89- if (
90- embedding_quantize is not None and int (embedding_quantize .split ("," )[0 ]) == 8
91- ) or pt2e_quantize in ("coreml_8a_c8w" , "coreml_baseline_8a_c8w" ):
92- minimum_deployment_target = max (minimum_deployment_target , ct .target .iOS17 )
93- # In Core ML, 4-bit weight compression is introduced in iOS 18
94- if (
95- (embedding_quantize is not None and int (embedding_quantize .split ("," )[0 ]) == 4 )
96- or pt2e_quantize in ("coreml_c4w" , "coreml_8a_c4w" , "coreml_baseline_8a_c4w" )
97- or coreml_quantize == "b4w"
98- ):
99- minimum_deployment_target = max (minimum_deployment_target , ct .target .iOS18 )
77+ def _validate_ios_version () -> None :
78+ assert ios in (15 , 16 , 17 , 18 )
10079
80+ if embedding_quantize is not None and ios < 18 :
81+ raise ValueError (
82+ "In Core ML, per-block quantization is introduced in iOS 18"
83+ )
84+
85+ use_quantization = pt2e_quantize is not None or coreml_quantize is not None
86+ if use_quantization and ios < 16 :
87+ raise ValueError ("In Core ML, quantization is introduced in iOS 16" )
88+
89+ use_8a = (pt2e_quantize is not None and "8a" in pt2e_quantize ) or (
90+ coreml_quantize is not None and "8a" in coreml_quantize
91+ )
92+ if use_8a and ios < 17 :
93+ raise ValueError (
94+ "In Core ML, 8-bit activation quantization is introduced in iOS 17"
95+ )
96+
97+ use_4w = (pt2e_quantize is not None and "4w" in pt2e_quantize ) or (
98+ coreml_quantize is not None and "4w" in coreml_quantize
99+ )
100+ if use_4w and ios < 18 :
101+ raise ValueError (
102+ "In Core ML, 4-bit weight compression is introduced in iOS 18"
103+ )
104+
105+ _validate_ios_version ()
106+
107+ minimum_deployment_target = {
108+ 15 : ct .target .iOS15 ,
109+ 16 : ct .target .iOS16 ,
110+ 17 : ct .target .iOS17 ,
111+ 18 : ct .target .iOS18 ,
112+ }[ios ]
101113 op_linear_quantizer_config = None
102114 if coreml_quantize == "b4w" :
103115 op_linear_quantizer_config = {
@@ -107,7 +119,6 @@ def get_coreml_partitioner(
107119 "block_size" : 32 ,
108120 "weight_threshold" : 512 ,
109121 }
110-
111122 compile_specs = CoreMLBackend .generate_compile_specs ( # pyre-fixme[16]
112123 minimum_deployment_target = minimum_deployment_target ,
113124 compute_precision = ct .precision (ct .precision .FLOAT16 .value ),
@@ -116,9 +127,12 @@ def get_coreml_partitioner(
116127 model_type = CoreMLBackend .MODEL_TYPE .MODEL , # pyre-fixme[16]
117128 op_linear_quantizer_config = op_linear_quantizer_config ,
118129 )
130+
131+ take_over_mutable_buffer = minimum_deployment_target >= ct .target .iOS18
132+
119133 return CoreMLPartitioner ( # pyre-fixme[16]
120134 compile_specs = compile_specs ,
121- take_over_mutable_buffer = enable_state ,
135+ take_over_mutable_buffer = take_over_mutable_buffer ,
122136 )
123137
124138
0 commit comments