Skip to content

Commit 8646e36

Browse files
committed
decode_embd_batch::set_position_...
1 parent 8742f8a commit 8646e36

File tree

1 file changed

+20
-4
lines changed

1 file changed

+20
-4
lines changed

examples/llava/mtmd.cpp

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -460,13 +460,12 @@ struct decode_embd_batch {
460460
std::vector<llama_seq_id *> seq_ids;
461461
std::vector<int8_t> logits;
462462
llama_batch batch;
463-
decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id, int n_pos_per_embd, int n_mmproj_embd) : n_pos_per_embd(n_pos_per_embd), n_mmproj_embd(n_mmproj_embd) {
463+
decode_embd_batch(float * embd, int32_t n_tokens, int n_pos_per_embd, int n_mmproj_embd) : n_pos_per_embd(n_pos_per_embd), n_mmproj_embd(n_mmproj_embd) {
464464
pos .resize(n_tokens * n_pos_per_embd);
465465
n_seq_id.resize(n_tokens);
466466
seq_ids .resize(n_tokens + 1);
467467
logits .resize(n_tokens);
468468
seq_id_0.resize(1);
469-
seq_id_0[0] = seq_id;
470469
seq_ids [n_tokens] = nullptr;
471470
batch = {
472471
/*n_tokens =*/ n_tokens,
@@ -477,14 +476,23 @@ struct decode_embd_batch {
477476
/*seq_id =*/ seq_ids.data(),
478477
/*logits =*/ logits.data(),
479478
};
480-
for (int i = 0; i < n_tokens; i++) {
479+
}
480+
481+
void set_position_normal(llama_pos pos_0, llama_seq_id seq_id) {
482+
seq_id_0[0] = seq_id;
483+
for (int i = 0; i < batch.n_tokens; i++) {
481484
batch.pos [i] = pos_0 + i;
482485
batch.n_seq_id[i] = 1;
483486
batch.seq_id [i] = seq_id_0.data();
484487
batch.logits [i] = false;
485488
}
486489
}
487490

491+
void set_position_mrope(llama_pos pos_0, int nx, int ny, llama_seq_id seq_id) {
492+
seq_id_0[0] = seq_id;
493+
GGML_ABORT("TODO");
494+
}
495+
488496
llama_batch get_view(int offset, int n_tokens) {
489497
return {
490498
/*n_tokens =*/ n_tokens,
@@ -556,7 +564,15 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
556564
int32_t i_batch = 0;
557565
int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;
558566
float * embd = mtmd_get_output_embd(ctx);
559-
decode_embd_batch batch_embd(embd, n_tokens, n_past, seq_id, n_pos_per_embd, n_mmproj_embd);
567+
decode_embd_batch batch_embd(embd, n_tokens, n_pos_per_embd, n_mmproj_embd);
568+
569+
if (mtmd_decode_use_mrope(ctx)) {
570+
int nx = mtmd_image_tokens_get_nx(chunk.tokens_image.get());
571+
int ny = mtmd_image_tokens_get_ny(chunk.tokens_image.get());
572+
batch_embd.set_position_mrope(pos0, nx, ny, seq_id);
573+
} else {
574+
batch_embd.set_position_normal(pos0, seq_id);
575+
}
560576

561577
if (mtmd_decode_use_non_causal(ctx)) {
562578
llama_set_causal_attn(lctx, false);

0 commit comments

Comments
 (0)