Skip to content

Commit 1e761b7

Browse files
[Fix]Fix the grammar_compiler. (#395)
This PR fixes a minor bug in Grammar Compiler, which is reported by #382. The bug is led by a small optimization, which is used to reduce the time to fill the `tmp_rejected_indices`. In the implementation, jf `tmp_rejected_indices == false`, then it shouldn't be set to `true` again, since some infomation will miss. Regretfully, when `tmp_rejected_indices` is initialized with `false`, and the `tmp_rejected_indices` is set to `true` until the end of the function, then the incorrect `tmp_rejected_indices` will be passed. Signed-off-by: Yuchuan <[email protected]> --------- Signed-off-by: Yuchuan <[email protected]>
1 parent 2b4775c commit 1e761b7

File tree

2 files changed

+36
-3
lines changed

2 files changed

+36
-3
lines changed

cpp/grammar_compiler.cc

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,9 @@ bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterCheck(
288288
if (fill_reject_indices) {
289289
tmp_rejected_indices_.push_back(i);
290290
fill_reject_indices =
291-
tmp_rejected_indices_.size() < AdaptiveTokenMask::USE_BITSET_THRESHOLD;
291+
tmp_rejected_indices_.size() >= AdaptiveTokenMask::USE_BITSET_THRESHOLD
292+
? false
293+
: fill_reject_indices;
292294
} else {
293295
i = last_rejected_range - 1;
294296
}
@@ -387,7 +389,9 @@ bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterCheck(
387389
tmp_rejected_indices_.push_back(i);
388390
last_rejected_range = subtree_nodes_range[i];
389391
fill_reject_indices =
390-
tmp_rejected_indices_.size() < AdaptiveTokenMask::USE_BITSET_THRESHOLD;
392+
tmp_rejected_indices_.size() >= AdaptiveTokenMask::USE_BITSET_THRESHOLD
393+
? false
394+
: fill_reject_indices;
391395
}
392396
}
393397
}
@@ -396,7 +400,9 @@ bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterCheck(
396400
for (int i = interval.second; i < next_interval.first; ++i) {
397401
tmp_rejected_indices_.push_back(i);
398402
}
399-
fill_reject_indices = tmp_rejected_indices_.size() < AdaptiveTokenMask::USE_BITSET_THRESHOLD;
403+
fill_reject_indices = tmp_rejected_indices_.size() >= AdaptiveTokenMask::USE_BITSET_THRESHOLD
404+
? false
405+
: fill_reject_indices;
400406
}
401407
}
402408

tests/python/test_grammar_matcher_json_schema.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,5 +511,32 @@ def test_implicit_left_recursion_schema():
511511
_ = grammar_compiler.compile_json_schema(schema=json.dumps(json_schema))
512512

513513

514+
@pytest.mark.hf_token_required
515+
def test_regression_accept_invalid_token():
516+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-235B-A22B-Instruct-2507-FP8")
517+
vocab_size = 151936
518+
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
519+
tokenizer, vocab_size=vocab_size, stop_token_ids=[tokenizer.eos_token_id]
520+
)
521+
grammar_compiler = xgr.GrammarCompiler(tokenizer_info=tokenizer_info)
522+
ctx = grammar_compiler.compile_json_schema(
523+
schema="""{"type": "object", "properties": {"value": {"type": ["string", "null"], "maxLength": 10}, "nested": {"type": "object", "properties": {"value": {"type": ["string", "null"]}, "nested_nested": {"type": "array", "items": {"type": ["string", "null"]}}}, "required": ["value", "nested_nested"], "maxItems": 10, "minItems": 1}}, "required": ["value", "nested"], "additionalProperties": false}"""
524+
)
525+
matcher = xgr.GrammarMatcher(ctx, max_rollback_tokens=200, override_stop_tokens=None)
526+
token_bitmask = xgr.allocate_token_bitmask(vocab_size=vocab_size, batch_size=7)
527+
token_bitmask.fill_(0)
528+
for i, token in enumerate([4913, 957, 788, 330, 1072, 67212, 788]):
529+
if i == 0:
530+
accepted = True
531+
else:
532+
parent_pos = i - 1
533+
curr_token_id = token
534+
parent_bitmask = token_bitmask[parent_pos]
535+
# 32 boolean bitmask values are packed into 32-bit integers
536+
accepted = (parent_bitmask[curr_token_id // 32] & (1 << (curr_token_id % 32))) != 0
537+
assert matcher.accept_token(token) == accepted
538+
matcher.fill_next_token_bitmask(token_bitmask, i)
539+
540+
514541
if __name__ == "__main__":
515542
pytest.main(sys.argv)

0 commit comments

Comments
 (0)