Skip to content

Commit 9462365

Browse files
author
ochafik
committed
refactor parser w/ optionals
1 parent cd3681d commit 9462365

File tree

5 files changed

+522
-427
lines changed

5 files changed

+522
-427
lines changed

common/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ add_library(${TARGET} STATIC
5858
base64.hpp
5959
chat.cpp
6060
chat.h
61+
chat-parser.cpp
62+
chat-parser.h
6163
common.cpp
6264
common.h
6365
console.cpp

common/chat-parser.cpp

Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
#include "chat-parser.h"
2+
#include "common.h"
3+
#include "log.h"
4+
// #include "json-partial.h"
5+
#include "regex-partial.h"
6+
7+
#include <cstdio>
8+
#include <optional>
9+
#include <stdexcept>
10+
#include <string>
11+
#include <vector>
12+
13+
using json = nlohmann::ordered_json;
14+
15+
common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, bool extract_reasoning)
16+
: input_(input), is_partial_(is_partial), extract_reasoning_(extract_reasoning)
17+
{
18+
result_.role = "assistant";
19+
20+
while (true) {
21+
std::string id = std::to_string(std::rand());
22+
if (input.find(id) == std::string::npos) {
23+
healing_marker_ = id;
24+
break;
25+
}
26+
}
27+
}
28+
29+
std::string common_chat_msg_parser::str(const common_string_range & rng) const {
30+
GGML_ASSERT(rng.begin <= rng.end);
31+
return input_.substr(rng.begin, rng.end - rng.begin);
32+
}
33+
34+
void common_chat_msg_parser::add_content(const std::string &content) {
35+
result_.content += content;
36+
}
37+
38+
void common_chat_msg_parser::add_reasoning_content(const std::string &reasoning_content) {
39+
result_.reasoning_content += reasoning_content;
40+
}
41+
42+
bool common_chat_msg_parser::add_tool_call(const std::string & name, const std::string & id, const std::string & arguments, const common_healing_marker & healing_marker) {
43+
if (name.empty()) {
44+
return false;
45+
}
46+
47+
auto marker_idx = std::string::npos;
48+
if (!arguments.empty() && !healing_marker.marker.empty()) {
49+
marker_idx = arguments.find(healing_marker.json_dump_marker);
50+
if (marker_idx == std::string::npos) {
51+
marker_idx = arguments.find(healing_marker.marker);
52+
}
53+
}
54+
55+
common_chat_tool_call tool_call;
56+
tool_call.name = name;
57+
tool_call.arguments = marker_idx != std::string::npos ? arguments.substr(0, marker_idx) : arguments;
58+
tool_call.id = id;
59+
60+
if (tool_call.arguments == "\"") {
61+
// This happens because of completing `:"$magic` after `"arguments"`
62+
tool_call.arguments = "";
63+
}
64+
LOG_DBG("Tool call arguments:\n\traw: %s\n\tresult: %s\n", arguments.c_str(), tool_call.arguments.c_str());
65+
result_.tool_calls.emplace_back(tool_call);
66+
return true;
67+
}
68+
bool common_chat_msg_parser::add_tool_call(const json & tool_call, const common_healing_marker & healing_marker) {
69+
std::string name = tool_call.contains("name") ? tool_call.at("name") : "";
70+
std::string id = tool_call.contains("id") ? tool_call.at("id") : "";
71+
std::string arguments = tool_call.contains("arguments") ? tool_call.at("arguments").dump() : "";
72+
return add_tool_call(name, id, arguments, healing_marker);
73+
}
74+
75+
bool common_chat_msg_parser::add_tool_calls(const json & arr, const common_healing_marker & healing_marker) {
76+
for (const auto & item : arr) {
77+
if (!add_tool_call(item, healing_marker)) {
78+
return false;
79+
}
80+
}
81+
return true;
82+
}
83+
void common_chat_msg_parser::finish() {
84+
if (!is_partial_ && pos_ != input_.size()) {
85+
throw std::runtime_error("Unexpected content at end of input: " + input_.substr(pos_));
86+
}
87+
result_.reasoning_content = string_strip(result_.reasoning_content);
88+
if (!result_.tool_calls.empty()) {
89+
result_.content = string_strip(result_.content);
90+
}
91+
}
92+
93+
void common_chat_msg_parser::incomplete(const std::string & message) {
94+
if (is_partial_) {
95+
finish();
96+
}
97+
throw common_chat_msg_partial_exception(message);
98+
}
99+
100+
bool common_chat_msg_parser::consume_spaces() {
101+
const auto length = input_.size();
102+
auto consumed = false;
103+
while (pos_ < length && std::isspace(input_[pos_])) {
104+
++pos_;
105+
consumed = true;
106+
}
107+
return consumed;
108+
}
109+
110+
bool common_chat_msg_parser::try_consume_literal(const std::string & literal) {
111+
auto pos = pos_;
112+
for (auto i = 0u; i < literal.size(); ++i) {
113+
if (pos >= input_.size()) {
114+
return false;
115+
}
116+
if (input_[pos] != literal[i]) {
117+
return false;
118+
}
119+
++pos;
120+
}
121+
pos_ = pos;
122+
return true;
123+
}
124+
125+
void common_chat_msg_parser::consume_literal(const std::string & literal) {
126+
if (!try_consume_literal(literal)) {
127+
incomplete("Expected literal '" + literal + "' at position " + std::to_string(pos_));
128+
}
129+
}
130+
131+
void common_chat_msg_parser::try_consume_think_tags(const common_regex & start_think_regex, const common_regex & end_think_regex) {
132+
if (extract_reasoning_) {
133+
if (try_consume_regex(start_think_regex)) {
134+
if (auto res = try_find_regex(end_think_regex)) {
135+
result_.reasoning_content = res->prelude;
136+
consume_spaces();
137+
} else {
138+
result_.reasoning_content = consume_rest();
139+
incomplete("Failed to find end of reasoning tag " + end_think_regex.str());
140+
}
141+
} else if (auto res = try_find_regex(end_think_regex)) {
142+
result_.reasoning_content = res->prelude;
143+
consume_spaces();
144+
}
145+
}
146+
}
147+
148+
std::string common_chat_msg_parser::consume_rest() {
149+
auto rest = input_.substr(pos_);
150+
pos_ = input_.size();
151+
return rest;
152+
}
153+
154+
// Tries to find the regex, consumes it (pos right after it) and gives the prelude (right before it) and the groups to the callback.
155+
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_regex(const common_regex & regex) {
156+
auto m = regex.search(input_, pos_);
157+
if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
158+
return std::nullopt;
159+
}
160+
if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) {
161+
incomplete(regex.str());
162+
return std::nullopt;
163+
}
164+
auto prelude = input_.substr(pos_, m.groups[0].begin - pos_);
165+
pos_ = m.groups[0].end;
166+
167+
return find_regex_result{prelude, m.groups};
168+
}
169+
170+
common_chat_msg_parser::consume_regex_result common_chat_msg_parser::consume_regex(const common_regex & regex) {
171+
if (auto result = try_consume_regex(regex)) {
172+
return *result;
173+
}
174+
incomplete("Failed to consume regex: " + regex.str());
175+
return {};
176+
}
177+
178+
std::optional<common_chat_msg_parser::consume_regex_result> common_chat_msg_parser::try_consume_regex(const common_regex & regex) {
179+
if (!regex.at_start()) {
180+
throw std::runtime_error("try_consume_regex requires a common_regex w/ at_start=true");
181+
}
182+
auto m = regex.search(input_, pos_);
183+
if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
184+
return std::nullopt;
185+
}
186+
if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) {
187+
incomplete(regex.str());
188+
return std::nullopt;
189+
}
190+
pos_ = m.groups[0].end;
191+
192+
return consume_regex_result{m.groups};
193+
}
194+
195+
// Calls the callback, *then* explodes w/ a partial match exception if it's partial
196+
common_json common_chat_msg_parser::consume_json(
197+
const std::vector<std::vector<std::string>> & args_paths
198+
) {
199+
if (auto result = try_consume_json(args_paths)) {
200+
return *result;
201+
}
202+
incomplete("Failed to consume JSON");
203+
return {};
204+
}
205+
206+
std::optional<common_json> common_chat_msg_parser::try_consume_json(
207+
const std::vector<std::vector<std::string>> & args_paths
208+
) {
209+
auto it = input_.cbegin() + pos_;
210+
const auto end = input_.cend();
211+
common_json result;
212+
if (!common_json_parse(it, end, healing_marker_, result)) {
213+
return std::nullopt;
214+
}
215+
pos_ = std::distance(input_.cbegin(), it);
216+
if (result.healing_marker.marker.empty()) {
217+
// No healing marker, just return the parsed json
218+
return result;
219+
}
220+
if (!is_partial_) {
221+
incomplete("JSON is incomplete");
222+
return std::nullopt; // Actually unreachable
223+
}
224+
225+
LOG_DBG("Parsed partial JSON: %s (json_healing_marker: %s)\n", result.json.dump().c_str(), result.healing_marker.json_dump_marker.c_str());
226+
227+
// Healing marker found, we need to visit the json and removed objects that we didn't want to heal
228+
auto is_arguments_path = [&](const std::vector<std::string> & path) {
229+
return std::find(args_paths.begin(), args_paths.end(), path) != args_paths.end();
230+
};
231+
232+
std::vector<std::string> path;
233+
std::function<json(const json &)> remove_unsupported_healings = [&](const json & j) {
234+
if (j.is_object()) {
235+
auto obj = json::object();
236+
for (const auto & p : j.items()) {
237+
const auto & key = p.key();
238+
const auto & value = p.value();
239+
const std::string key_str = key; // NOLINT
240+
auto idx = key_str.find(healing_marker_);
241+
if (idx != std::string::npos) {//} && idx != 0) {
242+
// Don't heal keys halfway, cut just after their opening quotes
243+
obj[result.healing_marker.marker] = 1;
244+
if (idx != 0) {
245+
result.healing_marker.json_dump_marker = result.healing_marker.marker;
246+
}
247+
break;
248+
}
249+
path.push_back(key_str);
250+
auto is_args = is_arguments_path(path);
251+
if (is_args) {
252+
obj[key] = value;
253+
} else if (value.is_string()) {
254+
const std::string value_str = value;
255+
if (value_str.find(healing_marker_) == std::string::npos) {
256+
obj[key] = value;
257+
} else {
258+
obj[result.healing_marker.marker] = 1;
259+
result.healing_marker.json_dump_marker = result.healing_marker.marker;
260+
}
261+
} else {
262+
obj[key] = remove_unsupported_healings(value);
263+
}
264+
path.pop_back();
265+
}
266+
return obj;
267+
}
268+
if (j.is_array()) {
269+
auto arr = json::array();
270+
for (const auto & value : j) {
271+
if (value.is_string()) {
272+
std::string str = value;
273+
auto idx = str.find(healing_marker_);
274+
if (idx != std::string::npos) {
275+
// Don't heal array values that aren't in the arguments.
276+
arr.push_back(result.healing_marker.marker);
277+
result.healing_marker.json_dump_marker = result.healing_marker.marker;
278+
break;
279+
}
280+
}
281+
arr.push_back(remove_unsupported_healings(value));
282+
}
283+
return arr;
284+
}
285+
return j;
286+
};
287+
288+
if (!is_arguments_path({})) {
289+
auto cleaned = remove_unsupported_healings(result.json);
290+
LOG_DBG("Cleaned up JSON %s to %s (json_healing_marker : '%s')\n", result.json.dump().c_str(), cleaned.dump().c_str(), result.healing_marker.json_dump_marker.c_str());
291+
result.json = cleaned;
292+
}
293+
LOG_DBG("Half-healed json: %s\n", result.json.dump().c_str());
294+
return result;
295+
}

