Skip to content

Commit 853c7ac

Browse files
[Fix] Fix lookahead assertion analyzer. (#401)
This PR fixes the issue reported in #400, i.e. the incorrect behavior when the lookahead assertion is nullable. Moreover, this PR adds exact lookahead to improve the efficiency at runtime. --------- Signed-off-by: Yuchuan <[email protected]>
1 parent 3b9c3ae commit 853c7ac

File tree

7 files changed

+123
-31
lines changed

7 files changed

+123
-31
lines changed

cpp/grammar_builder.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,12 @@ class GrammarBuilder {
279279
grammar_->rules_[rule_id].lookahead_assertion_id = lookahead_assertion_id;
280280
}
281281

282+
void UpdateLookaheadExact(int32_t rule_id, bool is_exact = true) {
283+
XGRAMMAR_CHECK(rule_id < static_cast<int32_t>(grammar_->rules_.size()))
284+
<< "Rule id " << rule_id << " is out of range.";
285+
grammar_->rules_[rule_id].is_exact_lookahead = is_exact;
286+
}
287+
282288
/*!
283289
* \brief Add a lookahead assertion to a rule referred by the given name. The lookahead
284290
* assertion should be a sequence GrammarExpr id. An id of -1 means no lookahead assertion.

cpp/grammar_compiler.cc

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -114,14 +114,21 @@ class GrammarMatcherForTokenMaskCache : public EarleyParser {
114114
std::pair<bool, bool> GrammarMatcherForTokenMaskCache::IsTokenPassLookaheadAssertion(
115115
const std::string& token, const std::vector<bool>& can_reach_end_stack
116116
) {
117+
bool accepted = true;
118+
bool can_reach_end = true;
117119
auto lookahead_assertion_id = grammar_->GetRule(init_rule_id).lookahead_assertion_id;
118120
if (lookahead_assertion_id == -1) {
119-
return {true, true};
121+
return {accepted, can_reach_end};
120122
}
121123
auto lookahead_state =
122124
ParserState(/*rule_id*/ -1, lookahead_assertion_id, 0, ParserState::kNoPrevInputPos, 0);
123125
PushStateAndExpand(lookahead_state);
124126
int token_len = token.size();
127+
if (IsCompleted()) {
128+
// If the lookahead assertion is already completed, we can accept the token.
129+
PopLastStates(1);
130+
return {accepted, can_reach_end};
131+
}
125132

126133
// Find all positions that can come to and end. Then check if the suffix from that position
127134
// can be accepted by the lookahead assertion.
@@ -140,20 +147,23 @@ std::pair<bool, bool> GrammarMatcherForTokenMaskCache::IsTokenPassLookaheadAsser
140147
// accepted chars: pos - i + 1
141148
// we need to rollback the pushed initial state as well
142149
PopLastStates(pos - i + 2);
143-
return {true, true};
150+
return {accepted, can_reach_end};
144151
}
145152
}
146153
// Case 2. The whole token is accepted
147154
if (last_accept_pos == token_len - 1) {
148155
PopLastStates(last_accept_pos - i + 2);
149-
return {true, false};
156+
can_reach_end = false;
157+
return {accepted, can_reach_end};
150158
}
151159
// Case 3. The token is not accepted. Check the next position.
152160
PopLastStates(last_accept_pos - i + 1);
153161
}
154162

155163
PopLastStates(1);
156-
return {false, false};
164+
can_reach_end = false;
165+
accepted = false;
166+
return {accepted, can_reach_end};
157167
}
158168

