Skip to content

Commit c5c3482

Browse files
author
ochafik
committed
try_consume_regex: basic tests + fix non-partial case
1 parent 02913b0 commit c5c3482

File tree

3 files changed

+70
-14
lines changed

3 files changed

+70
-14
lines changed

common/chat-parser.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser
135135

136136
void common_chat_msg_parser::consume_literal(const std::string & literal) {
137137
if (!try_consume_literal(literal)) {
138-
incomplete("Expected literal '" + literal + "' at position " + std::to_string(pos_));
138+
incomplete(literal);
139139
}
140140
}
141141

@@ -166,7 +166,7 @@ bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think
166166
handle_reasoning(consume_rest(), /* closed */ !is_partial());
167167
}
168168
if (!syntax_.thinking_forced_open) {
169-
incomplete("Failed to find end of reasoning tag " + end_think);
169+
incomplete(end_think);
170170
}
171171
return true;
172172
} else {
@@ -209,7 +209,7 @@ common_chat_msg_parser::consume_regex_result common_chat_msg_parser::consume_reg
209209
if (auto result = try_consume_regex(regex)) {
210210
return *result;
211211
}
212-
incomplete("Failed to consume regex: " + regex.str());
212+
incomplete(regex.str());
213213
}
214214

215215
std::optional<common_chat_msg_parser::consume_regex_result> common_chat_msg_parser::try_consume_regex(const common_regex & regex) {
@@ -218,7 +218,10 @@ std::optional<common_chat_msg_parser::consume_regex_result> common_chat_msg_pars
218218
return std::nullopt;
219219
}
220220
if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) {
221-
incomplete(regex.str());
221+
if (is_partial()) {
222+
incomplete(regex.str());
223+
}
224+
return std::nullopt;
222225
}
223226
if (m.groups[0].begin != pos_) {
224227
// Didn't match at the current position.
@@ -242,7 +245,7 @@ std::optional<common_json> common_chat_msg_parser::try_consume_json() {
242245
return result;
243246
}
244247
if (!is_partial()) {
245-
incomplete("JSON is incomplete");
248+
incomplete("JSON");
246249
}
247250
return result;
248251
}
@@ -251,7 +254,7 @@ common_json common_chat_msg_parser::consume_json() {
251254
if (auto result = try_consume_json()) {
252255
return *result;
253256
}
254-
incomplete("Failed to consume JSON");
257+
incomplete("JSON");
255258
}
256259

257260
common_chat_msg_parser::consume_json_result common_chat_msg_parser::consume_json_with_dumped_args(
@@ -260,7 +263,7 @@ common_chat_msg_parser::consume_json_result common_chat_msg_parser::consume_json
260263
if (auto result = try_consume_json_with_dumped_args(args_paths)) {
261264
return *result;
262265
}
263-
incomplete("Failed to consume JSON");
266+
incomplete("JSON");
264267
}
265268

