Skip to content

Commit 2b7840e

Browse files
committed
server : add --reasoning-cache
1 parent 04e1626 commit 2b7840e

File tree

3 files changed

+122
-0
lines changed

3 files changed

+122
-0
lines changed

common/arg.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2960,6 +2960,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
29602960
params.reasoning_budget = value;
29612961
}
29622962
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK_BUDGET"));
2963+
add_opt(common_arg(
2964+
{"--reasoning-cache"}, "N",
2965+
"controls the reasoning cache size for models that require reasoning content during inference. (default: 0)",
2966+
[](common_params & params, int value) {
2967+
if (value < 0) { throw std::invalid_argument("invalid value"); }
2968+
params.reasoning_cache = value;
2969+
}
2970+
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_REASONING_CACHE"));
29632971
add_opt(common_arg(
29642972
{"--chat-template"}, "JINJA_TEMPLATE",
29652973
string_format(

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,7 @@ struct common_params {
398398
bool enable_chat_template = true;
399399
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO;
400400
int reasoning_budget = -1;
401+
int reasoning_cache = 0;
401402
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
402403

403404
std::vector<std::string> api_keys;

tools/server/server.cpp

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "chat.h"
2+
#include "ggml.h"
23
#include "utils.hpp"
34

45
#include "arg.h"
@@ -597,6 +598,103 @@ struct result_timings {
597598
}
598599
};
599600

601+
struct reasoning_cache {
602+
struct cache_item {
603+
std::string id;
604+
std::string content;
605+
};
606+
607+
std::unordered_map<std::string, cache_item> cache;
608+
std::deque<std::string> ids;
609+
std::mutex mutex;
610+
size_t n_size;
611+
612+
void init(size_t size = 64) {
613+
SRV_INF("initializing reasoning cache, n_size = %ld\n", size);
614+
n_size = size;
615+
}
616+
617+
bool enabled() const {
618+
return n_size > 0;
619+
}
620+
621+
std::optional<std::string> get(const std::string & id) {
622+
if (n_size <= 0) {
623+
return std::nullopt;
624+
}
625+
626+
std::unique_lock<std::mutex> lock(mutex);
627+
auto it = cache.find(id);
628+
if (it == cache.end()) {
629+
SRV_DBG("reasoning cache miss: %s\n", id.c_str());
630+
return std::nullopt;
631+
}
632+
633+
std::string hit = it->second.content;
634+
SRV_DBG("reasoning cache hit: %s\n", id.c_str());
635+
return hit;
636+
}
637+
638+
void insert(const std::string & id, const std::string & content) {
639+
if (n_size <= 0) {
640+
return;
641+
}
642+
643+
std::unique_lock<std::mutex> lock(mutex);
644+
645+
if (ids.size() >= n_size) {
646+
const std::string & last_id = ids.back();
647+
ids.pop_back();
648+
cache.erase(last_id);
649+
}
650+
651+
ids.push_front(id);
652+
cache[id] = {/* .id = */ id, /* .content = */ content};
653+
SRV_DBG("reasoning cache add: %s\n", id.c_str());
654+
}
655+
656+
void extract_from_message(const common_chat_msg & msg) {
657+
for (const auto & t : msg.tool_calls) {
658+
if (!t.id.empty() && !msg.reasoning_content.empty()) {
659+
insert(t.id, msg.reasoning_content);
660+
}
661+
}
662+
}
663+
664+
void inject_oaicompat_chat_params(json & body) {
665+
if (!body.contains("messages")) {
666+
return;
667+
}
668+
669+
json & messages = body.at("messages");
670+
if (!messages.is_array()) {
671+
return;
672+
}
673+
674+
for (auto &msg : messages) {
675+
if (!msg.contains("tool_calls") || msg.contains("reasoning_content")) {
676+
continue;
677+
}
678+
679+
// inject cached reasoning to tool call messages to support models that require it (gpt-oss)
680+
const json & tool_calls = msg.at("tool_calls");
681+
if (tool_calls.is_array() && !tool_calls.empty()) {
682+
for (const auto & t : tool_calls) {
683+
std::string tool_id = json_value(t, "id", std::string());
684+
if (tool_id.empty()) {
685+
continue;
686+
}
687+
688+
if (auto content = get(tool_id)) {
689+
msg["reasoning_content"] = content;
690+
break;
691+
}
692+
}
693+
}
694+
}
695+
}
696+
};
697+
600698
struct server_task_result {
601699
int id = -1;
602700
int id_slot = -1;
@@ -1961,6 +2059,9 @@ struct server_context {
19612059
common_chat_templates_ptr chat_templates;
19622060
oaicompat_parser_options oai_parser_opt;
19632061

2062+
// reasoning cache
2063+
reasoning_cache cache_reasoning;
2064+
19642065
~server_context() {
19652066
mtmd_free(mctx);
19662067

@@ -2161,6 +2262,8 @@ struct server_context {
21612262
/* allow_audio */ mctx ? mtmd_support_audio (mctx) : false,
21622263
/* enable_thinking */ params_base.reasoning_budget != 0,
21632264
};
2265+
2266+
cache_reasoning.init(params_base.reasoning_cache);
21642267
}
21652268

21662269
server_slot * get_slot_by_id(int id) {
@@ -2585,6 +2688,10 @@ struct server_context {
25852688
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
25862689
res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);
25872690

2691+
if (cache_reasoning.enabled()) {
2692+
cache_reasoning.extract_from_message(res->oaicompat_msg);
2693+
}
2694+
25882695
// populate res.probs_output
25892696
if (slot.params.sampling.n_probs > 0) {
25902697
if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) {
@@ -4479,6 +4586,9 @@ int main(int argc, char ** argv) {
44794586

44804587
auto body = json::parse(req.body);
44814588
std::vector<raw_buffer> files;
4589+
if (ctx_server.cache_reasoning.enabled()) {
4590+
ctx_server.cache_reasoning.inject_oaicompat_chat_params(body);
4591+
}
44824592
json data = oaicompat_chat_params_parse(
44834593
body,
44844594
ctx_server.oai_parser_opt,
@@ -4497,6 +4607,9 @@ int main(int argc, char ** argv) {
44974607
const auto handle_apply_template = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
44984608
auto body = json::parse(req.body);
44994609
std::vector<raw_buffer> files; // dummy, unused
4610+
if (ctx_server.cache_reasoning.enabled()) {
4611+
ctx_server.cache_reasoning.inject_oaicompat_chat_params(body);
4612+
}
45004613
json data = oaicompat_chat_params_parse(
45014614
body,
45024615
ctx_server.oai_parser_opt,

0 commit comments

Comments
 (0)