Skip to content

Commit 4358d5d

Browse files
author
ochafik
committed
rm common_regex.at_start
1 parent c46d4da commit 4358d5d

File tree

5 files changed

+33
-46
lines changed

5 files changed

+33
-46
lines changed

common/chat-parser.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,6 @@ common_chat_msg_parser::consume_regex_result common_chat_msg_parser::consume_reg
177177
}
178178

179179
std::optional<common_chat_msg_parser::consume_regex_result> common_chat_msg_parser::try_consume_regex(const common_regex & regex) {
180-
if (!regex.at_start()) {
181-
throw std::runtime_error("try_consume_regex requires a common_regex w/ at_start=true");
182-
}
183180
auto m = regex.search(input_, pos_);
184181
if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
185182
return std::nullopt;
@@ -188,6 +185,10 @@ std::optional<common_chat_msg_parser::consume_regex_result> common_chat_msg_pars
188185
incomplete(regex.str());
189186
return std::nullopt;
190187
}
188+
if (m.groups[0].begin != pos_) {
189+
// Didn't match at the current position.
190+
return std::nullopt;
191+
}
191192
pos_ = m.groups[0].end;
192193

193194
return consume_regex_result{m.groups};

common/chat.cpp

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#include <vector>
1717

1818

19-
static const common_regex default_start_think_regex("<think>", /* at_start= */ true);
19+
static const common_regex default_start_think_regex("<think>");
2020
static const common_regex default_end_think_regex("</think>");
2121

2222
static std::string string_diff(const std::string & last, const std::string & current) {
@@ -915,13 +915,13 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
915915
}
916916

917917
static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) {
918-
static const common_regex start_thinking_regex("<\\|START_THINKING\\|>", /* at_start= */ true);
918+
static const common_regex start_thinking_regex("<\\|START_THINKING\\|>");
919919
static const common_regex end_thinking_regex("<\\|END_THINKING\\|>");
920920

921921
builder.try_consume_think_tags(start_thinking_regex, end_thinking_regex);
922922

923923
static const common_regex start_action_regex("<\\|START_ACTION\\|>");
924-
static const common_regex end_action_regex("<\\|END_ACTION\\|>", /* at_start= */ true);
924+
static const common_regex end_action_regex("<\\|END_ACTION\\|>");
925925
static const common_regex start_response_regex("<\\|START_RESPONSE\\|>");
926926
static const common_regex end_response_regex("<\\|END_RESPONSE\\|>");
927927

@@ -1048,12 +1048,12 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
10481048
}
10491049
static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) {
10501050
static const common_regex function_regex(
1051-
"\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: ", /* at_start= */ true);
1052-
static const common_regex close_regex("\\}\\s*", /* at_start= */ true);
1053-
static const common_regex builtin_call_regex("<\\|python_tag\\|>", /* at_start= */ true);
1051+
"\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: ");
1052+
static const common_regex close_regex("\\}\\s*");
1053+
static const common_regex builtin_call_regex("<\\|python_tag\\|>");
10541054

1055-
static const common_regex function_name_regex("\\s*(\\w+)\\s*\\.\\s*call\\(", /* at_start= */ true);
1056-
static const common_regex arg_name_regex("\\s*(\\w+)\\s*=\\s*", /* at_start= */ true);
1055+
static const common_regex function_name_regex("\\s*(\\w+)\\s*\\.\\s*call\\(");
1056+
static const common_regex arg_name_regex("\\s*(\\w+)\\s*=\\s*");
10571057

