@@ -69,6 +69,7 @@ def get_coreml_partitioner(
6969 embedding_quantize : Optional [str ] = None ,
7070 pt2e_quantize : Optional [str ] = None ,
7171 coreml_quantize : Optional [str ] = None ,
72+ coreml_compute_units : Optional [str ] = None ,
7273):
7374 try :
7475 import coremltools as ct
@@ -119,6 +120,19 @@ def _validate_ios_version() -> None:
119120 17 : ct .target .iOS17 ,
120121 18 : ct .target .iOS18 ,
121122 }[ios ]
123+
124+ if coreml_compute_units is None :
125+ # using `ComputeUnit.ALL` can increase the model load time
126+ # On iPhone 15 Pro, CPU decode model is over 8x faster than GPU for stories110M,
127+ # so default to CPU_ONLY
128+ coreml_compute_units = "cpu_only"
129+ coreml_compute_units = {
130+ "cpu_only" : ct .ComputeUnit .CPU_ONLY ,
131+ "cpu_and_ne" : ct .ComputeUnit .CPU_AND_NE ,
132+ "cpu_and_gpu" : ct .ComputeUnit .CPU_AND_GPU ,
133+ "all" : ct .ComputeUnit .ALL ,
134+ }[coreml_compute_units .lower ()]
135+
122136 op_linear_quantizer_config = None
123137 if coreml_quantize == "b4w" :
124138 op_linear_quantizer_config = {
@@ -128,17 +142,22 @@ def _validate_ios_version() -> None:
128142 "block_size" : 32 ,
129143 "weight_threshold" : 512 ,
130144 }
145+ elif coreml_quantize == "c4w" :
146+ op_linear_quantizer_config = {
147+ "mode" : "linear_symmetric" ,
148+ "dtype" : "int4" ,
149+ "granularity" : "per_channel" ,
150+ }
151+
131152 compile_specs = CoreMLBackend .generate_compile_specs ( # pyre-fixme[16]
132153 minimum_deployment_target = minimum_deployment_target ,
133154 compute_precision = ct .precision (ct .precision .FLOAT16 .value ),
134- # using `ComputeUnit.ALL` can increase the model load time, default to `ComputeUnit.CPU_AND_GPU`
135- compute_unit = ct .ComputeUnit [ct .ComputeUnit .CPU_AND_GPU .name .upper ()],
155+ compute_units = coreml_compute_units ,
136156 model_type = CoreMLBackend .MODEL_TYPE .MODEL , # pyre-fixme[16]
137157 op_linear_quantizer_config = op_linear_quantizer_config ,
138158 )
139159
140160 take_over_mutable_buffer = minimum_deployment_target >= ct .target .iOS18
141-
142161 return CoreMLPartitioner ( # pyre-fixme[16]
143162 compile_specs = compile_specs ,
144163 take_over_mutable_buffer = take_over_mutable_buffer ,
0 commit comments