@@ -55,15 +55,11 @@ llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks) {
5555
5656//  helper struct to make working with embd batch easier
5757//  note: this will be removed after llama_batch_ext refactoring
58- //  notes2: Normally, batch's `pos` stores linearly increasing position
59- //  However, some multi-modal models requires special position embedding (e.g. M-Rope in qwen2vl and qwen2.5vl)
60- //  But linearly increasing position is still needed for proper causal attention masking
61- //  So we store both of them: the first n_tokens elements are not changed, while model-specific positions are appended after that.
62- //  So `pos` has `n_tokens * (n_pos_per_embd + 1)` elements
6358struct  decode_embd_batch  {
6459    int  n_pos_per_embd;
6560    int  n_mmproj_embd;
66-     std::vector<llama_pos>      pos;
61+     std::vector<llama_pos>      pos;      //  for M-RoPE, this will have (1+n_pos_per_embd)*n_tokens elements
62+                                           //  the extra n_tokens are for linearly increasing positions
6763    std::vector<llama_pos>      pos_view; //  used by mrope
6864    std::vector<int32_t >        n_seq_id;
6965    std::vector<llama_seq_id>   seq_id_0;
@@ -171,6 +167,59 @@ struct decode_embd_batch {
171167    }
172168};
173169
170+ //  helper struct to make working with embd batch easier
171+ struct  decode_text_batch  {
172+     std::vector<llama_token>    tokens;
173+     std::vector<int32_t >        n_seq_id;
174+     std::vector<llama_seq_id>   seq_id_0;
175+     std::vector<llama_seq_id *> seq_ids;
176+     std::vector<int8_t >         logits;
177+     llama_seq_id                seq_id;
178+     llama_batch batch;
179+     decode_text_batch (int32_t  n_tokens, llama_seq_id seq_id) : seq_id(seq_id) {
180+         tokens  .resize (n_tokens);
181+         n_seq_id.resize (n_tokens);
182+         seq_ids .resize (n_tokens + 1 );
183+         logits  .resize (n_tokens);
184+         seq_ids[n_tokens] = nullptr ;
185+         for  (int32_t  i = 0 ; i < n_tokens; i++) {
186+             n_seq_id[i] = 1 ;
187+             seq_ids [i] = &this ->seq_id ;
188+         }
189+         batch = {
190+             /* n_tokens       =*/ 
191+             /* tokens         =*/ data (),
192+             /* embd           =*/ nullptr ,
193+             /* pos            =*/ nullptr , //  position is tracked automatically
194+             /* n_seq_id       =*/ data (),
195+             /* seq_id         =*/ data (),
196+             /* logits         =*/ data (),
197+         };
198+     }
199+ 
200+     void  clear () {
201+         batch.n_tokens  = 0 ;
202+     }
203+ 
204+     bool  is_full () const  {
205+         return  batch.n_tokens  >= (int32_t ) tokens.size ();
206+     }
207+ 
208+     void  add_token (llama_token tok, bool  output) {
209+         GGML_ASSERT (!is_full ());
210+         int32_t  j = batch.n_tokens ;
211+         batch.token  [j] = tok;
212+         batch.logits [j] = output;
213+         batch.n_tokens ++;
214+     }
215+ 
216+     void  set_logits_last () {
217+         if  (batch.n_tokens  > 0 ) {
218+             batch.logits [batch.n_tokens  - 1 ] = true ;
219+         }
220+     }
221+ };
222+ 
174223//  Helper function for decoding an image whose embeddings have already been calculated
175224int32_t  mtmd_helper_decode_image_chunk (
176225        mtmd_context * ctx,
@@ -259,7 +308,7 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
259308        bool  logits_last,
260309        llama_pos * new_n_past) {
261310    int32_t  ret;
262-     llama_batch  text_batch =  llama_batch_init (n_batch, 0 ,  1 );
311+     decode_text_batch  text_batch (n_batch, seq_id );
263312    auto  chunk_type = mtmd_input_chunk_get_type (chunk);
264313
265314    if  (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
@@ -268,28 +317,20 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
268317        //  LOG_INF("decoding text chunk, n_tokens = %zu\n", n_tokens);
269318        size_t  i = 0 ;
270319        while  (i < n_tokens) { //  split into batches
271-             text_batch.n_tokens  = 0 ; //  clear the batch
272-             for  (; i < n_tokens && text_batch.n_tokens  < n_batch; i++) {
273-                 int32_t  j = text_batch.n_tokens ;
274-                 text_batch.token    [j]    = tokens[i];
275-                 text_batch.pos      [j]    = n_past++;
276-                 text_batch.n_seq_id [j]    = 1 ;
277-                 text_batch.seq_id   [j][0 ] = seq_id;
278-                 text_batch.logits   [j]    = false ;
279- 
280-                 text_batch.n_tokens ++;
320+             text_batch.clear ();
321+             for  (; i < n_tokens && !text_batch.is_full (); i++) {
322+                 text_batch.add_token (tokens[i], false );
281323            }
282324            bool  is_last_token = (i == n_tokens);
283325            if  (logits_last && is_last_token) {
284-                 text_batch.logits [text_batch. n_tokens  -  1 ] =  true ;
326+                 text_batch.set_logits_last () ;
285327            }
286-             ret = llama_decode (lctx, text_batch);
328+             ret = llama_decode (lctx, text_batch. batch );
287329            if  (ret != 0 ) {
288330                LOG_ERR (" failed to decode text\n " 
289-                 llama_batch_free (text_batch);
290331                return  ret;
291332            }
292-             *new_n_past += text_batch.n_tokens ;
333+             *new_n_past += text_batch.batch . n_tokens ;
293334        }
294335
295336    } else  if  (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE || chunk_type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
@@ -301,7 +342,6 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
301342        ret = mtmd_encode_chunk (ctx, chunk);
302343        if  (ret != 0 ) {
303344            LOG_ERR (" failed to encode %s slice\n " 
304-             llama_batch_free (text_batch);
305345            return  ret;
306346        }
307347
@@ -311,14 +351,12 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
311351        ret = mtmd_helper_decode_image_chunk (ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past);
312352        if  (ret != 0 ) {
313353            LOG_ERR (" failed to decode %s\n " 
314-             llama_batch_free (text_batch);
315354            return  ret;
316355        }
317356    } else  {
318357        GGML_ABORT (" chunk type not supported" 
319358    }
320359
321-     llama_batch_free (text_batch);
322360    return  0 ;
323361}
324362
0 commit comments