55# LICENSE file in the root directory of this source tree.
66
77import argparse
8- import copy
98import json
109
1110import logging
1211import sys
13-
14- from typing import List , Tuple
12+ import types
1513
1614import torch
17- import torch . nn as nn
15+
1816from executorch .backends .qualcomm .quantizer .custom_annotation import (
1917 annotate_linear_16a8w_in_affine_layer ,
2018 annotate_matmul_16a8w ,
4644 LlamaModel ,
4745 ModelArgs ,
4846)
49-
50- from executorch .examples .qualcomm .utils import make_quantizer
47+ from executorch .examples .qualcomm .oss_scripts .llama .range_setting_pt2e import (
48+ compute_scales ,
49+ make_custom_quantizer ,
50+ reverse_quantize_module_swap ,
51+ set_scales ,
52+ WrappedLlamaModel ,
53+ )
5154
5255from lm_eval .evaluator import simple_evaluate
5356
5457from pytorch_tokenizers import get_tokenizer
58+ from torchao .prototype .spinquant import apply_spinquant
5559
56- from torchao .quantization .pt2e import MinMaxObserver
5760from torchao .quantization .pt2e .quantize_pt2e import convert_pt2e , prepare_pt2e
5861from torchao .quantization .pt2e .quantizer import QuantizationSpec
5962
6467logging .getLogger ().setLevel (logging .INFO )
6568
6669
67- class WrappedLlamaModel (nn .Module ):
68- def __init__ (
69- self , model , atten_mask , use_kv_cache = False , max_seq_len = 512 , device = "cuda"
70- ):
71- super (WrappedLlamaModel , self ).__init__ ()
72- self .model = model
73- self .max_seq_len = max_seq_len
74- self .use_kv_cache = use_kv_cache
75- self .device = device
76- self .atten_mask = atten_mask
77-
78- def forward (
79- self ,
80- tokens : torch .Tensor ,
81- * args ,
82- ) -> Tuple [torch .Tensor , List [torch .Tensor ], List [torch .Tensor ]]:
83- # Pad input if necessary, since LlamaModel requires static shape
84- if tokens .shape [1 ] != self .max_seq_len :
85- tokens = torch .nn .functional .pad (
86- tokens , (0 , self .max_seq_len - tokens .shape [1 ])
87- )
88- return self .model .forward (tokens , self .atten_mask )
89-
90-
9170def add_mse_weight_observer (quant_dtype , quantizer ):
9271 weight_dtype = (
9372 torch .int4
@@ -115,24 +94,16 @@ def add_mse_weight_observer(quant_dtype, quantizer):
11594 )
11695
11796
118- def gen_eval_wrapper (model_name , args ):
119- tokenizer = get_tokenizer (args .tokenizer_path )
97+ def prepare_model (model_name , args ):
12098 with open (args .params ) as f :
121- kv_config = ModelArgs (** json .load (f ))
99+ prefill_config = ModelArgs (** json .load (f ))
122100 # TODO: support batch inputs if necessary
123- kv_config .max_batch_size = 1
124- kv_config .max_seq_len = args .max_seq_length
125- kv_config .use_kv_cache = True
126-
127- prefill_config = copy .copy (kv_config )
101+ prefill_config .max_batch_size = 1
128102 prefill_config .max_seq_len = args .max_seq_length
129- prefill_config .use_kv_cache = (
130- False if args .max_seq_length == args .prefill_ar_len else True
131- )
132- config = prefill_config
103+ prefill_config .use_kv_cache = False
133104 use_i64_token = args .embedding_quantize is not None
134105 model = LlamaModel (
135- config ,
106+ prefill_config ,
136107 ar_len = args .prefill_ar_len ,
137108 output_new_cache_only = True ,
138109 output_cache = False ,
@@ -173,57 +144,90 @@ def permute(w, heads):
173144 if "model" in state_dict :
174145 state_dict = state_dict ["model" ]
175146
147+ # TODO: use dtype of model checkpoint
148+ model = model .to (device = args .device , dtype = torch .float )
149+ inputs = model .get_example_inputs (use_kv_cache = False )
150+ tokens , atten_mask = inputs
151+
152+ scales_state_dict = {}
153+ if args .spinquant :
154+ config = types .SimpleNamespace (
155+ dim = prefill_config .dim ,
156+ head_dim = prefill_config .dim // prefill_config .n_heads ,
157+ n_local_heads = prefill_config .n_heads ,
158+ intermediate_size = 4 * prefill_config .dim ,
159+ )
160+ model .config = config
161+ apply_spinquant (
162+ model ,
163+ use_r1 = True ,
164+ use_r2 = True ,
165+ use_r4 = False ,
166+ pretrained_rotation_path = None ,
167+ qkv_split = True ,
168+ )
169+ logging .info ("Applied SpinQuant to the model" )
170+
171+ if args .range_setting == "mse_with_act_loss" :
172+ wrapped_model = WrappedLlamaModel (
173+ model , atten_mask , args .use_kv_cache , args .max_seq_length , args .device
174+ )
175+ act_bits , weight_bits = {
176+ "8a8w" : (8 , 8 ),
177+ "16a4w" : (16 , 4 ),
178+ "16a4w_block" : (16 , 4 ),
179+ }[args .ptq ]
180+ scales_state_dict = compute_scales (
181+ wrapped_model , tokens , weight_bits , act_bits , 1600
182+ )
183+ torch .save (scales_state_dict , "scales_state_dict.pth" )
184+ logging .info ("Saved scales to scales_state_dict.pth!" )
185+ reverse_quantize_module_swap (wrapped_model )
186+
176187 for layer in model .layers :
177188 if getattr (layer .attention , "prepare_sha" , None ):
178189 layer .attention .prepare_sha ()
179190 if getattr (layer .feed_forward , "prepare_feedfoward_conv" , None ):
180191 layer .feed_forward .prepare_feedfoward_conv ()
181-
182- model .to (dtype = torch .float )
183- model .to (device = args .device )
184-
185- tokens , atten_mask = model .get_example_inputs (use_kv_cache = False )
186- tokens = tokens .to (device = args .device )
187- atten_mask = atten_mask .to (device = args .device )
188- atten_mask = atten_mask .to (dtype = torch .float )
189- inputs = (tokens , atten_mask )
190-
191192 if args .embedding_quantize :
192193 model = get_quant_embedding_transform (
193194 embedding_quantize = args .embedding_quantize
194195 )(model )
195196
196197 model = convert_linear_to_conv2d (model )
198+ return model , prefill_config , inputs , scales_state_dict
199+
200+
201+ def gen_eval_wrapper (model_name , args ):
202+ tokenizer = get_tokenizer (args .tokenizer_path )
203+ model , config , inputs , scales_state_dict = prepare_model (model_name , args )
204+ tokens , atten_mask = inputs
205+ use_i64_token = args .embedding_quantize is not None
197206
198- if args .ptq :
207+ if args .ptq is not None :
199208 quant_dtype = getattr (QuantDtype , f"use_{ args .ptq } " )
200209
201210 custom_annotations = (annotate_matmul_16a8w ,)
202211 if args .llama_model == "stories110m" :
203212 custom_annotations = custom_annotations + (
204213 annotate_linear_16a8w_in_affine_layer ,
205214 )
206- quantizer = make_quantizer (
207- quant_dtype = quant_dtype ,
208- per_channel_conv = True ,
209- per_channel_linear = True ,
210- act_observer = MinMaxObserver ,
211- )
212- quantizer .add_custom_quant_annotations (custom_annotations )
213215
214- if args .range_setting == "mse_weight" :
215- add_mse_weight_observer (quant_dtype , quantizer )
216+ quantizer = make_custom_quantizer (
217+ quant_dtype , args .range_setting , custom_annotations , args .quant_linear_only
218+ )
216219
217220 with torch .no_grad ():
221+ logging .info ("Starting export..." )
218222 model = torch .export .export (model , inputs , strict = True ).module ()
219223 if quant_dtype == QuantDtype .use_16a4w_block :
220224 conv_nodes = [n for n in model .graph .nodes if "conv" in n .name ]
221225 block_size_map = {n .name : (1 , 64 , 1 , 1 ) for n in conv_nodes }
222226 quantizer .set_block_size_map (block_size_map )
223-
227+ logging . info ( "Finished export, adding observers (prepare_pt2e)..." )
224228 model = prepare_pt2e (model , quantizer )
225229
226- logging .info ("Quantizing the model ..." )
230+ logging .info ("Observers added, starting calibration ..." )
227231
228232 calibrate (
229233 inputs ,
@@ -236,7 +240,24 @@ def permute(w, heads):
236240 use_i64_token = use_i64_token ,
237241 )
238242
243+ if args .range_setting == "mse_with_act_loss" :
244+ # scales_state_dict = torch.load("scales_state_dict.pth")
245+ set_scales (model , scales_state_dict , config .head_dim )
246+
247+ logging .info ("Quantizing the model..." )
239248 model = convert_pt2e (model )
249+ logging .info ("Quantization complete! Here is some sample generated text:" )
250+
251+ calibrate (
252+ inputs ,
253+ "Could you tell me about Facebook?" ,
254+ model ,
255+ tokenizer = tokenizer ,
256+ ar_len = args .prefill_ar_len ,
257+ max_seq_len = args .max_seq_len ,
258+ kv_updater = None ,
259+ use_i64_token = use_i64_token ,
260+ )
240261
241262 model = WrappedLlamaModel (
242263 model , atten_mask , args .use_kv_cache , args .max_seq_length , args .device
@@ -248,7 +269,7 @@ def permute(w, heads):
248269 max_seq_length = args .calibration_seq_length ,
249270 use_kv_cache = args .use_kv_cache ,
250271 generate_full_logits = args .generate_full_logits ,
251- enable_dynamic_shape = args . enable_dynamic_shape ,
272+ enable_dynamic_shape = False ,
252273 )
253274
254275
@@ -271,6 +292,7 @@ def eval_llama(
271292 model = eval_wrapper ,
272293 tasks = args .tasks ,
273294 num_fewshot = args .num_fewshot ,
295+ limit = args .fraction ,
274296 )
275297
276298 for task , res in eval_results ["results" ].items ():
@@ -290,9 +312,24 @@ def main() -> None:
290312 )
291313 parser .add_argument (
292314 "--range_setting" ,
293- help = "Choose which range setting method (e.g. mse_weight ). If not specified, will do minmax for weights and activations " ,
315+ help = "Choose which range setting method for weight quantization (e.g. mse_weight_only or mse_with_act_loss ). If not specified, defaults to minmax" ,
294316 type = str ,
295317 )
318+ parser .add_argument (
319+ "--spinquant" ,
320+ help = "Apply SpinQuant (R1+R2) to the model. Uses random Hadamard matrices for rotations" ,
321+ action = "store_true" ,
322+ )
323+ parser .add_argument (
324+ "--fraction" ,
325+ help = "the fraction of examples per task (only use this for testing)" ,
326+ type = float ,
327+ )
328+ parser .add_argument (
329+ "--quant_linear_only" ,
330+ help = "if you select this option we quantize linear layers only" ,
331+ action = "store_true" ,
332+ )
296333
297334 args = parser .parse_args ()
298335 args .llama_model = "llama3_2"
0 commit comments