|
| 1 | +#include <json-partial.h> |
| 2 | +#include "ggml.h" |
| 3 | +#include "log.h" |
| 4 | +#include <string> |
| 5 | + |
| 6 | +#include <json.hpp> |
| 7 | + |
| 8 | +using json = nlohmann::ordered_json; |
| 9 | + |
| 10 | +enum common_json_stack_element_type { |
| 11 | + COMMON_JSON_STACK_ELEMENT_OBJECT, |
| 12 | + COMMON_JSON_STACK_ELEMENT_KEY, |
| 13 | + COMMON_JSON_STACK_ELEMENT_ARRAY, |
| 14 | +}; |
| 15 | + |
| 16 | +struct common_json_stack_element { |
| 17 | + common_json_stack_element_type type; |
| 18 | + std::string key; |
| 19 | +}; |
| 20 | + |
| 21 | +bool common_json_parse( |
| 22 | + const std::string & input, |
| 23 | + const std::string & healing_marker, |
| 24 | + common_json & out) |
| 25 | +{ |
| 26 | + std::string::const_iterator it = input.begin(); |
| 27 | + const auto end = input.end(); |
| 28 | + return common_json_parse(it, end, healing_marker, out); |
| 29 | +} |
| 30 | + |
| 31 | +bool common_json_parse( |
| 32 | + std::string::const_iterator & it, |
| 33 | + const std::string::const_iterator & end, |
| 34 | + const std::string & healing_marker, |
| 35 | + common_json & out) |
| 36 | +{ |
| 37 | + // // https://json.nlohmann.me/features/parsing/sax_interface/ |
| 38 | + struct json_error_locator : public nlohmann::json_sax<json> { |
| 39 | + std::size_t position; |
| 40 | + bool found_error; |
| 41 | + std::string last_token; |
| 42 | + std::string exception_message; |
| 43 | + std::vector<common_json_stack_element> stack; |
| 44 | + |
| 45 | + json_error_locator() : position(0), found_error(false) {} |
| 46 | + |
| 47 | + bool parse_error(std::size_t position, const std::string & last_token, const json::exception & ex) override { // NOLINT |
| 48 | + this->position = position - 1; |
| 49 | + this->found_error = true; |
| 50 | + this->last_token = last_token; |
| 51 | + this->exception_message = ex.what(); |
| 52 | + return false; |
| 53 | + } |
| 54 | + void close_value() { |
| 55 | + if (!stack.empty() && (stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY)) { |
| 56 | + stack.pop_back(); |
| 57 | + } |
| 58 | + } |
| 59 | + bool null() override { // NOLINT |
| 60 | + close_value(); |
| 61 | + return true; |
| 62 | + } |
| 63 | + bool boolean(bool) override { // NOLINT |
| 64 | + close_value(); |
| 65 | + return true; |
| 66 | + } |
| 67 | + bool number_integer(number_integer_t) override { // NOLINT |
| 68 | + close_value(); |
| 69 | + return true; |
| 70 | + } |
| 71 | + bool number_unsigned(number_unsigned_t) override { // NOLINT |
| 72 | + close_value(); |
| 73 | + return true; |
| 74 | + } |
| 75 | + bool number_float(number_float_t, const string_t &) override { // NOLINT |
| 76 | + close_value(); |
| 77 | + return true; |
| 78 | + } |
| 79 | + bool string(string_t &) override { // NOLINT |
| 80 | + close_value(); |
| 81 | + return true; |
| 82 | + } |
| 83 | + bool binary(binary_t &) override { // NOLINT |
| 84 | + close_value(); |
| 85 | + return true; |
| 86 | + } |
| 87 | + bool start_object(std::size_t) override { // NOLINT |
| 88 | + stack.push_back({COMMON_JSON_STACK_ELEMENT_OBJECT, ""}); |
| 89 | + return true; |
| 90 | + } |
| 91 | + bool end_object() override { |
| 92 | + GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT); |
| 93 | + stack.pop_back(); |
| 94 | + close_value(); |
| 95 | + return true; |
| 96 | + } |
| 97 | + bool key(string_t & key) override { // NOLINT |
| 98 | + stack.push_back({COMMON_JSON_STACK_ELEMENT_KEY, key}); |
| 99 | + return true; |
| 100 | + } |
| 101 | + bool start_array(std::size_t) override { // NOLINT |
| 102 | + stack.push_back({COMMON_JSON_STACK_ELEMENT_ARRAY, ""}); |
| 103 | + return true; |
| 104 | + } |
| 105 | + bool end_array() override { |
| 106 | + GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY); |
| 107 | + stack.pop_back(); |
| 108 | + close_value(); |
| 109 | + return true; |
| 110 | + } |
| 111 | + }; |
| 112 | + json_error_locator err_loc; |
| 113 | + auto start = it; |
| 114 | + json::sax_parse(it, end, &err_loc); |
| 115 | + |
| 116 | + // std::string::const_iterator temptative_end; |
| 117 | + if (err_loc.found_error) { |
| 118 | + it = start; |
| 119 | + auto temptative_end = it + err_loc.position; |
| 120 | + // fprintf(stderr, "Error at position %zu (is_end = %s): %s\n", err_loc.position, temptative_end == end ? "true" : "false", err_loc.exception_message.c_str()); |
| 121 | + |
| 122 | + auto input = std::string(it, temptative_end); |
| 123 | + try { |
| 124 | + out.json = json::parse(input); |
| 125 | + // out.json = json::parse(it, temptative_end); |
| 126 | + it = temptative_end; |
| 127 | + return true; |
| 128 | + } catch (const std::exception & ex) { |
| 129 | + // No, needs healing. |
| 130 | + LOG_DBG("Failed to parse up to error: %s: <<<%s>>>\n", ex.what(), std::string(it, temptative_end).c_str()); |
| 131 | + } |
| 132 | + auto can_parse = [](const std::string & str) { |
| 133 | + try { |
| 134 | + auto _ = json::parse(str); // NOLINT |
| 135 | + return true; |
| 136 | + } catch (const std::exception &) { |
| 137 | + return false; |
| 138 | + } |
| 139 | + }; |
| 140 | + if (!healing_marker.empty() && !err_loc.stack.empty()) { |
| 141 | + std::string str(it, temptative_end); |
| 142 | + auto last_non_sp_pos = str.find_last_not_of(" \n\r\t"); |
| 143 | + if (last_non_sp_pos == std::string::npos) { |
| 144 | + throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location"); |
| 145 | + } |
| 146 | + auto last_non_sp_char = str[last_non_sp_pos]; |
| 147 | + |
| 148 | + std::string closing; |
| 149 | + for (size_t i = err_loc.stack.size(); i > 0; i--) { |
| 150 | + auto & el = err_loc.stack[i - 1]; |
| 151 | + if (el.type == COMMON_JSON_STACK_ELEMENT_OBJECT) { |
| 152 | + closing += "}"; |
| 153 | + } else if (el.type == COMMON_JSON_STACK_ELEMENT_ARRAY) { |
| 154 | + closing += "]"; |
| 155 | + } else if (el.type != COMMON_JSON_STACK_ELEMENT_KEY) { |
| 156 | + throw std::runtime_error("Unexpected stack element type"); |
| 157 | + } |
| 158 | + } |
| 159 | + |
| 160 | + const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$"; |
| 161 | + |
| 162 | + if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) { |
| 163 | + // We're inside an object value |
| 164 | + if (last_non_sp_char == ':') { |
| 165 | + // Was about to create an object value |
| 166 | + str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing; |
| 167 | + } else if (can_parse(str + ": 1" + closing)) { |
| 168 | + str += (out.healing_marker.json_dump_marker = ":\"" + magic_seed) + "\"" + closing; |
| 169 | + } else if (last_non_sp_char == '{') { |
| 170 | + // Was about to create an object |
| 171 | + str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing; |
| 172 | + } else if (can_parse(str + "\"" + closing)) { |
| 173 | + // Was inside an object value string |
| 174 | + str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing; |
| 175 | + } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) { |
| 176 | + // Was inside an object value string after an escape |
| 177 | + str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing; |
| 178 | + } else { |
| 179 | + // find last : |
| 180 | + auto last_pos = str.find_last_of(':'); |
| 181 | + if (last_pos == std::string::npos) { |
| 182 | + throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location"); |
| 183 | + } |
| 184 | + // Cutting back to opening : for object value |
| 185 | + str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing; |
| 186 | + } |
| 187 | + } else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY) { |
| 188 | + if (last_non_sp_char == ',' || last_non_sp_char == '[') { |
| 189 | + // Was about to create an array value |
| 190 | + str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing; |
| 191 | + } else if (can_parse(str + "\"" + closing)) { |
| 192 | + // Was inside an array value string |
| 193 | + str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing; |
| 194 | + } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) { |
| 195 | + // Was inside an array value string after an escape |
| 196 | + str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing; |
| 197 | + } else if (!std::isdigit(last_non_sp_char) && last_non_sp_char != '.' && last_non_sp_char != 'e' && last_non_sp_char != 'E' && last_non_sp_char != '-' && can_parse(str + ", 1" + closing)) { |
| 198 | + // Had just finished a value |
| 199 | + str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing; |
| 200 | + } else { |
| 201 | + auto last_pos = str.find_last_of("[,"); |
| 202 | + if (last_pos == std::string::npos) { |
| 203 | + throw std::runtime_error("Cannot heal a truncated JSON array stopped in an unknown location"); |
| 204 | + } |
| 205 | + // Cutting back to last [ or , for array value |
| 206 | + str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing; |
| 207 | + } |
| 208 | + } else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT) { |
| 209 | + if (last_non_sp_char == ',' || last_non_sp_char == '{') { |
| 210 | + // Was about to create an object key+value |
| 211 | + str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing; |
| 212 | + } else if (can_parse(str + ",\"\": 1" + closing)) { |
| 213 | + // Was about to create an object key+value |
| 214 | + str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\": 1" + closing; |
| 215 | + } else if (can_parse(str + "\": 1" + closing)) { |
| 216 | + // Was inside an object key string |
| 217 | + str += (out.healing_marker.json_dump_marker = magic_seed) + "\": 1" + closing; |
| 218 | + } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) { |
| 219 | + // Was inside an object key string after an escape |
| 220 | + str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing; |
| 221 | + } else { |
| 222 | + auto last_pos = str.find_last_of(':'); |
| 223 | + if (last_pos == std::string::npos) { |
| 224 | + throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location"); |
| 225 | + } |
| 226 | + // fprintf(stderr, "Cutting back to last : for object key+value\n"); |
| 227 | + str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing; |
| 228 | + } |
| 229 | + } else { |
| 230 | + throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location"); |
| 231 | + } |
| 232 | + // fprintf(stderr, "HEALED:\nSTRING <<<\n%s\n>>>\n\nmagic_cut: <<<\n%s\n>>>\n\n", str.c_str(), out.healing_marker.json_dump_marker.c_str()); |
| 233 | + out.json = json::parse(str); |
| 234 | + it = temptative_end; |
| 235 | + return true; |
| 236 | + } |
| 237 | + // TODO: handle unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...) |
| 238 | + // fprintf(stderr, "Closing: TODO\n"); |
| 239 | + return false; |
| 240 | + } |
| 241 | + out.json = json::parse(it, end); |
| 242 | + it = end; |
| 243 | + return true; |
| 244 | +} |
0 commit comments