@@ -460,13 +460,12 @@ struct decode_embd_batch {
460460 std::vector<llama_seq_id *> seq_ids;
461461 std::vector<int8_t > logits;
462462 llama_batch batch;
463- decode_embd_batch (float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id, int n_pos_per_embd, int n_mmproj_embd) : n_pos_per_embd(n_pos_per_embd), n_mmproj_embd(n_mmproj_embd) {
463+ decode_embd_batch (float * embd, int32_t n_tokens, int n_pos_per_embd, int n_mmproj_embd) : n_pos_per_embd(n_pos_per_embd), n_mmproj_embd(n_mmproj_embd) {
464464 pos .resize (n_tokens * n_pos_per_embd);
465465 n_seq_id.resize (n_tokens);
466466 seq_ids .resize (n_tokens + 1 );
467467 logits .resize (n_tokens);
468468 seq_id_0.resize (1 );
469- seq_id_0[0 ] = seq_id;
470469 seq_ids [n_tokens] = nullptr ;
471470 batch = {
472471 /* n_tokens =*/ n_tokens,
@@ -477,14 +476,23 @@ struct decode_embd_batch {
477476 /* seq_id =*/ seq_ids.data (),
478477 /* logits =*/ logits.data (),
479478 };
480- for (int i = 0 ; i < n_tokens; i++) {
479+ }
480+
481+ void set_position_normal (llama_pos pos_0, llama_seq_id seq_id) {
482+ seq_id_0[0 ] = seq_id;
483+ for (int i = 0 ; i < batch.n_tokens ; i++) {
481484 batch.pos [i] = pos_0 + i;
482485 batch.n_seq_id [i] = 1 ;
483486 batch.seq_id [i] = seq_id_0.data ();
484487 batch.logits [i] = false ;
485488 }
486489 }
487490
491+ void set_position_mrope (llama_pos pos_0, int nx, int ny, llama_seq_id seq_id) {
492+ seq_id_0[0 ] = seq_id;
493+ GGML_ABORT (" TODO" );
494+ }
495+
488496 llama_batch get_view (int offset, int n_tokens) {
489497 return {
490498 /* n_tokens =*/ n_tokens,
@@ -556,7 +564,15 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
556564 int32_t i_batch = 0 ;
557565 int32_t n_img_batches = GGML_PAD (n_tokens, n_batch) / n_batch;
558566 float * embd = mtmd_get_output_embd (ctx);
559- decode_embd_batch batch_embd (embd, n_tokens, n_past, seq_id, n_pos_per_embd, n_mmproj_embd);
567+ decode_embd_batch batch_embd (embd, n_tokens, n_pos_per_embd, n_mmproj_embd);
568+
569+ if (mtmd_decode_use_mrope (ctx)) {
570+ int nx = mtmd_image_tokens_get_nx (chunk.tokens_image .get ());
571+ int ny = mtmd_image_tokens_get_ny (chunk.tokens_image .get ());
572+ batch_embd.set_position_mrope (pos0, nx, ny, seq_id);
573+ } else {
574+ batch_embd.set_position_normal (pos0, seq_id);
575+ }
560576
561577 if (mtmd_decode_use_non_causal (ctx)) {
562578 llama_set_causal_attn (lctx, false );
0 commit comments