11#include " json-schema-to-grammar.h"
2+ #include " common.h"
3+
24#include < algorithm>
35#include < fstream>
46#include < map>
1113
1214using json = nlohmann::ordered_json;
1315
14- template <typename Iterator>
15- static std::string join (Iterator begin, Iterator end, const std::string & separator);
16-
17- static std::string repeat (const std::string & str, size_t n);
18-
1916static std::string build_repetition (const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = " " ) {
2017 auto has_max = max_items != std::numeric_limits<int >::max ();
2118
@@ -128,8 +125,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
128125 if (sub_len > 0 ) {
129126 auto from_sub = from.substr (i + 1 );
130127 auto to_sub = to.substr (i + 1 );
131- auto sub_zeros = repeat (" 0" , sub_len);
132- auto sub_nines = repeat (" 9" , sub_len);
128+ auto sub_zeros = string_repeat (" 0" , sub_len);
129+ auto sub_nines = string_repeat (" 9" , sub_len);
133130
134131 auto to_reached = false ;
135132 out << " (" ;
@@ -188,8 +185,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
188185 auto max_digits = max_s.length ();
189186
190187 for (auto digits = min_digits; digits < max_digits; digits++) {
191- uniform_range (min_s, repeat (" 9" , digits));
192- min_s = " 1" + repeat (" 0" , digits);
188+ uniform_range (min_s, string_repeat (" 9" , digits));
189+ min_s = " 1" + string_repeat (" 0" , digits);
193190 out << " | " ;
194191 }
195192 uniform_range (min_s, max_s);
@@ -318,49 +315,6 @@ std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
318315std::unordered_set<char > NON_LITERAL_SET = {' |' , ' .' , ' (' , ' )' , ' [' , ' ]' , ' {' , ' }' , ' *' , ' +' , ' ?' };
319316std::unordered_set<char > ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {' ^' , ' $' , ' .' , ' [' , ' ]' , ' (' , ' )' , ' |' , ' {' , ' }' , ' *' , ' +' , ' ?' };
320317
321- template <typename Iterator>
322- std::string join (Iterator begin, Iterator end, const std::string & separator) {
323- std::ostringstream result;
324- if (begin != end) {
325- result << *begin;
326- for (Iterator it = begin + 1 ; it != end; ++it) {
327- result << separator << *it;
328- }
329- }
330- return result.str ();
331- }
332-
333- static std::vector<std::string> split (const std::string & str, const std::string & delimiter) {
334- std::vector<std::string> tokens;
335- size_t start = 0 ;
336- size_t end = str.find (delimiter);
337-
338- while (end != std::string::npos) {
339- tokens.push_back (str.substr (start, end - start));
340- start = end + delimiter.length ();
341- end = str.find (delimiter, start);
342- }
343-
344- tokens.push_back (str.substr (start));
345-
346- return tokens;
347- }
348-
349- static std::string repeat (const std::string & str, size_t n) {
350- if (n == 0 ) {
351- return " " ;
352- }
353-
354- std::string result;
355- result.reserve (str.length () * n);
356-
357- for (size_t i = 0 ; i < n; ++i) {
358- result += str;
359- }
360-
361- return result;
362- }
363-
364318static std::string replacePattern (const std::string & input, const std::regex & regex, const std::function<std::string(const std::smatch &)> & replacement) {
365319 std::smatch match;
366320 std::string result;
@@ -389,6 +343,7 @@ static std::string format_literal(const std::string & literal) {
389343
390344class SchemaConverter {
391345private:
346+ friend std::string build_grammar (const std::function<void (const llama_grammar_builder &)> & cb);
392347 std::function<json(const std::string &)> _fetch_json;
393348 bool _dotall;
394349 std::map<std::string, std::string> _rules;
@@ -418,7 +373,7 @@ class SchemaConverter {
418373 for (size_t i = 0 ; i < alt_schemas.size (); i++) {
419374 rules.push_back (visit (alt_schemas[i], name + (name.empty () ? " alternative-" : " -" ) + std::to_string (i)));
420375 }
421- return join (rules. begin (), rules. end () , " | " );
376+ return string_join (rules, " | " );
422377 }
423378
424379 std::string _visit_pattern (const std::string & pattern, const std::string & name) {
@@ -481,7 +436,7 @@ class SchemaConverter {
481436 for (const auto & item : ret) {
482437 results.push_back (to_rule (item));
483438 }
484- return std::make_pair (join (results. begin (), results. end () , " " ), false );
439+ return std::make_pair (string_join (results, " " ), false );
485440 };
486441
487442 while (i < length) {
@@ -539,7 +494,7 @@ class SchemaConverter {
539494 }
540495 curly_brackets += ' }' ;
541496 i++;
542- auto nums = split (curly_brackets.substr (1 , curly_brackets.length () - 2 ), " ," );
497+ auto nums = string_split (curly_brackets.substr (1 , curly_brackets.length () - 2 ), " ," );
543498 int min_times = 0 ;
544499 int max_times = std::numeric_limits<int >::max ();
545500 try {
@@ -854,7 +809,7 @@ class SchemaConverter {
854809 return ;
855810 }
856811 std::string pointer = ref.substr (ref.find (' #' ) + 1 );
857- std::vector<std::string> tokens = split (pointer, " /" );
812+ std::vector<std::string> tokens = string_split (pointer, " /" );
858813 for (size_t i = 1 ; i < tokens.size (); ++i) {
859814 std::string sel = tokens[i];
860815 if (target.is_null () || !target.contains (sel)) {
@@ -905,7 +860,7 @@ class SchemaConverter {
905860 for (const auto & v : schema[" enum" ]) {
906861 enum_values.push_back (_generate_constant_rule (v));
907862 }
908- return _add_rule (rule_name, " (" + join (enum_values. begin (), enum_values. end () , " | " ) + " ) space" );
863+ return _add_rule (rule_name, " (" + string_join (enum_values, " | " ) + " ) space" );
909864 } else if ((schema_type.is_null () || schema_type == " object" )
910865 && (schema.contains (" properties" ) ||
911866 (schema.contains (" additionalProperties" ) && schema[" additionalProperties" ] != true ))) {
@@ -1019,10 +974,10 @@ class SchemaConverter {
1019974
1020975 void check_errors () {
1021976 if (!_errors.empty ()) {
1022- throw std::runtime_error (" JSON schema conversion failed:\n " + join (_errors. begin (), _errors. end () , " \n " ));
977+ throw std::runtime_error (" JSON schema conversion failed:\n " + string_join (_errors, " \n " ));
1023978 }
1024979 if (!_warnings.empty ()) {
1025- fprintf (stderr, " WARNING: JSON schema conversion was incomplete: %s\n " , join (_warnings. begin (), _warnings. end () , " ; " ).c_str ());
980+ fprintf (stderr, " WARNING: JSON schema conversion was incomplete: %s\n " , string_join (_warnings, " ; " ).c_str ());
1026981 }
1027982 }
1028983
@@ -1036,10 +991,27 @@ class SchemaConverter {
1036991};
1037992
1038993std::string json_schema_to_grammar (const json & schema) {
1039- SchemaConverter converter ([](const std::string &) { return json::object (); }, /* dotall= */ false );
1040- auto copy = schema;
1041- converter.resolve_refs (copy, " input" );
1042- converter.visit (copy, " " );
994+ return build_grammar ([&](const llama_grammar_builder & callbacks) {
995+ auto copy = schema;
996+ callbacks.resolve_refs (copy);
997+ callbacks.add_schema (" " , copy);
998+ });
999+ }
1000+
1001+ std::string build_grammar (const std::function<void (const llama_grammar_builder &)> & cb) {
1002+ SchemaConverter converter ([&](const std::string &) { return json (); }, /* dotall= */ false );
1003+ llama_grammar_builder builder {
1004+ /* .add_rule = */ [&](const std::string & name, const std::string & rule) {
1005+ return converter._add_rule (name, rule);
1006+ },
1007+ /* .add_schema = */ [&](const std::string & name, const nlohmann::ordered_json & schema) {
1008+ return converter.visit (schema, name == " root" ? " " : name);
1009+ },
1010+ /* .resolve_refs = */ [&](nlohmann::ordered_json & schema) {
1011+ converter.resolve_refs (schema, " " );
1012+ }
1013+ };
1014+ cb (builder);
10431015 converter.check_errors ();
10441016 return converter.format_grammar ();
10451017}
0 commit comments