266269
std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parser::try_consume_json_with_dumped_args(

common/chat-parser.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@ class common_chat_msg_parser {
1818
std::string input_;
1919
bool is_partial_;
2020
common_chat_syntax syntax_;
21+
std::string healing_marker_;
2122

2223
size_t pos_ = 0;
2324
common_chat_msg result_;
24-
std::string healing_marker_;
2525

2626
public:
2727
common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax);

tests/test-chat-parser.cpp

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
//
66
// cmake -B build && cmake --build build --parallel && ./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null
77
//
8+
#include <exception>
89
#include <iostream>
910
#include <json.hpp>
1011
#include <string>
@@ -29,9 +30,27 @@ static void assert_equals(const char * expected, const std::string & actual) {
2930
return assert_equals<std::string>(expected, actual);
3031
}
3132

33+
template <class T = std::exception>
34+
static void assert_throws(const std::function<void()> & fn, const std::string & expected_exception_pattern = "") {
35+
try {
36+
fn();
37+
} catch (const T & e) {
38+
if (expected_exception_pattern.empty()) {
39+
return;
40+
}
41+
std::regex expected_exception_regex(expected_exception_pattern);
42+
std::string actual_message = e.what();
43+
if (std::regex_search(actual_message, expected_exception_regex)) {
44+
return;
45+
}
46+
throw std::runtime_error("Exception doesn't match expected pattern: " + actual_message + " (pattern: " + expected_exception_pattern + ")");
47+
}
48+
throw std::runtime_error("Exception was expected but not thrown");
49+
}
50+
3251
static void test_reasoning() {
3352
{
34-
common_chat_msg_parser builder("<tnk>Cogito</tnk>Ergo sum", false, {
53+
common_chat_msg_parser builder("<tnk>Cogito</tnk>Ergo sum", /* is_partial= */ false, {
3554
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
3655
/* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
3756
/* .reasoning_in_content = */ false,
@@ -41,7 +60,7 @@ static void test_reasoning() {
4160
assert_equals("<tnk>Cogito</tnk>Ergo sum", builder.consume_rest());
4261
}
4362
{
44-
common_chat_msg_parser builder("<tnk>Cogito</tnk>Ergo sum", false, {
63+
common_chat_msg_parser builder("<tnk>Cogito</tnk>Ergo sum", /* is_partial= */ false, {
4564
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
4665
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
4766
/* .reasoning_in_content = */ false,
@@ -52,7 +71,7 @@ static void test_reasoning() {
5271
assert_equals("Ergo sum", builder.consume_rest());
5372
}
5473
{
55-
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", false, {
74+
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
5675
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
5776
/* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
5877
/* .reasoning_in_content = */ false,
@@ -62,7 +81,7 @@ static void test_reasoning() {
6281
assert_equals("Cogito</tnk>Ergo sum", builder.consume_rest());
6382
}
6483
{
65-
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", false, {
84+
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
6685
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
6786
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
6887
/* .reasoning_in_content = */ false,
@@ -73,7 +92,7 @@ static void test_reasoning() {
7392
assert_equals("Ergo sum", builder.consume_rest());
7493
}
7594
{
76-
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", false, {
95+
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
7796
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
7897
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
7998
/* .reasoning_in_content = */ true,
@@ -86,8 +105,42 @@ static void test_reasoning() {
86105
}
87106

88107
static void test_regex() {
108+
auto test_throws = [](const std::string & input, const std::string & regex, const std::string & expected_exception_pattern = "") {
109+
common_chat_msg_parser builder(input, /* is_partial= */ false, {});
110+
assert_throws([&]() { builder.consume_regex(common_regex(regex)); }, expected_exception_pattern);
111+
};
112+
113+
test_throws("Hello, world!", "abc", "^abc$");
114+
test_throws("Hello, world!", "e", "^e$");
115+
89116
{
90-
common_chat_msg_parser builder("Hello, world!", false, common_chat_syntax());
117+
common_chat_msg_parser builder("Hello, world!", /* is_partial= */ false, {});
118+
builder.consume_regex(common_regex("Hello"));
119+
assert_equals(", world!", builder.consume_rest());
120+
}
121+
122+
{
123+
// When in non partial mode, we can say whether the regex was consumed or not.
124+
common_chat_msg_parser builder("Hello,", /* is_partial= */ false, {});
125+
assert_equals(false, builder.try_consume_regex(common_regex("Hello, world!")).has_value());
126+
assert_equals(true, builder.try_consume_regex(common_regex("Hell(o, world!)?")).has_value());
127+
}
128+
{
129+
// But in partial mode, we have a partial final match / can't decide, so we throw a partial exception.
130+
common_chat_msg_parser builder("Hello,", /* is_partial= */ true, {});
131+
assert_throws<common_chat_msg_partial_exception>([&]() {
132+
builder.try_consume_regex(common_regex("Hello, world!"));
133+
}, "^Hello, world!$");
134+
}
135+
136+
// Now regardless of the mode, we can tell these aren't a match.
137+
for (const auto is_partial : {false, true}) {
138+
common_chat_msg_parser builder("Hello,", is_partial, {});
139+
assert_equals(false, builder.try_consume_regex(common_regex("a(b|c)(d|e)f")).has_value());
140+
}
141+
for (const auto is_partial : {false, true}) {
142+
common_chat_msg_parser builder("Hello,", is_partial, {});
143+
assert_equals(false, builder.try_consume_literal("Oh"));
91144
}
92145
}
93146

0 commit comments

Comments
 (0)