Skip to content

Commit 1dd2e3b

Browse files
committed
Add handler and MCP message types
1 parent a024747 commit 1dd2e3b

File tree

8 files changed

+709
-107
lines changed

8 files changed

+709
-107
lines changed

common/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ add_library(${TARGET} STATIC
7575
sampling.h
7676
speculative.cpp
7777
speculative.h
78+
${CMAKE_CURRENT_SOURCE_DIR}/toolcall/handler.cpp
79+
${CMAKE_CURRENT_SOURCE_DIR}/toolcall/handler.hpp
80+
${CMAKE_CURRENT_SOURCE_DIR}/toolcall/mcp_messages.cpp
81+
${CMAKE_CURRENT_SOURCE_DIR}/toolcall/mcp_messages.hpp
7882
)
7983

8084
if (BUILD_SHARED_LIBS)

common/common.cpp

Lines changed: 34 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
#include "common.h"
99
#include "log.h"
10+
// Change JSON_ASSERT from assert() to GGML_ASSERT:
11+
#define JSON_ASSERT GGML_ASSERT
12+
#include "json.hpp"
1013
#include "json-schema-to-grammar.h"
1114
#include "llama.h"
1215
#include "chat.hpp"
@@ -1769,46 +1772,6 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto
17691772
// Chat template utils
17701773
//
17711774

