1
1
#include " chat.h"
2
+ #include " ggml.h"
2
3
#include " utils.hpp"
3
4
4
5
#include " arg.h"
@@ -593,6 +594,103 @@ struct result_timings {
593
594
}
594
595
};
595
596
597
+ struct reasoning_cache {
598
+ struct cache_item {
599
+ std::string id;
600
+ std::string content;
601
+ };
602
+
603
+ std::unordered_map<std::string, cache_item> cache;
604
+ std::deque<std::string> ids;
605
+ std::mutex mutex;
606
+ size_t n_size;
607
+
608
+ void init (size_t size = 64 ) {
609
+ SRV_INF (" initializing reasoning cache, n_size = %ld\n " , size);
610
+ n_size = size;
611
+ }
612
+
613
+ bool enabled () const {
614
+ return n_size > 0 ;
615
+ }
616
+
617
+ std::optional<std::string> get (const std::string & id) {
618
+ if (n_size <= 0 ) {
619
+ return std::nullopt;
620
+ }
621
+
622
+ std::unique_lock<std::mutex> lock (mutex);
623
+ auto it = cache.find (id);
624
+ if (it == cache.end ()) {
625
+ SRV_DBG (" reasoning cache miss: %s\n " , id.c_str ());
626
+ return std::nullopt;
627
+ }
628
+
629
+ std::string hit = it->second .content ;
630
+ SRV_DBG (" reasoning cache hit: %s\n " , id.c_str ());
631
+ return hit;
632
+ }
633
+
634
+ void insert (const std::string & id, const std::string & content) {
635
+ if (n_size <= 0 ) {
636
+ return ;
637
+ }
638
+
639
+ std::unique_lock<std::mutex> lock (mutex);
640
+
641
+ if (ids.size () >= n_size) {
642
+ const std::string & last_id = ids.back ();
643
+ ids.pop_back ();
644
+ cache.erase (last_id);
645
+ }
646
+
647
+ ids.push_front (id);
648
+ cache[id] = {/* .id = */ id, /* .content = */ content};
649
+ SRV_DBG (" reasoning cache add: %s\n " , id.c_str ());
650
+ }
651
+
652
+ void extract_from_message (const common_chat_msg & msg) {
653
+ for (const auto & t : msg.tool_calls ) {
654
+ if (!t.id .empty () && !msg.reasoning_content .empty ()) {
655
+ insert (t.id , msg.reasoning_content );
656
+ }
657
+ }
658
+ }
659
+
660
+ void inject_oaicompat_chat_params (json & body) {
661
+ if (!body.contains (" messages" )) {
662
+ return ;
663
+ }
664
+
665
+ json & messages = body.at (" messages" );
666
+ if (!messages.is_array ()) {
667
+ return ;
668
+ }
669
+
670
+ for (auto &msg : messages) {
671
+ if (!msg.contains (" tool_calls" ) || msg.contains (" reasoning_content" )) {
672
+ continue ;
673
+ }
674
+
675
+ // inject cached reasoning to tool call messages to support models that require it (gpt-oss)
676
+ const json & tool_calls = msg.at (" tool_calls" );
677
+ if (tool_calls.is_array () && !tool_calls.empty ()) {
678
+ for (const auto & t : tool_calls) {
679
+ std::string tool_id = json_value (t, " id" , std::string ());
680
+ if (tool_id.empty ()) {
681
+ continue ;
682
+ }
683
+
684
+ if (auto content = get (tool_id)) {
685
+ msg[" reasoning_content" ] = content;
686
+ break ;
687
+ }
688
+ }
689
+ }
690
+ }
691
+ }
692
+ };
693
+
596
694
struct server_task_result {
597
695
int id = -1 ;
598
696
int id_slot = -1 ;
@@ -1957,6 +2055,9 @@ struct server_context {
1957
2055
common_chat_templates_ptr chat_templates;
1958
2056
oaicompat_parser_options oai_parser_opt;
1959
2057
2058
+ // reasoning cache
2059
+ reasoning_cache cache_reasoning;
2060
+
1960
2061
~server_context () {
1961
2062
mtmd_free (mctx);
1962
2063
@@ -2157,6 +2258,8 @@ struct server_context {
2157
2258
/* allow_audio */ mctx ? mtmd_support_audio (mctx) : false ,
2158
2259
/* enable_thinking */ params_base.reasoning_budget != 0 ,
2159
2260
};
2261
+
2262
+ cache_reasoning.init (params_base.reasoning_cache );
2160
2263
}
2161
2264
2162
2265
server_slot * get_slot_by_id (int id) {
@@ -2581,6 +2684,10 @@ struct server_context {
2581
2684
res->oaicompat_cmpl_id = slot.params .oaicompat_cmpl_id ;
2582
2685
res->oaicompat_msg = slot.update_chat_msg (res->oaicompat_msg_diffs );
2583
2686
2687
+ if (cache_reasoning.enabled ()) {
2688
+ cache_reasoning.extract_from_message (res->oaicompat_msg );
2689
+ }
2690
+
2584
2691
// populate res.probs_output
2585
2692
if (slot.params .sampling .n_probs > 0 ) {
2586
2693
if (!slot.params .stream && slot.stop == STOP_TYPE_WORD) {
@@ -4475,6 +4582,9 @@ int main(int argc, char ** argv) {
4475
4582
4476
4583
auto body = json::parse (req.body );
4477
4584
std::vector<raw_buffer> files;
4585
+ if (ctx_server.cache_reasoning .enabled ()) {
4586
+ ctx_server.cache_reasoning .inject_oaicompat_chat_params (body);
4587
+ }
4478
4588
json data = oaicompat_chat_params_parse (
4479
4589
body,
4480
4590
ctx_server.oai_parser_opt ,
@@ -4493,6 +4603,9 @@ int main(int argc, char ** argv) {
4493
4603
const auto handle_apply_template = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
4494
4604
auto body = json::parse (req.body );
4495
4605
std::vector<raw_buffer> files; // dummy, unused
4606
+ if (ctx_server.cache_reasoning .enabled ()) {
4607
+ ctx_server.cache_reasoning .inject_oaicompat_chat_params (body);
4608
+ }
4496
4609
json data = oaicompat_chat_params_parse (
4497
4610
body,
4498
4611
ctx_server.oai_parser_opt ,
0 commit comments