@@ -1859,7 +1859,7 @@ struct server_context {
18591859
18601860 llama_context_params cparams_dft;
18611861
1862- llama_batch batch = {} ;
1862+ server_batch batch;
18631863
18641864 bool clean_kv_cache = true ;
18651865 bool add_bos_token = true ;
@@ -1897,8 +1897,6 @@ struct server_context {
18971897
18981898 llama_batch_free (slot.batch_spec );
18991899 }
1900-
1901- llama_batch_free (batch);
19021900 }
19031901
19041902 bool load_model (const common_params & params) {
@@ -2035,9 +2033,7 @@ struct server_context {
20352033 // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
20362034 {
20372035 const int32_t n_batch = llama_n_batch (ctx);
2038-
2039- // only a single seq_id per token is needed
2040- batch = llama_batch_init (std::max (n_batch, params_base.n_parallel ), 0 , 1 );
2036+ batch = server_batch (std::max (n_batch, params_base.n_parallel ));
20412037 }
20422038
20432039 metrics.init ();
@@ -2934,7 +2930,7 @@ struct server_context {
29342930 }*/
29352931
29362932 // start populating the batch for this iteration
2937- common_batch_clear ( batch);
2933+ batch. clear ( );
29382934
29392935 // track if given slot can be batched with slots already in the batch
29402936 server_slot * slot_batched = nullptr ;
@@ -2956,9 +2952,9 @@ struct server_context {
29562952 continue ;
29572953 }
29582954
2959- slot.i_batch = batch.n_tokens ;
2955+ slot.i_batch = batch.n_tokens () ;
29602956
2961- common_batch_add (batch, slot.sampled , slot.n_past , { slot.id }, true );
2957+ common_batch_add (batch. batch , slot.sampled , slot.n_past , { slot.id }, true );
29622958
29632959 slot.n_past += 1 ;
29642960
@@ -2974,12 +2970,8 @@ struct server_context {
29742970 int32_t n_batch = llama_n_batch (ctx);
29752971 int32_t n_ubatch = llama_n_ubatch (ctx);
29762972
2977- // for multimodal
2978- bool is_decoding_embd = false ;
2979- server_embd_batch batch_embd;
2980-
29812973 // next, batch any pending prompts without exceeding n_batch
2982- if (params_base.cont_batching || batch.n_tokens == 0 ) {
2974+ if (params_base.cont_batching || batch.n_tokens () == 0 ) {
29832975 for (auto & slot : slots) {
29842976 // check if we can batch this slot with the previous one
29852977 if (slot.is_processing ()) {
@@ -3147,7 +3139,7 @@ struct server_context {
31473139 // non-causal tasks require to fit the entire prompt in the physical batch
31483140 if (slot.is_non_causal ()) {
31493141 // cannot fit the prompt in the current batch - will try next iter
3150- if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
3142+ if (batch.n_tokens () + slot.n_prompt_tokens > n_batch) {
31513143 continue ;
31523144 }
31533145 }
@@ -3167,36 +3159,55 @@ struct server_context {
31673159 slot.cache_tokens .keep_until (slot.n_past );
31683160
31693161 // add prompt tokens for processing in the current batch
3170- while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
3162+ while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens () < n_batch) {
31713163 // without pooling, we want to output the embeddings for all the tokens in the batch
31723164 const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type (slot.ctx ) == LLAMA_POOLING_TYPE_NONE;
31733165
31743166 auto & curr_chunk = slot.prompt_tokens .get_chunk (slot.n_past );
31753167 if (curr_chunk.tok_image ) {
3176- // decode image
3177- server_encode_image (slot.mctx , batch_embd, curr_chunk, slot.n_past , slot.id );
3178- is_decoding_embd = true ;
3179- SLT_INF (slot, " decoding image, n_past = %d, n_tokens = %d\n " , slot.n_past , batch_embd.batch .n_tokens );
3180- slot.n_past += batch_embd.batch .n_tokens ;
3181- break ; // do not process any other slots
3168+ // if there are already TEXT tokens in the batch, we need to process them first
3169+ if (batch.batch .n_tokens > 0 ) {
3170+ break ;
3171+ }
3172+ // encode the image
3173+ server_encode_image (slot.mctx , batch, curr_chunk, slot.n_past , slot.id );
3174+ GGML_ASSERT (batch.has_embd ());
3175+ SLT_INF (slot, " image encoded, n_past = %d, n_embd_tokens = %d\n " , slot.n_past , batch.n_tokens ());
3176+
3177+ if (slot.params .cache_prompt ) {
3178+ slot.cache_tokens .add_image_tokens (curr_chunk.tok_image );
3179+ }
3180+
3181+ slot.n_past += batch.n_tokens ();
3182+ slot.n_prompt_tokens_processed += batch.n_tokens ();
3183+ break ; // we cannot have both text batch and image batch
3184+
31823185 } else {
3183- common_batch_add (batch, curr_chunk.tok_text , slot.n_past , { slot.id }, need_embd);
3186+ GGML_ASSERT (!batch.has_embd ());
3187+ common_batch_add (batch.batch , curr_chunk.tok_text , slot.n_past , { slot.id }, need_embd);
31843188 if (slot.params .cache_prompt ) {
31853189 slot.cache_tokens .add_text_token (curr_chunk.tok_text );
31863190 }
3191+
3192+ slot.n_prompt_tokens_processed ++;
3193+ slot.n_past ++;
31873194 }
3195+ }
3196+
3197+ SLT_INF (slot, " new cache_tokens: %s\n " , slot.cache_tokens .str ().c_str ());
31883198
3189- slot.n_prompt_tokens_processed ++;
3190- slot.n_past ++;
3199+ if (batch.has_embd ()) {
3200+ // currently, we can only process one image at a time, so we skip other slots
3201+ break ;
31913202 }
31923203
3193- SLT_INF (slot, " prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n " , slot.n_past , batch.n_tokens , (float ) slot.n_prompt_tokens_processed / slot.n_prompt_tokens );
3204+ SLT_INF (slot, " prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n " , slot.n_past , batch.n_tokens () , (float ) slot.n_prompt_tokens_processed / slot.n_prompt_tokens );
31943205
31953206 // entire prompt has been processed
31963207 if (slot.n_past == slot.n_prompt_tokens ) {
31973208 slot.state = SLOT_STATE_DONE_PROMPT;
31983209
3199- GGML_ASSERT (batch.n_tokens > 0 );
3210+ GGML_ASSERT (batch.n_tokens () > 0 );
32003211
32013212 common_sampler_reset (slot.smpl );
32023213
@@ -3209,27 +3220,32 @@ struct server_context {
32093220 }
32103221
32113222 // extract the logits only for the last token
3212- batch.logits [batch.n_tokens - 1 ] = true ;
3223+ batch.logits [batch.n_tokens () - 1 ] = true ;
32133224
32143225 slot.n_decoded = 0 ;
3215- slot.i_batch = batch.n_tokens - 1 ;
3226+ slot.i_batch = batch.n_tokens () - 1 ;
32163227
3217- SLT_INF (slot, " prompt done, n_past = %d, n_tokens = %d\n " , slot.n_past , batch.n_tokens );
3228+ SLT_INF (slot, " prompt done, n_past = %d, n_tokens = %d\n " , slot.n_past , batch.n_tokens () );
32183229 }
32193230 }
32203231
3221- if (batch.n_tokens >= n_batch) {
3232+ if (batch.n_tokens () >= n_batch) {
32223233 break ;
32233234 }
32243235 }
32253236 }
32263237
3227- if (batch.n_tokens == 0 ) {
3238+ if (batch.n_tokens () == 0 ) {
32283239 SRV_WRN (" %s" , " no tokens to decode\n " );
32293240 return ;
32303241 }
32313242
3232- SRV_DBG (" decoding batch, n_tokens = %d\n " , batch.n_tokens );
3243+ // debug
3244+ if (batch.has_embd ()) {
3245+ SRV_INF (" decoding embd batch, n_tokens = %d\n " , batch.n_tokens ());
3246+ } else {
3247+ SRV_INF (" decoding batch, n_tokens = %d\n " , batch.n_tokens ());
3248+ }
32333249
32343250 if (slot_batched) {
32353251 // make sure we're in the right embedding mode
@@ -3239,28 +3255,29 @@ struct server_context {
32393255 }
32403256
32413257 // process the created batch of tokens
3242- for (int32_t i = 0 ; i < batch.n_tokens ; i += n_batch) {
3243- const int32_t n_tokens = std::min (n_batch, batch.n_tokens - i);
3258+ for (int32_t i = 0 ; i < batch.n_tokens () ; i += n_batch) {
3259+ const int32_t n_tokens = std::min (n_batch, batch.n_tokens () - i);
32443260
3245- llama_batch batch_view = is_decoding_embd ? batch_embd.batch : llama_batch{
3261+ // TODO @ngxson : hacky here, we don't want to split the embd batch
3262+ llama_batch batch_view = batch.has_embd () ? batch.batch : llama_batch{
32463263 n_tokens,
3247- batch.token + i,
3264+ batch.batch . token + i,
32483265 nullptr ,
3249- batch.pos + i,
3250- batch.n_seq_id + i,
3251- batch.seq_id + i,
3252- batch.logits + i,
3266+ batch.batch . pos + i,
3267+ batch.batch . n_seq_id + i,
3268+ batch.batch . seq_id + i,
3269+ batch.batch . logits + i,
32533270 };
32543271
32553272 // TODO @ngxson : maybe move this to llama_batch_ext
3256- if (is_decoding_embd && mtmd_decode_use_non_causal (mctx)) {
3273+ if (batch. has_embd () && mtmd_decode_use_non_causal (mctx)) {
32573274 llama_set_causal_attn (ctx, false );
32583275 }
32593276
32603277 const int ret = llama_decode (ctx, batch_view);
32613278 metrics.on_decoded (slots);
32623279
3263- if (is_decoding_embd && mtmd_decode_use_non_causal (mctx)) {
3280+ if (batch. has_embd () && mtmd_decode_use_non_causal (mctx)) {
32643281 llama_set_causal_attn (ctx, true );
32653282 }
32663283
@@ -4006,13 +4023,13 @@ int main(int argc, char ** argv) {
40064023 /* add_special */ true ,
40074024 /* parse_special */ true ,
40084025 };
4009- mtmd_input_chunks * tokenized = mtmd_tokenize (ctx_server.mctx , inp_txt, bitmaps);
4010- if (!tokenized) {
4026+ mtmd_input_chunks chunks;
4027+ int32_t tokenized = mtmd_tokenize (ctx_server.mctx , chunks, inp_txt, bitmaps);
4028+ if (tokenized != 0 ) {
40114029 throw std::runtime_error (" Failed to tokenize prompt" );
40124030 }
4013- server_inputs tmp (tokenized );
4031+ server_inputs tmp (chunks );
40144032 inputs.push_back (std::move (tmp));
4015- mtmd_input_chunks_free (tokenized, false ); // only free the container, not the images
40164033 }
40174034 } else {
40184035 // non-multimodal version
0 commit comments