Skip to content

Commit becf9b4

Browse files
committed
add tool-choice parameter
1 parent 352f79c commit becf9b4

File tree

4 files changed

+103
-31
lines changed

4 files changed

+103
-31
lines changed

common/arg.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1993,15 +1993,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
19931993
std::back_inserter(params.chat_template));
19941994
}
19951995
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE"));
1996+
19961997
add_opt(common_arg(
19971998
{"--tools"}, "JINJA_TOOLS",
1998-
string_format(
1999-
"set to a JSON array of tool definitions used for assistant function-calling "
2000-
"(requires --jinja)"),
1999+
"set to JSON array of tool definitions used for assistant function-calling (requires --jinja)",
20012000
[](common_params &params, const std::string & value) {
2002-
params.jinja_tools = value;
2003-
}
2004-
).set_examples({LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA_TOOLS"));
2001+
params.jinja_tools.tools(value);
2002+
2003+
}).set_examples({LLAMA_EXAMPLE_MAIN}));
2004+
2005+
add_opt(common_arg(
2006+
{"--tool-choice"}, "JINJA_TOOL_CHOICE",
2007+
"set to \"auto\", \"required\", \"none\" or a JSON object specifying a tool function (default: \"auto\")",
2008+
[](common_params &params, const std::string & value) {
2009+
params.jinja_tools.choice(value);
2010+
2011+
}).set_examples({LLAMA_EXAMPLE_MAIN}));
2012+
20052013
add_opt(common_arg(
20062014
{"-sps", "--slot-prompt-similarity"}, "SIMILARITY",
20072015
string_format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity),

common/common.cpp

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@
77

88
#include "common.h"
99
#include "log.h"
10-
// Change JSON_ASSERT from assert() to GGML_ASSERT:
11-
#define JSON_ASSERT GGML_ASSERT
12-
#include "json.hpp"
1310
#include "json-schema-to-grammar.h"
1411
#include "llama.h"
1512
#include "chat.hpp"
@@ -1772,6 +1769,42 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto
17721769
// Chat template utils
17731770
//
17741771

1772+
common_params_tools::common_params_tools(std::string tools, std::string choice) {
1773+
this->tools(tools);
1774+
this->choice(choice);
1775+
}
1776+
1777+
void common_params_tools::tools(std::string tools) {
1778+
try {
1779+
tools_ = std::make_shared<json>(json::parse(tools));
1780+
if (! tools_->is_array()) {
1781+
throw std::invalid_argument("tools must be a valid JSON array");
1782+
}
1783+
1784+
} catch (const json::exception & err) {
1785+
throw std::invalid_argument(err.what());
1786+
}
1787+
}
1788+
1789+
void common_params_tools::choice(std::string choice) {
1790+
try {
1791+
if (choice == "auto" || choice == "required" || choice == "none") {
1792+
tool_choice_ = std::move(choice);
1793+
1794+
} else {
1795+
auto choice_ptr = std::make_shared<json>(json::parse(choice));
1796+
tool_choice_ = choice_ptr;
1797+
if (! choice_ptr->is_object()) {
1798+
throw std::invalid_argument(
1799+
"tool choice must be a valid JSON object, \"auto\", \"required\", or \"none\"");
1800+
}
1801+
}
1802+
1803+
} catch (const json::exception & err) {
1804+
throw std::invalid_argument(err.what());
1805+
}
1806+
}
1807+
17751808
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
17761809
if (use_jinja) {
17771810
try {
@@ -1798,7 +1831,7 @@ std::string common_chat_apply_template(
17981831
const std::vector<common_chat_msg> & msgs,
17991832
bool add_ass,
18001833
bool use_jinja,
1801-
std::string tools_json_arr)
1834+
const common_params_tools & tools)
18021835
{
18031836
if (use_jinja) {
18041837
common_chat_inputs inputs;
@@ -1807,17 +1840,19 @@ std::string common_chat_apply_template(
18071840
for (const auto & msg : msgs) {
18081841
messages.push_back({{"role", msg.role}, {"content", msg.content}});
18091842
}
1843+
if (tools.tools() != nullptr) {
1844+
inputs.tools = *tools.tools();
1845+
}
1846+
auto choice = tools.choice();
1847+
if (std::holds_alternative<std::string>(choice)) {
1848+
inputs.tool_choice = std::get<std::string>(choice);
18101849

1811-
if (! tools_json_arr.empty()) {
1812-
try {
1813-
inputs.tools = tools_json_arr;
1814-
1815-
} catch (const json::exception & err) {
1816-
LOG_WRN("Failed to parse tools JSON array \"%s\": \"%s\". Ignoring tools...\n",
1817-
tools_json_arr.c_str(), err.what());
1850+
} else {
1851+
auto choice_ptr = std::get<common_params_tools::json_ptr>(choice);
1852+
if (choice_ptr != nullptr) {
1853+
inputs.tool_choice = *choice_ptr;
18181854
}
18191855
}
1820-
18211856
inputs.messages = messages;
18221857
inputs.add_generation_prompt = add_ass;
18231858
return common_chat_params_init(tmpl, inputs).prompt;
@@ -1858,11 +1893,11 @@ std::string common_chat_format_single(
18581893
const common_chat_msg & new_msg,
18591894
bool add_ass,
18601895
bool use_jinja,
1861-
std::string tools_json_arr)
1896+
const common_params_tools & tools)
18621897
{
18631898
std::ostringstream ss;
18641899
auto fmt_past_msg = past_msg.empty() ? ""
1865-
: common_chat_apply_template(tmpl, past_msg, false, use_jinja, tools_json_arr);
1900+
: common_chat_apply_template(tmpl, past_msg, false, use_jinja, tools);
18661901

18671902
std::vector<common_chat_msg> chat_new(past_msg);
18681903
// if the past_msg ends with a newline, we must preserve it in the formatted version
@@ -1871,7 +1906,7 @@ std::string common_chat_format_single(
18711906
};
18721907
// format chat with new_msg
18731908
chat_new.push_back(new_msg);
1874-
auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, tools_json_arr);
1909+
auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, tools);
18751910
// get the diff part
18761911
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
18771912
return ss.str();

common/common.h

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
#include <string>
99
#include <vector>
1010
#include <sstream>
11+
// Change JSON_ASSERT from assert() to GGML_ASSERT:
12+
#define JSON_ASSERT GGML_ASSERT
13+
#include "json.hpp"
1114

1215
#ifdef _WIN32
1316
#define DIRECTORY_SEPARATOR '\\'
@@ -202,6 +205,31 @@ struct common_params_vocoder {
202205
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
203206
};
204207

208+
class common_params_tools {
209+
public:
210+
using json = nlohmann::ordered_json;
211+
using json_ptr = std::shared_ptr<json>;
212+
using tool_choice_t = std::variant<std::string, json_ptr>;
213+
214+
common_params_tools(std::string tools = "",
215+
std::string choice = "auto");
216+
217+
common_params_tools(const common_params_tools & other) = default;
218+
common_params_tools(common_params_tools && other) noexcept = default;
219+
common_params_tools & operator=(const common_params_tools & other) = default;
220+
common_params_tools & operator=(common_params_tools && other) noexcept = default;
221+
222+
void tools(std::string tools);
223+
const json * tools() const { return tools_.get(); }
224+
225+
void choice(std::string choice);
226+
const tool_choice_t & choice() const { return tool_choice_; }
227+
228+
private:
229+
json_ptr tools_;
230+
tool_choice_t tool_choice_;
231+
};
232+
205233
struct common_params {
206234
int32_t n_predict = -1; // new tokens to predict
207235
int32_t n_ctx = 4096; // context size
@@ -346,7 +374,8 @@ struct common_params {
346374
std::string chat_template = ""; // NOLINT
347375
bool use_jinja = false; // NOLINT
348376
bool enable_chat_template = true;
349-
std::string jinja_tools = "";
377+
common_params_tools jinja_tools;
378+
350379
std::vector<std::string> api_keys;
351380

352381
std::string ssl_file_key = ""; // NOLINT
@@ -649,7 +678,7 @@ std::string common_chat_apply_template(
649678
const std::vector<common_chat_msg> & chat,
650679
bool add_ass,
651680
bool use_jinja,
652-
std::string tools_json_arr = std::string());
681+
const common_params_tools & tools = common_params_tools());
653682

654683
// Format single message, while taking into account the position of that message in chat history
655684
std::string common_chat_format_single(
@@ -658,7 +687,7 @@ std::string common_chat_format_single(
658687
const common_chat_msg & new_msg,
659688
bool add_ass,
660689
bool use_jinja,
661-
std::string tools_json_arr = std::string());
690+
const common_params_tools & tools = common_params_tools());
662691

663692
// Returns an example of formatted chat
664693
std::string common_chat_format_example(

examples/main/main.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -263,14 +263,14 @@ int main(int argc, char ** argv) {
263263

264264
std::vector<llama_token> embd_inp;
265265

266-
auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role,
267-
const std::string & content,
268-
const std::string & tools = std::string()) {
266+
auto chat_add_and_format = [&chat_msgs, &chat_templates](
267+
const std::string & role, const std::string & content,
268+
const common_params_tools & tools = common_params_tools())
269+
{
269270
common_chat_msg new_msg{role, content, {}};
270-
auto formatted = common_chat_format_single(*chat_templates.template_default,
271-
chat_msgs, new_msg,
272-
role == "user",
273-
g_params->use_jinja, tools);
271+
auto formatted = common_chat_format_single(
272+
*chat_templates.template_default, chat_msgs, new_msg, role == "user",
273+
g_params->use_jinja, tools);
274274

275275
chat_msgs.push_back({role, content, {}});
276276
LOG_DBG("formatted: '%s'\n", formatted.c_str());

0 commit comments

Comments
 (0)