1313import executorch .exir as exir
1414
1515import torch
16+ from coremltools .optimize .torch .quantization .quantization_config import (
17+ LinearQuantizerConfig ,
18+ QuantizationScheme ,
19+ )
1620
1721from executorch .backends .apple .coreml .compiler import CoreMLBackend
1822
1923from executorch .backends .apple .coreml .partition import CoreMLPartitioner
24+ from executorch .backends .apple .coreml .quantizer import CoreMLQuantizer
2025from executorch .devtools .etrecord import generate_etrecord
2126from executorch .exir import to_edge
2227
2328from executorch .exir .backend .backend_api import to_backend
29+ from torch .ao .quantization .quantize_pt2e import convert_pt2e , prepare_pt2e
2430
2531from torch .export import export
2632
@@ -75,6 +81,13 @@ def parse_args() -> argparse.ArgumentParser:
7581 parser .add_argument ("--generate_etrecord" , action = argparse .BooleanOptionalAction )
7682 parser .add_argument ("--save_processed_bytes" , action = argparse .BooleanOptionalAction )
7783
84+ parser .add_argument (
85+ "--quantize" ,
86+ action = argparse .BooleanOptionalAction ,
87+ required = False ,
88+ help = "Quantize CoreML model" ,
89+ )
90+
7891 args = parser .parse_args ()
7992 return args
8093
@@ -110,9 +123,10 @@ def export_lowered_module_to_executorch_program(lowered_module, example_inputs):
110123 return exec_prog
111124
112125
113- def save_executorch_program (exec_prog , model_name , compute_unit ):
126+ def save_executorch_program (exec_prog , model_name , compute_unit , quantize ):
114127 buffer = exec_prog .buffer
115- filename = f"{ model_name } _coreml_{ compute_unit } .pte"
128+ data_type = "quantize" if quantize else "fp"
129+ filename = f"{ model_name } _coreml_{ compute_unit } _{ data_type } .pte"
116130 print (f"Saving exported program to { filename } " )
117131 with open (filename , "wb" ) as file :
118132 file .write (buffer )
@@ -168,6 +182,22 @@ def main():
168182 if args .use_partitioner :
169183 model .eval ()
170184 exir_program_aten = torch .export .export (model , example_inputs )
185+ if args .quantize :
186+ quantization_config = LinearQuantizerConfig .from_dict (
187+ {
188+ "global_config" : {
189+ "quantization_scheme" : QuantizationScheme .affine ,
190+ "activation_dtype" : torch .quint8 ,
191+ "weight_dtype" : torch .qint8 ,
192+ "weight_per_channel" : True ,
193+ }
194+ }
195+ )
196+
197+ quantizer = CoreMLQuantizer (quantization_config )
198+ model = prepare_pt2e (model , quantizer ) # pyre-fixme[6]
199+ model (* example_inputs )
200+ exir_program_aten = convert_pt2e (model )
171201
172202 edge_program_manager = exir .to_edge (exir_program_aten )
173203 edge_copy = copy .deepcopy (edge_program_manager )
@@ -189,7 +219,9 @@ def main():
189219 example_inputs ,
190220 )
191221
192- save_executorch_program (exec_program , args .model_name , args .compute_unit )
222+ save_executorch_program (
223+ exec_program , args .model_name , args .compute_unit , args .quantize
224+ )
193225 generate_etrecord (f"{ args .model_name } _coreml_etrecord.bin" , edge_copy , exec_program )
194226
195227 if args .save_processed_bytes and lowered_module is not None :
0 commit comments