Skip to content

Commit 98bbd1c

Browse files
committed
Fix eval logits type
1 parent b5f3e74 commit 98bbd1c

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

llama_cpp/llama.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def __init__(
127127
self.last_n_tokens_size = last_n_tokens_size
128128
self.n_batch = min(n_ctx, n_batch)
129129
self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
130-
self.eval_logits: Deque[List[llama_cpp.c_float]] = deque(
130+
self.eval_logits: Deque[List[float]] = deque(
131131
maxlen=n_ctx if logits_all else 1
132132
)
133133

@@ -245,7 +245,7 @@ def eval(self, tokens: Sequence[llama_cpp.llama_token]):
245245
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
246246
cols = int(n_vocab)
247247
logits_view = llama_cpp.llama_get_logits(self.ctx)
248-
logits: List[List[llama_cpp.c_float]] = [
248+
logits: List[List[float]] = [
249249
[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)
250250
]
251251
self.eval_logits.extend(logits)
@@ -287,7 +287,7 @@ def _sample_top_p_top_k(
287287
candidates=llama_cpp.ctypes.pointer(candidates),
288288
penalty=repeat_penalty,
289289
)
290-
if float(temp) == 0.0:
290+
if float(temp.value) == 0.0:
291291
return llama_cpp.llama_sample_token_greedy(
292292
ctx=self.ctx,
293293
candidates=llama_cpp.ctypes.pointer(candidates),

0 commit comments

Comments
 (0)