Skip to content

Commit 72eb1f5

Browse files
committed
Qualcomm AI Engine Direct - Verify accuracy for Llama 3.2 1B using 110M Stories
1 parent 383aa70 commit 72eb1f5

File tree

2 files changed

+75
-28
lines changed

2 files changed

+75
-28
lines changed

examples/qualcomm/oss_scripts/llama2/llama.py

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def quantize(self, quant_dtype, custom_annotations=()):
357357
).module()
358358
fx_graph_module = prepare_pt2e(fx_graph_module, quantizer)
359359
print("Quantizing the model...")
360-
360+
example_inputs = self.get_example_inputs(self.llama_meta["get_use_kv_cache"])
361361
calibrate(
362362
self.get_example_inputs(self.llama_meta["get_use_kv_cache"]),
363363
args.prompt,
@@ -368,6 +368,50 @@ def quantize(self, quant_dtype, custom_annotations=()):
368368

369369
self.llama_model = convert_pt2e(fx_graph_module)
370370

371+
sp_model = SentencePieceProcessor(model_file=args.tokenizer_model)
372+
_, atten_mask, _, k_caches, v_caches = example_inputs
373+
374+
# TODO: change criteria & support batch inputs if necessary
375+
pos = torch.tensor(0, dtype=torch.int32)
376+
token_list = [sp_model.bos_id()]
377+
for prompt in args.prompt.split():
378+
token_list += sp_model.encode(prompt)
379+
380+
def sample_top_p(probs: torch.Tensor, top_p: float) -> torch.Tensor:
381+
probs_sort, probs_indices = torch.sort(probs, dim=-1, descending=True)
382+
probs_sum = torch.cumsum(probs_sort, dim=-1)
383+
mask = probs_sum - probs_sort > top_p
384+
probs_sort[mask] = 0
385+
probs_sort /= probs_sort.sum(dim=-1, keepdim=True)
386+
next_token = torch.multinomial(probs_sort, num_samples=1)
387+
return probs_indices.gather(dim=-1, index=next_token)
388+
389+
with torch.no_grad():
390+
while token_list[-1] != sp_model.eos_id() and pos < args.seq_len - 1:
391+
logits, new_k_caches, new_v_caches = self.llama_model(
392+
torch.full((1, 1), token_list[pos]),
393+
atten_mask,
394+
torch.full((1, 1), pos),
395+
*k_caches,
396+
*v_caches,
397+
)
398+
k_caches = [
399+
torch.cat([k_cache[:, :, 1:], new_k_caches[i]], dim=-1)
400+
for i, k_cache in enumerate(k_caches)
401+
]
402+
v_caches = [
403+
torch.cat([v_cache[:, 1:, :], new_v_caches[i]], dim=1)
404+
for i, v_cache in enumerate(v_caches)
405+
]
406+
407+
pos += 1
408+
atten_mask[0][-pos - 1] = 0
409+
if pos >= len(token_list):
410+
probs = torch.softmax(logits[:, -1] / 0.8, dim=-1)
411+
token_list.append(sample_top_p(probs, 0.9).item())
412+
print("-----")
413+
print(f"convert_pt2e data:\n{sp_model.decode(token_list)}")
414+
371415
def lowering_modules(
372416
self, work_space, kv_type=torch.uint8, soc_model=QcomChipset.SM8650
373417
):
@@ -495,17 +539,18 @@ def inference(args, pre_gen_pte=""):
495539
runner_args = " ".join(
496540
[
497541
f"--model_path {pte_filename}.pte",
498-
"--output_folder_path outputs",
542+
"--output_path outputs/outputs.txt",
499543
f"--tokenizer_path {os.path.basename(args.tokenizer_bin)}",
500544
f'--prompt "{args.prompt}"',
501545
f"--seq_len {args.seq_len}",
502546
f"--temperature {args.temperature}",
547+
"--eval_mode 1",
503548
]
504549
)
505550
runner_cmd = " ".join(
506551
[
507552
f"cd {workspace} &&",
508-
f"./qnn_llama_runner {runner_args}",
553+
f"./qnn_llama3_2_runner {runner_args}",
509554
]
510555
)
511556

@@ -523,7 +568,7 @@ def inference(args, pre_gen_pte=""):
523568
host_id=args.host,
524569
soc_model=args.model,
525570
shared_buffer=args.shared_buffer,
526-
runner="examples/qualcomm/oss_scripts/llama2/qnn_llama_runner",
571+
runner="examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_runner",
527572
)
528573
# No pregen inputs, input_list is not required
529574
adb.push(inputs=[], input_list="", files=[args.tokenizer_bin])
@@ -535,16 +580,8 @@ def inference(args, pre_gen_pte=""):
535580
outputs = []
536581

