1
1
#include " chat.h"
2
+ #include " ggml.h"
2
3
#include " utils.hpp"
3
4
4
5
#include " arg.h"
@@ -597,6 +598,103 @@ struct result_timings {
597
598
}
598
599
};
599
600
601
+ struct reasoning_cache {
602
+ struct cache_item {
603
+ std::string id;
604
+ std::string content;
605
+ };
606
+
607
+ std::unordered_map<std::string, cache_item> cache;
608
+ std::deque<std::string> ids;
609
+ std::mutex mutex;
610
+ size_t n_size;
611
+
612
+ void init (size_t size = 64 ) {
613
+ SRV_INF (" initializing reasoning cache, n_size = %ld\n " , size);
614
+ n_size = size;
615
+ }
616
+
617
+ bool enabled () const {
618
+ return n_size > 0 ;
619
+ }
620
+
621
+ std::optional<std::string> get (const std::string & id) {
622
+ if (n_size <= 0 ) {
623
+ return std::nullopt;
624
+ }
625
+
626
+ std::unique_lock<std::mutex> lock (mutex);
627
+ auto it = cache.find (id);
628
+ if (it == cache.end ()) {
629
+ SRV_DBG (" reasoning cache miss: %s\n " , id.c_str ());
630
+ return std::nullopt;
631
+ }
632
+
633
+ std::string hit = it->second .content ;
634
+ SRV_DBG (" reasoning cache hit: %s\n " , id.c_str ());
635
+ return hit;
636
+ }
637
+
638
+ void insert (const std::string & id, const std::string & content) {
639
+ if (n_size <= 0 ) {
640
+ return ;
641
+ }
642
+
643
+ std::unique_lock<std::mutex> lock (mutex);
644
+
645
+ if (ids.size () >= n_size) {
646
+ const std::string & last_id = ids.back ();
647
+ ids.pop_back ();
648
+ cache.erase (last_id);
649
+ }
650
+
651
+ ids.push_front (id);
652
+ cache[id] = {/* .id = */ id, /* .content = */ content};
653
+ SRV_DBG (" reasoning cache add: %s\n " , id.c_str ());
654
+ }
655
+
656
+ void extract_from_message (const common_chat_msg & msg) {
657
+ for (const auto & t : msg.tool_calls ) {
658
+ if (!t.id .empty () && !msg.reasoning_content .empty ()) {
659
+ insert (t.id , msg.reasoning_content );
660
+ }
661
+ }
662
+ }
663
+
664
+ void inject_oaicompat_chat_params (json & body) {
665
+ if (!body.contains (" messages" )) {
666
+ return ;
667
+ }
668
+
669
+ json & messages = body.at (" messages" );
670
+ if (!messages.is_array ()) {
671
+ return ;
672
+ }
673
+
674
+ for (auto &msg : messages) {
675
+ if (!msg.contains (" tool_calls" ) || msg.contains (" reasoning_content" )) {
676
+ continue ;
677
+ }
678
+
679
+ // inject cached reasoning to tool call messages to support models that require it (gpt-oss)
680
+ const json & tool_calls = msg.at (" tool_calls" );
681
+ if (tool_calls.is_array () && !tool_calls.empty ()) {
682
+ for (const auto & t : tool_calls) {
683
+ std::string tool_id = json_value (t, " id" , std::string ());
684
+ if (tool_id.empty ()) {
685
+ continue ;
686
+ }
687
+
688
+ if (auto content = get (tool_id)) {
689
+ msg[" reasoning_content" ] = content;
690
+ break ;
691
+ }
692
+ }
693
+ }
694
+ }
695
+ }
696
+ };
697
+
600
698
struct server_task_result {
601
699
int id = -1 ;
602
700
int id_slot = -1 ;
@@ -1961,6 +2059,9 @@ struct server_context {
1961
2059
common_chat_templates_ptr chat_templates;
1962
2060
oaicompat_parser_options oai_parser_opt;
1963
2061
2062
+ // reasoning cache
2063
+ reasoning_cache cache_reasoning;
2064
+
1964
2065
~server_context () {
1965
2066
mtmd_free (mctx);
1966
2067
@@ -2161,6 +2262,8 @@ struct server_context {
2161
2262
/* allow_audio */ mctx ? mtmd_support_audio (mctx) : false ,
2162
2263
/* enable_thinking */ params_base.reasoning_budget != 0 ,
2163
2264
};
2265
+
2266
+ cache_reasoning.init (params_base.reasoning_cache );
2164
2267
}
2165
2268
2166
2269
server_slot * get_slot_by_id (int id) {
@@ -2585,6 +2688,10 @@ struct server_context {
2585
2688
res->oaicompat_cmpl_id = slot.params .oaicompat_cmpl_id ;
2586
2689
res->oaicompat_msg = slot.update_chat_msg (res->oaicompat_msg_diffs );
2587
2690
2691
+ if (cache_reasoning.enabled ()) {
2692
+ cache_reasoning.extract_from_message (res->oaicompat_msg );
2693
+ }
2694
+
2588
2695
// populate res.probs_output
2589
2696
if (slot.params .sampling .n_probs > 0 ) {
2590
2697
if (!slot.params .stream && slot.stop == STOP_TYPE_WORD) {
@@ -4479,6 +4586,9 @@ int main(int argc, char ** argv) {
4479
4586
4480
4587
auto body = json::parse (req.body );
4481
4588
std::vector<raw_buffer> files;
4589
+ if (ctx_server.cache_reasoning .enabled ()) {
4590
+ ctx_server.cache_reasoning .inject_oaicompat_chat_params (body);
4591
+ }
4482
4592
json data = oaicompat_chat_params_parse (
4483
4593
body,
4484
4594
ctx_server.oai_parser_opt ,
@@ -4497,6 +4607,9 @@ int main(int argc, char ** argv) {
4497
4607
const auto handle_apply_template = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
4498
4608
auto body = json::parse (req.body );
4499
4609
std::vector<raw_buffer> files; // dummy, unused
4610
+ if (ctx_server.cache_reasoning .enabled ()) {
4611
+ ctx_server.cache_reasoning .inject_oaicompat_chat_params (body);
4612
+ }
4500
4613
json data = oaicompat_chat_params_parse (
4501
4614
body,
4502
4615
ctx_server.oai_parser_opt ,
0 commit comments