5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
import argparse
8
- import copy
9
8
import json
10
9
11
10
import logging
12
11
import sys
13
-
14
- from typing import List , Tuple
12
+ import types
15
13
16
14
import torch
17
- import torch . nn as nn
15
+
18
16
from executorch .backends .qualcomm .quantizer .custom_annotation import (
19
17
annotate_linear_16a8w_in_affine_layer ,
20
18
annotate_matmul_16a8w ,
46
44
LlamaModel ,
47
45
ModelArgs ,
48
46
)
49
-
50
- from executorch .examples .qualcomm .utils import make_quantizer
47
+ from executorch .examples .qualcomm .oss_scripts .llama .range_setting_pt2e import (
48
+ compute_scales ,
49
+ make_custom_quantizer ,
50
+ reverse_quantize_module_swap ,
51
+ set_scales ,
52
+ WrappedLlamaModel ,
53
+ )
51
54
52
55
from lm_eval .evaluator import simple_evaluate
53
56
54
57
from pytorch_tokenizers import get_tokenizer
58
+ from torchao .prototype .spinquant import apply_spinquant
55
59
56
- from torchao .quantization .pt2e import MinMaxObserver
57
60
from torchao .quantization .pt2e .quantize_pt2e import convert_pt2e , prepare_pt2e
58
61
from torchao .quantization .pt2e .quantizer import QuantizationSpec
59
62
64
67
logging .getLogger ().setLevel (logging .INFO )
65
68
66
69
67
- class WrappedLlamaModel (nn .Module ):
68
- def __init__ (
69
- self , model , atten_mask , use_kv_cache = False , max_seq_len = 512 , device = "cuda"
70
- ):
71
- super (WrappedLlamaModel , self ).__init__ ()
72
- self .model = model
73
- self .max_seq_len = max_seq_len
74
- self .use_kv_cache = use_kv_cache
75
- self .device = device
76
- self .atten_mask = atten_mask
77
-
78
- def forward (
79
- self ,
80
- tokens : torch .Tensor ,
81
- * args ,
82
- ) -> Tuple [torch .Tensor , List [torch .Tensor ], List [torch .Tensor ]]:
83
- # Pad input if necessary, since LlamaModel requires static shape
84
- if tokens .shape [1 ] != self .max_seq_len :
85
- tokens = torch .nn .functional .pad (
86
- tokens , (0 , self .max_seq_len - tokens .shape [1 ])
87
- )
88
- return self .model .forward (tokens , self .atten_mask )
89
-
90
-
91
70
def add_mse_weight_observer (quant_dtype , quantizer ):
92
71
weight_dtype = (
93
72
torch .int4
@@ -115,24 +94,16 @@ def add_mse_weight_observer(quant_dtype, quantizer):
115
94
)
116
95
117
96
118
- def gen_eval_wrapper (model_name , args ):
119
- tokenizer = get_tokenizer (args .tokenizer_path )
97
+ def prepare_model (model_name , args ):
120
98
with open (args .params ) as f :
121
- kv_config = ModelArgs (** json .load (f ))
99
+ prefill_config = ModelArgs (** json .load (f ))
122
100
# TODO: support batch inputs if necessary
123
- kv_config .max_batch_size = 1
124
- kv_config .max_seq_len = args .max_seq_length
125
- kv_config .use_kv_cache = True
126
-
127
- prefill_config = copy .copy (kv_config )
101
+ prefill_config .max_batch_size = 1
128
102
prefill_config .max_seq_len = args .max_seq_length
129
- prefill_config .use_kv_cache = (
130
- False if args .max_seq_length == args .prefill_ar_len else True
131
- )
132
- config = prefill_config
103
+ prefill_config .use_kv_cache = False
133
104
use_i64_token = args .embedding_quantize is not None
134
105
model = LlamaModel (
135
- config ,
106
+ prefill_config ,
136
107
ar_len = args .prefill_ar_len ,
137
108
output_new_cache_only = True ,
138
109
output_cache = False ,
@@ -173,57 +144,90 @@ def permute(w, heads):
173
144
if "model" in state_dict :
174
145
state_dict = state_dict ["model" ]
175
146
147
+ # TODO: use dtype of model checkpoint
148
+ model = model .to (device = args .device , dtype = torch .float )
149
+ inputs = model .get_example_inputs (use_kv_cache = False )
150
+ tokens , atten_mask = inputs
151
+
152
+ scales_state_dict = {}
153
+ if args .spinquant :
154
+ config = types .SimpleNamespace (
155
+ dim = prefill_config .dim ,
156
+ head_dim = prefill_config .dim // prefill_config .n_heads ,
157
+ n_local_heads = prefill_config .n_heads ,
158
+ intermediate_size = 4 * prefill_config .dim ,
159
+ )
160
+ model .config = config
161
+ apply_spinquant (
162
+ model ,
163
+ use_r1 = True ,
164
+ use_r2 = True ,
165
+ use_r4 = False ,
166
+ pretrained_rotation_path = None ,
167
+ qkv_split = True ,
168
+ )
169
+ logging .info ("Applied SpinQuant to the model" )
170
+
171
+ if args .range_setting == "mse_with_act_loss" :
172
+ wrapped_model = WrappedLlamaModel (
173
+ model , atten_mask , args .use_kv_cache , args .max_seq_length , args .device
174
+ )
175
+ act_bits , weight_bits = {
176
+ "8a8w" : (8 , 8 ),
177
+ "16a4w" : (16 , 4 ),
178
+ "16a4w_block" : (16 , 4 ),
179
+ }[args .ptq ]
180
+ scales_state_dict = compute_scales (
181
+ wrapped_model , tokens , weight_bits , act_bits , 1600
182
+ )
183
+ torch .save (scales_state_dict , "scales_state_dict.pth" )
184
+ logging .info ("Saved scales to scales_state_dict.pth!" )
185
+ reverse_quantize_module_swap (wrapped_model )
186
+
176
187
for layer in model .layers :
177
188
if getattr (layer .attention , "prepare_sha" , None ):
178
189
layer .attention .prepare_sha ()
179
190
if getattr (layer .feed_forward , "prepare_feedfoward_conv" , None ):
180
191
layer .feed_forward .prepare_feedfoward_conv ()
181
-
182
- model .to (dtype = torch .float )
183
- model .to (device = args .device )
184
-
185
- tokens , atten_mask = model .get_example_inputs (use_kv_cache = False )
186
- tokens = tokens .to (device = args .device )
187
- atten_mask = atten_mask .to (device = args .device )
188
- atten_mask = atten_mask .to (dtype = torch .float )
189
- inputs = (tokens , atten_mask )
190
-
191
192
if args .embedding_quantize :
192
193
model = get_quant_embedding_transform (
193
194
embedding_quantize = args .embedding_quantize
194
195
)(model )
195
196
196
197
model = convert_linear_to_conv2d (model )
198
+ return model , prefill_config , inputs , scales_state_dict
199
+
200
+
201
+ def gen_eval_wrapper (model_name , args ):
202
+ tokenizer = get_tokenizer (args .tokenizer_path )
203
+ model , config , inputs , scales_state_dict = prepare_model (model_name , args )
204
+ tokens , atten_mask = inputs
205
+ use_i64_token = args .embedding_quantize is not None
197
206
198
- if args .ptq :
207
+ if args .ptq is not None :
199
208
quant_dtype = getattr (QuantDtype , f"use_{ args .ptq } " )
200
209
201
210
custom_annotations = (annotate_matmul_16a8w ,)
202
211
if args .llama_model == "stories110m" :
203
212
custom_annotations = custom_annotations + (
204
213
annotate_linear_16a8w_in_affine_layer ,
205
214
)
206
- quantizer = make_quantizer (
207
- quant_dtype = quant_dtype ,
208
- per_channel_conv = True ,
209
- per_channel_linear = True ,
210
- act_observer = MinMaxObserver ,
211
- )
212
- quantizer .add_custom_quant_annotations (custom_annotations )
213
215
214
- if args .range_setting == "mse_weight" :
215
- add_mse_weight_observer (quant_dtype , quantizer )
216
+ quantizer = make_custom_quantizer (
217
+ quant_dtype , args .range_setting , custom_annotations , args .quant_linear_only
218
+ )
216
219
217
220
with torch .no_grad ():
221
+ logging .info ("Starting export..." )
218
222
model = torch .export .export (model , inputs , strict = True ).module ()
219
223
if quant_dtype == QuantDtype .use_16a4w_block :
220
224
conv_nodes = [n for n in model .graph .nodes if "conv" in n .name ]
221
225
block_size_map = {n .name : (1 , 64 , 1 , 1 ) for n in conv_nodes }
222
226
quantizer .set_block_size_map (block_size_map )
223
-
227
+ logging . info ( "Finished export, adding observers (prepare_pt2e)..." )
224
228
model = prepare_pt2e (model , quantizer )
225
229
226
- logging .info ("Quantizing the model ..." )
230
+ logging .info ("Observers added, starting calibration ..." )
227
231
228
232
calibrate (
229
233
inputs ,
@@ -236,7 +240,24 @@ def permute(w, heads):
236
240
use_i64_token = use_i64_token ,
237
241
)
238
242
243
+ if args .range_setting == "mse_with_act_loss" :
244
+ # scales_state_dict = torch.load("scales_state_dict.pth")
245
+ set_scales (model , scales_state_dict , config .head_dim )
246
+
247
+ logging .info ("Quantizing the model..." )
239
248
model = convert_pt2e (model )
249
+ logging .info ("Quantization complete! Here is some sample generated text:" )
250
+
251
+ calibrate (
252
+ inputs ,
253
+ "Could you tell me about Facebook?" ,
254
+ model ,
255
+ tokenizer = tokenizer ,
256
+ ar_len = args .prefill_ar_len ,
257
+ max_seq_len = args .max_seq_len ,
258
+ kv_updater = None ,
259
+ use_i64_token = use_i64_token ,
260
+ )
240
261
241
262
model = WrappedLlamaModel (
242
263
model , atten_mask , args .use_kv_cache , args .max_seq_length , args .device
@@ -248,7 +269,7 @@ def permute(w, heads):
248
269
max_seq_length = args .calibration_seq_length ,
249
270
use_kv_cache = args .use_kv_cache ,
250
271
generate_full_logits = args .generate_full_logits ,
251
- enable_dynamic_shape = args . enable_dynamic_shape ,
272
+ enable_dynamic_shape = False ,
252
273
)
253
274
254
275
@@ -271,6 +292,7 @@ def eval_llama(
271
292
model = eval_wrapper ,
272
293
tasks = args .tasks ,
273
294
num_fewshot = args .num_fewshot ,
295
+ limit = args .fraction ,
274
296
)
275
297
276
298
for task , res in eval_results ["results" ].items ():
@@ -290,9 +312,24 @@ def main() -> None:
290
312
)
291
313
parser .add_argument (
292
314
"--range_setting" ,
293
- help = "Choose which range setting method (e.g. mse_weight ). If not specified, will do minmax for weights and activations " ,
315
+ help = "Choose which range setting method for weight quantization (e.g. mse_weight_only or mse_with_act_loss ). If not specified, defaults to minmax" ,
294
316
type = str ,
295
317
)
318
+ parser .add_argument (
319
+ "--spinquant" ,
320
+ help = "Apply SpinQuant (R1+R2) to the model. Uses random Hadamard matrices for rotations" ,
321
+ action = "store_true" ,
322
+ )
323
+ parser .add_argument (
324
+ "--fraction" ,
325
+ help = "the fraction of examples per task (only use this for testing)" ,
326
+ type = float ,
327
+ )
328
+ parser .add_argument (
329
+ "--quant_linear_only" ,
330
+ help = "if you select this option we quantize linear layers only" ,
331
+ action = "store_true" ,
332
+ )
296
333
297
334
args = parser .parse_args ()
298
335
args .llama_model = "llama3_2"
0 commit comments