Skip to content

Commit e1a3d85

Browse files
committed
llama.cpp: Fix Qwen2.5 VL cache causal masking (PR #16745)
Applied changes from llama.cpp PR #16745 to fix cache causal masking issues for Qwen2.5 VL models. Key changes: - Disabled consecutive position validation in llama-batch.cpp (allows position jumps for vision embeddings) - Added kv_position_of_token field to track KV cache positions for proper causal masking - Modified causal masking logic to use batch positions instead of temporal positions - Updated M-RoPE position calculation to use max(nx, ny) for images This fix allows Qwen VL models to handle non-consecutive positions in embeddings, which is required for proper vision processing. Ref: ggml-org/llama.cpp#16745
1 parent 15eab62 commit e1a3d85

File tree

5 files changed

+111
-95
lines changed

5 files changed

+111
-95
lines changed

llama/llama.cpp/src/llama-batch.cpp

Lines changed: 87 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -221,12 +221,11 @@ bool llama_batch_allocr::init(
221221
/*.n_seq_id =*/ batch.n_seq_id,
222222
/*.seq_id =*/ batch.seq_id,
223223
/*.seq_id_unq =*/ this->seq_id_unq.data(),
224-
/*.seq_idx =*/ this->seq_idx.data(),
225-
/*.output =*/ batch.logits,
226-
/*.data =*/ {},
227-
};
228-
229-
ubatch_print(ubatch, debug);
224+
/*.seq_idx =*/ this->seq_idx.data(),
225+
/*.output =*/ batch.logits,
226+
/*.kv_position_of_token=*/ {},
227+
/*.data =*/ {},
228+
}; ubatch_print(ubatch, debug);
230229

231230
LLAMA_LOG_DEBUG("%s: seq = [\n", __func__);
232231
for (int s0 = 0; s0 < (int) seq_pos.size(); ++s0) {
@@ -256,36 +255,38 @@ bool llama_batch_allocr::init(
256255
continue;
257256
}
258257

259-
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
260-
261-
if (p0 >= 0) {
262-
bool ok = true;
263-
264-
if (batch.token) {
265-
if (seq_pos_min(s) != p0 + 1) {
266-
ok = false;
267-
}
268-
} else {
269-
assert(batch.embd);
270-
271-
// for embeddings (typically used as vision input), we allow them to have repeating positions
272-
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
273-
if (seq_pos_min(s) != p0 && seq_pos_min(s) != p0 + 1) {
274-
ok = false;
275-
}
276-
}
277-
278-
if (!ok) {
279-
LLAMA_LOG_ERROR(
280-
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
281-
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
282-
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
283-
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
284-
__func__, s, s, p0, s, seq_pos_min(s));
285-
286-
return false;
287-
}
288-
}
258+
//@fmayran: these checks don't make sense with models using position encoding such as Qwen VL, because the position stored in the KV cache can jump around (it is not even always increasing).
259+
//it is not enough to let them be repeating. Within an image embedding, arbitrary jumps are expected.
260+
//const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
261+
//
262+
//if (p0 >= 0) {
263+
// bool ok = true;
264+
//
265+
// if (batch.token) {
266+
// if (seq_pos_min(s) != p0 + 1) {
267+
// ok = false;
268+
// }
269+
// } else {
270+
// assert(batch.embd);
271+
//
272+
// // for embeddings (typically used as vision input), we allow them to have repeating positions
273+
// // ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
274+
// if (seq_pos_min(s) != p0 && seq_pos_min(s) != p0 + 1) {
275+
// ok = false;
276+
// }
277+
// }
278+
//
279+
// if (!ok) {
280+
// LLAMA_LOG_ERROR(
281+
// "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
282+
// " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
283+
// " - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
284+
// " it is required that the sequence positions remain consecutive: Y = X + 1\n",
285+
// __func__, s, s, p0, s, seq_pos_min(s));
286+
//
287+
// return false;
288+
// }
289+
//}
289290

290291
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
291292
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
@@ -369,36 +370,38 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t
369370

370371
auto udata = std::make_shared<llama_ubatch::data_t>();
371372

372-
udata->token .resize(n_tokens);
373-
udata->embd .clear();
374-
udata->pos .resize(n_tokens);
375-
udata->n_seq_id .resize(n_tokens);
376-
udata->seq_id .resize(n_tokens);
377-
udata->seq_id_unq.resize(0);
378-
udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
379-
udata->output .resize(n_tokens);
373+
udata->token .resize(n_tokens);
374+
udata->embd .clear();
375+
udata->pos .resize(n_tokens);
376+
udata->n_seq_id .resize(n_tokens);
377+
udata->seq_id .resize(n_tokens);
378+
udata->seq_id_unq .resize(0);
379+
udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
380+
udata->output .resize(n_tokens);
381+
udata->kv_position_of_token.resize(n_tokens, -1);
380382

381383
for (uint32_t s = 0; s < n_seqs; ++s) {
382384
udata->seq_idx[s] = s;
383385
udata->seq_id_unq.push_back(s);
384386
}
385387

386388
llama_ubatch res {
387-
/*.b_equal_seqs =*/ true,
388-
/*.n_tokens =*/ n_tokens,
389-
/*.n_seq_tokens =*/ n_seq_tokens,
390-
/*.n_seqs =*/ n_seqs,
391-
/*.n_seqs_unq =*/ n_seqs,
392-
393-
/*.token =*/ udata->token.data(),
394-
/*.embd =*/ nullptr,
395-
/*.pos =*/ udata->pos.data(),
396-
/*.n_seq_id =*/ udata->n_seq_id.data(),
397-
/*.seq_id =*/ udata->seq_id.data(),
398-
/*.seq_id_unq =*/ udata->seq_id_unq.data(),
399-
/*.seq_idx =*/ udata->seq_idx.data(),
400-
/*.output =*/ udata->output.data(),
401-
/*.data =*/ std::move(udata),
389+
/*.b_equal_seqs =*/ true,
390+
/*.n_tokens =*/ n_tokens,
391+
/*.n_seq_tokens =*/ n_seq_tokens,
392+
/*.n_seqs =*/ n_seqs,
393+
/*.n_seqs_unq =*/ n_seqs,
394+
395+
/*.token =*/ udata->token.data(),
396+
/*.embd =*/ nullptr,
397+
/*.pos =*/ udata->pos.data(),
398+
/*.n_seq_id =*/ udata->n_seq_id.data(),
399+
/*.seq_id =*/ udata->seq_id.data(),
400+
/*.seq_id_unq =*/ udata->seq_id_unq.data(),
401+
/*.seq_idx =*/ udata->seq_idx.data(),
402+
/*.output =*/ udata->output.data(),
403+
/*.kv_position_of_token=*/ udata->kv_position_of_token.data(),
404+
/*.data =*/ std::move(udata),
402405
};
403406

404407
return res;
@@ -660,14 +663,15 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
660663
const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
661664
const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur;
662665

663-
udata->token .resize(n_tokens);
664-
udata->embd .resize(n_embd_all);
665-
udata->pos .resize(n_pos_all);
666-
udata->n_seq_id .resize(n_tokens);
667-
udata->seq_id .resize(n_tokens);
668-
udata->seq_id_unq.resize(0);
669-
udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
670-
udata->output .resize(n_tokens);
666+
udata->token .resize(n_tokens);
667+
udata->embd .resize(n_embd_all);
668+
udata->pos .resize(n_pos_all);
669+
udata->n_seq_id .resize(n_tokens);
670+
udata->seq_id .resize(n_tokens);
671+
udata->seq_id_unq .resize(0);
672+
udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
673+
udata->output .resize(n_tokens);
674+
udata->kv_position_of_token.resize(n_tokens, -1);
671675

672676
seq_set_t seq_set_unq;
673677

@@ -705,21 +709,22 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
705709
}
706710

