@@ -919,13 +919,13 @@ struct server_context {
919919 default_generation_settings_for_props = get_formated_generation (slots.front ());
920920 default_generation_settings_for_props[" seed" ] = -1 ;
921921
922- // the update_slots() logic will always submit a maximum of n_batch tokens
922+ // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens
923923 // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
924924 {
925925 const int32_t n_batch = llama_n_batch (ctx);
926926
927927 // only a single seq_id per token is needed
928- batch = llama_batch_init (n_batch, 0 , 1 );
928+ batch = llama_batch_init (std::max ( n_batch, params. n_parallel ) , 0 , 1 );
929929 }
930930
931931 metrics.init ();
@@ -1345,28 +1345,19 @@ struct server_context {
13451345 if (!system_prompt.empty ()) {
13461346 system_tokens = ::llama_tokenize (ctx, system_prompt, true );
13471347
1348- llama_batch_clear (batch);
1348+ const int32_t n_batch = llama_n_batch (ctx);
1349+ const int32_t n_tokens_prompt = system_tokens.size ();
13491350
1350- for (int i = 0 ; i < (int )system_tokens.size (); ++i) {
1351- llama_batch_add (batch, system_tokens[i], i, { 0 }, false );
1352- }
1351+ for (int32_t i = 0 ; i < n_tokens_prompt; i += n_batch) {
1352+ const int32_t n_tokens = std::min (n_batch, n_tokens_prompt - i);
13531353
1354- const int32_t n_batch = llama_n_batch (ctx );
1354+ llama_batch_clear (batch );
13551355
1356- for (int32_t i = 0 ; i < batch.n_tokens ; i += n_batch) {
1357- const int32_t n_tokens = std::min (params.n_batch , batch.n_tokens - i);
1358- llama_batch batch_view = {
1359- n_tokens,
1360- batch.token + i,
1361- nullptr ,
1362- batch.pos + i,
1363- batch.n_seq_id + i,
1364- batch.seq_id + i,
1365- batch.logits + i,
1366- 0 , 0 , 0 , // unused
1367- };
1356+ for (int32_t j = 0 ; j < n_tokens; ++j) {
1357+ llama_batch_add (batch, system_tokens[i + j], i + j, { 0 }, false );
1358+ }
13681359
1369- if (llama_decode (ctx, batch_view ) != 0 ) {
1360+ if (llama_decode (ctx, batch ) != 0 ) {
13701361 LOG_ERROR (" llama_decode() failed" , {});
13711362 return ;
13721363 }
0 commit comments