Skip to content

Commit b5a74d1

Browse files
author
ochafik
committed
Simplify parser defs (incremental parsing for streaming will need more thinking)
1 parent ec4aeaf commit b5a74d1

File tree

4 files changed

+34
-97
lines changed

4 files changed

+34
-97
lines changed

common/chat-handler.cpp

Lines changed: 25 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -152,50 +152,6 @@ static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& in
152152
return result;
153153
}
154154

155-
class text_chat_parser : public common_chat_parser {
156-
public:
157-
std::optional<common_chat_msg> parse_partial(const std::string & input) override {
158-
return parse_final(input);
159-
}
160-
161-
common_chat_msg parse_final(const std::string & input) override {
162-
return {
163-
/* .role = */ "assistant",
164-
/* .content = */ input,
165-
/* .tool_calls = */ {},
166-
};
167-
}
168-
169-
std::unique_ptr<common_chat_parser> clone() const override {
170-
return std::make_unique<text_chat_parser>();
171-
}
172-
};
173-
174-
class monolithic_chat_parser : public common_chat_parser {
175-
176-
std::string input_buffer_;
177-
std::function<common_chat_msg(const std::string & input)> parse_final_;
178-
179-
public:
180-
monolithic_chat_parser(const std::function<common_chat_msg(const std::string & input)> & parse_final) : parse_final_(parse_final) {}
181-
182-
std::optional<common_chat_msg> parse_partial(const std::string & input) override {
183-
input_buffer_ += input;
184-
return std::nullopt;
185-
}
186-
187-
common_chat_msg parse_final(const std::string & input) override {
188-
input_buffer_ += input;
189-
auto out = parse_final_(input_buffer_);
190-
input_buffer_.clear();
191-
return out;
192-
}
193-
194-
std::unique_ptr<common_chat_parser> clone() const override {
195-
return std::make_unique<monolithic_chat_parser>(parse_final_);
196-
}
197-
};
198-
199155
static void foreach_function(const json & tools, const std::function<void(const json &)> & fn) {
200156
for (const auto & tool : tools) {
201157
if (!tool.contains("type") || tool["type"] != "function" || !tool.contains("function")) {
@@ -289,7 +245,7 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
289245

290246
data.prompt = tmpl.apply(tweaked_messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
291247
data.format = "generic tool calls";
292-
data.parser = std::make_unique<monolithic_chat_parser>([&](const std::string & input) {
248+
data.parser = [&](const std::string & input) {
293249
json data = json::parse(input);
294250
common_chat_msg result;
295251
result.role = "assistant";
@@ -312,7 +268,7 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
312268
result.content = response.is_string() ? response.get<std::string>() : response.dump(2);
313269
}
314270
return result;
315-
});
271+
};
316272
return data;
317273
}
318274

@@ -355,9 +311,9 @@ static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_cha
355311
data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
356312
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
357313
data.format = "mistral nemo tool calls";
358-
data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) {
314+
data.parser = [](const std::string & input) {
359315
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
360-
});
316+
};
361317
return data;
362318
}
363319

@@ -441,7 +397,7 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
441397
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
442398
});
443399
data.format = "llama 3.1 tool calls";
444-
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
400+
data.parser = [params](const std::string & input) -> common_chat_msg {
445401
static std::regex function_regex("\\{\"name\": \"([^\"]+)\", \"parameters\": ");
446402
static std::regex close_regex("\\}");
447403
static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\\.call\\((.*)\\)");
@@ -472,7 +428,7 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
472428
};
473429
}
474430
return parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true);
475-
});
431+
};
476432
return data;
477433
}
478434

@@ -505,12 +461,12 @@ static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_
505461
data.additional_stops.push_back("<|eom_id|>");
506462
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt, {});
507463
data.format = "llama 3.2 tool calls";
508-
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) {
464+
data.parser = [params](const std::string & input) {
509465
static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": ");
510466
static std::regex close_regex("\\}");
511467
auto res = parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true);
512468
return res;
513-
});
469+
};
514470
return data;
515471
}
516472

