Skip to content

Commit 2df8c1a

Browse files
committed
disable some features when mtmd is on
1 parent 19b9fe1 commit 2df8c1a

File tree

2 files changed

+112
-44
lines changed

2 files changed

+112
-44
lines changed

examples/server/server.cpp

Lines changed: 91 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1983,6 +1983,21 @@ struct server_context {
19831983
return false;
19841984
}
19851985
SRV_INF("loaded multimodal model, '%s'\n", mmproj_path.c_str());
1986+
1987+
if (params_base.ctx_shift) {
1988+
params_base.ctx_shift = false;
1989+
SRV_INF("%s\n", "ctx_shift is not supported by multimodal, it will be disabled");
1990+
}
1991+
1992+
if (params_base.n_cache_reuse) {
1993+
params_base.n_cache_reuse = 0;
1994+
SRV_INF("%s\n", "cache_reuse is not supported by multimodal, it will be disabled");
1995+
}
1996+
1997+
if (!params_base.speculative.model.path.empty()) {
1998+
SRV_ERR("%s\n", "err: speculative decode is not supported by multimodal");
1999+
return false;
2000+
}
19862001
}
19872002

19882003
return true;
@@ -2432,14 +2447,15 @@ struct server_context {
24322447

24332448
void send_final_response(server_slot & slot) {
24342449
auto res = std::make_unique<server_task_result_cmpl_final>();
2450+
llama_tokens text_tokens = slot.prompt_tokens.get_text_tokens();
24352451
res->id = slot.id_task;
24362452
res->id_slot = slot.id;
24372453

24382454
res->index = slot.index;
24392455
res->content = std::move(slot.generated_text);
24402456
res->tokens = std::move(slot.generated_tokens);
24412457
res->timings = slot.get_timings();
2442-
//res->prompt = common_detokenize(ctx, slot.prompt_tokens, true); // TODO @ngxson : hacky, need to fix
2458+
res->prompt = common_detokenize(ctx, text_tokens, true);
24432459
res->response_fields = std::move(slot.params.response_fields);
24442460

24452461
res->truncated = slot.truncated;
@@ -2747,10 +2763,14 @@ struct server_context {
27472763
}
27482764
queue_results.send(std::move(res));
27492765
} break;
2750-
/*case SERVER_TASK_TYPE_SLOT_SAVE:
2766+
case SERVER_TASK_TYPE_SLOT_SAVE:
27512767
{
27522768
int id_slot = task.slot_action.slot_id;
27532769
server_slot * slot = get_slot_by_id(id_slot);
2770+
if (mctx) {
2771+
send_error(task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED);
2772+
break;
2773+
}
27542774
if (slot == nullptr) {
27552775
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
27562776
break;
@@ -2762,13 +2782,14 @@ struct server_context {
27622782
break;
27632783
}
27642784

2765-
const size_t token_count = slot->cache_tokens.size();
2785+
const size_t token_count = slot->cache_tokens.n_tokens();
27662786
const int64_t t_start = ggml_time_us();
27672787

27682788
std::string filename = task.slot_action.filename;
27692789
std::string filepath = task.slot_action.filepath;
27702790

2771-
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count);
2791+
const llama_tokens tokens = slot->cache_tokens.get_text_tokens();
2792+
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count);
27722793

27732794
const int64_t t_end = ggml_time_us();
27742795
const double t_save_ms = (t_end - t_start) / 1000.0;
@@ -2785,6 +2806,10 @@ struct server_context {
27852806
} break;
27862807
case SERVER_TASK_TYPE_SLOT_RESTORE:
27872808
{
2809+
if (mctx) {
2810+
send_error(task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED);
2811+
break;
2812+
}
27882813
int id_slot = task.slot_action.slot_id;
27892814
server_slot * slot = get_slot_by_id(id_slot);
27902815
if (slot == nullptr) {
@@ -2803,15 +2828,17 @@ struct server_context {
28032828
std::string filename = task.slot_action.filename;
28042829
std::string filepath = task.slot_action.filepath;
28052830

2806-
slot->cache_tokens.resize(slot->n_ctx);
2831+
llama_tokens tokens;
2832+
tokens.resize(slot->n_ctx);
28072833
size_t token_count = 0;
2808-
size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count);
2834+
size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count);
28092835
if (nread == 0) {
2810-
slot->cache_tokens.resize(0);
2836+
slot->cache_tokens.clear(); // KV may already been invalidated?
28112837
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
28122838
break;
28132839
}
2814-
slot->cache_tokens.resize(token_count);
2840+
tokens.resize(token_count);
2841+
slot->cache_tokens.set_text_tokens(tokens);
28152842

28162843
const int64_t t_end = ggml_time_us();
28172844
const double t_restore_ms = (t_end - t_start) / 1000.0;
@@ -2828,6 +2855,10 @@ struct server_context {
28282855
} break;
28292856
case SERVER_TASK_TYPE_SLOT_ERASE:
28302857
{
2858+
if (mctx) {
2859+
send_error(task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED);
2860+
break;
2861+
}
28312862
int id_slot = task.slot_action.slot_id;
28322863
server_slot * slot = get_slot_by_id(id_slot);
28332864
if (slot == nullptr) {
@@ -2842,7 +2873,7 @@ struct server_context {
28422873
}
28432874

28442875
// Erase token cache
2845-
const size_t n_erased = slot->cache_tokens.size();
2876+
const size_t n_erased = slot->cache_tokens.n_tokens();
28462877
llama_kv_self_seq_rm(ctx, slot->id, -1, -1);
28472878
slot->cache_tokens.clear();
28482879

@@ -2851,11 +2882,7 @@ struct server_context {
28512882
res->id_slot = id_slot;
28522883
res->n_erased = n_erased;
28532884
queue_results.send(std::move(res));
2854-
} break;*/
2855-
case SERVER_TASK_TYPE_SLOT_SAVE:
2856-
case SERVER_TASK_TYPE_SLOT_RESTORE:
2857-
case SERVER_TASK_TYPE_SLOT_ERASE:
2858-
GGML_ASSERT(false && "TODO @ngxson : removed due to not compat with multimodal");
2885+
} break;
28592886
case SERVER_TASK_TYPE_SET_LORA:
28602887
{
28612888
params_base.lora_adapters = std::move(task.set_lora);
@@ -2899,8 +2926,7 @@ struct server_context {
28992926

29002927
// apply context-shift if needed
29012928
// TODO: simplify and improve
2902-
// TODO @ngxson : hacky, need to disable context shift for multimodal
2903-
/*for (server_slot & slot : slots) {
2929+
for (server_slot & slot : slots) {
29042930
if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) {
29052931
if (!params_base.ctx_shift) {
29062932
// this check is redundant (for good)
@@ -2910,6 +2936,12 @@ struct server_context {
29102936
continue;
29112937
}
29122938

2939+
if (mctx) {
2940+
// we should never reach this because params_base.ctx_shift is automatically disabled if mmproj is loaded
2941+
// we don't support ctx_shift because an image chunk may contains multiple tokens
2942+
GGML_ABORT("not supported by multimodal");
2943+
}
2944+
29132945
// Shift context
29142946
const int n_keep = slot.params.n_keep + add_bos_token;
29152947
const int n_left = slot.n_past - n_keep;
@@ -2921,18 +2953,18 @@ struct server_context {
29212953
llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard);
29222954

29232955
if (slot.params.cache_prompt) {
2924-
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
2925-
slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
2956+
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.chunks.size(); i++) {
2957+
slot.cache_tokens.chunks[i - n_discard] = std::move(slot.cache_tokens.chunks[i]);
29262958
}
29272959

2928-
slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
2960+
slot.cache_tokens.chunks.resize(slot.cache_tokens.chunks.size() - n_discard);
29292961
}
29302962

29312963
slot.n_past -= n_discard;
29322964

29332965
slot.truncated = true;
29342966
}
2935-
}*/
2967+
}
29362968

29372969
// start populating the batch for this iteration
29382970
common_batch_clear(batch);
@@ -3054,51 +3086,59 @@ struct server_context {
30543086
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
30553087

30563088
// if input prompt is too big, truncate it
3057-
// TODO @ngxson : this won't work with multimodal
3058-
/*if (slot.n_prompt_tokens >= slot.n_ctx) {
3089+
if (slot.n_prompt_tokens >= slot.n_ctx) {
3090+
if (mctx) {
3091+
// we should never reach this
3092+
GGML_ABORT("not supported by multimodal");
3093+
}
3094+
llama_tokens curr_tokens = slot.prompt_tokens.get_text_tokens();
30593095
const int n_left = slot.n_ctx - slot.params.n_keep;
30603096

30613097
const int n_block_size = n_left / 2;
30623098
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
30633099

30643100
llama_tokens new_tokens(
3065-
prompt_tokens.begin(),
3066-
prompt_tokens.begin() + slot.params.n_keep);
3101+
curr_tokens.begin(),
3102+
curr_tokens.begin() + slot.params.n_keep);
30673103

30683104
new_tokens.insert(
30693105
new_tokens.end(),
3070-
prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size,
3071-
prompt_tokens.end());
3106+
curr_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size,
3107+
curr_tokens.end());
30723108

3073-
prompt_tokens = std::move(new_tokens);
3109+
prompt_tokens.set_text_tokens(new_tokens);
30743110

30753111
slot.truncated = true;
3076-
slot.n_prompt_tokens = prompt_tokens.size();
3112+
slot.n_prompt_tokens = prompt_tokens.n_tokens();
30773113

30783114
SLT_WRN(slot, "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens);
30793115

30803116
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
3081-
}*/
3117+
}
30823118

30833119
if (slot.params.cache_prompt) {
30843120
// reuse any previously computed tokens that are common with the new prompt
30853121
slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens);
30863122

30873123
// reuse chunks from the cached prompt by shifting their KV cache in the new position
3088-
// TODO @ngxson : this won't work with multimodal
3089-
/*if (params_base.n_cache_reuse > 0) {
3124+
if (params_base.n_cache_reuse > 0) {
30903125
size_t head_c = slot.n_past; // cache
30913126
size_t head_p = slot.n_past; // current prompt
30923127

3128+
if (mctx) {
3129+
// we should never reach this
3130+
GGML_ABORT("not supported by multimodal");
3131+
}
3132+
30933133
SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past);
30943134

3095-
while (head_c < slot.cache_tokens.size() &&
3096-
head_p < prompt_tokens.size()) {
3135+
while (head_c < slot.cache_tokens.chunks.size() &&
3136+
head_p < prompt_tokens.chunks.size()) {
30973137

30983138
size_t n_match = 0;
3099-
while (head_c + n_match < slot.cache_tokens.size() &&
3100-
head_p + n_match < prompt_tokens.size() &&
3101-
slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) {
3139+
while (head_c + n_match < slot.cache_tokens.chunks.size() &&
3140+
head_p + n_match < prompt_tokens.chunks.size() &&
3141+
slot.cache_tokens.chunks[head_c + n_match].tok_text == prompt_tokens.chunks[head_p + n_match].tok_text) {
31023142

31033143
n_match++;
31043144
}
@@ -3115,7 +3155,7 @@ struct server_context {
31153155
llama_kv_self_seq_add(ctx, slot.id, head_c, head_c + n_match, kv_shift);
31163156

31173157
for (size_t i = 0; i < n_match; i++) {
3118-
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
3158+
slot.cache_tokens.chunks[head_p + i].tok_text = slot.cache_tokens.chunks[head_c + i].tok_text;
31193159
slot.n_past++;
31203160
}
31213161

@@ -3127,7 +3167,7 @@ struct server_context {
31273167
}
31283168

31293169
SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past);
3130-
}*/
3170+
}
31313171
}
31323172
}
31333173

