Skip to content

Commit 5ec4c5e

Browse files
author
ochafik
committed
reshuffle chat handlers
1 parent 43385b2 commit 5ec4c5e

File tree

3 files changed

+331
-270
lines changed

3 files changed

+331
-270
lines changed

common/chat-handler.cpp

Lines changed: 95 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ static common_chat_msg parse_json_tool_calls(const json & tools, const std::stri
7676
if (type == "function") {
7777
tool_names.push_back(tool["function"]["name"]);
7878
} else if (type == "code_interpreter") {
79-
tool_names.push_back("ipython");
79+
tool_names.push_back("python");
8080
}
8181
}
8282
}
@@ -171,6 +171,10 @@ class text_chat_parser : public common_chat_parser {
171171
/* .tool_calls = */ {},
172172
};
173173
}
174+
175+
std::unique_ptr<common_chat_parser> clone() const override {
176+
return std::make_unique<text_chat_parser>();
177+
}
174178
};
175179

176180
class monolithic_chat_parser : public common_chat_parser {
@@ -192,13 +196,48 @@ class monolithic_chat_parser : public common_chat_parser {
192196
input_buffer_.clear();
193197
return out;
194198
}
199+
200+
std::unique_ptr<common_chat_parser> clone() const override {
201+
return std::make_unique<monolithic_chat_parser>(parse_final_);
202+
}
195203
};
196204

197-
static common_chat_data build_generic_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) {
205+
const auto python_tool = json::parse(R"({
206+
"type": "function",
207+
"function": {
208+
"name": "python",
209+
"description": "an ipython interpreter",
210+
"parameters": {
211+
"type": "object",
212+
"properties": {
213+
"code": {
214+
"type": "string",
215+
"description": "Python code to execute."
216+
}
217+
},
218+
"required": ["code"]
219+
}
220+
}
221+
})");
222+
223+
static void foreach_normalized_tool(const json & tools, const std::function<void(const json &)> & fn) {
224+
for (const auto & tool : tools) {
225+
if (!tool.contains("type")) {
226+
continue;
227+
}
228+
if (tool["type"] == "code_interpreter") {
229+
fn(python_tool);
230+
} else {
231+
fn(tool);
232+
}
233+
}
234+
}
235+
236+
static common_chat_data common_chat_init_generic_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
198237
common_chat_data data;
199238

