@@ -66,7 +66,7 @@ struct decode_embd_batch {
6666 std::vector<int8_t > logits;
6767 llama_batch batch;
6868 decode_embd_batch (float * embd, int32_t n_tokens, int n_pos_per_embd, int n_mmproj_embd) : n_pos_per_embd(n_pos_per_embd), n_mmproj_embd(n_mmproj_embd) {
69- pos .resize (n_tokens * n_pos_per_embd);
69+ pos .resize (n_tokens * ( n_pos_per_embd + 1 ) );
7070 n_seq_id.resize (n_tokens);
7171 seq_ids .resize (n_tokens + 1 );
7272 logits .resize (n_tokens);
@@ -100,13 +100,14 @@ struct decode_embd_batch {
100100 for (int y = 0 ; y < ny; y++) {
101101 for (int x = 0 ; x < nx; x++) {
102102 int i = y * nx + x;
103- pos[i ] = pos_0;
104- pos[i + batch.n_tokens ] = pos_0 + y;
105- pos[i + batch.n_tokens * 2 ] = pos_0 + x;
106- pos[i + batch.n_tokens * 3 ] = 0 ; // last pos dim is unused
103+ pos[i + batch. n_tokens ] = pos_0;
104+ pos[i + batch.n_tokens * 2 ] = pos_0 + y;
105+ pos[i + batch.n_tokens * 3 ] = pos_0 + x;
106+ pos[i + batch.n_tokens * 4 ] = 0 ; // last pos dim is unused
107107 }
108108 }
109109 for (int i = 0 ; i < batch.n_tokens ; i++) {
110+ batch.pos [i] = pos_0 + i;
110111 batch.n_seq_id [i] = 1 ;
111112 batch.seq_id [i] = seq_id_0.data ();
112113 batch.logits [i] = false ;
@@ -118,12 +119,13 @@ struct decode_embd_batch {
118119 GGML_ASSERT (n_pos_per_embd == 4 );
119120 seq_id_0[0 ] = seq_id;
120121 for (int i = 0 ; i < batch.n_tokens ; i++) {
121- pos[i ] = pos_0 + i;
122122 pos[i + batch.n_tokens ] = pos_0 + i;
123123 pos[i + batch.n_tokens * 2 ] = pos_0 + i;
124- pos[i + batch.n_tokens * 3 ] = 0 ; // last pos dim is unused
124+ pos[i + batch.n_tokens * 3 ] = pos_0 + i;
125+ pos[i + batch.n_tokens * 4 ] = 0 ; // last pos dim is unused
125126 }
126127 for (int i = 0 ; i < batch.n_tokens ; i++) {
128+ batch.pos [i] = pos_0 + i;
127129 batch.n_seq_id [i] = 1 ;
128130 batch.seq_id [i] = seq_id_0.data ();
129131 batch.logits [i] = false ;
@@ -133,12 +135,12 @@ struct decode_embd_batch {
133135 llama_batch get_view (int offset, int n_tokens) {
134136 llama_pos * pos_ptr;
135137 pos_view.clear ();
136- pos_view.reserve (n_tokens * n_pos_per_embd);
138+ pos_view.reserve (n_tokens * ( n_pos_per_embd + 1 ) );
137139 if (n_pos_per_embd > 1 ) {
138140 // mrope
139141 // for example, with layout of src: 1234...1234...1234...1234...
140142 // offset 2 will give us dst: 34...34...34...34...
141- for (int i = 0 ; i < n_pos_per_embd; i++) {
143+ for (int i = 0 ; i <= n_pos_per_embd; i++) {
142144 // assume n_tokens is less than or equal to batch.n_tokens
143145 // batch.n_tokens is number of **total** tokens
144146 // n_tokens is number of viewed token
0 commit comments