10581058
if (with_builtin_tools) {
10591059
if (auto res = builder.try_find_regex(builtin_call_regex)) {
@@ -1164,9 +1164,9 @@ static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
11641164
builder.try_consume_think_tags(default_start_think_regex, default_end_think_regex);
11651165

11661166
static const common_regex tool_calls_begin("[\\s\\r\\n]*(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>)");
1167-
static const common_regex tool_calls_end("<|tool▁calls▁end|>", /* at_start= */ true);
1168-
static const common_regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n", /* at_start= */ true);
1169-
static const common_regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>", /* at_start= */ true);
1167+
static const common_regex tool_calls_end("<|tool▁calls▁end|>");
1168+
static const common_regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n");
1169+
static const common_regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>");
11701170

11711171
parse_json_tool_calls(builder, tool_calls_begin, function_regex, close_regex, tool_calls_end);
11721172
}
@@ -1272,7 +1272,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
12721272
}
12731273
static void common_chat_parse_functionary_v3_2(common_chat_msg_parser & builder) {
12741274
static const common_regex function_regex(R"((>>>)?(\w+\n\{|python\n|all\n))");
1275-
static const common_regex close_regex(R"(\s*)", /* at_start= */ true);
1275+
static const common_regex close_regex(R"(\s*)");
12761276

12771277
parse_json_tool_calls(builder, std::nullopt, function_regex, close_regex, std::nullopt, /* allow_raw_python= */ true,
12781278
/* get_function_name= */ [&](const auto & res) -> std::string {
@@ -1351,7 +1351,7 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
13511351
}
13521352
static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser & builder) {
13531353
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
1354-
static const common_regex python_tag_regex(regex_escape("<|python_tag|>"), /* at_start= */ true);
1354+
static const common_regex python_tag_regex(regex_escape("<|python_tag|>"));
13551355

13561356
if (auto res = builder.try_find_regex(python_tag_regex)) {
13571357
builder.add_content(res->prelude);
@@ -1368,8 +1368,8 @@ static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser
13681368
return;
13691369
}
13701370

1371-
static const common_regex function_regex(R"(<function=(\w+)>)", /* at_start= */ true);
1372-
static const common_regex close_regex(R"(</function>)", /* at_start= */ true);
1371+
static const common_regex function_regex(R"(<function=(\w+)>)");
1372+
static const common_regex close_regex(R"(</function>)");
13731373

13741374
parse_json_tool_calls(builder, std::nullopt, function_regex, close_regex, std::nullopt);
13751375
}
@@ -1490,9 +1490,8 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
14901490
")"
14911491
"|"
14921492
"(?:<function=([^>]+)>" // match 4 (function name)
1493-
"|<function name=\"([^\"]+)\">)" // match 5 (function name again)
1494-
"([\\s\\S]*)", // match 6 (function arguments + rest)})"
1495-
/* at_start= */ true
1493+
"|<function name=\"([^\"]+)\">)" // match 5 (function name again)
1494+
"([\\s\\S]*)" // match 6 (function arguments + rest)})"
14961495
);
14971496

