Skip to content

Commit 1e1b4af

Browse files
committed
add server_token::pos
1 parent 7e60d1c commit 1e1b4af

File tree

2 files changed

+93
-48
lines changed

2 files changed

+93
-48
lines changed

tools/server/server.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3926,16 +3926,16 @@ struct server_context {
39263926
// check if we should process the image
39273927
if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
39283928
// process the image
3929-
int32_t new_n_past;
3930-
int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.id, new_n_past);
3929+
size_t n_tokens_out = 0;
3930+
int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.id, n_tokens_out);
39313931
if (res != 0) {
39323932
SLT_ERR(slot, "failed to process image, res = %d\n", res);
39333933
send_error(slot, "failed to process image", ERROR_TYPE_SERVER);
39343934
slot.release();
39353935
continue;
39363936
}
39373937

3938-
slot.n_prompt_tokens_processed += new_n_past - slot.prompt.n_tokens();
3938+
slot.n_prompt_tokens_processed += n_tokens_out;
39393939

39403940
// add the image chunk to cache
39413941
{
@@ -3992,7 +3992,11 @@ struct server_context {
39923992
}
39933993

39943994
// embedding requires all tokens in the batch to be output
3995-
common_batch_add(batch, cur_tok, slot.prompt.n_tokens(), { slot.id }, slot.need_embd());
3995+
common_batch_add(batch,
3996+
cur_tok,
3997+
input_tokens.get_pos(slot.prompt.n_tokens()),
3998+
{ slot.id },
3999+
slot.need_embd());
39964000
slot.prompt.tokens.push_back(cur_tok);
39974001

39984002
slot.n_prompt_tokens_processed++;

tools/server/utils.hpp

Lines changed: 85 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,19 +1080,25 @@ struct server_tokens {
10801080

10811081
private: // disallow accessing these members directly, risking out-of-sync
10821082

1083-
// map a **start** position in tokens to the image chunk
1084-
std::unordered_map<llama_pos, mtmd::input_chunk_ptr> map_pos_to_media;
1083+
// map a **start** index in tokens to the image chunk
1084+
// note: the order need to be in-sync with tokens and pos
1085+
std::map<size_t, mtmd::input_chunk_ptr> map_idx_to_media;
10851086

10861087
// list of tokens
1087-
// it can include LLAMA_TOKEN_NULL, which is used to indicate a token that is not a text token
1088-
// a mtmd_input_chunk can occupy multiple tokens, one llama_token per **position**
1089-
// important: for models using mrope, an image can contain multiple tokens but will use only one **position**
1088+
// if the token is LLAMA_TOKEN_NULL, it indicates that this position is occupied by media chunk
1089+
// otherwise, it is a normal text token
1090+
// note: a non-text chunk can occupy multiple tokens (aka memory cells) in the token list
10901091
llama_tokens tokens;
10911092

1092-
// for ex. with input of 5 text tokens and 2 images:
1093-
// [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1]
1094-
// pos 0 1 2 3 4 5 6 7 8 9
1095-
// map_pos_to_media will contain: {5, img0}, {8, img1}
1093+
// the position per-token (llama_pos) in the overall input
1094+
// useful for M-RoPE, where the position is different from the index in tokens
1095+
std::vector<llama_pos> pos;
1096+
1097+
// for ex. with input of 5 text tokens and 2 images (each image occupies 3 tokens and 2 pos):
1098+
// [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] [img1]
1099+
// idx 0 1 2 3 4 5 6 7 8 9 10
1100+
// pos 0 1 2 3 4 5 5 5 7 7 7
1101+
// map_idx_to_media will contain: {5, img0}, {8, img1}
10961102

10971103
public:
10981104
server_tokens() = default;
@@ -1117,30 +1123,57 @@ struct server_tokens {
11171123
}
11181124
}
11191125

1120-
server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {}
1126+
server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {
1127+
for (llama_pos i = 0; i < (llama_pos)tokens.size(); ++i) {
1128+
pos.push_back(i);
1129+
}
1130+
}
1131+
1132+
llama_pos next_pos() const {
1133+
if (tokens.empty()) {
1134+
return 0;
1135+
} else if (tokens.back() != LLAMA_TOKEN_NULL) {
1136+
return pos.back() + 1;
1137+
} else {
1138+
printf("%s", str().c_str());
1139+
// find the last media chunk
1140+
GGML_ASSERT(has_mtmd);
1141+
GGML_ASSERT(!map_idx_to_media.empty());
1142+
const auto & chunk = map_idx_to_media.rbegin()->second;
1143+
return pos.back() + mtmd_input_chunk_get_n_pos(chunk.get());
1144+
}
1145+
}
1146+
1147+
llama_pos get_pos(size_t idx) const {
1148+
GGML_ASSERT(idx < pos.size());
1149+
return pos[idx];
1150+
}
11211151

11221152
// for debugging
11231153
std::string str() const {
11241154
std::ostringstream oss;
11251155
oss << "tokens: ";
1126-
for (const auto & t : tokens) {
1156+
for (size_t idx = 0; idx < tokens.size(); ++idx) {
1157+
llama_token t = tokens[idx];
1158+
llama_pos p = pos[idx];
1159+
oss << "idx:" << idx << " ";
11271160
if (t == LLAMA_TOKEN_NULL) {
1128-
oss << "<embd> ";
1161+
oss << "<embd>(" << p << ")\n";
11291162
} else {
1130-
oss << t << " ";
1163+
oss << t << "(" << p << ")\n";
11311164
}
11321165
}
11331166
oss << "\n";
1134-
oss << "image pos: ";
1135-
for (const auto & it : map_pos_to_media) {
1167+
oss << "image idx: ";
1168+
for (const auto & it : map_idx_to_media) {
11361169
oss << it.first << ", ";
11371170
}
11381171
return oss.str();
11391172
}
11401173

1141-
const mtmd::input_chunk_ptr & find_chunk(llama_pos pos) const {
1142-
auto it = map_pos_to_media.find(pos);
1143-
if (it != map_pos_to_media.end()) {
1174+
const mtmd::input_chunk_ptr & find_chunk(size_t idx) const {
1175+
auto it = map_idx_to_media.find(idx);
1176+
if (it != map_idx_to_media.end()) {
11441177
return it->second;
11451178
}
11461179
throw std::runtime_error("Chunk not found");
@@ -1150,6 +1183,7 @@ struct server_tokens {
11501183
if (tok == LLAMA_TOKEN_NULL) {
11511184
throw std::runtime_error("Invalid token");
11521185
}
1186+
pos.emplace_back(next_pos());
11531187
tokens.emplace_back(tok);
11541188
}
11551189

@@ -1158,13 +1192,15 @@ struct server_tokens {
11581192
auto type = mtmd_input_chunk_get_type(chunk);
11591193
if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
11601194
GGML_ASSERT(has_mtmd);
1161-
const int n_pos = mtmd_input_chunk_get_n_pos(chunk);
1162-
llama_pos start_pos = tokens.size();
1163-
for (int i = 0; i < n_pos; ++i) {
1195+
const size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk);
1196+
const llama_pos cur_pos = next_pos();
1197+
size_t start_idx = tokens.size();
1198+
for (size_t i = 0; i < n_tokens; ++i) {
1199+
pos.emplace_back(cur_pos);
11641200
tokens.emplace_back(LLAMA_TOKEN_NULL);
11651201
}
11661202
mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
1167-
map_pos_to_media[start_pos] = std::move(new_chunk);
1203+
map_idx_to_media[start_idx] = std::move(new_chunk);
11681204
} else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
11691205
size_t n_tokens;
11701206
const auto * text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
@@ -1178,18 +1214,18 @@ struct server_tokens {
11781214

11791215
// appends server tokens, updates the media map. copies media chunks.
11801216
void push_back(server_tokens & tokens) {
1181-
size_t start_pos = size();
1217+
size_t start_idx = size();
11821218
for (size_t i = 0; i < tokens.size(); i++) {
11831219
push_back(tokens[i]);
11841220
}
11851221
if (tokens.has_mtmd) {
11861222
// Assert if we are copying MTMD chunks to a server_tokens that does not have mtmd.
11871223
// We could also just check, but this will prevent silently dropping MTMD data.
11881224
GGML_ASSERT(has_mtmd);
1189-
for (auto it = tokens.map_pos_to_media.begin(); it != tokens.map_pos_to_media.end(); ) {
1190-
auto * chunk = tokens.map_pos_to_media[it->first].get();
1225+
for (auto it = tokens.map_idx_to_media.begin(); it != tokens.map_idx_to_media.end(); ) {
1226+
auto * chunk = tokens.map_idx_to_media[it->first].get();
11911227
mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
1192-
map_pos_to_media[start_pos+it->first] = std::move(new_chunk);
1228+
map_idx_to_media[start_idx+it->first] = std::move(new_chunk);
11931229
}
11941230
}
11951231
}
@@ -1198,6 +1234,11 @@ struct server_tokens {
11981234
void insert(const llama_tokens & inp_tokens) {
11991235
GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled
12001236
tokens.insert(tokens.end(), inp_tokens.begin(), inp_tokens.end());
1237+
// rebuild the pos vector
1238+
pos.clear();
1239+
for (llama_pos i = 0; i < (llama_pos)tokens.size(); ++i) {
1240+
pos.emplace_back(i);
1241+
}
12011242
}
12021243

12031244
// for compatibility with speculative decoding, ctx shift, slot save/load
@@ -1245,10 +1286,10 @@ struct server_tokens {
12451286
}
12461287
}
12471288
// remove all image chunks that are not used anymore
1248-
for (auto it = map_pos_to_media.begin(); it != map_pos_to_media.end(); ) {
1249-
llama_pos pos = it->first;
1250-
if (pos >= (llama_pos)n) {
1251-
it = map_pos_to_media.erase(it);
1289+
for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ) {
1290+
size_t idx = it->first;
1291+
if (idx >= n) {
1292+
it = map_idx_to_media.erase(it);
12521293
} else {
12531294
++it;
12541295
}
@@ -1296,12 +1337,12 @@ struct server_tokens {
12961337
const std::string id_ai = mtmd_input_chunk_get_id(a_chunk.get());
12971338
const std::string id_bi = mtmd_input_chunk_get_id(b_chunk.get());
12981339

1299-
const size_t pos_a = mtmd_input_chunk_get_n_pos(a_chunk.get());
1300-
const size_t pos_b = mtmd_input_chunk_get_n_pos(b_chunk.get());
1340+
const size_t n_tok_a = mtmd_input_chunk_get_n_tokens(a_chunk.get());
1341+
const size_t n_tok_b = mtmd_input_chunk_get_n_tokens(b_chunk.get());
13011342

1302-
if (id_ai == id_bi && pos_a == pos_b) {
1303-
GGML_ASSERT(pos_a > 0 && "Invalid media chunk"); // should never happen
1304-
i += pos_a - 1; // will be +1 by the for loop
1343+
if (id_ai == id_bi && n_tok_a == n_tok_b) {
1344+
GGML_ASSERT(n_tok_a > 0 && "Invalid media chunk"); // should never happen
1345+
i += n_tok_a - 1; // will be +1 by the for loop
13051346
continue;
13061347
}
13071348

@@ -1329,8 +1370,8 @@ struct server_tokens {
13291370
if (t == LLAMA_TOKEN_NULL) {
13301371
try {
13311372
const auto & chunk = find_chunk(i);
1332-
size_t n_pos = mtmd_input_chunk_get_n_pos(chunk.get());
1333-
i += n_pos - 1; // will be +1 by the for loop
1373+
size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk.get());
1374+
i += n_tokens - 1; // will be +1 by the for loop
13341375
} catch (const std::exception & e) {
13351376
return false;
13361377
}
@@ -1345,30 +1386,30 @@ struct server_tokens {
13451386
int32_t process_chunk(
13461387
llama_context * ctx,
13471388
mtmd_context * mctx,
1348-
llama_pos n_past,
1389+
size_t idx,
13491390
int32_t seq_id,
1350-
llama_pos & n_pos_out) const {
1351-
const auto & chunk = find_chunk(n_past);
1391+
size_t & n_tokens_out) const {
1392+
const auto & chunk = find_chunk(idx);
13521393
const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE
13531394
? "image" : "audio";
13541395
SRV_INF("processing %s...\n", name);
13551396
int32_t n_batch = llama_n_batch(ctx);
13561397
int64_t t0 = ggml_time_ms();
1357-
llama_pos new_n_past = n_past;
1398+
llama_pos new_n_past; // unused for now
13581399
int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx,
13591400
chunk.get(),
1360-
n_past,
1401+
pos[idx], // position
13611402
seq_id,
13621403
n_batch,
13631404
true, // logits last
13641405
&new_n_past);
13651406
SRV_INF("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0);
13661407
if (result != 0) {
13671408
LOG_ERR("mtmd_helper_eval failed with status %d", result);
1368-
n_pos_out = n_past;
1409+
n_tokens_out = 0;
13691410
return result;
13701411
}
1371-
n_pos_out = new_n_past;
1412+
n_tokens_out = mtmd_input_chunk_get_n_tokens(chunk.get());
13721413
return 0;
13731414
}
13741415
};

0 commit comments

Comments
 (0)