Skip to content

Commit 3b9c3ae

Browse files
[Optim] Optim repetition expressions when the max repetition time is unbounded. (#403)
This PR optims the behavior of repetition expression when the max repetition time is unbounded. It keeps the invariant of the repetition states, instead of setting the upper as a large number. This can significantly speed lower the uncertainty when the max repetition time is unbounded, and the min repetition time is very small, like this schema: ``` { "type": "string", "minLength": 1 } ``` Signed-off-by: Yuchuan [[email protected]](mailto:[email protected])
1 parent 2ce4f23 commit 3b9c3ae

File tree

3 files changed

+73
-57
lines changed

3 files changed

+73
-57
lines changed

cpp/earley_parser.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ class StateEqualForParsing {
156156
bool operator()(const ParserState& lhs, const ParserState& rhs) const {
157157
return lhs.rule_id == rhs.rule_id && lhs.sequence_id == rhs.sequence_id &&
158158
lhs.element_id == rhs.element_id && lhs.rule_start_pos == rhs.rule_start_pos &&
159-
lhs.sub_element_id == rhs.sub_element_id;
159+
lhs.sub_element_id == rhs.sub_element_id && lhs.repeat_count == rhs.repeat_count;
160160
}
161161
};
162162

cpp/grammar_parser.cc

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <picojson.h>
99

1010
#include <cstdint>
11+
#include <string>
1112
#include <variant>
1213
#include <vector>
1314

@@ -758,12 +759,33 @@ int32_t EBNFParser::HandleQuestionQuantifier(int32_t grammar_expr_id) {
758759
int32_t EBNFParser::HandleRepetitionRange(
759760
const int32_t grammar_expr_id, int64_t lower, int64_t upper
760761
) {
762+
bool is_unbounded = false;
763+
int32_t new_element;
761764
if (upper == -1) {
762765
// The repeation is unbounded, e.g. {2,}
763-
upper = 0x7FFFFFFF; // Use a large number to represent unbounded
766+
is_unbounded = true;
767+
const auto& rule_expr = builder_.GetGrammarExpr(grammar_expr_id);
768+
if (rule_expr.type == GrammarBuilder::GrammarExprType::kCharacterClass) {
769+
std::vector<GrammarBuilder::CharacterClassElement> character_ranges;
770+
bool is_negative = rule_expr[0];
771+
for (int i = 1; i < static_cast<int>(rule_expr.size()); i += 2) {
772+
character_ranges.push_back({rule_expr[i], rule_expr[i + 1]});
773+
}
774+
new_element = builder_.AddCharacterClassStar(character_ranges, is_negative);
775+
} else {
776+
const auto& unbounded_rule_id =
777+
builder_.AddEmptyRule(builder_.GetNewRuleName(cur_rule_name_ + "_repeat_inf"));
778+
int recursion_sequence =
779+
builder_.AddSequence({grammar_expr_id, builder_.AddRuleRef(unbounded_rule_id)});
780+
int recursion_choice = builder_.AddChoices({builder_.AddEmptyStr(), recursion_sequence});
781+
builder_.UpdateRuleBody(unbounded_rule_id, recursion_choice);
782+
new_element = builder_.AddRuleRef(unbounded_rule_id);
783+
}
784+
upper = lower;
764785
}
765-
const auto repeat_name = builder_.GetNewRuleName(cur_rule_name_) + "_xgrammar_repetition_context";
766786
std::vector<int32_t> elements;
787+
const auto repeat_name = cur_rule_name_ + "_repeat_";
788+
int cnt = 1;
767789
int splited_count = lower >= 4 ? 4 : lower;
768790
int nullable_splited_count = 0;
769791
if (splited_count != 4) {
@@ -772,9 +794,9 @@ int32_t EBNFParser::HandleRepetitionRange(
772794
}
773795
// The repetition sentence.
774796
if (upper != (splited_count + nullable_splited_count)) {
775-
auto new_rule_name = builder_.GetNewRuleName(repeat_name);
776797
auto new_grammar_expr_id = builder_.AddChoices({builder_.AddSequence({grammar_expr_id})});
777-
auto new_rule_id = builder_.AddRule(new_rule_name, new_grammar_expr_id);
798+
auto new_rule_id =
799+
builder_.AddRuleWithHint(repeat_name + std::to_string(cnt++), new_grammar_expr_id);
778800
elements.push_back(builder_.AddRepeat(
779801
new_rule_id, lower - splited_count, upper - splited_count - nullable_splited_count
780802
));
@@ -783,20 +805,22 @@ int32_t EBNFParser::HandleRepetitionRange(
783805

784806
// The nullable exprs.
785807
for (int i = 0; i < nullable_splited_count; i++) {
786-
auto new_rule_name = builder_.GetNewRuleName(repeat_name);
787808
auto new_grammar_expr_id =
788809
builder_.AddChoices({builder_.AddEmptyStr(), builder_.AddSequence({grammar_expr_id})});
789-
auto new_rule_id = builder_.AddRule(new_rule_name, new_grammar_expr_id);
810+
auto new_rule_id =
811+
builder_.AddRuleWithHint(repeat_name + std::to_string(cnt++), new_grammar_expr_id);
790812
elements.push_back(builder_.AddRuleRef(new_rule_id));
791813
}
792814

793815
for (int i = 0; i < splited_count; i++) {
794-
auto new_rule_name = builder_.GetNewRuleName(repeat_name);
795816
auto new_grammar_expr_id = builder_.AddChoices({builder_.AddSequence({grammar_expr_id})});
796-
auto new_rule_id = builder_.AddRule(new_rule_name, new_grammar_expr_id);
817+
auto new_rule_id =
818+
builder_.AddRuleWithHint(repeat_name + std::to_string(cnt++), new_grammar_expr_id);
797819
elements.push_back(builder_.AddRuleRef(new_rule_id));
798820
}
799-
821+
if (is_unbounded) {
822+
elements.push_back(new_element);
823+
}
800824
// Add the lookahead elements
801825
std::vector<int32_t> lookahead_elements = elements;
802826
if (elements.empty()) {

tests/python/test_grammar_parser.py

Lines changed: 39 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,10 @@ def test_repetition_range_exact():
146146
"""Test repetition range with exact count {n}."""
147147
before = """root ::= "a"{3}
148148
"""
149-
expected = """root ::= (((root_1_xgrammar_repetition_context root_1_xgrammar_repetition_context_1 root_1_xgrammar_repetition_context_2)))
150-
root_1_xgrammar_repetition_context ::= (("a")) (=(root_1_xgrammar_repetition_context_1 root_1_xgrammar_repetition_context_2))
151-
root_1_xgrammar_repetition_context_1 ::= (("a")) (=(root_1_xgrammar_repetition_context_2))
152-
root_1_xgrammar_repetition_context_2 ::= (("a"))
149+
expected = """root ::= (((root_repeat_1 root_repeat_2 root_repeat_3)))
150+
root_repeat_1 ::= (("a")) (=(root_repeat_2 root_repeat_3))
151+
root_repeat_2 ::= (("a")) (=(root_repeat_3))
152+
root_repeat_3 ::= (("a"))
153153
"""
154154
grammar = _ebnf_to_grammar_no_normalization(before)
155155
after = str(grammar)
@@ -160,11 +160,11 @@ def test_repetition_range_min_max():
160160
"""Test repetition range with min and max {n,m}."""
161161
before = """root ::= "a"{2,4}
162162
"""
163-
expected = """root ::= (((root_1_xgrammar_repetition_context root_1_xgrammar_repetition_context_1 root_1_xgrammar_repetition_context_2 root_1_xgrammar_repetition_context_3)))
164-
root_1_xgrammar_repetition_context ::= ("" | ("a")) (=(root_1_xgrammar_repetition_context_1 root_1_xgrammar_repetition_context_2 root_1_xgrammar_repetition_context_3))
165-
root_1_xgrammar_repetition_context_1 ::= ("" | ("a")) (=(root_1_xgrammar_repetition_context_2 root_1_xgrammar_repetition_context_3))
166-
root_1_xgrammar_repetition_context_2 ::= (("a")) (=(root_1_xgrammar_repetition_context_3))
167-
root_1_xgrammar_repetition_context_3 ::= (("a"))
163+
expected = """root ::= (((root_repeat_1 root_repeat_2 root_repeat_3 root_repeat_4)))
164+
root_repeat_1 ::= ("" | ("a")) (=(root_repeat_2 root_repeat_3 root_repeat_4))
165+
root_repeat_2 ::= ("" | ("a")) (=(root_repeat_3 root_repeat_4))
166+
root_repeat_3 ::= (("a")) (=(root_repeat_4))
167+
root_repeat_4 ::= (("a"))
168168
"""
169169
grammar = _ebnf_to_grammar_no_normalization(before)
170170
after = str(grammar)
@@ -175,12 +175,10 @@ def test_repetition_range_min_only():
175175
"""Test repetition range with only min {n,}."""
176176
before = """root ::= "a"{2,}
177177
"""
178-
expected = """root ::= (((root_1_xgrammar_repetition_context{0, 2147483643} root_1_xgrammar_repetition_context_1 root_1_xgrammar_repetition_context_2 root_1_xgrammar_repetition_context_3 root_1_xgrammar_repetition_context_4)))
179-
root_1_xgrammar_repetition_context ::= (("a")) (=(root_1_xgrammar_repetition_context_1 root_1_xgrammar_repetition_context_2 root_1_xgrammar_repetition_context_3 root_1_xgrammar_repetition_context_4))
180-
root_1_xgrammar_repetition_context_1 ::= ("" | ("a")) (=(root_1_xgrammar_repetition_context_2 root_1_xgrammar_repetition_context_3 root_1_xgrammar_repetition_context_4))
181-
root_1_xgrammar_repetition_context_2 ::= ("" | ("a")) (=(root_1_xgrammar_repetition_context_3 root_1_xgrammar_repetition_context_4))
182-
root_1_xgrammar_repetition_context_3 ::= (("a")) (=(root_1_xgrammar_repetition_context_4))
183-
root_1_xgrammar_repetition_context_4 ::= (("a"))
178+
expected = """root ::= (((root_repeat_1 root_repeat_2 root_repeat_inf)))
179+
root_repeat_inf ::= ("" | ("a" root_repeat_inf))
180+
root_repeat_1 ::= (("a")) (=(root_repeat_2 root_repeat_inf))
181+
root_repeat_2 ::= (("a")) (=(root_repeat_inf))
184182
"""
185183
grammar = _ebnf_to_grammar_no_normalization(before)
186184
after = str(grammar)
@@ -275,12 +273,12 @@ def test_combined_features():
275273
rule2 ::= [0-9]+ "." [0-9]*
276274
"""
277275
expected = """root ::= (("start" root_1 "end"))
278-
rule1 ::= (((rule1_1_xgrammar_repetition_context rule1_1_xgrammar_repetition_context_1 rule1_1_xgrammar_repetition_context_2))) (=((":")))
276+
rule1 ::= (((rule1_repeat_1 rule1_repeat_2 rule1_repeat_3))) (=((":")))
279277
rule2 ::= ((rule2_1 "." [0-9]*))
280278
root_1 ::= ((((rule1) | (rule2)) root_1) | ((rule1) | (rule2)))
281-
rule1_1_xgrammar_repetition_context ::= ("" | ([a-z])) (=(rule1_1_xgrammar_repetition_context_1 rule1_1_xgrammar_repetition_context_2))
282-
rule1_1_xgrammar_repetition_context_1 ::= ("" | ([a-z])) (=(rule1_1_xgrammar_repetition_context_2))
283-
rule1_1_xgrammar_repetition_context_2 ::= (([a-z]))
279+
rule1_repeat_1 ::= ("" | ([a-z])) (=(rule1_repeat_2 rule1_repeat_3))
280+
rule1_repeat_2 ::= ("" | ([a-z])) (=(rule1_repeat_3))
281+
rule1_repeat_3 ::= (([a-z]))
284282
rule2_1 ::= (([0-9] rule2_1) | [0-9])
285283
"""
286284
grammar = _ebnf_to_grammar_no_normalization(before)
@@ -354,35 +352,29 @@ def test_repetition_range():
354352
"""
355353

356354
expected = """root ::= ((a b c d e f g))
357-
a ::= ((a_1_xgrammar_repetition_context a_1_xgrammar_repetition_context_1))
358-
b ::= ((b_1_xgrammar_repetition_context{0, 1} b_1_xgrammar_repetition_context_1 b_1_xgrammar_repetition_context_2 b_1_xgrammar_repetition_context_3 b_1_xgrammar_repetition_context_4))
359-
c ::= ((c_1_xgrammar_repetition_context c_1_xgrammar_repetition_context_1))
360-
d ::= ((d_1_xgrammar_repetition_context{0, 2147483643} d_1_xgrammar_repetition_context_1 d_1_xgrammar_repetition_context_2 d_1_xgrammar_repetition_context_3 d_1_xgrammar_repetition_context_4))
361-
e ::= ((e_1_xgrammar_repetition_context{0, 2147483643} e_1_xgrammar_repetition_context_1 e_1_xgrammar_repetition_context_2 e_1_xgrammar_repetition_context_3 e_1_xgrammar_repetition_context_4))
362-
f ::= ((f_1_xgrammar_repetition_context f_1_xgrammar_repetition_context_1 f_1_xgrammar_repetition_context_2))
355+
a ::= ((a_repeat_1 a_repeat_2))
356+
b ::= ((b_repeat_1{0, 1} b_repeat_2 b_repeat_3 b_repeat_4 b_repeat_5))
357+
c ::= ((c_repeat_1 c_repeat_2))
358+
d ::= ((d_repeat_inf))
359+
e ::= ((e_repeat_1 e_repeat_2 e_repeat_inf))
360+
f ::= ((f_repeat_1 f_repeat_2 f_repeat_3))
363361
g ::= ("")
364-
a_1_xgrammar_repetition_context ::= ("" | ("a")) (=(a_1_xgrammar_repetition_context_1))
365-
a_1_xgrammar_repetition_context_1 ::= (("a"))
366-
b_1_xgrammar_repetition_context ::= ((a) | ("b")) (=(b_1_xgrammar_repetition_context_1 b_1_xgrammar_repetition_context_2 b_1_xgrammar_repetition_context_3 b_1_xgrammar_repetition_context_4))
367-
b_1_xgrammar_repetition_context_1 ::= ("" | (a) | ("b")) (=(b_1_xgrammar_repetition_context_2 b_1_xgrammar_repetition_context_3 b_1_xgrammar_repetition_context_4))
368-
b_1_xgrammar_repetition_context_2 ::= ("" | (a) | ("b")) (=(b_1_xgrammar_repetition_context_3 b_1_xgrammar_repetition_context_4))
369-
b_1_xgrammar_repetition_context_3 ::= ("" | (a) | ("b")) (=(b_1_xgrammar_repetition_context_4))
370-
b_1_xgrammar_repetition_context_4 ::= ((a) | ("b"))
371-
c_1_xgrammar_repetition_context ::= ("" | ("c")) (=(c_1_xgrammar_repetition_context_1))
372-
c_1_xgrammar_repetition_context_1 ::= ("" | ("c"))
373-
d_1_xgrammar_repetition_context ::= (("d")) (=(d_1_xgrammar_repetition_context_1 d_1_xgrammar_repetition_context_2 d_1_xgrammar_repetition_context_3 d_1_xgrammar_repetition_context_4))
374-
d_1_xgrammar_repetition_context_1 ::= ("" | ("d")) (=(d_1_xgrammar_repetition_context_2 d_1_xgrammar_repetition_context_3 d_1_xgrammar_repetition_context_4))
375-
d_1_xgrammar_repetition_context_2 ::= ("" | ("d")) (=(d_1_xgrammar_repetition_context_3 d_1_xgrammar_repetition_context_4))
376-
d_1_xgrammar_repetition_context_3 ::= ("" | ("d")) (=(d_1_xgrammar_repetition_context_4))
377-
d_1_xgrammar_repetition_context_4 ::= ("" | ("d"))
378-
e_1_xgrammar_repetition_context ::= (("e")) (=(e_1_xgrammar_repetition_context_1 e_1_xgrammar_repetition_context_2 e_1_xgrammar_repetition_context_3 e_1_xgrammar_repetition_context_4))
379-
e_1_xgrammar_repetition_context_1 ::= ("" | ("e")) (=(e_1_xgrammar_repetition_context_2 e_1_xgrammar_repetition_context_3 e_1_xgrammar_repetition_context_4))
380-
e_1_xgrammar_repetition_context_2 ::= ("" | ("e")) (=(e_1_xgrammar_repetition_context_3 e_1_xgrammar_repetition_context_4))
381-
e_1_xgrammar_repetition_context_3 ::= (("e")) (=(e_1_xgrammar_repetition_context_4))
382-
e_1_xgrammar_repetition_context_4 ::= (("e"))
383-
f_1_xgrammar_repetition_context ::= (("f")) (=(f_1_xgrammar_repetition_context_1 f_1_xgrammar_repetition_context_2))
384-
f_1_xgrammar_repetition_context_1 ::= (("f")) (=(f_1_xgrammar_repetition_context_2))
385-
f_1_xgrammar_repetition_context_2 ::= (("f"))
362+
a_repeat_1 ::= ("" | ("a")) (=(a_repeat_2))
363+
a_repeat_2 ::= (("a"))
364+
b_repeat_1 ::= ((a) | ("b")) (=(b_repeat_2 b_repeat_3 b_repeat_4 b_repeat_5))
365+
b_repeat_2 ::= ("" | (a) | ("b")) (=(b_repeat_3 b_repeat_4 b_repeat_5))
366+
b_repeat_3 ::= ("" | (a) | ("b")) (=(b_repeat_4 b_repeat_5))
367+
b_repeat_4 ::= ("" | (a) | ("b")) (=(b_repeat_5))
368+
b_repeat_5 ::= ((a) | ("b"))
369+
c_repeat_1 ::= ("" | ("c")) (=(c_repeat_2))
370+
c_repeat_2 ::= ("" | ("c"))
371+
d_repeat_inf ::= ("" | ("d" d_repeat_inf))
372+
e_repeat_inf ::= ("" | ("e" e_repeat_inf))
373+
e_repeat_1 ::= (("e")) (=(e_repeat_2 e_repeat_inf))
374+
e_repeat_2 ::= (("e")) (=(e_repeat_inf))
375+
f_repeat_1 ::= (("f")) (=(f_repeat_2 f_repeat_3))
376+
f_repeat_2 ::= (("f")) (=(f_repeat_3))
377+
f_repeat_3 ::= (("f"))
386378
"""
387379

388380
grammar = _ebnf_to_grammar_no_normalization(before)

0 commit comments

Comments
 (0)