@@ -161,6 +161,19 @@ class GrammarMatcherForTokenMaskCache : public EarleyParser {
161161 const std::optional<const DynamicBitset*>& definite_accepted_bitset
162162 );
163163
164+ /* ! \brief Find the common prefix size with the previous token.
165+ \param token The current token.
166+ \param prev_token The previous token.
167+ \param prev_matched_size The matched size of the previous token.
168+ \param accepted Whether the current token is accepted.
169+ */
170+ void FindCommonPrefixWithPreviousToken (
171+ const std::string& token,
172+ const std::string*& prev_token,
173+ int * prev_matched_size,
174+ bool * accepted
175+ );
176+
164177 // The id of the initial rule.
165178 int32_t init_rule_id;
166179
@@ -620,6 +633,34 @@ bool GrammarMatcherForTokenMaskCache::ApplySpeculativeCalculation(
620633 return false ;
621634}
622635
636+ void GrammarMatcherForTokenMaskCache::FindCommonPrefixWithPreviousToken (
637+ const std::string& token, const std::string*& prev_token, int * prev_matched_size, bool * accepted
638+ ) {
639+ if (prev_token != nullptr ) {
640+ int lcp_len =
641+ std::mismatch (token.begin (), token.end (), prev_token->begin (), prev_token->end ()).first -
642+ token.begin ();
643+ if (lcp_len > *prev_matched_size) {
644+ // Case 1. The common prefix is rejected by the matcher in the last token. Reject
645+ // directly.
646+ *accepted = false ;
647+ } else if (lcp_len < *prev_matched_size) {
648+ // Case 2. The common prefix is shorter than the previous matched size. Rollback
649+ // the non-common part.
650+ PopLastStates (*prev_matched_size - lcp_len);
651+ tmp_can_reach_end_stack_.erase (
652+ tmp_can_reach_end_stack_.end () - (*prev_matched_size - lcp_len),
653+ tmp_can_reach_end_stack_.end ()
654+ );
655+ tmp_can_reach_end_prefix_or_stack_.erase (
656+ tmp_can_reach_end_prefix_or_stack_.end () - (*prev_matched_size - lcp_len),
657+ tmp_can_reach_end_prefix_or_stack_.end ()
658+ );
659+ }
660+ *prev_matched_size = std::min (*prev_matched_size, lcp_len);
661+ }
662+ }
663+
623664bool GrammarMatcherForTokenMaskCache::CheckTokensInInterval (
624665 const TokenizerInfo& tokenizer_info,
625666 const std::pair<int , int >& interval,
@@ -661,30 +702,7 @@ bool GrammarMatcherForTokenMaskCache::CheckTokensInInterval(
661702 // Many tokens may contain the same prefix, so we will avoid unnecessary matching
662703 // by finding the longest common prefix with the previous token.
663704 bool accepted = true ;
664- if (prev_token != nullptr ) {
665- int lcp_len =
666- std::mismatch (token.begin (), token.end (), prev_token->begin (), prev_token->end ()).first -
667- token.begin ();
668- if (lcp_len > *prev_matched_size) {
669- // Case 1. The common prefix is rejected by the matcher in the last token. Reject
670- // directly.
671- accepted = false ;
672- } else if (lcp_len < *prev_matched_size) {
673- // Case 2. The common prefix is shorter than the previous matched size. Rollback
674- // the non-common part.
675- PopLastStates (*prev_matched_size - lcp_len);
676- tmp_can_reach_end_stack_.erase (
677- tmp_can_reach_end_stack_.end () - (*prev_matched_size - lcp_len),
678- tmp_can_reach_end_stack_.end ()
679- );
680- tmp_can_reach_end_prefix_or_stack_.erase (
681- tmp_can_reach_end_prefix_or_stack_.end () - (*prev_matched_size - lcp_len),
682- tmp_can_reach_end_prefix_or_stack_.end ()
683- );
684- }
685- *prev_matched_size = std::min (*prev_matched_size, lcp_len);
686- }
687-
705+ FindCommonPrefixWithPreviousToken (token, prev_token, prev_matched_size, &accepted);
688706 prev_token = &token;
689707
690708 if (accepted) {
@@ -727,10 +745,9 @@ bool GrammarMatcherForTokenMaskCache::CheckTokensInInterval(
727745 } else {
728746 tmp_rejected_indices_.push_back (i);
729747 *last_rejected_range = subtree_nodes_range[i];
730- fill_reject_indices =
731- tmp_rejected_indices_.size () >= AdaptiveTokenMask::USE_BITSET_THRESHOLD
732- ? false
733- : fill_reject_indices;
748+ if (tmp_rejected_indices_.size () >= AdaptiveTokenMask::USE_BITSET_THRESHOLD) {
749+ fill_reject_indices = false ;
750+ }
734751 }
735752 }
736753 }
@@ -898,8 +915,8 @@ void GrammarCompilerNoCache::GenerateTokenMaskCacheForScannableStates(
898915
899916 auto root_rule_id = compiled_grammar_impl->grammar ->GetRootRuleId ();
900917
901- // Iterate through all rules and their scannable states to generate the adaptive token mask, since
902- // unscanable states will be expanded to the scannable states.
918+ // Iterate through all rules and their scannable states to generate the adaptive token mask,
919+ // since unscanable states will be expanded to the scannable states.
903920 for (int32_t rule_id = 0 ; rule_id < static_cast <int >(compiled_grammar_impl->grammar ->NumRules ());
904921 ++rule_id) {
905922 auto rule = compiled_grammar_impl->grammar ->GetRule (rule_id);
0 commit comments