2020 annotate_matmul_16a8w ,
2121)
2222
23+ from executorch .backends .qualcomm .quantizer .observers .per_channel_param_observer import (
24+ PerChannelParamObserver ,
25+ )
26+ from executorch .backends .qualcomm .quantizer .qconfig import (
27+ _derived_bias_quant_spec ,
28+ QuantizationConfig ,
29+ )
30+
2331from executorch .backends .qualcomm .quantizer .quantizer import QuantDtype
2432from executorch .backends .qualcomm .utils .utils import convert_linear_to_conv2d
2533
4755
4856from torchao .quantization .pt2e import MinMaxObserver
4957from torchao .quantization .pt2e .quantize_pt2e import convert_pt2e , prepare_pt2e
58+ from torchao .quantization .pt2e .quantizer import QuantizationSpec
59+
5060
5161sys .setrecursionlimit (4096 )
5262FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
@@ -78,6 +88,33 @@ def forward(
7888 return self .model .forward (tokens , self .atten_mask )
7989
8090
91+ def add_mse_weight_observer (quant_dtype , quantizer ):
92+ weight_dtype = (
93+ torch .int4
94+ if quant_dtype in (QuantDtype .use_16a4w , QuantDtype .use_16a4w_block )
95+ else torch .int8
96+ )
97+ per_channel_q_config = quantizer .default_quant_config .quant_config
98+ weight_qspec = QuantizationSpec (
99+ dtype = torch .int8 if weight_dtype == torch .int4 else weight_dtype ,
100+ quant_min = (
101+ - 7 if weight_dtype == torch .int4 else torch .iinfo (weight_dtype ).min + 1
102+ ),
103+ quant_max = (7 if weight_dtype == torch .int4 else torch .iinfo (weight_dtype ).max ),
104+ qscheme = torch .per_channel_symmetric ,
105+ ch_axis = 0 ,
106+ observer_or_fake_quant_ctr = PerChannelParamObserver .with_args (
107+ ** {"steps" : 200 , "use_mse" : True }
108+ ),
109+ )
110+ quantizer .default_quant_config .per_channel_quant_config = QuantizationConfig (
111+ input_activation = per_channel_q_config .input_activation ,
112+ output_activation = per_channel_q_config .output_activation ,
113+ weight = weight_qspec ,
114+ bias = _derived_bias_quant_spec ,
115+ )
116+
117+
81118def gen_eval_wrapper (model_name , args ):
82119 tokenizer = get_tokenizer (args .tokenizer_path )
83120 with open (args .params ) as f :
@@ -142,13 +179,13 @@ def permute(w, heads):
142179 if getattr (layer .feed_forward , "prepare_feedfoward_conv" , None ):
143180 layer .feed_forward .prepare_feedfoward_conv ()
144181
145- model .to (dtype = torch .bfloat16 )
182+ model .to (dtype = torch .float )
146183 model .to (device = args .device )
147184
148185 tokens , atten_mask = model .get_example_inputs (use_kv_cache = False )
149186 tokens = tokens .to (device = args .device )
150187 atten_mask = atten_mask .to (device = args .device )
151- atten_mask = atten_mask .to (dtype = torch .bfloat16 )
188+ atten_mask = atten_mask .to (dtype = torch .float )
152189 inputs = (tokens , atten_mask )
153190
154191 if args .embedding_quantize :
@@ -174,7 +211,8 @@ def permute(w, heads):
174211 )
175212 quantizer .add_custom_quant_annotations (custom_annotations )
176213
177- model .has_quant_io = True
214+ if args .range_setting == "mse_weight" :
215+ add_mse_weight_observer (quant_dtype , quantizer )
178216
179217 with torch .no_grad ():
180218 model = torch .export .export (model , inputs , strict = True ).module ()
@@ -245,6 +283,23 @@ def main() -> None:
245283 torch .manual_seed (seed )
246284 modelname = "llama2"
247285 parser = build_args_parser ()
286+ parser .add_argument (
287+ "-P" ,
288+ "--ptq" ,
289+ help = "If specified, will do PTQ quantization. default is 16bits activation and 4bits weight. Support 8a8w, 16a4w and 16a4w_block." ,
290+ type = str ,
291+ )
292+ parser .add_argument (
293+ "--range_setting" ,
294+ help = "Choose which range setting method (e.g. mse_weight). If not specified, will do minmax for weights and activations" ,
295+ type = str ,
296+ )
297+ parser .add_argument (
298+ "--limit" ,
299+ help = "the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples" ,
300+ type = str ,
301+ )
302+
248303 args = parser .parse_args ()
249304 args .llama_model = "llama3_2"
250305 # Overrides this arg, because evaluation requires full logits.
@@ -257,15 +312,9 @@ def main() -> None:
257312 args .use_kv_cache = False
258313 args .prefill_ar_len = args .max_seq_length
259314
260- # To do fewer samples for faster evaluation
261- args .limit = 0.1
262- # args.samples = {'wikitext': list(range(1))}
263-
264315 args .device = "cuda" if torch .cuda .is_available () else "cpu"
265316 torch .set_default_device (args .device )
266317
267- args .ptq = "8a8w"
268-
269318 eval_llama (modelname , args )
270319
271320
0 commit comments