|
1 | 1 | #include "json-schema-to-grammar.h" |
| 2 | +#include "common.h" |
| 3 | + |
2 | 4 | #include <algorithm> |
3 | 5 | #include <fstream> |
4 | 6 | #include <map> |
|
11 | 13 |
|
12 | 14 | using json = nlohmann::ordered_json; |
13 | 15 |
|
14 | | -template <typename Iterator> |
15 | | -static std::string join(Iterator begin, Iterator end, const std::string & separator); |
16 | | - |
17 | | -static std::string repeat(const std::string & str, size_t n); |
18 | | - |
19 | 16 | static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") { |
20 | 17 | auto has_max = max_items != std::numeric_limits<int>::max(); |
21 | 18 |
|
@@ -128,8 +125,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream & |
128 | 125 | if (sub_len > 0) { |
129 | 126 | auto from_sub = from.substr(i + 1); |
130 | 127 | auto to_sub = to.substr(i + 1); |
131 | | - auto sub_zeros = repeat("0", sub_len); |
132 | | - auto sub_nines = repeat("9", sub_len); |
| 128 | + auto sub_zeros = string_repeat("0", sub_len); |
| 129 | + auto sub_nines = string_repeat("9", sub_len); |
133 | 130 |
|
134 | 131 | auto to_reached = false; |
135 | 132 | out << "("; |
@@ -188,8 +185,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream & |
188 | 185 | auto max_digits = max_s.length(); |
189 | 186 |
|
190 | 187 | for (auto digits = min_digits; digits < max_digits; digits++) { |
191 | | - uniform_range(min_s, repeat("9", digits)); |
192 | | - min_s = "1" + repeat("0", digits); |
| 188 | + uniform_range(min_s, string_repeat("9", digits)); |
| 189 | + min_s = "1" + string_repeat("0", digits); |
193 | 190 | out << " | "; |
194 | 191 | } |
195 | 192 | uniform_range(min_s, max_s); |
@@ -318,49 +315,6 @@ std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = { |
318 | 315 | std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'}; |
319 | 316 | std::unordered_set<char> ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'}; |
320 | 317 |
|
321 | | -template <typename Iterator> |
322 | | -std::string join(Iterator begin, Iterator end, const std::string & separator) { |
323 | | - std::ostringstream result; |
324 | | - if (begin != end) { |
325 | | - result << *begin; |
326 | | - for (Iterator it = begin + 1; it != end; ++it) { |
327 | | - result << separator << *it; |
328 | | - } |
329 | | - } |
330 | | - return result.str(); |
331 | | -} |
332 | | - |
333 | | -static std::vector<std::string> split(const std::string & str, const std::string & delimiter) { |
334 | | - std::vector<std::string> tokens; |
335 | | - size_t start = 0; |
336 | | - size_t end = str.find(delimiter); |
337 | | - |
338 | | - while (end != std::string::npos) { |
339 | | - tokens.push_back(str.substr(start, end - start)); |
340 | | - start = end + delimiter.length(); |
341 | | - end = str.find(delimiter, start); |
342 | | - } |
343 | | - |
344 | | - tokens.push_back(str.substr(start)); |
345 | | - |
346 | | - return tokens; |
347 | | -} |
348 | | - |
349 | | -static std::string repeat(const std::string & str, size_t n) { |
350 | | - if (n == 0) { |
351 | | - return ""; |
352 | | - } |
353 | | - |
354 | | - std::string result; |
355 | | - result.reserve(str.length() * n); |
356 | | - |
357 | | - for (size_t i = 0; i < n; ++i) { |
358 | | - result += str; |
359 | | - } |
360 | | - |
361 | | - return result; |
362 | | -} |
363 | | - |
364 | 318 | static std::string replacePattern(const std::string & input, const std::regex & regex, const std::function<std::string(const std::smatch &)> & replacement) { |
365 | 319 | std::smatch match; |
366 | 320 | std::string result; |
@@ -389,6 +343,7 @@ static std::string format_literal(const std::string & literal) { |
389 | 343 |
|
390 | 344 | class SchemaConverter { |
391 | 345 | private: |
| 346 | + friend std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb); |
392 | 347 | std::function<json(const std::string &)> _fetch_json; |
393 | 348 | bool _dotall; |
394 | 349 | std::map<std::string, std::string> _rules; |
@@ -418,7 +373,7 @@ class SchemaConverter { |
418 | 373 | for (size_t i = 0; i < alt_schemas.size(); i++) { |
419 | 374 | rules.push_back(visit(alt_schemas[i], name + (name.empty() ? "alternative-" : "-") + std::to_string(i))); |
420 | 375 | } |
421 | | - return join(rules.begin(), rules.end(), " | "); |
| 376 | + return string_join(rules, " | "); |
422 | 377 | } |
423 | 378 |
|
424 | 379 | std::string _visit_pattern(const std::string & pattern, const std::string & name) { |
@@ -481,7 +436,7 @@ class SchemaConverter { |
481 | 436 | for (const auto & item : ret) { |
482 | 437 | results.push_back(to_rule(item)); |
483 | 438 | } |
484 | | - return std::make_pair(join(results.begin(), results.end(), " "), false); |
| 439 | + return std::make_pair(string_join(results, " "), false); |
485 | 440 | }; |
486 | 441 |
|
487 | 442 | while (i < length) { |
@@ -539,7 +494,7 @@ class SchemaConverter { |
539 | 494 | } |
540 | 495 | curly_brackets += '}'; |
541 | 496 | i++; |
542 | | - auto nums = split(curly_brackets.substr(1, curly_brackets.length() - 2), ","); |
| 497 | + auto nums = string_split(curly_brackets.substr(1, curly_brackets.length() - 2), ","); |
543 | 498 | int min_times = 0; |
544 | 499 | int max_times = std::numeric_limits<int>::max(); |
545 | 500 | try { |
@@ -854,7 +809,7 @@ class SchemaConverter { |
854 | 809 | return; |
855 | 810 | } |
856 | 811 | std::string pointer = ref.substr(ref.find('#') + 1); |
857 | | - std::vector<std::string> tokens = split(pointer, "/"); |
| 812 | + std::vector<std::string> tokens = string_split(pointer, "/"); |
858 | 813 | for (size_t i = 1; i < tokens.size(); ++i) { |
859 | 814 | std::string sel = tokens[i]; |
860 | 815 | if (target.is_null() || !target.contains(sel)) { |
@@ -905,7 +860,7 @@ class SchemaConverter { |
905 | 860 | for (const auto & v : schema["enum"]) { |
906 | 861 | enum_values.push_back(_generate_constant_rule(v)); |
907 | 862 | } |
908 | | - return _add_rule(rule_name, "(" + join(enum_values.begin(), enum_values.end(), " | ") + ") space"); |
| 863 | + return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ") space"); |
909 | 864 | } else if ((schema_type.is_null() || schema_type == "object") |
910 | 865 | && (schema.contains("properties") || |
911 | 866 | (schema.contains("additionalProperties") && schema["additionalProperties"] != true))) { |
@@ -1019,10 +974,10 @@ class SchemaConverter { |
1019 | 974 |
|
1020 | 975 | void check_errors() { |
1021 | 976 | if (!_errors.empty()) { |
1022 | | - throw std::runtime_error("JSON schema conversion failed:\n" + join(_errors.begin(), _errors.end(), "\n")); |
| 977 | + throw std::runtime_error("JSON schema conversion failed:\n" + string_join(_errors, "\n")); |
1023 | 978 | } |
1024 | 979 | if (!_warnings.empty()) { |
1025 | | - fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", join(_warnings.begin(), _warnings.end(), "; ").c_str()); |
| 980 | + fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str()); |
1026 | 981 | } |
1027 | 982 | } |
1028 | 983 |
|
|
0 commit comments