@@ -532,12 +488,12 @@ static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat
532488
}, grammar_options);
533489
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
534490
data.format = "deepseek r1 tool calls";
535-
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) {
491+
data.parser = [params](const std::string & input) {
536492
static std::regex trigger_regex("<|tool▁calls▁begin|>");
537493
static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n");
538494
static std::regex close_regex("```<|tool▁call▁end|>");
539495
return parse_json_tool_calls(params.tools, input, trigger_regex, function_regex, close_regex, /* check_names= */ true);
540-
});
496+
};
541497
return data;
542498
}
543499

@@ -573,9 +529,9 @@ static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_
573529
data.grammar_triggers.push_back({" functools[", /* .at_start = */ false});
574530
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
575531
data.format = "firefunction v2 tool calls";
576-
data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) {
532+
data.parser = [](const std::string & input) {
577533
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
578-
});
534+
};
579535
return data;
580536
}
581537

@@ -610,7 +566,7 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common
610566

611567
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
612568
data.format = "functionary v3.2 tool calls";
613-
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) {
569+
data.parser = [params](const std::string & input) {
614570
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
615571
static std::regex close_regex(R"($|(?=>>>))");
616572

@@ -619,7 +575,7 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common
619575
res.content = res.content.substr(4);
620576
}
621577
return res;
622-
});
578+
};
623579
return data;
624580
}
625581

@@ -674,7 +630,7 @@ static common_chat_data common_chat_init_functionary_v3_1_llama_3_1_tool_call(co
674630

675631
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
676632
data.format = "functionary v3.1 llama 3.1 tool calls";
677-
data.parser = std::make_unique<monolithic_chat_parser>([params, has_raw_python, python_code_argument_name](const std::string & input) -> common_chat_msg {
633+
data.parser = [params, has_raw_python, python_code_argument_name](const std::string & input) -> common_chat_msg {
678634
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
679635
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
680636
std::smatch match;
@@ -695,7 +651,7 @@ static common_chat_data common_chat_init_functionary_v3_1_llama_3_1_tool_call(co
695651
static std::regex function_regex(R"(<function=(\w+)>)");
696652
static std::regex close_regex(R"(</function>)");
697653
return parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ false, has_raw_python);
698-
});
654+
};
699655
return data;
700656
}
701657

@@ -726,7 +682,7 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
726682

727683
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
728684
data.format = "hermes 2 pro tool calls";
729-
data.parser = std::make_unique<monolithic_chat_parser>([&](const std::string & input) -> common_chat_msg {
685+
data.parser = [&](const std::string & input) -> common_chat_msg {
730686
try {
731687
std::regex start_pattern(R"([\n\s]*<tool_call>)");
732688
std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
@@ -779,15 +735,21 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
779735
/* .tool_calls = */ {},
780736
};
781737
}
782-
});
738+
};
783739
return data;
784740
}
785741

