Skip to content

Commit 11ac33a

Browse files
committed
server : add --reasoning-cache
1 parent 379b652 commit 11ac33a

File tree

3 files changed

+122
-0
lines changed

3 files changed

+122
-0
lines changed

common/arg.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2964,6 +2964,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
29642964
params.reasoning_budget = value;
29652965
}
29662966
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK_BUDGET"));
2967+
add_opt(common_arg(
2968+
{"--reasoning-cache"}, "N",
2969+
"controls the reasoning cache size for models that require reasoning content during inference. (default: 0)",
2970+
[](common_params & params, int value) {
2971+
if (value < 0) { throw std::invalid_argument("invalid value"); }
2972+
params.reasoning_cache = value;
2973+
}
2974+
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_REASONING_CACHE"));
29672975
add_opt(common_arg(
29682976
{"--chat-template"}, "JINJA_TEMPLATE",
29692977
string_format(

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,7 @@ struct common_params {
398398
bool enable_chat_template = true;
399399
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO;
400400
int reasoning_budget = -1;
401+
int reasoning_cache = 0;
401402
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
402403

403404
std::vector<std::string> api_keys;

tools/server/server.cpp

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "chat.h"
2+
#include "ggml.h"
23
#include "utils.hpp"
34

45
#include "arg.h"
@@ -593,6 +594,103 @@ struct result_timings {
593594
}
594595
};
595596

597+
struct reasoning_cache {
598+
struct cache_item {
599+
std::string id;
600+
std::string content;
601+
};
602+
603+
std::unordered_map<std::string, cache_item> cache;
604+
std::deque<std::string> ids;
605+
std::mutex mutex;
606+
size_t n_size;
607+
608+
void init(size_t size = 64) {
609+
SRV_INF("initializing reasoning cache, n_size = %ld\n", size);
610+
n_size = size;
611+
}
612+
613+
bool enabled() const {
614+
return n_size > 0;
615+
}
616+
617+
std::optional<std::string> get(const std::string & id) {
618+
if (n_size <= 0) {
619+
return std::nullopt;
620+
}
621+
622+
std::unique_lock<std::mutex> lock(mutex);
623+
auto it = cache.find(id);
624+
if (it == cache.end()) {
625+
SRV_DBG("reasoning cache miss: %s\n", id.c_str());
626+
return std::nullopt;
627+
}
628+
629+
std::string hit = it->second.content;
630+
SRV_DBG("reasoning cache hit: %s\n", id.c_str());
631+
return hit;
632+
}
633+
634+
void insert(const std::string & id, const std::string & content) {
635+
if (n_size <= 0) {
636+
return;
637+
}
638+
639+
std::unique_lock<std::mutex> lock(mutex);
640+
641+
if (ids.size() >= n_size) {
642+
const std::string & last_id = ids.back();
643+
ids.pop_back();
644+
cache.erase(last_id);
645+
}
646+
647+
ids.push_front(id);
648+
cache[id] = {/* .id = */ id, /* .content = */ content};
649+
SRV_DBG("reasoning cache add: %s\n", id.c_str());
650+
}
651+
652+
void extract_from_message(const common_chat_msg & msg) {
653+
for (const auto & t : msg.tool_calls) {
654+
if (!t.id.empty() && !msg.reasoning_content.empty()) {
655+
insert(t.id, msg.reasoning_content);
656+
}
657+
}
658+
}
659+
660+
void inject_oaicompat_chat_params(json & body) {
661+
if (!body.contains("messages")) {
662+
return;
663+
}
664+
665+
json & messages = body.at("messages");
666+
if (!messages.is_array()) {
667+
return;
668+
}
669+
670+
for (auto &msg : messages) {
671+
if (!msg.contains("tool_calls") || msg.contains("reasoning_content")) {
672+
continue;
673+
}
674+
675+
// inject cached reasoning to tool call messages to support models that require it (gpt-oss)
676+
const json & tool_calls = msg.at("tool_calls");
677+
if (tool_calls.is_array() && !tool_calls.empty()) {
678+
for (const auto & t : tool_calls) {
679+
std::string tool_id = json_value(t, "id", std::string());
680+
if (tool_id.empty()) {
681+
continue;
682+
}
683+
684+
if (auto content = get(tool_id)) {
685+
msg["reasoning_content"] = content;
686+
break;
687+
}
688+
}
689+
}
690+
}
691+
}
692+
};
693+
596694
struct server_task_result {
597695
int id = -1;
598696
int id_slot = -1;
@@ -1957,6 +2055,9 @@ struct server_context {
19572055
common_chat_templates_ptr chat_templates;
19582056
oaicompat_parser_options oai_parser_opt;
19592057

2058+
// reasoning cache
2059+
reasoning_cache cache_reasoning;
2060+
19602061
~server_context() {
19612062
mtmd_free(mctx);
19622063

@@ -2157,6 +2258,8 @@ struct server_context {
21572258
/* allow_audio */ mctx ? mtmd_support_audio (mctx) : false,
21582259
/* enable_thinking */ params_base.reasoning_budget != 0,
21592260
};
2261+
2262+
cache_reasoning.init(params_base.reasoning_cache);
21602263
}
21612264

21622265
server_slot * get_slot_by_id(int id) {
@@ -2581,6 +2684,10 @@ struct server_context {
25812684
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
25822685
res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);
25832686

2687+
if (cache_reasoning.enabled()) {
2688+
cache_reasoning.extract_from_message(res->oaicompat_msg);
2689+
}
2690+
25842691
// populate res.probs_output
25852692
if (slot.params.sampling.n_probs > 0) {
25862693
if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) {
@@ -4475,6 +4582,9 @@ int main(int argc, char ** argv) {
44754582

44764583
auto body = json::parse(req.body);
44774584
std::vector<raw_buffer> files;
4585+
if (ctx_server.cache_reasoning.enabled()) {
4586+
ctx_server.cache_reasoning.inject_oaicompat_chat_params(body);
4587+
}
44784588
json data = oaicompat_chat_params_parse(
44794589
body,
44804590
ctx_server.oai_parser_opt,
@@ -4493,6 +4603,9 @@ int main(int argc, char ** argv) {
44934603
const auto handle_apply_template = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
44944604
auto body = json::parse(req.body);
44954605
std::vector<raw_buffer> files; // dummy, unused
4606+
if (ctx_server.cache_reasoning.enabled()) {
4607+
ctx_server.cache_reasoning.inject_oaicompat_chat_params(body);
4608+
}
44964609
json data = oaicompat_chat_params_parse(
44974610
body,
44984611
ctx_server.oai_parser_opt,

0 commit comments

Comments
 (0)