707711
llama_ubatch res {
708-
/*.b_equal_seqs =*/ equal_seqs,
709-
/*.n_tokens =*/ n_tokens,
710-
/*.n_seq_tokens =*/ n_tokens/n_seqs,
711-
/*.n_seqs =*/ n_seqs,
712-
/*.n_seqs_unq =*/ (uint32_t) udata->seq_id_unq.size(),
713-
714-
/*.token =*/ batch.token ? udata->token.data() : nullptr,
715-
/*.embd =*/ batch.embd ? udata->embd.data() : nullptr,
716-
/*.pos =*/ udata->pos.data(),
717-
/*.n_seq_id =*/ udata->n_seq_id.data(),
718-
/*.seq_id =*/ udata->seq_id.data(),
719-
/*.seq_id_unq =*/ udata->seq_id_unq.data(),
720-
/*.seq_idx =*/ udata->seq_idx.data(),
721-
/*.output =*/ udata->output.data(),
722-
/*.data =*/ std::move(udata),
712+
/*.b_equal_seqs =*/ equal_seqs,
713+
/*.n_tokens =*/ n_tokens,
714+
/*.n_seq_tokens =*/ n_tokens/n_seqs,
715+
/*.n_seqs =*/ n_seqs,
716+
/*.n_seqs_unq =*/ (uint32_t) udata->seq_id_unq.size(),
717+
718+
/*.token =*/ batch.token ? udata->token.data() : nullptr,
719+
/*.embd =*/ batch.embd ? udata->embd.data() : nullptr,
720+
/*.pos =*/ udata->pos.data(),
721+
/*.n_seq_id =*/ udata->n_seq_id.data(),
722+
/*.seq_id =*/ udata->seq_id.data(),
723+
/*.seq_id_unq =*/ udata->seq_id_unq.data(),
724+
/*.seq_idx =*/ udata->seq_idx.data(),
725+
/*.output =*/ udata->output.data(),
726+
/*.kv_position_of_token=*/ udata->kv_position_of_token.data(),
727+
/*.data =*/ std::move(udata),
723728
};
724729

