@@ -55,6 +55,11 @@ llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks) {
5555
5656//  helper struct to make working with embd batch easier
5757//  note: this will be removed after llama_batch_ext refactoring
58+ //  notes2: Normally, batch's `pos` stores linearly increasing position
59+ //  However, some multi-modal models requires special position embedding (e.g. M-Rope in qwen2vl and qwen2.5vl)
60+ //  But linearly increasing position is still needed for proper causal attention masking
61+ //  So we store both of them: the first n_tokens elements are not changed, while model-specific positions are appended after that.
62+ //  So `pos` has `n_tokens * (n_pos_per_embd + 1)` elements
5863struct  decode_embd_batch  {
5964    int  n_pos_per_embd;
6065    int  n_mmproj_embd;
@@ -66,7 +71,7 @@ struct decode_embd_batch {
6671    std::vector<int8_t >         logits;
6772    llama_batch batch;
6873    decode_embd_batch (float  * embd, int32_t  n_tokens, int  n_pos_per_embd, int  n_mmproj_embd) : n_pos_per_embd(n_pos_per_embd), n_mmproj_embd(n_mmproj_embd) {
69-         pos     .resize (n_tokens * n_pos_per_embd);
74+         pos     .resize (n_tokens * ( n_pos_per_embd +  1 ) );
7075        n_seq_id.resize (n_tokens);
7176        seq_ids .resize (n_tokens + 1 );
7277        logits  .resize (n_tokens);
@@ -100,13 +105,14 @@ struct decode_embd_batch {
100105        for  (int  y = 0 ; y < ny; y++) {
101106            for  (int  x = 0 ; x < nx; x++) {
102107                int  i = y * nx + x;
103-                 pos[i                      ] = pos_0;
104-                 pos[i + batch.n_tokens      ] = pos_0 + y;
105-                 pos[i + batch.n_tokens  * 2 ] = pos_0 + x;
106-                 pos[i + batch.n_tokens  * 3 ] = 0 ; //  last pos dim is unused
108+                 pos[i + batch. n_tokens     ] = pos_0;
109+                 pos[i + batch.n_tokens  *  2 ] = pos_0 + y;
110+                 pos[i + batch.n_tokens  * 3 ] = pos_0 + x;
111+                 pos[i + batch.n_tokens  * 4 ] = 0 ; //  last pos dim is unused
107112            }
108113        }
109114        for  (int  i = 0 ; i < batch.n_tokens ; i++) {
115+             batch.pos      [i] = pos_0 + i;
110116            batch.n_seq_id [i] = 1 ;
111117            batch.seq_id   [i] = seq_id_0.data ();
112118            batch.logits   [i] = false ;
@@ -118,12 +124,13 @@ struct decode_embd_batch {
118124        GGML_ASSERT (n_pos_per_embd == 4 );
119125        seq_id_0[0 ] = seq_id;
120126        for  (int  i = 0 ; i < batch.n_tokens ; i++) {
121-             pos[i                     ] = pos_0 + i;
122127            pos[i + batch.n_tokens     ] = pos_0 + i;
123128            pos[i + batch.n_tokens  * 2 ] = pos_0 + i;
124-             pos[i + batch.n_tokens  * 3 ] = 0 ; //  last pos dim is unused
129+             pos[i + batch.n_tokens  * 3 ] = pos_0 + i;
130+             pos[i + batch.n_tokens  * 4 ] = 0 ; //  last pos dim is unused
125131        }
126132        for  (int  i = 0 ; i < batch.n_tokens ; i++) {
133+             batch.pos      [i] = pos_0 + i;
127134            batch.n_seq_id [i] = 1 ;
128135            batch.seq_id   [i] = seq_id_0.data ();
129136            batch.logits   [i] = false ;
@@ -133,12 +140,12 @@ struct decode_embd_batch {
133140    llama_batch get_view (int  offset, int  n_tokens) {
134141        llama_pos * pos_ptr;
135142        pos_view.clear ();
136-         pos_view.reserve (n_tokens * n_pos_per_embd);
143+         pos_view.reserve (n_tokens * ( n_pos_per_embd +  1 ) );
137144        if  (n_pos_per_embd > 1 ) {
138145            //  mrope
139146            //  for example, with layout of src: 1234...1234...1234...1234...
140147            //        offset 2 will give us dst: 34...34...34...34...
141-             for  (int  i = 0 ; i < n_pos_per_embd; i++) {
148+             for  (int  i = 0 ; i <=  n_pos_per_embd; i++) {
142149                //  assume n_tokens is less than or equal to batch.n_tokens
143150                //  batch.n_tokens is number of **total** tokens
144151                //  n_tokens is number of viewed token
0 commit comments