Skip to content

Commit c462928

Browse files
committed
finish.
Signed-off-by: Yuchuan <[email protected]>
1 parent b755d00 commit c462928

File tree

1 file changed

+47
-30
lines changed

1 file changed

+47
-30
lines changed

cpp/grammar_compiler.cc

Lines changed: 47 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,19 @@ class GrammarMatcherForTokenMaskCache : public EarleyParser {
161161
const std::optional<const DynamicBitset*>& definite_accepted_bitset
162162
);
163163

164+
/*! \brief Find the common prefix size with the previous token.
165+
\param token The current token.
166+
\param prev_token The previous token.
167+
\param prev_matched_size The matched size of the previous token.
168+
\param accepted Whether the current token is accepted.
169+
*/
170+
void FindCommonPrefixWithPreviousToken(
171+
const std::string& token,
172+
const std::string*& prev_token,
173+
int* prev_matched_size,
174+
bool* accepted
175+
);
176+
164177
// The id of the initial rule.
165178
int32_t init_rule_id;
166179

@@ -620,6 +633,34 @@ bool GrammarMatcherForTokenMaskCache::ApplySpeculativeCalculation(
620633
return false;
621634
}
622635

636+
void GrammarMatcherForTokenMaskCache::FindCommonPrefixWithPreviousToken(
637+
const std::string& token, const std::string*& prev_token, int* prev_matched_size, bool* accepted
638+
) {
639+
if (prev_token != nullptr) {
640+
int lcp_len =
641+
std::mismatch(token.begin(), token.end(), prev_token->begin(), prev_token->end()).first -
642+
token.begin();
643+
if (lcp_len > *prev_matched_size) {
644+
// Case 1. The common prefix is rejected by the matcher in the last token. Reject
645+
// directly.
646+
*accepted = false;
647+
} else if (lcp_len < *prev_matched_size) {
648+
// Case 2. The common prefix is shorter than the previous matched size. Rollback
649+
// the non-common part.
650+
PopLastStates(*prev_matched_size - lcp_len);
651+
tmp_can_reach_end_stack_.erase(
652+
tmp_can_reach_end_stack_.end() - (*prev_matched_size - lcp_len),
653+
tmp_can_reach_end_stack_.end()
654+
);
655+
tmp_can_reach_end_prefix_or_stack_.erase(
656+
tmp_can_reach_end_prefix_or_stack_.end() - (*prev_matched_size - lcp_len),
657+
tmp_can_reach_end_prefix_or_stack_.end()
658+
);
659+
}
660+
*prev_matched_size = std::min(*prev_matched_size, lcp_len);
661+
}
662+
}
663+
623664
bool GrammarMatcherForTokenMaskCache::CheckTokensInInterval(
624665
const TokenizerInfo& tokenizer_info,
625666
const std::pair<int, int>& interval,
@@ -661,30 +702,7 @@ bool GrammarMatcherForTokenMaskCache::CheckTokensInInterval(
661702
// Many tokens may contain the same prefix, so we will avoid unnecessary matching
662703
// by finding the longest common prefix with the previous token.
663704
bool accepted = true;
664-
if (prev_token != nullptr) {
665-
int lcp_len =
666-
std::mismatch(token.begin(), token.end(), prev_token->begin(), prev_token->end()).first -
667-
token.begin();
668-
if (lcp_len > *prev_matched_size) {
669-
// Case 1. The common prefix is rejected by the matcher in the last token. Reject
670-
// directly.
671-
accepted = false;
672-
} else if (lcp_len < *prev_matched_size) {
673-
// Case 2. The common prefix is shorter than the previous matched size. Rollback
674-
// the non-common part.
675-
PopLastStates(*prev_matched_size - lcp_len);
676-
tmp_can_reach_end_stack_.erase(
677-
tmp_can_reach_end_stack_.end() - (*prev_matched_size - lcp_len),
678-
tmp_can_reach_end_stack_.end()
679-
);
680-
tmp_can_reach_end_prefix_or_stack_.erase(
681-
tmp_can_reach_end_prefix_or_stack_.end() - (*prev_matched_size - lcp_len),
682-
tmp_can_reach_end_prefix_or_stack_.end()
683-
);
684-
}
685-
*prev_matched_size = std::min(*prev_matched_size, lcp_len);
686-
}
687-
705+
FindCommonPrefixWithPreviousToken(token, prev_token, prev_matched_size, &accepted);
688706
prev_token = &token;
689707

690708
if (accepted) {
@@ -727,10 +745,9 @@ bool GrammarMatcherForTokenMaskCache::CheckTokensInInterval(
727745
} else {
728746
tmp_rejected_indices_.push_back(i);
729747
*last_rejected_range = subtree_nodes_range[i];
730-
fill_reject_indices =
731-
tmp_rejected_indices_.size() >= AdaptiveTokenMask::USE_BITSET_THRESHOLD
732-
? false
733-
: fill_reject_indices;
748+
if (tmp_rejected_indices_.size() >= AdaptiveTokenMask::USE_BITSET_THRESHOLD) {
749+
fill_reject_indices = false;
750+
}
734751
}
735752
}
736753
}
@@ -898,8 +915,8 @@ void GrammarCompilerNoCache::GenerateTokenMaskCacheForScannableStates(
898915

899916
auto root_rule_id = compiled_grammar_impl->grammar->GetRootRuleId();
900917

901-
// Iterate through all rules and their scannable states to generate the adaptive token mask, since
902-
// unscanable states will be expanded to the scannable states.
918+
// Iterate through all rules and their scannable states to generate the adaptive token mask,
919+
// since unscanable states will be expanded to the scannable states.
903920
for (int32_t rule_id = 0; rule_id < static_cast<int>(compiled_grammar_impl->grammar->NumRules());
904921
++rule_id) {
905922
auto rule = compiled_grammar_impl->grammar->GetRule(rule_id);

0 commit comments

Comments
 (0)