200239
auto tool_call_schemas = json::array();
201-
for (const auto & tool : params.tools) {
240+
foreach_normalized_tool(params.tools, [&](const json & tool) {
202241
const auto & function = tool["function"];
203242
auto tool_schema = json {
204243
{"type", "object"},
@@ -222,7 +261,7 @@ static common_chat_data build_generic_tool_call_handler(const common_chat_templa
222261
tool_schema["required"].push_back("id");
223262
}
224263
tool_call_schemas.emplace_back(tool_schema);
225-
}
264+
});
226265
const auto tool_call =
227266
params.parallel_tool_calls
228267
? json {
@@ -276,7 +315,7 @@ static common_chat_data build_generic_tool_call_handler(const common_chat_templa
276315
"Respond in JSON format, either with a request to call tools or with a response to the user's request. Here is the schema for all responses:\n\n```json\n" + schema.dump(2) + "\n```");
277316

278317
data.prompt = tmpl.apply(tweaked_messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
279-
data.handler = std::make_unique<monolithic_chat_parser>([&](const std::string & input) {
318+
data.parser = std::make_unique<monolithic_chat_parser>([&](const std::string & input) {
280319
json data = json::parse(input);
281320
common_chat_msg result;
282321
result.role = "assistant";
@@ -303,13 +342,11 @@ static common_chat_data build_generic_tool_call_handler(const common_chat_templa
303342
return data;
304343
}
305344

306-
static common_chat_data build_mistral_nemo_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) {
345+
static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
307346
common_chat_data data;
308-
auto builtin_tools = json {"wolfram_alpha", "brave_search"};
309-
310347
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
311348
auto schemas = json::array();
312-
for (const auto & tool : params.tools) {
349+
foreach_normalized_tool(params.tools, [&](const json & tool) {
313350
const auto & function = tool["function"];
314351
schemas.push_back({
315352
{"type", "object"},
@@ -329,7 +366,7 @@ static common_chat_data build_mistral_nemo_tool_call_handler(const common_chat_t
329366
}},
330367
{"required", json::array({"name", "arguments", "id"})},
331368
});
332-
}
369+
});
333370
auto schema = json {
334371
{"type", "array"},
335372
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
@@ -344,24 +381,14 @@ static common_chat_data build_mistral_nemo_tool_call_handler(const common_chat_t
344381
data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
345382
}
346383
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
347-
data.handler = std::make_unique<monolithic_chat_parser>([](const std::string & input) -> common_chat_msg {
384+
data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) -> common_chat_msg {
348385
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
349386
});
350387
return data;
351388
}
352389

353-
static common_chat_data build_llama_3_tool_calls_handler(const common_chat_template & tmpl, const struct common_chat_params & params, bool uses_python_tag, bool eagerly_match_any_json) {
390+
static common_chat_data common_chat_init_llama_3_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params, bool uses_python_tag, bool eagerly_match_any_json) {
354391
auto builtin_tools = json {"wolfram_alpha", "brave_search"};
355-
for (const auto & tool : params.tools) {
356-
if (!tool.contains("type")) {
357-
continue;
358-
}
359-
if (tool["type"] == "code_interpreter") {
360-
builtin_tools.push_back("code_interpreter");
361-
break;
362-
}
363-
}
364-
365392
common_chat_data data;
366393

367394
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
@@ -375,6 +402,7 @@ static common_chat_data build_llama_3_tool_calls_handler(const common_chat_templ
375402
}
376403

377404
if (tool["type"] == "code_interpreter") {
405+
builtin_tools.push_back("code_interpreter");
378406
has_python = true;
379407
} else if (tool["type"] == "function" && tool.contains("function")) {
380408
const auto & function = tool["function"];
@@ -422,8 +450,10 @@ static common_chat_data build_llama_3_tool_calls_handler(const common_chat_templ
422450
builder.add_rule("root", string_join(tool_rules, " | "));
423451
}, grammar_options);
424452
data.additional_stops.push_back("<|eom_id|>");
425-
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
426-
data.handler = std::make_unique<monolithic_chat_parser>([params, uses_python_tag](const std::string & input) -> common_chat_msg {
453+
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true, {
454+
{"builtin_tools", builtin_tools},
455+
});
456+
data.parser = std::make_unique<monolithic_chat_parser>([params, uses_python_tag](const std::string & input) -> common_chat_msg {
427457
if (uses_python_tag) {
428458
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
429459
std::smatch match;
@@ -448,11 +478,11 @@ static common_chat_data build_llama_3_tool_calls_handler(const common_chat_templ
448478
return data;
449479
}
450480

451-
static common_chat_data build_firefunction_v2_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) {
481+
static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
452482
common_chat_data data;
453483
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
454484
auto schemas = json::array();
455-
for (const auto & tool : params.tools) {
485+
foreach_normalized_tool(params.tools, [&](const json & tool) {
456486
const auto & function = tool["function"];
457487
schemas.push_back({
458488
{"type", "object"},
@@ -465,7 +495,7 @@ static common_chat_data build_firefunction_v2_tool_call_handler(const common_cha
465495
}},
466496
{"required", json::array({"name", "arguments", "id"})},
467497
});
468-
}
498+
});
469499
auto schema = json {
470500
{"type", "array"},
471501
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
@@ -480,13 +510,13 @@ static common_chat_data build_firefunction_v2_tool_call_handler(const common_cha
480510
data.grammar_triggers.push_back({" functools[", /* .at_start = */ false});
481511
}
482512
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
483-
data.handler = std::make_unique<monolithic_chat_parser>([](const std::string & input) -> common_chat_msg {
513+
data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) -> common_chat_msg {
484514
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
485515
});
486516
return data;
487517
}
488518

489-
static common_chat_data build_functionary_v3_llama_3_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) {
519+
static common_chat_data common_chat_init_functionary_v3_llama_3_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
490520
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
491521
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
492522
common_chat_data data;
@@ -530,19 +560,20 @@ static common_chat_data build_functionary_v3_llama_3_tool_call_handler(const com
530560
}, grammar_options);
531561

532562
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
533-
data.handler = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
563+
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
534564
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
535565
static std::regex close_regex(R"($|(?=>>>))");
536566
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true);
537567
});
538568
return data;
539569
}
540570

541-
static common_chat_data build_functionary_v3_llama_3_1_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) {
571+
static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
542572
// ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja
543573
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
544574
// TODO: handle tool {type: code_interpreter} as python
545575
common_chat_data data;
576+
json tools = params.tools.is_null() ? params.tools : json::array();
546577

547578
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
548579
std::vector<std::string> tool_rules;
@@ -578,7 +609,7 @@ static common_chat_data build_functionary_v3_llama_3_1_tool_call_handler(const c
578609
}, grammar_options);
579610

580611
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
581-
data.handler = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
612+
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
582613
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
583614
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
584615
std::smatch match;
@@ -602,12 +633,12 @@ static common_chat_data build_functionary_v3_llama_3_1_tool_call_handler(const c
602633
return data;
603634
}
604635

605-
static common_chat_data build_hermes_2_pro_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) {
636+
static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
606637
common_chat_data data;
607638
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
608639
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
609640
std::vector<std::string> tool_rules;
610-
for (const auto & tool : params.tools) {
641+
foreach_normalized_tool(params.tools, [&](const json & tool) {
611642
const auto & function = tool["function"];
612643
std::string name = function["name"];
613644
auto parameters = function["parameters"];
@@ -620,8 +651,7 @@ static common_chat_data build_hermes_2_pro_tool_call_handler(const common_chat_t
620651
}},
621652
{"required", json::array({"name", "arguments"})},
622653
}));
623-
}
624-
654+
});
625655
auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"</tool_call>\" space";
626656
builder.add_rule("root", params.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
627657
if (params.tool_choice != "required") {
@@ -630,7 +660,7 @@ static common_chat_data build_hermes_2_pro_tool_call_handler(const common_chat_t
630660
}, grammar_options);
631661

632662
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
633-
data.handler = std::make_unique<monolithic_chat_parser>([&](const std::string & input) -> common_chat_msg {
663+
data.parser = std::make_unique<monolithic_chat_parser>([&](const std::string & input) -> common_chat_msg {
634664
try {
635665
std::regex start_pattern(R"([\n\s]*<tool_call>)");
636666
std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
@@ -677,24 +707,40 @@ static common_chat_data build_hermes_2_pro_tool_call_handler(const common_chat_t
677707
return data;
678708
}
679709

710+
static common_chat_data common_chat_init_without_tools(const common_chat_template & tmpl, const struct common_chat_params & params) {
711+
common_chat_data data;
712+
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
713+
data.parser = std::make_unique<text_chat_parser>();
714+
if (!params.json_schema.is_null()) {
715+
if (!params.grammar.empty()) {
716+
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
717+
}
718+
data.grammar = json_schema_to_grammar(params.json_schema);
719+
} else {
720+
data.grammar = params.grammar.empty();
721+
}
722+
return data;
723+
}
724+
680725
common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params) {
681726
if (params.tools.is_null()) {
682-
common_chat_data data;
683-
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
684-
data.handler = std::make_unique<text_chat_parser>();
685-
return data;
727+
return common_chat_init_without_tools(tmpl, params);
686728
}
687-
const auto & src = tmpl.source();
688729

730+
if (!params.grammar.empty()) {
731+
throw std::runtime_error("Cannot specify grammar with tools");
732+
}
733+
734+
const auto & src = tmpl.source();
689735
if (src.find("<tool_call>") != std::string::npos) {
690-
return build_hermes_2_pro_tool_call_handler(tmpl, params);
736+
return common_chat_init_hermes_2_pro_tool_call(tmpl, params);
691737
}
692738
if (src.find(">>>all") != std::string::npos) {
693-
return build_functionary_v3_llama_3_tool_call_handler(tmpl, params);
739+
return common_chat_init_functionary_v3_llama_3_tool_call(tmpl, params);
694740
}
695741
if (src.find("<|start_header_id|>") != std::string::npos
696742
&& src.find("<function=") != std::string::npos) {
697-
return build_functionary_v3_llama_3_1_tool_call_handler(tmpl, params);
743+
return common_chat_init_functionary_v3_llama_3_1_tool_call(tmpl, params);
698744
}
699745
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
700746
auto uses_python_tag = src.find("<|python_tag|>") != std::string::npos;
@@ -705,16 +751,16 @@ common_chat_data common_chat_init(const common_chat_template & tmpl, const struc
705751
// TODO: make this conditional on a very small model (e.g. 1B / 3B).
706752
auto eagerly_match_any_json = false; // style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_2;
707753

708-
return build_llama_3_tool_calls_handler(tmpl, params, uses_python_tag, eagerly_match_any_json);
754+
return common_chat_init_llama_3_tool_calls(tmpl, params, uses_python_tag, eagerly_match_any_json);
709755
}
710756
// if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) {
711757
// TODO: Command-R-Plus
712758
// }
713759
if (src.find("[TOOL_CALLS]") != std::string::npos) {
714-
return build_mistral_nemo_tool_call_handler(tmpl, params);
760+
return common_chat_init_mistral_nemo_tool_call(tmpl, params);
715761
}
716762
if (src.find(" functools[") != std::string::npos) {
717-
return build_firefunction_v2_tool_call_handler(tmpl, params);
763+
return common_chat_init_firefunction_v2_tool_call(tmpl, params);
718764
}
719-
return build_generic_tool_call_handler(tmpl, params);
765+
return common_chat_init_generic_tool_call(tmpl, params);
720766
}

common/chat-handler.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ struct common_chat_params {
2222
json json_schema;
2323
bool parallel_tool_calls;
2424
bool stream;
25+
std::string grammar;
2526
};
2627

2728
class common_chat_parser {
@@ -30,14 +31,15 @@ class common_chat_parser {
3031

3132
virtual std::optional<common_chat_msg> parse_partial(const std::string & input) = 0;
3233
virtual common_chat_msg parse_final(const std::string & input) = 0;
34+
virtual std::unique_ptr<common_chat_parser> clone() const = 0;
3335
};
3436

3537
struct common_chat_data {
36-
std::string prompt;
38+
json prompt;
3739
std::string grammar;
3840
std::vector<common_grammar_trigger> grammar_triggers;
3941
std::vector<std::string> additional_stops;
40-
std::unique_ptr<class common_chat_parser> handler;
42+
std::unique_ptr<class common_chat_parser> parser;
4143
};
4244

4345
struct common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params);

0 commit comments

Comments
 (0)