Skip to content

Commit 3abc049

Browse files
Merge branch 'llama-cpp-logits-fix'
2 parents ed14ec3 + e3b7766 commit 3abc049

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/lmql/models/lmtp/backends/llama_cpp_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ def score(self, input_ids, attention_mask, **model_kwargs):
4040
self.llm.n_tokens = longest_prefix
4141

4242
self.llm.eval(tokens)
43-
scores = np.array([self.llm.scores[j][i] for j,i in enumerate(input_ids[0])])
44-
scores = nputil.log_softmax(scores, axis=-1)
45-
# print("llama_cpp_model: score() took", time.time() - s, "seconds", file=sys.stderr)
43+
logits = np.array(self.llm.scores)
44+
logits = nputil.log_softmax(logits, axis=-1)
45+
scores = np.array([logits[j][i] for j,i in enumerate(input_ids[0])])
4646

4747
return scores.reshape(1, -1)
4848

0 commit comments

Comments
 (0)