@@ -257,6 +257,8 @@ bool llama_batch_allocr::init(
257257 continue ;
258258 }
259259
260+ // @fmayran: these checks don't make sense with models using position encoding such as Qwen VL, because the position stored in the KV cache can jump around (it is not even always increasing).
261+ // it is not enough to let them be repeating. Within an image embedding, arbitrary jumps are expected.
260262 // const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
261263 //
262264 // if (p0 >= 0) {
@@ -370,37 +372,38 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t
370372
371373 auto udata = std::make_shared<llama_ubatch::data_t >();
372374
373- udata->token .resize (n_tokens);
374- udata->embd .clear ();
375- udata->pos .resize (n_tokens);
376- udata->n_seq_id .resize (n_tokens);
377- udata->seq_id .resize (n_tokens);
378- udata->seq_id_unq .resize (0 );
379- udata->seq_idx .resize (LLAMA_MAX_SEQ, -1 );
380- udata->output .resize (n_tokens);
375+ udata->token .resize (n_tokens);
376+ udata->embd .clear ();
377+ udata->pos .resize (n_tokens);
378+ udata->n_seq_id .resize (n_tokens);
379+ udata->seq_id .resize (n_tokens);
380+ udata->seq_id_unq .resize (0 );
381+ udata->seq_idx .resize (LLAMA_MAX_SEQ, -1 );
382+ udata->output .resize (n_tokens);
383+ udata->kv_position_of_token .resize (n_tokens, -1 );
381384
382385 for (uint32_t s = 0 ; s < n_seqs; ++s) {
383386 udata->seq_idx [s] = s;
384387 udata->seq_id_unq .push_back (s);
385388 }
386389
387390 llama_ubatch res {
388- /* .b_equal_seqs =*/ true ,
389- /* .n_tokens =*/ n_tokens,
390- /* .n_seq_tokens =*/ n_seq_tokens,
391- /* .n_seqs =*/ n_seqs,
392- /* .n_seqs_unq =*/ n_seqs,
393-
394- /* .token =*/ udata->token .data (),
395- /* .embd =*/ nullptr ,
396- /* .pos =*/ udata->pos .data (),
397- /* .n_seq_id =*/ udata->n_seq_id .data (),
398- /* .seq_id =*/ udata->seq_id .data (),
399- /* .seq_id_unq =*/ udata->seq_id_unq .data (),
400- /* .seq_idx =*/ udata->seq_idx .data (),
401- /* .output =*/ udata->output .data (),
402- /* .data =*/ std::move ( udata),
403- /* .kv_position_of_token =*/ {} ,
391+ /* .b_equal_seqs =*/ true ,
392+ /* .n_tokens =*/ n_tokens,
393+ /* .n_seq_tokens =*/ n_seq_tokens,
394+ /* .n_seqs =*/ n_seqs,
395+ /* .n_seqs_unq =*/ n_seqs,
396+
397+ /* .token =*/ udata->token .data (),
398+ /* .embd =*/ nullptr ,
399+ /* .pos =*/ udata->pos .data (),
400+ /* .n_seq_id =*/ udata->n_seq_id .data (),
401+ /* .seq_id =*/ udata->seq_id .data (),
402+ /* .seq_id_unq =*/ udata->seq_id_unq .data (),
403+ /* .seq_idx =*/ udata->seq_idx .data (),
404+ /* .output =*/ udata->output .data (),
405+ /* .kv_position_of_token =*/ udata-> kv_position_of_token . data ( ),
406+ /* .data =*/ std::move (udata) ,
404407 };
405408
406409 return res;
@@ -662,14 +665,15 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
662665 const int64_t n_embd_all = batch.embd ? (int64_t ) n_tokens*n_embd : 0 ;
663666 const int64_t n_pos_all = (int64_t ) n_tokens*n_pos_cur;
664667
665- udata->token .resize (n_tokens);
666- udata->embd .resize (n_embd_all);
667- udata->pos .resize (n_pos_all);
668- udata->n_seq_id .resize (n_tokens);
669- udata->seq_id .resize (n_tokens);
670- udata->seq_id_unq .resize (0 );
671- udata->seq_idx .resize (LLAMA_MAX_SEQ, -1 );
672- udata->output .resize (n_tokens);
668+ udata->token .resize (n_tokens);
669+ udata->embd .resize (n_embd_all);
670+ udata->pos .resize (n_pos_all);
671+ udata->n_seq_id .resize (n_tokens);
672+ udata->seq_id .resize (n_tokens);
673+ udata->seq_id_unq .resize (0 );
674+ udata->seq_idx .resize (LLAMA_MAX_SEQ, -1 );
675+ udata->output .resize (n_tokens);
676+ udata->kv_position_of_token .resize (n_tokens, -1 );
673677
674678 seq_set_t seq_set_unq;
675679
@@ -707,22 +711,23 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
707711 }
708712
709713 llama_ubatch res {
710- /* .b_equal_seqs =*/ equal_seqs,
711- /* .n_tokens =*/ n_tokens,
712- /* .n_seq_tokens =*/ n_tokens/n_seqs,
713- /* .n_seqs =*/ n_seqs,
714- /* .n_seqs_unq =*/ (uint32_t ) udata->seq_id_unq .size (),
715-
716- /* .token =*/ batch.token ? udata->token .data () : nullptr ,
717- /* .embd =*/ batch.embd ? udata->embd .data () : nullptr ,
718- /* .pos =*/ udata->pos .data (),
719- /* .n_seq_id =*/ udata->n_seq_id .data (),
720- /* .seq_id =*/ udata->seq_id .data (),
721- /* .seq_id_unq =*/ udata->seq_id_unq .data (),
722- /* .seq_idx =*/ udata->seq_idx .data (),
723- /* .output =*/ udata->output .data (),
724- /* .data =*/ std::move (udata),
725- /* .kv_position_of_token=*/ {},
714+ /* .b_equal_seqs =*/ equal_seqs,
715+ /* .n_tokens =*/ n_tokens,
716+ /* .n_seq_tokens =*/ n_tokens/n_seqs,
717+ /* .n_seqs =*/ n_seqs,
718+ /* .n_seqs_unq =*/ (uint32_t ) udata->seq_id_unq .size (),
719+
720+ /* .token =*/ batch.token ? udata->token .data () : nullptr ,
721+ /* .embd =*/ batch.embd ? udata->embd .data () : nullptr ,
722+ /* .pos =*/ udata->pos .data (),
723+ /* .n_seq_id =*/ udata->n_seq_id .data (),
724+ /* .seq_id =*/ udata->seq_id .data (),
725+ /* .seq_id_unq =*/ udata->seq_id_unq .data (),
726+ /* .seq_idx =*/ udata->seq_idx .data (),
727+ /* .output =*/ udata->output .data (),
728+ /* .kv_position_of_token=*/ udata->kv_position_of_token .data (),
729+ /* .data =*/ std::move (udata),
730+
726731 };
727732
728733 if (debug > 0 ) {
0 commit comments