@@ -1205,14 +1205,55 @@ struct server_task_result_apply_lora : server_task_result {
12051205 }
12061206};
12071207
1208+ struct server_batch {
1209+ llama_batch_ext_ptr batch;
1210+ struct batch_token {
1211+ llama_token token;
1212+ llama_seq_id seq_id;
1213+ bool logits;
1214+ };
1215+ std::vector<batch_token> tokens;
1216+ server_batch () = default ;
1217+ server_batch (int32_t n_tokens, int32_t n_seq_max) {
1218+ batch.reset (llama_batch_ext_init (n_tokens, n_seq_max));
1219+ tokens.reserve (n_tokens);
1220+ }
1221+ void clear () {
1222+ llama_batch_ext_clear (batch.get ());
1223+ tokens.clear ();
1224+ }
1225+ void add_text (llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) {
1226+ llama_batch_ext_add_text (batch.get (), token, pos, &seq_id, 1 , logits);
1227+ tokens.push_back ({token, seq_id, logits});
1228+ }
1229+ void set_logits_last () {
1230+ if (!tokens.empty ()) {
1231+ llama_batch_ext_set_logits_last (batch.get ());
1232+ tokens.back ().logits = true ;
1233+ }
1234+ }
1235+ int32_t get_n_tokens () const {
1236+ return (int32_t )tokens.size ();
1237+ }
1238+ server_batch get_view (int32_t offset, int32_t n_tokens) {
1239+ server_batch view;
1240+ view.batch = llama_batch_ext_ptr (llama_batch_ext_get_view (batch.get (), offset, n_tokens));
1241+ view.tokens .reserve (n_tokens);
1242+ for (int32_t i = 0 ; i < n_tokens; i++) {
1243+ view.tokens .push_back (tokens[offset + i]);
1244+ }
1245+ return view;
1246+ }
1247+ };
1248+
12081249struct server_slot {
12091250 int id;
12101251 int id_task = -1 ;
12111252
12121253 // only used for completion/embedding/infill/rerank
12131254 server_task_type task_type = SERVER_TASK_TYPE_COMPLETION;
12141255
1215- llama_batch_ext_ptr batch_spec;
1256+ server_batch batch_spec;
12161257
12171258 llama_context * ctx = nullptr ;
12181259 llama_context * ctx_dft = nullptr ;
@@ -1784,7 +1825,7 @@ struct server_context {
17841825
17851826 llama_context_params cparams_dft;
17861827
1787- llama_batch_ext_ptr batch;
1828+ server_batch batch;
17881829
17891830 bool clean_kv_cache = true ;
17901831 bool add_bos_token = true ;
@@ -1909,7 +1950,7 @@ struct server_context {
19091950 slot.n_predict = params_base.n_predict ;
19101951
19111952 if (model_dft) {
1912- slot.batch_spec . reset ( llama_batch_ext_init ( params_base.speculative .n_max + 1 , 1 ) );
1953+ slot.batch_spec = server_batch ( params_base.speculative .n_max + 1 , 1 );
19131954
19141955 slot.ctx_dft = llama_init_from_model (model_dft, cparams_dft);
19151956 if (slot.ctx_dft == nullptr ) {
@@ -1945,7 +1986,7 @@ struct server_context {
19451986 const int32_t n_batch = llama_n_batch (ctx);
19461987
19471988 // only a single seq_id per token is needed
1948- batch. reset ( llama_batch_ext_init ( std::max (n_batch, params_base.n_parallel ), 1 ) );
1989+ batch = server_batch ( std::max (n_batch, params_base.n_parallel ), 1 );
19491990 }
19501991
19511992 metrics.init ();
@@ -2063,7 +2104,7 @@ struct server_context {
20632104 }
20642105
20652106 if (slot.ctx_dft ) {
2066- slot.batch_spec . reset ( llama_batch_ext_init ( slot.params .speculative .n_max + 1 , 1 ) );
2107+ slot.batch_spec = server_batch ( slot.params .speculative .n_max + 1 , 1 );
20672108 }
20682109
20692110 slot.state = SLOT_STATE_STARTED;
@@ -2371,7 +2412,7 @@ struct server_context {
23712412 queue_results.send (std::move (res));
23722413 }
23732414
2374- void send_embedding (const server_slot & slot, llama_batch_ext_ptr & batch) {
2415+ void send_embedding (const server_slot & slot, server_batch & batch) {
23752416 auto res = std::make_unique<server_task_result_embd>();
23762417 res->id = slot.id_task ;
23772418 res->index = slot.index ;
@@ -2382,19 +2423,19 @@ struct server_context {
23822423
23832424 std::vector<float > embd_res (n_embd, 0 .0f );
23842425
2385- for (int i = 0 ; i < llama_batch_ext_get_n_tokens ( batch.get () ); ++i) {
2386- llama_batch_ext_token_info tok = llama_batch_ext_get_token_info ( batch.get (), i) ;
2387- if (!tok.logits || tok.seq_id [ 0 ] != slot.id ) {
2426+ for (int i = 0 ; i < batch.get_n_tokens ( ); ++i) {
2427+ auto tok = batch.tokens [i] ;
2428+ if (!tok.logits || tok.seq_id != slot.id ) {
23882429 continue ;
23892430 }
23902431
2391- const float * embd = llama_get_embeddings_seq (ctx, tok.seq_id [ 0 ] );
2432+ const float * embd = llama_get_embeddings_seq (ctx, tok.seq_id );
23922433 if (embd == NULL ) {
23932434 embd = llama_get_embeddings_ith (ctx, i);
23942435 }
23952436
23962437 if (embd == NULL ) {
2397- SLT_ERR (slot, " failed to get embeddings, token = %d, seq_id = %d\n " , tok.token , tok.seq_id [ 0 ] );
2438+ SLT_ERR (slot, " failed to get embeddings, token = %d, seq_id = %d\n " , tok.token , tok.seq_id );
23982439
23992440 res->embedding .push_back (std::vector<float >(n_embd, 0 .0f ));
24002441 continue ;
@@ -2415,25 +2456,25 @@ struct server_context {
24152456 queue_results.send (std::move (res));
24162457 }
24172458
2418- void send_rerank (const server_slot & slot, llama_batch_ext_ptr & batch) {
2459+ void send_rerank (const server_slot & slot, server_batch & batch) {
24192460 auto res = std::make_unique<server_task_result_rerank>();
24202461 res->id = slot.id_task ;
24212462 res->index = slot.index ;
24222463 res->n_tokens = slot.n_prompt_tokens ;
24232464
2424- for (int i = 0 ; i < llama_batch_ext_get_n_tokens ( batch.get () ); ++i) {
2425- llama_batch_ext_token_info tok = llama_batch_ext_get_token_info ( batch.get (), i) ;
2426- if (!tok.logits || tok.seq_id [ 0 ] != slot.id ) {
2465+ for (int i = 0 ; i < batch.get_n_tokens ( ); ++i) {
2466+ auto tok = batch.tokens [i] ;
2467+ if (!tok.logits || tok.seq_id != slot.id ) {
24272468 continue ;
24282469 }
24292470
2430- const float * embd = llama_get_embeddings_seq (ctx, tok.seq_id [ 0 ] );
2471+ const float * embd = llama_get_embeddings_seq (ctx, tok.seq_id );
24312472 if (embd == NULL ) {
24322473 embd = llama_get_embeddings_ith (ctx, i);
24332474 }
24342475
24352476 if (embd == NULL ) {
2436- SLT_ERR (slot, " failed to get embeddings, token = %d, seq_id = %d\n " , tok.token , tok.seq_id [ 0 ] );
2477+ SLT_ERR (slot, " failed to get embeddings, token = %d, seq_id = %d\n " , tok.token , tok.seq_id );
24372478
24382479 res->score = -1e6 ;
24392480 continue ;
@@ -2824,7 +2865,7 @@ struct server_context {
28242865 }
28252866
28262867 // start populating the batch for this iteration
2827- llama_batch_ext_clear ( batch.get () );
2868+ batch.clear ( );
28282869
28292870 // track if given slot can be batched with slots already in the batch
28302871 server_slot * slot_batched = nullptr ;
@@ -2846,10 +2887,9 @@ struct server_context {
28462887 continue ;
28472888 }
28482889
2849- slot.i_batch = llama_batch_ext_get_n_tokens ( batch.get () );
2890+ slot.i_batch = batch.get_n_tokens ( );
28502891
2851- std::array<llama_token, 1 > seq_id = { slot.id };
2852- llama_batch_ext_add_text (batch.get (), slot.sampled , slot.n_past , seq_id.data (), seq_id.size (), true );
2892+ batch.add_text (slot.sampled , slot.n_past , slot.id , true );
28532893
28542894 slot.n_past += 1 ;
28552895
@@ -2866,7 +2906,7 @@ struct server_context {
28662906 int32_t n_ubatch = llama_n_ubatch (ctx);
28672907
28682908 // next, batch any pending prompts without exceeding n_batch
2869- if (params_base.cont_batching || llama_batch_ext_get_n_tokens ( batch.get () ) == 0 ) {
2909+ if (params_base.cont_batching || batch.get_n_tokens ( ) == 0 ) {
28702910 for (auto & slot : slots) {
28712911 // check if we can batch this slot with the previous one
28722912 if (slot.is_processing ()) {
@@ -3032,7 +3072,7 @@ struct server_context {
30323072 // non-causal tasks require to fit the entire prompt in the physical batch
30333073 if (slot.is_non_causal ()) {
30343074 // cannot fit the prompt in the current batch - will try next iter
3035- if (llama_batch_ext_get_n_tokens ( batch.get () ) + slot.n_prompt_tokens > n_batch) {
3075+ if (batch.get_n_tokens ( ) + slot.n_prompt_tokens > n_batch) {
30363076 continue ;
30373077 }
30383078 }
@@ -3052,12 +3092,11 @@ struct server_context {
30523092 slot.cache_tokens .resize (slot.n_past );
30533093
30543094 // add prompt tokens for processing in the current batch
3055- while (slot.n_past < slot.n_prompt_tokens && llama_batch_ext_get_n_tokens ( batch.get () ) < n_batch) {
3095+ while (slot.n_past < slot.n_prompt_tokens && batch.get_n_tokens ( ) < n_batch) {
30563096 // without pooling, we want to output the embeddings for all the tokens in the batch
30573097 const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type (slot.ctx ) == LLAMA_POOLING_TYPE_NONE;
30583098
3059- std::array<llama_token, 1 > seq_id = { slot.id };
3060- llama_batch_ext_add_text (batch.get (), prompt_tokens[slot.n_past ], slot.n_past , seq_id.data (), seq_id.size (), need_embd);
3099+ batch.add_text (prompt_tokens[slot.n_past ], slot.n_past , slot.id , need_embd);
30613100
30623101 if (slot.params .cache_prompt ) {
30633102 slot.cache_tokens .push_back (prompt_tokens[slot.n_past ]);
@@ -3067,13 +3106,13 @@ struct server_context {
30673106 slot.n_past ++;
30683107 }
30693108
3070- SLT_INF (slot, " prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n " , slot.n_past , llama_batch_ext_get_n_tokens ( batch.get () ), (float ) slot.n_prompt_tokens_processed / slot.n_prompt_tokens );
3109+ SLT_INF (slot, " prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n " , slot.n_past , batch.get_n_tokens ( ), (float ) slot.n_prompt_tokens_processed / slot.n_prompt_tokens );
30713110
30723111 // entire prompt has been processed
30733112 if (slot.n_past == slot.n_prompt_tokens ) {
30743113 slot.state = SLOT_STATE_DONE_PROMPT;
30753114
3076- GGML_ASSERT (llama_batch_ext_get_n_tokens ( batch.get () ) > 0 );
3115+ GGML_ASSERT (batch.get_n_tokens ( ) > 0 );
30773116
30783117 common_sampler_reset (slot.smpl );
30793118
@@ -3083,27 +3122,27 @@ struct server_context {
30833122 }
30843123
30853124 // extract the logits only for the last token
3086- llama_batch_ext_set_logits_last ( batch.get () );
3125+ batch.set_logits_last ( );
30873126
30883127 slot.n_decoded = 0 ;
3089- slot.i_batch = llama_batch_ext_get_n_tokens ( batch.get () ) - 1 ;
3128+ slot.i_batch = batch.get_n_tokens ( ) - 1 ;
30903129
3091- SLT_INF (slot, " prompt done, n_past = %d, n_tokens = %d\n " , slot.n_past , llama_batch_ext_get_n_tokens ( batch.get () ));
3130+ SLT_INF (slot, " prompt done, n_past = %d, n_tokens = %d\n " , slot.n_past , batch.get_n_tokens ( ));
30923131 }
30933132 }
30943133
3095- if (llama_batch_ext_get_n_tokens ( batch.get () ) >= n_batch) {
3134+ if (batch.get_n_tokens ( ) >= n_batch) {
30963135 break ;
30973136 }
30983137 }
30993138 }
31003139
3101- if (llama_batch_ext_get_n_tokens ( batch.get () ) == 0 ) {
3140+ if (batch.get_n_tokens ( ) == 0 ) {
31023141 SRV_WRN (" %s" , " no tokens to decode\n " );
31033142 return ;
31043143 }
31053144
3106- SRV_DBG (" decoding batch, n_tokens = %d\n " , llama_batch_ext_get_n_tokens ( batch.get () ));
3145+ SRV_DBG (" decoding batch, n_tokens = %d\n " , batch.get_n_tokens ( ));
31073146
31083147 if (slot_batched) {
31093148 // make sure we're in the right embedding mode
@@ -3113,12 +3152,12 @@ struct server_context {
31133152 }
31143153
31153154 // process the created batch of tokens
3116- for (int32_t i = 0 ; i < llama_batch_ext_get_n_tokens ( batch.get () ); i += n_batch) {
3117- const int32_t n_tokens = std::min (n_batch, llama_batch_ext_get_n_tokens ( batch.get () ) - i);
3155+ for (int32_t i = 0 ; i < batch.get_n_tokens ( ); i += n_batch) {
3156+ const int32_t n_tokens = std::min (n_batch, batch.get_n_tokens ( ) - i);
31183157
3119- llama_batch_ext_ptr batch_view ( llama_batch_ext_get_view ( batch.get (), i, n_tokens) );
3158+ server_batch batch_view = batch.get_view ( i, n_tokens);
31203159
3121- const int ret = llama_decode_ext (ctx, batch_view.get ());
3160+ const int ret = llama_decode_ext (ctx, batch_view.batch . get ());
31223161 metrics.on_decoded (slots);
31233162
31243163 if (ret != 0 ) {
@@ -3253,17 +3292,16 @@ struct server_context {
32533292 }
32543293
32553294 // construct the speculation batch
3256- llama_batch_ext_clear (slot.batch_spec .get ());
3257- std::array<llama_token, 1 > seq_id = { slot.id };
3258- llama_batch_ext_add_text (slot.batch_spec .get (), id, slot.n_past , seq_id.data (), seq_id.size (), true );
3295+ slot.batch_spec .clear ();
3296+ slot.batch_spec .add_text (id, slot.n_past , slot.id , true );
32593297
32603298 for (size_t i = 0 ; i < draft.size (); ++i) {
3261- llama_batch_ext_add_text ( slot.batch_spec .get (), draft[i], slot.n_past + 1 , seq_id. data (), seq_id. size () , true );
3299+ slot.batch_spec .add_text ( draft[i], slot.n_past + 1 + i, slot. id , true );
32623300 }
32633301
3264- SLT_DBG (slot, " decoding speculative batch, size = %d\n " , llama_batch_ext_get_n_tokens ( slot.batch_spec .get () ));
3302+ SLT_DBG (slot, " decoding speculative batch, size = %d\n " , slot.batch_spec .get_n_tokens ( ));
32653303
3266- llama_decode_ext (ctx, slot.batch_spec .get ());
3304+ llama_decode_ext (ctx, slot.batch_spec .batch . get ());
32673305
32683306 // the accepted tokens from the speculation
32693307 const auto ids = common_sampler_sample_and_accept_n (slot.smpl , ctx, draft);
0 commit comments