786742
static common_chat_data common_chat_init_without_tools(const common_chat_template & tmpl, const struct common_chat_params & params) {
787743
common_chat_data data;
788744
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
789745
data.format = "content-only";
790-
data.parser = std::make_unique<text_chat_parser>();
746+
data.parser = [](const std::string & input) -> common_chat_msg {
747+
return {
748+
/* .role = */ "assistant",
749+
/* .content = */ input,
750+
/* .tool_calls = */ {},
751+
};
752+
};
791753
data.grammar_lazy = false;
792754
if (!params.json_schema.is_null()) {
793755
if (!params.grammar.empty()) {

common/chat-handler.hpp

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,21 +27,14 @@ struct common_chat_params {
2727
bool add_generation_prompt = true;
2828
};
2929

30-
class common_chat_parser {
31-
public:
32-
virtual ~common_chat_parser() = default;
33-
34-
virtual std::optional<common_chat_msg> parse_partial(const std::string & input) = 0;
35-
virtual common_chat_msg parse_final(const std::string & input) = 0;
36-
virtual std::unique_ptr<common_chat_parser> clone() const = 0;
37-
};
30+
typedef std::function<common_chat_msg(const std::string & input)> common_chat_parser;
3831

3932
struct common_chat_data {
4033
json prompt;
4134
std::string grammar;
4235
std::vector<common_grammar_trigger> grammar_triggers;
43-
std::vector<std::string> additional_stops;
44-
std::unique_ptr<class common_chat_parser> parser;
36+
std::vector<std::string> additional_stops;// std::unique_ptr<class common_chat_parser> parser;
37+
common_chat_parser parser;
4538
std::string format; // For debugging and testing.
4639
bool grammar_lazy = false;
4740
};

examples/server/server.cpp

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ struct slot_params {
117117
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
118118
std::string oaicompat_model;
119119
std::string oaicompat_cmpl_id;
120-
std::shared_ptr<common_chat_parser> chat_parser;
120+
common_chat_parser chat_parser;
121121

122122
json to_json() const {
123123
std::vector<std::string> samplers;
@@ -768,7 +768,6 @@ struct server_task_result_cmpl_partial : server_task_result {
768768
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
769769
std::string oaicompat_model;
770770
std::string oaicompat_cmpl_id;
771-
common_chat_msg oaicompat_chat_msg;
772771
std::shared_ptr<common_chat_parser> chat_parser;
773772

774773
virtual int get_index() override {
@@ -2220,16 +2219,6 @@ struct server_context {
22202219
}
22212220

22222221
void send_partial_response(server_slot & slot, const completion_token_output & tkn) {
2223-
common_chat_msg msg;
2224-
if (slot.params.chat_parser) {
2225-
if (auto opt_msg = slot.params.chat_parser->parse_partial(tkn.text_to_send)) {
2226-
msg = *opt_msg;
2227-
} else {
2228-
return;
2229-
}
2230-
} else {
2231-
msg.content = tkn.text_to_send;
2232-
}
22332222
auto res = std::make_unique<server_task_result_cmpl_partial>();
22342223

22352224
res->id = slot.id_task;
@@ -2245,7 +2234,6 @@ struct server_context {
22452234
res->oaicompat = slot.params.oaicompat;
22462235
res->oaicompat_model = slot.params.oaicompat_model;
22472236
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
2248-
res->oaicompat_chat_msg = msg;
22492237

22502238
// populate res.probs_output
22512239
if (slot.params.sampling.n_probs > 0) {
@@ -2286,18 +2274,14 @@ struct server_context {
22862274
res->oaicompat = slot.params.oaicompat;
22872275
res->oaicompat_model = slot.params.oaicompat_model;
22882276
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
2289-
if (!slot.params.chat_parser) {
2277+
if (slot.params.chat_parser) {
2278+
res->oaicompat_chat_msg = slot.params.chat_parser(slot.generated_text);
2279+
} else {
22902280
res->oaicompat_chat_msg = {
22912281
/* .role = */ "assistant",
22922282
/* .content = */ slot.generated_text,
22932283
/* .tool_calls = */ {}
22942284
};
2295-
} else if (slot.stop == STOP_TYPE_LIMIT) {
2296-
if (auto opt_msg = slot.params.chat_parser->parse_partial(slot.generated_text)) {
2297-
res->oaicompat_chat_msg = *opt_msg;
2298-
}
2299-
} else {
2300-
res->oaicompat_chat_msg = slot.params.chat_parser->parse_final(slot.generated_text);
23012285
}
23022286
// populate res.probs_output
23032287
if (slot.params.sampling.n_probs > 0) {
@@ -3835,9 +3819,7 @@ int main(int argc, char ** argv) {
38353819
task.params.sampling.grammar_trigger_words.push_back(trigger);
38363820
}
38373821
task.params.antiprompt = chat_data.additional_stops;
3838-
if (chat_data.parser) {
3839-
task.params.chat_parser = i == tokenized_prompts.size() ? std::move(chat_data.parser) : std::move(chat_data.parser->clone());
3840-
}
3822+
task.params.chat_parser = chat_data.parser;
38413823
if (task.params.sampling.grammar_lazy) {
38423824
GGML_ASSERT(task.params.sampling.grammar_trigger_tokens.size() > 0 || task.params.sampling.grammar_trigger_words.size() > 0);
38433825
}

tests/test-chat-handler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
397397
auto full_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, test_message, tools);
398398
std::cout << "Full delta:\n```\n" << full_delta << "\n```" << std::endl;
399399

400-
const auto msg = chat_data.parser->parse_final(full_delta);
400+
const auto msg = chat_data.parser(full_delta);
401401
assert_msg_equals(expected_msg, msg);
402402

403403
auto content_less_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, {

0 commit comments

Comments
 (0)