@@ -69,6 +69,7 @@ def get_coreml_partitioner(
69
69
embedding_quantize : Optional [str ] = None ,
70
70
pt2e_quantize : Optional [str ] = None ,
71
71
coreml_quantize : Optional [str ] = None ,
72
+ coreml_compute_units : Optional [str ] = None ,
72
73
):
73
74
try :
74
75
import coremltools as ct
@@ -119,6 +120,19 @@ def _validate_ios_version() -> None:
119
120
17 : ct .target .iOS17 ,
120
121
18 : ct .target .iOS18 ,
121
122
}[ios ]
123
+
124
+ if coreml_compute_units is None :
125
+ # using `ComputeUnit.ALL` can increase the model load time
126
+ # On iPhone 15 Pro, CPU decode model is over 8x faster than GPU for stories110M,
127
+ # so default to CPU_ONLY
128
+ coreml_compute_units = "cpu_only"
129
+ coreml_compute_units = {
130
+ "cpu_only" : ct .ComputeUnit .CPU_ONLY ,
131
+ "cpu_and_ne" : ct .ComputeUnit .CPU_AND_NE ,
132
+ "cpu_and_gpu" : ct .ComputeUnit .CPU_AND_GPU ,
133
+ "all" : ct .ComputeUnit .ALL ,
134
+ }[coreml_compute_units .lower ()]
135
+
122
136
op_linear_quantizer_config = None
123
137
if coreml_quantize == "b4w" :
124
138
op_linear_quantizer_config = {
@@ -128,17 +142,22 @@ def _validate_ios_version() -> None:
128
142
"block_size" : 32 ,
129
143
"weight_threshold" : 512 ,
130
144
}
145
+ elif coreml_quantize == "c4w" :
146
+ op_linear_quantizer_config = {
147
+ "mode" : "linear_symmetric" ,
148
+ "dtype" : "int4" ,
149
+ "granularity" : "per_channel" ,
150
+ }
151
+
131
152
compile_specs = CoreMLBackend .generate_compile_specs ( # pyre-fixme[16]
132
153
minimum_deployment_target = minimum_deployment_target ,
133
154
compute_precision = ct .precision (ct .precision .FLOAT16 .value ),
134
- # using `ComputeUnit.ALL` can increase the model load time, default to `ComputeUnit.CPU_AND_GPU`
135
- compute_unit = ct .ComputeUnit [ct .ComputeUnit .CPU_AND_GPU .name .upper ()],
155
+ compute_unit = coreml_compute_units ,
136
156
model_type = CoreMLBackend .MODEL_TYPE .MODEL , # pyre-fixme[16]
137
157
op_linear_quantizer_config = op_linear_quantizer_config ,
138
158
)
139
159
140
160
take_over_mutable_buffer = minimum_deployment_target >= ct .target .iOS18
141
-
142
161
return CoreMLPartitioner ( # pyre-fixme[16]
143
162
compile_specs = compile_specs ,
144
163
take_over_mutable_buffer = take_over_mutable_buffer ,
0 commit comments