Skip to content

Commit 2ba6efc

Browse files
committed
slot.can_batch_with
1 parent d79d8f3 commit 2ba6efc

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

examples/server/server.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,6 +1090,10 @@ struct server_slot {
10901090
return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK;
10911091
}
10921092

1093+
bool can_batch_with(server_slot & other_slot) {
1094+
return is_non_causal() == other_slot.is_non_causal();
1095+
}
1096+
10931097
bool has_budget(const common_params & global_params) {
10941098
if (params.n_predict == -1 && global_params.n_predict == -1) {
10951099
return true; // limitless
@@ -2564,11 +2568,8 @@ struct server_context {
25642568
int32_t n_batch = llama_n_batch(ctx);
25652569
int32_t n_ubatch = llama_n_ubatch(ctx);
25662570

2567-
// track if this is an embedding or non-embedding batch
2568-
// if we've added sampled tokens above, we are in non-embedding mode
2569-
// -1: none, 0: non-embedding, 1: embedding
2570-
// TODO: make enum
2571-
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
2571+
// track if given slot can be batched with slots already in the batch
2572+
server_slot * slot_batched = nullptr;
25722573

25732574
// next, batch any pending prompts without exceeding n_batch
25742575
if (params_base.cont_batching || batch.n_tokens == 0) {
@@ -2733,11 +2734,10 @@ struct server_context {
27332734
}
27342735
}
27352736

2736-
// check that we are in the right batch_type, if not defer the slot
2737-
int slot_type = slot.is_non_causal();
2738-
if (batch_type == -1) {
2739-
batch_type = slot_type;
2740-
} else if (batch_type != slot_type) {
2737+
// check if we can batch this slot with the previous one
2738+
if (!slot_batched) {
2739+
slot_batched = &slot;
2740+
} else if (slot_batched && !slot_batched->can_batch_with(slot)) {
27412741
continue;
27422742
}
27432743

@@ -2809,7 +2809,7 @@ struct server_context {
28092809
SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
28102810

28112811
// make sure we're in the right embedding mode
2812-
llama_set_embeddings(ctx, batch_type == 1);
2812+
llama_set_embeddings(ctx, slot_batched && slot_batched->is_non_causal());
28132813

28142814
// process the created batch of tokens
28152815
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {

0 commit comments

Comments
 (0)