@@ -1215,7 +1215,7 @@ struct server_slot {
12151215 // only used for completion/embedding/infill/rerank
12161216 server_task_type task_type = SERVER_TASK_TYPE_COMPLETION;
12171217
1218- llama_batch_ptr batch_spec;
1218+ llama_batch_ext_ptr batch_spec;
12191219
12201220 llama_context * ctx = nullptr ;
12211221 llama_context * ctx_dft = nullptr ;
@@ -1787,7 +1787,7 @@ struct server_context {
17871787
17881788 llama_context_params cparams_dft;
17891789
1790- llama_batch_ptr batch;
1790+ llama_batch_ext_ptr batch;
17911791
17921792 bool clean_kv_cache = true ;
17931793 bool add_bos_token = true ;
@@ -1940,7 +1940,7 @@ struct server_context {
19401940 slot.n_predict = params_base.n_predict ;
19411941
19421942 if (model_dft) {
1943- slot.batch_spec .reset (llama_batch_init (params_base.speculative .n_max + 1 , 1 ));
1943+ slot.batch_spec .reset (llama_batch_ext_init (params_base.speculative .n_max + 1 , 1 ));
19441944
19451945 slot.ctx_dft = llama_init_from_model (model_dft, cparams_dft);
19461946 if (slot.ctx_dft == nullptr ) {
@@ -1976,7 +1976,7 @@ struct server_context {
19761976 const int32_t n_batch = llama_n_batch (ctx);
19771977
19781978 // only a single seq_id per token is needed
1979- batch.reset (llama_batch_init (std::max (n_batch, params_base.n_parallel ), 1 ));
1979+ batch.reset (llama_batch_ext_init (std::max (n_batch, params_base.n_parallel ), 1 ));
19801980 }
19811981
19821982 metrics.init ();
@@ -2094,7 +2094,7 @@ struct server_context {
20942094 }
20952095
20962096 if (slot.ctx_dft ) {
2097- slot.batch_spec .reset (llama_batch_init (slot.params .speculative .n_max + 1 , 1 ));
2097+ slot.batch_spec .reset (llama_batch_ext_init (slot.params .speculative .n_max + 1 , 1 ));
20982098 }
20992099
21002100 slot.state = SLOT_STATE_STARTED;
@@ -2402,7 +2402,7 @@ struct server_context {
24022402 queue_results.send (std::move (res));
24032403 }
24042404
2405- void send_embedding (const server_slot & slot, llama_batch_ptr & batch) {
2405+ void send_embedding (const server_slot & slot, llama_batch_ext_ptr & batch) {
24062406 auto res = std::make_unique<server_task_result_embd>();
24072407 res->id = slot.id_task ;
24082408 res->index = slot.index ;
@@ -2413,8 +2413,8 @@ struct server_context {
24132413
24142414 std::vector<float > embd_res (n_embd, 0 .0f );
24152415
2416- for (int i = 0 ; i < llama_batch_get_n_tokens (batch.get ()); ++i) {
2417- llama_batch_token_info tok = llama_batch_get_token_info (batch.get (), i);
2416+ for (int i = 0 ; i < llama_batch_ext_get_n_tokens (batch.get ()); ++i) {
2417+ llama_batch_ext_token_info tok = llama_batch_ext_get_token_info (batch.get (), i);
24182418 if (!tok.logits || tok.seq_id [0 ] != slot.id ) {
24192419 continue ;
24202420 }
@@ -2446,14 +2446,14 @@ struct server_context {
24462446 queue_results.send (std::move (res));
24472447 }
24482448
2449- void send_rerank (const server_slot & slot, llama_batch_ptr & batch) {
2449+ void send_rerank (const server_slot & slot, llama_batch_ext_ptr & batch) {
24502450 auto res = std::make_unique<server_task_result_rerank>();
24512451 res->id = slot.id_task ;
24522452 res->index = slot.index ;
24532453 res->n_tokens = slot.n_prompt_tokens ;
24542454
2455- for (int i = 0 ; i < llama_batch_get_n_tokens (batch.get ()); ++i) {
2456- llama_batch_token_info tok = llama_batch_get_token_info (batch.get (), i);
2455+ for (int i = 0 ; i < llama_batch_ext_get_n_tokens (batch.get ()); ++i) {
2456+ llama_batch_ext_token_info tok = llama_batch_ext_get_token_info (batch.get (), i);
24572457 if (!tok.logits || tok.seq_id [0 ] != slot.id ) {
24582458 continue ;
24592459 }
@@ -2855,7 +2855,7 @@ struct server_context {
28552855 }
28562856
28572857 // start populating the batch for this iteration
2858- common_batch_clear (batch.get ());
2858+ llama_batch_ext_clear (batch.get ());
28592859
28602860 // track if given slot can be batched with slots already in the batch
28612861 server_slot * slot_batched = nullptr ;
@@ -2877,9 +2877,10 @@ struct server_context {
28772877 continue ;
28782878 }
28792879
2880- slot.i_batch = llama_batch_get_n_tokens (batch.get ());
2880+ slot.i_batch = llama_batch_ext_get_n_tokens (batch.get ());
28812881
2882- common_batch_add (batch.get (), slot.sampled , slot.n_past , { slot.id }, true );
2882+ std::array<llama_token, 1 > seq_id = { slot.id };
2883+ llama_batch_ext_add_text_token (batch.get (), slot.sampled , slot.n_past , seq_id.data (), seq_id.size (), true );
28832884
28842885 slot.n_past += 1 ;
28852886
@@ -2896,7 +2897,7 @@ struct server_context {
28962897 int32_t n_ubatch = llama_n_ubatch (ctx);
28972898
28982899 // next, batch any pending prompts without exceeding n_batch
2899- if (params_base.cont_batching || llama_batch_get_n_tokens (batch.get ()) == 0 ) {
2900+ if (params_base.cont_batching || llama_batch_ext_get_n_tokens (batch.get ()) == 0 ) {
29002901 for (auto & slot : slots) {
29012902 // check if we can batch this slot with the previous one
29022903 if (slot.is_processing ()) {
@@ -3062,7 +3063,7 @@ struct server_context {
30623063 // non-causal tasks require to fit the entire prompt in the physical batch
30633064 if (slot.is_non_causal ()) {
30643065 // cannot fit the prompt in the current batch - will try next iter
3065- if (llama_batch_get_n_tokens (batch.get ()) + slot.n_prompt_tokens > n_batch) {
3066+ if (llama_batch_ext_get_n_tokens (batch.get ()) + slot.n_prompt_tokens > n_batch) {
30663067 continue ;
30673068 }
30683069 }
@@ -3082,11 +3083,12 @@ struct server_context {
30823083 slot.cache_tokens .resize (slot.n_past );
30833084
30843085 // add prompt tokens for processing in the current batch
3085- while (slot.n_past < slot.n_prompt_tokens && llama_batch_get_n_tokens (batch.get ()) < n_batch) {
3086+ while (slot.n_past < slot.n_prompt_tokens && llama_batch_ext_get_n_tokens (batch.get ()) < n_batch) {
30863087 // without pooling, we want to output the embeddings for all the tokens in the batch
30873088 const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type (slot.ctx ) == LLAMA_POOLING_TYPE_NONE;
30883089
3089- common_batch_add (batch.get (), prompt_tokens[slot.n_past ], slot.n_past , { slot.id }, need_embd);
3090+ std::array<llama_token, 1 > seq_id = { slot.id };
3091+ llama_batch_ext_add_text_token (batch.get (), prompt_tokens[slot.n_past ], slot.n_past , seq_id.data (), seq_id.size (), true );
30903092
30913093 if (slot.params .cache_prompt ) {
30923094 slot.cache_tokens .push_back (prompt_tokens[slot.n_past ]);
@@ -3096,13 +3098,13 @@ struct server_context {
30963098 slot.n_past ++;
30973099 }
30983100
3099- SLT_INF (slot, " prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n " , slot.n_past , llama_batch_get_n_tokens (batch.get ()), (float ) slot.n_prompt_tokens_processed / slot.n_prompt_tokens );
3101+ SLT_INF (slot, " prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n " , slot.n_past , llama_batch_ext_get_n_tokens (batch.get ()), (float ) slot.n_prompt_tokens_processed / slot.n_prompt_tokens );
31003102
31013103 // entire prompt has been processed
31023104 if (slot.n_past == slot.n_prompt_tokens ) {
31033105 slot.state = SLOT_STATE_DONE_PROMPT;
31043106
3105- GGML_ASSERT (llama_batch_get_n_tokens (batch.get ()) > 0 );
3107+ GGML_ASSERT (llama_batch_ext_get_n_tokens (batch.get ()) > 0 );
31063108
31073109 common_sampler_reset (slot.smpl );
31083110
@@ -3112,27 +3114,27 @@ struct server_context {
31123114 }
31133115
31143116 // extract the logits only for the last token
3115- llama_batch_set_logits_last (batch.get ());
3117+ llama_batch_ext_set_logits_last (batch.get ());
31163118
31173119 slot.n_decoded = 0 ;
3118- slot.i_batch = llama_batch_get_n_tokens (batch.get ()) - 1 ;
3120+ slot.i_batch = llama_batch_ext_get_n_tokens (batch.get ()) - 1 ;
31193121
3120- SLT_INF (slot, " prompt done, n_past = %d, n_tokens = %d\n " , slot.n_past , llama_batch_get_n_tokens (batch.get ()));
3122+ SLT_INF (slot, " prompt done, n_past = %d, n_tokens = %d\n " , slot.n_past , llama_batch_ext_get_n_tokens (batch.get ()));
31213123 }
31223124 }
31233125
3124- if (llama_batch_get_n_tokens (batch.get ()) >= n_batch) {
3126+ if (llama_batch_ext_get_n_tokens (batch.get ()) >= n_batch) {
31253127 break ;
31263128 }
31273129 }
31283130 }
31293131
3130- if (llama_batch_get_n_tokens (batch.get ()) == 0 ) {
3132+ if (llama_batch_ext_get_n_tokens (batch.get ()) == 0 ) {
31313133 SRV_WRN (" %s" , " no tokens to decode\n " );
31323134 return ;
31333135 }
31343136
3135- SRV_DBG (" decoding batch, n_tokens = %d\n " , llama_batch_get_n_tokens (batch.get ()));
3137+ SRV_DBG (" decoding batch, n_tokens = %d\n " , llama_batch_ext_get_n_tokens (batch.get ()));
31363138
31373139 if (slot_batched) {
31383140 // make sure we're in the right embedding mode
@@ -3142,12 +3144,12 @@ struct server_context {
31423144 }
31433145
31443146 // process the created batch of tokens
3145- for (int32_t i = 0 ; i < llama_batch_get_n_tokens (batch.get ()); i += n_batch) {
3146- const int32_t n_tokens = std::min (n_batch, llama_batch_get_n_tokens (batch.get ()) - i);
3147+ for (int32_t i = 0 ; i < llama_batch_ext_get_n_tokens (batch.get ()); i += n_batch) {
3148+ const int32_t n_tokens = std::min (n_batch, llama_batch_ext_get_n_tokens (batch.get ()) - i);
31473149
3148- llama_batch_ptr batch_view (llama_batch_get_view (batch.get (), i, n_tokens));
3150+ llama_batch_ext_ptr batch_view (llama_batch_ext_get_view (batch.get (), i, n_tokens));
31493151
3150- const int ret = llama_decode (ctx, batch_view.get ());
3152+ const int ret = llama_text_decode (ctx, batch_view.get ());
31513153 metrics.on_decoded (slots);
31523154
31533155 if (ret != 0 ) {
@@ -3282,16 +3284,17 @@ struct server_context {
32823284 }
32833285
32843286 // construct the speculation batch
3285- common_batch_clear (slot.batch_spec .get ());
3286- common_batch_add (slot.batch_spec .get (), id, slot.n_past , { slot.id }, true );
3287+ llama_batch_ext_clear (slot.batch_spec .get ());
3288+ std::array<llama_token, 1 > seq_id = { slot.id };
3289+ llama_batch_ext_add_text_token (slot.batch_spec .get (), id, slot.n_past , seq_id.data (), seq_id.size (), true );
32873290
32883291 for (size_t i = 0 ; i < draft.size (); ++i) {
3289- common_batch_add (slot.batch_spec .get (), draft[i], slot.n_past + 1 + i, { slot. id } , true );
3292+ llama_batch_ext_add_text_token (slot.batch_spec .get (), draft[i], slot.n_past + 1 , seq_id. data (), seq_id. size () , true );
32903293 }
32913294
3292- SLT_DBG (slot, " decoding speculative batch, size = %d\n " , llama_batch_get_n_tokens (slot.batch_spec .get ()));
3295+ SLT_DBG (slot, " decoding speculative batch, size = %d\n " , llama_batch_ext_get_n_tokens (slot.batch_spec .get ()));
32933296
3294- llama_decode (ctx, slot.batch_spec .get ());
3297+ llama_text_decode (ctx, slot.batch_spec .get ());
32953298
32963299 // the accepted tokens from the speculation
32973300 const auto ids = common_sampler_sample_and_accept_n (slot.smpl , ctx, draft);
0 commit comments