1313import executorch .exir as exir
1414
1515import torch
16+ from coremltools .optimize .torch .quantization .quantization_config import (
17+ LinearQuantizerConfig ,
18+ QuantizationScheme ,
19+ )
1620
1721from executorch .backends .apple .coreml .compiler import CoreMLBackend
1822
1923from executorch .backends .apple .coreml .partition import CoreMLPartitioner
24+ from executorch .backends .apple .coreml .quantizer import CoreMLQuantizer
2025from executorch .devtools .etrecord import generate_etrecord
2126from executorch .exir import to_edge
2227
2328from executorch .exir .backend .backend_api import to_backend
29+ from torch .ao .quantization .quantize_pt2e import convert_pt2e , prepare_pt2e
30+
2431from torch .export import export
2532
2633REPO_ROOT = pathlib .Path (__file__ ).resolve ().parent .parent .parent .parent .parent
@@ -74,6 +81,13 @@ def parse_args() -> argparse.ArgumentParser:
7481 parser .add_argument ("--generate_etrecord" , action = argparse .BooleanOptionalAction )
7582 parser .add_argument ("--save_processed_bytes" , action = argparse .BooleanOptionalAction )
7683
84+ parser .add_argument (
85+ "--quantize" ,
86+ action = argparse .BooleanOptionalAction ,
87+ required = False ,
88+ help = "Quantize CoreML model" ,
89+ )
90+
7791 args = parser .parse_args ()
7892 return args
7993
@@ -109,9 +123,10 @@ def export_lowered_module_to_executorch_program(lowered_module, example_inputs):
109123 return exec_prog
110124
111125
112- def save_executorch_program (exec_prog , model_name , compute_unit ):
126+ def save_executorch_program (exec_prog , model_name , compute_unit , quantize ):
113127 buffer = exec_prog .buffer
114- filename = f"{ model_name } _coreml_{ compute_unit } .pte"
128+ data_type = "quantize" if quantize else "fp"
129+ filename = f"{ model_name } _coreml_{ compute_unit } _{ data_type } .pte"
115130 print (f"Saving exported program to { filename } " )
116131 with open (filename , "wb" ) as file :
117132 file .write (buffer )
@@ -167,6 +182,23 @@ def generate_compile_specs_from_args(args):
167182 if args .use_partitioner :
168183 model .eval ()
169184 exir_program_aten = torch .export .export (model , example_inputs )
185+ if args .quantize :
186+ quantization_config = LinearQuantizerConfig .from_dict (
187+ {
188+ "global_config" : {
189+ "quantization_scheme" : QuantizationScheme .affine ,
190+ "activation_dtype" : torch .quint8 ,
191+ "weight_dtype" : torch .qint8 ,
192+ "weight_per_channel" : True ,
193+ }
194+ }
195+ )
196+
197+ quantizer = CoreMLQuantizer (quantization_config )
198+ model = prepare_pt2e (model , quantizer ) # pyre-fixme[6]
199+ model (* example_inputs )
200+ exir_program_aten = convert_pt2e (model )
201+
170202 edge_program_manager = exir .to_edge (exir_program_aten )
171203 edge_copy = copy .deepcopy (edge_program_manager )
172204 partitioner = CoreMLPartitioner (
@@ -186,7 +218,9 @@ def generate_compile_specs_from_args(args):
186218 example_inputs ,
187219 )
188220
189- save_executorch_program (exec_program , args .model_name , args .compute_unit )
221+ save_executorch_program (
222+ exec_program , args .model_name , args .compute_unit , args .quantize
223+ )
190224 generate_etrecord (f"{ args .model_name } _coreml_etrecord.bin" , edge_copy , exec_program )
191225
192226 if args .save_processed_bytes and lowered_module is not None :
0 commit comments