Skip to content

Commit beb5c03

Browse files
ggerganovngxson
andcommitted
server : add server_tokens::pos_next()
Co-authored-by: Xuan-Son Nguyen <[email protected]>
1 parent 278fbdd commit beb5c03

File tree

2 files changed

+72
-47
lines changed

2 files changed

+72
-47
lines changed

tools/server/server.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3623,7 +3623,7 @@ struct server_context {
36233623

36243624
slot.i_batch = batch.n_tokens;
36253625

3626-
common_batch_add(batch, slot.sampled, slot.prompt.n_tokens(), { slot.id }, true);
3626+
common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
36273627

36283628
slot.prompt.tokens.push_back(slot.sampled);
36293629

@@ -3926,16 +3926,16 @@ struct server_context {
39263926
// check if we should process the image
39273927
if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
39283928
// process the image
3929-
int32_t new_n_past;
3930-
int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.id, new_n_past);
3929+
size_t n_tokens_out = 0;
3930+
int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out);
39313931
if (res != 0) {
39323932
SLT_ERR(slot, "failed to process image, res = %d\n", res);
39333933
send_error(slot, "failed to process image", ERROR_TYPE_SERVER);
39343934
slot.release();
39353935
continue;
39363936
}
39373937

3938-
slot.n_prompt_tokens_processed += new_n_past - slot.prompt.n_tokens();
3938+
slot.n_prompt_tokens_processed += n_tokens_out;
39393939

39403940
// add the image chunk to cache
39413941
{
@@ -3992,7 +3992,11 @@ struct server_context {
39923992
}
39933993

39943994
// embedding requires all tokens in the batch to be output
3995-
common_batch_add(batch, cur_tok, slot.prompt.n_tokens(), { slot.id }, slot.need_embd());
3995+
common_batch_add(batch,
3996+
cur_tok,
3997+
slot.prompt.tokens.pos_next(),
3998+
{ slot.id },
3999+
slot.need_embd());
39964000
slot.prompt.tokens.push_back(cur_tok);
39974001

39984002
slot.n_prompt_tokens_processed++;
@@ -4287,10 +4291,10 @@ struct server_context {
42874291

42884292
// construct the speculation batch
42894293
common_batch_clear(slot.batch_spec);
4290-
common_batch_add (slot.batch_spec, id, slot.prompt.n_tokens(), { slot.id }, true);
4294+
common_batch_add (slot.batch_spec, id, slot.prompt.tokens.pos_next(), { slot.id }, true);
42914295

42924296
for (size_t i = 0; i < draft.size(); ++i) {
4293-
common_batch_add(slot.batch_spec, draft[i], slot.prompt.n_tokens() + 1 + i, { slot.id }, true);
4297+
common_batch_add(slot.batch_spec, draft[i], slot.prompt.tokens.pos_next() + 1 + i, { slot.id }, true);
42944298
}
42954299

42964300
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);

tools/server/utils.hpp

Lines changed: 61 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,19 +1080,21 @@ struct server_tokens {
10801080

10811081
private: // disallow accessing these members directly, risking out-of-sync
10821082

1083-
// map a **start** position in tokens to the image chunk
1084-
std::unordered_map<llama_pos, mtmd::input_chunk_ptr> map_pos_to_media;
1083+
// map a **start** index in tokens to the image chunk
1084+
// note: the order need to be in-sync with tokens
1085+
std::map<size_t, mtmd::input_chunk_ptr> map_idx_to_media;
10851086

10861087
// list of tokens
1087-
// it can include LLAMA_TOKEN_NULL, which is used to indicate a token that is not a text token
1088-
// a mtmd_input_chunk can occupy multiple tokens, one llama_token per **position**
1089-
// important: for models using mrope, an image can contain multiple tokens but will use only one **position**
1088+
// if the token is LLAMA_TOKEN_NULL, it indicates that this position is occupied by media chunk
1089+
// otherwise, it is a normal text token
1090+
// note: a non-text chunk can occupy multiple tokens (aka memory cells) in the token list
10901091
llama_tokens tokens;
10911092

1092-
// for ex. with input of 5 text tokens and 2 images:
1093-
// [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1]
1094-
// pos 0 1 2 3 4 5 6 7 8 9
1095-
// map_pos_to_media will contain: {5, img0}, {8, img1}
1093+
// for ex. with input of 5 text tokens and 2 images (each image occupies 3 tokens and 2 pos):
1094+
// [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] [img1]
1095+
// idx 0 1 2 3 4 5 6 7 8 9 10
1096+
// pos 0 1 2 3 4 5 5 5 7 7 7
1097+
// map_idx_to_media will contain: {5, img0}, {8, img1}
10961098

10971099
public:
10981100
server_tokens() = default;
@@ -1117,30 +1119,48 @@ struct server_tokens {
11171119
}
11181120
}
11191121

1120-
server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {}
1122+
server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {
1123+
}
1124+
1125+
llama_pos pos_next() const {
1126+
if (!has_mtmd) {
1127+
return tokens.size();
1128+
}
1129+
1130+
llama_pos res = tokens.size();
1131+
1132+
for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) {
1133+
const auto & chunk = it->second;
1134+
res += mtmd_input_chunk_get_n_pos(chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk.get());
1135+
}
1136+
1137+
return res;
1138+
}
11211139

