@@ -502,6 +502,33 @@ struct server_task {
502502 }
503503 }
504504 }
505+ } else if (logit_bias != data.end () && logit_bias->is_object ()) {
506+ const int n_vocab = llama_vocab_n_tokens (vocab);
507+ for (const auto & el : logit_bias->items ()) {
508+ float bias;
509+ const auto & key = el.key ();
510+ const auto & value = el.value ();
511+ if (value.is_number ()) {
512+ bias = value.get <float >();
513+ } else if (value.is_boolean () && !value.get <bool >()) {
514+ bias = -INFINITY;
515+ } else {
516+ continue ;
517+ }
518+
519+ char *end;
520+ llama_token tok = strtol (key.c_str (), &end, 10 );
521+ if (*end == 0 ) {
522+ if (tok >= 0 && tok < n_vocab) {
523+ params.sampling .logit_bias .push_back ({tok, bias});
524+ }
525+ } else {
526+ auto toks = common_tokenize (vocab, key, false );
527+ for (auto tok : toks) {
528+ params.sampling .logit_bias .push_back ({tok, bias});
529+ }
530+ }
531+ }
505532 }
506533
507534 params.sampling .ignore_eos = json_value (data, " ignore_eos" , params_base.sampling .ignore_eos );
@@ -1000,8 +1027,8 @@ struct server_task_result_cmpl_partial : server_task_result {
10001027 {" progress" , progress},
10011028 };
10021029 }
1003- // mmojo-server END
1004-
1030+ // mmojo-server END
1031+
10051032 return res;
10061033 }
10071034
@@ -1954,6 +1981,7 @@ struct server_context {
19541981 mtmd_context * mctx = nullptr ;
19551982
19561983 const llama_vocab * vocab = nullptr ;
1984+ bool vocab_dft_compatible = true ;
19571985
19581986 llama_model * model_dft = nullptr ;
19591987
@@ -2044,10 +2072,9 @@ struct server_context {
20442072 return false ;
20452073 }
20462074
2047- if (!common_speculative_are_compatible (ctx, llama_init_dft.context .get ())) {
2048- SRV_ERR (" the draft model '%s' is not compatible with the target model '%s'\n " , params_base.speculative .model .path .c_str (), params_base.model .path .c_str ());
2049-
2050- return false ;
2075+ vocab_dft_compatible = common_speculative_are_compatible (ctx, llama_init_dft.context .get ());
2076+ if (!vocab_dft_compatible) {
2077+ SRV_INF (" the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n " , params_base.speculative .model .path .c_str (), params_base.model .path .c_str ());
20512078 }
20522079
20532080 const int n_ctx_dft = llama_n_ctx (llama_init_dft.context .get ());
@@ -2137,11 +2164,14 @@ struct server_context {
21372164 return ;
21382165 }
21392166
2140- slot.spec = common_speculative_init (slot.ctx_dft );
2167+ slot.spec = common_speculative_init (slot.ctx , slot. ctx_dft );
21412168 if (slot.spec == nullptr ) {
21422169 SRV_ERR (" %s" , " failed to create speculator\n " );
21432170 return ;
21442171 }
2172+ for (auto &pair : params_base.speculative .replacements ) {
2173+ common_speculative_add_replacement_tgt_dft (slot.spec , pair.first .c_str (), pair.second .c_str ());
2174+ }
21452175 }
21462176
21472177 SLT_INF (slot, " new slot n_ctx_slot = %d\n " , slot.n_ctx );
@@ -2640,6 +2670,7 @@ struct server_context {
26402670 }
26412671 // mmojo-server END
26422672
2673+
26432674 void send_final_response (server_slot & slot) {
26442675 auto res = std::make_unique<server_task_result_cmpl_final>();
26452676 res->id = slot.id_task ;
@@ -3492,7 +3523,7 @@ struct server_context {
34923523 SLT_INF (slot, " %s" , " Finished sleep after batch.\n " );
34933524 }
34943525 // mmojo-server END
3495-
3526+
34963527 // entire prompt has been processed
34973528 if (slot.n_past == slot.n_prompt_tokens ) {
34983529 slot.state = SLOT_STATE_DONE_PROMPT;
@@ -3887,7 +3918,7 @@ int main(int argc, char ** argv) {
38873918 // User supplied args override argsFilename and zipArgsFilename args.
38883919 #endif
38893920 // mmojo-server END
3890-
3921+
38913922 // own arguments required by this example
38923923 common_params params;
38933924
@@ -5086,7 +5117,7 @@ int main(int argc, char ** argv) {
50865117 return false ;
50875118 });
50885119 // mmojo-server END
5089-
5120+
50905121 // register API routes
50915122 svr->Get (params.api_prefix + " /health" , handle_health); // public endpoint (no API key check)
50925123 svr->Get (params.api_prefix + " /metrics" , handle_metrics);
0 commit comments