File tree Expand file tree Collapse file tree 1 file changed +3
-6
lines changed
examples/qualcomm/oss_scripts/llama Expand file tree Collapse file tree 1 file changed +3
-6
lines changed Original file line number Diff line number Diff line change @@ -83,7 +83,6 @@ def _model_call(self, inps):
8383 inps ,
8484 self ._model ,
8585 self ._tokenizer ,
86- self .ar_len ,
8786 self .max_seq_length ,
8887 use_i64_token = self .use_i64_token ,
8988 collect_logits = True ,
@@ -458,15 +457,13 @@ def prefill_inference(
458457 logits , new_k_caches , new_v_caches = results
459458 elif len (results ) == 1 :
460459 logits = results
461- logits = torch .argmax (logits [:, pos - 1 ], dim = - 1 ).item ()
462- token_list .append (logits )
460+ token = torch .argmax (logits [:, pos - 1 ], dim = - 1 ).item ()
461+ token_list .append (token )
463462 if collect_logits :
464- result_logits . append ( logits )
463+ result_logits = logits [:, : pos ]
465464 pos += 1
466465
467466 logging .info (f"prefill inference result:\n { tokenizer .decode (token_list )} " )
468- if collect_logits :
469- result_logits = torch .cat (result_logits , dim = 1 )
470467 return result_logits
471468
472469
You can’t perform that action at this time.
0 commit comments