Skip to content

Commit 9102a7c

Browse files
committed
add consistency checks
1 parent ebac831 commit 9102a7c

File tree

1 file changed

+21
-2
lines changed

1 file changed

+21
-2
lines changed

src/llama-batch.cpp

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,27 @@ bool llama_batch_allocr::init(
252252
// consistency checks
253253
//
254254

255-
// TODO @ngxson : we currently can't check M-RoPE positions, as the position is increased based on image size
256-
if (n_pos_per_embd == 1) {
255+
if (n_pos_per_embd > 1) {
256+
// M-RoPE case: allow position to "jump" forward only (non-continuous positions are allowed)
257+
for (uint32_t s = 0; s < n_seq_max; ++s) {
258+
if (seq_pos[s].empty()) {
259+
continue;
260+
}
261+
262+
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
263+
264+
if (p0 >= 0 && p0 >= seq_pos_min(s)) {
265+
LLAMA_LOG_ERROR(
266+
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
267+
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
268+
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
269+
" for M-RoPE, it is required that the position satisfies: X < Y\n",
270+
__func__, s, s, p0, s, seq_pos_min(s));
271+
272+
return false;
273+
}
274+
}
275+
} else {
257276
for (uint32_t s = 0; s < n_seq_max; ++s) {
258277
if (seq_pos[s].empty()) {
259278
continue;

0 commit comments

Comments
 (0)