14981497
if (auto res = builder.try_find_regex(open_regex)) {

common/regex-partial.cpp

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22
#include "common.h"
33
#include <functional>
44

5-
common_regex::common_regex(const std::string & pattern, bool at_start) :
5+
common_regex::common_regex(const std::string & pattern) :
66
pattern(pattern),
77
rx(pattern),
8-
rx_reversed_partial(regex_to_reversed_partial_regex(pattern)),
9-
at_start_(at_start) {}
8+
rx_reversed_partial(regex_to_reversed_partial_regex(pattern)) {}
109

1110
common_regex_match common_regex::search(const std::string & input, size_t pos, bool as_match) const {
1211
std::smatch match;
@@ -18,23 +17,21 @@ common_regex_match common_regex::search(const std::string & input, size_t pos, b
1817
? std::regex_match(start, input.end(), match, rx)
1918
: std::regex_search(start, input.end(), match, rx);
2019
if (found) {
21-
if (as_match || !at_start_ || match.position() == 0) {
22-
common_regex_match res;
23-
res.type = COMMON_REGEX_MATCH_TYPE_FULL;
24-
for (size_t i = 0; i < match.size(); ++i) {
25-
auto begin = pos + match.position(i);
26-
res.groups.emplace_back(begin, begin + match.length(i));
27-
}
28-
return res;
20+
common_regex_match res;
21+
res.type = COMMON_REGEX_MATCH_TYPE_FULL;
22+
for (size_t i = 0; i < match.size(); ++i) {
23+
auto begin = pos + match.position(i);
24+
res.groups.emplace_back(begin, begin + match.length(i));
2925
}
26+
return res;
3027
}
3128
std::match_results<std::string::const_reverse_iterator> srmatch;
3229
if (std::regex_match(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial)) {
3330
auto group = srmatch[1].str();
3431
if (group.length() != 0) {
3532
auto it = srmatch[1].second.base();
3633
// auto position = static_cast<size_t>(std::distance(input.begin(), it));
37-
if ((!as_match && !at_start_) || it == input.begin()) {
34+
if ((!as_match) || it == input.begin()) {
3835
common_regex_match res;
3936
res.type = COMMON_REGEX_MATCH_TYPE_PARTIAL;
4037
auto begin = std::distance(input.begin(), it);

common/regex-partial.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,13 @@ class common_regex {
4242
std::string pattern;
4343
std::regex rx;
4444
std::regex rx_reversed_partial;
45-
bool at_start_;
4645

4746
public:
48-
common_regex(const std::string & pattern, bool at_start = false);
47+
common_regex(const std::string & pattern);
4948

5049
common_regex_match search(const std::string & input, size_t pos, bool as_match = false) const;
5150

5251
const std::string & str() const { return pattern; }
53-
bool at_start() const { return at_start_; }
5452
};
5553

5654
// For testing only (pretty print of failures).

tests/test-regex-partial.cpp

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ template <class T> static void assert_equals(const T & expected, const T & actua
1616

1717
struct test_case {
1818
std::string pattern;
19-
bool at_start = false;
2019
struct input_output {
2120
std::string input;
2221
common_regex_match output;
@@ -28,7 +27,6 @@ static void test_regex() {
2827
std::vector<test_case> test_cases {
2928
test_case {
3029
"a",
31-
/* .at_start = */ false,
3230
{
3331
{"a", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 1}}}},
3432
{"b", {COMMON_REGEX_MATCH_TYPE_NONE, {}}},
@@ -38,7 +36,6 @@ static void test_regex() {
3836
},
3937
test_case {
4038
"abcd",
41-
/* .at_start = */ false,
4239
{
4340
{"abcd", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
4441
{"abcde", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
@@ -56,7 +53,6 @@ static void test_regex() {
5653
},
5754
test_case {
5855
".*?ab",
59-
/* .at_start = */ false,
6056
{
6157
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
6258
{"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
@@ -68,7 +64,6 @@ static void test_regex() {
6864
},
6965
test_case {
7066
"a.*?b",
71-
/* .at_start = */ false,
7267
{
7368
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
7469
{"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
@@ -81,7 +76,6 @@ static void test_regex() {
8176
},
8277
test_case {
8378
"ab(?:cd){2,4}ef",
84-
/* .at_start = */ false,
8579
{
8680
// {"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, 0, {}}},
8781
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
@@ -99,7 +93,6 @@ static void test_regex() {
9993
},
10094
test_case {
10195
"a(?:rte| pure )fact",
102-
/* .at_start = */ false,
10396
{
10497
{"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
10598
{"art", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
@@ -118,7 +111,6 @@ static void test_regex() {
118111
},
119112
test_case {
120113
"abc",
121-
/* .at_start = */ true,
122114
{
123115
{" abcc", {}},
124116
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
@@ -129,8 +121,8 @@ static void test_regex() {
129121
};
130122

131123
for (const auto & test_case : test_cases) {
132-
common_regex cr(test_case.pattern, test_case.at_start);
133-
std::cout << "Testing pattern: /" << test_case.pattern << "/ (at_start = " << (test_case.at_start ? "true" : "false") << ")\n";
124+
common_regex cr(test_case.pattern);
125+
std::cout << "Testing pattern: /" << test_case.pattern << "/\n";
134126
// std::cout << " partial rev: " << cr.reversed_partial_pattern.str() << '\n';
135127
for (const auto & input_output : test_case.inputs_outputs) {
136128
std::cout << " Input: " << input_output.input << '\n';

0 commit comments

Comments
 (0)