@@ -3359,8 +3399,7 @@ struct server_context {
33593399
}
33603400

33613401
// do speculative decoding
3362-
// TODO @ngxson : remove speculative decoding for multimodal
3363-
/*for (auto & slot : slots) {
3402+
for (auto & slot : slots) {
33643403
if (!slot.is_processing() || !slot.can_speculate()) {
33653404
continue;
33663405
}
@@ -3369,6 +3408,11 @@ struct server_context {
33693408
continue;
33703409
}
33713410

3411+
if (mctx) {
3412+
// we should never reach this
3413+
GGML_ABORT("not supported by multimodal");
3414+
}
3415+
33723416
// determine the max draft that fits the current slot state
33733417
int n_draft_max = slot.params.speculative.n_max;
33743418

@@ -3395,7 +3439,8 @@ struct server_context {
33953439
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
33963440
params_spec.p_min = slot.params.speculative.p_min;
33973441

3398-
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);
3442+
llama_tokens cached_text_tokens = slot.cache_tokens.get_text_tokens();
3443+
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id);
33993444

34003445
// keep track of total number of tokens generated in the draft
34013446
slot.n_draft_total += draft.size();
@@ -3428,8 +3473,10 @@ struct server_context {
34283473
// update how many tokens out of draft was accepted
34293474
slot.n_draft_accepted += ids.size() - 1;
34303475

3431-
slot.cache_tokens.push_back(id);
3432-
slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
3476+
slot.cache_tokens.add_text_token(id);
3477+
for (auto & t : ids) {
3478+
slot.cache_tokens.add_text_token(t);
3479+
}
34333480

34343481
llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1);
34353482

@@ -3453,7 +3500,7 @@ struct server_context {
34533500
}
34543501

34553502
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past);
3456-
}*/
3503+
}
34573504
}
34583505

34593506
SRV_DBG("%s", "run slots completed\n");

examples/server/utils.hpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,6 +1174,27 @@ struct server_inputs {
11741174
}
11751175
return std::to_string(hash);
11761176
}
1177+
1178+
// TODO: maybe implement a (de)seralizer for this struct, so we can get rid of functions below
1179+
1180+
// return all text tokens (for legacy code), to be used by save/load slot
1181+
llama_tokens get_text_tokens() {
1182+
llama_tokens output;
1183+
for (auto & chunk : chunks) {
1184+
if (chunk.tok_text != LLAMA_TOKEN_NULL) {
1185+
output.push_back(chunk.tok_text);
1186+
}
1187+
}
1188+
return output;
1189+
}
1190+
1191+
// clear and set text tokens (for legacy code), to be used by save/load slot
1192+
void set_text_tokens(llama_tokens tokens) {
1193+
chunks.clear();
1194+
for (auto & tok : tokens) {
1195+
add_text_token(tok);
1196+
}
1197+
}
11771198
};
11781199

11791200
// helper struct to make working with embd batch easier

0 commit comments

Comments
 (0)