537582
def post_process():
538-
for f in sorted(
539-
os.listdir(output_data_folder), key=lambda f: int(f.split("_")[1])
540-
):
541-
with codecs.open(
542-
os.path.join(output_data_folder, f),
543-
"r",
544-
encoding="utf-8",
545-
errors="replace",
546-
) as fdata:
547-
outputs.append(fdata.read())
583+
with open(f"{args.artifact}/outputs/outputs.txt", "r") as f:
584+
outputs.append(f.read())
548585

549586
adb.pull(output_path=args.artifact, callback=post_process)
550587

examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <executorch/examples/qualcomm/oss_scripts/llama3_2/runner/runner.h>
1414
#include <executorch/extension/evalue_util/print_evalue.h>
1515
#include <executorch/extension/llm/runner/util.h>
16+
#include <executorch/extension/llm/tokenizer/bpe_tokenizer.h>
1617
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1718
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
1819
#include <executorch/runtime/platform/log.h>
@@ -70,9 +71,17 @@ Runner::Runner(
7071
vocab_size_ = vocab_size;
7172
tokenizer_ = example::get_tiktoken_for_llama();
7273
Error err = tokenizer_->load(tokenizer_path_);
73-
ET_CHECK_MSG(
74-
err == Error::Ok, "failed to load tokenizer %s", tokenizer_path_.c_str());
75-
eos_id_.insert(tokenizer_->encode("<|eot_id|>", 0, 0).get()[0]);
74+
if (err == Error::InvalidArgument) {
75+
ET_LOG(
76+
Info,
77+
"Failed to load %s as a Tiktoken artifact, trying BPE tokenizer",
78+
tokenizer_path_.c_str());
79+
tokenizer_.reset();
80+
tokenizer_ = std::make_unique<executorch::extension::llm::BPETokenizer>();
81+
tokenizer_->load(tokenizer_path_);
82+
} else {
83+
eos_id_.insert(tokenizer_->encode("<|eot_id|>", 0, 0).get()[0]);
84+
}
7685
bos_id_ = tokenizer_->bos_tok();
7786
eos_id_.insert(tokenizer_->eos_tok());
7887

@@ -182,18 +191,19 @@ Error Runner::generate(
182191
stats_.model_load_end_ms = time_in_ms();
183192
}
184193
std::string post_process_prompt;
185-
186-
if (!system_prompt.empty()) {
187-
post_process_prompt.append(
188-
"<|start_header_id|>system<|end_header_id|>\n\n");
189-
post_process_prompt.append(system_prompt);
190-
post_process_prompt.append("<|eot_id|>\n");
191-
}
192-
post_process_prompt.append("<|start_header_id|>user<|end_header_id|>\n\n");
193194
post_process_prompt.append(prompt);
194-
post_process_prompt.append(
195-
"<|eot_id|><|start_header_id|>assistant<|end_header_id|>");
196-
token_callback("<|begin_of_text|>");
195+
196+
// if (!system_prompt.empty()) {
197+
// post_process_prompt.append(
198+
// "<|start_header_id|>system<|end_header_id|>\n\n");
199+
// post_process_prompt.append(system_prompt);
200+
// post_process_prompt.append("<|eot_id|>\n");
201+
// }
202+
// post_process_prompt.append("<|start_header_id|>user<|end_header_id|>\n\n");
203+
// post_process_prompt.append(prompt);
204+
// post_process_prompt.append(
205+
// "<|eot_id|><|start_header_id|>assistant<|end_header_id|>");
206+
// token_callback("<|begin_of_text|>");
197207

198208
stats_.inference_start_ms = time_in_ms();
199209

0 commit comments

Comments
 (0)