Skip to content

Commit 922b5b2

Browse files
committed
Merge branch 'main' into server-embedding
2 parents e783f1c + 2c45255 commit 922b5b2

File tree

5 files changed

+239
-68
lines changed

5 files changed

+239
-68
lines changed

llama_cpp/llama.py

Lines changed: 58 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ def __init__(
127127

128128
self.params = llama_cpp.llama_context_default_params()
129129
self.params.n_ctx = n_ctx
130-
self.params.n_parts = n_parts
131130
self.params.n_gpu_layers = n_gpu_layers
132131
self.params.seed = seed
133132
self.params.f16_kv = f16_kv
@@ -149,6 +148,10 @@ def __init__(
149148
self.lora_base = lora_base
150149
self.lora_path = lora_path
151150

151+
### DEPRECATED ###
152+
self.n_parts = n_parts
153+
### DEPRECATED ###
154+
152155
if not os.path.exists(model_path):
153156
raise ValueError(f"Model path does not exist: {model_path}")
154157

@@ -173,6 +176,30 @@ def __init__(
173176

174177
if self.verbose:
175178
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
179+
180+
181+
n_vocab = self.n_vocab()
182+
n_ctx = self.n_ctx()
183+
data = (llama_cpp.llama_token_data * n_vocab)(
184+
*[
185+
llama_cpp.llama_token_data(
186+
id=llama_cpp.llama_token(i),
187+
logit=llama_cpp.c_float(0.0),
188+
p=llama_cpp.c_float(0.0),
189+
)
190+
for i in range(n_vocab)
191+
]
192+
)
193+
size = llama_cpp.c_size_t(n_vocab)
194+
sorted = False
195+
candidates = llama_cpp.llama_token_data_array(
196+
data=data,
197+
size=size,
198+
sorted=sorted,
199+
)
200+
self._candidates = candidates
201+
self._token_nl = Llama.token_nl()
202+
self._token_eos = Llama.token_eos()
176203

177204
def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
178205
"""Tokenize a string.
@@ -293,33 +320,23 @@ def _sample(
293320
):
294321
assert self.ctx is not None
295322
assert len(self.eval_logits) > 0
296-
n_vocab = int(llama_cpp.llama_n_vocab(self.ctx))
297-
n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
323+
n_vocab = self.n_vocab()
324+
n_ctx = self.n_ctx()
298325
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
299326
last_n_tokens_size = (
300327
llama_cpp.c_int(n_ctx)
301328
if last_n_tokens_size.value < 0
302329
else last_n_tokens_size
303330
)
304331
logits = self.eval_logits[-1]
305-
nl_logit = logits[int(Llama.token_nl())]
306-
data = (llama_cpp.llama_token_data * n_vocab)(
307-
*[
308-
llama_cpp.llama_token_data(
309-
id=llama_cpp.llama_token(i),
310-
logit=logits[i],
311-
p=llama_cpp.c_float(0.0),
312-
)
313-
for i in range(n_vocab)
314-
]
315-
)
316-
size = llama_cpp.c_size_t(n_vocab)
317-
sorted = False
318-
candidates = llama_cpp.llama_token_data_array(
319-
data=data,
320-
size=size,
321-
sorted=sorted,
322-
)
332+
nl_logit = logits[self._token_nl]
333+
candidates = self._candidates
334+
for i, logit in enumerate(logits):
335+
candidates.data[i].id = llama_cpp.llama_token(i)
336+
candidates.data[i].logit = llama_cpp.c_float(logit)
337+
candidates.data[i].p = llama_cpp.c_float(0.0)
338+
candidates.sorted = llama_cpp.c_bool(False)
339+
candidates.size = llama_cpp.c_size_t(n_vocab)
323340
llama_cpp.llama_sample_repetition_penalty(
324341
ctx=self.ctx,
325342
last_tokens_data=last_n_tokens_data,
@@ -336,7 +353,7 @@ def _sample(
336353
alpha_presence=presence_penalty,
337354
)
338355
if not penalize_nl:
339-
candidates.data[int(Llama.token_nl())].logit = nl_logit
356+
candidates.data[self._token_nl].logit = llama_cpp.c_float(nl_logit)
340357
if temp.value == 0.0:
341358
return llama_cpp.llama_sample_token_greedy(
342359
ctx=self.ctx,
@@ -685,7 +702,7 @@ def _create_completion(
685702
presence_penalty=presence_penalty,
686703
repeat_penalty=repeat_penalty,
687704
):
688-
if token == Llama.token_eos():
705+
if token == self._token_eos:
689706
text = self.detokenize(completion_tokens)
690707
finish_reason = "stop"
691708
break
@@ -1237,7 +1254,6 @@ def __getstate__(self):
12371254
verbose=self.verbose,
12381255
model_path=self.model_path,
12391256
n_ctx=self.params.n_ctx,
1240-
n_parts=self.params.n_parts,
12411257
n_gpu_layers=self.params.n_gpu_layers,
12421258
seed=self.params.seed,
12431259
f16_kv=self.params.f16_kv,
@@ -1251,6 +1267,9 @@ def __getstate__(self):
12511267
n_threads=self.n_threads,
12521268
lora_base=self.lora_base,
12531269
lora_path=self.lora_path,
1270+
### DEPRECATED ###
1271+
n_parts=self.n_parts,
1272+
### DEPRECATED ###
12541273
)
12551274

12561275
def __setstate__(self, state):
@@ -1303,6 +1322,21 @@ def load_state(self, state: LlamaState) -> None:
13031322
if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size:
13041323
raise RuntimeError("Failed to set llama state data")
13051324

1325+
def n_ctx(self) -> int:
1326+
"""Return the context window size."""
1327+
assert self.ctx is not None
1328+
return llama_cpp.llama_n_ctx(self.ctx)
1329+
1330+
def n_embd(self) -> int:
1331+
"""Return the embedding size."""
1332+
assert self.ctx is not None
1333+
return llama_cpp.llama_n_embd(self.ctx)
1334+
1335+
def n_vocab(self) -> int:
1336+
"""Return the vocabulary size."""
1337+
assert self.ctx is not None
1338+
return llama_cpp.llama_n_vocab(self.ctx)
1339+
13061340
@staticmethod
13071341
def token_eos() -> int:
13081342
"""Return the end-of-sequence token."""

0 commit comments

Comments
 (0)