Skip to content

Commit 06b72b9

Browse files
author
ochafik
committed
add preludes to content on partial regex match
1 parent f5cd27b commit 06b72b9

File tree

4 files changed

+10
-18
lines changed

4 files changed

+10
-18
lines changed

common/chat-parser.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,20 +170,23 @@ std::string common_chat_msg_parser::consume_rest() {
170170
}
171171

172172
// Tries to find the regex, consumes it (pos right after it) and gives the prelude (right before it) and the groups to the callback.
173-
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from) {
173+
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from, bool add_prelude_to_content) {
174174
auto m = regex.search(input_, from == std::string::npos ? pos_ : from);
175175
if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
176176
return std::nullopt;
177177
}
178+
auto prelude = input_.substr(pos_, m.groups[0].begin - pos_);
179+
pos_ = m.groups[0].end;
180+
181+
if (add_prelude_to_content) {
182+
add_content(prelude);
183+
}
178184
if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) {
179185
if (is_partial()) {
180186
throw common_chat_msg_partial_exception(regex.str());
181187
}
182188
return std::nullopt;
183189
}
184-
auto prelude = input_.substr(pos_, m.groups[0].begin - pos_);
185-
pos_ = m.groups[0].end;
186-
187190
return find_regex_result{prelude, m.groups};
188191
}
189192

common/chat-parser.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ class common_chat_msg_parser {
7777
std::vector<common_string_range> groups;
7878
};
7979

80-
std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos);
80+
std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos, bool add_prelude_to_content = true);
8181

8282
bool try_consume_literal(const std::string & literal);
8383

common/chat.cpp

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -644,7 +644,6 @@ static void parse_json_tool_calls(
644644
}
645645
from = std::string::npos;
646646

647-
builder.add_content(res->prelude);
648647
auto maybe_raw_python = name == "python" && allow_raw_python;
649648
if (builder.input()[builder.pos()] == '{' || !maybe_raw_python) {
650649
if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) {
@@ -674,7 +673,6 @@ static void parse_json_tool_calls(
674673
};
675674
if (block_open) {
676675
if (auto res = builder.try_find_regex(*block_open)) {
677-
builder.add_content(res->prelude);
678676
parse_tool_calls();
679677
} else {
680678
builder.add_content(builder.consume_rest());
@@ -687,7 +685,6 @@ static void parse_json_tool_calls(
687685
static void parse_prefixed_json_tool_call_array(common_chat_msg_parser & builder, const common_regex & prefix, size_t rstrip_prefix = 0) {
688686
static const std::vector<std::vector<std::string>> args_paths = {{"arguments"}};
689687
if (auto res = builder.try_find_regex(prefix)) {
690-
builder.add_content(res->prelude);
691688
builder.move_back(rstrip_prefix);
692689
auto tool_calls = builder.consume_json_with_dumped_args(args_paths);
693690
if (!builder.add_tool_calls(tool_calls.value) || tool_calls.is_partial) {
@@ -983,7 +980,6 @@ static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) {
983980

984981
if (auto res = builder.try_find_regex(start_action_regex)) {
985982
// If we didn't extract thoughts, prelude includes them.
986-
builder.add_content(res->prelude);
987983
auto tool_calls = builder.consume_json_with_dumped_args({{"parameters"}});
988984
for (const auto & tool_call : tool_calls.value) {
989985
std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : "";
@@ -998,10 +994,8 @@ static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) {
998994
}
999995
builder.consume_regex(end_action_regex);
1000996
} else if (auto res = builder.try_find_regex(start_response_regex)) {
1001-
// If we didn't extract thoughts, prelude includes them.
1002-
builder.add_content(res->prelude);
1003997
if (auto res = builder.try_find_regex(end_response_regex)) {
1004-
builder.add_content(res->prelude);
998+
// If we didn't extract thoughts, prelude includes them.
1005999
} else {
10061000
builder.add_content(builder.consume_rest());
10071001
throw common_chat_msg_partial_exception(end_response_regex.str());
@@ -1120,8 +1114,6 @@ static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool w
11201114
if (with_builtin_tools) {
11211115
static const common_regex builtin_call_regex("<\\|python_tag\\|>");
11221116
if (auto res = builder.try_find_regex(builtin_call_regex)) {
1123-
builder.add_content(res->prelude);
1124-
11251117
auto fun_res = builder.consume_regex(function_name_regex);
11261118
auto function_name = builder.str(fun_res.groups[1]);
11271119

@@ -1439,7 +1431,6 @@ static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser
14391431
static const common_regex python_tag_regex(regex_escape("<|python_tag|>"));
14401432

14411433
if (auto res = builder.try_find_regex(python_tag_regex)) {
1442-
builder.add_content(res->prelude);
14431434
auto arguments = wrap_code_as_arguments(builder, builder.consume_rest());
14441435
builder.add_tool_call("python", "", arguments);
14451436
return;
@@ -1584,8 +1575,6 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
15841575
);
15851576

15861577
if (auto res = builder.try_find_regex(open_regex)) {
1587-
builder.add_content(res->prelude);
1588-
15891578
const auto & block_start = res->groups[1];
15901579
std::string block_end = block_start.empty() ? "" : "```";
15911580

tests/test-chat.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -780,7 +780,7 @@ static void test_template_output_parsers() {
780780
/* .thinking_forced_open = */ false,
781781
}));
782782
assert_msg_equals(
783-
simple_assist_msg(""),
783+
simple_assist_msg("Let's call something\n"),
784784
common_chat_parse(
785785
"Let's call something\n"
786786
"<tool_call>{\"name",

0 commit comments

Comments
 (0)