Skip to content

Commit dbc7f19

Browse files
ExtReMLapinCNE Pierre FICHEPOILSeven-Streams
authored
Json schema generation, limit the number of whitespaces (#414)
Fixes #412 --------- Signed-off-by: Yuchuan <[email protected]> Co-authored-by: CNE Pierre FICHEPOIL <[email protected]> Co-authored-by: Yuchuan <[email protected]>
1 parent c4d985e commit dbc7f19

File tree

13 files changed

+241
-41
lines changed

13 files changed

+241
-41
lines changed

cpp/grammar.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,11 @@ Grammar Grammar::FromJSONSchema(
4646
std::optional<int> indent,
4747
std::optional<std::pair<std::string, std::string>> separators,
4848
bool strict_mode,
49+
std::optional<int> max_whitespace_cnt,
4950
bool print_converted_ebnf
5051
) {
51-
auto ebnf_string = JSONSchemaToEBNF(schema, any_whitespace, indent, separators, strict_mode);
52+
auto ebnf_string =
53+
JSONSchemaToEBNF(schema, any_whitespace, indent, separators, strict_mode, max_whitespace_cnt);
5254
if (print_converted_ebnf) {
5355
XGRAMMAR_LOG(INFO) << "Converted EBNF: " << ebnf_string << std::endl;
5456
}

cpp/grammar_compiler.cc

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -554,8 +554,13 @@ AdaptiveTokenMask GrammarMatcherForTokenMaskCache::GetAdaptiveTokenMask(
554554

555555
/******************* GrammarCompiler::Impl *******************/
556556

557-
using SchemaKey =
558-
std::tuple<std::string, bool, std::optional<int>, std::pair<std::string, std::string>, bool>;
557+
using SchemaKey = std::tuple<
558+
std::string,
559+
bool,
560+
std::optional<int>,
561+
std::pair<std::string, std::string>,
562+
bool,
563+
std::optional<int>>;
559564
using StructuralTagKey = std::tuple<std::vector<StructuralTagItem>, std::vector<std::string>>;
560565
using GrammarKey = std::pair<std::string, std::string>;
561566

@@ -605,7 +610,8 @@ class GrammarCompiler::Impl {
605610
bool any_whitespace,
606611
std::optional<int> indent,
607612
std::optional<std::pair<std::string, std::string>> separators,
608-
bool strict_mode = true
613+
bool strict_mode = true,
614+
std::optional<int> max_whitespace_cnt = std::nullopt
609615
);
610616

611617
CompiledGrammar CompileStructuralTag(
@@ -768,8 +774,10 @@ CompiledGrammar GrammarCompiler::Impl::CompileJson() {
768774

769775
template <>
770776
CompiledGrammar GrammarCompiler::Impl::Compute<SchemaKey>(const SchemaKey& key) {
771-
const auto& [schema, any_whitespace, indent, separators, strict_mode] = key;
772-
auto grammar = Grammar::FromJSONSchema(schema, any_whitespace, indent, separators, strict_mode);
777+
const auto& [schema, any_whitespace, indent, separators, strict_mode, max_whitespace_cnt] = key;
778+
auto grammar = Grammar::FromJSONSchema(
779+
schema, any_whitespace, indent, separators, strict_mode, max_whitespace_cnt
780+
);
773781
return MultiThreadCompileGrammar(grammar);
774782
}
775783

@@ -802,17 +810,20 @@ CompiledGrammar GrammarCompiler::Impl::CompileJSONSchema(
802810
bool any_whitespace,
803811
std::optional<int> indent,
804812
std::optional<std::pair<std::string, std::string>> separators,
805-
bool strict_mode
813+
bool strict_mode,
814+
std::optional<int> max_whitespace_cnt
806815
) {
807816
if (!cache_enabled_) {
808-
return MultiThreadCompileGrammar(
809-
Grammar::FromJSONSchema(schema, any_whitespace, indent, separators, strict_mode)
810-
);
817+
return MultiThreadCompileGrammar(Grammar::FromJSONSchema(
818+
schema, any_whitespace, indent, separators, strict_mode, max_whitespace_cnt
819+
));
811820
}
812821
auto separators_value = separators.value_or(
813822
(indent == std::nullopt) ? std::make_pair(", ", ": ") : std::make_pair(",", ": ")
814823
);
815-
auto key = std::make_tuple(schema, any_whitespace, indent, separators_value, strict_mode);
824+
auto key = std::make_tuple(
825+
schema, any_whitespace, indent, separators_value, strict_mode, max_whitespace_cnt
826+
);
816827
return compile_cache_.Get(key);
817828
}
818829

@@ -876,9 +887,12 @@ CompiledGrammar GrammarCompiler::CompileJSONSchema(
876887
bool any_whitespace,
877888
std::optional<int> indent,
878889
std::optional<std::pair<std::string, std::string>> separators,
879-
bool strict_mode
890+
bool strict_mode,
891+
std::optional<int> max_whitespace_cnt
880892
) {
881-
return pimpl_->CompileJSONSchema(schema, any_whitespace, indent, separators, strict_mode);
893+
return pimpl_->CompileJSONSchema(
894+
schema, any_whitespace, indent, separators, strict_mode, max_whitespace_cnt
895+
);
882896
}
883897

884898
CompiledGrammar GrammarCompiler::CompileBuiltinJSONGrammar() {

cpp/grammar_functor.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,9 @@ class UsedRulesAnalyzer : public GrammarVisitor<std::vector<int32_t>> {
492492
visited.insert(rule_id);
493493
auto rule = base_grammar_->GetRule(rule_id);
494494
VisitExpr(rule.body_expr_id);
495+
if (rule.lookahead_assertion_id != -1) {
496+
VisitExpr(rule.lookahead_assertion_id);
497+
}
495498
}
496499

497500
return std::vector<int32_t>(visited.begin(), visited.end());

cpp/json_schema_converter.cc

Lines changed: 97 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,23 @@ using SchemaError = TypedError<SchemaErrorType>;
3939
*/
4040
class IndentManager {
4141
public:
42-
IndentManager(std::optional<int> indent, const std::string& separator, bool any_whitespace)
42+
IndentManager(
43+
std::optional<int> indent,
44+
const std::string& separator,
45+
bool any_whitespace,
46+
std::optional<int> max_whitespace_cnt
47+
)
4348
: any_whitespace_(any_whitespace),
4449
enable_newline_(indent.has_value()),
4550
indent_(indent.value_or(0)),
4651
separator_(separator),
4752
total_indent_(0),
48-
is_first_({true}) {}
53+
is_first_({true}),
54+
max_whitespace_cnt_(max_whitespace_cnt) {
55+
if (max_whitespace_cnt.has_value() && max_whitespace_cnt.value() <= 0) {
56+
XGRAMMAR_LOG(FATAL) << ("max_whitespace_cnt must be positive.");
57+
}
58+
}
4959

5060
/*! \brief Enter a new indent level. */
5161
void StartIndent() {
@@ -104,12 +114,17 @@ class IndentManager {
104114
std::string separator_;
105115
int64_t total_indent_;
106116
std::vector<bool> is_first_;
117+
std::optional<int> max_whitespace_cnt_;
107118
friend class JSONSchemaConverter;
108119
};
109120

110121
std::string IndentManager::StartSeparator() {
111122
if (any_whitespace_) {
112-
return "[ \\n\\t]*";
123+
if (!max_whitespace_cnt_.has_value()) {
124+
return "[ \\n\\t]*";
125+
} else {
126+
return "[ \\n\\t]{0," + std::to_string(max_whitespace_cnt_.value()) + "}";
127+
}
113128
}
114129
if (!enable_newline_) {
115130
return "\"\"";
@@ -119,7 +134,13 @@ std::string IndentManager::StartSeparator() {
119134

120135
std::string IndentManager::MiddleSeparator() {
121136
if (any_whitespace_) {
122-
return "[ \\n\\t]* \"" + separator_ + "\" [ \\n\\t]*";
137+
std::string whitespace_part;
138+
if (!max_whitespace_cnt_.has_value()) {
139+
whitespace_part = "[ \\n\\t]*";
140+
} else {
141+
whitespace_part = "[ \\n\\t]{0," + std::to_string(max_whitespace_cnt_.value()) + "}";
142+
}
143+
return whitespace_part + " \"" + separator_ + "\" " + whitespace_part;
123144
}
124145
if (!enable_newline_) {
125146
return "\"" + separator_ + "\"";
@@ -129,7 +150,11 @@ std::string IndentManager::MiddleSeparator() {
129150

130151
std::string IndentManager::EndSeparator() {
131152
if (any_whitespace_) {
132-
return "[ \\n\\t]*";
153+
if (!max_whitespace_cnt_.has_value()) {
154+
return "[ \\n\\t]*";
155+
} else {
156+
return "[ \\n\\t]{0," + std::to_string(max_whitespace_cnt_.value()) + "}";
157+
}
133158
}
134159
if (!enable_newline_) {
135160
return "\"\"";
@@ -139,7 +164,11 @@ std::string IndentManager::EndSeparator() {
139164

140165
std::string IndentManager::EmptySeparator() {
141166
if (any_whitespace_) {
142-
return "[ \\n\\t]*";
167+
if (!max_whitespace_cnt_.has_value()) {
168+
return "[ \\n\\t]*";
169+
} else {
170+
return "[ \\n\\t]{0," + std::to_string(max_whitespace_cnt_.value()) + "}";
171+
}
143172
}
144173
return "\"\"";
145174
}
@@ -148,9 +177,19 @@ std::string IndentManager::NextSeparator(bool is_end) {
148177
if (any_whitespace_) {
149178
if (is_first_.back() || is_end) {
150179
is_first_.back() = false;
151-
return "[ \\n\\t]*";
180+
if (!max_whitespace_cnt_.has_value()) {
181+
return "[ \\n\\t]*";
182+
} else {
183+
return "[ \\n\\t]{0," + std::to_string(max_whitespace_cnt_.value()) + "}";
184+
}
152185
} else {
153-
return "[ \\n\\t]* \"" + separator_ + "\" [ \\n\\t]*";
186+
std::string whitespace_part;
187+
if (!max_whitespace_cnt_.has_value()) {
188+
whitespace_part = "[ \\n\\t]*";
189+
} else {
190+
whitespace_part = "[ \\n\\t]{0," + std::to_string(max_whitespace_cnt_.value()) + "}";
191+
}
192+
return whitespace_part + " \"" + separator_ + "\" " + whitespace_part;
154193
}
155194
}
156195

@@ -189,6 +228,7 @@ class JSONSchemaConverter {
189228
std::optional<int> indent,
190229
std::optional<std::pair<std::string, std::string>> separators,
191230
bool strict_mode,
231+
std::optional<int> max_whitespace_cnt = std::nullopt,
192232
JSONFormat json_format = JSONFormat::kJSON
193233
);
194234

@@ -224,7 +264,6 @@ class JSONSchemaConverter {
224264
inline static const std::string kXMLEscape = "xml_escape";
225265
inline static const std::string kXMLString = "xml_string";
226266
inline static const std::string kXMLVariableName = "xml_variable_name";
227-
inline static const std::string kWhiteSpace = "[ \\n\\t]*";
228267

229268
/*! \brief Add the basic rules to the rules list and the basic_rules_cache. */
230269
void AddBasicRules(JSONFormat json_format);
@@ -517,6 +556,13 @@ class JSONSchemaConverter {
517556
bool any_whitespace_;
518557
// The cache for URI to rule. Mapping from the URI to the rule name.
519558
std::unordered_map<std::string, std::string> uri_to_rule_cache_;
559+
// The maximum number of whitespaces allowed when any_whitespace_ is true.
560+
std::optional<int> max_whitespace_cnt_;
561+
562+
const std::string kWhiteSpace =
563+
max_whitespace_cnt_.has_value()
564+
? "[ \\n\\t]{0," + std::to_string(max_whitespace_cnt_.value()) + "}"
565+
: "[ \\n\\t]*";
520566
};
521567

522568
JSONSchemaConverter::JSONSchemaConverter(
@@ -525,9 +571,13 @@ JSONSchemaConverter::JSONSchemaConverter(
525571
std::optional<int> indent,
526572
std::optional<std::pair<std::string, std::string>> separators,
527573
bool strict_mode,
574+
std::optional<int> max_whitespace_cnt,
528575
JSONFormat json_format
529576
)
530-
: json_schema_(json_schema), strict_mode_(strict_mode), any_whitespace_(any_whitespace) {
577+
: json_schema_(json_schema),
578+
strict_mode_(strict_mode),
579+
any_whitespace_(any_whitespace),
580+
max_whitespace_cnt_(max_whitespace_cnt) {
531581
if (!separators.has_value()) {
532582
if (indent == std::nullopt) {
533583
separators = std::make_pair(", ", ": ");
@@ -538,9 +588,15 @@ JSONSchemaConverter::JSONSchemaConverter(
538588
if (any_whitespace) {
539589
separators = std::make_pair(",", ":");
540590
}
541-
indentManager_ = IndentManager(indent, separators->first, any_whitespace);
591+
indentManager_ = IndentManager(indent, separators->first, any_whitespace, max_whitespace_cnt);
542592
if (any_whitespace) {
543-
colon_pattern_ = "[ \\n\\t]* \"" + separators->second + "\" [ \\n\\t]*";
593+
std::string whitespace_part;
594+
if (!max_whitespace_cnt_.has_value()) {
595+
whitespace_part = "[ \\n\\t]*";
596+
} else {
597+
whitespace_part = "[ \\n\\t]{0," + std::to_string(max_whitespace_cnt_.value()) + "}";
598+
}
599+
colon_pattern_ = whitespace_part + " \"" + separators->second + "\" " + whitespace_part;
544600
} else {
545601
colon_pattern_ = "\"" + separators->second + "\"";
546602
}
@@ -579,9 +635,9 @@ void JSONSchemaConverter::AddBasicRules(JSONFormat json_format) {
579635

580636
auto past_indent_manager = indentManager_;
581637
if (any_whitespace_) {
582-
indentManager_ = IndentManager(std::nullopt, ",", true);
638+
indentManager_ = IndentManager(std::nullopt, ",", true, std::nullopt);
583639
} else {
584-
indentManager_ = IndentManager(std::nullopt, ", ", false);
640+
indentManager_ = IndentManager(std::nullopt, ", ", false, std::nullopt);
585641
}
586642
AddJSONHelperRules();
587643
if (json_format == JSONFormat::kXML) {
@@ -628,14 +684,28 @@ void JSONSchemaConverter::AddJSONHelperRules() {
628684
ebnf_script_creator_.AddRule(
629685
kBasicEscape, "[\"\\\\/bfnrt] | \"u\" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9]"
630686
);
687+
std::string whitespace_part;
688+
if (!max_whitespace_cnt_.has_value()) {
689+
whitespace_part = "[ \\n\\t]*";
690+
} else {
691+
whitespace_part = "[ \\n\\t]{0," + std::to_string(max_whitespace_cnt_.value()) + "}";
692+
}
631693
ebnf_script_creator_.AddRule(
632694
kBasicStringSub,
633695
"(\"\\\"\" | [^\\0-\\x1f\\\"\\\\\\r\\n] " + kBasicStringSub + " | \"\\\\\" " + kBasicEscape +
634-
" " + kBasicStringSub + ") (= [ \\n\\t]* [,}\\]:])"
696+
" " + kBasicStringSub + ") (= " + whitespace_part + " [,}\\]:])"
635697
);
636698
}
637699

638700
void JSONSchemaConverter::AddXMLHelperRules() {
701+
std::string whitespace_part;
702+
if (any_whitespace_) {
703+
if (!max_whitespace_cnt_.has_value()) {
704+
whitespace_part = "[ \\n\\t]*";
705+
} else {
706+
whitespace_part = "[ \\n\\t]{0," + std::to_string(max_whitespace_cnt_.value()) + "}";
707+
}
708+
}
639709
ebnf_script_creator_.AddRule(
640710
kXMLEscape, "[\"\\\\/bfnrt] | \"u\" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9]"
641711
);
@@ -645,7 +715,7 @@ void JSONSchemaConverter::AddXMLHelperRules() {
645715
ebnf_script_creator_.AddRule(
646716
kXMLString,
647717
"(\"\" | [^<>&\\0-\\x1f\\\\\\r\\n] " + kXMLString + " | \"\\\\\" " + kXMLEscape + " " +
648-
kXMLString + " | " + kXMLEntity + " " + kXMLString + ") (= [ \\n\\t]*)"
718+
kXMLString + " | " + kXMLEntity + " " + kXMLString + ") (= " + whitespace_part + ")"
649719
);
650720
ebnf_script_creator_.AddRule(kXMLVariableName, "[a-zA-Z_] [a-zA-Z0-9_]*");
651721
}
@@ -3278,7 +3348,13 @@ std::string JSONSchemaConverter::VisitObject(
32783348
result += " \"}\"";
32793349
if (could_be_empty) {
32803350
// result = (result) | {}
3281-
auto rest = "\"{\" " + std::string(any_whitespace_ ? "[ \\n\\t]* " : "") + "\"}\"";
3351+
std::string whitespace_part;
3352+
if (max_whitespace_cnt_ == std::nullopt) {
3353+
whitespace_part = "[ \\n\\t]* ";
3354+
} else {
3355+
whitespace_part = "[ \\n\\t]{0," + std::to_string(*max_whitespace_cnt_) + "} ";
3356+
}
3357+
auto rest = "\"{\" " + std::string(any_whitespace_ ? whitespace_part : "") + "\"}\"";
32823358
if (result == "\"{\" \"}\"") {
32833359
result = rest;
32843360
} else {
@@ -3329,14 +3405,15 @@ std::string JSONSchemaToEBNF(
33293405
std::optional<int> indent,
33303406
std::optional<std::pair<std::string, std::string>> separators,
33313407
bool strict_mode,
3408+
std::optional<int> max_whitespace_cnt,
33323409
JSONFormat json_format
33333410
) {
33343411
picojson::value schema_value;
33353412
std::string err = picojson::parse(schema_value, schema);
33363413
XGRAMMAR_CHECK(err.empty()) << "Failed to parse JSON: " << err
33373414
<< ". The JSON string is:" << schema;
33383415
return JSONSchemaToEBNF(
3339-
schema_value, any_whitespace, indent, separators, strict_mode, json_format
3416+
schema_value, any_whitespace, indent, separators, strict_mode, max_whitespace_cnt, json_format
33403417
);
33413418
}
33423419

@@ -3346,10 +3423,11 @@ std::string JSONSchemaToEBNF(
33463423
std::optional<int> indent,
33473424
std::optional<std::pair<std::string, std::string>> separators,
33483425
bool strict_mode,
3426+
std::optional<int> max_whitespace_cnt,
33493427
JSONFormat json_format
33503428
) {
33513429
JSONSchemaConverter converter(
3352-
schema, any_whitespace, indent, separators, strict_mode, json_format
3430+
schema, any_whitespace, indent, separators, strict_mode, max_whitespace_cnt, json_format
33533431
);
33543432
return converter.Convert(json_format);
33553433
}

cpp/json_schema_converter.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ std::string JSONSchemaToEBNF(
4747
std::optional<int> indent = std::nullopt,
4848
std::optional<std::pair<std::string, std::string>> separators = std::nullopt,
4949
bool strict_mode = true,
50+
std::optional<int> max_whitespace_cnt = std::nullopt,
5051
JSONFormat json_format = JSONFormat::kJSON
5152
);
5253

@@ -76,6 +77,7 @@ std::string JSONSchemaToEBNF(
7677
std::optional<int> indent = std::nullopt,
7778
std::optional<std::pair<std::string, std::string>> separators = std::nullopt,
7879
bool strict_mode = true,
80+
std::optional<int> max_whitespace_cnt = std::nullopt,
7981
JSONFormat json_format = JSONFormat::kJSON
8082
);
8183

0 commit comments

Comments
 (0)