@@ -754,13 +754,13 @@ struct server_context {
754754 default_generation_settings_for_props = get_formated_generation (slots.front ());
755755 default_generation_settings_for_props[" seed" ] = -1 ;
756756
757- // the update_slots() logic will always submit a maximum of n_batch tokens
757+ // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens
758758 // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
759759 {
760760 const int32_t n_batch = llama_n_batch (ctx);
761761
762762 // only a single seq_id per token is needed
763- batch = llama_batch_init (n_batch, 0 , 1 );
763+ batch = llama_batch_init (std::max ( n_batch, params. n_parallel ) , 0 , 1 );
764764 }
765765
766766 metrics.init ();
@@ -1137,28 +1137,19 @@ struct server_context {
11371137 if (!system_prompt.empty ()) {
11381138 system_tokens = ::llama_tokenize (ctx, system_prompt, true );
11391139
1140- llama_batch_clear (batch);
1140+ const int32_t n_batch = llama_n_batch (ctx);
1141+ const int32_t n_tokens_prompt = system_tokens.size ();
11411142
1142- for (int i = 0 ; i < (int )system_tokens.size (); ++i) {
1143- llama_batch_add (batch, system_tokens[i], i, { 0 }, false );
1144- }
1143+ for (int32_t i = 0 ; i < n_tokens_prompt; i += n_batch) {
1144+ const int32_t n_tokens = std::min (n_batch, n_tokens_prompt - i);
11451145
1146- const int32_t n_batch = llama_n_batch (ctx );
1146+ llama_batch_clear (batch );
11471147
1148- for (int32_t i = 0 ; i < batch.n_tokens ; i += n_batch) {
1149- const int32_t n_tokens = std::min (params.n_batch , batch.n_tokens - i);
1150- llama_batch batch_view = {
1151- n_tokens,
1152- batch.token + i,
1153- nullptr ,
1154- batch.pos + i,
1155- batch.n_seq_id + i,
1156- batch.seq_id + i,
1157- batch.logits + i,
1158- 0 , 0 , 0 , // unused
1159- };
1148+ for (int32_t j = 0 ; j < n_tokens; ++j) {
1149+ llama_batch_add (batch, system_tokens[i + j], i + j, { 0 }, false );
1150+ }
11601151
1161- if (llama_decode (ctx, batch_view ) != 0 ) {
1152+ if (llama_decode (ctx, batch ) != 0 ) {
11621153 LOG_ERROR (" llama_decode() failed" , {});
11631154 return ;
11641155 }
0 commit comments