Skip to content

Commit 58100b3

Browse files
committed
refactor server_inp to server_tokens
1 parent 0f39770 commit 58100b3

File tree

2 files changed

+162
-205
lines changed

2 files changed

+162
-205
lines changed

examples/server/server.cpp

Lines changed: 74 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ struct server_task {
198198

199199
// used by SERVER_TASK_TYPE_INFERENCE
200200
slot_params params;
201-
server_inputs prompt_tokens;
201+
server_tokens prompt_tokens;
202202
int id_selected_slot = -1;
203203

204204
// used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE
@@ -1277,14 +1277,14 @@ struct server_slot {
12771277
int32_t n_prompt_tokens_processed = 0;
12781278

12791279
// input prompt tokens
1280-
server_inputs prompt_tokens;
1280+
server_tokens prompt_tokens;
12811281

12821282
size_t last_nl_pos = 0;
12831283

12841284
std::string generated_text;
12851285
llama_tokens generated_tokens;
12861286

1287-
server_inputs cache_tokens;
1287+
server_tokens cache_tokens;
12881288

12891289
std::vector<completion_token_output> generated_token_probs;
12901290

@@ -2020,6 +2020,7 @@ struct server_context {
20202020
slot.n_ctx = n_ctx_slot;
20212021
slot.n_predict = params_base.n_predict;
20222022
slot.mctx = mctx;
2023+
slot.cache_tokens.has_mtmd = mctx != nullptr;
20232024

20242025
if (model_dft) {
20252026
slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
@@ -2096,7 +2097,7 @@ struct server_context {
20962097
int cur_lcs_len = slot.cache_tokens.get_common_prefix(task.prompt_tokens);
20972098

20982099
// fraction of the common subsequence length compared to the current slot's prompt length
2099-
float cur_similarity = static_cast<float>(cur_lcs_len) / static_cast<int>(slot.cache_tokens.n_tokens());
2100+
float cur_similarity = static_cast<float>(cur_lcs_len) / static_cast<int>(slot.cache_tokens.size());
21002101

21012102
// select the current slot if the criteria match
21022103
if (cur_lcs_len > lcs_len && cur_similarity > slot_prompt_similarity) {
@@ -2135,7 +2136,7 @@ struct server_context {
21352136
return ret;
21362137
}
21372138

2138-
bool can_be_detokenized(const struct llama_context * ctx, const server_inputs & inp) {
2139+
bool can_be_detokenized(const struct llama_context * ctx, const server_tokens & inp) {
21392140
const llama_model * model = llama_get_model(ctx);
21402141
const llama_vocab * vocab = llama_model_get_vocab(model);
21412142
const int32_t n_vocab = llama_vocab_n_tokens(vocab);
@@ -2786,7 +2787,7 @@ struct server_context {
27862787
break;
27872788
}
27882789

2789-
const size_t token_count = slot->cache_tokens.n_tokens();
2790+
const size_t token_count = slot->cache_tokens.size();
27902791
const int64_t t_start = ggml_time_us();
27912792

27922793
std::string filename = task.slot_action.filename;
@@ -2877,7 +2878,7 @@ struct server_context {
28772878
}
28782879

28792880
// Erase token cache
2880-
const size_t n_erased = slot->cache_tokens.n_tokens();
2881+
const size_t n_erased = slot->cache_tokens.size();
28812882
llama_kv_self_seq_rm(ctx, slot->id, -1, -1);
28822883
slot->cache_tokens.clear();
28832884

@@ -2957,11 +2958,11 @@ struct server_context {
29572958
llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard);
29582959

29592960
if (slot.params.cache_prompt) {
2960-
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.chunks.size(); i++) {
2961-
slot.cache_tokens.chunks[i - n_discard] = std::move(slot.cache_tokens.chunks[i]);
2961+
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
2962+
slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
29622963
}
29632964

2964-
slot.cache_tokens.chunks.resize(slot.cache_tokens.chunks.size() - n_discard);
2965+
slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
29652966
}
29662967

29672968
slot.n_past -= n_discard;
@@ -3004,7 +3005,7 @@ struct server_context {
30043005
}
30053006

30063007
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n",
3007-
slot.n_ctx, slot.n_past, (int) slot.cache_tokens.n_tokens(), slot.truncated);
3008+
slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated);
30083009
}
30093010

30103011
// process in chunks of params.n_batch
@@ -3033,23 +3034,23 @@ struct server_context {
30333034
slot.t_start_generation = 0;
30343035

30353036
slot.n_past = 0;
3036-
slot.n_prompt_tokens = prompt_tokens.n_tokens();
3037+
slot.n_prompt_tokens = prompt_tokens.size();
30373038
slot.state = SLOT_STATE_PROCESSING_PROMPT;
30383039

30393040
SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
30403041

30413042
// print prompt tokens (for debugging)
3042-
// if (1) {
3043-
// // first 16 tokens (avoid flooding logs)
3044-
// for (int i = 0; i < std::min<int>(16, prompt_tokens.size()); i++) {
3045-
// SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
3046-
// }
3047-
// } else {
3048-
// // all
3049-
// for (int i = 0; i < (int) prompt_tokens.size(); i++) {
3050-
// SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
3051-
// }
3052-
// }
3043+
/*if (1) {
3044+
// first 16 tokens (avoid flooding logs)
3045+
for (int i = 0; i < std::min<int>(16, prompt_tokens.size()); i++) {
3046+
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
3047+
}
3048+
} else {
3049+
// all
3050+
for (int i = 0; i < (int) prompt_tokens.size(); i++) {
3051+
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
3052+
}
3053+
}*/
30533054

30543055
// empty prompt passed -> release the slot and send empty response
30553056
if (prompt_tokens.empty()) {
@@ -3113,7 +3114,7 @@ struct server_context {
31133114
prompt_tokens.set_text_tokens(new_tokens);
31143115

31153116
slot.truncated = true;
3116-
slot.n_prompt_tokens = prompt_tokens.n_tokens();
3117+
slot.n_prompt_tokens = prompt_tokens.size();
31173118

31183119
SLT_WRN(slot, "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens);
31193120

@@ -3136,13 +3137,13 @@ struct server_context {
31363137

31373138
SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past);
31383139

3139-
while (head_c < slot.cache_tokens.chunks.size() &&
3140-
head_p < prompt_tokens.chunks.size()) {
3140+
while (head_c < slot.cache_tokens.size() &&
3141+
head_p < prompt_tokens.size()) {
31413142

31423143
size_t n_match = 0;
3143-
while (head_c + n_match < slot.cache_tokens.chunks.size() &&
3144-
head_p + n_match < prompt_tokens.chunks.size() &&
3145-
slot.cache_tokens.chunks[head_c + n_match].tok_text == prompt_tokens.chunks[head_p + n_match].tok_text) {
3144+
while (head_c + n_match < slot.cache_tokens.size() &&
3145+
head_p + n_match < prompt_tokens.size() &&
3146+
slot.cache_tokens[head_c + n_match].txt == prompt_tokens[head_p + n_match].txt) {
31463147

31473148
n_match++;
31483149
}
@@ -3159,7 +3160,7 @@ struct server_context {
31593160
llama_kv_self_seq_add(ctx, slot.id, head_c, head_c + n_match, kv_shift);
31603161

31613162
for (size_t i = 0; i < n_match; i++) {
3162-
slot.cache_tokens.chunks[head_p + i].tok_text = slot.cache_tokens.chunks[head_c + i].tok_text;
3163+
slot.cache_tokens[head_p + i].txt = slot.cache_tokens[head_c + i].txt;
31633164
slot.n_past++;
31643165
}
31653166

@@ -3207,12 +3208,13 @@ struct server_context {
32073208
// remove the non-common part from the cache
32083209
slot.cache_tokens.keep_until(slot.n_past);
32093210

3210-
auto & curr_chunk = slot.prompt_tokens.get_chunk(slot.n_past);
3211+
auto & cur_tok = slot.prompt_tokens[slot.n_past];
32113212

32123213
// check if we should process the image
3213-
if (curr_chunk.tok_image) {
3214+
if (cur_tok.img) {
32143215
// process the image
3215-
int32_t res = server_img_process(ctx, mctx, curr_chunk, batch_embd, slot.n_past, slot.id);
3216+
int32_t res = server_img_process(ctx, mctx, cur_tok, batch_embd, slot.n_past, slot.id);
3217+
int32_t n_tokens = mtmd_image_tokens_get_n_tokens(cur_tok.img.get());
32163218
if (res != 0) {
32173219
SLT_ERR(slot, "failed to process image, res = %d\n", res);
32183220
slot.release();
@@ -3221,27 +3223,30 @@ struct server_context {
32213223
}
32223224

32233225
if (slot.params.cache_prompt) {
3224-
slot.cache_tokens.add_image_tokens(curr_chunk.tok_image);
3226+
// all ALL image tokens at once
3227+
for (int32_t i = 0; i < n_tokens; i++) {
3228+
slot.cache_tokens.add_token(std::move(slot.prompt_tokens[slot.n_past + i]));
3229+
}
32253230
}
32263231

3227-
slot.n_past += curr_chunk.n_tokens;
3228-
slot.n_prompt_tokens_processed += curr_chunk.n_tokens;
3232+
slot.n_past += n_tokens;
3233+
slot.n_prompt_tokens_processed += n_tokens;
32293234
}
32303235

32313236
// add prompt tokens for processing in the current batch
32323237
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
32333238
// get next token to process
3234-
auto & curr_chunk = slot.prompt_tokens.get_chunk(slot.n_past);
3235-
if (curr_chunk.tok_text == LLAMA_TOKEN_NULL) {
3239+
auto & curr_chunk = slot.prompt_tokens[slot.n_past];
3240+
if (curr_chunk.txt == LLAMA_TOKEN_NULL) {
32363241
break; // end of text chunk
32373242
}
32383243

32393244
// without pooling, we want to output the embeddings for all the tokens in the batch
32403245
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
32413246

3242-
common_batch_add(batch, curr_chunk.tok_text, slot.n_past, { slot.id }, need_embd);
3247+
common_batch_add(batch, curr_chunk.txt, slot.n_past, { slot.id }, need_embd);
32433248
if (slot.params.cache_prompt) {
3244-
slot.cache_tokens.add_text_token(curr_chunk.tok_text);
3249+
slot.cache_tokens.add_text_token(curr_chunk.txt);
32453250
}
32463251

32473252
slot.n_prompt_tokens_processed++;
@@ -3261,10 +3266,10 @@ struct server_context {
32613266
common_sampler_reset(slot.smpl);
32623267

32633268
// Process all prompt tokens through sampler system
3264-
for (size_t i = 0; i < slot.cache_tokens.n_tokens(); ++i) {
3265-
auto & curr_chunk = slot.prompt_tokens.get_chunk(i);
3266-
if (curr_chunk.tok_text != LLAMA_TOKEN_NULL) {
3267-
common_sampler_accept(slot.smpl, curr_chunk.tok_text, false);
3269+
for (size_t i = 0; i < slot.cache_tokens.size(); ++i) {
3270+
auto & cur_tok = slot.prompt_tokens[i];
3271+
if (cur_tok.txt != LLAMA_TOKEN_NULL) {
3272+
common_sampler_accept(slot.smpl, cur_tok.txt, false);
32683273
}
32693274
}
32703275

@@ -3289,7 +3294,6 @@ struct server_context {
32893294
return;
32903295
}
32913296

3292-
// debug
32933297
SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
32943298

32953299
if (slot_batched) {
@@ -3303,7 +3307,7 @@ struct server_context {
33033307
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
33043308
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
33053309

3306-
llama_batch batch_view = llama_batch{
3310+
llama_batch batch_view = {
33073311
n_tokens,
33083312
batch.token + i,
33093313
nullptr,
@@ -4072,38 +4076,43 @@ int main(int argc, char ** argv) {
40724076

40734077
// process files
40744078
std::vector<mtmd_bitmap> bitmaps;
4079+
const bool has_mtmd = ctx_server.mctx != nullptr;
40754080
{
4081+
if (!has_mtmd && !files.empty()) {
4082+
throw std::runtime_error("This server does not support multimodal");
4083+
}
40764084
for (auto & file : files) {
40774085
mtmd_bitmap bmp;
40784086
int32_t res = mtmd_helper_bitmap_init_from_buf(file.data(), file.size(), bmp);
40794087
if (res != 0) {
40804088
throw std::runtime_error("Failed to load image");
40814089
}
40824090
// calculate bitmap hash (for KV caching)
4083-
bmp.id = server_inputs::fnv_hash(bmp.data.data(), bmp.data.size());
4091+
bmp.id = server_tokens::fnv_hash(bmp.data.data(), bmp.data.size());
40844092
bitmaps.push_back(std::move(bmp));
40854093
}
40864094
}
40874095

4088-
std::vector<server_inputs> inputs;
4089-
if (oaicompat) {
4090-
if (!prompt.is_string()) {
4091-
throw std::runtime_error("prompt must be a string");
4092-
} else {
4093-
// SRV_INF("prompt: %s\n", prompt.get<std::string>().c_str());
4094-
mtmd_input_text inp_txt = {
4095-
prompt.get<std::string>(),
4096-
/* add_special */ true,
4097-
/* parse_special */ true,
4098-
};
4099-
mtmd_input_chunks chunks;
4100-
int32_t tokenized = mtmd_tokenize(ctx_server.mctx, chunks, inp_txt, bitmaps);
4101-
if (tokenized != 0) {
4102-
throw std::runtime_error("Failed to tokenize prompt");
4103-
}
4104-
server_inputs tmp(chunks);
4105-
inputs.push_back(std::move(tmp));
4096+
// process prompt
4097+
std::vector<server_tokens> inputs;
4098+
if (oaicompat && !prompt.is_string()) {
4099+
throw std::runtime_error("prompt must be a string");
4100+
4101+
} else if (oaicompat && has_mtmd) {
4102+
// multimodal
4103+
mtmd_input_text inp_txt = {
4104+
prompt.get<std::string>(),
4105+
/* add_special */ true,
4106+
/* parse_special */ true,
4107+
};
4108+
mtmd_input_chunks chunks;
4109+
int32_t tokenized = mtmd_tokenize(ctx_server.mctx, chunks, inp_txt, bitmaps);
4110+
if (tokenized != 0) {
4111+
throw std::runtime_error("Failed to tokenize prompt");
41064112
}
4113+
server_tokens tmp(chunks, true);
4114+
inputs.push_back(std::move(tmp));
4115+
41074116
} else {
41084117
// non-multimodal version
41094118
auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);

0 commit comments

Comments
 (0)