@@ -95,6 +95,15 @@ class GrammarMatcherForTokenMaskCache : public EarleyParser {
9595
9696 /* !
9797 * \brief Check all intervals for possible tokens.
98+ * \param tokenizer_info The tokenizer info.
99+ * \param possible_intervals The possible intervals for tokens.
100+ * \param speculative_calculation Whether to use speculative calculation.
101+ * \param speculative_mask The speculative mask for speculative calculation.
102+ * \param definite_accepted_bitset The definite accepted bitset for TagDispatch rules.
103+ * \param is_root_rule Whether to consider the parent rule. If false, there will be
104+ * no uncertain tokens. Useful for the root rule.
105+ * \param fill_reject_indices Whether to fill the rejected indices.
106+ * \return True if the rejected indices are filled as usual, False otherwise.
98107 * \note All the possible tokens will be divided into accepted, rejected and uncertain tokens.
99108 */
100109 bool CheckAllPossibleTokens (
@@ -107,6 +116,36 @@ class GrammarMatcherForTokenMaskCache : public EarleyParser {
107116 bool fill_reject_indices
108117 );
109118
119+ /* ! \brief Check each token in a given interval.
120+ \param tokenizer_info The tokenizer info.
121+ \param interval The interval to check.
122+ \param speculative_calculation Whether to use speculative calculation.
123+ \param speculative_mask The speculative mask for speculative calculation.
124+ \param definite_accepted_bitset The definite accepted bitset for TagDispatch rules.
125+ \param is_root_rule Whether to consider the parent rule. If false, there will be
126+ no uncertain tokens. Useful for the root rule.
127+ \param fill_reject_indices Whether to fill the rejected indices.
128+ \param last_rejected_range The last rejected subtree range. If a token's index is less than
129+ this value, it will be rejected directly.
130+ \param prev_token The previous token parsed in the parser.
131+ \param prev_matched_size The matched size of the previous token.
132+ \return True if the rejected indices are filled as usual, False otherwise.
133+ \note All the tokens in the given interval will be divided into accepted, rejected and
134+ uncertain tokens.
135+ */
136+ bool CheckTokensInInterval (
137+ const TokenizerInfo& tokenizer_info,
138+ const std::pair<int , int >& interval,
139+ bool speculative_calculation,
140+ const std::bitset<256 >& speculative_mask,
141+ const std::optional<const DynamicBitset*>& definite_accepted_bitset,
142+ bool is_root_rule,
143+ bool fill_reject_indices,
144+ int * last_rejected_range,
145+ const std::string* prev_token,
146+ int * prev_matched_size
147+ );
148+
110149 // The id of the initial rule.
111150 int32_t init_rule_id;
112151
@@ -495,153 +534,178 @@ bool GrammarMatcherForTokenMaskCache::CheckAllPossibleTokens(
495534) {
496535 int prev_matched_size = 0 ;
497536 int last_rejected_range = 0 ;
498- const auto & sorted_decoded_vocab = tokenizer_info.GetSortedDecodedVocab ();
499- const auto & subtree_nodes_range = tokenizer_info.GetTrieSubtreeNodesRange ();
500- const bool & is_exact_lookahead = grammar_->GetRule (init_rule_id).is_exact_lookahead ;
501537 const std::string* prev_token = nullptr ;
502538 for (size_t interval_idx = 0 ; interval_idx < possible_intervals.size (); ++interval_idx) {
503539 const auto & interval = possible_intervals[interval_idx];
504- for (int i = interval.first ; i < interval.second ; ++i) {
505- // Check if the current token is in the rejected range. i.e. check if the current token
506- // is on the subtree of the rejected token.
507- if (i < last_rejected_range) {
508- if (fill_reject_indices) {
509- tmp_rejected_indices_.push_back (i);
510- if (tmp_rejected_indices_.size () >= AdaptiveTokenMask::USE_BITSET_THRESHOLD) {
511- fill_reject_indices = false ;
512- }
513- } else {
514- i = last_rejected_range - 1 ;
540+ fill_reject_indices = CheckTokensInInterval (
541+ tokenizer_info,
542+ interval,
543+ speculative_calculation,
544+ speculative_mask,
545+ definite_accepted_bitset,
546+ is_root_rule,
547+ fill_reject_indices,
548+ &last_rejected_range,
549+ prev_token,
550+ &prev_matched_size
551+ );
552+ if (interval_idx != possible_intervals.size () - 1 && fill_reject_indices) {
553+ const auto & next_interval = possible_intervals[interval_idx + 1 ];
554+ for (int i = interval.second ; i < next_interval.first ; ++i) {
555+ tmp_rejected_indices_.push_back (i);
556+ }
557+ if (tmp_rejected_indices_.size () >= AdaptiveTokenMask::USE_BITSET_THRESHOLD) {
558+ fill_reject_indices = false ;
559+ }
560+ }
561+ }
562+
563+ // Rollback the last matched part.
564+ PopLastStates (prev_matched_size);
565+ return fill_reject_indices;
566+ }
567+
568+ bool GrammarMatcherForTokenMaskCache::CheckTokensInInterval (
569+ const TokenizerInfo& tokenizer_info,
570+ const std::pair<int , int >& interval,
571+ bool speculative_calculation,
572+ const std::bitset<256 >& speculative_mask,
573+ const std::optional<const DynamicBitset*>& definite_accepted_bitset,
574+ bool is_root_rule,
575+ bool fill_reject_indices,
576+ int * last_rejected_range,
577+ const std::string* prev_token,
578+ int * prev_matched_size
579+ ) {
580+ const auto & sorted_decoded_vocab = tokenizer_info.GetSortedDecodedVocab ();
581+ const auto & subtree_nodes_range = tokenizer_info.GetTrieSubtreeNodesRange ();
582+ const bool & is_exact_lookahead = grammar_->GetRule (init_rule_id).is_exact_lookahead ;
583+ for (int i = interval.first ; i < interval.second ; ++i) {
584+ // Check if the current token is in the rejected range. i.e. check if the current token
585+ // is on the subtree of the rejected token.
586+ if (i < *last_rejected_range) {
587+ if (fill_reject_indices) {
588+ tmp_rejected_indices_.push_back (i);
589+ if (tmp_rejected_indices_.size () >= AdaptiveTokenMask::USE_BITSET_THRESHOLD) {
590+ fill_reject_indices = false ;
515591 }
516- continue ;
592+ } else {
593+ i = *last_rejected_range - 1 ;
517594 }
595+ continue ;
596+ }
518597
519- const auto & token = sorted_decoded_vocab[i].second ;
520- // This optimization is useful for simple self-recursive rules, like string content.
521- if (speculative_calculation) {
522- // Optimization for tag dispatch rules.
523- if (definite_accepted_bitset.has_value ()) {
524- // If the token is empty, it must be accepted.
525- if (token.empty ()) {
526- tmp_accepted_indices_.push_back (i);
527- continue ;
528- }
529- // If the token doesn't contain tags or stop strings since the second character, and it
530- // will transit to the start state after consuming the first character, it must be
531- // accepted.
532- if (speculative_mask[static_cast <uint8_t >(token[0 ])] &&
533- (*definite_accepted_bitset.value ())[i]) {
534- tmp_accepted_indices_.push_back (i);
535- continue ;
536- }
537- } else {
538- bool all_accepted = true ;
539- for (char ch : token) {
540- // If the first character is not the ascii character or can't be accepted by the
541- // first character mask, we need to check them in the parser.
542- if (isascii (ch) == 0 || !speculative_mask[static_cast <uint8_t >(ch)]) {
543- all_accepted = false ;
544- break ;
545- }
546- }
547- if (all_accepted) {
548- tmp_accepted_indices_.push_back (i);
549- continue ;
598+ const auto & token = sorted_decoded_vocab[i].second ;
599+ // This optimization is useful for simple self-recursive rules, like string content.
600+ if (speculative_calculation) {
601+ // Optimization for tag dispatch rules.
602+ if (definite_accepted_bitset.has_value ()) {
603+ // If the token is empty, it must be accepted.
604+ if (token.empty ()) {
605+ tmp_accepted_indices_.push_back (i);
606+ continue ;
607+ }
608+ // If the token doesn't contain tags or stop strings since the second character, and it
609+ // will transit to the start state after consuming the first character, it must be
610+ // accepted.
611+ if (speculative_mask[static_cast <uint8_t >(token[0 ])] &&
612+ (*definite_accepted_bitset.value ())[i]) {
613+ tmp_accepted_indices_.push_back (i);
614+ continue ;
615+ }
616+ } else {
617+ bool all_accepted = true ;
618+ for (char ch : token) {
619+ // If the first character is not the ascii character or can't be accepted by the
620+ // first character mask, we need to check them in the parser.
621+ if (isascii (ch) == 0 || !speculative_mask[static_cast <uint8_t >(ch)]) {
622+ all_accepted = false ;
623+ break ;
550624 }
551625 }
552- }
553- // Many tokens may contain the same prefix, so we will avoid unnecessary matching
554- // by finding the longest common prefix with the previous token.
555- bool accepted = true ;
556- if (prev_token != nullptr ) {
557- int lcp_len =
558- std::mismatch (token.begin (), token.end (), prev_token->begin (), prev_token->end ())
559- .first -
560- token.begin ();
561- if (lcp_len > prev_matched_size) {
562- // Case 1. The common prefix is rejected by the matcher in the last token. Reject
563- // directly.
564- accepted = false ;
565- } else if (lcp_len < prev_matched_size) {
566- // Case 2. The common prefix is shorter than the previous matched size. Rollback
567- // the non-common part.
568- PopLastStates (prev_matched_size - lcp_len);
569- tmp_can_reach_end_stack_.erase (
570- tmp_can_reach_end_stack_.end () - (prev_matched_size - lcp_len),
571- tmp_can_reach_end_stack_.end ()
572- );
573- tmp_can_reach_end_prefix_or_stack_.erase (
574- tmp_can_reach_end_prefix_or_stack_.end () - (prev_matched_size - lcp_len),
575- tmp_can_reach_end_prefix_or_stack_.end ()
576- );
626+ if (all_accepted) {
627+ tmp_accepted_indices_.push_back (i);
628+ continue ;
577629 }
578- prev_matched_size = std::min (prev_matched_size, lcp_len);
579630 }
631+ }
632+ // Many tokens may contain the same prefix, so we will avoid unnecessary matching
633+ // by finding the longest common prefix with the previous token.
634+ bool accepted = true ;
635+ if (prev_token != nullptr ) {
636+ int lcp_len =
637+ std::mismatch (token.begin (), token.end (), prev_token->begin (), prev_token->end ()).first -
638+ token.begin ();
639+ if (lcp_len > *prev_matched_size) {
640+ // Case 1. The common prefix is rejected by the matcher in the last token. Reject
641+ // directly.
642+ accepted = false ;
643+ } else if (lcp_len < *prev_matched_size) {
644+ // Case 2. The common prefix is shorter than the previous matched size. Rollback
645+ // the non-common part.
646+ PopLastStates (*prev_matched_size - lcp_len);
647+ tmp_can_reach_end_stack_.erase (
648+ tmp_can_reach_end_stack_.end () - (*prev_matched_size - lcp_len),
649+ tmp_can_reach_end_stack_.end ()
650+ );
651+ tmp_can_reach_end_prefix_or_stack_.erase (
652+ tmp_can_reach_end_prefix_or_stack_.end () - (*prev_matched_size - lcp_len),
653+ tmp_can_reach_end_prefix_or_stack_.end ()
654+ );
655+ }
656+ *prev_matched_size = std::min (*prev_matched_size, lcp_len);
657+ }
580658
581- prev_token = &token;
659+ prev_token = &token;
582660
583- if (accepted) {
584- // Accept the rest chars one by one.
585- for (int j = prev_matched_size; j < static_cast <int >(token.size ()); ++j) {
586- if (!Advance (token[j])) {
587- accepted = false ;
588- break ;
589- }
590- tmp_can_reach_end_stack_.push_back (IsCompleted ());
591- tmp_can_reach_end_prefix_or_stack_.push_back (
592- tmp_can_reach_end_stack_.back () || tmp_can_reach_end_prefix_or_stack_.back ()
593- );
594- prev_matched_size = j + 1 ;
661+ if (accepted) {
662+ // Accept the rest chars one by one.
663+ for (int j = *prev_matched_size; j < static_cast <int >(token.size ()); ++j) {
664+ if (!Advance (token[j])) {
665+ accepted = false ;
666+ break ;
595667 }
668+ tmp_can_reach_end_stack_.push_back (IsCompleted ());
669+ tmp_can_reach_end_prefix_or_stack_.push_back (
670+ tmp_can_reach_end_stack_.back () || tmp_can_reach_end_prefix_or_stack_.back ()
671+ );
672+ *prev_matched_size = j + 1 ;
596673 }
674+ }
597675
598- bool can_reach_end = tmp_can_reach_end_prefix_or_stack_.back ();
676+ bool can_reach_end = tmp_can_reach_end_prefix_or_stack_.back ();
599677
600- if (accepted) {
601- tmp_accepted_indices_.push_back (i);
602- } else {
603- auto lookahead_result_pair = IsTokenPassLookaheadAssertion (token, tmp_can_reach_end_stack_);
604- if (can_reach_end && !is_root_rule && lookahead_result_pair.first &&
605- prev_matched_size > 0 ) {
606- // 1. If the current rule is the root rule (is_root_rule=true), there are no
607- // uncertain tokens. Not accepted tokens are just rejected.
608- // 2. If a token cannot pass the lookahead assertion, it is rejected.
609- if ((!lookahead_result_pair.second ) && is_exact_lookahead) {
610- tmp_accepted_indices_.push_back (i);
611- } else {
612- tmp_uncertain_indices_.push_back (i);
613- // On the subtree, they are all uncertain tokens.
614- if (lookahead_result_pair.second ) {
615- for (int j = i + 1 ; j < subtree_nodes_range[i]; ++j) {
616- tmp_uncertain_indices_.push_back (j);
617- }
618- i = subtree_nodes_range[i] - 1 ; // Skip the subtree nodes.
678+ if (accepted) {
679+ tmp_accepted_indices_.push_back (i);
680+ } else {
681+ auto lookahead_result_pair = IsTokenPassLookaheadAssertion (token, tmp_can_reach_end_stack_);
682+ if (can_reach_end && !is_root_rule && lookahead_result_pair.first && *prev_matched_size > 0 ) {
683+ // 1. If the current rule is the root rule (is_root_rule=true), there are no
684+ // uncertain tokens. Not accepted tokens are just rejected.
685+ // 2. If a token cannot pass the lookahead assertion, it is rejected.
686+ if ((!lookahead_result_pair.second ) && is_exact_lookahead) {
687+ tmp_accepted_indices_.push_back (i);
688+ } else {
689+ tmp_uncertain_indices_.push_back (i);
690+ // On the subtree, they are all uncertain tokens.
691+ if (lookahead_result_pair.second ) {
692+ for (int j = i + 1 ; j < subtree_nodes_range[i]; ++j) {
693+ tmp_uncertain_indices_.push_back (j);
619694 }
695+ i = subtree_nodes_range[i] - 1 ; // Skip the subtree nodes.
620696 }
621- } else {
622- tmp_rejected_indices_.push_back (i);
623- last_rejected_range = subtree_nodes_range[i];
624- fill_reject_indices =
625- tmp_rejected_indices_.size () >= AdaptiveTokenMask::USE_BITSET_THRESHOLD
626- ? false
627- : fill_reject_indices;
628697 }
629- }
630- }
631- if (interval_idx != possible_intervals.size () - 1 && fill_reject_indices) {
632- const auto & next_interval = possible_intervals[interval_idx + 1 ];
633- for (int i = interval.second ; i < next_interval.first ; ++i) {
698+ } else {
634699 tmp_rejected_indices_.push_back (i);
635- }
636- if (tmp_rejected_indices_.size () >= AdaptiveTokenMask::USE_BITSET_THRESHOLD) {
637- fill_reject_indices = false ;
700+ *last_rejected_range = subtree_nodes_range[i];
701+ fill_reject_indices =
702+ tmp_rejected_indices_.size () >= AdaptiveTokenMask::USE_BITSET_THRESHOLD
703+ ? false
704+ : fill_reject_indices;
638705 }
639706 }
640707 }
641-
642- // Rollback the last matched part.
643- PopLastStates (prev_matched_size);
644- return fill_reject_indices;
708+ return false ;
645709}
646710
647711/* ****************** GrammarCompilerNoCache *******************/
0 commit comments