Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,7 @@ extern "C" {
LLAMA_API int32_t llama_n_threads_batch(struct llama_context * ctx);

// Set whether the context outputs embeddings or not
// TODO: rename to avoid confusion with llama_get_embeddings()
LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);

// Set whether to use causal attention or not
Expand Down
20 changes: 10 additions & 10 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1358,6 +1358,14 @@ struct server_slot {
return server_task_type_need_logits(task_type);
}

// if the context does not have a memory module then all embeddings have to be computed within a single ubatch
// also we cannot split if the pooling would require any past tokens
bool can_split() const {
return
!need_embd() ||
(llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST);
}

bool can_batch_with(server_slot & other_slot) const {
return task_type == other_slot.task_type && are_lora_equal(lora, other_slot.lora);
}
Expand Down Expand Up @@ -1929,14 +1937,6 @@ struct server_context {
llama_batch_free(batch);
}

// if the context does not have a memory module then all embeddings have to be computed within a single ubatch
// also we cannot split if the pooling would require any past tokens
bool can_split() const {
return
!llama_get_embeddings(ctx) ||
(llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST);
}

bool load_model(const common_params & params) {
SRV_INF("loading model '%s'\n", params.model.path.c_str());

Expand Down Expand Up @@ -3130,7 +3130,7 @@ struct server_context {
continue;
}

if (!can_split()) {
if (!slot.can_split()) {
if (slot.n_prompt_tokens > n_ubatch) {
slot.release();
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
Expand Down Expand Up @@ -3273,7 +3273,7 @@ struct server_context {
slot.n_prompt_tokens_processed = 0;
}

if (!can_split()) {
if (!slot.can_split()) {
// cannot fit the prompt in the current batch - will try next iter
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
continue;
Expand Down
Loading