Skip to content

Commit aefc8a4

Browse files
author
ochafik
committed
refactor + test chat parser (try_consume_json_with_dumped_args, literal based thinking tags parsing)
1 parent b48ab23 commit aefc8a4

File tree

8 files changed

+496
-169
lines changed

8 files changed

+496
-169
lines changed

common/chat-parser.cpp

Lines changed: 125 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -37,42 +37,30 @@ void common_chat_msg_parser::add_reasoning_content(const std::string &reasoning_
3737
result_.reasoning_content += reasoning_content;
3838
}
3939

40-
bool common_chat_msg_parser::add_tool_call(const std::string & name, const std::string & id, const std::string & arguments, const common_healing_marker & healing_marker) {
40+
bool common_chat_msg_parser::add_tool_call(const std::string & name, const std::string & id, const std::string & arguments) {
4141
if (name.empty()) {
4242
return false;
4343
}
4444

45-
auto marker_idx = std::string::npos;
46-
if (!arguments.empty() && !healing_marker.marker.empty()) {
47-
marker_idx = arguments.find(healing_marker.json_dump_marker);
48-
if (marker_idx == std::string::npos) {
49-
marker_idx = arguments.find(healing_marker.marker);
50-
}
51-
}
52-
5345
common_chat_tool_call tool_call;
5446
tool_call.name = name;
55-
tool_call.arguments = marker_idx != std::string::npos ? arguments.substr(0, marker_idx) : arguments;
47+
tool_call.arguments = arguments;
5648
tool_call.id = id;
5749

58-
if (tool_call.arguments == "\"") {
59-
// This happens because of completing `:"$magic` after `"arguments"`
60-
tool_call.arguments = "";
61-
}
6250
LOG_DBG("Tool call arguments:\n\traw: %s\n\tresult: %s\n", arguments.c_str(), tool_call.arguments.c_str());
6351
result_.tool_calls.emplace_back(tool_call);
6452
return true;
6553
}
66-
bool common_chat_msg_parser::add_tool_call(const json & tool_call, const common_healing_marker & healing_marker) {
54+
bool common_chat_msg_parser::add_tool_call(const json & tool_call) {
6755
std::string name = tool_call.contains("name") ? tool_call.at("name") : "";
6856
std::string id = tool_call.contains("id") ? tool_call.at("id") : "";
69-
std::string arguments = tool_call.contains("arguments") ? tool_call.at("arguments").dump() : "";
70-
return add_tool_call(name, id, arguments, healing_marker);
57+
std::string arguments = tool_call.contains("arguments") ? tool_call.at("arguments") : "";
58+
return add_tool_call(name, id, arguments);
7159
}
7260

73-
bool common_chat_msg_parser::add_tool_calls(const json & arr, const common_healing_marker & healing_marker) {
61+
bool common_chat_msg_parser::add_tool_calls(const json & arr) {
7462
for (const auto & item : arr) {
75-
if (!add_tool_call(item, healing_marker)) {
63+
if (!add_tool_call(item)) {
7664
return false;
7765
}
7866
}
@@ -121,30 +109,71 @@ bool common_chat_msg_parser::try_consume_literal(const std::string & literal) {
121109
return true;
122110
}
123111

112+
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_literal(const std::string & literal) {
113+
auto idx = input_.find(literal, pos_);
114+
if (idx != std::string::npos) {
115+
find_regex_result res;
116+
res.prelude = input_.substr(pos_, idx - pos_);
117+
auto end = idx + literal.size();
118+
res.groups.emplace_back(common_string_range{idx, end});
119+
move_to(end);
120+
return res;
121+
}
122+
if (is_partial_) {
123+
idx = string_find_partial_stop(input_, literal);
124+
if (idx != std::string::npos && idx >= pos_) {
125+
find_regex_result res;
126+
res.prelude = input_.substr(pos_, idx - pos_);
127+
auto end = input_.size();
128+
res.groups.emplace_back(common_string_range{idx, end});
129+
move_to(end);
130+
return res;
131+
}
132+
}
133+
return std::nullopt;
134+
}
135+
124136
void common_chat_msg_parser::consume_literal(const std::string & literal) {
125137
if (!try_consume_literal(literal)) {
126138
incomplete("Expected literal '" + literal + "' at position " + std::to_string(pos_));
127139
}
128140
}
129141

130-
void common_chat_msg_parser::try_consume_think_tags(const common_regex & start_think_regex, const common_regex & end_think_regex) {
142+
bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think, const std::string & end_think) {
143+
auto handle_reasoning = [&](const std::string & reasoning, bool closed) {
144+
if (syntax_.reasoning_in_content) {
145+
add_content(syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "<think>" : start_think);
146+
add_content(reasoning);
147+
if (closed) {
148+
add_content(syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "</think>" : end_think);
149+
}
150+
} else {
151+
add_reasoning_content(reasoning);
152+
}
153+
};
131154
if (syntax_.reasoning_format != COMMON_REASONING_FORMAT_NONE) {
132-
if (syntax_.thinking_forced_open || try_consume_regex(start_think_regex)) {
133-
if (auto res = try_find_regex(end_think_regex)) {
134-
result_.reasoning_content = res->prelude;
155+
if (syntax_.thinking_forced_open || try_consume_literal(start_think)) {
156+
if (auto res = try_find_literal(end_think)) {
157+
handle_reasoning(res->prelude, /* closed */ true);
135158
consume_spaces();
136-
} else {
137-
result_.reasoning_content = consume_rest();
138-
if (!syntax_.thinking_forced_open) {
139-
incomplete("Failed to find end of reasoning tag " + end_think_regex.str());
140-
}
141-
return;
159+
return true;
160+
}
161+
auto rest = consume_rest();
162+
if (!rest.empty()) {
163+
handle_reasoning(consume_rest(), /* closed */ !is_partial());
142164
}
143-
} else if (auto res = try_find_regex(end_think_regex)) {
144-
result_.reasoning_content = res->prelude;
165+
if (!syntax_.thinking_forced_open) {
166+
incomplete("Failed to find end of reasoning tag " + end_think);
167+
}
168+
return true;
169+
}
170+
if (auto res = try_find_literal(end_think)) {
171+
handle_reasoning(res->prelude, /* closed */ true);
145172
consume_spaces();
173+
return true;
146174
}
147175
}
176+
return false;
148177
}
149178

150179
std::string common_chat_msg_parser::consume_rest() {
@@ -195,19 +224,7 @@ std::optional<common_chat_msg_parser::consume_regex_result> common_chat_msg_pars
195224
return consume_regex_result{m.groups};
196225
}
197226

198-
// Calls the callback, *then* explodes w/ a partial match exception if it's partial
199-
common_json common_chat_msg_parser::consume_json(
200-
const std::vector<std::vector<std::string>> & args_paths
201-
) {
202-
if (auto result = try_consume_json(args_paths)) {
203-
return *result;
204-
}
205-
incomplete("Failed to consume JSON");
206-
}
207-
208-
std::optional<common_json> common_chat_msg_parser::try_consume_json(
209-
const std::vector<std::vector<std::string>> & args_paths
210-
) {
227+
std::optional<common_json> common_chat_msg_parser::try_consume_json() {
211228
auto it = input_.cbegin() + pos_;
212229
const auto end = input_.cend();
213230
common_json result;
@@ -222,45 +239,84 @@ std::optional<common_json> common_chat_msg_parser::try_consume_json(
222239
if (!is_partial()) {
223240
incomplete("JSON is incomplete");
224241
}
242+
return result;
243+
}
225244

226-
LOG_DBG("Parsed partial JSON: %s (json_healing_marker: %s)\n", result.json.dump().c_str(), result.healing_marker.json_dump_marker.c_str());
245+
common_json common_chat_msg_parser::consume_json() {
246+
if (auto result = try_consume_json()) {
247+
return *result;
248+
}
249+
incomplete("Failed to consume JSON");
250+
}
227251

228-
// Healing marker found, we need to visit the json and removed objects that we didn't want to heal
252+
nlohmann::ordered_json common_chat_msg_parser::consume_json_with_dumped_args(
253+
const std::vector<std::vector<std::string>> & args_paths
254+
) {
255+
if (auto result = try_consume_json_with_dumped_args(args_paths)) {
256+
return *result;
257+
}
258+
incomplete("Failed to consume JSON");
259+
}
260+
261+
std::optional<nlohmann::ordered_json> common_chat_msg_parser::try_consume_json_with_dumped_args(
262+
const std::vector<std::vector<std::string>> & args_paths
263+
) {
264+
auto partial = try_consume_json();
265+
if (!partial) {
266+
return std::nullopt;
267+
}
229268
auto is_arguments_path = [&](const std::vector<std::string> & path) {
230269
return std::find(args_paths.begin(), args_paths.end(), path) != args_paths.end();
231270
};
232271

272+
if (partial->healing_marker.marker.empty()) {
273+
if (args_paths.empty()) {
274+
// No arguments to dump, and JSON was parsed fully.
275+
return partial->json;
276+
}
277+
if (is_arguments_path({})) {
278+
// Entire JSON is the arguments and was parsed fully.
279+
return partial->json.dump();
280+
}
281+
}
282+
283+
LOG_DBG("Parsed partial JSON: %s (json_healing_marker: %s)\n", partial->json.dump().c_str(), partial->healing_marker.json_dump_marker.c_str());
284+
233285
std::vector<std::string> path;
234-
std::function<json(const json &)> remove_unsupported_healings = [&](const json & j) {
286+
std::function<json(const json &)> remove_unsupported_healings_and_dump_args = [&](const json & j) -> json {
287+
if (is_arguments_path(path)) {
288+
auto arguments = j.dump();
289+
if (is_partial() && !partial->healing_marker.marker.empty()) {
290+
auto idx = arguments.find(partial->healing_marker.json_dump_marker);
291+
if (idx != std::string::npos) {
292+
arguments.resize(idx);
293+
}
294+
if (arguments == "\"") {
295+
// This happens because of completing `:"$magic` after `"arguments"`
296+
arguments = "";
297+
}
298+
}
299+
return arguments;
300+
}
235301
if (j.is_object()) {
236302
auto obj = json::object();
237303
for (const auto & p : j.items()) {
238304
const auto & key = p.key();
239305
const auto & value = p.value();
240306
const std::string key_str = key; // NOLINT
241307
auto idx = key_str.find(healing_marker_);
242-
if (idx != std::string::npos) {//} && idx != 0) {
243-
// Don't heal keys halfway, cut just after their opening quotes
244-
obj[result.healing_marker.marker] = 1;
245-
if (idx != 0) {
246-
result.healing_marker.json_dump_marker = result.healing_marker.marker;
247-
}
308+
if (idx != std::string::npos) {
248309
break;
249310
}
250311
path.push_back(key_str);
251-
auto is_args = is_arguments_path(path);
252-
if (is_args) {
253-
obj[key] = value;
254-
} else if (value.is_string()) {
312+
if (value.is_string()) {
255313
const std::string value_str = value;
256-
if (value_str.find(healing_marker_) == std::string::npos) {
257-
obj[key] = value;
258-
} else {
259-
obj[result.healing_marker.marker] = 1;
260-
result.healing_marker.json_dump_marker = result.healing_marker.marker;
314+
if (value_str.find(healing_marker_) != std::string::npos) {
315+
break;
261316
}
317+
obj[key] = value;
262318
} else {
263-
obj[key] = remove_unsupported_healings(value);
319+
obj[key] = remove_unsupported_healings_and_dump_args(value);
264320
}
265321
path.pop_back();
266322
}
@@ -274,23 +330,19 @@ std::optional<common_json> common_chat_msg_parser::try_consume_json(
274330
auto idx = str.find(healing_marker_);
275331
if (idx != std::string::npos) {
276332
// Don't heal array values that aren't in the arguments.
277-
arr.push_back(result.healing_marker.marker);
278-
result.healing_marker.json_dump_marker = result.healing_marker.marker;
333+
// arr.push_back(partial->healing_marker.marker);
334+
// partial->healing_marker.json_dump_marker = partial->healing_marker.marker;
279335
break;
280336
}
281337
}
282-
arr.push_back(remove_unsupported_healings(value));
338+
arr.push_back(remove_unsupported_healings_and_dump_args(value));
283339
}
284340
return arr;
285341
}
286342
return j;
287343
};
288344

289-
if (!is_arguments_path({})) {
290-
auto cleaned = remove_unsupported_healings(result.json);
291-
LOG_DBG("Cleaned up JSON %s to %s (json_healing_marker : '%s')\n", result.json.dump().c_str(), cleaned.dump().c_str(), result.healing_marker.json_dump_marker.c_str());
292-
result.json = cleaned;
293-
}
294-
LOG_DBG("Half-healed json: %s\n", result.json.dump().c_str());
295-
return result;
345+
auto cleaned = remove_unsupported_healings_and_dump_args(partial->json);
346+
LOG_DBG("Cleaned up JSON %s to %s (json_healing_marker : '%s')\n", partial->json.dump().c_str(), cleaned.dump().c_str(), partial->healing_marker.json_dump_marker.c_str());
347+
return cleaned;
296348
}

common/chat-parser.h

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "chat.h"
44
#include "json-partial.h"
5+
#include "json.hpp"
56
#include "regex-partial.h"
67

78
#include <optional>
@@ -53,13 +54,13 @@ class common_chat_msg_parser {
5354
void add_reasoning_content(const std::string & reasoning_content);
5455

5556
// Adds a tool call to the result. If the tool call is too incomplete (e.g. name empty), it won't add anything.
56-
bool add_tool_call(const std::string & name, const std::string & id, const std::string & arguments, const common_healing_marker & healing_marker);
57+
bool add_tool_call(const std::string & name, const std::string & id, const std::string & arguments);
5758

5859
// Adds a tool call using the "name", "id" and "arguments" fields of the json object
59-
bool add_tool_call(const nlohmann::ordered_json & tool_call, const common_healing_marker & healing_marker);
60+
bool add_tool_call(const nlohmann::ordered_json & tool_call);
6061

6162
// Adds an array of tool calls using their "name", "id" and "arguments" fields.
62-
bool add_tool_calls(const nlohmann::ordered_json & arr, const common_healing_marker & healing_marker);
63+
bool add_tool_calls(const nlohmann::ordered_json & arr);
6364

6465
void finish();
6566

@@ -68,11 +69,9 @@ class common_chat_msg_parser {
6869

6970
bool consume_spaces();
7071

71-
bool try_consume_literal(const std::string & literal);
72-
7372
void consume_literal(const std::string & literal);
7473

75-
void try_consume_think_tags(const common_regex & start_think_regex, const common_regex & end_think_regex);
74+
bool try_parse_reasoning(const std::string & start_think, const std::string & end_think);
7675

7776
std::string consume_rest();
7877

@@ -83,18 +82,24 @@ class common_chat_msg_parser {
8382

8483
std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos);
8584

85+
bool try_consume_literal(const std::string & literal);
86+
87+
std::optional<find_regex_result> try_find_literal(const std::string & literal);
88+
8689
struct consume_regex_result {
8790
std::vector<common_string_range> groups;
8891
};
8992
consume_regex_result consume_regex(const common_regex & regex);
9093

9194
std::optional<consume_regex_result> try_consume_regex(const common_regex & regex);
9295

93-
common_json consume_json(
96+
std::optional<common_json> try_consume_json();
97+
common_json consume_json();
98+
99+
nlohmann::ordered_json consume_json_with_dumped_args(
94100
const std::vector<std::vector<std::string>> & args_paths = {}
95101
);
96-
97-
std::optional<common_json> try_consume_json(
102+
std::optional<nlohmann::ordered_json> try_consume_json_with_dumped_args(
98103
const std::vector<std::vector<std::string>> & args_paths = {}
99104
);
100105
};

0 commit comments

Comments
 (0)