Skip to content

Commit 80e3672

Browse files
committed
server : remove pos state from server_tokens
1 parent 610cab4 commit 80e3672

File tree

2 files changed

+21
-40
lines changed

2 files changed

+21
-40
lines changed

tools/server/server.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3623,7 +3623,7 @@ struct server_context {
36233623

36243624
slot.i_batch = batch.n_tokens;
36253625

3626-
common_batch_add(batch, slot.sampled, slot.prompt.n_tokens(), { slot.id }, true);
3626+
common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
36273627

36283628
slot.prompt.tokens.push_back(slot.sampled);
36293629

@@ -3927,7 +3927,7 @@ struct server_context {
39273927
if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
39283928
// process the image
39293929
size_t n_tokens_out = 0;
3930-
int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.id, n_tokens_out);
3930+
int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out);
39313931
if (res != 0) {
39323932
SLT_ERR(slot, "failed to process image, res = %d\n", res);
39333933
send_error(slot, "failed to process image", ERROR_TYPE_SERVER);
@@ -3994,7 +3994,7 @@ struct server_context {
39943994
// embedding requires all tokens in the batch to be output
39953995
common_batch_add(batch,
39963996
cur_tok,
3997-
input_tokens.get_pos(slot.prompt.n_tokens()),
3997+
slot.prompt.tokens.pos_next(),
39983998
{ slot.id },
39993999
slot.need_embd());
40004000
slot.prompt.tokens.push_back(cur_tok);
@@ -4291,10 +4291,10 @@ struct server_context {
42914291

42924292
// construct the speculation batch
42934293
common_batch_clear(slot.batch_spec);
4294-
common_batch_add (slot.batch_spec, id, slot.prompt.n_tokens(), { slot.id }, true);
4294+
common_batch_add (slot.batch_spec, id, slot.prompt.tokens.pos_next(), { slot.id }, true);
42954295

42964296
for (size_t i = 0; i < draft.size(); ++i) {
4297-
common_batch_add(slot.batch_spec, draft[i], slot.prompt.n_tokens() + 1 + i, { slot.id }, true);
4297+
common_batch_add(slot.batch_spec, draft[i], slot.prompt.tokens.pos_next() + 1 + i, { slot.id }, true);
42984298
}
42994299

43004300
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);

tools/server/utils.hpp

Lines changed: 16 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,7 +1081,7 @@ struct server_tokens {
10811081
private: // disallow accessing these members directly, risking out-of-sync
10821082

10831083
// map a **start** index in tokens to the image chunk
1084-
// note: the order need to be in-sync with tokens and pos
1084+
// note: the order need to be in-sync with tokens
10851085
std::map<size_t, mtmd::input_chunk_ptr> map_idx_to_media;
10861086

10871087
// list of tokens
@@ -1090,10 +1090,6 @@ struct server_tokens {
10901090
// note: a non-text chunk can occupy multiple tokens (aka memory cells) in the token list
10911091
llama_tokens tokens;
10921092

1093-
// the position per-token (llama_pos) in the overall input
1094-
// useful for M-RoPE, where the position is different from the index in tokens
1095-
std::vector<llama_pos> pos;
1096-
10971093
// for ex. with input of 5 text tokens and 2 images (each image occupies 3 tokens and 2 pos):
10981094
// [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] [img1]
10991095
// idx 0 1 2 3 4 5 6 7 8 9 10
@@ -1124,28 +1120,21 @@ struct server_tokens {
11241120
}
11251121

11261122
server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {
1127-
for (llama_pos i = 0; i < (llama_pos)tokens.size(); ++i) {
1128-
pos.push_back(i);
1129-
}
11301123
}
11311124

1132-
llama_pos next_pos() const {
1133-
if (tokens.empty()) {
1134-
return 0;
1135-
} else if (tokens.back() != LLAMA_TOKEN_NULL) {
1136-
return pos.back() + 1;
1137-
} else {
1138-
// find the last media chunk
1139-
GGML_ASSERT(has_mtmd);
1140-
GGML_ASSERT(!map_idx_to_media.empty());
1141-
const auto & chunk = map_idx_to_media.rbegin()->second;
1142-
return pos.back() + mtmd_input_chunk_get_n_pos(chunk.get());
1125+
llama_pos pos_next() const {
1126+
if (!has_mtmd) {
1127+
return tokens.size();
11431128
}
1144-
}
11451129

1146-
llama_pos get_pos(size_t idx) const {
1147-
GGML_ASSERT(idx < pos.size());
1148-
return pos[idx];
1130+
llama_pos res = tokens.size();
1131+
1132+
for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) {
1133+
const auto & chunk = it->second;
1134+
res += mtmd_input_chunk_get_n_pos(chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk.get());
1135+
}
1136+
1137+
return res;
11491138
}
11501139

11511140
// for debugging
@@ -1154,12 +1143,11 @@ struct server_tokens {
11541143
oss << "tokens: ";
11551144
for (size_t idx = 0; idx < tokens.size(); ++idx) {
11561145
llama_token t = tokens[idx];
1157-
llama_pos p = pos[idx];
11581146
oss << "idx:" << idx << " ";
11591147
if (t == LLAMA_TOKEN_NULL) {
1160-
oss << "<embd>(" << p << ")\n";
1148+
oss << "<embd> ";
11611149
} else {
1162-
oss << t << "(" << p << ")\n";
1150+
oss << t << " ";
11631151
}
11641152
}
11651153
oss << "\n";
@@ -1182,7 +1170,6 @@ struct server_tokens {
11821170
if (tok == LLAMA_TOKEN_NULL) {
11831171
throw std::runtime_error("Invalid token");
11841172
}
1185-
pos.emplace_back(next_pos());
11861173
tokens.emplace_back(tok);
11871174
}
11881175

@@ -1192,10 +1179,8 @@ struct server_tokens {
11921179
if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
11931180
GGML_ASSERT(has_mtmd);
11941181
const size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk);
1195-
const llama_pos cur_pos = next_pos();
11961182
size_t start_idx = tokens.size();
11971183
for (size_t i = 0; i < n_tokens; ++i) {
1198-
pos.emplace_back(cur_pos);
11991184
tokens.emplace_back(LLAMA_TOKEN_NULL);
12001185
}
12011186
mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
@@ -1233,11 +1218,6 @@ struct server_tokens {
12331218
void insert(const llama_tokens & inp_tokens) {
12341219
GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled
12351220
tokens.insert(tokens.end(), inp_tokens.begin(), inp_tokens.end());
1236-
// rebuild the pos vector
1237-
pos.clear();
1238-
for (llama_pos i = 0; i < (llama_pos)tokens.size(); ++i) {
1239-
pos.emplace_back(i);
1240-
}
12411221
}
12421222

12431223
// for compatibility with speculative decoding, ctx shift, slot save/load
@@ -1386,6 +1366,7 @@ struct server_tokens {
13861366
llama_context * ctx,
13871367
mtmd_context * mctx,
13881368
size_t idx,
1369+
llama_pos n_past,
13891370
int32_t seq_id,
13901371
size_t & n_tokens_out) const {
13911372
const auto & chunk = find_chunk(idx);
@@ -1397,7 +1378,7 @@ struct server_tokens {
13971378
llama_pos new_n_past; // unused for now
13981379
int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx,
13991380
chunk.get(),
1400-
pos[idx], // position
1381+
n_past,
14011382
seq_id,
14021383
n_batch,
14031384
true, // logits last

0 commit comments

Comments
 (0)