159169
// Comparator for std::pair<int32_t, std::string> based on the string value.
@@ -322,6 +332,7 @@ bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterCheck(
322332

323333
int prev_matched_size = 0;
324334
int last_rejected_range = 0;
335+
const bool& is_exact_lookahead = grammar_->GetRule(init_rule_id).is_exact_lookahead;
325336
const std::string* prev_token = nullptr;
326337
for (size_t interval_idx = 0; interval_idx < possible_intervals.size(); ++interval_idx) {
327338
const auto& interval = possible_intervals[interval_idx];
@@ -414,10 +425,7 @@ bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterCheck(
414425
// 1. If the current rule is the root rule (is_root_rule=true), there are no
415426
// uncertain tokens. Not accepted tokens are just rejected.
416427
// 2. If a token cannot pass the lookahead assertion, it is rejected.
417-
if ((!lookahead_result_pair.second) &&
418-
(std::binary_search(
419-
grammar_->exact_lookahead.begin(), grammar_->exact_lookahead.end(), init_rule_id
420-
))) {
428+
if ((!lookahead_result_pair.second) && is_exact_lookahead) {
421429
tmp_accepted_indices_.push_back(i);
422430
} else {
423431
tmp_uncertain_indices_.push_back(i);

cpp/grammar_functor.cc

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -576,17 +576,63 @@ class LookaheadAssertionAnalyzerImpl : public GrammarMutator {
576576
}
577577
for (int i = 0; i < static_cast<int>(grammar->NumRules()); ++i) {
578578
auto rule = grammar->GetRule(i);
579-
if (i == grammar->GetRootRuleId() || rule.lookahead_assertion_id != -1) {
579+
if (i == grammar->GetRootRuleId()) {
580+
continue;
581+
}
582+
if (rule.lookahead_assertion_id != -1) {
583+
builder_->UpdateLookaheadExact(i, IsExactLookaheadAssertion(i));
580584
continue;
581585
}
582586
auto look_head_assertion_id = DetectLookaheadAssertion(i);
583587
if (look_head_assertion_id != -1) {
584588
builder_->UpdateLookaheadAssertion(i, look_head_assertion_id);
589+
builder_->UpdateLookaheadExact(i);
585590
}
586591
}
587592
return builder_->Get(grammar->GetRootRuleId());
588593
}
589594

595+
bool IsExactLookaheadAssertion(int32_t rule_id) {
596+
XGRAMMAR_DCHECK(base_grammar_->GetRule(rule_id).lookahead_assertion_id != -1);
597+
bool found = false;
598+
for (int i = 0; i < static_cast<int>(base_grammar_->NumRules()); ++i) {
599+
auto rule = base_grammar_->GetRule(i);
600+
auto grammar_expr = base_grammar_->GetGrammarExpr(rule.body_expr_id);
601+
if (grammar_expr.type == GrammarExprType::kTagDispatch) {
602+
for (int j = 1; j < grammar_expr.size() - 3; j += 2) {
603+
if (grammar_expr[j] == rule_id) {
604+
return false;
605+
}
606+
}
607+
continue;
608+
}
609+
XGRAMMAR_DCHECK(grammar_expr.type == GrammarExprType::kChoices);
610+
for (auto sequence_id : grammar_expr) {
611+
auto sequence_expr = base_grammar_->GetGrammarExpr(sequence_id);
612+
if (sequence_expr.type != GrammarExprType::kSequence) {
613+
continue;
614+
}
615+
auto last_element = base_grammar_->GetGrammarExpr(sequence_expr.end()[-1]);
616+
if (last_element.type == GrammarExprType::kRuleRef && last_element[0] == rule_id &&
617+
i != rule_id) {
618+
return false;
619+
}
620+
621+
for (int j = 0; j < sequence_expr.size() - 1; ++j) {
622+
auto element_expr = base_grammar_->GetGrammarExpr(sequence_expr[j]);
623+
if (element_expr.type != GrammarExprType::kRuleRef || element_expr[0] != rule_id) {
624+
continue;
625+
}
626+
if (found) {
627+
return false;
628+
}
629+
found = true;
630+
}
631+
}
632+
}
633+
return found;
634+
}
635+
590636
int32_t DetectLookaheadAssertion(int32_t rule_id) {
591637
std::vector<int32_t> found_sequence; // Element ids
592638
bool found = false;
@@ -1648,7 +1694,7 @@ class RepetitionNormalizerImpl {
16481694
continue;
16491695
}
16501696
int repeat_rule_id = expr[0];
1651-
(*grammar)->exact_lookahead.push_back(repeat_rule_id);
1697+
grammar->ImplPtr()->GetRule(repeat_rule_id).is_exact_lookahead = true;
16521698
if (std::binary_search(
16531699
(*grammar)->allow_empty_rule_ids.begin(),
16541700
(*grammar)->allow_empty_rule_ids.end(),
@@ -1658,7 +1704,6 @@ class RepetitionNormalizerImpl {
16581704
expr.SetData(1, 0); // Set min repeat to 0
16591705
}
16601706
}
1661-
std::sort((*grammar)->exact_lookahead.begin(), (*grammar)->exact_lookahead.end());
16621707
}
16631708
};
16641709

cpp/grammar_impl.h

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ class Grammar::Impl {
7878
/*! \brief The id of the associated lookahead assertion expr. For now it must be a id of a
7979
* sequence GrammarExpr. -1 if not exists. */
8080
int32_t lookahead_assertion_id = -1;
81+
/*! \brief Whether the lookahead assertion is exact. */
82+
bool is_exact_lookahead = false;
8183
};
8284

8385
/*! \brief Get the number of rules. */
@@ -88,6 +90,11 @@ class Grammar::Impl {
8890
<< "rule_id " << rule_id << " is out of bound";
8991
return rules_[rule_id];
9092
}
93+
Rule& GetRule(int32_t rule_id) {
94+
XGRAMMAR_DCHECK(rule_id >= 0 && rule_id < static_cast<int32_t>(rules_.size()))
95+
<< "rule_id " << rule_id << " is out of bound";
96+
return rules_[rule_id];
97+
}
9198
/*! \brief Get the root rule id of the grammar. */
9299
int32_t GetRootRuleId() const { return root_rule_id_; }
93100
/*! \brief Get the root rule of the grammar. */
@@ -250,9 +257,6 @@ class Grammar::Impl {
250257
/*! \brief The ids of the rules that are allowed to be empty. */
251258
std::vector<int32_t> allow_empty_rule_ids;
252259

253-
/*! \brief Store the lookahead which are exact, used to reduce uncertainty.*/
254-
std::vector<int32_t> exact_lookahead;
255-
256260
friend class GrammarBuilder;
257261
friend class GrammarCompiler;
258262

@@ -264,7 +268,8 @@ XGRAMMAR_MEMBER_ARRAY(
264268
Grammar::Impl::Rule,
265269
&Grammar::Impl::Rule::name,
266270
&Grammar::Impl::Rule::body_expr_id,
267-
&Grammar::Impl::Rule::lookahead_assertion_id
271+
&Grammar::Impl::Rule::lookahead_assertion_id,
272+
&Grammar::Impl::Rule::is_exact_lookahead
268273
);
269274

270275
XGRAMMAR_MEMBER_TABLE(
@@ -282,9 +287,7 @@ XGRAMMAR_MEMBER_TABLE(
282287
"per_rule_fsms",
283288
&Grammar::Impl::per_rule_fsms,
284289
"allow_empty_rule_ids",
285-
&Grammar::Impl::allow_empty_rule_ids,
286-
"exact_lookahead",
287-
&Grammar::Impl::exact_lookahead
290+
&Grammar::Impl::allow_empty_rule_ids
288291
);
289292

290293
} // namespace xgrammar

cpp/support/json_serializer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class SerializeVersion {
6262
* \brief The current serialization version. When the serialization result of any object in
6363
* XGrammar is changed, this version should be bumped.
6464
*/
65-
static constexpr const char kXGrammarSerializeVersion[] = "v4";
65+
static constexpr const char kXGrammarSerializeVersion[] = "v5";
6666
};
6767

6868
/*!

tests/python/test_grammar_matcher_regex.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import time
33

44
import pytest
5+
import torch
56
from transformers import AutoTokenizer
67

78
import xgrammar as xgr
@@ -173,5 +174,37 @@ def test_regex_with_large_range_compilation():
173174
print(f"Time to compile regex with large range: {(time_end - time_start) / 1e3} us")
174175

175176

177+
@pytest.mark.hf_token_required
178+
def test_regression_lookahead_already_completed():
179+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
180+
tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer)
181+
xgr_compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=1)
182+
compiled_grammar = xgr_compiler.compile_regex(r"\/\*(\*+[^*\/]|[^*])*\*+\/")
183+
matcher = xgr.GrammarMatcher(compiled_grammar)
184+
185+
token_bitmask = xgr.allocate_token_bitmask(1, tokenizer_info.vocab_size)
186+
187+
def process_logit(input_ids: list, logit: torch.Tensor) -> torch.Tensor:
188+
if input_ids:
189+
last_token = input_ids[-1]
190+
assert matcher.accept_token(last_token)
191+
matcher.fill_next_token_bitmask(token_bitmask)
192+
xgr.apply_token_bitmask_inplace(logit, token_bitmask)
193+
return logit
194+
195+
def process_tokens(tokens: list):
196+
for i in range(len(tokens)):
197+
logit = torch.zeros((tokenizer_info.vocab_size,), dtype=torch.float)
198+
visible_tokens = tokens[:i]
199+
masked_logit = process_logit(visible_tokens, logit)
200+
assert masked_logit[tokens[i]] != float(
201+
"-inf"
202+
), f"token {i} ({tokens[i]}, {tokenizer.decode(tokens[i])!r}) is masked"
203+
204+
text = "/* */"
205+
tokens = tokenizer.encode(text)
206+
process_tokens(tokens)
207+
208+
176209
if __name__ == "__main__":
177210
pytest.main(sys.argv)

tests/python/test_serialization.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,15 @@ def construct_compiled_grammar():
4141

4242
def test_get_serialization_version():
4343
"""Test the version of the serialized JSON string."""
44-
assert xgr.get_serialization_version() == "v4"
44+
assert xgr.get_serialization_version() == "v5"
4545

4646

4747
def test_serialize_grammar():
4848
"""Test Grammar serialization produces expected JSON string."""
4949
grammar = construct_grammar()
5050
serialized = grammar.serialize_json()
5151
expected_json = {
52-
"rules": [["rule1", 4, 9], ["root_rule", 8, -1]],
52+
"rules": [["rule1", 4, 9, True], ["root_rule", 8, -1, False]],
5353
"grammar_expr_data": [0, 2, 7, 10, 14, 18, 21, 24, 28, 31],
5454
"grammar_expr_indptr": [
5555
# fmt: off
@@ -60,8 +60,7 @@ def test_serialize_grammar():
6060
"complete_fsm": None,
6161
"per_rule_fsms": [],
6262
"allow_empty_rule_ids": [],
63-
"exact_lookahead": [],
64-
"__VERSION__": "v4",
63+
"__VERSION__": "v5",
6564
}
6665
# The fsms are the same one, but the start state and end states are different.
6766
assert json.loads(serialized) == expected_json
@@ -70,7 +69,7 @@ def test_serialize_grammar():
7069
def test_serialize_grammar_exception():
7170
"""Test Grammar serialization produces expected JSON string."""
7271
expected_json = {
73-
"rules": [["rule1", 4, 9], ["root_rule", 8, -1]],
72+
"rules": [["rule1", 4, 9, True], ["root_rule", 8, -1, False]],
7473
"grammar_expr_data": [0, 2, 7, 10, 14, 18, 21, 24, 28, 31],
7574
"grammar_expr_indptr": [
7675
# fmt: off
@@ -81,15 +80,14 @@ def test_serialize_grammar_exception():
8180
"allow_empty_rule_ids": [],
8281
"complete_fsm": None,
8382
"per_rule_fsms": [],
84-
"exact_lookahead": [],
85-
"__VERSION__": "v4",
83+
"__VERSION__": "v5",
8684
}
8785

8886
expected_json["__VERSION__"] = "v1" # Change version to trigger error
8987
with pytest.raises(xgr.DeserializeVersionError):
9088
xgr.Grammar.deserialize_json(json.dumps(expected_json))
9189

92-
expected_json["__VERSION__"] = "v4"
90+
expected_json["__VERSION__"] = "v5"
9391
expected_json.pop("rules") # Remove required field to trigger error
9492
with pytest.raises(xgr.DeserializeFormatError):
9593
xgr.Grammar.deserialize_json(json.dumps(expected_json))
@@ -141,7 +139,7 @@ def test_serialize_tokenizer_info():
141139
'"decoded_vocab":["1","212","a","A","b","\\u00e4\\u00b8\\u0080","-","aBc","abc"],'
142140
'"sorted_decoded_vocab":[[6,"-"],[3,"A"],[2,"a"],[7,"aBc"],[8,"abc"],[4,"b"],[5,"\\u00e4\\u00b8\\u0080"]],'
143141
'"trie_subtree_nodes_range":[1,2,5,4,5,6,7],'
144-
'"__VERSION__":"v4"}'
142+
'"__VERSION__":"v5"}'
145143
)
146144
assert json.loads(serialized) == json.loads(expected_json)
147145

@@ -195,7 +193,7 @@ def test_serialize_compiled_grammar():
195193

196194
expected_json = {
197195
"grammar": {
198-
"rules": [["rule1", 4, 6], ["root_rule", 10, -1]],
196+
"rules": [["rule1", 4, 6, True], ["root_rule", 10, -1, False]],
199197
"grammar_expr_data": [0, 2, 7, 10, 14, 18, 21, 24, 27, 30, 34],
200198
"grammar_expr_indptr": [
201199
# fmt: off
@@ -215,15 +213,14 @@ def test_serialize_compiled_grammar():
215213
[{'data_': [[0, 47, 3], [58, 127, 3], [192, 223, 1], [224, 239, 4], [240, 247, 5], [128, 191, 3], [-2, 0, 2], [128, 191, 1], [128, 191, 4], [-2, 0, 8], [97, 97, 6]],
216214
'indptr_': [0, 5, 6, 6, 7, 8, 9, 9, 10, 11]}, 7, [6], False]],
217215
# fmt: on
218-
"exact_lookahead": [],
219216
},
220217
"tokenizer_metadata": {
221218
"vocab_type": 1,
222219
"vocab_size": 10,
223220
"add_prefix_space": True,
224221
"stop_token_ids": [0, 1],
225222
},
226-
"__VERSION__": "v4",
223+
"__VERSION__": "v5",
227224
}
228225

229226
class AdaptiveTokenMask(BaseModel):

0 commit comments

Comments
 (0)