@@ -127,7 +127,6 @@ def __init__(
127127
128128 self .params = llama_cpp .llama_context_default_params ()
129129 self .params .n_ctx = n_ctx
130- self .params .n_parts = n_parts
131130 self .params .n_gpu_layers = n_gpu_layers
132131 self .params .seed = seed
133132 self .params .f16_kv = f16_kv
@@ -149,6 +148,10 @@ def __init__(
149148 self .lora_base = lora_base
150149 self .lora_path = lora_path
151150
151+ ### DEPRECATED ###
152+ self .n_parts = n_parts
153+ ### DEPRECATED ###
154+
152155 if not os .path .exists (model_path ):
153156 raise ValueError (f"Model path does not exist: { model_path } " )
154157
@@ -173,6 +176,30 @@ def __init__(
173176
174177 if self .verbose :
175178 print (llama_cpp .llama_print_system_info ().decode ("utf-8" ), file = sys .stderr )
179+
180+
181+ n_vocab = self .n_vocab ()
182+ n_ctx = self .n_ctx ()
183+ data = (llama_cpp .llama_token_data * n_vocab )(
184+ * [
185+ llama_cpp .llama_token_data (
186+ id = llama_cpp .llama_token (i ),
187+ logit = llama_cpp .c_float (0.0 ),
188+ p = llama_cpp .c_float (0.0 ),
189+ )
190+ for i in range (n_vocab )
191+ ]
192+ )
193+ size = llama_cpp .c_size_t (n_vocab )
194+ sorted = False
195+ candidates = llama_cpp .llama_token_data_array (
196+ data = data ,
197+ size = size ,
198+ sorted = sorted ,
199+ )
200+ self ._candidates = candidates
201+ self ._token_nl = Llama .token_nl ()
202+ self ._token_eos = Llama .token_eos ()
176203
177204 def tokenize (self , text : bytes , add_bos : bool = True ) -> List [int ]:
178205 """Tokenize a string.
@@ -293,33 +320,23 @@ def _sample(
293320 ):
294321 assert self .ctx is not None
295322 assert len (self .eval_logits ) > 0
296- n_vocab = int ( llama_cpp . llama_n_vocab ( self .ctx ) )
297- n_ctx = int ( llama_cpp . llama_n_ctx ( self .ctx ) )
323+ n_vocab = self .n_vocab ( )
324+ n_ctx = self .n_ctx ( )
298325 top_k = llama_cpp .c_int (n_vocab ) if top_k .value <= 0 else top_k
299326 last_n_tokens_size = (
300327 llama_cpp .c_int (n_ctx )
301328 if last_n_tokens_size .value < 0
302329 else last_n_tokens_size
303330 )
304331 logits = self .eval_logits [- 1 ]
305- nl_logit = logits [int (Llama .token_nl ())]
306- data = (llama_cpp .llama_token_data * n_vocab )(
307- * [
308- llama_cpp .llama_token_data (
309- id = llama_cpp .llama_token (i ),
310- logit = logits [i ],
311- p = llama_cpp .c_float (0.0 ),
312- )
313- for i in range (n_vocab )
314- ]
315- )
316- size = llama_cpp .c_size_t (n_vocab )
317- sorted = False
318- candidates = llama_cpp .llama_token_data_array (
319- data = data ,
320- size = size ,
321- sorted = sorted ,
322- )
332+ nl_logit = logits [self ._token_nl ]
333+ candidates = self ._candidates
334+ for i , logit in enumerate (logits ):
335+ candidates .data [i ].id = llama_cpp .llama_token (i )
336+ candidates .data [i ].logit = llama_cpp .c_float (logit )
337+ candidates .data [i ].p = llama_cpp .c_float (0.0 )
338+ candidates .sorted = llama_cpp .c_bool (False )
339+ candidates .size = llama_cpp .c_size_t (n_vocab )
323340 llama_cpp .llama_sample_repetition_penalty (
324341 ctx = self .ctx ,
325342 last_tokens_data = last_n_tokens_data ,
@@ -336,7 +353,7 @@ def _sample(
336353 alpha_presence = presence_penalty ,
337354 )
338355 if not penalize_nl :
339- candidates .data [int ( Llama . token_nl ()) ].logit = nl_logit
356+ candidates .data [self . _token_nl ].logit = llama_cpp . c_float ( nl_logit )
340357 if temp .value == 0.0 :
341358 return llama_cpp .llama_sample_token_greedy (
342359 ctx = self .ctx ,
@@ -685,7 +702,7 @@ def _create_completion(
685702 presence_penalty = presence_penalty ,
686703 repeat_penalty = repeat_penalty ,
687704 ):
688- if token == Llama . token_eos () :
705+ if token == self . _token_eos :
689706 text = self .detokenize (completion_tokens )
690707 finish_reason = "stop"
691708 break
@@ -1237,7 +1254,6 @@ def __getstate__(self):
12371254 verbose = self .verbose ,
12381255 model_path = self .model_path ,
12391256 n_ctx = self .params .n_ctx ,
1240- n_parts = self .params .n_parts ,
12411257 n_gpu_layers = self .params .n_gpu_layers ,
12421258 seed = self .params .seed ,
12431259 f16_kv = self .params .f16_kv ,
@@ -1251,6 +1267,9 @@ def __getstate__(self):
12511267 n_threads = self .n_threads ,
12521268 lora_base = self .lora_base ,
12531269 lora_path = self .lora_path ,
1270+ ### DEPRECATED ###
1271+ n_parts = self .n_parts ,
1272+ ### DEPRECATED ###
12541273 )
12551274
12561275 def __setstate__ (self , state ):
@@ -1303,6 +1322,21 @@ def load_state(self, state: LlamaState) -> None:
13031322 if llama_cpp .llama_set_state_data (self .ctx , state .llama_state ) != state_size :
13041323 raise RuntimeError ("Failed to set llama state data" )
13051324
1325+ def n_ctx (self ) -> int :
1326+ """Return the context window size."""
1327+ assert self .ctx is not None
1328+ return llama_cpp .llama_n_ctx (self .ctx )
1329+
1330+ def n_embd (self ) -> int :
1331+ """Return the embedding size."""
1332+ assert self .ctx is not None
1333+ return llama_cpp .llama_n_embd (self .ctx )
1334+
1335+ def n_vocab (self ) -> int :
1336+ """Return the vocabulary size."""
1337+ assert self .ctx is not None
1338+ return llama_cpp .llama_n_vocab (self .ctx )
1339+
13061340 @staticmethod
13071341 def token_eos () -> int :
13081342 """Return the end-of-sequence token."""
0 commit comments