Skip to content

Commit 6dcff43

Browse files
author
ochafik
committed
add common_json w/ support for truncated json healing
1 parent 16c9c63 commit 6dcff43

File tree

5 files changed

+342
-0
lines changed

5 files changed

+342
-0
lines changed

common/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ add_library(${TARGET} STATIC
6464
console.h
6565
json-schema-to-grammar.cpp
6666
json.hpp
67+
json-partial.h
68+
json-partial.cpp
6769
llguidance.cpp
6870
log.cpp
6971
log.h

common/json-partial.cpp

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
#include <json-partial.h>
2+
#include "ggml.h"
3+
#include "log.h"
4+
#include <string>
5+
6+
#include <json.hpp>
7+
8+
using json = nlohmann::ordered_json;
9+
10+
enum common_json_stack_element_type {
11+
COMMON_JSON_STACK_ELEMENT_OBJECT,
12+
COMMON_JSON_STACK_ELEMENT_KEY,
13+
COMMON_JSON_STACK_ELEMENT_ARRAY,
14+
};
15+
16+
struct common_json_stack_element {
17+
common_json_stack_element_type type;
18+
std::string key;
19+
};
20+
21+
bool common_json_parse(
22+
const std::string & input,
23+
const std::string & healing_marker,
24+
common_json & out)
25+
{
26+
std::string::const_iterator it = input.begin();
27+
const auto end = input.end();
28+
return common_json_parse(it, end, healing_marker, out);
29+
}
30+
31+
bool common_json_parse(
32+
std::string::const_iterator & it,
33+
const std::string::const_iterator & end,
34+
const std::string & healing_marker,
35+
common_json & out)
36+
{
37+
// // https://json.nlohmann.me/features/parsing/sax_interface/
38+
struct json_error_locator : public nlohmann::json_sax<json> {
39+
std::size_t position;
40+
bool found_error;
41+
std::string last_token;
42+
std::string exception_message;
43+
std::vector<common_json_stack_element> stack;
44+
45+
json_error_locator() : position(0), found_error(false) {}
46+
47+
bool parse_error(std::size_t position, const std::string & last_token, const json::exception & ex) override { // NOLINT
48+
this->position = position - 1;
49+
this->found_error = true;
50+
this->last_token = last_token;
51+
this->exception_message = ex.what();
52+
return false;
53+
}
54+
void close_value() {
55+
if (!stack.empty() && (stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY)) {
56+
stack.pop_back();
57+
}
58+
}
59+
bool null() override { // NOLINT
60+
close_value();
61+
return true;
62+
}
63+
bool boolean(bool) override { // NOLINT
64+
close_value();
65+
return true;
66+
}
67+
bool number_integer(number_integer_t) override { // NOLINT
68+
close_value();
69+
return true;
70+
}
71+
bool number_unsigned(number_unsigned_t) override { // NOLINT
72+
close_value();
73+
return true;
74+
}
75+
bool number_float(number_float_t, const string_t &) override { // NOLINT
76+
close_value();
77+
return true;
78+
}
79+
bool string(string_t &) override { // NOLINT
80+
close_value();
81+
return true;
82+
}
83+
bool binary(binary_t &) override { // NOLINT
84+
close_value();
85+
return true;
86+
}
87+
bool start_object(std::size_t) override { // NOLINT
88+
stack.push_back({COMMON_JSON_STACK_ELEMENT_OBJECT, ""});
89+
return true;
90+
}
91+
bool end_object() override {
92+
GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT);
93+
stack.pop_back();
94+
close_value();
95+
return true;
96+
}
97+
bool key(string_t & key) override { // NOLINT
98+
stack.push_back({COMMON_JSON_STACK_ELEMENT_KEY, key});
99+
return true;
100+
}
101+
bool start_array(std::size_t) override { // NOLINT
102+
stack.push_back({COMMON_JSON_STACK_ELEMENT_ARRAY, ""});
103+
return true;
104+
}
105+
bool end_array() override {
106+
GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY);
107+
stack.pop_back();
108+
close_value();
109+
return true;
110+
}
111+
};
112+
json_error_locator err_loc;
113+
auto start = it;
114+
json::sax_parse(it, end, &err_loc);
115+
116+
// std::string::const_iterator temptative_end;
117+
if (err_loc.found_error) {
118+
it = start;
119+
auto temptative_end = it + err_loc.position;
120+
// fprintf(stderr, "Error at position %zu (is_end = %s): %s\n", err_loc.position, temptative_end == end ? "true" : "false", err_loc.exception_message.c_str());
121+
122+
auto input = std::string(it, temptative_end);
123+
try {
124+
out.json = json::parse(input);
125+
// out.json = json::parse(it, temptative_end);
126+
it = temptative_end;
127+
return true;
128+
} catch (const std::exception & ex) {
129+
// No, needs healing.
130+
LOG_DBG("Failed to parse up to error: %s: <<<%s>>>\n", ex.what(), std::string(it, temptative_end).c_str());
131+
}
132+
auto can_parse = [](const std::string & str) {
133+
try {
134+
auto _ = json::parse(str); // NOLINT
135+
return true;
136+
} catch (const std::exception &) {
137+
return false;
138+
}
139+
};
140+
if (!healing_marker.empty() && !err_loc.stack.empty()) {
141+
std::string str(it, temptative_end);
142+
auto last_non_sp_pos = str.find_last_not_of(" \n\r\t");
143+
if (last_non_sp_pos == std::string::npos) {
144+
throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
145+
}
146+
auto last_non_sp_char = str[last_non_sp_pos];
147+
148+
std::string closing;
149+
for (size_t i = err_loc.stack.size(); i > 0; i--) {
150+
auto & el = err_loc.stack[i - 1];
151+
if (el.type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
152+
closing += "}";
153+
} else if (el.type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
154+
closing += "]";
155+
} else if (el.type != COMMON_JSON_STACK_ELEMENT_KEY) {
156+
throw std::runtime_error("Unexpected stack element type");
157+
}
158+
}
159+
160+
const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$";
161+
162+
if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) {
163+
// We're inside an object value
164+
if (last_non_sp_char == ':') {
165+
// Was about to create an object value
166+
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
167+
} else if (can_parse(str + ": 1" + closing)) {
168+
str += (out.healing_marker.json_dump_marker = ":\"" + magic_seed) + "\"" + closing;
169+
} else if (last_non_sp_char == '{') {
170+
// Was about to create an object
171+
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
172+
} else if (can_parse(str + "\"" + closing)) {
173+
// Was inside an object value string
174+
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
175+
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
176+
// Was inside an object value string after an escape
177+
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
178+
} else {
179+
// find last :
180+
auto last_pos = str.find_last_of(':');
181+
if (last_pos == std::string::npos) {
182+
throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
183+
}
184+
// Cutting back to opening : for object value
185+
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
186+
}
187+
} else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
188+
if (last_non_sp_char == ',' || last_non_sp_char == '[') {
189+
// Was about to create an array value
190+
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
191+
} else if (can_parse(str + "\"" + closing)) {
192+
// Was inside an array value string
193+
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
194+
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
195+
// Was inside an array value string after an escape
196+
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
197+
} else if (!std::isdigit(last_non_sp_char) && last_non_sp_char != '.' && last_non_sp_char != 'e' && last_non_sp_char != 'E' && last_non_sp_char != '-' && can_parse(str + ", 1" + closing)) {
198+
// Had just finished a value
199+
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing;
200+
} else {
201+
auto last_pos = str.find_last_of("[,");
202+
if (last_pos == std::string::npos) {
203+
throw std::runtime_error("Cannot heal a truncated JSON array stopped in an unknown location");
204+
}
205+
// Cutting back to last [ or , for array value
206+
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
207+
}
208+
} else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
209+
if (last_non_sp_char == ',' || last_non_sp_char == '{') {
210+
// Was about to create an object key+value
211+
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
212+
} else if (can_parse(str + ",\"\": 1" + closing)) {
213+
// Was about to create an object key+value
214+
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\": 1" + closing;
215+
} else if (can_parse(str + "\": 1" + closing)) {
216+
// Was inside an object key string
217+
str += (out.healing_marker.json_dump_marker = magic_seed) + "\": 1" + closing;
218+
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
219+
// Was inside an object key string after an escape
220+
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing;
221+
} else {
222+
auto last_pos = str.find_last_of(':');
223+
if (last_pos == std::string::npos) {
224+
throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
225+
}
226+
// fprintf(stderr, "Cutting back to last : for object key+value\n");
227+
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
228+
}
229+
} else {
230+
throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
231+
}
232+
// fprintf(stderr, "HEALED:\nSTRING <<<\n%s\n>>>\n\nmagic_cut: <<<\n%s\n>>>\n\n", str.c_str(), out.healing_marker.json_dump_marker.c_str());
233+
out.json = json::parse(str);
234+
it = temptative_end;
235+
return true;
236+
}
237+
// TODO: handle unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...)
238+
// fprintf(stderr, "Closing: TODO\n");
239+
return false;
240+
}
241+
out.json = json::parse(it, end);
242+
it = end;
243+
return true;
244+
}

