@@ -176,6 +176,28 @@ def __init__(
176176
177177 if self .verbose :
178178 print (llama_cpp .llama_print_system_info ().decode ("utf-8" ), file = sys .stderr )
179+
180+
181+ n_vocab = self .n_vocab ()
182+ n_ctx = self .n_ctx ()
183+ data = (llama_cpp .llama_token_data * n_vocab )(
184+ * [
185+ llama_cpp .llama_token_data (
186+ id = llama_cpp .llama_token (i ),
187+ logit = llama_cpp .c_float (0.0 ),
188+ p = llama_cpp .c_float (0.0 ),
189+ )
190+ for i in range (n_vocab )
191+ ]
192+ )
193+ size = llama_cpp .c_size_t (n_vocab )
194+ sorted = False
195+ candidates = llama_cpp .llama_token_data_array (
196+ data = data ,
197+ size = size ,
198+ sorted = sorted ,
199+ )
200+ self ._candidates = candidates
179201
180202 def tokenize (self , text : bytes , add_bos : bool = True ) -> List [int ]:
181203 """Tokenize a string.
@@ -296,33 +318,23 @@ def _sample(
296318 ):
297319 assert self .ctx is not None
298320 assert len (self .eval_logits ) > 0
299- n_vocab = int ( llama_cpp . llama_n_vocab ( self .ctx ) )
300- n_ctx = int ( llama_cpp . llama_n_ctx ( self .ctx ) )
321+ n_vocab = self .n_vocab ( )
322+ n_ctx = self .n_ctx ( )
301323 top_k = llama_cpp .c_int (n_vocab ) if top_k .value <= 0 else top_k
302324 last_n_tokens_size = (
303325 llama_cpp .c_int (n_ctx )
304326 if last_n_tokens_size .value < 0
305327 else last_n_tokens_size
306328 )
307329 logits = self .eval_logits [- 1 ]
308- nl_logit = logits [int (Llama .token_nl ())]
309- data = (llama_cpp .llama_token_data * n_vocab )(
310- * [
311- llama_cpp .llama_token_data (
312- id = llama_cpp .llama_token (i ),
313- logit = logits [i ],
314- p = llama_cpp .c_float (0.0 ),
315- )
316- for i in range (n_vocab )
317- ]
318- )
319- size = llama_cpp .c_size_t (n_vocab )
320- sorted = False
321- candidates = llama_cpp .llama_token_data_array (
322- data = data ,
323- size = size ,
324- sorted = sorted ,
325- )
330+ nl_logit = logits [Llama .token_nl ()]
331+ candidates = self ._candidates
332+ for i , logit in enumerate (logits ):
333+ candidates .data [i ].id = llama_cpp .llama_token (i )
334+ candidates .data [i ].logit = llama_cpp .c_float (logit )
335+ candidates .data [i ].p = llama_cpp .c_float (0.0 )
336+ candidates .sorted = llama_cpp .c_bool (False )
337+ candidates .size = llama_cpp .c_size_t (n_vocab )
326338 llama_cpp .llama_sample_repetition_penalty (
327339 ctx = self .ctx ,
328340 last_tokens_data = last_n_tokens_data ,
@@ -339,7 +351,7 @@ def _sample(
339351 alpha_presence = presence_penalty ,
340352 )
341353 if not penalize_nl :
342- candidates .data [int ( Llama .token_nl () )].logit = nl_logit
354+ candidates .data [Llama .token_nl ()].logit = nl_logit
343355 if temp .value == 0.0 :
344356 return llama_cpp .llama_sample_token_greedy (
345357 ctx = self .ctx ,
0 commit comments