Skip to content

Commit 2a6952b

Browse files
committed
cont : fix rerank
ggml-ci
1 parent e8ddfa3 commit 2a6952b

File tree

2 files changed

+3
-5
lines changed

2 files changed

+3
-5
lines changed

src/llama-memory.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ struct llama_memory_i {
7070
// split the input batch into a set of ubatches and verify that they can fit into the cache
7171
// return a state object containing the ubatches and KV cache state required to process them
7272
// check the llama_memory_state_i::get_status() for the result
73-
// TODO: remove embd_all argument
7473
virtual llama_memory_state_ptr init_batch(
7574
const llama_batch & batch,
7675
uint32_t n_ubatch,

tools/server/server.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1901,8 +1901,8 @@ struct server_context {
19011901
llama_batch_free(batch);
19021902
}
19031903

1904-
// if the context does not have a memory module then all inputs have to be processed within a single ubatch
1905-
// also we cannot split if the input requires any past tokens
1904+
// if the context does not have a memory module then all embeddings have to be computed without a single ubatch
1905+
// also we cannot split if the pooling requires any past tokens
19061906
bool can_split() const {
19071907
return
19081908
!llama_get_embeddings(ctx) ||
@@ -3238,7 +3238,6 @@ struct server_context {
32383238
slot.n_prompt_tokens_processed = 0;
32393239
}
32403240

3241-
// non-causal tasks require to fit the entire prompt in the physical batch
32423241
if (!can_split()) {
32433242
// cannot fit the prompt in the current batch - will try next iter
32443243
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
@@ -3293,7 +3292,7 @@ struct server_context {
32933292
}
32943293

32953294
// embedding requires all tokens in the batch to be output
3296-
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING;
3295+
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING || slot.task_type == SERVER_TASK_TYPE_RERANK;
32973296

32983297
common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd);
32993298
slot.cache_tokens.push_back(cur_tok);

0 commit comments

Comments
 (0)