Skip to content

Commit 3464bda

Browse files
authored
llama: fix ASAN error with M-RoPE (#16848)
1 parent e3af556 commit 3464bda

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

src/llama-batch.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -669,10 +669,8 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
669669

670670
auto udata = std::make_shared<llama_ubatch::data_t>();
671671

672-
const int32_t n_pos_cur = batch.embd ? n_pos_per_embd : 1;
673-
674672
const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
675-
const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur;
673+
const int64_t n_pos_all = (int64_t) n_tokens*n_pos_per_embd;
676674

677675
udata->token .resize(n_tokens);
678676
udata->embd .resize(n_embd_all);
@@ -694,8 +692,13 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
694692
memcpy(udata->embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
695693
}
696694

697-
for (int j = 0; j < n_pos_cur; ++j) {
698-
udata->pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
695+
for (size_t j = 0; j < (size_t)n_pos_per_embd; ++j) {
696+
// if we are using M-RoPE
697+
// if the current batch is text, we need to broadcast the same position across all RoPE sections
698+
// otherwise, the input batch is image embeddings, we copy the positions as-is
699+
// if we are not using M-RoPE, there is only one position per token (this loop runs only once)
700+
size_t src_off = batch.token ? 0 : j*batch.n_tokens;
701+
udata->pos[j*n_tokens + i] = batch.pos[src_off + idxs[i]];
699702
}
700703

701704
udata->n_seq_id[i] = batch.n_seq_id[idxs[i]];

0 commit comments

Comments
 (0)