1772-
common_params_tools::common_params_tools(std::string tools, std::string choice) {
1773-
this->tools(tools);
1774-
this->choice(choice);
1775-
}
1776-
1777-
void common_params_tools::tools(std::string tools) {
1778-
if (tools.empty()) {
1779-
tools_.reset();
1780-
return;
1781-
}
1782-
try {
1783-
tools_ = std::make_shared<json>(json::parse(tools));
1784-
if (! tools_->is_array()) {
1785-
throw std::invalid_argument("tools must be a valid JSON array");
1786-
}
1787-
1788-
} catch (const json::exception & err) {
1789-
throw std::invalid_argument(err.what());
1790-
}
1791-
}
1792-
1793-
void common_params_tools::choice(std::string choice) {
1794-
try {
1795-
if (choice == "auto" || choice == "required" || choice == "none") {
1796-
tool_choice_ = std::move(choice);
1797-
1798-
} else {
1799-
auto choice_ptr = std::make_shared<json>(json::parse(choice));
1800-
tool_choice_ = choice_ptr;
1801-
if (! choice_ptr->is_object()) {
1802-
throw std::invalid_argument(
1803-
"tool choice must be a valid JSON object, \"auto\", \"required\", or \"none\"");
1804-
}
1805-
}
1806-
1807-
} catch (const json::exception & err) {
1808-
throw std::invalid_argument(err.what());
1809-
}
1810-
}
1811-
18121775
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
18131776
if (use_jinja) {
18141777
try {
@@ -1830,7 +1793,7 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
18301793
return res >= 0;
18311794
}
18321795

1833-
static void copy_chat_params(const common_chat_params & src, common_chat_sampling_updater * update_sparams)
1796+
static void copy_chat_params(const common_chat_params & src, toolcall::sampling_updater * update_sparams)
18341797
{
18351798
GGML_ASSERT(update_sparams && update_sparams->sparams && update_sparams->vocab);
18361799

@@ -1873,11 +1836,11 @@ std::string common_chat_apply_template(
18731836
const std::vector<common_chat_msg> & msgs,
18741837
bool add_ass,
18751838
bool use_jinja,
1876-
const common_params_tools & tools,
1877-
common_chat_sampling_updater * update_sparams)
1839+
toolcall::handler::ptr handler,
1840+
toolcall::sampling_updater * update_sparams)
18781841
{
18791842
const auto & tmpl_selected =
1880-
tools.tools() && tmpl.template_tool_use ? *tmpl.template_tool_use : *tmpl.template_default;
1843+
handler != nullptr && tmpl.template_tool_use ? *tmpl.template_tool_use : *tmpl.template_default;
18811844

18821845
if (use_jinja) {
18831846
common_chat_inputs inputs;
@@ -1886,29 +1849,38 @@ std::string common_chat_apply_template(
18861849
for (const auto & msg : msgs) {
18871850
messages.push_back({{"role", msg.role}, {"content", msg.content}});
18881851
}
1852+
inputs.messages = messages;
1853+
inputs.add_generation_prompt = add_ass;
18891854

1890-
if (tools.tools() != nullptr) {
1891-
inputs.tools = *tools.tools();
1892-
}
1893-
1894-
auto choice = tools.choice();
1895-
if (std::holds_alternative<std::string>(choice)) {
1896-
inputs.tool_choice = std::get<std::string>(choice);
1855+
if (handler != nullptr) {
1856+
auto choice = handler->tool_choice();
1857+
if (std::holds_alternative<std::string>(choice)) {
1858+
inputs.tool_choice = std::get<std::string>(choice);
18971859

1898-
} else {
1899-
auto choice_ptr = std::get<common_params_tools::json_ptr>(choice);
1900-
if (choice_ptr != nullptr) {
1901-
inputs.tool_choice = *choice_ptr;
1860+
} else {
1861+
auto choice_ptr = std::get<toolcall::json_ptr>(choice);
1862+
if (choice_ptr != nullptr) {
1863+
inputs.tool_choice = *choice_ptr;
1864+
}
19021865
}
1866+
1867+
inputs.tools = handler->tool_list();
19031868
}
19041869

1905-
inputs.messages = messages;
1906-
inputs.add_generation_prompt = add_ass;
19071870
auto chat_params = common_chat_params_init(tmpl_selected, inputs);
19081871
if (update_sparams) {
19091872
copy_chat_params(chat_params, update_sparams);
19101873
}
1911-
return chat_params.prompt;
1874+
1875+
auto prompt = chat_params.prompt;
1876+
if (handler != nullptr) {
1877+
json response;
1878+
handler->call(prompt, response);
1879+
return response; // Caller will determine what to do based upon last_action
1880+
1881+
} else {
1882+
return prompt;
1883+
}
19121884
}
19131885

19141886
int alloc_size = 0;
@@ -1946,12 +1918,12 @@ std::string common_chat_format_single(
19461918
const common_chat_msg & new_msg,
19471919
bool add_ass,
19481920
bool use_jinja,
1949-
const common_params_tools & tools,
1950-
common_chat_sampling_updater * update_sparams)
1921+
toolcall::handler::ptr handler,
1922+
toolcall::sampling_updater * update_sparams)
19511923
{
19521924
std::ostringstream ss;
19531925
auto fmt_past_msg = past_msg.empty() ? ""
1954-
: common_chat_apply_template(tmpl, past_msg, false, use_jinja, tools, update_sparams);
1926+
: common_chat_apply_template(tmpl, past_msg, false, use_jinja, handler, update_sparams);
19551927

19561928
std::vector<common_chat_msg> chat_new(past_msg);
19571929
// if the past_msg ends with a newline, we must preserve it in the formatted version
@@ -1960,7 +1932,7 @@ std::string common_chat_format_single(
19601932
};
19611933
// format chat with new_msg
19621934
chat_new.push_back(new_msg);
1963-
auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, tools, update_sparams);
1935+
auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, handler, update_sparams);
19641936
// get the diff part
19651937
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
19661938
return ss.str();

common/common.h

Lines changed: 12 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,12 @@
33
#pragma once
44

55
#include "llama-cpp.h"
6-
6+
#include "toolcall/handler.hpp"
77
#include <set>
88
#include <string>
99
#include <vector>
1010
#include <sstream>
1111
#include <variant>
12-
// Change JSON_ASSERT from assert() to GGML_ASSERT:
13-
#define JSON_ASSERT GGML_ASSERT
14-
#include "json.hpp"
1512

1613
#ifdef _WIN32
1714
#define DIRECTORY_SEPARATOR '\\'
@@ -206,31 +203,6 @@ struct common_params_vocoder {
206203
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
207204
};
208205

209-
class common_params_tools {
210-
public:
211-
using json = nlohmann::ordered_json;
212-
using json_ptr = std::shared_ptr<json>;
213-
using tool_choice_t = std::variant<std::string, json_ptr>;
214-
215-
common_params_tools(std::string tools = "",
216-
std::string choice = "auto");
217-
218-
common_params_tools(const common_params_tools & other) = default;
219-
common_params_tools(common_params_tools && other) noexcept = default;
220-
common_params_tools & operator=(const common_params_tools & other) = default;
221-
common_params_tools & operator=(common_params_tools && other) noexcept = default;
222-
223-
void tools(std::string tools);
224-
const json * tools() const { return tools_.get(); }
225-
226-
void choice(std::string choice);
227-
const tool_choice_t & choice() const { return tool_choice_; }
228-
229-
private:
230-
json_ptr tools_;
231-
tool_choice_t tool_choice_;
232-
};
233-
234206
struct common_params {
235207
int32_t n_predict = -1; // new tokens to predict
236208
int32_t n_ctx = 4096; // context size
@@ -375,7 +347,7 @@ struct common_params {
375347
std::string chat_template = ""; // NOLINT
376348
bool use_jinja = false; // NOLINT
377349
bool enable_chat_template = true;
378-
common_params_tools jinja_tools;
350+
toolcall::params jinja_tools;
379351

380352
std::vector<std::string> api_keys;
381353

@@ -671,10 +643,12 @@ struct common_chat_templates {
671643
std::unique_ptr<common_chat_template> template_tool_use;
672644
};
673645

674-
struct common_chat_sampling_updater {
675-
common_params_sampling * sparams;
676-
const llama_vocab * vocab;
677-
};
646+
namespace toolcall {
647+
struct sampling_updater {
648+
common_params_sampling * sparams;
649+
const llama_vocab * vocab;
650+
};
651+
}
678652

679653
// CPP wrapper for llama_chat_apply_template
680654
// If the built-in template is not supported, we default to chatml
@@ -684,8 +658,8 @@ std::string common_chat_apply_template(
684658
const std::vector<common_chat_msg> & chat,
685659
bool add_ass,
686660
bool use_jinja,
687-
const common_params_tools & tools = common_params_tools(),
688-
common_chat_sampling_updater * update_sparams = nullptr);
661+
toolcall::handler::ptr handler = nullptr,
662+
toolcall::sampling_updater * update_sparams = nullptr);
689663

690664
// Format single message, while taking into account the position of that message in chat history
691665
std::string common_chat_format_single(
@@ -694,8 +668,8 @@ std::string common_chat_format_single(
694668
const common_chat_msg & new_msg,
695669
bool add_ass,
696670
bool use_jinja,
697-
const common_params_tools & tools = common_params_tools(),
698-
common_chat_sampling_updater * update_sparams = nullptr);
671+
toolcall::handler::ptr handler = nullptr,
672+
toolcall::sampling_updater * update_sparams = nullptr);
699673

700674
// Returns an example of formatted chat
701675
std::string common_chat_format_example(

common/toolcall/handler.cpp

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
2+
#include "handler.hpp"
3+
4+
using json = toolcall::json;
5+
6+
toolcall::params::params(std::string tools, std::string choice) {
7+
this->tools(tools);
8+
this->choice(choice);
9+
}
10+
11+
static bool starts_with(const std::string & str, const std::string & prefix) {
12+
return str.size() >= prefix.size()
13+
&& str.compare(0, prefix.size(), prefix) == 0;
14+
}
15+
16+
std::shared_ptr<toolcall::handler> toolcall::create_handler(const toolcall::params & params) {
17+
std::shared_ptr<toolcall::handler> result;
18+
19+
auto tools = params.tools();
20+
auto choice = params.choice();
21+
bool has_uri = std::holds_alternative<std::string>(tools);
22+
if (has_uri) {
23+
auto tools_str = std::get<std::string>(tools);
24+
result.reset(new toolcall::handler(std::make_unique<toolcall::mcp_impl>(tools_str, choice)));
25+
26+
} else {
27+
auto tools_ptr = std::get<toolcall::json_ptr>(tools);
28+
if (tools_ptr != nullptr) {
29+
result.reset(new toolcall::handler(std::make_unique<toolcall::loopback_impl>(*tools_ptr, choice)));
30+
}
31+
}
32+
33+
return result;
34+
}
35+
36+
void toolcall::params::tools(std::string tools) {
37+
try {
38+
if (tools.empty() || starts_with(tools, "mcp+http")) {
39+
tools_ = std::move(tools);
40+
41+
} else {
42+
tools_ = std::make_shared<json>(json::parse(tools));
43+
auto tools_ptr = std::get<std::shared_ptr<json>>(tools_);
44+
if (! tools_ptr->is_array()) {
45+
throw std::invalid_argument("tools must be a valid JSON array");
46+
}
47+
}
48+
49+
} catch (const json::exception & err) {
50+
throw std::invalid_argument(err.what());
51+
}
52+
}
53+
54+
void toolcall::params::choice(std::string choice) {
55+
try {
56+
if (choice == "auto" || choice == "required" || choice == "none") {
57+
tool_choice_ = std::move(choice);
58+
59+
} else {
60+
auto choice_ptr = std::make_shared<json>(json::parse(choice));
61+
tool_choice_ = choice_ptr;
62+
if (! choice_ptr->is_object()) {
63+
throw std::invalid_argument(
64+
"tool choice must be a valid JSON object, \"auto\", \"required\", or \"none\"");
65+
}
66+
}
67+
68+
} catch (const json::exception & err) {
69+
throw std::invalid_argument(err.what());
70+
}
71+
}
72+
73+
toolcall::params::operator bool() const {
74+
if (std::holds_alternative<std::string>(tools_)) {
75+
return ! std::get<std::string>(tools_).empty();
76+
77+
} else {
78+
return std::get<toolcall::json_ptr>(tools_) != nullptr;
79+
}
80+
}
81+
82+
json toolcall::handler::tool_list() {
83+
return impl_->tool_list();
84+
}
85+
86+
toolcall::action toolcall::handler::call(const json & request, json & response) {
87+
last_action_ = impl_->call(request, response);
88+
return last_action_;
89+
}
90+
91+
const toolcall::tool_choice_t & toolcall::handler::tool_choice() const {
92+
return impl_->tool_choice();
93+
}
94+
toolcall::action toolcall::handler::last_action() const {
95+
return last_action_;
96+
}
97+
98+
toolcall::mcp_impl::mcp_impl(std::string server_uri, tool_choice_t tool_choice)
99+
: handler_impl(tool_choice)
100+
{
101+
// TODO
102+
}
103+
104+
json toolcall::mcp_impl::tool_list() {
105+
return json{};// TODO
106+
}
107+
108+
toolcall::action toolcall::mcp_impl::call(const json & request, json & response) {
109+
return toolcall::ACCEPT; // TODO
110+
}

0 commit comments

Comments
 (0)