55# LICENSE file in the root directory of this source tree.
66
77import argparse
8- import copy
98import json
109
1110import logging
1211import sys
1312
14- from typing import List , Tuple
15-
1613import torch
17- import torch . nn as nn
14+
1815from executorch .backends .qualcomm .quantizer .custom_annotation import (
1916 annotate_linear_16a8w_in_affine_layer ,
2017 annotate_matmul_16a8w ,
4643 LlamaModel ,
4744 ModelArgs ,
4845)
49-
50- from executorch .examples .qualcomm .utils import make_quantizer
46+ from executorch .examples .qualcomm .oss_scripts .llama .range_setting_pt2e import (
47+ compute_scales ,
48+ make_custom_quantizer ,
49+ reverse_quantize_module_swap ,
50+ set_scales ,
51+ WrappedLlamaModel ,
52+ )
5153
5254from lm_eval .evaluator import simple_evaluate
5355
5456from pytorch_tokenizers import get_tokenizer
5557
56- from torchao .quantization .pt2e import MinMaxObserver
5758from torchao .quantization .pt2e .quantize_pt2e import convert_pt2e , prepare_pt2e
5859from torchao .quantization .pt2e .quantizer import QuantizationSpec
5960
6465logging .getLogger ().setLevel (logging .INFO )
6566
6667
67- class WrappedLlamaModel (nn .Module ):
68- def __init__ (
69- self , model , atten_mask , use_kv_cache = False , max_seq_len = 512 , device = "cuda"
70- ):
71- super (WrappedLlamaModel , self ).__init__ ()
72- self .model = model
73- self .max_seq_len = max_seq_len
74- self .use_kv_cache = use_kv_cache
75- self .device = device
76- self .atten_mask = atten_mask
77-
78- def forward (
79- self ,
80- tokens : torch .Tensor ,
81- * args ,
82- ) -> Tuple [torch .Tensor , List [torch .Tensor ], List [torch .Tensor ]]:
83- # Pad input if necessary, since LlamaModel requires static shape
84- if tokens .shape [1 ] != self .max_seq_len :
85- tokens = torch .nn .functional .pad (
86- tokens , (0 , self .max_seq_len - tokens .shape [1 ])
87- )
88- return self .model .forward (tokens , self .atten_mask )
89-
90-
9168def add_mse_weight_observer (quant_dtype , quantizer ):
9269 weight_dtype = (
9370 torch .int4
@@ -115,24 +92,16 @@ def add_mse_weight_observer(quant_dtype, quantizer):
11592 )
11693
11794
118- def gen_eval_wrapper (model_name , args ):
119- tokenizer = get_tokenizer (args .tokenizer_path )
95+ def prepare_model (model_name , args ):
12096 with open (args .params ) as f :
121- kv_config = ModelArgs (** json .load (f ))
97+ prefill_config = ModelArgs (** json .load (f ))
12298 # TODO: support batch inputs if necessary
123- kv_config .max_batch_size = 1
124- kv_config .max_seq_len = args .max_seq_length
125- kv_config .use_kv_cache = True
126-
127- prefill_config = copy .copy (kv_config )
99+ prefill_config .max_batch_size = 1
128100 prefill_config .max_seq_len = args .max_seq_length
129- prefill_config .use_kv_cache = (
130- False if args .max_seq_length == args .prefill_ar_len else True
131- )
132- config = prefill_config
101+ prefill_config .use_kv_cache = False
133102 use_i64_token = args .embedding_quantize is not None
134103 model = LlamaModel (
135- config ,
104+ prefill_config ,
136105 ar_len = args .prefill_ar_len ,
137106 output_new_cache_only = True ,
138107 output_cache = False ,
@@ -173,57 +142,72 @@ def permute(w, heads):
173142 if "model" in state_dict :
174143 state_dict = state_dict ["model" ]
175144
145+ # TODO: use dtype of model checkpoint
146+ model = model .to (device = args .device , dtype = torch .float )
147+ inputs = model .get_example_inputs (use_kv_cache = False )
148+ tokens , atten_mask = inputs
149+
150+ scales_state_dict = {}
151+ if args .range_setting == "mse_with_act_loss" :
152+ wrapped_model = WrappedLlamaModel (
153+ model , atten_mask , args .use_kv_cache , args .max_seq_length , args .device
154+ )
155+ act_bits , weight_bits = {
156+ "8a8w" : (8 , 8 ),
157+ "16a4w" : (16 , 4 ),
158+ "16a4w_block" : (16 , 4 ),
159+ }[args .ptq ]
160+ scales_state_dict = compute_scales (
161+ wrapped_model , tokens , weight_bits , act_bits , 1600
162+ )
163+ torch .save (scales_state_dict , "scales_state_dict.pth" )
164+ logging .info ("Saved scales to scales_state_dict.pth!" )
165+ reverse_quantize_module_swap (wrapped_model )
166+
176167 for layer in model .layers :
177168 if getattr (layer .attention , "prepare_sha" , None ):
178169 layer .attention .prepare_sha ()
179170 if getattr (layer .feed_forward , "prepare_feedfoward_conv" , None ):
180171 layer .feed_forward .prepare_feedfoward_conv ()
181-
182- model .to (dtype = torch .float )
183- model .to (device = args .device )
184-
185- tokens , atten_mask = model .get_example_inputs (use_kv_cache = False )
186- tokens = tokens .to (device = args .device )
187- atten_mask = atten_mask .to (device = args .device )
188- atten_mask = atten_mask .to (dtype = torch .float )
189- inputs = (tokens , atten_mask )
190-
191172 if args .embedding_quantize :
192173 model = get_quant_embedding_transform (
193174 embedding_quantize = args .embedding_quantize
194175 )(model )
195176
196177 model = convert_linear_to_conv2d (model )
178+ return model , prefill_config , inputs , scales_state_dict
179+
197180
198- if args .ptq :
181+ def gen_eval_wrapper (model_name , args ):
182+ tokenizer = get_tokenizer (args .tokenizer_path )
183+ model , config , inputs , scales_state_dict = prepare_model (model_name , args )
184+ tokens , atten_mask = inputs
185+ use_i64_token = args .embedding_quantize is not None
186+
187+ if args .ptq is not None :
199188 quant_dtype = getattr (QuantDtype , f"use_{ args .ptq } " )
200189
201190 custom_annotations = (annotate_matmul_16a8w ,)
202191 if args .llama_model == "stories110m" :
203192 custom_annotations = custom_annotations + (
204193 annotate_linear_16a8w_in_affine_layer ,
205194 )
206- quantizer = make_quantizer (
207- quant_dtype = quant_dtype ,
208- per_channel_conv = True ,
209- per_channel_linear = True ,
210- act_observer = MinMaxObserver ,
211- )
212- quantizer .add_custom_quant_annotations (custom_annotations )
213195
214- if args .range_setting == "mse_weight" :
215- add_mse_weight_observer (quant_dtype , quantizer )
196+ quantizer = make_custom_quantizer (
197+ quant_dtype , args .range_setting , custom_annotations , args .quant_linear_only
198+ )
216199
217200 with torch .no_grad ():
201+ logging .info ("Starting export..." )
218202 model = torch .export .export (model , inputs , strict = True ).module ()
219203 if quant_dtype == QuantDtype .use_16a4w_block :
220204 conv_nodes = [n for n in model .graph .nodes if "conv" in n .name ]
221205 block_size_map = {n .name : (1 , 64 , 1 , 1 ) for n in conv_nodes }
222206 quantizer .set_block_size_map (block_size_map )
223-
207+ logging . info ( "Finished export, adding observers (prepare_pt2e)..." )
224208 model = prepare_pt2e (model , quantizer )
225209
226- logging .info ("Quantizing the model ..." )
210+ logging .info ("Observers added, starting calibration ..." )
227211
228212 calibrate (
229213 inputs ,
@@ -236,7 +220,24 @@ def permute(w, heads):
236220 use_i64_token = use_i64_token ,
237221 )
238222
223+ if args .range_setting == "mse_with_act_loss" :
224+ # scales_state_dict = torch.load("scales_state_dict.pth")
225+ set_scales (model , scales_state_dict , config .head_dim )
226+
227+ logging .info ("Quantizing the model..." )
239228 model = convert_pt2e (model )
229+ logging .info ("Quantization complete! Here is some sample generated text:" )
230+
231+ calibrate (
232+ inputs ,
233+ "Could you tell me about Facebook?" ,
234+ model ,
235+ tokenizer = tokenizer ,
236+ ar_len = args .prefill_ar_len ,
237+ max_seq_len = args .max_seq_len ,
238+ kv_updater = None ,
239+ use_i64_token = use_i64_token ,
240+ )
240241
241242 model = WrappedLlamaModel (
242243 model , atten_mask , args .use_kv_cache , args .max_seq_length , args .device
@@ -248,7 +249,7 @@ def permute(w, heads):
248249 max_seq_length = args .calibration_seq_length ,
249250 use_kv_cache = args .use_kv_cache ,
250251 generate_full_logits = args .generate_full_logits ,
251- enable_dynamic_shape = args . enable_dynamic_shape ,
252+ enable_dynamic_shape = False ,
252253 )
253254
254255
@@ -271,6 +272,7 @@ def eval_llama(
271272 model = eval_wrapper ,
272273 tasks = args .tasks ,
273274 num_fewshot = args .num_fewshot ,
275+ limit = args .fraction ,
274276 )
275277
276278 for task , res in eval_results ["results" ].items ():
@@ -290,9 +292,19 @@ def main() -> None:
290292 )
291293 parser .add_argument (
292294 "--range_setting" ,
293- help = "Choose which range setting method (e.g. mse_weight ). If not specified, will do minmax for weights and activations " ,
295+ help = "Choose which range setting method for weight quantization (e.g. mse_weight_only or mse_with_act_loss ). If not specified, defaults to minmax" ,
294296 type = str ,
295297 )
298+ parser .add_argument (
299+ "--fraction" ,
300+ help = "the fraction of examples per task (only use this for testing)" ,
301+ type = float ,
302+ )
303+ parser .add_argument (
304+ "--quant_linear_only" ,
305+ help = "if you select this option we quantize linear layers only" ,
306+ action = "store_true" ,
307+ )
296308
297309 args = parser .parse_args ()
298310 args .llama_model = "llama3_2"
0 commit comments