Skip to content

Commit 70ec318

Browse files
Update server-mmojo.cpp
Signed-off-by: Brad Hutchings <[email protected]>
1 parent 5a6d55e commit 70ec318

File tree

1 file changed

+41
-10
lines changed

1 file changed

+41
-10
lines changed

tools/server/server-mmojo.cpp

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,33 @@ struct server_task {
502502
}
503503
}
504504
}
505+
} else if (logit_bias != data.end() && logit_bias->is_object()) {
506+
const int n_vocab = llama_vocab_n_tokens(vocab);
507+
for (const auto & el : logit_bias->items()) {
508+
float bias;
509+
const auto & key = el.key();
510+
const auto & value = el.value();
511+
if (value.is_number()) {
512+
bias = value.get<float>();
513+
} else if (value.is_boolean() && !value.get<bool>()) {
514+
bias = -INFINITY;
515+
} else {
516+
continue;
517+
}
518+
519+
char *end;
520+
llama_token tok = strtol(key.c_str(), &end, 10);
521+
if (*end == 0) {
522+
if (tok >= 0 && tok < n_vocab) {
523+
params.sampling.logit_bias.push_back({tok, bias});
524+
}
525+
} else {
526+
auto toks = common_tokenize(vocab, key, false);
527+
for (auto tok : toks) {
528+
params.sampling.logit_bias.push_back({tok, bias});
529+
}
530+
}
531+
}
505532
}
506533

507534
params.sampling.ignore_eos = json_value(data, "ignore_eos", params_base.sampling.ignore_eos);
@@ -1000,8 +1027,8 @@ struct server_task_result_cmpl_partial : server_task_result {
10001027
{"progress", progress},
10011028
};
10021029
}
1003-
// mmojo-server END
1004-
1030+
// mmojo-server END
1031+
10051032
return res;
10061033
}
10071034

@@ -1954,6 +1981,7 @@ struct server_context {
19541981
mtmd_context * mctx = nullptr;
19551982

19561983
const llama_vocab * vocab = nullptr;
1984+
bool vocab_dft_compatible = true;
19571985

19581986
llama_model * model_dft = nullptr;
19591987

@@ -2044,10 +2072,9 @@ struct server_context {
20442072
return false;
20452073
}
20462074

2047-
if (!common_speculative_are_compatible(ctx, llama_init_dft.context.get())) {
2048-
SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str());
2049-
2050-
return false;
2075+
vocab_dft_compatible = common_speculative_are_compatible(ctx, llama_init_dft.context.get());
2076+
if (!vocab_dft_compatible) {
2077+
SRV_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str());
20512078
}
20522079

20532080
const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get());
@@ -2137,11 +2164,14 @@ struct server_context {
21372164
return;
21382165
}
21392166

2140-
slot.spec = common_speculative_init(slot.ctx_dft);
2167+
slot.spec = common_speculative_init(slot.ctx, slot.ctx_dft);
21412168
if (slot.spec == nullptr) {
21422169
SRV_ERR("%s", "failed to create speculator\n");
21432170
return;
21442171
}
2172+
for (auto &pair : params_base.speculative.replacements) {
2173+
common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str());
2174+
}
21452175
}
21462176

21472177
SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
@@ -2640,6 +2670,7 @@ struct server_context {
26402670
}
26412671
// mmojo-server END
26422672

2673+
26432674
void send_final_response(server_slot & slot) {
26442675
auto res = std::make_unique<server_task_result_cmpl_final>();
26452676
res->id = slot.id_task;
@@ -3492,7 +3523,7 @@ struct server_context {
34923523
SLT_INF(slot, "%s", "Finished sleep after batch.\n");
34933524
}
34943525
// mmojo-server END
3495-
3526+
34963527
// entire prompt has been processed
34973528
if (slot.n_past == slot.n_prompt_tokens) {
34983529
slot.state = SLOT_STATE_DONE_PROMPT;
@@ -3887,7 +3918,7 @@ int main(int argc, char ** argv) {
38873918
// User supplied args override argsFilename and zipArgsFilename args.
38883919
#endif
38893920
// mmojo-server END
3890-
3921+
38913922
// own arguments required by this example
38923923
common_params params;
38933924

@@ -5086,7 +5117,7 @@ int main(int argc, char ** argv) {
50865117
return false;
50875118
});
50885119
// mmojo-server END
5089-
5120+
50905121
// register API routes
50915122
svr->Get (params.api_prefix + "/health", handle_health); // public endpoint (no API key check)
50925123
svr->Get (params.api_prefix + "/metrics", handle_metrics);

0 commit comments

Comments
 (0)