@@ -97,7 +97,7 @@ class GrammarMatcherForTokenMaskCache : public EarleyParser {
9797 * \brief Check all intervals for possible tokens.
9898 * \param tokenizer_info The tokenizer info.
9999 * \param possible_intervals The possible intervals for tokens.
100- * \param speculative_calculation Whether to use speculative calculation.
100+ * \param speculative_calculation_applied Whether to use speculative calculation.
101101 * \param speculative_mask The speculative mask for speculative calculation.
102102 * \param definite_accepted_bitset The definite accepted bitset for TagDispatch rules.
103103 * \param is_root_rule Whether to consider the parent rule. If false, there will be
@@ -109,7 +109,7 @@ class GrammarMatcherForTokenMaskCache : public EarleyParser {
109109 bool CheckAllPossibleTokens (
110110 const TokenizerInfo& tokenizer_info,
111111 const std::vector<std::pair<int32_t , int32_t >>& possible_intervals,
112- bool speculative_calculation ,
112+ bool speculative_calculation_applied ,
113113 const std::bitset<256 >& speculative_mask,
114114 const std::optional<const DynamicBitset*>& definite_accepted_bitset,
115115 bool is_root_rule,
@@ -119,7 +119,7 @@ class GrammarMatcherForTokenMaskCache : public EarleyParser {
119119 /* ! \brief Check each token in a given interval.
120120 \param tokenizer_info The tokenizer info.
121121 \param interval The interval to check.
122- \param speculative_calculation Whether to use speculative calculation.
122+ \param speculative_calculation_applied Whether to use speculative calculation.
123123 \param speculative_mask The speculative mask for speculative calculation.
124124 \param definite_accepted_bitset The definite accepted bitset for TagDispatch rules.
125125 \param is_root_rule Whether to consider the parent rule. If false, there will be
@@ -136,7 +136,7 @@ class GrammarMatcherForTokenMaskCache : public EarleyParser {
136136 bool CheckTokensInInterval (
137137 const TokenizerInfo& tokenizer_info,
138138 const std::pair<int , int >& interval,
139- bool speculative_calculation ,
139+ bool speculative_calculation_applied ,
140140 const std::bitset<256 >& speculative_mask,
141141 const std::optional<const DynamicBitset*>& definite_accepted_bitset,
142142 bool is_root_rule,
@@ -146,6 +146,20 @@ class GrammarMatcherForTokenMaskCache : public EarleyParser {
146146 int * prev_matched_size
147147 );
148148
149+ /* ! \brief Apply speculative calculation for a token.
150+ \param token The token to check.
151+ \param index The index of the token in the vocabulary.
152+ \param speculative_mask The speculative mask for speculative calculation.
153+ \param definite_accepted_bitset The definite accepted bitset for TagDispatch rules.
154+ \return True if the token is accepted by speculative calculation, False otherwise.
155+ */
156+ bool ApplySpeculativeCalculation (
157+ const std::string& token,
158+ int32_t index,
159+ const std::bitset<256 >& speculative_mask,
160+ const std::optional<const DynamicBitset*>& definite_accepted_bitset
161+ );
162+
149163 // The id of the initial rule.
150164 int32_t init_rule_id;
151165
@@ -397,15 +411,15 @@ bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterOptimization
397411 }
398412 }
399413
400- bool speculative_calculation = false ;
414+ bool speculative_calculation_applied = false ;
401415 std::bitset<256 > speculative_mask;
402416 if (init_rule_id == -1 || !grammar_->per_rule_fsms [init_rule_id].has_value ()) {
403- speculative_calculation =
417+ speculative_calculation_applied =
404418 GetSpeculativeCalculation (sorted_decoded_vocab).first &&
405419 (possible_token_num >= static_cast <int >(sorted_decoded_vocab.size () / 4 ));
406420 speculative_mask = first_char_mask;
407421 } else {
408- std::tie (speculative_calculation , speculative_mask) =
422+ std::tie (speculative_calculation_applied , speculative_mask) =
409423 GetSpeculativeCalculation (sorted_decoded_vocab);
410424 }
411425
@@ -421,7 +435,7 @@ bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterOptimization
421435 fill_reject_indices = CheckAllPossibleTokens (
422436 tokenizer_info,
423437 possible_intervals,
424- speculative_calculation ,
438+ speculative_calculation_applied ,
425439 speculative_mask,
426440 definite_accepted_bitset,
427441 is_root_rule,
@@ -526,7 +540,7 @@ std::bitset<256> GrammarMatcherForTokenMaskCache::GetFirstCharacterMask() {
526540bool GrammarMatcherForTokenMaskCache::CheckAllPossibleTokens (
527541 const TokenizerInfo& tokenizer_info,
528542 const std::vector<std::pair<int32_t , int32_t >>& possible_intervals,
529- bool speculative_calculation ,
543+ bool speculative_calculation_applied ,
530544 const std::bitset<256 >& speculative_mask,
531545 const std::optional<const DynamicBitset*>& definite_accepted_bitset,
532546 bool is_root_rule,
@@ -540,7 +554,7 @@ bool GrammarMatcherForTokenMaskCache::CheckAllPossibleTokens(
540554 fill_reject_indices = CheckTokensInInterval (
541555 tokenizer_info,
542556 interval,
543- speculative_calculation ,
557+ speculative_calculation_applied ,
544558 speculative_mask,
545559 definite_accepted_bitset,
546560 is_root_rule,
@@ -565,10 +579,50 @@ bool GrammarMatcherForTokenMaskCache::CheckAllPossibleTokens(
565579 return fill_reject_indices;
566580}
567581
582+ bool GrammarMatcherForTokenMaskCache::ApplySpeculativeCalculation (
583+ const std::string& token,
584+ int32_t index,
585+ const std::bitset<256 >& speculative_mask,
586+ const std::optional<const DynamicBitset*>& definite_accepted_bitset
587+ ) {
588+ // This optimization is useful for simple self-recursive rules, like string content.
589+ // Optimization for tag dispatch rules.
590+ if (definite_accepted_bitset.has_value ()) {
591+ // If the token is empty, it must be accepted.
592+ if (token.empty ()) {
593+ tmp_accepted_indices_.push_back (index);
594+ return true ;
595+ }
596+ // If the token doesn't contain tags or stop strings since the second character, and it
597+ // will transit to the start state after consuming the first character, it must be
598+ // accepted.
599+ if (speculative_mask[static_cast <uint8_t >(token[0 ])] &&
600+ (*definite_accepted_bitset.value ())[index]) {
601+ tmp_accepted_indices_.push_back (index);
602+ return true ;
603+ }
604+ } else {
605+ bool all_accepted = true ;
606+ for (char ch : token) {
607+ // If the first character is not the ascii character or can't be accepted by the
608+ // first character mask, we need to check them in the parser.
609+ if (isascii (ch) == 0 || !speculative_mask[static_cast <uint8_t >(ch)]) {
610+ all_accepted = false ;
611+ break ;
612+ }
613+ }
614+ if (all_accepted) {
615+ tmp_accepted_indices_.push_back (index);
616+ return true ;
617+ }
618+ }
619+ return false ;
620+ }
621+
568622bool GrammarMatcherForTokenMaskCache::CheckTokensInInterval (
569623 const TokenizerInfo& tokenizer_info,
570624 const std::pair<int , int >& interval,
571- bool speculative_calculation ,
625+ bool speculative_calculation_applied ,
572626 const std::bitset<256 >& speculative_mask,
573627 const std::optional<const DynamicBitset*>& definite_accepted_bitset,
574628 bool is_root_rule,
@@ -596,37 +650,11 @@ bool GrammarMatcherForTokenMaskCache::CheckTokensInInterval(
596650 }
597651
598652 const auto & token = sorted_decoded_vocab[i].second ;
599- // This optimization is useful for simple self-recursive rules, like string content.
600- if (speculative_calculation) {
601- // Optimization for tag dispatch rules.
602- if (definite_accepted_bitset.has_value ()) {
603- // If the token is empty, it must be accepted.
604- if (token.empty ()) {
605- tmp_accepted_indices_.push_back (i);
606- continue ;
607- }
608- // If the token doesn't contain tags or stop strings since the second character, and it
609- // will transit to the start state after consuming the first character, it must be
610- // accepted.
611- if (speculative_mask[static_cast <uint8_t >(token[0 ])] &&
612- (*definite_accepted_bitset.value ())[i]) {
613- tmp_accepted_indices_.push_back (i);
614- continue ;
615- }
616- } else {
617- bool all_accepted = true ;
618- for (char ch : token) {
619- // If the first character is not the ascii character or can't be accepted by the
620- // first character mask, we need to check them in the parser.
621- if (isascii (ch) == 0 || !speculative_mask[static_cast <uint8_t >(ch)]) {
622- all_accepted = false ;
623- break ;
624- }
625- }
626- if (all_accepted) {
627- tmp_accepted_indices_.push_back (i);
628- continue ;
629- }
653+ if (speculative_calculation_applied) {
654+ bool speculative_accepted =
655+ ApplySpeculativeCalculation (token, i, speculative_mask, definite_accepted_bitset);
656+ if (speculative_accepted) {
657+ continue ;
630658 }
631659 }
632660 // Many tokens may contain the same prefix, so we will avoid unnecessary matching
0 commit comments