725730
if (debug > 0) {

llama/llama.cpp/src/llama-batch.h

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,16 @@ struct llama_ubatch {
3030
// seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
3131
// used for extracting sequence pooled embeddings
3232

33-
// // size | idx | val
34-
llama_token * token; // [n_tokens] | i | id, token
35-
float * embd; // [n_embd, n_tokens] | i | embd
36-
llama_pos * pos; // [n_tokens] | i | pos
37-
int32_t * n_seq_id; // [n_tokens] | i | -
38-
llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
39-
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
40-
int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
41-
int8_t * output; // [n_tokens] | i | -
33+
// // size | idx | val
34+
llama_token * token; // [n_tokens] | i | id, token
35+
float * embd; // [n_embd, n_tokens] | i | embd
36+
llama_pos * pos; // [n_tokens] | i | pos
37+
int32_t * n_seq_id; // [n_tokens] | i | -
38+
llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
39+
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
40+
int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
41+
int8_t * output; // [n_tokens] | i | -
42+
int32_t * kv_position_of_token; // [n_tokens] | i | kv position where the token was inserted
4243

4344
struct data_t {
4445
std::vector<llama_token> token;
@@ -49,6 +50,7 @@ struct llama_ubatch {
4950
std::vector<llama_seq_id> seq_id_unq;
5051
std::vector<int32_t> seq_idx;
5152
std::vector<int8_t> output;
53+
std::vector<int32_t> kv_position_of_token;//when pushed to the kv cache, where is the token pushed (used for causal masking)
5254
};
5355

5456
// the llama_ubatch pointers above point to this data if set. otherwise - points to non-owning data

llama/llama.cpp/src/llama-kv-cache.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -895,6 +895,7 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch &
895895
}
896896

897897
cells.pos_set(idx, ubatch.pos[i]);
898+
ubatch.kv_position_of_token[i] = (int32_t)idx;//set the position in the kv cache as a property for this token (needed for proper causal masking)
898899

899900
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
900901
cells.seq_add(idx, ubatch.seq_id[i][s]);
@@ -1215,6 +1216,12 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
12151216

12161217
std::fill(data, data + ggml_nelements(dst), -INFINITY);
12171218

1219+
std::vector<int32_t> map_kv_to_batch(n_kv, -1);//for each token in the cache, either (-1) or the position in the current ubatch
1220+
for (uint32_t i = 0; i < n_tokens; ++i)//invert the batch -> kv position map into a kv -> batch position map
1221+
{
1222+
if (ubatch->kv_position_of_token[i] != -1)
1223+
map_kv_to_batch[ubatch->kv_position_of_token[i]] = i;
1224+
}
12181225
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
12191226
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
12201227
// Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
@@ -1254,8 +1261,10 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
12541261
const llama_pos p0 = cells.pos_get(j);
12551262

12561263
// mask future tokens
1257-
if (causal_attn && p0 > p1) {
1258-
continue;
1264+
if (causal_attn)
1265+
{
1266+
if (map_kv_to_batch[j] != -1 && map_kv_to_batch[j] > (int32_t)i)//if the kv cache token is in the current batch AND its position in the batch is higher than i
1267+
continue;
12591268
}
12601269

12611270
// apply SWA if any

llama/llama.cpp/tools/mtmd/mtmd.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1036,7 +1036,7 @@ const char * mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) {
10361036

10371037
llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) {
10381038
if (image_tokens->use_mrope_pos) {
1039-
return 1; // for M-RoPE, the whole image is 1 in temporal dimension
1039+
return (::std::max)(image_tokens->nx, image_tokens->ny);//assuming image, not video // for M-RoPE, the whole image is 1 in temporal dimension
10401040
}
10411041
return image_tokens->n_tokens();
10421042
}

llama/llama.cpp/tools/mtmd/mtmd.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ MTMD_API const mtmd_image_tokens * mtmd_input_chunk_get_tokens_image(const mtmd
156156
MTMD_API size_t mtmd_input_chunk_get_n_tokens (const mtmd_input_chunk * chunk);
157157
// returns nullptr for ID on text chunk
158158
MTMD_API const char * mtmd_input_chunk_get_id (const mtmd_input_chunk * chunk);
159-
// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
159+
// number of temporal positions (always max(ntok_x, ntok_y, ntok_t) for M-RoPE, n_tokens otherwise)
160160
MTMD_API llama_pos mtmd_input_chunk_get_n_pos (const mtmd_input_chunk * chunk);
161161

162162
// in case you want to use custom logic to handle the chunk (i.e. KV cache management)

0 commit comments

Comments
 (0)