common/chat-parser.h

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
#pragma once
2+
3+
#include "chat.h"
4+
#include "json-partial.h"
5+
#include "regex-partial.h"
6+
7+
#include <optional>
8+
#include <string>
9+
#include <vector>
10+
11+
using common_string_ranges = std::vector<common_string_range>;
12+
13+
class common_chat_msg_partial_exception : public std::runtime_error {
14+
public:
15+
common_chat_msg_partial_exception(const std::string & message) : std::runtime_error(message) {}
16+
};
17+
18+
class common_chat_msg_parser {
19+
std::string input_;
20+
bool is_partial_;
21+
bool extract_reasoning_;
22+
size_t pos_ = 0;
23+
common_chat_msg result_;
24+
std::string healing_marker_;
25+
26+
public:
27+
common_chat_msg_parser(const std::string & input, bool is_partial, bool extract_reasoning);
28+
29+
const std::string & input() const { return input_; }
30+
const std::string & healing_marker() const { return healing_marker_; }
31+
const bool & is_partial() const { return is_partial_; }
32+
const bool & extract_reasoning() const { return extract_reasoning_; }
33+
const common_chat_msg & result() const { return result_; }
34+
35+
void move_to(size_t pos) {
36+
if (pos > input_.size()) {
37+
throw std::runtime_error("Invalid position!");
38+
}
39+
pos_ = pos;
40+
}
41+
void move_back(size_t n) {
42+
if (pos_ < n) {
43+
throw std::runtime_error("Can't move back that far!");
44+
}
45+
pos_ -= n;
46+
}
47+
48+
std::string str(const common_string_range & rng) const;
49+
50+
void add_content(const std::string & content);
51+
void add_reasoning_content(const std::string & reasoning_content);
52+
53+
bool add_tool_call(const std::string & name, const std::string & id, const std::string & arguments, const common_healing_marker & healing_marker);
54+
bool add_tool_call(const nlohmann::ordered_json & tool_call, const common_healing_marker & healing_marker);
55+
bool add_tool_calls(const nlohmann::ordered_json & arr, const common_healing_marker & healing_marker);
56+
57+
void finish();
58+
59+
void incomplete(const std::string & message);
60+
61+
bool consume_spaces();
62+
63+
bool try_consume_literal(const std::string & literal);
64+
65+
void consume_literal(const std::string & literal);
66+
67+
void try_consume_think_tags(const common_regex & start_think_regex, const common_regex & end_think_regex);
68+
69+
std::string consume_rest();
70+
71+
struct find_regex_result {
72+
std::string prelude;
73+
common_string_ranges groups;
74+
};
75+
76+
std::optional<find_regex_result> try_find_regex(const common_regex & regex);
77+
78+
struct consume_regex_result {
79+
common_string_ranges groups;
80+
};
81+
consume_regex_result consume_regex(const common_regex & regex);
82+
83+
std::optional<consume_regex_result> try_consume_regex(const common_regex & regex);
84+
85+
common_json consume_json(
86+
const std::vector<std::vector<std::string>> & args_paths = {}
87+
);
88+
89+
std::optional<common_json> try_consume_json(
90+
const std::vector<std::vector<std::string>> & args_paths = {}
91+
);
92+
};

0 commit comments

Comments
 (0)