Skip to content

Commit 8da3c0e

Browse files
authored
batch : fix consistency checks for the input positions (ggml-org#16890)
1 parent c22473b commit 8da3c0e

File tree

1 file changed

+22
-8
lines changed

1 file changed

+22
-8
lines changed

src/llama-batch.cpp

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -261,15 +261,29 @@ bool llama_batch_allocr::init(
261261

262262
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
263263

264-
if (p0 >= 0 && p0 >= seq_pos_min(s)) {
265-
LLAMA_LOG_ERROR(
266-
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
267-
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
268-
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
269-
" for M-RoPE, it is required that the position satisfies: X < Y\n",
270-
__func__, s, s, p0, s, seq_pos_min(s));
264+
if (batch.token) {
265+
if (p0 >= 0 && p0 >= seq_pos_min(s)) {
266+
LLAMA_LOG_ERROR(
267+
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
268+
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
269+
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
270+
" for M-RoPE, it is required that the position satisfies: X < Y\n",
271+
__func__, s, s, p0, s, seq_pos_min(s));
271272

272-
return false;
273+
return false;
274+
}
275+
} else {
276+
// embedding inputs can have overlapping positions
277+
if (p0 >= 0 && p0 > seq_pos_min(s)) {
278+
LLAMA_LOG_ERROR(
279+
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
280+
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
281+
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
282+
" for M-RoPE, it is required that the position satisfies: X <= Y\n",
283+
__func__, s, s, p0, s, seq_pos_min(s));
284+
285+
return false;
286+
}
273287
}
274288
}
275289
} else {

0 commit comments

Comments
 (0)