Skip to content

Commit f2af859

Browse files
committed
Post-Merge refactoring
1 parent 6ce23b6 commit f2af859

File tree

8 files changed

+85
-101
lines changed

8 files changed

+85
-101
lines changed

common/chat.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -291,9 +291,11 @@ std::string common_chat_format_single(
291291
const std::vector<common_chat_msg> & past_msg,
292292
const common_chat_msg & new_msg,
293293
bool add_ass,
294-
bool use_jinja) {
294+
bool use_jinja,
295+
const struct common_chat_templates_inputs * input_extra,
296+
struct common_chat_params * out_params) {
295297

296-
common_chat_templates_inputs inputs;
298+
common_chat_templates_inputs inputs = input_extra ? *input_extra : common_chat_templates_inputs();
297299
inputs.use_jinja = use_jinja;
298300

299301
std::string fmt_past_msg;
@@ -310,9 +312,13 @@ std::string common_chat_format_single(
310312
// format chat with new_msg
311313
inputs.messages.push_back(new_msg);
312314
inputs.add_generation_prompt = add_ass;
313-
auto fmt_new_msg = common_chat_templates_apply(tmpls, inputs).prompt;
315+
auto chat_params = common_chat_templates_apply(tmpls, inputs);
316+
auto fmt_new_msg = chat_params.prompt;
314317
// get the diff part
315318
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
319+
if (out_params) {
320+
*out_params = std::move(chat_params);
321+
}
316322
return ss.str();
317323
}
318324

common/chat.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,9 @@ std::string common_chat_format_single(
111111
const std::vector<common_chat_msg> & past_msg,
112112
const common_chat_msg & new_msg,
113113
bool add_ass,
114-
bool use_jinja);
114+
bool use_jinja,
115+
const struct common_chat_templates_inputs * input_extra = nullptr,
116+
struct common_chat_params * out_params = nullptr);
115117

