@@ -108,7 +108,7 @@ def prepare_tokenizer(args):
108
108
args .tokenizer_bin is not None
109
109
), "Please provide tokenizer_bin for stories."
110
110
runtime_tokenizer_path = args .tokenizer_bin
111
- elif args . decoder_model == "llama3_2" :
111
+ elif "llama3_2" in args . decoder_model :
112
112
tokenizer = get_tokenizer (args .tokenizer_model )
113
113
assert isinstance (
114
114
tokenizer , TiktokenTokenizer
@@ -240,7 +240,7 @@ def prequant_algorithm(model, prefill_config, args):
240
240
241
241
if args .range_setting == "mse_with_act_loss" :
242
242
wrapped_model = WrappedLlamaModel (
243
- model , atten_mask , args .use_kv_cache , args .max_seq_length , args .device
243
+ model , * atten_mask , args .use_kv_cache , args .max_seq_length , args .device
244
244
)
245
245
act_bits , weight_bits = {
246
246
"8a8w" : (8 , 8 ),
@@ -355,20 +355,20 @@ def eval_llm(args):
355
355
356
356
logging .info ("Quantizing the model..." )
357
357
model = convert_pt2e (model )
358
- logging .info ("Quantization complete! Here is some sample generated text:" )
359
-
360
- graph_module_inference (
361
- use_kv_cache = False ,
362
- get_example_inputs = lambda use_kv_cache = False : inputs ,
363
- module = model ,
364
- tokenizer = tokenizer ,
365
- ar_len = args .max_seq_len ,
366
- max_seq_len = args .max_seq_len ,
367
- kv_updater = args .kv_updater ,
368
- prompt = "Can you tell me about Facebook?" ,
369
- use_i64_token = use_i64_token ,
370
- event_name = "convert_pt2e_prompt" ,
371
- )
358
+ # logging.info("Quantization complete! Here is some sample generated text:")
359
+
360
+ # graph_module_inference(
361
+ # use_kv_cache=False,
362
+ # get_example_inputs=lambda use_kv_cache=False: inputs,
363
+ # module=model,
364
+ # tokenizer=tokenizer,
365
+ # ar_len=args.max_seq_len,
366
+ # max_seq_len=args.max_seq_len,
367
+ # kv_updater=args.kv_updater,
368
+ # prompt="Can you tell me about Facebook?",
369
+ # use_i64_token=use_i64_token,
370
+ # event_name="convert_pt2e_prompt",
371
+ # )
372
372
373
373
logging .info ("Evaluation of QDQ model:" )
374
374
graph_module_inference (
@@ -380,6 +380,7 @@ def eval_llm(args):
380
380
max_seq_len = args .max_seq_len ,
381
381
kv_updater = args .kv_updater ,
382
382
tasks = ["wikitext" ],
383
+ tasks_limit = 0.1 ,
383
384
use_i64_token = use_i64_token ,
384
385
event_name = "convert_pt2e_prompt" ,
385
386
)
@@ -424,9 +425,7 @@ def main() -> None:
424
425
)
425
426
parser .add_argument (
426
427
"--decoder_model" ,
427
- choices = ["stories260k" , "stories110m" , "llama3_2" ]
428
- + list (SUPPORTED_LLM_MODELS .keys ()),
429
- help = f"The Llama model to export. Current available options are: [stories260k, stories110m, llama3_2] + { SUPPORTED_LLM_MODELS .keys ()} " ,
428
+ help = f"The Llama model to export. Current available options are: { SUPPORTED_LLM_MODELS .keys ()} " ,
430
429
required = True ,
431
430
)
432
431
parser .add_argument (
0 commit comments