@@ -357,7 +357,7 @@ def quantize(self, quant_dtype, custom_annotations=()):
357357 ).module ()
358358 fx_graph_module = prepare_pt2e (fx_graph_module , quantizer )
359359 print ("Quantizing the model..." )
360-
360+ example_inputs = self . get_example_inputs ( self . llama_meta [ "get_use_kv_cache" ])
361361 calibrate (
362362 self .get_example_inputs (self .llama_meta ["get_use_kv_cache" ]),
363363 args .prompt ,
@@ -368,6 +368,50 @@ def quantize(self, quant_dtype, custom_annotations=()):
368368
369369 self .llama_model = convert_pt2e (fx_graph_module )
370370
371+ sp_model = SentencePieceProcessor (model_file = args .tokenizer_model )
372+ _ , atten_mask , _ , k_caches , v_caches = example_inputs
373+
374+ # TODO: change criteria & support batch inputs if necessary
375+ pos = torch .tensor (0 , dtype = torch .int32 )
376+ token_list = [sp_model .bos_id ()]
377+ for prompt in args .prompt .split ():
378+ token_list += sp_model .encode (prompt )
379+
380+ def sample_top_p (probs : torch .Tensor , top_p : float ) -> torch .Tensor :
381+ probs_sort , probs_indices = torch .sort (probs , dim = - 1 , descending = True )
382+ probs_sum = torch .cumsum (probs_sort , dim = - 1 )
383+ mask = probs_sum - probs_sort > top_p
384+ probs_sort [mask ] = 0
385+ probs_sort /= probs_sort .sum (dim = - 1 , keepdim = True )
386+ next_token = torch .multinomial (probs_sort , num_samples = 1 )
387+ return probs_indices .gather (dim = - 1 , index = next_token )
388+
389+ with torch .no_grad ():
390+ while token_list [- 1 ] != sp_model .eos_id () and pos < args .seq_len - 1 :
391+ logits , new_k_caches , new_v_caches = self .llama_model (
392+ torch .full ((1 , 1 ), token_list [pos ]),
393+ atten_mask ,
394+ torch .full ((1 , 1 ), pos ),
395+ * k_caches ,
396+ * v_caches ,
397+ )
398+ k_caches = [
399+ torch .cat ([k_cache [:, :, 1 :], new_k_caches [i ]], dim = - 1 )
400+ for i , k_cache in enumerate (k_caches )
401+ ]
402+ v_caches = [
403+ torch .cat ([v_cache [:, 1 :, :], new_v_caches [i ]], dim = 1 )
404+ for i , v_cache in enumerate (v_caches )
405+ ]
406+
407+ pos += 1
408+ atten_mask [0 ][- pos - 1 ] = 0
409+ if pos >= len (token_list ):
410+ probs = torch .softmax (logits [:, - 1 ] / 0.8 , dim = - 1 )
411+ token_list .append (sample_top_p (probs , 0.9 ).item ())
412+ print ("-----" )
413+ print (f"convert_pt2e data:\n { sp_model .decode (token_list )} " )
414+
371415 def lowering_modules (
372416 self , work_space , kv_type = torch .uint8 , soc_model = QcomChipset .SM8650
373417 ):
@@ -495,17 +539,18 @@ def inference(args, pre_gen_pte=""):
495539 runner_args = " " .join (
496540 [
497541 f"--model_path { pte_filename } .pte" ,
498- "--output_folder_path outputs" ,
542+ "--output_path outputs/outputs.txt " ,
499543 f"--tokenizer_path { os .path .basename (args .tokenizer_bin )} " ,
500544 f'--prompt "{ args .prompt } "' ,
501545 f"--seq_len { args .seq_len } " ,
502546 f"--temperature { args .temperature } " ,
547+ "--eval_mode 1" ,
503548 ]
504549 )
505550 runner_cmd = " " .join (
506551 [
507552 f"cd { workspace } &&" ,
508- f"./qnn_llama_runner { runner_args } " ,
553+ f"./qnn_llama3_2_runner { runner_args } " ,
509554 ]
510555 )
511556
@@ -523,7 +568,7 @@ def inference(args, pre_gen_pte=""):
523568 host_id = args .host ,
524569 soc_model = args .model ,
525570 shared_buffer = args .shared_buffer ,
526- runner = "examples/qualcomm/oss_scripts/llama2/qnn_llama_runner " ,
571+ runner = "examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_runner " ,
527572 )
528573 # No pregen inputs, input_list is not required
529574 adb .push (inputs = [], input_list = "" , files = [args .tokenizer_bin ])
@@ -535,16 +580,8 @@ def inference(args, pre_gen_pte=""):
535580 outputs = []
536581
537582 def post_process ():
538- for f in sorted (
539- os .listdir (output_data_folder ), key = lambda f : int (f .split ("_" )[1 ])
540- ):
541- with codecs .open (
542- os .path .join (output_data_folder , f ),
543- "r" ,
544- encoding = "utf-8" ,
545- errors = "replace" ,
546- ) as fdata :
547- outputs .append (fdata .read ())
583+ with open (f"{ args .artifact } /outputs/outputs.txt" , "r" ) as f :
584+ outputs .append (f .read ())
548585
549586 adb .pull (output_path = args .artifact , callback = post_process )
550587
0 commit comments