@@ -1215,7 +1215,7 @@ struct server_slot {
12151215 // only used for completion/embedding/infill/rerank
12161216 server_task_type task_type = SERVER_TASK_TYPE_COMPLETION;
12171217
1218- llama_batch batch_spec = {} ;
1218+ llama_batch_ptr batch_spec;
12191219
12201220 llama_context * ctx = nullptr ;
12211221 llama_context * ctx_dft = nullptr ;
@@ -1787,7 +1787,7 @@ struct server_context {
17871787
17881788 llama_context_params cparams_dft;
17891789
1790- llama_batch batch = {} ;
1790+ llama_batch_ptr batch;
17911791
17921792 bool clean_kv_cache = true ;
17931793 bool add_bos_token = true ;
@@ -1820,11 +1820,7 @@ struct server_context {
18201820
18211821 common_speculative_free (slot.spec );
18221822 slot.spec = nullptr ;
1823-
1824- llama_batch_free (slot.batch_spec );
18251823 }
1826-
1827- llama_batch_free (batch);
18281824 }
18291825
18301826 bool load_model (const common_params & params) {
@@ -1944,7 +1940,7 @@ struct server_context {
19441940 slot.n_predict = params_base.n_predict ;
19451941
19461942 if (model_dft) {
1947- slot.batch_spec = llama_batch_init (params_base.speculative .n_max + 1 , 0 , 1 );
1943+ slot.batch_spec . reset ( llama_batch_init (params_base.speculative .n_max + 1 , 1 ) );
19481944
19491945 slot.ctx_dft = llama_init_from_model (model_dft, cparams_dft);
19501946 if (slot.ctx_dft == nullptr ) {
@@ -1969,7 +1965,7 @@ struct server_context {
19691965
19701966 slot.reset ();
19711967
1972- slots.push_back (slot);
1968+ slots.push_back (std::move ( slot) );
19731969 }
19741970
19751971 default_generation_settings_for_props = slots[0 ].to_json ();
@@ -1980,7 +1976,7 @@ struct server_context {
19801976 const int32_t n_batch = llama_n_batch (ctx);
19811977
19821978 // only a single seq_id per token is needed
1983- batch = llama_batch_init (std::max (n_batch, params_base.n_parallel ), 0 , 1 );
1979+ batch. reset ( llama_batch_init (std::max (n_batch, params_base.n_parallel ), 1 ) );
19841980 }
19851981
19861982 metrics.init ();
@@ -2098,9 +2094,7 @@ struct server_context {
20982094 }
20992095
21002096 if (slot.ctx_dft ) {
2101- llama_batch_free (slot.batch_spec );
2102-
2103- slot.batch_spec = llama_batch_init (slot.params .speculative .n_max + 1 , 0 , 1 );
2097+ slot.batch_spec .reset (llama_batch_init (slot.params .speculative .n_max + 1 , 1 ));
21042098 }
21052099
21062100 slot.state = SLOT_STATE_STARTED;
@@ -2408,7 +2402,7 @@ struct server_context {
24082402 queue_results.send (std::move (res));
24092403 }
24102404
2411- void send_embedding (const server_slot & slot, const llama_batch & batch) {
2405+ void send_embedding (const server_slot & slot, llama_batch_ptr & batch) {
24122406 auto res = std::make_unique<server_task_result_embd>();
24132407 res->id = slot.id_task ;
24142408 res->index = slot.index ;
@@ -2419,18 +2413,19 @@ struct server_context {
24192413
24202414 std::vector<float > embd_res (n_embd, 0 .0f );
24212415
2422- for (int i = 0 ; i < batch.n_tokens ; ++i) {
2423- if (!batch.logits [i] || batch.seq_id [i][0 ] != slot.id ) {
2416+ for (int i = 0 ; i < llama_batch_get_n_tokens (batch.get ()); ++i) {
2417+ llama_batch_token_info tok = llama_batch_get_token_info (batch.get (), i);
2418+ if (!tok.logits || tok.seq_id [0 ] != slot.id ) {
24242419 continue ;
24252420 }
24262421
2427- const float * embd = llama_get_embeddings_seq (ctx, batch .seq_id [i] [0 ]);
2422+ const float * embd = llama_get_embeddings_seq (ctx, tok .seq_id [0 ]);
24282423 if (embd == NULL ) {
24292424 embd = llama_get_embeddings_ith (ctx, i);
24302425 }
24312426
24322427 if (embd == NULL ) {
2433- SLT_ERR (slot, " failed to get embeddings, token = %d, seq_id = %d\n " , batch .token [i], batch .seq_id [i] [0 ]);
2428+ SLT_ERR (slot, " failed to get embeddings, token = %d, seq_id = %d\n " , tok .token , tok .seq_id [0 ]);
24342429
24352430 res->embedding .push_back (std::vector<float >(n_embd, 0 .0f ));
24362431 continue ;
@@ -2451,24 +2446,25 @@ struct server_context {
24512446 queue_results.send (std::move (res));
24522447 }
24532448
2454- void send_rerank (const server_slot & slot, const llama_batch & batch) {
2449+ void send_rerank (const server_slot & slot, llama_batch_ptr & batch) {
24552450 auto res = std::make_unique<server_task_result_rerank>();
24562451 res->id = slot.id_task ;
24572452 res->index = slot.index ;
24582453 res->n_tokens = slot.n_prompt_tokens ;
24592454
2460- for (int i = 0 ; i < batch.n_tokens ; ++i) {
2461- if (!batch.logits [i] || batch.seq_id [i][0 ] != slot.id ) {
2455+ for (int i = 0 ; i < llama_batch_get_n_tokens (batch.get ()); ++i) {
2456+ llama_batch_token_info tok = llama_batch_get_token_info (batch.get (), i);
2457+ if (!tok.logits || tok.seq_id [0 ] != slot.id ) {
24622458 continue ;
24632459 }
24642460
2465- const float * embd = llama_get_embeddings_seq (ctx, batch .seq_id [i] [0 ]);
2461+ const float * embd = llama_get_embeddings_seq (ctx, tok .seq_id [0 ]);
24662462 if (embd == NULL ) {
24672463 embd = llama_get_embeddings_ith (ctx, i);
24682464 }
24692465
24702466 if (embd == NULL ) {
2471- SLT_ERR (slot, " failed to get embeddings, token = %d, seq_id = %d\n " , batch .token [i], batch .seq_id [i] [0 ]);
2467+ SLT_ERR (slot, " failed to get embeddings, token = %d, seq_id = %d\n " , tok .token , tok .seq_id [0 ]);
24722468
24732469 res->score = -1e6 ;
24742470 continue ;
@@ -2859,7 +2855,7 @@ struct server_context {
28592855 }
28602856
28612857 // start populating the batch for this iteration
2862- common_batch_clear (batch);
2858+ common_batch_clear (batch. get () );
28632859
28642860 // track if given slot can be batched with slots already in the batch
28652861 server_slot * slot_batched = nullptr ;
@@ -2881,9 +2877,9 @@ struct server_context {
28812877 continue ;
28822878 }
28832879
2884- slot.i_batch = batch.n_tokens ;
2880+ slot.i_batch = llama_batch_get_n_tokens ( batch.get ()) ;
28852881
2886- common_batch_add (batch, slot.sampled , slot.n_past , { slot.id }, true );
2882+ common_batch_add (batch. get () , slot.sampled , slot.n_past , { slot.id }, true );
28872883
28882884 slot.n_past += 1 ;
28892885
@@ -2900,7 +2896,7 @@ struct server_context {
29002896 int32_t n_ubatch = llama_n_ubatch (ctx);
29012897
29022898 // next, batch any pending prompts without exceeding n_batch
2903- if (params_base.cont_batching || batch.n_tokens == 0 ) {
2899+ if (params_base.cont_batching || llama_batch_get_n_tokens ( batch.get ()) == 0 ) {
29042900 for (auto & slot : slots) {
29052901 // check if we can batch this slot with the previous one
29062902 if (slot.is_processing ()) {
@@ -3066,7 +3062,7 @@ struct server_context {
30663062 // non-causal tasks require to fit the entire prompt in the physical batch
30673063 if (slot.is_non_causal ()) {
30683064 // cannot fit the prompt in the current batch - will try next iter
3069- if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
3065+ if (llama_batch_get_n_tokens ( batch.get ()) + slot.n_prompt_tokens > n_batch) {
30703066 continue ;
30713067 }
30723068 }
@@ -3086,11 +3082,11 @@ struct server_context {
30863082 slot.cache_tokens .resize (slot.n_past );
30873083
30883084 // add prompt tokens for processing in the current batch
3089- while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
3085+ while (slot.n_past < slot.n_prompt_tokens && llama_batch_get_n_tokens ( batch.get ()) < n_batch) {
30903086 // without pooling, we want to output the embeddings for all the tokens in the batch
30913087 const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type (slot.ctx ) == LLAMA_POOLING_TYPE_NONE;
30923088
3093- common_batch_add (batch, prompt_tokens[slot.n_past ], slot.n_past , { slot.id }, need_embd);
3089+ common_batch_add (batch. get () , prompt_tokens[slot.n_past ], slot.n_past , { slot.id }, need_embd);
30943090
30953091 if (slot.params .cache_prompt ) {
30963092 slot.cache_tokens .push_back (prompt_tokens[slot.n_past ]);
@@ -3100,13 +3096,13 @@ struct server_context {
31003096 slot.n_past ++;
31013097 }
31023098
3103- SLT_INF (slot, " prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n " , slot.n_past , batch.n_tokens , (float ) slot.n_prompt_tokens_processed / slot.n_prompt_tokens );
3099+ SLT_INF (slot, " prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n " , slot.n_past , llama_batch_get_n_tokens ( batch.get ()) , (float ) slot.n_prompt_tokens_processed / slot.n_prompt_tokens );
31043100
31053101 // entire prompt has been processed
31063102 if (slot.n_past == slot.n_prompt_tokens ) {
31073103 slot.state = SLOT_STATE_DONE_PROMPT;
31083104
3109- GGML_ASSERT (batch.n_tokens > 0 );
3105+ GGML_ASSERT (llama_batch_get_n_tokens ( batch.get ()) > 0 );
31103106
31113107 common_sampler_reset (slot.smpl );
31123108
@@ -3116,27 +3112,27 @@ struct server_context {
31163112 }
31173113
31183114 // extract the logits only for the last token
3119- batch.logits [batch. n_tokens - 1 ] = true ;
3115+ llama_batch_set_logits_last ( batch.get ()) ;
31203116
31213117 slot.n_decoded = 0 ;
3122- slot.i_batch = batch.n_tokens - 1 ;
3118+ slot.i_batch = llama_batch_get_n_tokens ( batch.get ()) - 1 ;
31233119
3124- SLT_INF (slot, " prompt done, n_past = %d, n_tokens = %d\n " , slot.n_past , batch.n_tokens );
3120+ SLT_INF (slot, " prompt done, n_past = %d, n_tokens = %d\n " , slot.n_past , llama_batch_get_n_tokens ( batch.get ()) );
31253121 }
31263122 }
31273123
3128- if (batch.n_tokens >= n_batch) {
3124+ if (llama_batch_get_n_tokens ( batch.get ()) >= n_batch) {
31293125 break ;
31303126 }
31313127 }
31323128 }
31333129
3134- if (batch.n_tokens == 0 ) {
3130+ if (llama_batch_get_n_tokens ( batch.get ()) == 0 ) {
31353131 SRV_WRN (" %s" , " no tokens to decode\n " );
31363132 return ;
31373133 }
31383134
3139- SRV_DBG (" decoding batch, n_tokens = %d\n " , batch.n_tokens );
3135+ SRV_DBG (" decoding batch, n_tokens = %d\n " , llama_batch_get_n_tokens ( batch.get ()) );
31403136
31413137 if (slot_batched) {
31423138 // make sure we're in the right embedding mode
@@ -3146,20 +3142,12 @@ struct server_context {
31463142 }
31473143
31483144 // process the created batch of tokens
3149- for (int32_t i = 0 ; i < batch.n_tokens ; i += n_batch) {
3150- const int32_t n_tokens = std::min (n_batch, batch.n_tokens - i);
3151-
3152- llama_batch batch_view = {
3153- n_tokens,
3154- batch.token + i,
3155- nullptr ,
3156- batch.pos + i,
3157- batch.n_seq_id + i,
3158- batch.seq_id + i,
3159- batch.logits + i,
3160- };
3145+ for (int32_t i = 0 ; i < llama_batch_get_n_tokens (batch.get ()); i += n_batch) {
3146+ const int32_t n_tokens = std::min (n_batch, llama_batch_get_n_tokens (batch.get ()) - i);
3147+
3148+ llama_batch_ptr batch_view (llama_batch_get_view (batch.get (), i, n_tokens));
31613149
3162- const int ret = llama_decode (ctx, batch_view);
3150+ const int ret = llama_decode (ctx, batch_view. get () );
31633151 metrics.on_decoded (slots);
31643152
31653153 if (ret != 0 ) {
@@ -3294,16 +3282,16 @@ struct server_context {
32943282 }
32953283
32963284 // construct the speculation batch
3297- common_batch_clear (slot.batch_spec );
3298- common_batch_add (slot.batch_spec , id, slot.n_past , { slot.id }, true );
3285+ common_batch_clear (slot.batch_spec . get () );
3286+ common_batch_add (slot.batch_spec . get () , id, slot.n_past , { slot.id }, true );
32993287
33003288 for (size_t i = 0 ; i < draft.size (); ++i) {
3301- common_batch_add (slot.batch_spec , draft[i], slot.n_past + 1 + i, { slot.id }, true );
3289+ common_batch_add (slot.batch_spec . get () , draft[i], slot.n_past + 1 + i, { slot.id }, true );
33023290 }
33033291
3304- SLT_DBG (slot, " decoding speculative batch, size = %d\n " , slot.batch_spec .n_tokens );
3292+ SLT_DBG (slot, " decoding speculative batch, size = %d\n " , llama_batch_get_n_tokens ( slot.batch_spec .get ()) );
33053293
3306- llama_decode (ctx, slot.batch_spec );
3294+ llama_decode (ctx, slot.batch_spec . get () );
33073295
33083296 // the accepted tokens from the speculation
33093297 const auto ids = common_sampler_sample_and_accept_n (slot.smpl , ctx, draft);
0 commit comments