Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3002,6 +3002,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.reasoning_budget = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK_BUDGET"));
add_opt(common_arg(
{"--reasoning-cache"}, "N",
"controls the reasoning cache size for models that require reasoning content during inference. (default: 0)",
[](common_params & params, int value) {
if (value < 0) { throw std::invalid_argument("invalid value"); }
params.reasoning_cache = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_REASONING_CACHE"));
add_opt(common_arg(
{"--chat-template"}, "JINJA_TEMPLATE",
string_format(
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ struct common_params {
bool enable_chat_template = true;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO;
int reasoning_budget = -1;
int reasoning_cache = 0;
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response

std::vector<std::string> api_keys;
Expand Down
130 changes: 130 additions & 0 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "chat.h"
#include "ggml.h"
#include "utils.hpp"

#include "arg.h"
Expand Down Expand Up @@ -597,6 +598,120 @@ struct result_timings {
}
};

struct reasoning_cache {
struct cache_item {
std::string id;
std::string content;
};

std::unordered_map<std::string, cache_item> cache;
std::deque<std::string> ids;
std::mutex mutex;
size_t n_size;

void init(size_t size = 64) {
SRV_INF("initializing reasoning cache, n_size = %ld\n", size);
n_size = size;
}

bool enabled() const {
return n_size > 0;
}

std::optional<std::string> get(const std::string & id) {
if (n_size <= 0) {
return std::nullopt;
}

std::unique_lock<std::mutex> lock(mutex);
auto it = cache.find(id);
if (it == cache.end()) {
SRV_DBG("reasoning cache miss: %s\n", id.c_str());
return std::nullopt;
}

std::string hit = it->second.content;
SRV_DBG("reasoning cache hit: %s\n", id.c_str());
return hit;
}

void insert(const std::string & id, const std::string & content) {
if (n_size <= 0) {
return;
}

std::unique_lock<std::mutex> lock(mutex);

if (ids.size() >= n_size) {
const std::string & last_id = ids.back();
ids.pop_back();
cache.erase(last_id);
}

ids.push_front(id);
cache[id] = {/* .id = */ id, /* .content = */ content};
SRV_DBG("reasoning cache add: %s\n", id.c_str());
}

void extract_from_message(const common_chat_msg & msg) {
for (const auto & t : msg.tool_calls) {
if (!t.id.empty() && !msg.reasoning_content.empty()) {
insert(t.id, msg.reasoning_content);
}
}
}

void inject_oaicompat_chat_params(json & body) {
if (!body.contains("messages")) {
return;
}

json & messages = body.at("messages");
if (!messages.is_array()) {
return;
}

for (auto &msg : messages) {
std::string role = json_value(msg, "role", std::string());
if (role != "assistant") {
continue;
}

if (!msg.contains("tool_calls") || msg.contains("reasoning_content")) {
continue;
}

// do not inject if the message contains a non-empty content
if (msg.contains("content")) {
if (msg.at("content").is_string()) {
std::string content = json_value(msg, "content", std::string());
if (!content.empty()) {
continue;
}
} else if (!msg.at("content").empty()) {
continue;
}
}

// inject cached reasoning to tool call messages to support models that require it (gpt-oss)
const json & tool_calls = msg.at("tool_calls");
if (tool_calls.is_array() && !tool_calls.empty()) {
for (const auto & t : tool_calls) {
std::string tool_id = json_value(t, "id", std::string());
if (tool_id.empty()) {
continue;
}

if (auto content = get(tool_id)) {
msg["reasoning_content"] = content;
break;
}
}
}
}
}
};

struct server_task_result {
int id = -1;
int id_slot = -1;
Expand Down Expand Up @@ -1970,6 +2085,9 @@ struct server_context {
common_chat_templates_ptr chat_templates;
oaicompat_parser_options oai_parser_opt;

// reasoning cache
reasoning_cache cache_reasoning;

~server_context() {
mtmd_free(mctx);

Expand Down Expand Up @@ -2174,6 +2292,8 @@ struct server_context {
/* allow_audio */ mctx ? mtmd_support_audio (mctx) : false,
/* enable_thinking */ params_base.reasoning_budget != 0,
};

cache_reasoning.init(params_base.reasoning_cache);
}

server_slot * get_slot_by_id(int id) {
Expand Down Expand Up @@ -2598,6 +2718,10 @@ struct server_context {
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);

if (cache_reasoning.enabled()) {
cache_reasoning.extract_from_message(res->oaicompat_msg);
}

// populate res.probs_output
if (slot.params.sampling.n_probs > 0) {
if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) {
Expand Down Expand Up @@ -4573,6 +4697,9 @@ int main(int argc, char ** argv) {

auto body = json::parse(req.body);
std::vector<raw_buffer> files;
if (ctx_server.cache_reasoning.enabled()) {
ctx_server.cache_reasoning.inject_oaicompat_chat_params(body);
}
json data = oaicompat_chat_params_parse(
body,
ctx_server.oai_parser_opt,
Expand All @@ -4591,6 +4718,9 @@ int main(int argc, char ** argv) {
const auto handle_apply_template = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
auto body = json::parse(req.body);
std::vector<raw_buffer> files; // dummy, unused
if (ctx_server.cache_reasoning.enabled()) {
ctx_server.cache_reasoning.inject_oaicompat_chat_params(body);
}
json data = oaicompat_chat_params_parse(
body,
ctx_server.oai_parser_opt,
Expand Down
Loading