Skip to content

Commit 18498bf

Browse files
authored
Fix eval_llama_qnn (#14439)
Reviewed By: cccclai Differential Revision: D82790290
1 parent 9e7a264 commit 18498bf

File tree

2 files changed

+20
-21
lines changed

2 files changed

+20
-21
lines changed

examples/qualcomm/oss_scripts/llama/decoder_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -494,8 +494,8 @@ def prefill_inference(
494494
if collect_logits:
495495
result_logits = logits[:, :pos]
496496
pos += 1
497-
498-
logging.info(f"prefill inference result:\n{tokenizer.decode(token_list)}")
497+
if isinstance(prompt, str):
498+
logging.info(f"prefill inference result:\n{tokenizer.decode(token_list)}")
499499
return result_logits
500500

501501

examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def prepare_tokenizer(args):
108108
args.tokenizer_bin is not None
109109
), "Please provide tokenizer_bin for stories."
110110
runtime_tokenizer_path = args.tokenizer_bin
111-
elif args.decoder_model == "llama3_2":
111+
elif "llama3_2" in args.decoder_model:
112112
tokenizer = get_tokenizer(args.tokenizer_model)
113113
assert isinstance(
114114
tokenizer, TiktokenTokenizer
@@ -240,7 +240,7 @@ def prequant_algorithm(model, prefill_config, args):
240240

241241
if args.range_setting == "mse_with_act_loss":
242242
wrapped_model = WrappedLlamaModel(
243-
model, atten_mask, args.use_kv_cache, args.max_seq_length, args.device
243+
model, *atten_mask, args.use_kv_cache, args.max_seq_length, args.device
244244
)
245245
act_bits, weight_bits = {
246246
"8a8w": (8, 8),
@@ -355,20 +355,20 @@ def eval_llm(args):
355355

356356
logging.info("Quantizing the model...")
357357
model = convert_pt2e(model)
358-
logging.info("Quantization complete! Here is some sample generated text:")
359-
360-
graph_module_inference(
361-
use_kv_cache=False,
362-
get_example_inputs=lambda use_kv_cache=False: inputs,
363-
module=model,
364-
tokenizer=tokenizer,
365-
ar_len=args.max_seq_len,
366-
max_seq_len=args.max_seq_len,
367-
kv_updater=args.kv_updater,
368-
prompt="Can you tell me about Facebook?",
369-
use_i64_token=use_i64_token,
370-
event_name="convert_pt2e_prompt",
371-
)
358+
# logging.info("Quantization complete! Here is some sample generated text:")
359+
360+
# graph_module_inference(
361+
# use_kv_cache=False,
362+
# get_example_inputs=lambda use_kv_cache=False: inputs,
363+
# module=model,
364+
# tokenizer=tokenizer,
365+
# ar_len=args.max_seq_len,
366+
# max_seq_len=args.max_seq_len,
367+
# kv_updater=args.kv_updater,
368+
# prompt="Can you tell me about Facebook?",
369+
# use_i64_token=use_i64_token,
370+
# event_name="convert_pt2e_prompt",
371+
# )
372372

373373
logging.info("Evaluation of QDQ model:")
374374
graph_module_inference(
@@ -380,6 +380,7 @@ def eval_llm(args):
380380
max_seq_len=args.max_seq_len,
381381
kv_updater=args.kv_updater,
382382
tasks=["wikitext"],
383+
tasks_limit=0.1,
383384
use_i64_token=use_i64_token,
384385
event_name="convert_pt2e_prompt",
385386
)
@@ -424,9 +425,7 @@ def main() -> None:
424425
)
425426
parser.add_argument(
426427
"--decoder_model",
427-
choices=["stories260k", "stories110m", "llama3_2"]
428-
+ list(SUPPORTED_LLM_MODELS.keys()),
429-
help=f"The Llama model to export. Current available options are: [stories260k, stories110m, llama3_2] + {SUPPORTED_LLM_MODELS.keys()}",
428+
help=f"The Llama model to export. Current available options are: {SUPPORTED_LLM_MODELS.keys()}",
430429
required=True,
431430
)
432431
parser.add_argument(

0 commit comments

Comments
 (0)