18
18
from executorch .examples .apple .coreml .llama .utils import (
19
19
replace_linear_with_split_linear ,
20
20
)
21
- from executorch .examples .models .llama .source_transformation .quantize import (
22
- EmbeddingQuantHandler ,
23
- )
24
21
22
+ from executorch .exir import to_edge_transform_and_lower
25
23
from executorch .exir .backend .utils import format_delegated_graph
26
24
from executorch .exir .capture ._config import EdgeCompileConfig , ExecutorchBackendConfig
27
25
from executorch .exir .passes import MemoryPlanningPass
28
26
from executorch .exir .passes .quant_fusion_pass import QuantFusionPass
29
27
from executorch .exir .passes .sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
30
- from executorch .exir .program ._program import to_edge
31
28
from executorch .extension .export_util .utils import save_pte_program
32
29
30
+ from torchao .quantization .granularity import PerAxis , PerGroup
31
+ from torchao .quantization .quant_api import IntxWeightOnlyConfig , quantize_
32
+ from torchao .utils import unwrap_tensor_subclass
33
+
33
34
34
35
def main () -> None :
35
36
parser = argparse .ArgumentParser ()
@@ -115,19 +116,8 @@ def main() -> None:
115
116
export_args .dtype
116
117
] # dtype for model/inputs
117
118
118
- if export_args .embedding_quantize :
119
- bitwidth , group_size = export_args .embedding_quantize .split ("," )
120
- if group_size == "none" or group_size == "None" or group_size == "0" :
121
- group_size = None
122
- else :
123
- group_size = int (group_size )
124
- bitwidth = int (bitwidth )
125
- model = EmbeddingQuantHandler (
126
- model ,
127
- bitwidth = bitwidth ,
128
- group_size = group_size ,
129
- packed = (bitwidth in [2 , 4 ]),
130
- ).quantized_model ()
119
+ model .eval ()
120
+ model .to (float_dtype )
131
121
132
122
if export_args .target_split_size is not None :
133
123
replace_linear_with_split_linear (
@@ -140,24 +130,40 @@ def main() -> None:
140
130
in_max_splits = 1 ,
141
131
)
142
132
143
- model .eval ()
144
- model .to (float_dtype )
133
+ # Quantization
134
+ if export_args .embedding_quantize :
135
+ bitwidth , group_size = export_args .embedding_quantize .split ("," )
136
+ bitwidth = int (bitwidth )
137
+ assert bitwidth in [4 , 8 ], "CoreML only supports 4-bit and 8-bit quantization"
138
+ group_size = int (group_size )
139
+ if group_size == 0 :
140
+ granularity = PerAxis (0 )
141
+ else :
142
+ granularity = PerGroup (group_size )
143
+ weight_dtype = getattr (torch , f"int{ bitwidth } " )
144
+
145
+ quantize_ (
146
+ model ,
147
+ IntxWeightOnlyConfig (weight_dtype = weight_dtype , granularity = granularity ),
148
+ lambda m , fqn : isinstance (m , torch .nn .Embedding ),
149
+ )
145
150
146
- op_linear_quantizer_config = None
147
151
if export_args .coreml_quantize == "b4w" :
148
- op_linear_quantizer_config = {
149
- "mode" : "linear_symmetric" ,
150
- "dtype" : "int4" ,
151
- "granularity" : "per_block" ,
152
- "block_size" : 32 ,
153
- "weight_threshold" : 512 ,
154
- }
152
+ quantize_ (
153
+ model ,
154
+ IntxWeightOnlyConfig (
155
+ weight_dtype = torch . int4 ,
156
+ granularity = PerGroup ( 32 ) ,
157
+ ) ,
158
+ )
155
159
elif export_args .coreml_quantize == "c4w" :
156
- op_linear_quantizer_config = {
157
- "mode" : "linear_symmetric" ,
158
- "dtype" : "int4" ,
159
- "granularity" : "per_channel" ,
160
- }
160
+ quantize_ (
161
+ model ,
162
+ IntxWeightOnlyConfig (
163
+ weight_dtype = torch .int4 ,
164
+ granularity = PerAxis (0 ),
165
+ ),
166
+ )
161
167
162
168
compile_specs = CoreMLBackend .generate_compile_specs ( # pyre-fixme[16]
163
169
minimum_deployment_target = ct .target .iOS18 ,
@@ -167,15 +173,11 @@ def main() -> None:
167
173
}[float_dtype ],
168
174
compute_unit = ct .ComputeUnit .CPU_AND_NE ,
169
175
model_type = CoreMLBackend .MODEL_TYPE .MODEL , # pyre-fixme[16]
170
- op_linear_quantizer_config = op_linear_quantizer_config ,
171
176
)
172
177
partitioner = CoreMLPartitioner ( # pyre-fixme[16]
173
178
compile_specs = compile_specs ,
174
179
take_over_mutable_buffer = False ,
175
- skip_ops_for_coreml_delegation = [
176
- "quantized_decomposed.embedding_4bit.dtype" ,
177
- "aten.embedding.default" ,
178
- ],
180
+ skip_ops_for_coreml_delegation = [],
179
181
)
180
182
181
183
input_manager = InputManager (
@@ -192,33 +194,22 @@ def main() -> None:
192
194
)
193
195
example_inputs = input_manager .get_inputs (tokens = [0 ])
194
196
197
+ model = unwrap_tensor_subclass (model )
198
+
195
199
ep = torch .export .export (model , example_inputs , strict = True )
196
200
print ("Exported program" )
197
201
print (ep )
198
202
199
- edge_manager = to_edge (
203
+ edge_manager = to_edge_transform_and_lower (
200
204
ep ,
205
+ partitioner = [partitioner ],
201
206
compile_config = EdgeCompileConfig (
202
- _check_ir_validity = False ,
207
+ # TODO: fix lowering when dim_order is enabled
203
208
_skip_dim_order = True ,
204
- preserve_ops = [
205
- torch .ops .aten .scaled_dot_product_attention .default ,
206
- # preserve norm op for numerical stability
207
- torch .ops .aten .linalg_vector_norm .default ,
208
- torch .ops .aten .reciprocal .default ,
209
- ],
210
209
),
211
210
)
212
- print ("Edge program" )
213
- print (edge_manager .exported_program ())
214
-
215
- for node in edge_manager .exported_program ().graph_module .graph .nodes :
216
- print (node .name , node .target , node .args , node .kwargs )
217
-
218
- edge_manager = edge_manager .to_backend (partitioner )
219
211
220
212
print ("Delegated program" )
221
-
222
213
print (format_delegated_graph (edge_manager .exported_program ().graph_module ))
223
214
224
215
executorch_program = edge_manager .to_executorch (
0 commit comments