1818from executorch .examples .apple .coreml .llama .utils import (
1919 replace_linear_with_split_linear ,
2020)
21- from executorch .examples .models .llama .source_transformation .quantize import (
22- EmbeddingQuantHandler ,
23- )
2421
2522from executorch .exir .backend .utils import format_delegated_graph
2623from executorch .exir .capture ._config import EdgeCompileConfig , ExecutorchBackendConfig
2724from executorch .exir .passes import MemoryPlanningPass
2825from executorch .exir .passes .quant_fusion_pass import QuantFusionPass
2926from executorch .exir .passes .sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
30- from executorch .exir .program ._program import to_edge , to_edge_transform_and_lower
27+ from executorch .exir .program ._program import to_edge_transform_and_lower
3128from executorch .extension .export_util .utils import save_pte_program
3229
30+ from torchao .quantization .granularity import PerAxis , PerGroup
31+ from torchao .quantization .quant_api import IntxWeightOnlyConfig , quantize_
32+ from torchao .utils import unwrap_tensor_subclass
33+
3334
3435def main () -> None :
3536 parser = argparse .ArgumentParser ()
@@ -115,19 +116,8 @@ def main() -> None:
115116 export_args .dtype
116117 ] # dtype for model/inputs
117118
118- if export_args .embedding_quantize :
119- bitwidth , group_size = export_args .embedding_quantize .split ("," )
120- if group_size == "none" or group_size == "None" or group_size == "0" :
121- group_size = None
122- else :
123- group_size = int (group_size )
124- bitwidth = int (bitwidth )
125- model = EmbeddingQuantHandler (
126- model ,
127- bitwidth = bitwidth ,
128- group_size = group_size ,
129- packed = (bitwidth in [2 , 4 ]),
130- ).quantized_model ()
119+ model .eval ()
120+ model .to (float_dtype )
131121
132122 if export_args .target_split_size is not None :
133123 replace_linear_with_split_linear (
@@ -140,24 +130,49 @@ def main() -> None:
140130 in_max_splits = 1 ,
141131 )
142132
143- model .eval ()
144- model .to (float_dtype )
133+ # Quantization
134+ if export_args .embedding_quantize :
135+ bitwidth , group_size = export_args .embedding_quantize .split ("," )
136+ bitwidth = int (bitwidth )
137+ assert bitwidth in [4 , 8 ], "CoreML only supports 4-bit and 8-bit quantization"
138+ group_size = int (group_size )
139+ if group_size == 0 :
140+ granularity = PerAxis (0 )
141+ else :
142+ granularity = PerGroup (group_size )
143+ weight_dtype = getattr (torch , f"int{ bitwidth } " )
145144
145+ quantize_ (
146+ model ,
147+ IntxWeightOnlyConfig (weight_dtype = weight_dtype , granularity = granularity ),
148+ lambda m , fqn : isinstance (m , torch .nn .Embedding ),
149+ )
150+
151+ # CoreML's op_linear_quantizer_config appears to have a bug where the quantization
152+ # quality is subpar. We use torchao APIs instead, which are now supported by CoreML
146153 op_linear_quantizer_config = None
154+ # op_linear_quantizer_config = {
155+ # "mode": "linear_symmetric",
156+ # "dtype": "int4",
157+ # "granularity": "per_channel",
158+ # }
159+
147160 if export_args .coreml_quantize == "b4w" :
148- op_linear_quantizer_config = {
149- "mode" : "linear_symmetric" ,
150- "dtype" : "int4" ,
151- "granularity" : "per_block" ,
152- "block_size" : 32 ,
153- "weight_threshold" : 512 ,
154- }
161+ quantize_ (
162+ model ,
163+ IntxWeightOnlyConfig (
164+ weight_dtype = torch . int4 ,
165+ granularity = PerGroup ( 32 ) ,
166+ ) ,
167+ )
155168 elif export_args .coreml_quantize == "c4w" :
156- op_linear_quantizer_config = {
157- "mode" : "linear_symmetric" ,
158- "dtype" : "int4" ,
159- "granularity" : "per_channel" ,
160- }
169+ quantize_ (
170+ model ,
171+ IntxWeightOnlyConfig (
172+ weight_dtype = torch .int4 ,
173+ granularity = PerAxis (0 ),
174+ ),
175+ )
161176
162177 compile_specs = CoreMLBackend .generate_compile_specs ( # pyre-fixme[16]
163178 minimum_deployment_target = ct .target .iOS18 ,
@@ -172,10 +187,7 @@ def main() -> None:
172187 partitioner = CoreMLPartitioner ( # pyre-fixme[16]
173188 compile_specs = compile_specs ,
174189 take_over_mutable_buffer = False ,
175- skip_ops_for_coreml_delegation = [
176- "quantized_decomposed.embedding_4bit.dtype" ,
177- "aten.embedding.default" ,
178- ],
190+ skip_ops_for_coreml_delegation = [],
179191 )
180192
181193 input_manager = InputManager (
@@ -192,31 +204,12 @@ def main() -> None:
192204 )
193205 example_inputs = input_manager .get_inputs (tokens = [0 ])
194206
207+ model = unwrap_tensor_subclass (model )
208+
195209 ep = torch .export .export (model , example_inputs , strict = True )
196210 print ("Exported program" )
197211 print (ep )
198212
199- # edge_manager = to_edge(
200- # ep,
201- # compile_config=EdgeCompileConfig(
202- # _check_ir_validity=False,
203- # _skip_dim_order=True,
204- # preserve_ops=[
205- # torch.ops.aten.scaled_dot_product_attention.default,
206- # # preserve norm op for numerical stability
207- # torch.ops.aten.linalg_vector_norm.default,
208- # torch.ops.aten.reciprocal.default,
209- # ],
210- # ),
211- # )
212- # print("Edge program")
213- # print(edge_manager.exported_program())
214-
215- # for node in edge_manager.exported_program().graph_module.graph.nodes:
216- # print(node.name, node.target, node.args, node.kwargs)
217-
218- # edge_manager = edge_manager.to_backend(partitioner)
219-
220213 edge_manager = to_edge_transform_and_lower (
221214 ep ,
222215 partitioner = [partitioner ],
@@ -227,7 +220,6 @@ def main() -> None:
227220 )
228221
229222 print ("Delegated program" )
230-
231223 print (format_delegated_graph (edge_manager .exported_program ().graph_module ))
232224
233225 executorch_program = edge_manager .to_executorch (
0 commit comments