@@ -299,6 +299,8 @@ def reset(self):
299299 """Reset the model state."""
300300 self .eval_tokens .clear ()
301301 self .eval_logits .clear ()
302+ self ._input_ids = np .array ([], dtype = np .intc )
303+ self ._scores = np .ndarray ((0 , self ._n_vocab ), dtype = np .single )
302304
303305 def eval (self , tokens : Sequence [int ]):
304306 """Evaluate a list of tokens.
@@ -310,7 +312,7 @@ def eval(self, tokens: Sequence[int]):
310312 n_ctx = self ._n_ctx
311313 for i in range (0 , len (tokens ), self .n_batch ):
312314 batch = tokens [i : min (len (tokens ), i + self .n_batch )]
313- n_past = min (n_ctx - len (batch ), len (self .eval_tokens ))
315+ n_past = min (n_ctx - len (batch ), len (self ._input_ids ))
314316 n_tokens = len (batch )
315317 return_code = llama_cpp .llama_eval (
316318 ctx = self .ctx ,
@@ -356,6 +358,7 @@ def _sample(
356358 ):
357359 assert self .ctx is not None
358360 assert len (self .eval_logits ) > 0
361+ assert self ._scores .shape [0 ] > 0
359362 n_vocab = self ._n_vocab
360363 n_ctx = self ._n_ctx
361364 top_k = llama_cpp .c_int (n_vocab ) if top_k .value <= 0 else top_k
@@ -368,7 +371,7 @@ def _sample(
368371
369372 if logits_processor is not None :
370373 logits = np .array (
371- logits_processor (list ( self .eval_tokens ), logits .tolist ()),
374+ logits_processor (self ._input_ids . tolist ( ), logits .tolist ()),
372375 dtype = np .single ,
373376 )
374377 self ._scores [- 1 , :] = logits
@@ -498,8 +501,8 @@ def sample(
498501 """
499502 assert self .ctx is not None
500503 last_n_tokens_data = [llama_cpp .llama_token (0 )] * max (
501- 0 , self .last_n_tokens_size - len (self .eval_tokens )
502- ) + list ( self .eval_tokens ) [- self .last_n_tokens_size :]
504+ 0 , self .last_n_tokens_size - len (self ._input_ids )
505+ ) + self ._input_ids [- self .last_n_tokens_size :]. tolist ()
503506 return self ._sample (
504507 last_n_tokens_data = (llama_cpp .llama_token * self .last_n_tokens_size )(
505508 * last_n_tokens_data
@@ -557,9 +560,9 @@ def generate(
557560 """
558561 assert self .ctx is not None
559562
560- if reset and len (self .eval_tokens ) > 0 :
563+ if reset and len (self ._input_ids ) > 0 :
561564 longest_prefix = 0
562- for a , b in zip (self .eval_tokens , tokens [:- 1 ]):
565+ for a , b in zip (self ._input_ids , tokens [:- 1 ]):
563566 if a == b :
564567 longest_prefix += 1
565568 else :
@@ -569,6 +572,8 @@ def generate(
569572 print ("Llama.generate: prefix-match hit" , file = sys .stderr )
570573 reset = False
571574 tokens = tokens [longest_prefix :]
575+ self ._input_ids = self ._input_ids [:longest_prefix ]
576+ self ._scores = self ._scores [:longest_prefix , :]
572577 for _ in range (len (self .eval_tokens ) - longest_prefix ):
573578 self .eval_tokens .pop ()
574579 try :
@@ -595,7 +600,7 @@ def generate(
595600 logits_processor = logits_processor ,
596601 )
597602 if stopping_criteria is not None and stopping_criteria (
598- list ( self .eval_tokens ), self .eval_logits [- 1 ]
603+ self ._input_ids . tolist ( ), self ._scores [- 1 , :]. tolist ()
599604 ):
600605 return
601606 tokens_or_none = yield token
@@ -820,7 +825,7 @@ def _create_completion(
820825 self .detokenize (completion_tokens [:returned_tokens ])
821826 )
822827 token_offset = len (prompt_tokens ) + returned_tokens
823- logits = self .eval_logits [token_offset - 1 ]
828+ logits = self ._scores [token_offset - 1 , :]. tolist ()
824829 current_logprobs = Llama .logits_to_logprobs (logits )
825830 sorted_logprobs = list (
826831 sorted (
@@ -869,7 +874,7 @@ def _create_completion(
869874 break
870875
871876 if stopping_criteria is not None and stopping_criteria (
872- list ( self .eval_tokens ), self .eval_logits [- 1 ]
877+ self ._input_ids . tolist ( ), self ._scores [- 1 , :]. tolist ()
873878 ):
874879 text = self .detokenize (completion_tokens )
875880 finish_reason = "stop"
@@ -899,7 +904,7 @@ def _create_completion(
899904 self .detokenize (completion_tokens [:returned_tokens ])
900905 )
901906 token_offset = len (prompt_tokens ) + returned_tokens - 1
902- logits = self .eval_logits [token_offset ]
907+ logits = self ._scores [token_offset , :]. tolist ()
903908 current_logprobs = Llama .logits_to_logprobs (logits )
904909 sorted_logprobs = list (
905910 sorted (
@@ -1001,8 +1006,7 @@ def _create_completion(
10011006 for token in all_tokens
10021007 ]
10031008 all_logprobs = [
1004- Llama .logits_to_logprobs (list (map (float , row )))
1005- for row in self .eval_logits
1009+ Llama .logits_to_logprobs (row .tolist ()) for row in self ._scores
10061010 ][token_offset :]
10071011 for token , token_str , logprobs_token in zip (
10081012 all_tokens , all_token_strs , all_logprobs
0 commit comments