Skip to content

Commit 31f0845

Browse files
committed
refactor interval.
Signed-off-by: Yuchuan <[email protected]>
1 parent 192b641 commit 31f0845

File tree

1 file changed

+188
-124
lines changed

1 file changed

+188
-124
lines changed

cpp/grammar_compiler.cc

Lines changed: 188 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,15 @@ class GrammarMatcherForTokenMaskCache : public EarleyParser {
9595

9696
/*!
9797
* \brief Check all intervals for possible tokens.
98+
* \param tokenizer_info The tokenizer info.
99+
* \param possible_intervals The possible intervals for tokens.
100+
* \param speculative_calculation Whether to use speculative calculation.
101+
* \param speculative_mask The speculative mask for speculative calculation.
102+
* \param definite_accepted_bitset The definite accepted bitset for TagDispatch rules.
103+
* \param is_root_rule Whether to consider the parent rule. If false, there will be
104+
* no uncertain tokens. Useful for the root rule.
105+
* \param fill_reject_indices Whether to fill the rejected indices.
106+
* \return True if the rejected indices are filled as usual, False otherwise.
98107
* \note All the possible tokens will be divided into accepted, rejected and uncertain tokens.
99108
*/
100109
bool CheckAllPossibleTokens(
@@ -107,6 +116,36 @@ class GrammarMatcherForTokenMaskCache : public EarleyParser {
107116
bool fill_reject_indices
108117
);
109118

119+
/*! \brief Check each token in a given interval.
120+
\param tokenizer_info The tokenizer info.
121+
\param interval The interval to check.
122+
\param speculative_calculation Whether to use speculative calculation.
123+
\param speculative_mask The speculative mask for speculative calculation.
124+
\param definite_accepted_bitset The definite accepted bitset for TagDispatch rules.
125+
\param is_root_rule Whether to consider the parent rule. If false, there will be
126+
no uncertain tokens. Useful for the root rule.
127+
\param fill_reject_indices Whether to fill the rejected indices.
128+
\param last_rejected_range The last rejected subtree range. If a token's index is less than
129+
this value, it will be rejected directly.
130+
\param prev_token The previous token parsed in the parser.
131+
\param prev_matched_size The matched size of the previous token.
132+
\return True if the rejected indices are filled as usual, False otherwise.
133+
\note All the tokens in the given interval will be divided into accepted, rejected and
134+
uncertain tokens.
135+
*/
136+
bool CheckTokensInInterval(
137+
const TokenizerInfo& tokenizer_info,
138+
const std::pair<int, int>& interval,
139+
bool speculative_calculation,
140+
const std::bitset<256>& speculative_mask,
141+
const std::optional<const DynamicBitset*>& definite_accepted_bitset,
142+
bool is_root_rule,
143+
bool fill_reject_indices,
144+
int* last_rejected_range,
145+
const std::string* prev_token,
146+
int* prev_matched_size
147+
);
148+
110149
// The id of the initial rule.
111150
int32_t init_rule_id;
112151

@@ -495,153 +534,178 @@ bool GrammarMatcherForTokenMaskCache::CheckAllPossibleTokens(
495534
) {
496535
int prev_matched_size = 0;
497536
int last_rejected_range = 0;
498-
const auto& sorted_decoded_vocab = tokenizer_info.GetSortedDecodedVocab();
499-
const auto& subtree_nodes_range = tokenizer_info.GetTrieSubtreeNodesRange();
500-
const bool& is_exact_lookahead = grammar_->GetRule(init_rule_id).is_exact_lookahead;
501537
const std::string* prev_token = nullptr;
502538
for (size_t interval_idx = 0; interval_idx < possible_intervals.size(); ++interval_idx) {
503539
const auto& interval = possible_intervals[interval_idx];
504-
for (int i = interval.first; i < interval.second; ++i) {
505-
// Check if the current token is in the rejected range. i.e. check if the current token
506-
// is on the subtree of the rejected token.
507-
if (i < last_rejected_range) {
508-
if (fill_reject_indices) {
509-
tmp_rejected_indices_.push_back(i);
510-
if (tmp_rejected_indices_.size() >= AdaptiveTokenMask::USE_BITSET_THRESHOLD) {
511-
fill_reject_indices = false;
512-
}
513-
} else {
514-
i = last_rejected_range - 1;
540+
fill_reject_indices = CheckTokensInInterval(
541+
tokenizer_info,
542+
interval,
543+
speculative_calculation,
544+
speculative_mask,
545+
definite_accepted_bitset,
546+
is_root_rule,
547+
fill_reject_indices,
548+
&last_rejected_range,
549+
prev_token,
550+
&prev_matched_size
551+
);
552+
if (interval_idx != possible_intervals.size() - 1 && fill_reject_indices) {
553+
const auto& next_interval = possible_intervals[interval_idx + 1];
554+
for (int i = interval.second; i < next_interval.first; ++i) {
555+
tmp_rejected_indices_.push_back(i);
556+
}
557+
if (tmp_rejected_indices_.size() >= AdaptiveTokenMask::USE_BITSET_THRESHOLD) {
558+
fill_reject_indices = false;
559+
}
560+
}
561+
}
562+
563+
// Rollback the last matched part.
564+
PopLastStates(prev_matched_size);
565+
return fill_reject_indices;
566+
}
567+
568+
bool GrammarMatcherForTokenMaskCache::CheckTokensInInterval(
569+
const TokenizerInfo& tokenizer_info,
570+
const std::pair<int, int>& interval,
571+
bool speculative_calculation,
572+
const std::bitset<256>& speculative_mask,
573+
const std::optional<const DynamicBitset*>& definite_accepted_bitset,
574+
bool is_root_rule,
575+
bool fill_reject_indices,
576+
int* last_rejected_range,
577+
const std::string* prev_token,
578+
int* prev_matched_size
579+
) {
580+
const auto& sorted_decoded_vocab = tokenizer_info.GetSortedDecodedVocab();
581+
const auto& subtree_nodes_range = tokenizer_info.GetTrieSubtreeNodesRange();
582+
const bool& is_exact_lookahead = grammar_->GetRule(init_rule_id).is_exact_lookahead;
583+
for (int i = interval.first; i < interval.second; ++i) {
584+
// Check if the current token is in the rejected range. i.e. check if the current token
585+
// is on the subtree of the rejected token.
586+
if (i < *last_rejected_range) {
587+
if (fill_reject_indices) {
588+
tmp_rejected_indices_.push_back(i);
589+
if (tmp_rejected_indices_.size() >= AdaptiveTokenMask::USE_BITSET_THRESHOLD) {
590+
fill_reject_indices = false;
515591
}
516-
continue;
592+
} else {
593+
i = *last_rejected_range - 1;
517594
}
595+
continue;
596+
}
518597

519-
const auto& token = sorted_decoded_vocab[i].second;
520-
// This optimization is useful for simple self-recursive rules, like string content.
521-
if (speculative_calculation) {
522-
// Optimization for tag dispatch rules.
523-
if (definite_accepted_bitset.has_value()) {
524-
// If the token is empty, it must be accepted.
525-
if (token.empty()) {
526-
tmp_accepted_indices_.push_back(i);
527-
continue;
528-
}
529-
// If the token doesn't contain tags or stop strings since the second character, and it
530-
// will transit to the start state after consuming the first character, it must be
531-
// accepted.
532-
if (speculative_mask[static_cast<uint8_t>(token[0])] &&
533-
(*definite_accepted_bitset.value())[i]) {
534-
tmp_accepted_indices_.push_back(i);
535-
continue;
536-
}
537-
} else {
538-
bool all_accepted = true;
539-
for (char ch : token) {
540-
// If the first character is not the ascii character or can't be accepted by the
541-
// first character mask, we need to check them in the parser.
542-
if (isascii(ch) == 0 || !speculative_mask[static_cast<uint8_t>(ch)]) {
543-
all_accepted = false;
544-
break;
545-
}
546-
}
547-
if (all_accepted) {
548-
tmp_accepted_indices_.push_back(i);
549-
continue;
598+
const auto& token = sorted_decoded_vocab[i].second;
599+
// This optimization is useful for simple self-recursive rules, like string content.
600+
if (speculative_calculation) {
601+
// Optimization for tag dispatch rules.
602+
if (definite_accepted_bitset.has_value()) {
603+
// If the token is empty, it must be accepted.
604+
if (token.empty()) {
605+
tmp_accepted_indices_.push_back(i);
606+
continue;
607+
}
608+
// If the token doesn't contain tags or stop strings since the second character, and it
609+
// will transit to the start state after consuming the first character, it must be
610+
// accepted.
611+
if (speculative_mask[static_cast<uint8_t>(token[0])] &&
612+
(*definite_accepted_bitset.value())[i]) {
613+
tmp_accepted_indices_.push_back(i);
614+
continue;
615+
}
616+
} else {
617+
bool all_accepted = true;
618+
for (char ch : token) {
619+
// If the first character is not the ascii character or can't be accepted by the
620+
// first character mask, we need to check them in the parser.
621+
if (isascii(ch) == 0 || !speculative_mask[static_cast<uint8_t>(ch)]) {
622+
all_accepted = false;
623+
break;
550624
}
551625
}
552-
}
553-
// Many tokens may contain the same prefix, so we will avoid unnecessary matching
554-
// by finding the longest common prefix with the previous token.
555-
bool accepted = true;
556-
if (prev_token != nullptr) {
557-
int lcp_len =
558-
std::mismatch(token.begin(), token.end(), prev_token->begin(), prev_token->end())
559-
.first -
560-
token.begin();
561-
if (lcp_len > prev_matched_size) {
562-
// Case 1. The common prefix is rejected by the matcher in the last token. Reject
563-
// directly.
564-
accepted = false;
565-
} else if (lcp_len < prev_matched_size) {
566-
// Case 2. The common prefix is shorter than the previous matched size. Rollback
567-
// the non-common part.
568-
PopLastStates(prev_matched_size - lcp_len);
569-
tmp_can_reach_end_stack_.erase(
570-
tmp_can_reach_end_stack_.end() - (prev_matched_size - lcp_len),
571-
tmp_can_reach_end_stack_.end()
572-
);
573-
tmp_can_reach_end_prefix_or_stack_.erase(
574-
tmp_can_reach_end_prefix_or_stack_.end() - (prev_matched_size - lcp_len),
575-
tmp_can_reach_end_prefix_or_stack_.end()
576-
);
626+
if (all_accepted) {
627+
tmp_accepted_indices_.push_back(i);
628+
continue;
577629
}
578-
prev_matched_size = std::min(prev_matched_size, lcp_len);
579630
}
631+
}
632+
// Many tokens may contain the same prefix, so we will avoid unnecessary matching
633+
// by finding the longest common prefix with the previous token.
634+
bool accepted = true;
635+
if (prev_token != nullptr) {
636+
int lcp_len =
637+
std::mismatch(token.begin(), token.end(), prev_token->begin(), prev_token->end()).first -
638+
token.begin();
639+
if (lcp_len > *prev_matched_size) {
640+
// Case 1. The common prefix is rejected by the matcher in the last token. Reject
641+
// directly.
642+
accepted = false;
643+
} else if (lcp_len < *prev_matched_size) {
644+
// Case 2. The common prefix is shorter than the previous matched size. Rollback
645+
// the non-common part.
646+
PopLastStates(*prev_matched_size - lcp_len);
647+
tmp_can_reach_end_stack_.erase(
648+
tmp_can_reach_end_stack_.end() - (*prev_matched_size - lcp_len),
649+
tmp_can_reach_end_stack_.end()
650+
);
651+
tmp_can_reach_end_prefix_or_stack_.erase(
652+
tmp_can_reach_end_prefix_or_stack_.end() - (*prev_matched_size - lcp_len),
653+
tmp_can_reach_end_prefix_or_stack_.end()
654+
);
655+
}
656+
*prev_matched_size = std::min(*prev_matched_size, lcp_len);
657+
}
580658

581-
prev_token = &token;
659+
prev_token = &token;
582660

583-
if (accepted) {
584-
// Accept the rest chars one by one.
585-
for (int j = prev_matched_size; j < static_cast<int>(token.size()); ++j) {
586-
if (!Advance(token[j])) {
587-
accepted = false;
588-
break;
589-
}
590-
tmp_can_reach_end_stack_.push_back(IsCompleted());
591-
tmp_can_reach_end_prefix_or_stack_.push_back(
592-
tmp_can_reach_end_stack_.back() || tmp_can_reach_end_prefix_or_stack_.back()
593-
);
594-
prev_matched_size = j + 1;
661+
if (accepted) {
662+
// Accept the rest chars one by one.
663+
for (int j = *prev_matched_size; j < static_cast<int>(token.size()); ++j) {
664+
if (!Advance(token[j])) {
665+
accepted = false;
666+
break;
595667
}
668+
tmp_can_reach_end_stack_.push_back(IsCompleted());
669+
tmp_can_reach_end_prefix_or_stack_.push_back(
670+
tmp_can_reach_end_stack_.back() || tmp_can_reach_end_prefix_or_stack_.back()
671+
);
672+
*prev_matched_size = j + 1;
596673
}
674+
}
597675

598-
bool can_reach_end = tmp_can_reach_end_prefix_or_stack_.back();
676+
bool can_reach_end = tmp_can_reach_end_prefix_or_stack_.back();
599677

600-
if (accepted) {
601-
tmp_accepted_indices_.push_back(i);
602-
} else {
603-
auto lookahead_result_pair = IsTokenPassLookaheadAssertion(token, tmp_can_reach_end_stack_);
604-
if (can_reach_end && !is_root_rule && lookahead_result_pair.first &&
605-
prev_matched_size > 0) {
606-
// 1. If the current rule is the root rule (is_root_rule=true), there are no
607-
// uncertain tokens. Not accepted tokens are just rejected.
608-
// 2. If a token cannot pass the lookahead assertion, it is rejected.
609-
if ((!lookahead_result_pair.second) && is_exact_lookahead) {
610-
tmp_accepted_indices_.push_back(i);
611-
} else {
612-
tmp_uncertain_indices_.push_back(i);
613-
// On the subtree, they are all uncertain tokens.
614-
if (lookahead_result_pair.second) {
615-
for (int j = i + 1; j < subtree_nodes_range[i]; ++j) {
616-
tmp_uncertain_indices_.push_back(j);
617-
}
618-
i = subtree_nodes_range[i] - 1; // Skip the subtree nodes.
678+
if (accepted) {
679+
tmp_accepted_indices_.push_back(i);
680+
} else {
681+
auto lookahead_result_pair = IsTokenPassLookaheadAssertion(token, tmp_can_reach_end_stack_);
682+
if (can_reach_end && !is_root_rule && lookahead_result_pair.first && *prev_matched_size > 0) {
683+
// 1. If the current rule is the root rule (is_root_rule=true), there are no
684+
// uncertain tokens. Not accepted tokens are just rejected.
685+
// 2. If a token cannot pass the lookahead assertion, it is rejected.
686+
if ((!lookahead_result_pair.second) && is_exact_lookahead) {
687+
tmp_accepted_indices_.push_back(i);
688+
} else {
689+
tmp_uncertain_indices_.push_back(i);
690+
// On the subtree, they are all uncertain tokens.
691+
if (lookahead_result_pair.second) {
692+
for (int j = i + 1; j < subtree_nodes_range[i]; ++j) {
693+
tmp_uncertain_indices_.push_back(j);
619694
}
695+
i = subtree_nodes_range[i] - 1; // Skip the subtree nodes.
620696
}
621-
} else {
622-
tmp_rejected_indices_.push_back(i);
623-
last_rejected_range = subtree_nodes_range[i];
624-
fill_reject_indices =
625-
tmp_rejected_indices_.size() >= AdaptiveTokenMask::USE_BITSET_THRESHOLD
626-
? false
627-
: fill_reject_indices;
628697
}
629-
}
630-
}
631-
if (interval_idx != possible_intervals.size() - 1 && fill_reject_indices) {
632-
const auto& next_interval = possible_intervals[interval_idx + 1];
633-
for (int i = interval.second; i < next_interval.first; ++i) {
698+
} else {
634699
tmp_rejected_indices_.push_back(i);
635-
}
636-
if (tmp_rejected_indices_.size() >= AdaptiveTokenMask::USE_BITSET_THRESHOLD) {
637-
fill_reject_indices = false;
700+
*last_rejected_range = subtree_nodes_range[i];
701+
fill_reject_indices =
702+
tmp_rejected_indices_.size() >= AdaptiveTokenMask::USE_BITSET_THRESHOLD
703+
? false
704+
: fill_reject_indices;
638705
}
639706
}
640707
}
641-
642-
// Rollback the last matched part.
643-
PopLastStates(prev_matched_size);
644-
return fill_reject_indices;
708+
return false;
645709
}
646710

647711
/******************* GrammarCompilerNoCache *******************/

0 commit comments

Comments
 (0)