4646 LlamaModel ,
4747 ModelArgs ,
4848)
49-
50- from executorch .examples .qualcomm .utils import make_quantizer
49+ from executorch .examples .qualcomm .oss_scripts .llama .range_setting_pt2e import (
50+ reverse_quantize_module_swap ,
51+ WrappedLlamaModel ,
52+ compute_scales ,
53+ set_scales ,
54+ make_custom_quantizer ,
55+ )
5156
5257from lm_eval .evaluator import simple_evaluate
5358
5459from pytorch_tokenizers import get_tokenizer
5560
56- from torchao .quantization .pt2e import MinMaxObserver
5761from torchao .quantization .pt2e .quantize_pt2e import convert_pt2e , prepare_pt2e
5862from torchao .quantization .pt2e .quantizer import QuantizationSpec
5963
@@ -87,7 +91,6 @@ def forward(
8791 )
8892 return self .model .forward (tokens , self .atten_mask )
8993
90-
9194def add_mse_weight_observer (quant_dtype , quantizer ):
9295 weight_dtype = (
9396 torch .int4
@@ -118,21 +121,15 @@ def add_mse_weight_observer(quant_dtype, quantizer):
118121def gen_eval_wrapper (model_name , args ):
119122 tokenizer = get_tokenizer (args .tokenizer_path )
120123 with open (args .params ) as f :
121- kv_config = ModelArgs (** json .load (f ))
124+ prefill_config = ModelArgs (** json .load (f ))
122125 # TODO: support batch inputs if necessary
123- kv_config .max_batch_size = 1
124- kv_config .max_seq_len = args .max_seq_length
125- kv_config .use_kv_cache = True
126-
127- prefill_config = copy .copy (kv_config )
126+ prefill_config .max_batch_size = 1
128127 prefill_config .max_seq_len = args .max_seq_length
129- prefill_config .use_kv_cache = (
130- False if args .max_seq_length == args .prefill_ar_len else True
131- )
132- config = prefill_config
128+ prefill_config .use_kv_cache = False
129+ print (prefill_config .hidden_dim )
133130 use_i64_token = args .embedding_quantize is not None
134131 model = LlamaModel (
135- config ,
132+ prefill_config ,
136133 ar_len = args .prefill_ar_len ,
137134 output_new_cache_only = True ,
138135 output_cache = False ,
@@ -173,20 +170,30 @@ def permute(w, heads):
173170 if "model" in state_dict :
174171 state_dict = state_dict ["model" ]
175172
173+ tokens , atten_mask = model .get_example_inputs (use_kv_cache = False )
174+ tokens = tokens .to (device = args .device )
175+ atten_mask = atten_mask .to (device = args .device )
176+ atten_mask = atten_mask .to (dtype = torch .float )
177+ inputs = (tokens , atten_mask )
178+
179+ model = model .to (dtype = torch .float )
180+ model = model .to (device = args .device )
181+
182+ scales_state_dict = dict ()
183+ if args .range_setting == "mse_with_act_loss" :
184+ wrapped_model = WrappedLlamaModel (
185+ model , atten_mask , args .use_kv_cache , args .max_seq_length , args .device
186+ )
187+ scales_state_dict = compute_scales (wrapped_model , tokens , 1600 ) # want to use different tokens for calibration!
188+ reverse_quantize_module_swap (wrapped_model )
189+
176190 for layer in model .layers :
177191 if getattr (layer .attention , "prepare_sha" , None ):
178192 layer .attention .prepare_sha ()
179193 if getattr (layer .feed_forward , "prepare_feedfoward_conv" , None ):
180194 layer .feed_forward .prepare_feedfoward_conv ()
181195
182- model .to (dtype = torch .float )
183- model .to (device = args .device )
184-
185- tokens , atten_mask = model .get_example_inputs (use_kv_cache = False )
186- tokens = tokens .to (device = args .device )
187- atten_mask = atten_mask .to (device = args .device )
188- atten_mask = atten_mask .to (dtype = torch .float )
189- inputs = (tokens , atten_mask )
196+ model = model .to (dtype = torch .float )
190197
191198 if args .embedding_quantize :
192199 model = get_quant_embedding_transform (
@@ -195,35 +202,28 @@ def permute(w, heads):
195202
196203 model = convert_linear_to_conv2d (model )
197204
198- if args .ptq :
205+ if args .ptq is not None :
199206 quant_dtype = getattr (QuantDtype , f"use_{ args .ptq } " )
200207
201208 custom_annotations = (annotate_matmul_16a8w ,)
202209 if args .llama_model == "stories110m" :
203210 custom_annotations = custom_annotations + (
204211 annotate_linear_16a8w_in_affine_layer ,
205212 )
206- quantizer = make_quantizer (
207- quant_dtype = quant_dtype ,
208- per_channel_conv = True ,
209- per_channel_linear = True ,
210- act_observer = MinMaxObserver ,
211- )
212- quantizer .add_custom_quant_annotations (custom_annotations )
213213
214- if args .range_setting == "mse_weight" :
215- add_mse_weight_observer (quant_dtype , quantizer )
214+ quantizer = make_custom_quantizer (quant_dtype , args .range_setting , custom_annotations , args .quant_linear_only )
216215
217216 with torch .no_grad ():
217+ logging .info ("Starting export..." )
218218 model = torch .export .export (model , inputs , strict = True ).module ()
219219 if quant_dtype == QuantDtype .use_16a4w_block :
220220 conv_nodes = [n for n in model .graph .nodes if "conv" in n .name ]
221221 block_size_map = {n .name : (1 , 64 , 1 , 1 ) for n in conv_nodes }
222222 quantizer .set_block_size_map (block_size_map )
223-
223+ logging . info ( "Finished export, adding observers (prepare_pt2e)..." )
224224 model = prepare_pt2e (model , quantizer )
225225
226- logging .info ("Quantizing the model ..." )
226+ logging .info ("Observers added, starting calibration ..." )
227227
228228 calibrate (
229229 inputs ,
@@ -236,7 +236,24 @@ def permute(w, heads):
236236 use_i64_token = use_i64_token ,
237237 )
238238
239+ if args .range_setting == "mse_with_act_loss" :
240+ # scales_state_dict = torch.load("scales_state_dict.pth")
241+ set_scales (model , scales_state_dict )
242+
243+ logging .info ("Quantizing the model..." )
239244 model = convert_pt2e (model )
245+ logging .info ("Quantization complete! Here is some sample generated text:" )
246+
247+ calibrate (
248+ inputs ,
249+ "Could you tell me about Facebook?" ,
250+ model ,
251+ tokenizer = tokenizer ,
252+ ar_len = args .prefill_ar_len ,
253+ max_seq_len = args .max_seq_len ,
254+ kv_updater = None ,
255+ use_i64_token = use_i64_token ,
256+ )
240257
241258 model = WrappedLlamaModel (
242259 model , atten_mask , args .use_kv_cache , args .max_seq_length , args .device
@@ -248,7 +265,7 @@ def permute(w, heads):
248265 max_seq_length = args .calibration_seq_length ,
249266 use_kv_cache = args .use_kv_cache ,
250267 generate_full_logits = args .generate_full_logits ,
251- enable_dynamic_shape = args . enable_dynamic_shape ,
268+ enable_dynamic_shape = False ,
252269 )
253270
254271
@@ -271,7 +288,7 @@ def eval_llama(
271288 model = eval_wrapper ,
272289 tasks = args .tasks ,
273290 num_fewshot = args .num_fewshot ,
274- limit = args .limit ,
291+ limit = args .fraction ,
275292 )
276293
277294 for task , res in eval_results ["results" ].items ():
@@ -291,13 +308,18 @@ def main() -> None:
291308 )
292309 parser .add_argument (
293310 "--range_setting" ,
294- help = "Choose which range setting method (e.g. mse_weight ). If not specified, will do minmax for weights and activations " ,
311+ help = "Choose which range setting method for weight quantization (e.g. mse_weight_only or mse_with_act_loss ). If not specified, defaults to minmax" ,
295312 type = str ,
296313 )
297314 parser .add_argument (
298- "--limit" ,
299- help = "the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples" ,
300- type = str ,
315+ "--fraction" ,
316+ help = "the fraction of examples per task (only use this for testing)" ,
317+ type = float ,
318+ )
319+ parser .add_argument (
320+ "--quant_linear_only" ,
321+ help = "if you select this option we quantize linear layers only. If ptq arg not specified then defaults to 16a4w" ,
322+ action = 'store_true' ,
301323 )
302324
303325 args = parser .parse_args ()
0 commit comments