@@ -64,7 +64,7 @@ class GrammarMatcherForTokenMaskCache : public EarleyParser {
6464 * \returns True if the rejected indices are filled as usual, False otherwise.
6565 * It's used to determine which construction function will be used.
6666 */
67- bool GetTokenMaskWithFirstCharacterCheck (
67+ bool GetTokenMaskWithFirstCharacterOptimization (
6868 const TokenizerInfo& tokenizer_info,
6969 const std::bitset<256>& first_char_mask,
7070 bool is_root_rule
@@ -93,6 +93,20 @@ class GrammarMatcherForTokenMaskCache : public EarleyParser {
9393 */
9494 std::bitset<256> GetFirstCharacterMask();
9595
96+ /*!
97+ * \brief Check all intervals for possible tokens.
98+ * \note All the possible tokens will be divided into accepted, rejected and uncertain tokens.
99+ */
100+ bool CheckAllPossibleTokens(
101+ const TokenizerInfo& tokenizer_info,
102+ const std::vector<std::pair<int32_t, int32_t>>& possible_intervals,
103+ bool speculative_calculation,
104+ const std::bitset<256>& speculative_mask,
105+ const std::optional<const DynamicBitset*>& definite_accepted_bitset,
106+ bool is_root_rule,
107+ bool fill_reject_indices
108+ );
109+
96110 // The id of the initial rule.
97111 int32_t init_rule_id;
98112
@@ -322,12 +336,11 @@ std::pair<bool, std::bitset<256>> GrammarMatcherForTokenMaskCache::GetSpeculativ
322336 return {can_be_applied, speculative_mask};
323337}
324338
325- bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterCheck (
339+ bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterOptimization (
326340 const TokenizerInfo& tokenizer_info, const std::bitset<256>& first_char_mask, bool is_root_rule
327341) {
328342 // the pair (a, b) means [a, b). Intialize the possible intervals.
329343 const auto& sorted_decoded_vocab = tokenizer_info.GetSortedDecodedVocab();
330- const auto& subtree_nodes_range = tokenizer_info.GetTrieSubtreeNodesRange();
331344 std::vector<std::pair<int32_t, int32_t>> possible_intervals;
332345 int possible_token_num =
333346 GetPossibleTokenIntervals(sorted_decoded_vocab, first_char_mask, possible_intervals);
@@ -357,9 +370,6 @@ bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterCheck(
357370 GetSpeculativeCalculation(sorted_decoded_vocab);
358371 }
359372
360- int prev_matched_size = 0;
361- int last_rejected_range = 0;
362- const bool& is_exact_lookahead = grammar_->GetRule(init_rule_id).is_exact_lookahead;
363373 std::optional<const DynamicBitset*> definite_accepted_bitset = std::nullopt;
364374 const bool is_tag_dispatch_rule =
365375 grammar_->GetGrammarExpr(grammar_->GetRule(init_rule_id).body_expr_id).type ==
@@ -369,6 +379,125 @@ bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterCheck(
369379 definite_accepted_bitset = &tag_dispatch_rule_id_to_second_slicing_bitset.at(init_rule_id);
370380 }
371381
382+ fill_reject_indices = CheckAllPossibleTokens(
383+ tokenizer_info,
384+ possible_intervals,
385+ speculative_calculation,
386+ speculative_mask,
387+ definite_accepted_bitset,
388+ is_root_rule,
389+ fill_reject_indices
390+ );
391+
392+ if (possible_intervals.back().second != static_cast<int>(sorted_decoded_vocab.size()) &&
393+ fill_reject_indices) {
394+ // If the last interval is not closed, we need to reject the rest tokens.
395+ for (int i = possible_intervals.back().second;
396+ i < static_cast<int>(sorted_decoded_vocab.size());
397+ ++i) {
398+ tmp_rejected_indices_.push_back(i);
399+ }
400+ }
401+
402+ return fill_reject_indices;
403+ }
404+
405+ AdaptiveTokenMask GrammarMatcherForTokenMaskCache::GetAdaptiveTokenMask(
406+ const TokenizerInfo& tokenizer_info, bool is_root_rule
407+ ) {
408+ const auto& sorted_decoded_vocab = tokenizer_info.GetSortedDecodedVocab();
409+ const int vocab_size = tokenizer_info.GetVocabSize();
410+ tmp_accepted_indices_.clear();
411+ tmp_rejected_indices_.clear();
412+ tmp_uncertain_indices_.clear();
413+ // For every character in the current token, stores whether it is possible to reach the end of
414+ // the rule when matching until this character. Store it in a stack for later rollback.
415+ tmp_can_reach_end_stack_.push_back(false);
416+ tmp_can_reach_end_prefix_or_stack_.push_back(false);
417+ std::bitset<256> first_character_mask = GetFirstCharacterMask();
418+ bool rejected_indices_are_filled = GetTokenMaskWithFirstCharacterOptimization(
419+ tokenizer_info, first_character_mask, is_root_rule
420+ );
421+ if (rejected_indices_are_filled) {
422+ return AdaptiveTokenMask(
423+ vocab_size,
424+ sorted_decoded_vocab,
425+ tmp_accepted_indices_,
426+ tmp_rejected_indices_,
427+ tmp_uncertain_indices_
428+ );
429+ } else {
430+ return AdaptiveTokenMask(
431+ vocab_size, sorted_decoded_vocab, tmp_accepted_indices_, tmp_uncertain_indices_
432+ );
433+ }
434+ }
435+
436+ std::bitset<256> GrammarMatcherForTokenMaskCache::GetFirstCharacterMask() {
437+ std::bitset<256> first_character_mask;
438+ const auto& sequence = grammar_->GetGrammarExpr(initial_state.sequence_id);
439+ if (!grammar_->per_rule_fsms[init_rule_id].has_value()) {
440+ const auto& sub_sequence = grammar_->GetGrammarExpr(sequence[initial_state.element_id]);
441+ switch (sub_sequence.type) {
442+ case Grammar::Impl::GrammarExprType::kByteString: {
443+ first_character_mask[sub_sequence[initial_state.sub_element_id]] = true;
444+ break;
445+ }
446+ case xgrammar::Grammar::Impl::GrammarExprType::kCharacterClass:
447+ case xgrammar::Grammar::Impl::GrammarExprType::kCharacterClassStar: {
448+ if (initial_state.sub_element_id == 0) {
449+ bool is_negative = sub_sequence[0];
450+ for (int i = 1; i < sub_sequence.size(); i += 2) {
451+ int left_char = static_cast<uint8_t>(sub_sequence[i]);
452+ int right_char = static_cast<uint8_t>(sub_sequence[i + 1]);
453+ for (int c = left_char; c <= right_char; ++c) {
454+ first_character_mask[c] = true;
455+ }
456+ }
457+ if (is_negative) {
458+ first_character_mask = ~first_character_mask;
459+ }
460+ break;
461+ }
462+ // Otherwise, it's matching a UTF-8 character. We can optimize the matching process
463+ // here.
464+ for (size_t i = 0x80; i < 0xC0; ++i) {
465+ first_character_mask[i] = true;
466+ }
467+ break;
468+ }
469+ default: {
470+ XGRAMMAR_LOG(FATAL) << "Unsupported grammar expr type: " << static_cast<int>(sequence.type);
471+ }
472+ }
473+ } else {
474+ const auto& fsm = grammar_->per_rule_fsms[init_rule_id].value();
475+ const auto& edges = fsm.GetFsm().GetEdges(initial_state.element_id);
476+ for (const auto& edge : edges) {
477+ if (edge.IsCharRange()) {
478+ for (int c = edge.min; c <= edge.max; ++c) {
479+ first_character_mask[c] = true;
480+ }
481+ }
482+ }
483+ }
484+ return first_character_mask;
485+ }
486+
487+ bool GrammarMatcherForTokenMaskCache::CheckAllPossibleTokens(
488+ const TokenizerInfo& tokenizer_info,
489+ const std::vector<std::pair<int32_t, int32_t>>& possible_intervals,
490+ bool speculative_calculation,
491+ const std::bitset<256>& speculative_mask,
492+ const std::optional<const DynamicBitset*>& definite_accepted_bitset,
493+ bool is_root_rule,
494+ bool fill_reject_indices
495+ ) {
496+ int prev_matched_size = 0;
497+ int last_rejected_range = 0;
498+ const auto& sorted_decoded_vocab = tokenizer_info.GetSortedDecodedVocab();
499+ const auto& subtree_nodes_range = tokenizer_info.GetTrieSubtreeNodesRange();
500+ const bool& is_exact_lookahead = grammar_->GetRule(init_rule_id).is_exact_lookahead;
372501 const std::string* prev_token = nullptr;
373502 for (size_t interval_idx = 0; interval_idx < possible_intervals.size(); ++interval_idx) {
374503 const auto& interval = possible_intervals[interval_idx];
@@ -512,101 +641,9 @@ bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterCheck(
512641
513642 // Rollback the last matched part.
514643 PopLastStates(prev_matched_size);
515-
516- if (possible_intervals.back().second != static_cast<int>(sorted_decoded_vocab.size()) &&
517- fill_reject_indices) {
518- // If the last interval is not closed, we need to reject the rest tokens.
519- for (int i = possible_intervals.back().second;
520- i < static_cast<int>(sorted_decoded_vocab.size());
521- ++i) {
522- tmp_rejected_indices_.push_back(i);
523- }
524- }
525-
526644 return fill_reject_indices;
527645}
528646
529- AdaptiveTokenMask GrammarMatcherForTokenMaskCache::GetAdaptiveTokenMask(
530- const TokenizerInfo& tokenizer_info, bool is_root_rule
531- ) {
532- const auto& sorted_decoded_vocab = tokenizer_info.GetSortedDecodedVocab();
533- const int vocab_size = tokenizer_info.GetVocabSize();
534- tmp_accepted_indices_.clear();
535- tmp_rejected_indices_.clear();
536- tmp_uncertain_indices_.clear();
537- // For every character in the current token, stores whether it is possible to reach the end of
538- // the rule when matching until this character. Store it in a stack for later rollback.
539- tmp_can_reach_end_stack_.push_back(false);
540- tmp_can_reach_end_prefix_or_stack_.push_back(false);
541- std::bitset<256> first_character_mask = GetFirstCharacterMask();
542- bool rejected_indices_are_filled =
543- GetTokenMaskWithFirstCharacterCheck(tokenizer_info, first_character_mask, is_root_rule);
544- if (rejected_indices_are_filled) {
545- return AdaptiveTokenMask(
546- vocab_size,
547- sorted_decoded_vocab,
548- tmp_accepted_indices_,
549- tmp_rejected_indices_,
550- tmp_uncertain_indices_
551- );
552- } else {
553- return AdaptiveTokenMask(
554- vocab_size, sorted_decoded_vocab, tmp_accepted_indices_, tmp_uncertain_indices_
555- );
556- }
557- }
558-
559- std::bitset<256> GrammarMatcherForTokenMaskCache::GetFirstCharacterMask() {
560- std::bitset<256> first_character_mask;
561- const auto& sequence = grammar_->GetGrammarExpr(initial_state.sequence_id);
562- if (!grammar_->per_rule_fsms[init_rule_id].has_value()) {
563- const auto& sub_sequence = grammar_->GetGrammarExpr(sequence[initial_state.element_id]);
564- switch (sub_sequence.type) {
565- case Grammar::Impl::GrammarExprType::kByteString: {
566- first_character_mask[sub_sequence[initial_state.sub_element_id]] = true;
567- break;
568- }
569- case xgrammar::Grammar::Impl::GrammarExprType::kCharacterClass:
570- case xgrammar::Grammar::Impl::GrammarExprType::kCharacterClassStar: {
571- if (initial_state.sub_element_id == 0) {
572- bool is_negative = sub_sequence[0];
573- for (int i = 1; i < sub_sequence.size(); i += 2) {
574- int left_char = static_cast<uint8_t>(sub_sequence[i]);
575- int right_char = static_cast<uint8_t>(sub_sequence[i + 1]);
576- for (int c = left_char; c <= right_char; ++c) {
577- first_character_mask[c] = true;
578- }
579- }
580- if (is_negative) {
581- first_character_mask = ~first_character_mask;
582- }
583- break;
584- }
585- // Otherwise, it's matching a UTF-8 character. We can optimize the matching process
586- // here.
587- for (size_t i = 0x80; i < 0xC0; ++i) {
588- first_character_mask[i] = true;
589- }
590- break;
591- }
592- default: {
593- XGRAMMAR_LOG(FATAL) << "Unsupported grammar expr type: " << static_cast<int>(sequence.type);
594- }
595- }
596- } else {
597- const auto& fsm = grammar_->per_rule_fsms[init_rule_id].value();
598- const auto& edges = fsm.GetFsm().GetEdges(initial_state.element_id);
599- for (const auto& edge : edges) {
600- if (edge.IsCharRange()) {
601- for (int c = edge.min; c <= edge.max; ++c) {
602- first_character_mask[c] = true;
603- }
604- }
605- }
606- }
607- return first_character_mask;
608- }
609-
610647/******************* GrammarCompilerNoCache *******************/
611648
612649/*!
0 commit comments