77
88#include " common.h"
99#include " log.h"
10- // Change JSON_ASSERT from assert() to GGML_ASSERT:
11- #define JSON_ASSERT GGML_ASSERT
12- #include " json.hpp"
1310#include " json-schema-to-grammar.h"
1411#include " llama.h"
1512#include " chat.hpp"
@@ -1772,6 +1769,42 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto
17721769// Chat template utils
17731770//
17741771
1772+ common_params_tools::common_params_tools (std::string tools, std::string choice) {
1773+ this ->tools (tools);
1774+ this ->choice (choice);
1775+ }
1776+
1777+ void common_params_tools::tools (std::string tools) {
1778+ try {
1779+ tools_ = std::make_shared<json>(json::parse (tools));
1780+ if (! tools_->is_array ()) {
1781+ throw std::invalid_argument (" tools must be a valid JSON array" );
1782+ }
1783+
1784+ } catch (const json::exception & err) {
1785+ throw std::invalid_argument (err.what ());
1786+ }
1787+ }
1788+
1789+ void common_params_tools::choice (std::string choice) {
1790+ try {
1791+ if (choice == " auto" || choice == " required" || choice == " none" ) {
1792+ tool_choice_ = std::move (choice);
1793+
1794+ } else {
1795+ auto choice_ptr = std::make_shared<json>(json::parse (choice));
1796+ tool_choice_ = choice_ptr;
1797+ if (! choice_ptr->is_object ()) {
1798+ throw std::invalid_argument (
1799+ " tool choice must be a valid JSON object, \" auto\" , \" required\" , or \" none\" " );
1800+ }
1801+ }
1802+
1803+ } catch (const json::exception & err) {
1804+ throw std::invalid_argument (err.what ());
1805+ }
1806+ }
1807+
17751808bool common_chat_verify_template (const std::string & tmpl, bool use_jinja) {
17761809 if (use_jinja) {
17771810 try {
@@ -1798,7 +1831,7 @@ std::string common_chat_apply_template(
17981831 const std::vector<common_chat_msg> & msgs,
17991832 bool add_ass,
18001833 bool use_jinja,
1801- std::string tools_json_arr )
1834+ const common_params_tools & tools )
18021835{
18031836 if (use_jinja) {
18041837 common_chat_inputs inputs;
@@ -1807,17 +1840,19 @@ std::string common_chat_apply_template(
18071840 for (const auto & msg : msgs) {
18081841 messages.push_back ({{" role" , msg.role }, {" content" , msg.content }});
18091842 }
1843+ if (tools.tools () != nullptr ) {
1844+ inputs.tools = *tools.tools ();
1845+ }
1846+ auto choice = tools.choice ();
1847+ if (std::holds_alternative<std::string>(choice)) {
1848+ inputs.tool_choice = std::get<std::string>(choice);
18101849
1811- if (! tools_json_arr.empty ()) {
1812- try {
1813- inputs.tools = tools_json_arr;
1814-
1815- } catch (const json::exception & err) {
1816- LOG_WRN (" Failed to parse tools JSON array \" %s\" : \" %s\" . Ignoring tools...\n " ,
1817- tools_json_arr.c_str (), err.what ());
1850+ } else {
1851+ auto choice_ptr = std::get<common_params_tools::json_ptr>(choice);
1852+ if (choice_ptr != nullptr ) {
1853+ inputs.tool_choice = *choice_ptr;
18181854 }
18191855 }
1820-
18211856 inputs.messages = messages;
18221857 inputs.add_generation_prompt = add_ass;
18231858 return common_chat_params_init (tmpl, inputs).prompt ;
@@ -1858,11 +1893,11 @@ std::string common_chat_format_single(
18581893 const common_chat_msg & new_msg,
18591894 bool add_ass,
18601895 bool use_jinja,
1861- std::string tools_json_arr )
1896+ const common_params_tools & tools )
18621897{
18631898 std::ostringstream ss;
18641899 auto fmt_past_msg = past_msg.empty () ? " "
1865- : common_chat_apply_template (tmpl, past_msg, false , use_jinja, tools_json_arr );
1900+ : common_chat_apply_template (tmpl, past_msg, false , use_jinja, tools );
18661901
18671902 std::vector<common_chat_msg> chat_new (past_msg);
18681903 // if the past_msg ends with a newline, we must preserve it in the formatted version
@@ -1871,7 +1906,7 @@ std::string common_chat_format_single(
18711906 };
18721907 // format chat with new_msg
18731908 chat_new.push_back (new_msg);
1874- auto fmt_new_msg = common_chat_apply_template (tmpl, chat_new, add_ass, use_jinja, tools_json_arr );
1909+ auto fmt_new_msg = common_chat_apply_template (tmpl, chat_new, add_ass, use_jinja, tools );
18751910 // get the diff part
18761911 ss << fmt_new_msg.substr (fmt_past_msg.size (), fmt_new_msg.size () - fmt_past_msg.size ());
18771912 return ss.str ();
0 commit comments