77
88#include  " common.h" 
99#include  " log.h" 
10+ //  Change JSON_ASSERT from assert() to GGML_ASSERT:
11+ #define  JSON_ASSERT  GGML_ASSERT
12+ #include  " json.hpp" 
1013#include  " json-schema-to-grammar.h" 
1114#include  " llama.h" 
1215#include  " chat.hpp" 
@@ -1769,46 +1772,6 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto
17691772//  Chat template utils
17701773// 
17711774
1772- common_params_tools::common_params_tools (std::string tools, std::string choice) {
1773-     this ->tools (tools);
1774-     this ->choice (choice);
1775- }
1776- 
1777- void  common_params_tools::tools (std::string tools) {
1778-     if  (tools.empty ()) {
1779-         tools_.reset ();
1780-         return ;
1781-     }
1782-     try  {
1783-         tools_ = std::make_shared<json>(json::parse (tools));
1784-         if  (! tools_->is_array ()) {
1785-             throw  std::invalid_argument (" tools must be a valid JSON array" 
1786-         }
1787- 
1788-     } catch  (const  json::exception & err) {
1789-         throw  std::invalid_argument (err.what ());
1790-     }
1791- }
1792- 
1793- void  common_params_tools::choice (std::string choice) {
1794-     try  {
1795-         if  (choice == " auto" " required" " none" 
1796-             tool_choice_ = std::move (choice);
1797- 
1798-         } else  {
1799-             auto  choice_ptr = std::make_shared<json>(json::parse (choice));
1800-             tool_choice_ = choice_ptr;
1801-             if  (! choice_ptr->is_object ()) {
1802-                 throw  std::invalid_argument (
1803-                     " tool choice must be a valid JSON object, \" auto\" , \" required\" , or \" none\" " 
1804-             }
1805-         }
1806- 
1807-     } catch  (const  json::exception & err) {
1808-         throw  std::invalid_argument (err.what ());
1809-     }
1810- }
1811- 
18121775bool  common_chat_verify_template (const  std::string & tmpl, bool  use_jinja) {
18131776    if  (use_jinja) {
18141777        try  {
@@ -1830,7 +1793,7 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
18301793    return  res >= 0 ;
18311794}
18321795
1833- static  void  copy_chat_params (const  common_chat_params & src, common_chat_sampling_updater  * update_sparams)
1796+ static  void  copy_chat_params (const  common_chat_params & src, toolcall::sampling_updater  * update_sparams)
18341797{
18351798    GGML_ASSERT (update_sparams && update_sparams->sparams  && update_sparams->vocab );
18361799
@@ -1873,11 +1836,11 @@ std::string common_chat_apply_template(
18731836        const  std::vector<common_chat_msg> & msgs,
18741837        bool  add_ass,
18751838        bool  use_jinja,
1876-         const  common_params_tools & tools ,
1877-         common_chat_sampling_updater  * update_sparams)
1839+         toolcall::handler::ptr handler ,
1840+         toolcall::sampling_updater  * update_sparams)
18781841{
18791842    const  auto  & tmpl_selected =
1880-         tools. tools ()  && tmpl.template_tool_use  ? *tmpl.template_tool_use  : *tmpl.template_default ;
1843+         handler !=  nullptr  && tmpl.template_tool_use  ? *tmpl.template_tool_use  : *tmpl.template_default ;
18811844
18821845    if  (use_jinja) {
18831846        common_chat_inputs inputs;
@@ -1886,29 +1849,38 @@ std::string common_chat_apply_template(
18861849        for  (const  auto  & msg : msgs) {
18871850            messages.push_back ({{" role" role }, {" content" content }});
18881851        }
1852+         inputs.messages  = messages;
1853+         inputs.add_generation_prompt  = add_ass;
18891854
1890-         if  (tools.tools () != nullptr ) {
1891-             inputs.tools  = *tools.tools ();
1892-         }
1893- 
1894-         auto  choice = tools.choice ();
1895-         if  (std::holds_alternative<std::string>(choice)) {
1896-             inputs.tool_choice  = std::get<std::string>(choice);
1855+         if  (handler != nullptr ) {
1856+             auto  choice = handler->tool_choice ();
1857+             if  (std::holds_alternative<std::string>(choice)) {
1858+                 inputs.tool_choice  = std::get<std::string>(choice);
18971859
1898-         } else  {
1899-             auto  choice_ptr = std::get<common_params_tools::json_ptr>(choice);
1900-             if  (choice_ptr != nullptr ) {
1901-                 inputs.tool_choice  = *choice_ptr;
1860+             } else  {
1861+                 auto  choice_ptr = std::get<toolcall::json_ptr>(choice);
1862+                 if  (choice_ptr != nullptr ) {
1863+                     inputs.tool_choice  = *choice_ptr;
1864+                 }
19021865            }
1866+ 
1867+             inputs.tools  = handler->tool_list ();
19031868        }
19041869
1905-         inputs.messages  = messages;
1906-         inputs.add_generation_prompt  = add_ass;
19071870        auto  chat_params = common_chat_params_init (tmpl_selected, inputs);
19081871        if  (update_sparams) {
19091872            copy_chat_params (chat_params, update_sparams);
19101873        }
1911-         return  chat_params.prompt ;
1874+ 
1875+         auto  prompt = chat_params.prompt ;
1876+         if  (handler != nullptr ) {
1877+             json response;
1878+             handler->call (prompt, response);
1879+             return  response; //  Caller will determine what to do based upon last_action
1880+ 
1881+         } else  {
1882+             return  prompt;
1883+         }
19121884    }
19131885
19141886    int  alloc_size = 0 ;
@@ -1946,12 +1918,12 @@ std::string common_chat_format_single(
19461918        const  common_chat_msg & new_msg,
19471919        bool  add_ass,
19481920        bool  use_jinja,
1949-         const  common_params_tools & tools ,
1950-         common_chat_sampling_updater  * update_sparams)
1921+         toolcall::handler::ptr handler ,
1922+         toolcall::sampling_updater  * update_sparams)
19511923{
19521924    std::ostringstream ss;
19531925    auto  fmt_past_msg = past_msg.empty () ? " " 
1954-         : common_chat_apply_template (tmpl, past_msg, false , use_jinja, tools , update_sparams);
1926+         : common_chat_apply_template (tmpl, past_msg, false , use_jinja, handler , update_sparams);
19551927
19561928    std::vector<common_chat_msg> chat_new (past_msg);
19571929    //  if the past_msg ends with a newline, we must preserve it in the formatted version
@@ -1960,7 +1932,7 @@ std::string common_chat_format_single(
19601932    };
19611933    //  format chat with new_msg
19621934    chat_new.push_back (new_msg);
1963-     auto  fmt_new_msg = common_chat_apply_template (tmpl, chat_new, add_ass, use_jinja, tools , update_sparams);
1935+     auto  fmt_new_msg = common_chat_apply_template (tmpl, chat_new, add_ass, use_jinja, handler , update_sparams);
19641936    //  get the diff part
19651937    ss << fmt_new_msg.substr (fmt_past_msg.size (), fmt_new_msg.size () - fmt_past_msg.size ());
19661938    return  ss.str ();
0 commit comments