@@ -1080,19 +1080,25 @@ struct server_tokens {
10801080
10811081private: // disallow accessing these members directly, risking out-of-sync
10821082
1083- // map a **start** position in tokens to the image chunk
1084- std::unordered_map<llama_pos, mtmd::input_chunk_ptr> map_pos_to_media;
1083+ // map a **start** index in tokens to the image chunk
1084+ // note: the order need to be in-sync with tokens and pos
1085+ std::map<size_t , mtmd::input_chunk_ptr> map_idx_to_media;
10851086
10861087 // list of tokens
1087- // it can include LLAMA_TOKEN_NULL, which is used to indicate a token that is not a text token
1088- // a mtmd_input_chunk can occupy multiple tokens, one llama_token per **position**
1089- // important: for models using mrope, an image can contain multiple tokens but will use only one **position**
1088+ // if the token is LLAMA_TOKEN_NULL, it indicates that this position is occupied by media chunk
1089+ // otherwise, it is a normal text token
1090+ // note: a non-text chunk can occupy multiple tokens (aka memory cells) in the token list
10901091 llama_tokens tokens;
10911092
1092- // for ex. with input of 5 text tokens and 2 images:
1093- // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1]
1094- // pos 0 1 2 3 4 5 6 7 8 9
1095- // map_pos_to_media will contain: {5, img0}, {8, img1}
1093+ // the position per-token (llama_pos) in the overall input
1094+ // useful for M-RoPE, where the position is different from the index in tokens
1095+ std::vector<llama_pos> pos;
1096+
1097+ // for ex. with input of 5 text tokens and 2 images (each image occupies 3 tokens and 2 pos):
1098+ // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] [img1]
1099+ // idx 0 1 2 3 4 5 6 7 8 9 10
1100+ // pos 0 1 2 3 4 5 5 5 7 7 7
1101+ // map_idx_to_media will contain: {5, img0}, {8, img1}
10961102
10971103public:
10981104 server_tokens () = default ;
@@ -1117,30 +1123,57 @@ struct server_tokens {
11171123 }
11181124 }
11191125
1120- server_tokens (const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {}
1126+ server_tokens (const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {
1127+ for (llama_pos i = 0 ; i < (llama_pos)tokens.size (); ++i) {
1128+ pos.push_back (i);
1129+ }
1130+ }
1131+
1132+ llama_pos next_pos () const {
1133+ if (tokens.empty ()) {
1134+ return 0 ;
1135+ } else if (tokens.back () != LLAMA_TOKEN_NULL) {
1136+ return pos.back () + 1 ;
1137+ } else {
1138+ printf (" %s" , str ().c_str ());
1139+ // find the last media chunk
1140+ GGML_ASSERT (has_mtmd);
1141+ GGML_ASSERT (!map_idx_to_media.empty ());
1142+ const auto & chunk = map_idx_to_media.rbegin ()->second ;
1143+ return pos.back () + mtmd_input_chunk_get_n_pos (chunk.get ());
1144+ }
1145+ }
1146+
1147+ llama_pos get_pos (size_t idx) const {
1148+ GGML_ASSERT (idx < pos.size ());
1149+ return pos[idx];
1150+ }
11211151
11221152 // for debugging
11231153 std::string str () const {
11241154 std::ostringstream oss;
11251155 oss << " tokens: " ;
1126- for (const auto & t : tokens) {
1156+ for (size_t idx = 0 ; idx < tokens.size (); ++idx) {
1157+ llama_token t = tokens[idx];
1158+ llama_pos p = pos[idx];
1159+ oss << " idx:" << idx << " " ;
11271160 if (t == LLAMA_TOKEN_NULL) {
1128- oss << " <embd> " ;
1161+ oss << " <embd>( " << p << " ) \n " ;
11291162 } else {
1130- oss << t << " " ;
1163+ oss << t << " ( " << p << " ) \n " ;
11311164 }
11321165 }
11331166 oss << " \n " ;
1134- oss << " image pos : " ;
1135- for (const auto & it : map_pos_to_media ) {
1167+ oss << " image idx : " ;
1168+ for (const auto & it : map_idx_to_media ) {
11361169 oss << it.first << " , " ;
11371170 }
11381171 return oss.str ();
11391172 }
11401173
1141- const mtmd::input_chunk_ptr & find_chunk (llama_pos pos ) const {
1142- auto it = map_pos_to_media .find (pos );
1143- if (it != map_pos_to_media .end ()) {
1174+ const mtmd::input_chunk_ptr & find_chunk (size_t idx ) const {
1175+ auto it = map_idx_to_media .find (idx );
1176+ if (it != map_idx_to_media .end ()) {
11441177 return it->second ;
11451178 }
11461179 throw std::runtime_error (" Chunk not found" );
@@ -1150,6 +1183,7 @@ struct server_tokens {
11501183 if (tok == LLAMA_TOKEN_NULL) {
11511184 throw std::runtime_error (" Invalid token" );
11521185 }
1186+ pos.emplace_back (next_pos ());
11531187 tokens.emplace_back (tok);
11541188 }
11551189
@@ -1158,13 +1192,15 @@ struct server_tokens {
11581192 auto type = mtmd_input_chunk_get_type (chunk);
11591193 if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
11601194 GGML_ASSERT (has_mtmd);
1161- const int n_pos = mtmd_input_chunk_get_n_pos (chunk);
1162- llama_pos start_pos = tokens.size ();
1163- for (int i = 0 ; i < n_pos; ++i) {
1195+ const size_t n_tokens = mtmd_input_chunk_get_n_tokens (chunk);
1196+ const llama_pos cur_pos = next_pos ();
1197+ size_t start_idx = tokens.size ();
1198+ for (size_t i = 0 ; i < n_tokens; ++i) {
1199+ pos.emplace_back (cur_pos);
11641200 tokens.emplace_back (LLAMA_TOKEN_NULL);
11651201 }
11661202 mtmd::input_chunk_ptr new_chunk (mtmd_input_chunk_copy (chunk));
1167- map_pos_to_media[start_pos ] = std::move (new_chunk);
1203+ map_idx_to_media[start_idx ] = std::move (new_chunk);
11681204 } else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
11691205 size_t n_tokens;
11701206 const auto * text_tokens = mtmd_input_chunk_get_tokens_text (chunk, &n_tokens);
@@ -1178,18 +1214,18 @@ struct server_tokens {
11781214
11791215 // appends server tokens, updates the media map. copies media chunks.
11801216 void push_back (server_tokens & tokens) {
1181- size_t start_pos = size ();
1217+ size_t start_idx = size ();
11821218 for (size_t i = 0 ; i < tokens.size (); i++) {
11831219 push_back (tokens[i]);
11841220 }
11851221 if (tokens.has_mtmd ) {
11861222 // Assert if we are copying MTMD chunks to a server_tokens that does not have mtmd.
11871223 // We could also just check, but this will prevent silently dropping MTMD data.
11881224 GGML_ASSERT (has_mtmd);
1189- for (auto it = tokens.map_pos_to_media .begin (); it != tokens.map_pos_to_media .end (); ) {
1190- auto * chunk = tokens.map_pos_to_media [it->first ].get ();
1225+ for (auto it = tokens.map_idx_to_media .begin (); it != tokens.map_idx_to_media .end (); ) {
1226+ auto * chunk = tokens.map_idx_to_media [it->first ].get ();
11911227 mtmd::input_chunk_ptr new_chunk (mtmd_input_chunk_copy (chunk));
1192- map_pos_to_media[start_pos +it->first ] = std::move (new_chunk);
1228+ map_idx_to_media[start_idx +it->first ] = std::move (new_chunk);
11931229 }
11941230 }
11951231 }
@@ -1198,6 +1234,11 @@ struct server_tokens {
11981234 void insert (const llama_tokens & inp_tokens) {
11991235 GGML_ASSERT (!has_mtmd); // only allow this if mtmd is disabled
12001236 tokens.insert (tokens.end (), inp_tokens.begin (), inp_tokens.end ());
1237+ // rebuild the pos vector
1238+ pos.clear ();
1239+ for (llama_pos i = 0 ; i < (llama_pos)tokens.size (); ++i) {
1240+ pos.emplace_back (i);
1241+ }
12011242 }
12021243
12031244 // for compatibility with speculative decoding, ctx shift, slot save/load
@@ -1245,10 +1286,10 @@ struct server_tokens {
12451286 }
12461287 }
12471288 // remove all image chunks that are not used anymore
1248- for (auto it = map_pos_to_media .begin (); it != map_pos_to_media .end (); ) {
1249- llama_pos pos = it->first ;
1250- if (pos >= (llama_pos) n) {
1251- it = map_pos_to_media .erase (it);
1289+ for (auto it = map_idx_to_media .begin (); it != map_idx_to_media .end (); ) {
1290+ size_t idx = it->first ;
1291+ if (idx >= n) {
1292+ it = map_idx_to_media .erase (it);
12521293 } else {
12531294 ++it;
12541295 }
@@ -1296,12 +1337,12 @@ struct server_tokens {
12961337 const std::string id_ai = mtmd_input_chunk_get_id (a_chunk.get ());
12971338 const std::string id_bi = mtmd_input_chunk_get_id (b_chunk.get ());
12981339
1299- const size_t pos_a = mtmd_input_chunk_get_n_pos (a_chunk.get ());
1300- const size_t pos_b = mtmd_input_chunk_get_n_pos (b_chunk.get ());
1340+ const size_t n_tok_a = mtmd_input_chunk_get_n_tokens (a_chunk.get ());
1341+ const size_t n_tok_b = mtmd_input_chunk_get_n_tokens (b_chunk.get ());
13011342
1302- if (id_ai == id_bi && pos_a == pos_b ) {
1303- GGML_ASSERT (pos_a > 0 && " Invalid media chunk" ); // should never happen
1304- i += pos_a - 1 ; // will be +1 by the for loop
1343+ if (id_ai == id_bi && n_tok_a == n_tok_b ) {
1344+ GGML_ASSERT (n_tok_a > 0 && " Invalid media chunk" ); // should never happen
1345+ i += n_tok_a - 1 ; // will be +1 by the for loop
13051346 continue ;
13061347 }
13071348
@@ -1329,8 +1370,8 @@ struct server_tokens {
13291370 if (t == LLAMA_TOKEN_NULL) {
13301371 try {
13311372 const auto & chunk = find_chunk (i);
1332- size_t n_pos = mtmd_input_chunk_get_n_pos (chunk.get ());
1333- i += n_pos - 1 ; // will be +1 by the for loop
1373+ size_t n_tokens = mtmd_input_chunk_get_n_tokens (chunk.get ());
1374+ i += n_tokens - 1 ; // will be +1 by the for loop
13341375 } catch (const std::exception & e) {
13351376 return false ;
13361377 }
@@ -1345,30 +1386,30 @@ struct server_tokens {
13451386 int32_t process_chunk (
13461387 llama_context * ctx,
13471388 mtmd_context * mctx,
1348- llama_pos n_past ,
1389+ size_t idx ,
13491390 int32_t seq_id,
1350- llama_pos & n_pos_out ) const {
1351- const auto & chunk = find_chunk (n_past );
1391+ size_t & n_tokens_out ) const {
1392+ const auto & chunk = find_chunk (idx );
13521393 const char * name = mtmd_input_chunk_get_type (chunk.get ()) == MTMD_INPUT_CHUNK_TYPE_IMAGE
13531394 ? " image" : " audio" ;
13541395 SRV_INF (" processing %s...\n " , name);
13551396 int32_t n_batch = llama_n_batch (ctx);
13561397 int64_t t0 = ggml_time_ms ();
1357- llama_pos new_n_past = n_past;
1398+ llama_pos new_n_past; // unused for now
13581399 int32_t result = mtmd_helper_eval_chunk_single (mctx, ctx,
13591400 chunk.get (),
1360- n_past,
1401+ pos[idx], // position
13611402 seq_id,
13621403 n_batch,
13631404 true , // logits last
13641405 &new_n_past);
13651406 SRV_INF (" %s processed in %" PRId64 " ms\n " , name, ggml_time_ms () - t0);
13661407 if (result != 0 ) {
13671408 LOG_ERR (" mtmd_helper_eval failed with status %d" , result);
1368- n_pos_out = n_past ;
1409+ n_tokens_out = 0 ;
13691410 return result;
13701411 }
1371- n_pos_out = new_n_past ;
1412+ n_tokens_out = mtmd_input_chunk_get_n_tokens (chunk. get ()) ;
13721413 return 0 ;
13731414 }
13741415};
0 commit comments