@@ -1080,19 +1080,21 @@ struct server_tokens {
10801080
10811081private:  //  disallow accessing these members directly, risking out-of-sync
10821082
1083-     //  map a **start** position in tokens to the image chunk
1084-     std::unordered_map<llama_pos, mtmd::input_chunk_ptr> map_pos_to_media;
1083+     //  map a **start** index in tokens to the image chunk
1084+     //  note: the order need to be in-sync with tokens
1085+     std::map<size_t , mtmd::input_chunk_ptr> map_idx_to_media;
10851086
10861087    //  list of tokens
1087-     //  it can include LLAMA_TOKEN_NULL, which  is used to indicate a token that  is not a text token 
1088-     //  a mtmd_input_chunk can occupy multiple tokens, one llama_token per **position** 
1089-     //  important: for models using mrope, an image  can contain  multiple tokens but will use only one **position** 
1088+     //    if the token  is LLAMA_TOKEN_NULL, it indicates that this position  is occupied by media chunk 
1089+     //    otherwise, it is a normal text token 
1090+     //  note: a non-text chunk  can occupy  multiple tokens (aka memory cells) in the token list 
10901091    llama_tokens tokens;
10911092
1092-     //  for ex. with input of 5 text tokens and 2 images:
1093-     //       [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1]
1094-     //  pos  0   1   2   3   4   5      6      7      8      9
1095-     //  map_pos_to_media will contain: {5, img0}, {8, img1}
1093+     //  for ex. with input of 5 text tokens and 2 images (each image occupies 3 tokens and 2 pos):
1094+     //       [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] [img1]
1095+     //  idx  0   1   2   3   4   5      6      7      8      9      10
1096+     //  pos  0   1   2   3   4   5      5      5      7      7      7
1097+     //  map_idx_to_media will contain: {5, img0}, {8, img1}
10961098
10971099public: 
10981100    server_tokens () = default ;
@@ -1117,30 +1119,48 @@ struct server_tokens {
11171119        }
11181120    }
11191121
1120-     server_tokens (const  llama_tokens & tokens, bool  has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {}
1122+     server_tokens (const  llama_tokens & tokens, bool  has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {
1123+     }
1124+ 
1125+     llama_pos pos_next () const  {
1126+         if  (!has_mtmd) {
1127+             return  tokens.size ();
1128+         }
1129+ 
1130+         llama_pos res = tokens.size ();
1131+ 
1132+         for  (auto  it = map_idx_to_media.begin (); it != map_idx_to_media.end (); ++it) {
1133+             const  auto  & chunk = it->second ;
1134+             res += mtmd_input_chunk_get_n_pos (chunk.get ()) - mtmd_input_chunk_get_n_tokens (chunk.get ());
1135+         }
1136+ 
1137+         return  res;
1138+     }
11211139
11221140    //  for debugging
11231141    std::string str () const  {
11241142        std::ostringstream oss;
11251143        oss << " tokens: "  ;
1126-         for  (const  auto  & t : tokens) {
1144+         for  (size_t  idx = 0 ; idx < tokens.size (); ++idx) {
1145+             llama_token t = tokens[idx];
1146+             oss << " idx:"   << idx << "  "  ;
11271147            if  (t == LLAMA_TOKEN_NULL) {
11281148                oss << " <embd> "  ;
11291149            } else  {
11301150                oss << t << "  "  ;
11311151            }
11321152        }
11331153        oss << " \n "  ;
1134-         oss << " image pos : "  ;
1135-         for  (const  auto  & it : map_pos_to_media ) {
1154+         oss << " image idx : "  ;
1155+         for  (const  auto  & it : map_idx_to_media ) {
11361156            oss << it.first  << " , "  ;
11371157        }
11381158        return  oss.str ();
11391159    }
11401160
1141-     const  mtmd::input_chunk_ptr & find_chunk (llama_pos pos ) const  {
1142-         auto  it = map_pos_to_media .find (pos );
1143-         if  (it != map_pos_to_media .end ()) {
1161+     const  mtmd::input_chunk_ptr & find_chunk (size_t  idx ) const  {
1162+         auto  it = map_idx_to_media .find (idx );
1163+         if  (it != map_idx_to_media .end ()) {
11441164            return  it->second ;
11451165        }
11461166        throw  std::runtime_error (" Chunk not found"  );
@@ -1158,13 +1178,13 @@ struct server_tokens {
11581178        auto  type = mtmd_input_chunk_get_type (chunk);
11591179        if  (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
11601180            GGML_ASSERT (has_mtmd);
1161-             const  int  n_pos  = mtmd_input_chunk_get_n_pos (chunk);
1162-             llama_pos start_pos  = tokens.size ();
1163-             for  (int  i = 0 ; i < n_pos ; ++i) {
1181+             const  size_t  n_tokens  = mtmd_input_chunk_get_n_tokens (chunk);
1182+             size_t  start_idx  = tokens.size ();
1183+             for  (size_t  i = 0 ; i < n_tokens ; ++i) {
11641184                tokens.emplace_back (LLAMA_TOKEN_NULL);
11651185            }
11661186            mtmd::input_chunk_ptr new_chunk (mtmd_input_chunk_copy (chunk));
1167-             map_pos_to_media[start_pos ] = std::move (new_chunk);
1187+             map_idx_to_media[start_idx ] = std::move (new_chunk);
11681188        } else  if  (type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
11691189            size_t  n_tokens;
11701190            const  auto  * text_tokens = mtmd_input_chunk_get_tokens_text (chunk, &n_tokens);
@@ -1178,18 +1198,18 @@ struct server_tokens {
11781198
11791199    //  appends server tokens, updates the media map. copies media chunks.
11801200    void  push_back (server_tokens & tokens) {
1181-         size_t  start_pos  = size ();
1201+         size_t  start_idx  = size ();
11821202        for  (size_t  i = 0 ; i < tokens.size (); i++) {
11831203            push_back (tokens[i]);
11841204        }
11851205        if  (tokens.has_mtmd ) {
11861206            //  Assert if we are copying MTMD chunks to a server_tokens that does not have mtmd.
11871207            //  We could also just check, but this will prevent silently dropping MTMD data.
11881208            GGML_ASSERT (has_mtmd);
1189-             for  (auto  it = tokens.map_pos_to_media .begin (); it != tokens.map_pos_to_media .end (); ) {
1190-                 auto  * chunk = tokens.map_pos_to_media [it->first ].get ();
1209+             for  (auto  it = tokens.map_idx_to_media .begin (); it != tokens.map_idx_to_media .end (); ) {
1210+                 auto  * chunk = tokens.map_idx_to_media [it->first ].get ();
11911211                mtmd::input_chunk_ptr new_chunk (mtmd_input_chunk_copy (chunk));
1192-                 map_pos_to_media[start_pos +it->first ] = std::move (new_chunk);
1212+                 map_idx_to_media[start_idx +it->first ] = std::move (new_chunk);
11931213            }
11941214        }
11951215    }
@@ -1245,10 +1265,10 @@ struct server_tokens {
12451265                }
12461266            }
12471267            //  remove all image chunks that are not used anymore
1248-             for  (auto  it = map_pos_to_media .begin (); it != map_pos_to_media .end (); ) {
1249-                 llama_pos pos  = it->first ;
1250-                 if  (pos  >= (llama_pos) n) {
1251-                     it = map_pos_to_media .erase (it);
1268+             for  (auto  it = map_idx_to_media .begin (); it != map_idx_to_media .end (); ) {
1269+                 size_t  idx  = it->first ;
1270+                 if  (idx  >= n) {
1271+                     it = map_idx_to_media .erase (it);
12521272                } else  {
12531273                    ++it;
12541274                }
@@ -1296,12 +1316,12 @@ struct server_tokens {
12961316                const  std::string id_ai = mtmd_input_chunk_get_id (a_chunk.get ());
12971317                const  std::string id_bi = mtmd_input_chunk_get_id (b_chunk.get ());
12981318
1299-                 const  size_t  pos_a  = mtmd_input_chunk_get_n_pos (a_chunk.get ());
1300-                 const  size_t  pos_b  = mtmd_input_chunk_get_n_pos (b_chunk.get ());
1319+                 const  size_t  n_tok_a  = mtmd_input_chunk_get_n_tokens (a_chunk.get ());
1320+                 const  size_t  n_tok_b  = mtmd_input_chunk_get_n_tokens (b_chunk.get ());
13011321
1302-                 if  (id_ai == id_bi && pos_a  == pos_b ) {
1303-                     GGML_ASSERT (pos_a  > 0  && " Invalid media chunk"  ); //  should never happen
1304-                     i += pos_a  - 1 ; //  will be +1 by the for loop
1322+                 if  (id_ai == id_bi && n_tok_a  == n_tok_b ) {
1323+                     GGML_ASSERT (n_tok_a  > 0  && " Invalid media chunk"  ); //  should never happen
1324+                     i += n_tok_a  - 1 ; //  will be +1 by the for loop
13051325                    continue ;
13061326                }
13071327
@@ -1329,8 +1349,8 @@ struct server_tokens {
13291349            if  (t == LLAMA_TOKEN_NULL) {
13301350                try  {
13311351                    const  auto  & chunk = find_chunk (i);
1332-                     size_t  n_pos  = mtmd_input_chunk_get_n_pos (chunk.get ());
1333-                     i += n_pos  - 1 ; //  will be +1 by the for loop
1352+                     size_t  n_tokens  = mtmd_input_chunk_get_n_tokens (chunk.get ());
1353+                     i += n_tokens  - 1 ; //  will be +1 by the for loop
13341354                } catch  (const  std::exception & e) {
13351355                    return  false ;
13361356                }
@@ -1345,16 +1365,17 @@ struct server_tokens {
13451365    int32_t  process_chunk (
13461366                llama_context * ctx,
13471367                mtmd_context * mctx,
1368+                 size_t  idx,
13481369                llama_pos n_past,
13491370                int32_t  seq_id,
1350-                 llama_pos  & n_pos_out ) const  {
1351-         const  auto  & chunk = find_chunk (n_past );
1371+                 size_t  & n_tokens_out ) const  {
1372+         const  auto  & chunk = find_chunk (idx );
13521373        const  char  * name = mtmd_input_chunk_get_type (chunk.get ()) == MTMD_INPUT_CHUNK_TYPE_IMAGE
13531374                            ? " image"   : " audio"  ;
13541375        SRV_INF (" processing %s...\n "  , name);
13551376        int32_t  n_batch = llama_n_batch (ctx);
13561377        int64_t  t0 = ggml_time_ms ();
1357-         llama_pos new_n_past = n_past; 
1378+         llama_pos new_n_past;  //  unused for now 
13581379        int32_t  result = mtmd_helper_eval_chunk_single (mctx, ctx,
13591380            chunk.get (),
13601381            n_past,
@@ -1365,10 +1386,10 @@ struct server_tokens {
13651386        SRV_INF (" %s processed in %"   PRId64 "  ms\n "  , name, ggml_time_ms () - t0);
13661387        if  (result != 0 ) {
13671388            LOG_ERR (" mtmd_helper_eval failed with status %d"  , result);
1368-             n_pos_out  = n_past ;
1389+             n_tokens_out  = 0 ;
13691390            return  result;
13701391        }
1371-         n_pos_out  = new_n_past ;
1392+         n_tokens_out  = mtmd_input_chunk_get_n_tokens (chunk. get ()) ;
13721393        return  0 ;
13731394    }
13741395};
0 commit comments