Skip to content

Commit 7e9690f

Browse files
committed
refactor.
Signed-off-by: Yuchuan <[email protected]>
1 parent a3238b6 commit 7e9690f

File tree

1 file changed

+70
-42
lines changed

1 file changed

+70
-42
lines changed

cpp/grammar_compiler.cc

Lines changed: 70 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ class GrammarMatcherForTokenMaskCache : public EarleyParser {
9797
* \brief Check all intervals for possible tokens.
9898
* \param tokenizer_info The tokenizer info.
9999
* \param possible_intervals The possible intervals for tokens.
100-
* \param speculative_calculation Whether to use speculative calculation.
100+
* \param speculative_calculation_applied Whether to use speculative calculation.
101101
* \param speculative_mask The speculative mask for speculative calculation.
102102
* \param definite_accepted_bitset The definite accepted bitset for TagDispatch rules.
103103
* \param is_root_rule Whether to consider the parent rule. If false, there will be
@@ -109,7 +109,7 @@ class GrammarMatcherForTokenMaskCache : public EarleyParser {
109109
bool CheckAllPossibleTokens(
110110
const TokenizerInfo& tokenizer_info,
111111
const std::vector<std::pair<int32_t, int32_t>>& possible_intervals,
112-
bool speculative_calculation,
112+
bool speculative_calculation_applied,
113113
const std::bitset<256>& speculative_mask,
114114
const std::optional<const DynamicBitset*>& definite_accepted_bitset,
115115
bool is_root_rule,
@@ -119,7 +119,7 @@ class GrammarMatcherForTokenMaskCache : public EarleyParser {
119119
/*! \brief Check each token in a given interval.
120120
\param tokenizer_info The tokenizer info.
121121
\param interval The interval to check.
122-
\param speculative_calculation Whether to use speculative calculation.
122+
\param speculative_calculation_applied Whether to use speculative calculation.
123123
\param speculative_mask The speculative mask for speculative calculation.
124124
\param definite_accepted_bitset The definite accepted bitset for TagDispatch rules.
125125
\param is_root_rule Whether to consider the parent rule. If false, there will be
@@ -136,7 +136,7 @@ class GrammarMatcherForTokenMaskCache : public EarleyParser {
136136
bool CheckTokensInInterval(
137137
const TokenizerInfo& tokenizer_info,
138138
const std::pair<int, int>& interval,
139-
bool speculative_calculation,
139+
bool speculative_calculation_applied,
140140
const std::bitset<256>& speculative_mask,
141141
const std::optional<const DynamicBitset*>& definite_accepted_bitset,
142142
bool is_root_rule,
@@ -146,6 +146,20 @@ class GrammarMatcherForTokenMaskCache : public EarleyParser {
146146
int* prev_matched_size
147147
);
148148

149+
/*! \brief Apply speculative calculation for a token.
150+
\param token The token to check.
151+
\param index The index of the token in the vocabulary.
152+
\param speculative_mask The speculative mask for speculative calculation.
153+
\param definite_accepted_bitset The definite accepted bitset for TagDispatch rules.
154+
\return True if the token is accepted by speculative calculation, False otherwise.
155+
*/
156+
bool ApplySpeculativeCalculation(
157+
const std::string& token,
158+
int32_t index,
159+
const std::bitset<256>& speculative_mask,
160+
const std::optional<const DynamicBitset*>& definite_accepted_bitset
161+
);
162+
149163
// The id of the initial rule.
150164
int32_t init_rule_id;
151165

@@ -397,15 +411,15 @@ bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterOptimization
397411
}
398412
}
399413

400-
bool speculative_calculation = false;
414+
bool speculative_calculation_applied = false;
401415
std::bitset<256> speculative_mask;
402416
if (init_rule_id == -1 || !grammar_->per_rule_fsms[init_rule_id].has_value()) {
403-
speculative_calculation =
417+
speculative_calculation_applied =
404418
GetSpeculativeCalculation(sorted_decoded_vocab).first &&
405419
(possible_token_num >= static_cast<int>(sorted_decoded_vocab.size() / 4));
406420
speculative_mask = first_char_mask;
407421
} else {
408-
std::tie(speculative_calculation, speculative_mask) =
422+
std::tie(speculative_calculation_applied, speculative_mask) =
409423
GetSpeculativeCalculation(sorted_decoded_vocab);
410424
}
411425

@@ -421,7 +435,7 @@ bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterOptimization
421435
fill_reject_indices = CheckAllPossibleTokens(
422436
tokenizer_info,
423437
possible_intervals,
424-
speculative_calculation,
438+
speculative_calculation_applied,
425439
speculative_mask,
426440
definite_accepted_bitset,
427441
is_root_rule,
@@ -526,7 +540,7 @@ std::bitset<256> GrammarMatcherForTokenMaskCache::GetFirstCharacterMask() {
526540
bool GrammarMatcherForTokenMaskCache::CheckAllPossibleTokens(
527541
const TokenizerInfo& tokenizer_info,
528542
const std::vector<std::pair<int32_t, int32_t>>& possible_intervals,
529-
bool speculative_calculation,
543+
bool speculative_calculation_applied,
530544
const std::bitset<256>& speculative_mask,
531545
const std::optional<const DynamicBitset*>& definite_accepted_bitset,
532546
bool is_root_rule,
@@ -540,7 +554,7 @@ bool GrammarMatcherForTokenMaskCache::CheckAllPossibleTokens(
540554
fill_reject_indices = CheckTokensInInterval(
541555
tokenizer_info,
542556
interval,
543-
speculative_calculation,
557+
speculative_calculation_applied,
544558
speculative_mask,
545559
definite_accepted_bitset,
546560
is_root_rule,
@@ -565,10 +579,50 @@ bool GrammarMatcherForTokenMaskCache::CheckAllPossibleTokens(
565579
return fill_reject_indices;
566580
}
567581

582+
bool GrammarMatcherForTokenMaskCache::ApplySpeculativeCalculation(
583+
const std::string& token,
584+
int32_t index,
585+
const std::bitset<256>& speculative_mask,
586+
const std::optional<const DynamicBitset*>& definite_accepted_bitset
587+
) {
588+
// This optimization is useful for simple self-recursive rules, like string content.
589+
// Optimization for tag dispatch rules.
590+
if (definite_accepted_bitset.has_value()) {
591+
// If the token is empty, it must be accepted.
592+
if (token.empty()) {
593+
tmp_accepted_indices_.push_back(index);
594+
return true;
595+
}
596+
// If the token doesn't contain tags or stop strings since the second character, and it
597+
// will transit to the start state after consuming the first character, it must be
598+
// accepted.
599+
if (speculative_mask[static_cast<uint8_t>(token[0])] &&
600+
(*definite_accepted_bitset.value())[index]) {
601+
tmp_accepted_indices_.push_back(index);
602+
return true;
603+
}
604+
} else {
605+
bool all_accepted = true;
606+
for (char ch : token) {
607+
// If the first character is not the ascii character or can't be accepted by the
608+
// first character mask, we need to check them in the parser.
609+
if (isascii(ch) == 0 || !speculative_mask[static_cast<uint8_t>(ch)]) {
610+
all_accepted = false;
611+
break;
612+
}
613+
}
614+
if (all_accepted) {
615+
tmp_accepted_indices_.push_back(index);
616+
return true;
617+
}
618+
}
619+
return false;
620+
}
621+
568622
bool GrammarMatcherForTokenMaskCache::CheckTokensInInterval(
569623
const TokenizerInfo& tokenizer_info,
570624
const std::pair<int, int>& interval,
571-
bool speculative_calculation,
625+
bool speculative_calculation_applied,
572626
const std::bitset<256>& speculative_mask,
573627
const std::optional<const DynamicBitset*>& definite_accepted_bitset,
574628
bool is_root_rule,
@@ -596,37 +650,11 @@ bool GrammarMatcherForTokenMaskCache::CheckTokensInInterval(
596650
}
597651

598652
const auto& token = sorted_decoded_vocab[i].second;
599-
// This optimization is useful for simple self-recursive rules, like string content.
600-
if (speculative_calculation) {
601-
// Optimization for tag dispatch rules.
602-
if (definite_accepted_bitset.has_value()) {
603-
// If the token is empty, it must be accepted.
604-
if (token.empty()) {
605-
tmp_accepted_indices_.push_back(i);
606-
continue;
607-
}
608-
// If the token doesn't contain tags or stop strings since the second character, and it
609-
// will transit to the start state after consuming the first character, it must be
610-
// accepted.
611-
if (speculative_mask[static_cast<uint8_t>(token[0])] &&
612-
(*definite_accepted_bitset.value())[i]) {
613-
tmp_accepted_indices_.push_back(i);
614-
continue;
615-
}
616-
} else {
617-
bool all_accepted = true;
618-
for (char ch : token) {
619-
// If the first character is not the ascii character or can't be accepted by the
620-
// first character mask, we need to check them in the parser.
621-
if (isascii(ch) == 0 || !speculative_mask[static_cast<uint8_t>(ch)]) {
622-
all_accepted = false;
623-
break;
624-
}
625-
}
626-
if (all_accepted) {
627-
tmp_accepted_indices_.push_back(i);
628-
continue;
629-
}
653+
if (speculative_calculation_applied) {
654+
bool speculative_accepted =
655+
ApplySpeculativeCalculation(token, i, speculative_mask, definite_accepted_bitset);
656+
if (speculative_accepted) {
657+
continue;
630658
}
631659
}
632660
// Many tokens may contain the same prefix, so we will avoid unnecessary matching

0 commit comments

Comments
 (0)