Skip to content

Commit a345aa9

Browse files
committed
Merge branch 'llamacli-tools-sse' into llamacli-tools
2 parents 608304f + a9e3404 commit a345aa9

File tree

8 files changed

+198
-171
lines changed

8 files changed

+198
-171
lines changed

common/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ add_library(${TARGET} STATIC
7575
sampling.h
7676
speculative.cpp
7777
speculative.h
78+
${CMAKE_CURRENT_SOURCE_DIR}/toolcall/params.cpp
79+
${CMAKE_CURRENT_SOURCE_DIR}/toolcall/params.hpp
7880
${CMAKE_CURRENT_SOURCE_DIR}/toolcall/handler.cpp
7981
${CMAKE_CURRENT_SOURCE_DIR}/toolcall/handler.hpp
8082
${CMAKE_CURRENT_SOURCE_DIR}/toolcall/mcp_messages.cpp

common/common.cpp

Lines changed: 46 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1768,42 +1768,18 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto
17681768
return text;
17691769
}
17701770

1771-
//
1772-
// Chat template utils
1773-
//
1774-
1775-
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
1776-
if (use_jinja) {
1777-
try {
1778-
auto chat_template = common_chat_template(tmpl, "<s>", "</s>");
1779-
common_chat_inputs inputs;
1780-
inputs.messages = json::array({{
1781-
{"role", "user"},
1782-
{"content", "test"},
1783-
}});
1784-
common_chat_params_init(chat_template, inputs);
1785-
return true;
1786-
} catch (const std::exception & e) {
1787-
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
1788-
return false;
1789-
}
1790-
}
1791-
llama_chat_message chat[] = {{"user", "test"}};
1792-
const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
1793-
return res >= 0;
1794-
}
1795-
1796-
static void copy_chat_params(const common_chat_params & src, toolcall::sampling_updater * update_sparams)
1771+
void common_chat_grammar_to_sampler(const common_chat_params * src,
1772+
const llama_vocab * vocab,
1773+
common_params_sampling * sparams)
17971774
{
1798-
GGML_ASSERT(update_sparams && update_sparams->sparams && update_sparams->vocab);
1775+
GGML_ASSERT(src && vocab && sparams);
17991776

1800-
auto & dst = *update_sparams->sparams;
1801-
auto vocab = update_sparams->vocab;
1777+
auto & dst = *sparams;
18021778

1803-
dst.grammar = src.grammar;
1804-
dst.grammar_lazy = src.grammar_lazy;
1779+
dst.grammar = src->grammar;
1780+
dst.grammar_lazy = src->grammar_lazy;
18051781

1806-
for (const auto & trigger : src.grammar_triggers) {
1782+
for (const auto & trigger : src->grammar_triggers) {
18071783
auto ids = common_tokenize(vocab, trigger.word, false, true);
18081784

18091785
if (ids.size() == 1) {
@@ -1816,7 +1792,7 @@ static void copy_chat_params(const common_chat_params & src, toolcall::sampling_
18161792
dst.grammar_trigger_words.push_back(trigger);
18171793
}
18181794

1819-
for (const auto & preserved : src.preserved_tokens) {
1795+
for (const auto & preserved : src->preserved_tokens) {
18201796
auto ids = common_tokenize(vocab, preserved, false, true);
18211797
if (ids.size() == 1) {
18221798
LOG_DBG("Preserved token: %d\n", ids[0]);
@@ -1831,19 +1807,45 @@ static void copy_chat_params(const common_chat_params & src, toolcall::sampling_
18311807
}
18321808
}
18331809

1810+
1811+
//
1812+
// Chat template utils
1813+
//
1814+
1815+
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
1816+
if (use_jinja) {
1817+
try {
1818+
auto chat_template = common_chat_template(tmpl, "<s>", "</s>");
1819+
common_chat_inputs inputs;
1820+
inputs.messages = json::array({{
1821+
{"role", "user"},
1822+
{"content", "test"},
1823+
}});
1824+
common_chat_params_init(chat_template, inputs);
1825+
return true;
1826+
} catch (const std::exception & e) {
1827+
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
1828+
return false;
1829+
}
1830+
}
1831+
llama_chat_message chat[] = {{"user", "test"}};
1832+
const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
1833+
return res >= 0;
1834+
}
1835+
18341836
std::string common_chat_apply_template(
18351837
const common_chat_templates & tmpl,
18361838
const std::vector<common_chat_msg> & msgs,
18371839
bool add_ass,
18381840
bool use_jinja,
1839-
toolcall::handler::ptr handler,
1840-
toolcall::sampling_updater * update_sparams)
1841+
const common_chat_inputs * inputs_,
1842+
common_chat_params * out_params)
18411843
{
1842-
bool use_tool_template = (use_jinja && handler != nullptr) && tmpl.template_tool_use;
1844+
bool use_tool_template = use_jinja && tmpl.template_tool_use;
18431845
const auto & tmpl_selected = use_tool_template ? *tmpl.template_tool_use : *tmpl.template_default;
18441846

18451847
if (use_jinja) {
1846-
common_chat_inputs inputs;
1848+
common_chat_inputs inputs = inputs_ ? *inputs_ : common_chat_inputs();
18471849

18481850
auto messages = json::array();
18491851
for (const auto & msg : msgs) {
@@ -1852,35 +1854,11 @@ std::string common_chat_apply_template(
18521854
inputs.messages = messages;
18531855
inputs.add_generation_prompt = add_ass;
18541856

1855-
if (handler != nullptr) {
1856-
auto choice = handler->tool_choice();
1857-
if (std::holds_alternative<std::string>(choice)) {
1858-
inputs.tool_choice = std::get<std::string>(choice);
1859-
1860-
} else {
1861-
auto choice_ptr = std::get<toolcall::json_ptr>(choice);
1862-
if (choice_ptr != nullptr) {
1863-
inputs.tool_choice = *choice_ptr;
1864-
}
1865-
}
1866-
1867-
inputs.tools = handler->tool_list();
1868-
}
1869-
18701857
auto chat_params = common_chat_params_init(tmpl_selected, inputs);
1871-
if (update_sparams) {
1872-
copy_chat_params(chat_params, update_sparams);
1873-
}
1874-
1875-
auto prompt = chat_params.prompt;
1876-
if (handler != nullptr) {
1877-
json response;
1878-
handler->call(prompt, response);
1879-
return response; // Caller will determine what to do based upon last_action
1880-
1881-
} else {
1882-
return prompt;
1858+
if (out_params != nullptr) {
1859+
*out_params = chat_params;
18831860
}
1861+
return chat_params.prompt;
18841862
}
18851863

18861864
int alloc_size = 0;
@@ -1918,12 +1896,12 @@ std::string common_chat_format_single(
19181896
const common_chat_msg & new_msg,
19191897
bool add_ass,
19201898
bool use_jinja,
1921-
toolcall::handler::ptr handler,
1922-
toolcall::sampling_updater * update_sparams)
1899+
const common_chat_inputs * inputs,
1900+
common_chat_params * out_params)
19231901
{
19241902
std::ostringstream ss;
19251903
auto fmt_past_msg = past_msg.empty() ? ""
1926-
: common_chat_apply_template(tmpl, past_msg, false, use_jinja, handler, update_sparams);
1904+
: common_chat_apply_template(tmpl, past_msg, false, use_jinja, inputs);
19271905

19281906
std::vector<common_chat_msg> chat_new(past_msg);
19291907
// if the past_msg ends with a newline, we must preserve it in the formatted version
@@ -1932,7 +1910,7 @@ std::string common_chat_format_single(
19321910
};
19331911
// format chat with new_msg
19341912
chat_new.push_back(new_msg);
1935-
auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, handler, update_sparams);
1913+
auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, inputs, out_params);
19361914
// get the diff part
19371915
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
19381916
return ss.str();

common/common.h

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,13 @@ std::string common_detokenize(
618618
const std::vector<llama_token> & tokens,
619619
bool special = true);
620620

621+
struct common_chat_params;
622+
struct common_chat_inputs;
623+
void common_chat_grammar_to_sampler(const common_chat_params * src,
624+
const llama_vocab * vocab,
625+
common_params_sampling * sparams);
626+
627+
621628
//
622629
// Chat template utils
623630
//
@@ -651,13 +658,6 @@ struct common_chat_templates {
651658
std::unique_ptr<common_chat_template> template_tool_use;
652659
};
653660

654-
namespace toolcall {
655-
struct sampling_updater {
656-
common_params_sampling * sparams;
657-
const llama_vocab * vocab;
658-
};
659-
}
660-
661661
// CPP wrapper for llama_chat_apply_template
662662
// If the built-in template is not supported, we default to chatml
663663
// If the custom "tmpl" is not supported, we throw an error
@@ -666,8 +666,8 @@ std::string common_chat_apply_template(
666666
const std::vector<common_chat_msg> & chat,
667667
bool add_ass,
668668
bool use_jinja,
669-
toolcall::handler::ptr handler = nullptr,
670-
toolcall::sampling_updater * update_sparams = nullptr);
669+
const common_chat_inputs * inputs = nullptr,
670+
common_chat_params * out_params = nullptr);
671671

672672
// Format single message, while taking into account the position of that message in chat history
673673
std::string common_chat_format_single(
@@ -676,8 +676,8 @@ std::string common_chat_format_single(
676676
const common_chat_msg & new_msg,
677677
bool add_ass,
678678
bool use_jinja,
679-
toolcall::handler::ptr handler = nullptr,
680-
toolcall::sampling_updater * update_sparams = nullptr);
679+
const common_chat_inputs * inputs = nullptr,
680+
common_chat_params * out_params = nullptr);
681681

682682
// Returns an example of formatted chat
683683
std::string common_chat_format_example(

common/toolcall/handler.cpp

Lines changed: 2 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11

2+
#include "../json.hpp"
23
#include "handler.hpp"
4+
#include "params.hpp"
35

46
#ifdef LLAMA_USE_CURL
57
# include "mcp_sse_transport.hpp"
@@ -9,16 +11,6 @@
911

1012
using json = toolcall::json;
1113

12-
toolcall::params::params(std::string tools, std::string choice) {
13-
this->tools(tools);
14-
this->choice(choice);
15-
}
16-
17-
static bool starts_with(const std::string & str, const std::string & prefix) {
18-
return str.size() >= prefix.size()
19-
&& str.compare(0, prefix.size(), prefix) == 0;
20-
}
21-
2214
std::shared_ptr<toolcall::handler> toolcall::create_handler(const toolcall::params & params) {
2315
std::shared_ptr<toolcall::handler> result;
2416

@@ -42,62 +34,6 @@ std::shared_ptr<toolcall::handler> toolcall::create_handler(const toolcall::para
4234
return result;
4335
}
4436

45-
void toolcall::params::tools(std::string tools) {
46-
try {
47-
48-
if (tools.empty()) {
49-
tools_ = std::move(tools);
50-
51-
} else if (starts_with(tools, "mcp+http")) {
52-
#ifdef LLAMA_USE_CURL
53-
tools_ = std::move(tools);
54-
#else
55-
throw std::invalid_argument(
56-
"Model Context Protocol (MCP) only works when llama.cpp is compiled with libcurl");
57-
#endif
58-
} else {
59-
tools_ = std::make_shared<json>(json::parse(tools));
60-
auto tools_ptr = std::get<std::shared_ptr<json>>(tools_);
61-
if (! tools_ptr->is_array()) {
62-
throw std::invalid_argument(
63-
"tools must be a URL of the form \"mcp+http(s)://hostname[:port]/\""
64-
", or a valid JSON array containing tool definitions");
65-
}
66-
}
67-
68-
} catch (const json::exception & err) {
69-
throw std::invalid_argument(err.what());
70-
}
71-
}
72-
73-
void toolcall::params::choice(std::string choice) {
74-
try {
75-
if (choice == "auto" || choice == "required" || choice == "none") {
76-
tool_choice_ = std::move(choice);
77-
78-
} else {
79-
auto choice_ptr = std::make_shared<json>(json::parse(choice));
80-
tool_choice_ = choice_ptr;
81-
if (! choice_ptr->is_object()) {
82-
throw std::invalid_argument(
83-
"tool choice must be a valid JSON object, \"auto\", \"required\", or \"none\"");
84-
}
85-
}
86-
87-
} catch (const json::exception & err) {
88-
throw std::invalid_argument(err.what());
89-
}
90-
}
91-
92-
toolcall::params::operator bool() const {
93-
if (std::holds_alternative<std::string>(tools_)) {
94-
return ! std::get<std::string>(tools_).empty();
95-
96-
} else {
97-
return std::get<toolcall::json_ptr>(tools_) != nullptr;
98-
}
99-
}
100-
10137
json toolcall::handler::tool_list() {
10238
return impl_->tool_list();
10339
}

common/toolcall/handler.hpp

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
#pragma once
22

3+
#include "../json.hpp"
4+
#include "params.hpp" // TODO: make foreward decl.
35
#include <string>
46
#include <variant>
57
#include <memory>
68

7-
#include "../json.hpp"
8-
99
namespace toolcall
1010
{
1111
using json = nlohmann::ordered_json;
@@ -36,28 +36,6 @@ namespace toolcall
3636
action last_action_;
3737
};
3838

39-
class params {
40-
public:
41-
params(std::string tools = "", std::string choice = "auto");
42-
43-
params(const params & other) = default;
44-
params(params && other) noexcept = default;
45-
params & operator=(const params & other) = default;
46-
params & operator=(params && other) noexcept = default;
47-
48-
operator bool() const;
49-
50-
void tools(std::string tools);
51-
const tools_t tools() const { return tools_; }
52-
53-
void choice(std::string choice);
54-
const tool_choice_t & choice() const { return tool_choice_; }
55-
56-
private:
57-
tools_t tools_;
58-
tool_choice_t tool_choice_;
59-
};
60-
6139
std::shared_ptr<toolcall::handler> create_handler(const toolcall::params & params);
6240

6341
class handler_impl {

0 commit comments

Comments
 (0)