1818from executorch .examples .apple .coreml .llama .utils import (
1919 replace_linear_with_split_linear ,
2020)
21- from executorch .examples .models .llama .source_transformation .quantize import (
22- EmbeddingQuantHandler ,
23- )
2421
22+ from executorch .exir import to_edge_transform_and_lower
2523from executorch .exir .backend .utils import format_delegated_graph
2624from executorch .exir .capture ._config import EdgeCompileConfig , ExecutorchBackendConfig
2725from executorch .exir .passes import MemoryPlanningPass
2826from executorch .exir .passes .quant_fusion_pass import QuantFusionPass
2927from executorch .exir .passes .sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
30- from executorch .exir .program ._program import to_edge
3128from executorch .extension .export_util .utils import save_pte_program
3229
30+ from torchao .quantization .granularity import PerAxis , PerGroup
31+ from torchao .quantization .quant_api import IntxWeightOnlyConfig , quantize_
32+ from torchao .utils import unwrap_tensor_subclass
33+
3334
3435def main () -> None :
3536 parser = argparse .ArgumentParser ()
@@ -115,19 +116,8 @@ def main() -> None:
115116 export_args .dtype
116117 ] # dtype for model/inputs
117118
118- if export_args .embedding_quantize :
119- bitwidth , group_size = export_args .embedding_quantize .split ("," )
120- if group_size == "none" or group_size == "None" or group_size == "0" :
121- group_size = None
122- else :
123- group_size = int (group_size )
124- bitwidth = int (bitwidth )
125- model = EmbeddingQuantHandler (
126- model ,
127- bitwidth = bitwidth ,
128- group_size = group_size ,
129- packed = (bitwidth in [2 , 4 ]),
130- ).quantized_model ()
119+ model .eval ()
120+ model .to (float_dtype )
131121
132122 if export_args .target_split_size is not None :
133123 replace_linear_with_split_linear (
@@ -140,24 +130,40 @@ def main() -> None:
140130 in_max_splits = 1 ,
141131 )
142132
143- model .eval ()
144- model .to (float_dtype )
133+ # Quantization
134+ if export_args .embedding_quantize :
135+ bitwidth , group_size = export_args .embedding_quantize .split ("," )
136+ bitwidth = int (bitwidth )
137+ assert bitwidth in [4 , 8 ], "CoreML only supports 4-bit and 8-bit quantization"
138+ group_size = int (group_size )
139+ if group_size == 0 :
140+ granularity = PerAxis (0 )
141+ else :
142+ granularity = PerGroup (group_size )
143+ weight_dtype = getattr (torch , f"int{ bitwidth } " )
144+
145+ quantize_ (
146+ model ,
147+ IntxWeightOnlyConfig (weight_dtype = weight_dtype , granularity = granularity ),
148+ lambda m , fqn : isinstance (m , torch .nn .Embedding ),
149+ )
145150
146- op_linear_quantizer_config = None
147151 if export_args .coreml_quantize == "b4w" :
148- op_linear_quantizer_config = {
149- "mode" : "linear_symmetric" ,
150- "dtype" : "int4" ,
151- "granularity" : "per_block" ,
152- "block_size" : 32 ,
153- "weight_threshold" : 512 ,
154- }
152+ quantize_ (
153+ model ,
154+ IntxWeightOnlyConfig (
155+ weight_dtype = torch . int4 ,
156+ granularity = PerGroup ( 32 ) ,
157+ ) ,
158+ )
155159 elif export_args .coreml_quantize == "c4w" :
156- op_linear_quantizer_config = {
157- "mode" : "linear_symmetric" ,
158- "dtype" : "int4" ,
159- "granularity" : "per_channel" ,
160- }
160+ quantize_ (
161+ model ,
162+ IntxWeightOnlyConfig (
163+ weight_dtype = torch .int4 ,
164+ granularity = PerAxis (0 ),
165+ ),
166+ )
161167
162168 compile_specs = CoreMLBackend .generate_compile_specs ( # pyre-fixme[16]
163169 minimum_deployment_target = ct .target .iOS18 ,
@@ -167,15 +173,11 @@ def main() -> None:
167173 }[float_dtype ],
168174 compute_unit = ct .ComputeUnit .CPU_AND_NE ,
169175 model_type = CoreMLBackend .MODEL_TYPE .MODEL , # pyre-fixme[16]
170- op_linear_quantizer_config = op_linear_quantizer_config ,
171176 )
172177 partitioner = CoreMLPartitioner ( # pyre-fixme[16]
173178 compile_specs = compile_specs ,
174179 take_over_mutable_buffer = False ,
175- skip_ops_for_coreml_delegation = [
176- "quantized_decomposed.embedding_4bit.dtype" ,
177- "aten.embedding.default" ,
178- ],
180+ skip_ops_for_coreml_delegation = [],
179181 )
180182
181183 input_manager = InputManager (
@@ -192,33 +194,22 @@ def main() -> None:
192194 )
193195 example_inputs = input_manager .get_inputs (tokens = [0 ])
194196
197+ model = unwrap_tensor_subclass (model )
198+
195199 ep = torch .export .export (model , example_inputs , strict = True )
196200 print ("Exported program" )
197201 print (ep )
198202
199- edge_manager = to_edge (
203+ edge_manager = to_edge_transform_and_lower (
200204 ep ,
205+ partitioner = [partitioner ],
201206 compile_config = EdgeCompileConfig (
202- _check_ir_validity = False ,
207+ # TODO: fix lowering when dim_order is enabled
203208 _skip_dim_order = True ,
204- preserve_ops = [
205- torch .ops .aten .scaled_dot_product_attention .default ,
206- # preserve norm op for numerical stability
207- torch .ops .aten .linalg_vector_norm .default ,
208- torch .ops .aten .reciprocal .default ,
209- ],
210209 ),
211210 )
212- print ("Edge program" )
213- print (edge_manager .exported_program ())
214-
215- for node in edge_manager .exported_program ().graph_module .graph .nodes :
216- print (node .name , node .target , node .args , node .kwargs )
217-
218- edge_manager = edge_manager .to_backend (partitioner )
219211
220212 print ("Delegated program" )
221-
222213 print (format_delegated_graph (edge_manager .exported_program ().graph_module ))
223214
224215 executorch_program = edge_manager .to_executorch (
0 commit comments