11221140
// for debugging
11231141
std::string str() const {
11241142
std::ostringstream oss;
11251143
oss << "tokens: ";
1126-
for (const auto & t : tokens) {
1144+
for (size_t idx = 0; idx < tokens.size(); ++idx) {
1145+
llama_token t = tokens[idx];
1146+
oss << "idx:" << idx << " ";
11271147
if (t == LLAMA_TOKEN_NULL) {
11281148
oss << "<embd> ";
11291149
} else {
11301150
oss << t << " ";
11311151
}
11321152
}
11331153
oss << "\n";
1134-
oss << "image pos: ";
1135-
for (const auto & it : map_pos_to_media) {
1154+
oss << "image idx: ";
1155+
for (const auto & it : map_idx_to_media) {
11361156
oss << it.first << ", ";
11371157
}
11381158
return oss.str();
11391159
}
11401160

1141-
const mtmd::input_chunk_ptr & find_chunk(llama_pos pos) const {
1142-
auto it = map_pos_to_media.find(pos);
1143-
if (it != map_pos_to_media.end()) {
1161+
const mtmd::input_chunk_ptr & find_chunk(size_t idx) const {
1162+
auto it = map_idx_to_media.find(idx);
1163+
if (it != map_idx_to_media.end()) {
11441164
return it->second;
11451165
}
11461166
throw std::runtime_error("Chunk not found");
@@ -1158,13 +1178,13 @@ struct server_tokens {
11581178
auto type = mtmd_input_chunk_get_type(chunk);
11591179
if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
11601180
GGML_ASSERT(has_mtmd);
1161-
const int n_pos = mtmd_input_chunk_get_n_pos(chunk);
1162-
llama_pos start_pos = tokens.size();
1163-
for (int i = 0; i < n_pos; ++i) {
1181+
const size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk);
1182+
size_t start_idx = tokens.size();
1183+
for (size_t i = 0; i < n_tokens; ++i) {
11641184
tokens.emplace_back(LLAMA_TOKEN_NULL);
11651185
}
11661186
mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
1167-
map_pos_to_media[start_pos] = std::move(new_chunk);
1187+
map_idx_to_media[start_idx] = std::move(new_chunk);
11681188
} else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
11691189
size_t n_tokens;
11701190
const auto * text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
@@ -1178,18 +1198,18 @@ struct server_tokens {
11781198

11791199
// appends server tokens, updates the media map. copies media chunks.
11801200
void push_back(server_tokens & tokens) {
1181-
size_t start_pos = size();
1201+
size_t start_idx = size();
11821202
for (size_t i = 0; i < tokens.size(); i++) {
11831203
push_back(tokens[i]);
11841204
}
11851205
if (tokens.has_mtmd) {
11861206
// Assert if we are copying MTMD chunks to a server_tokens that does not have mtmd.
11871207
// We could also just check, but this will prevent silently dropping MTMD data.
11881208
GGML_ASSERT(has_mtmd);
1189-
for (auto it = tokens.map_pos_to_media.begin(); it != tokens.map_pos_to_media.end(); ) {
1190-
auto * chunk = tokens.map_pos_to_media[it->first].get();
1209+
for (auto it = tokens.map_idx_to_media.begin(); it != tokens.map_idx_to_media.end(); ) {
1210+
auto * chunk = tokens.map_idx_to_media[it->first].get();
11911211
mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
1192-
map_pos_to_media[start_pos+it->first] = std::move(new_chunk);
1212+
map_idx_to_media[start_idx+it->first] = std::move(new_chunk);
11931213
}
11941214
}
11951215
}
@@ -1245,10 +1265,10 @@ struct server_tokens {
12451265
}
12461266
}
12471267
// remove all image chunks that are not used anymore
1248-
for (auto it = map_pos_to_media.begin(); it != map_pos_to_media.end(); ) {
1249-
llama_pos pos = it->first;
1250-
if (pos >= (llama_pos)n) {
1251-
it = map_pos_to_media.erase(it);
1268+
for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ) {
1269+
size_t idx = it->first;
1270+
if (idx >= n) {
1271+
it = map_idx_to_media.erase(it);
12521272
} else {
12531273
++it;
12541274
}
@@ -1296,12 +1316,12 @@ struct server_tokens {
12961316
const std::string id_ai = mtmd_input_chunk_get_id(a_chunk.get());
12971317
const std::string id_bi = mtmd_input_chunk_get_id(b_chunk.get());
12981318

1299-
const size_t pos_a = mtmd_input_chunk_get_n_pos(a_chunk.get());
1300-
const size_t pos_b = mtmd_input_chunk_get_n_pos(b_chunk.get());
1319+
const size_t n_tok_a = mtmd_input_chunk_get_n_tokens(a_chunk.get());
1320+
const size_t n_tok_b = mtmd_input_chunk_get_n_tokens(b_chunk.get());
13011321

1302-
if (id_ai == id_bi && pos_a == pos_b) {
1303-
GGML_ASSERT(pos_a > 0 && "Invalid media chunk"); // should never happen
1304-
i += pos_a - 1; // will be +1 by the for loop
1322+
if (id_ai == id_bi && n_tok_a == n_tok_b) {
1323+
GGML_ASSERT(n_tok_a > 0 && "Invalid media chunk"); // should never happen
1324+
i += n_tok_a - 1; // will be +1 by the for loop
13051325
continue;
13061326
}
13071327

@@ -1329,8 +1349,8 @@ struct server_tokens {
13291349
if (t == LLAMA_TOKEN_NULL) {
13301350
try {
13311351
const auto & chunk = find_chunk(i);
1332-
size_t n_pos = mtmd_input_chunk_get_n_pos(chunk.get());
1333-
i += n_pos - 1; // will be +1 by the for loop
1352+
size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk.get());
1353+
i += n_tokens - 1; // will be +1 by the for loop
13341354
} catch (const std::exception & e) {
13351355
return false;
13361356
}
@@ -1345,16 +1365,17 @@ struct server_tokens {
13451365
int32_t process_chunk(
13461366
llama_context * ctx,
13471367
mtmd_context * mctx,
1368+
size_t idx,
13481369
llama_pos n_past,
13491370
int32_t seq_id,
1350-
llama_pos & n_pos_out) const {
1351-
const auto & chunk = find_chunk(n_past);
1371+
size_t & n_tokens_out) const {
1372+
const auto & chunk = find_chunk(idx);
13521373
const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE
13531374
? "image" : "audio";
13541375
SRV_INF("processing %s...\n", name);
13551376
int32_t n_batch = llama_n_batch(ctx);
13561377
int64_t t0 = ggml_time_ms();
1357-
llama_pos new_n_past = n_past;
1378+
llama_pos new_n_past; // unused for now
13581379
int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx,
13591380
chunk.get(),
13601381
n_past,
@@ -1365,10 +1386,10 @@ struct server_tokens {
13651386
SRV_INF("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0);
13661387
if (result != 0) {
13671388
LOG_ERR("mtmd_helper_eval failed with status %d", result);
1368-
n_pos_out = n_past;
1389+
n_tokens_out = 0;
13691390
return result;
13701391
}
1371-
n_pos_out = new_n_past;
1392+
n_tokens_out = mtmd_input_chunk_get_n_tokens(chunk.get());
13721393
return 0;
13731394
}
13741395
};

0 commit comments

Comments
 (0)