@@ -1860,7 +1860,8 @@ struct server_context {
18601860
18611861 llama_context_params cparams_dft;
18621862
1863- server_batch batch;
1863+ llama_batch batch;
1864+ server_batch_embd batch_embd;
18641865
18651866 bool clean_kv_cache = true ;
18661867 bool add_bos_token = true ;
@@ -1898,6 +1899,8 @@ struct server_context {
18981899
18991900 llama_batch_free (slot.batch_spec );
19001901 }
1902+
1903+ llama_batch_free (batch);
19011904 }
19021905
19031906 bool load_model (const common_params & params) {
@@ -2034,7 +2037,8 @@ struct server_context {
20342037 // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
20352038 {
20362039 const int32_t n_batch = llama_n_batch (ctx);
2037- batch = server_batch (std::max (n_batch, params_base.n_parallel ));
2040+ batch = llama_batch_init (std::max (n_batch, params_base.n_parallel ), 0 , 1 );
2041+ batch_embd = server_batch_embd (std::max (n_batch, params_base.n_parallel ));
20382042 }
20392043
20402044 metrics.init ();
@@ -2931,7 +2935,7 @@ struct server_context {
29312935 }*/
29322936
29332937 // start populating the batch for this iteration
2934- batch. clear ( );
2938+ common_batch_clear (batch );
29352939
29362940 // track if given slot can be batched with slots already in the batch
29372941 server_slot * slot_batched = nullptr ;
@@ -2953,9 +2957,9 @@ struct server_context {
29532957 continue ;
29542958 }
29552959
2956- slot.i_batch = batch.n_tokens () ;
2960+ slot.i_batch = batch.n_tokens ;
29572961
2958- common_batch_add (batch. batch , slot.sampled , slot.n_past , { slot.id }, true );
2962+ common_batch_add (batch, slot.sampled , slot.n_past , { slot.id }, true );
29592963
29602964 slot.n_past += 1 ;
29612965
@@ -2972,7 +2976,7 @@ struct server_context {
29722976 int32_t n_ubatch = llama_n_ubatch (ctx);
29732977
29742978 // next, batch any pending prompts without exceeding n_batch
2975- if (params_base.cont_batching || batch.n_tokens () == 0 ) {
2979+ if (params_base.cont_batching || batch.n_tokens == 0 ) {
29762980 for (auto & slot : slots) {
29772981 // check if we can batch this slot with the previous one
29782982 if (slot.is_processing ()) {
@@ -3140,7 +3144,7 @@ struct server_context {
31403144 // non-causal tasks require to fit the entire prompt in the physical batch
31413145 if (slot.is_non_causal ()) {
31423146 // cannot fit the prompt in the current batch - will try next iter
3143- if (batch.n_tokens () + slot.n_prompt_tokens > n_batch) {
3147+ if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
31443148 continue ;
31453149 }
31463150 }
@@ -3163,28 +3167,26 @@ struct server_context {
31633167
31643168 // check if we should process the image
31653169 if (curr_chunk.tok_image ) {
3166- if (batch.has_text ()) {
3167- continue ; // we cannot have both text batch and image batch
3170+ // process the image
3171+ int32_t res = server_img_process (ctx, mctx, curr_chunk, batch_embd, slot.n_past , slot.id );
3172+ if (res != 0 ) {
3173+ SLT_ERR (slot, " failed to process image, res = %d\n " , res);
3174+ slot.release ();
3175+ send_error (slot, " failed to process image" , ERROR_TYPE_SERVER);
3176+ continue ;
31683177 }
31693178
3170- // encode the image
3171- server_encode_image (slot.mctx , batch, curr_chunk, slot.n_past , slot.id );
3172- GGML_ASSERT (batch.has_embd ());
3173- SLT_INF (slot, " image encoded, n_past = %d, n_embd_tokens = %d\n " , slot.n_past , batch.n_tokens ());
3174-
31753179 if (slot.params .cache_prompt ) {
31763180 slot.cache_tokens .add_image_tokens (curr_chunk.tok_image );
31773181 }
31783182
3179- slot.n_past += batch.n_tokens ();
3180- slot.n_prompt_tokens_processed += batch.n_tokens ();
3181-
3182- break ; // currently, we can only process one image at a time, so we skip ALL other slots
3183+ slot.n_past += curr_chunk.n_tokens ;
3184+ slot.n_prompt_tokens_processed += curr_chunk.n_tokens ;
31833185 }
31843186
31853187 // add prompt tokens for processing in the current batch
3186- while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens () < n_batch) {
3187- GGML_ASSERT (!batch. has_embd ());
3188+ while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
3189+ // get next token to process
31883190 auto & curr_chunk = slot.prompt_tokens .get_chunk (slot.n_past );
31893191 if (curr_chunk.tok_text == LLAMA_TOKEN_NULL) {
31903192 break ; // end of text chunk
@@ -3193,7 +3195,7 @@ struct server_context {
31933195 // without pooling, we want to output the embeddings for all the tokens in the batch
31943196 const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type (slot.ctx ) == LLAMA_POOLING_TYPE_NONE;
31953197
3196- common_batch_add (batch. batch , curr_chunk.tok_text , slot.n_past , { slot.id }, need_embd);
3198+ common_batch_add (batch, curr_chunk.tok_text , slot.n_past , { slot.id }, need_embd);
31973199 if (slot.params .cache_prompt ) {
31983200 slot.cache_tokens .add_text_token (curr_chunk.tok_text );
31993201 }
@@ -3204,47 +3206,47 @@ struct server_context {
32043206
32053207 // SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str());
32063208
3207- SLT_INF (slot, " prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n " , slot.n_past , batch.n_tokens () , (float ) slot.n_prompt_tokens_processed / slot.n_prompt_tokens );
3209+ SLT_INF (slot, " prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n " , slot.n_past , batch.n_tokens , (float ) slot.n_prompt_tokens_processed / slot.n_prompt_tokens );
32083210
32093211 // entire prompt has been processed
32103212 if (slot.n_past == slot.n_prompt_tokens ) {
32113213 slot.state = SLOT_STATE_DONE_PROMPT;
32123214
3213- GGML_ASSERT (batch.n_tokens () > 0 );
3215+ GGML_ASSERT (batch.n_tokens > 0 );
32143216
32153217 common_sampler_reset (slot.smpl );
32163218
32173219 // Process all prompt tokens through sampler system
32183220 for (size_t i = 0 ; i < slot.cache_tokens .n_tokens (); ++i) {
3219- auto & curr_chunk = slot.cache_tokens .get_chunk (i);
3221+ auto & curr_chunk = slot.prompt_tokens .get_chunk (i);
32203222 if (curr_chunk.tok_text != LLAMA_TOKEN_NULL) {
32213223 common_sampler_accept (slot.smpl , curr_chunk.tok_text , false );
32223224 }
32233225 }
32243226
32253227 // extract the logits only for the last token
3226- batch.logits [batch.n_tokens () - 1 ] = true ;
3228+ batch.logits [batch.n_tokens - 1 ] = true ;
32273229
32283230 slot.n_decoded = 0 ;
3229- slot.i_batch = batch.n_tokens () - 1 ;
3231+ slot.i_batch = batch.n_tokens - 1 ;
32303232
3231- SLT_INF (slot, " prompt done, n_past = %d, n_tokens = %d\n " , slot.n_past , batch.n_tokens () );
3233+ SLT_INF (slot, " prompt done, n_past = %d, n_tokens = %d\n " , slot.n_past , batch.n_tokens );
32323234 }
32333235 }
32343236
3235- if (batch.n_tokens () >= n_batch) {
3237+ if (batch.n_tokens >= n_batch) {
32363238 break ;
32373239 }
32383240 }
32393241 }
32403242
3241- if (batch.n_tokens () == 0 ) {
3243+ if (batch.n_tokens == 0 ) {
32423244 SRV_WRN (" %s" , " no tokens to decode\n " );
32433245 return ;
32443246 }
32453247
32463248 // debug
3247- SRV_DBG (" decoding %s batch, n_tokens = %d\n " , batch.has_embd () ? " embd " : " text " , batch. n_tokens () );
3249+ SRV_DBG (" decoding batch, n_tokens = %d\n " , batch.n_tokens );
32483250
32493251 if (slot_batched) {
32503252 // make sure we're in the right embedding mode
@@ -3254,32 +3256,22 @@ struct server_context {
32543256 }
32553257
32563258 // process the created batch of tokens
3257- for (int32_t i = 0 ; i < batch.n_tokens () ; i += n_batch) {
3258- const int32_t n_tokens = std::min (n_batch, batch.n_tokens () - i);
3259+ for (int32_t i = 0 ; i < batch.n_tokens ; i += n_batch) {
3260+ const int32_t n_tokens = std::min (n_batch, batch.n_tokens - i);
32593261
3260- // TODO @ngxson : hacky here, we don't want to split the embd batch
3261- llama_batch batch_view = batch.has_embd () ? batch.batch : llama_batch{
3262+ llama_batch batch_view = llama_batch{
32623263 n_tokens,
3263- batch.batch . token + i,
3264+ batch.token + i,
32643265 nullptr ,
3265- batch.batch . pos + i,
3266- batch.batch . n_seq_id + i,
3267- batch.batch . seq_id + i,
3268- batch.batch . logits + i,
3266+ batch.pos + i,
3267+ batch.n_seq_id + i,
3268+ batch.seq_id + i,
3269+ batch.logits + i,
32693270 };
32703271
3271- // TODO @ngxson : maybe move this to llama_batch_ext
3272- if (batch.has_embd () && mtmd_decode_use_non_causal (mctx)) {
3273- llama_set_causal_attn (ctx, false );
3274- }
3275-
32763272 const int ret = llama_decode (ctx, batch_view);
32773273 metrics.on_decoded (slots);
32783274
3279- if (batch.has_embd () && mtmd_decode_use_non_causal (mctx)) {
3280- llama_set_causal_attn (ctx, true );
3281- }
3282-
32833275 if (ret != 0 ) {
32843276 if (n_batch == 1 || ret < 0 ) {
32853277 // if you get here, it means the KV cache is full - try increasing it via the context size
0 commit comments