2020from . import llama_cpp
2121from .llama_types import *
2222
23+ import numpy as np
24+ import numpy .typing as npt
25+
2326
2427class LlamaCache :
2528 """Cache for a llama.cpp model."""
@@ -73,11 +76,15 @@ def __init__(
7376 self ,
7477 eval_tokens : Deque [int ],
7578 eval_logits : Deque [List [float ]],
79+ input_ids : npt .NDArray [np .intc ],
80+ scores : npt .NDArray [np .single ],
7681 llama_state , # type: llama_cpp.Array[llama_cpp.c_uint8]
7782 llama_state_size : int ,
7883 ):
7984 self .eval_tokens = eval_tokens
8085 self .eval_logits = eval_logits
86+ self .input_ids = input_ids
87+ self .scores = scores
8188 self .llama_state = llama_state
8289 self .llama_state_size = llama_state_size
8390
@@ -207,27 +214,24 @@ def __init__(
207214
208215 self ._n_vocab = self .n_vocab ()
209216 self ._n_ctx = self .n_ctx ()
210- data = (llama_cpp .llama_token_data * self ._n_vocab )(
211- * [
212- llama_cpp .llama_token_data (
213- id = llama_cpp .llama_token (i ),
214- logit = llama_cpp .c_float (0.0 ),
215- p = llama_cpp .c_float (0.0 ),
216- )
217- for i in range (self ._n_vocab )
218- ]
219- )
220217 size = llama_cpp .c_size_t (self ._n_vocab )
221- sorted = False
218+ sorted = llama_cpp .c_bool (False )
219+ self ._candidates_data = np .array (
220+ [], dtype = [("id" , np .intc ), ("logit" , np .single ), ("p" , np .single )]
221+ )
222+ self ._candidates_data .resize (3 , self ._n_vocab )
222223 candidates = llama_cpp .llama_token_data_array (
223- data = data ,
224+ data = self . _candidates_data . ctypes . data_as ( llama_cpp . llama_token_data_p ) ,
224225 size = size ,
225226 sorted = sorted ,
226227 )
227228 self ._candidates = candidates
228229 self ._token_nl = Llama .token_nl ()
229230 self ._token_eos = Llama .token_eos ()
230231
232+ self ._input_ids = np .array ([], dtype = np .intc )
233+ self ._scores = np .ndarray ((0 , self ._n_vocab ), dtype = np .single )
234+
231235 def tokenize (self , text : bytes , add_bos : bool = True ) -> List [int ]:
232236 """Tokenize a string.
233237
@@ -319,13 +323,19 @@ def eval(self, tokens: Sequence[int]):
319323 raise RuntimeError (f"llama_eval returned { return_code } " )
320324 # Save tokens
321325 self .eval_tokens .extend (batch )
326+ self ._input_ids : npt .NDArray [np .intc ] = np .concatenate (
327+ (self ._input_ids , np .array (batch , dtype = np .intc )), axis = 0
328+ )
322329 # Save logits
323330 rows = n_tokens if self .params .logits_all else 1
324331 n_vocab = self ._n_vocab
325332 cols = n_vocab
326333 logits_view = llama_cpp .llama_get_logits (self .ctx )
327334 logits = [logits_view [i * cols : (i + 1 ) * cols ] for i in range (rows )]
328335 self .eval_logits .extend (logits )
336+ self ._scores : npt .NDArray [np .single ] = np .concatenate (
337+ (self ._scores , np .array (logits , dtype = np .single )), axis = 0
338+ )
329339
330340 def _sample (
331341 self ,
@@ -354,18 +364,23 @@ def _sample(
354364 if last_n_tokens_size .value < 0
355365 else last_n_tokens_size
356366 )
357- logits = self .eval_logits [- 1 ]
367+ logits : npt . NDArray [ np . single ] = self ._scores [- 1 , : ]
358368
359369 if logits_processor is not None :
360- logits = logits_processor (list (self .eval_tokens ), logits )
361- self .eval_logits [- 1 ] = logits
370+ logits = np .array (
371+ logits_processor (list (self .eval_tokens ), logits .tolist ()),
372+ dtype = np .single ,
373+ )
374+ self ._scores [- 1 , :] = logits
375+ self .eval_logits [- 1 ] = logits .tolist ()
362376
363377 nl_logit = logits [self ._token_nl ]
364378 candidates = self ._candidates
365- for i , logit in enumerate (logits ):
366- candidates .data [i ].id = llama_cpp .llama_token (i )
367- candidates .data [i ].logit = llama_cpp .c_float (logit )
368- candidates .data [i ].p = llama_cpp .c_float (0.0 )
379+ candidates_data = self ._candidates_data
380+ candidates_data ["id" ] = np .arange (n_vocab , dtype = np .intc ) # type: ignore
381+ candidates_data ["logit" ] = logits
382+ candidates_data ["p" ] = np .zeros (n_vocab , dtype = np .single )
383+ candidates .data = candidates_data .ctypes .data_as (llama_cpp .llama_token_data_p )
369384 candidates .sorted = llama_cpp .c_bool (False )
370385 candidates .size = llama_cpp .c_size_t (n_vocab )
371386 llama_cpp .llama_sample_repetition_penalty (
@@ -1371,6 +1386,8 @@ def save_state(self) -> LlamaState:
13711386 return LlamaState (
13721387 eval_tokens = self .eval_tokens .copy (),
13731388 eval_logits = self .eval_logits .copy (),
1389+ scores = self ._scores .copy (),
1390+ input_ids = self ._input_ids .copy (),
13741391 llama_state = llama_state_compact ,
13751392 llama_state_size = n_bytes ,
13761393 )
@@ -1379,6 +1396,8 @@ def load_state(self, state: LlamaState) -> None:
13791396 assert self .ctx is not None
13801397 self .eval_tokens = state .eval_tokens .copy ()
13811398 self .eval_logits = state .eval_logits .copy ()
1399+ self ._scores = state .scores .copy ()
1400+ self ._input_ids = state .input_ids .copy ()
13821401 state_size = state .llama_state_size
13831402 if llama_cpp .llama_set_state_data (self .ctx , state .llama_state ) != state_size :
13841403 raise RuntimeError ("Failed to set llama state data" )
0 commit comments