Skip to content

Commit 51e77fa

Browse files
committed
fix missing logits_all params
1 parent dbfc1cf commit 51e77fa

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

llama_cpp/llama.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1065,6 +1065,7 @@ def embed(
10651065

10661066
# get pooling information
10671067
pooling_type = self.pooling_type()
1068+
logits_all = pooling_type == llama_cpp.LLAMA_POOLING_TYPE_NONE
10681069

10691070
if self.context_params.embeddings is False:
10701071
raise RuntimeError(
@@ -1142,7 +1143,7 @@ def decode_batch(seq_sizes: List[int]):
11421143
p_batch = 0
11431144

11441145
# add to batch
1145-
self._batch.add_sequence(tokens, p_batch)
1146+
self._batch.add_sequence(tokens, p_batch, logits_all)
11461147

11471148
# update batch stats
11481149
s_batch.append(n_tokens)

tests/test_llama.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def test_real_model(llama_cpp_model_path):
8181
cparams.n_ubatch = 16
8282
cparams.n_threads = multiprocessing.cpu_count()
8383
cparams.n_threads_batch = multiprocessing.cpu_count()
84+
cparams.logits_all = False
8485
cparams.flash_attn = True
8586
cparams.swa_full = True
8687

@@ -103,7 +104,7 @@ def test_real_model(llama_cpp_model_path):
103104
result = tokens
104105
n_eval = 0
105106
for _ in range(4):
106-
batch.set_batch(tokens, n_past=n_eval)
107+
batch.set_batch(tokens, n_past=n_eval, logits_all=False)
107108
context.decode(batch)
108109
n_eval += len(tokens)
109110
token_id = sampler.sample(context, -1)
@@ -122,8 +123,8 @@ def test_real_llama(llama_cpp_model_path):
122123
n_ubatch=32,
123124
n_threads=multiprocessing.cpu_count(),
124125
n_threads_batch=multiprocessing.cpu_count(),
125-
flash_attn=True,
126126
logits_all=False,
127+
flash_attn=True,
127128
swa_full = True
128129
)
129130

0 commit comments

Comments
 (0)