@@ -1090,6 +1090,10 @@ struct server_slot {
10901090 return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK;
10911091 }
10921092
1093+ bool can_batch_with (server_slot & other_slot) {
1094+ return is_non_causal () == other_slot.is_non_causal ();
1095+ }
1096+
10931097 bool has_budget (const common_params & global_params) {
10941098 if (params.n_predict == -1 && global_params.n_predict == -1 ) {
10951099 return true ; // limitless
@@ -2564,11 +2568,8 @@ struct server_context {
25642568 int32_t n_batch = llama_n_batch (ctx);
25652569 int32_t n_ubatch = llama_n_ubatch (ctx);
25662570
2567- // track if this is an embedding or non-embedding batch
2568- // if we've added sampled tokens above, we are in non-embedding mode
2569- // -1: none, 0: non-embedding, 1: embedding
2570- // TODO: make enum
2571- int32_t batch_type = batch.n_tokens > 0 ? 0 : -1 ;
2571+ // track if given slot can be batched with slots already in the batch
2572+ server_slot * slot_batched = nullptr ;
25722573
25732574 // next, batch any pending prompts without exceeding n_batch
25742575 if (params_base.cont_batching || batch.n_tokens == 0 ) {
@@ -2733,11 +2734,10 @@ struct server_context {
27332734 }
27342735 }
27352736
2736- // check that we are in the right batch_type, if not defer the slot
2737- int slot_type = slot.is_non_causal ();
2738- if (batch_type == -1 ) {
2739- batch_type = slot_type;
2740- } else if (batch_type != slot_type) {
2737+ // check if we can batch this slot with the previous one
2738+ if (!slot_batched) {
2739+ slot_batched = &slot;
2740+ } else if (slot_batched && !slot_batched->can_batch_with (slot)) {
27412741 continue ;
27422742 }
27432743
@@ -2809,7 +2809,7 @@ struct server_context {
28092809 SRV_DBG (" decoding batch, n_tokens = %d\n " , batch.n_tokens );
28102810
28112811 // make sure we're in the right embedding mode
2812- llama_set_embeddings (ctx, batch_type == 1 );
2812+ llama_set_embeddings (ctx, slot_batched && slot_batched-> is_non_causal () );
28132813
28142814 // process the created batch of tokens
28152815 for (int32_t i = 0 ; i < batch.n_tokens ; i += n_batch) {
0 commit comments