@@ -3341,6 +3341,37 @@ struct server_context {
33413341 common_set_adapter_lora (ctx, slot_batched->lora );
33423342 }
33433343
3344+ const bool do_encode = (params_base.embedding || params_base.reranking );
3345+
3346+ // pad the batch so that batch.n_tokens >= n_slots
3347+ // TODO: temporary workaround for https://github.com/ggml-org/llama.cpp/issues/13689
3348+ if (do_encode) {
3349+ const int n_slots = slots.size ();
3350+
3351+ if (batch.n_tokens < n_slots) {
3352+ std::set<llama_seq_id> seq_ids;
3353+ for (int j = 0 ; j < batch.n_tokens ; ++j) {
3354+ seq_ids.insert (batch.seq_id [j][0 ]);
3355+ }
3356+
3357+ // find unused sequence id
3358+ llama_seq_id seq_id = -1 ;
3359+ for (int i = 0 ; i < n_slots; ++i) {
3360+ if (seq_ids.find (i) == seq_ids.end ()) {
3361+ seq_id = i;
3362+ }
3363+ }
3364+
3365+ const int n_add = n_slots - batch.n_tokens ;
3366+
3367+ SRV_WRN (" adding %d dummy tokens to the batch, seq_id = %d\n " , n_add, seq_id);
3368+
3369+ for (int j = 0 ; j < n_add; ++j) {
3370+ common_batch_add (batch, 0 , j, { seq_id }, false );
3371+ }
3372+ }
3373+ }
3374+
33443375 // process the created batch of tokens
33453376 for (int32_t i = 0 ; i < batch.n_tokens ; i += n_batch) {
33463377 const int32_t n_tokens = std::min (n_batch, batch.n_tokens - i);
@@ -3357,7 +3388,7 @@ struct server_context {
33573388
33583389 int ret = 0 ;
33593390
3360- if (params_base. embedding || params_base. reranking ) {
3391+ if (do_encode ) {
33613392 ret = llama_encode (ctx, batch_view);
33623393 } else {
33633394 ret = llama_decode (ctx, batch_view);
0 commit comments