common/json-partial.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#pragma once
2+
#include <json.hpp>
3+
4+
struct common_healing_marker {
5+
std::string marker;
6+
std::string json_dump_marker;
7+
};
8+
9+
struct common_json {
10+
nlohmann::ordered_json json;
11+
common_healing_marker healing_marker;
12+
};
13+
14+
bool common_json_parse(
15+
const std::string & input,
16+
const std::string & healing_marker,
17+
common_json & out);
18+
19+
bool common_json_parse(
20+
std::string::const_iterator & it,
21+
const std::string::const_iterator & end,
22+
const std::string & healing_marker,
23+
common_json & out);

tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ endif()
133133
llama_target_and_test(test-log.cpp)
134134
llama_target_and_test(test-arg-parser.cpp)
135135
llama_target_and_test(test-chat-template.cpp)
136+
llama_target_and_test(test-json-partial.cpp)
136137
llama_target_and_test(test-regex-partial.cpp)
137138

138139
# llama_target_and_test(test-opt.cpp) # SLOW

tests/test-json-partial.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#include "common.h"
2+
#include "json-partial.h"
3+
#include <exception>
4+
#include <iostream>
5+
#include <stdexcept>
6+
7+
template <class T> static void assert_equals(const T & expected, const T & actual) {
8+
if (expected != actual) {
9+
std::cerr << "Expected: " << expected << std::endl;
10+
std::cerr << "Actual: " << actual << std::endl;
11+
std::cerr << std::flush;
12+
throw std::runtime_error("Test failed");
13+
}
14+
}
15+
16+
static void test_json_healing() {
17+
auto parse = [](const std::string & str) {
18+
std::cerr << "# Parsing: " << str << '\n';
19+
std::string::const_iterator it = str.begin();
20+
const auto end = str.end();
21+
common_json out;
22+
std::string healing_marker = "$llama.cpp.json$";
23+
if (common_json_parse(it, end, healing_marker, out)) {
24+
auto dump = out.json.dump();
25+
std::cerr << "Parsed: " << dump << '\n';
26+
std::cerr << "Magic: " << out.healing_marker.json_dump_marker << '\n';
27+
std::string result;
28+
if (!out.healing_marker.json_dump_marker.empty()) {
29+
auto i = dump.find(out.healing_marker.json_dump_marker);
30+
if (i == std::string::npos) {
31+
throw std::runtime_error("Failed to find magic in dump " + dump + " (magic: " + out.healing_marker.json_dump_marker + ")");
32+
}
33+
result = dump.substr(0, i);
34+
} else {
35+
result = dump;
36+
}
37+
std::cerr << "Result: " << result << '\n';
38+
if (string_starts_with(str, result)) {
39+
std::cerr << "Failure!\n";
40+
}
41+
// return dump;
42+
} else {
43+
throw std::runtime_error("Failed to parse: " + str);
44+
}
45+
46+
};
47+
auto parse_all = [&](const std::string & str) {
48+
for (size_t i = 1; i < str.size(); i++) {
49+
parse(str.substr(0, i));
50+
}
51+
};
52+
parse_all("{\"a\": \"b\"}");
53+
parse_all("{\"hey\": 1, \"ho\\\"ha\": [1]}");
54+
55+
parse_all("[{\"a\": \"b\"}]");
56+
57+
common_json out;
58+
assert_equals(true, common_json_parse("[{\"a\": \"b\"}", "$foo", out));
59+
assert_equals<std::string>("[{\"a\":\"b\"},\"$foo\"]", out.json.dump());
60+
61+
assert_equals(true, common_json_parse("{ \"code", "$foo", out));
62+
assert_equals<std::string>("{\"code$foo\":1}", out.json.dump());
63+
assert_equals<std::string>("$foo", out.healing_marker.json_dump_marker);
64+
65+
assert_equals(true, common_json_parse("{ \"code\"", "$foo", out));
66+
assert_equals<std::string>("{\"code\":\"$foo\"}", out.json.dump());
67+
}
68+
69+
int main() {
70+
test_json_healing();
71+
return 0;
72+
}

0 commit comments

Comments
 (0)