From f72c6eb27562ae452153e3516994a3d287345930 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Sat, 9 Aug 2025 18:14:23 -0500 Subject: [PATCH 1/2] server : add --reasoning-cache --- common/arg.cpp | 8 +++ common/common.h | 1 + tools/server/server.cpp | 113 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 122 insertions(+) diff --git a/common/arg.cpp b/common/arg.cpp index 98baac4c14da2..c580d183d3ecc 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3002,6 +3002,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.reasoning_budget = value; } ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK_BUDGET")); + add_opt(common_arg( + {"--reasoning-cache"}, "N", + "controls the reasoning cache size for models that require reasoning content during inference. (default: 0)", + [](common_params & params, int value) { + if (value < 0) { throw std::invalid_argument("invalid value"); } + params.reasoning_cache = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_REASONING_CACHE")); add_opt(common_arg( {"--chat-template"}, "JINJA_TEMPLATE", string_format( diff --git a/common/common.h b/common/common.h index 75596e6b32979..1d444deccfd07 100644 --- a/common/common.h +++ b/common/common.h @@ -428,6 +428,7 @@ struct common_params { bool enable_chat_template = true; common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO; int reasoning_budget = -1; + int reasoning_cache = 0; bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response std::vector api_keys; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 0b40f7bfa4258..0e6c5d544b6de 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1,4 +1,5 @@ #include "chat.h" +#include "ggml.h" #include "utils.hpp" #include "arg.h" @@ -597,6 +598,103 @@ struct result_timings { } }; +struct reasoning_cache { + struct cache_item { + std::string id; + std::string content; + }; + + std::unordered_map cache; + std::deque ids; + std::mutex mutex; + size_t n_size; + + void init(size_t size = 64) { + SRV_INF("initializing reasoning cache, n_size = %ld\n", size); + n_size = size; + } + + bool enabled() const { + return n_size > 0; + } + + std::optional get(const std::string & id) { + if (n_size <= 0) { + return std::nullopt; + } + + std::unique_lock lock(mutex); + auto it = cache.find(id); + if (it == cache.end()) { + SRV_DBG("reasoning cache miss: %s\n", id.c_str()); + return std::nullopt; + } + + std::string hit = it->second.content; + SRV_DBG("reasoning cache hit: %s\n", id.c_str()); + return hit; + } + + void insert(const std::string & id, const std::string & content) { + if (n_size <= 0) { + return; + } + + std::unique_lock lock(mutex); + + if (ids.size() >= n_size) { + const std::string & last_id = ids.back(); + ids.pop_back(); + cache.erase(last_id); + } + + ids.push_front(id); + cache[id] = {/* .id = */ id, /* .content = */ content}; + SRV_DBG("reasoning cache add: %s\n", id.c_str()); + } + + void extract_from_message(const common_chat_msg & msg) { + for (const auto & t : msg.tool_calls) { + if (!t.id.empty() && !msg.reasoning_content.empty()) { + insert(t.id, msg.reasoning_content); + } + } + } + + void inject_oaicompat_chat_params(json & body) { + if (!body.contains("messages")) { + return; + } + + json & messages = body.at("messages"); + if (!messages.is_array()) { + return; + } + + for (auto &msg : messages) { + if (!msg.contains("tool_calls") || msg.contains("reasoning_content")) { + continue; + } + + // inject cached reasoning to tool call messages to support models that require it (gpt-oss) + const json & tool_calls = msg.at("tool_calls"); + if (tool_calls.is_array() && !tool_calls.empty()) { + for (const auto & t : tool_calls) { + std::string tool_id = json_value(t, "id", std::string()); + if (tool_id.empty()) { + continue; + } + + if (auto content = get(tool_id)) { + msg["reasoning_content"] = content; + break; + } + } + } + } + } +}; + struct server_task_result { int id = -1; int id_slot = -1; @@ -1970,6 +2068,9 @@ struct server_context { common_chat_templates_ptr chat_templates; oaicompat_parser_options oai_parser_opt; + // reasoning cache + reasoning_cache cache_reasoning; + ~server_context() { mtmd_free(mctx); @@ -2174,6 +2275,8 @@ struct server_context { /* allow_audio */ mctx ? mtmd_support_audio (mctx) : false, /* enable_thinking */ params_base.reasoning_budget != 0, }; + + cache_reasoning.init(params_base.reasoning_cache); } server_slot * get_slot_by_id(int id) { @@ -2598,6 +2701,10 @@ struct server_context { res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs); + if (cache_reasoning.enabled()) { + cache_reasoning.extract_from_message(res->oaicompat_msg); + } + // populate res.probs_output if (slot.params.sampling.n_probs > 0) { if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) { @@ -4573,6 +4680,9 @@ int main(int argc, char ** argv) { auto body = json::parse(req.body); std::vector files; + if (ctx_server.cache_reasoning.enabled()) { + ctx_server.cache_reasoning.inject_oaicompat_chat_params(body); + } json data = oaicompat_chat_params_parse( body, ctx_server.oai_parser_opt, @@ -4591,6 +4701,9 @@ int main(int argc, char ** argv) { const auto handle_apply_template = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) { auto body = json::parse(req.body); std::vector files; // dummy, unused + if (ctx_server.cache_reasoning.enabled()) { + ctx_server.cache_reasoning.inject_oaicompat_chat_params(body); + } json data = oaicompat_chat_params_parse( body, ctx_server.oai_parser_opt, From 9d1d2454a0e68bc9b607956d8506b3e8d1e198b9 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Sun, 10 Aug 2025 16:28:19 -0500 Subject: [PATCH 2/2] server : do not inject reasoning if content is present --- tools/server/server.cpp | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 0e6c5d544b6de..0d0dabb7a8c2f 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -672,10 +672,27 @@ struct reasoning_cache { } for (auto &msg : messages) { + std::string role = json_value(msg, "role", std::string()); + if (role != "assistant") { + continue; + } + if (!msg.contains("tool_calls") || msg.contains("reasoning_content")) { continue; } + // do not inject if the message contains a non-empty content + if (msg.contains("content")) { + if (msg.at("content").is_string()) { + std::string content = json_value(msg, "content", std::string()); + if (!content.empty()) { + continue; + } + } else if (!msg.at("content").empty()) { + continue; + } + } + // inject cached reasoning to tool call messages to support models that require it (gpt-oss) const json & tool_calls = msg.at("tool_calls"); if (tool_calls.is_array() && !tool_calls.empty()) {