116118
// Returns an example of formatted chat
117119
std::string common_chat_format_example(

common/common.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "json.hpp"
1313
#include "json-schema-to-grammar.h"
1414
#include "llama.h"
15+
#include "chat.h"
1516

1617
#include <algorithm>
1718
#include <cinttypes>

common/toolcall/handler.cpp

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,82 +1,74 @@
11

22
#include <json.hpp>
33
#include "handler.hpp"
4-
#include "params.hpp"
54

65
#ifdef LLAMA_USE_CURL
76
# include "mcp_sse_transport.hpp"
87
#endif
98

109
#include "mcp_stdio_transport.hpp"
1110

12-
using json = toolcall::json;
11+
using json = nlohmann::json;
1312

1413
std::shared_ptr<toolcall::handler> toolcall::create_handler(const toolcall::params & params) {
15-
std::shared_ptr<toolcall::handler> result;
14+
std::shared_ptr<toolcall::handler> handler;
1615

1716
auto tools = params.tools();
1817
auto choice = params.choice();
19-
bool has_uri = std::holds_alternative<std::string>(tools);
20-
if (has_uri) {
18+
if (params.has_uri()) {
2119
#ifdef LLAMA_USE_CURL
22-
auto tools_str = std::get<std::string>(tools);
23-
if (! tools_str.empty()) {
24-
result.reset(new toolcall::handler(std::make_unique<toolcall::mcp_impl>(tools_str, choice)));
25-
}
20+
handler.reset(new toolcall::handler(std::make_unique<toolcall::mcp_impl>(tools, choice)));
2621
#endif
2722
} else {
28-
auto tools_ptr = std::get<toolcall::json_ptr>(tools);
29-
if (tools_ptr != nullptr) {
30-
result.reset(new toolcall::handler(std::make_unique<toolcall::loopback_impl>(*tools_ptr, choice)));
31-
}
23+
handler.reset(new toolcall::handler(std::make_unique<toolcall::loopback_impl>(tools, choice)));
3224
}
3325

34-
return result;
26+
return handler;
3527
}
3628

37-
json toolcall::handler::tool_list() {
29+
std::string toolcall::handler::tool_list() {
3830
return impl_->tool_list();
3931
}
4032

41-
toolcall::action toolcall::handler::call(const json & request, json & response) {
33+
toolcall::action toolcall::handler::call(const std::string & request, std::string & response) {
4234
last_action_ = impl_->call(request, response);
4335
return last_action_;
4436
}
4537

46-
const toolcall::tool_choice_t & toolcall::handler::tool_choice() const {
38+
const std::string & toolcall::handler::tool_choice() const {
4739
return impl_->tool_choice();
4840
}
4941
toolcall::action toolcall::handler::last_action() const {
5042
return last_action_;
5143
}
5244

5345
#ifdef LLAMA_USE_CURL
54-
toolcall::mcp_impl::mcp_impl(std::string server_uri, tool_choice_t tool_choice)
46+
toolcall::mcp_impl::mcp_impl(std::string server_uri, std::string tool_choice)
5547
: handler_impl(tool_choice),
5648
transport_(new mcp_sse_transport(server_uri))
5749
{
5850
transport_->start();
5951
}
6052
#else
61-
toolcall::mcp_impl::mcp_impl(std::string /*server_uri*/, tool_choice_t tool_choice)
53+
toolcall::mcp_impl::mcp_impl(std::string /*server_uri*/, std::string tool_choice)
6254
: handler_impl(tool_choice)
6355
{
6456
}
6557
#endif
6658

67-
toolcall::mcp_impl::mcp_impl(std::vector<std::string> argv, tool_choice_t tool_choice)
59+
toolcall::mcp_impl::mcp_impl(std::vector<std::string> argv, std::string tool_choice)
6860
: handler_impl(tool_choice),
6961
transport_(new mcp_stdio_transport(argv))
7062
{
7163
transport_->start();
7264
}
7365

74-
json toolcall::mcp_impl::tool_list() {
66+
std::string toolcall::mcp_impl::tool_list() {
7567
// Construct tools/list call and send to transport
7668
return json{};// TODO
7769
}
7870

79-
toolcall::action toolcall::mcp_impl::call(const json & /*request*/, json & /*response*/) {
71+
toolcall::action toolcall::mcp_impl::call(const std::string & /*request*/, std::string & /*response*/) {
8072
// Construct tool call and send to transport
8173
return toolcall::ACCEPT; // TODO
8274
}

common/toolcall/handler.hpp

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,13 @@
11
#pragma once
22

3-
#include <json.hpp> // TODO: remove dependence on this
4-
#include "params.hpp" // TODO: make foreward decl.
3+
#include "params.hpp"
54
#include <string>
65
#include <variant>
76
#include <memory>
7+
#include <vector>
88

99
namespace toolcall
1010
{
11-
using json = nlohmann::ordered_json;
12-
using json_ptr = std::shared_ptr<json>;
13-
using tools_t = std::variant<std::string, json_ptr>;
14-
using tool_choice_t = std::variant<std::string, json_ptr>;
15-
1611
enum action {
1712
ACCEPT,
1813
PENDING,
@@ -26,9 +21,9 @@ namespace toolcall
2621

2722
handler(std::unique_ptr<handler_impl> impl) : impl_(std::move(impl)) {}
2823

29-
json tool_list();
30-
action call(const json & request, json & response);
31-
const tool_choice_t & tool_choice() const;
24+
std::string tool_list();
25+
action call(const std::string & request, std::string & response);
26+
const std::string & tool_choice() const;
3227
action last_action() const;
3328

3429
private:
@@ -40,45 +35,45 @@ namespace toolcall
4035

4136
class handler_impl {
4237
public:
43-
handler_impl(tool_choice_t tool_choice)
38+
handler_impl(std::string tool_choice)
4439
: tool_choice_(std::move(tool_choice)) {}
4540

4641
virtual ~handler_impl() = default;
47-
virtual json tool_list() = 0;
48-
virtual action call(const json & request, json & response) = 0;
42+
virtual std::string tool_list() = 0;
43+
virtual action call(const std::string & request, std::string & response) = 0;
4944

50-
const tool_choice_t & tool_choice() const { return tool_choice_; }
45+
const std::string & tool_choice() const { return tool_choice_; }
5146

5247
protected:
53-
tool_choice_t tool_choice_;
48+
std::string tool_choice_;
5449
};
5550

5651
class loopback_impl : public handler_impl {
5752
public:
58-
loopback_impl(json tools, tool_choice_t tool_choice)
53+
loopback_impl(std::string tools, std::string tool_choice)
5954
: handler_impl(tool_choice), tools_(std::move(tools)) {}
6055

61-
virtual json tool_list() override {
56+
virtual std::string tool_list() override {
6257
return tools_;
6358
}
6459

65-
virtual action call(const json & request, json & response) override {
60+
virtual action call(const std::string & request, std::string & response) override {
6661
response = request;
6762
return toolcall::DEFER;
6863
}
6964

7065
private:
71-
json tools_;
66+
std::string tools_;
7267
};
7368

7469
class mcp_transport;
7570
class mcp_impl : public handler_impl {
7671
public:
77-
mcp_impl(std::string server_uri, tool_choice_t tool_choice);
78-
mcp_impl(std::vector<std::string> argv, tool_choice_t tool_choice);
72+
mcp_impl(std::string server_uri, std::string tool_choice);
73+
mcp_impl(std::vector<std::string> argv, std::string tool_choice);
7974

80-
virtual json tool_list() override;
81-
virtual action call(const json & request, json & response) override;
75+
virtual std::string tool_list() override;
76+
virtual action call(const std::string & request, std::string & response) override;
8277

8378
private:
8479
std::unique_ptr<mcp_transport> transport_;

common/toolcall/params.cpp

Lines changed: 21 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11

2-
#include "../json.hpp" // Must come before params due to forward decl.
32
#include "params.hpp"
43
#include <stdexcept>
4+
#include <json.hpp>
55

6-
using json = nlohmann::ordered_json;
6+
using json = nlohmann::json;
77

88
static bool starts_with(const std::string & str, const std::string & prefix) {
99
return str.size() >= prefix.size()
@@ -17,25 +17,25 @@ toolcall::params::params(std::string tools, std::string choice) {
1717

1818
void toolcall::params::tools(std::string tools) {
1919
try {
20-
if (tools.empty()) {
21-
tools_ = std::move(tools);
22-
23-
} else if (starts_with(tools, "mcp+http")) {
24-
#ifdef LLAMA_USE_CURL
25-
tools_ = std::move(tools);
26-
#else
27-
throw std::invalid_argument(
28-
"Model Context Protocol (MCP) only works when llama.cpp is compiled with libcurl");
29-
#endif
30-
} else {
31-
tools_ = std::make_shared<json>(json::parse(tools));
32-
auto tools_ptr = std::get<std::shared_ptr<json>>(tools_);
33-
if (! tools_ptr->is_array()) {
20+
if (! tools.empty()) {
21+
if (starts_with(tools, "mcp+http")) {
22+
#ifndef LLAMA_USE_CURL
3423
throw std::invalid_argument(
35-
"tools must be a URL of the form \"mcp+http(s)://hostname[:port]/\""
36-
", or a valid JSON array containing tool definitions");
24+
"Model Context Protocol (MCP) only works when llama.cpp is compiled with libcurl");
25+
#endif
26+
has_uri_ = true;
27+
28+
} else {
29+
json j = json::parse(tools); // Just for early validation
30+
if (! j.is_array()) {
31+
throw std::invalid_argument(
32+
"tools must be a URL of the form \"mcp+http(s)://hostname[:port]/\""
33+
", or a valid JSON array containing tool definitions");
34+
}
35+
has_uri_ = false;
3736
}
3837
}
38+
tools_ = std::move(tools);
3939

4040
} catch (const json::exception & err) {
4141
throw std::invalid_argument(err.what());
@@ -48,12 +48,8 @@ void toolcall::params::choice(std::string choice) {
4848
tool_choice_ = std::move(choice);
4949

5050
} else {
51-
auto choice_ptr = std::make_shared<json>(json::parse(choice));
52-
tool_choice_ = choice_ptr;
53-
if (! choice_ptr->is_object()) {
54-
throw std::invalid_argument(
55-
"tool choice must be a valid JSON object, \"auto\", \"required\", or \"none\"");
56-
}
51+
throw std::invalid_argument(
52+
"tool choice must be set to \"auto\", \"required\", or \"none\"");
5753
}
5854

5955
} catch (const json::exception & err) {
@@ -62,10 +58,5 @@ void toolcall::params::choice(std::string choice) {
6258
}
6359

6460
toolcall::params::operator bool() const {
65-
if (std::holds_alternative<std::string>(tools_)) {
66-
return ! std::get<std::string>(tools_).empty();
67-
68-
} else {
69-
return std::get<json_ptr>(tools_) != nullptr;
70-
}
61+
return ! tools_.empty();
7162
}

common/toolcall/params.hpp

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,10 @@
44
#include <variant>
55
#include <memory>
66

7-
#include <json.hpp> // TODO: switch to foreward decl.
8-
// namespace nlohmann { class ordered_json; }
9-
107
namespace toolcall
118
{
129
class params {
1310
public:
14-
using json_ptr = std::shared_ptr<nlohmann::ordered_json>;
15-
using tools_t = std::variant<std::string, json_ptr>;
16-
using tool_choice_t = std::variant<std::string, json_ptr>;
17-
1811
params(std::string tools = "", std::string choice = "auto");
1912

2013
params(const params & other) = default;
@@ -25,13 +18,16 @@ namespace toolcall
2518
operator bool() const;
2619

2720
void tools(std::string tools);
28-
const tools_t tools() const { return tools_; }
21+
const std::string & tools() const { return tools_; }
2922

3023
void choice(std::string choice);
31-
const tool_choice_t & choice() const { return tool_choice_; }
24+
const std::string & choice() const { return tool_choice_; }
25+
26+
bool has_uri() const { return has_uri_; }
3227

3328
private:
34-
tools_t tools_;
35-
tool_choice_t tool_choice_;
29+
std::string tools_;
30+
std::string tool_choice_;
31+
bool has_uri_;
3632
};
3733
}

0 commit comments

Comments
 (0)