Skip to content

Commit c10151c

Browse files
committed
refactor.
Signed-off-by: Yuchuan <[email protected]>
1 parent d85138d commit c10151c

File tree

1 file changed

+135
-98
lines changed

1 file changed

+135
-98
lines changed

cpp/grammar_compiler.cc

Lines changed: 135 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class GrammarMatcherForTokenMaskCache : public EarleyParser {
6464
* \returns True if the rejected indices are filled as usual, False otherwise.
6565
* It's used to determine which construction function will be used.
6666
*/
67-
bool GetTokenMaskWithFirstCharacterCheck(
67+
bool GetTokenMaskWithFirstCharacterOptimization(
6868
const TokenizerInfo& tokenizer_info,
6969
const std::bitset<256>& first_char_mask,
7070
bool is_root_rule
@@ -93,6 +93,20 @@ class GrammarMatcherForTokenMaskCache : public EarleyParser {
9393
*/
9494
std::bitset<256> GetFirstCharacterMask();
9595

96+
/*!
97+
* \brief Check all intervals for possible tokens.
98+
* \note All the possible tokens will be divided into accepted, rejected and uncertain tokens.
99+
*/
100+
bool CheckAllPossibleTokens(
101+
const TokenizerInfo& tokenizer_info,
102+
const std::vector<std::pair<int32_t, int32_t>>& possible_intervals,
103+
bool speculative_calculation,
104+
const std::bitset<256>& speculative_mask,
105+
const std::optional<const DynamicBitset*>& definite_accepted_bitset,
106+
bool is_root_rule,
107+
bool fill_reject_indices
108+
);
109+
96110
// The id of the initial rule.
97111
int32_t init_rule_id;
98112

@@ -322,12 +336,11 @@ std::pair<bool, std::bitset<256>> GrammarMatcherForTokenMaskCache::GetSpeculativ
322336
return {can_be_applied, speculative_mask};
323337
}
324338

325-
bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterCheck(
339+
bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterOptimization(
326340
const TokenizerInfo& tokenizer_info, const std::bitset<256>& first_char_mask, bool is_root_rule
327341
) {
328342
// the pair (a, b) means [a, b). Intialize the possible intervals.
329343
const auto& sorted_decoded_vocab = tokenizer_info.GetSortedDecodedVocab();
330-
const auto& subtree_nodes_range = tokenizer_info.GetTrieSubtreeNodesRange();
331344
std::vector<std::pair<int32_t, int32_t>> possible_intervals;
332345
int possible_token_num =
333346
GetPossibleTokenIntervals(sorted_decoded_vocab, first_char_mask, possible_intervals);
@@ -357,9 +370,6 @@ bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterCheck(
357370
GetSpeculativeCalculation(sorted_decoded_vocab);
358371
}
359372

360-
int prev_matched_size = 0;
361-
int last_rejected_range = 0;
362-
const bool& is_exact_lookahead = grammar_->GetRule(init_rule_id).is_exact_lookahead;
363373
std::optional<const DynamicBitset*> definite_accepted_bitset = std::nullopt;
364374
const bool is_tag_dispatch_rule =
365375
grammar_->GetGrammarExpr(grammar_->GetRule(init_rule_id).body_expr_id).type ==
@@ -369,6 +379,125 @@ bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterCheck(
369379
definite_accepted_bitset = &tag_dispatch_rule_id_to_second_slicing_bitset.at(init_rule_id);
370380
}
371381

382+
fill_reject_indices = CheckAllPossibleTokens(
383+
tokenizer_info,
384+
possible_intervals,
385+
speculative_calculation,
386+
speculative_mask,
387+
definite_accepted_bitset,
388+
is_root_rule,
389+
fill_reject_indices
390+
);
391+
392+
if (possible_intervals.back().second != static_cast<int>(sorted_decoded_vocab.size()) &&
393+
fill_reject_indices) {
394+
// If the last interval is not closed, we need to reject the rest tokens.
395+
for (int i = possible_intervals.back().second;
396+
i < static_cast<int>(sorted_decoded_vocab.size());
397+
++i) {
398+
tmp_rejected_indices_.push_back(i);
399+
}
400+
}
401+
402+
return fill_reject_indices;
403+
}
404+
405+
AdaptiveTokenMask GrammarMatcherForTokenMaskCache::GetAdaptiveTokenMask(
406+
const TokenizerInfo& tokenizer_info, bool is_root_rule
407+
) {
408+
const auto& sorted_decoded_vocab = tokenizer_info.GetSortedDecodedVocab();
409+
const int vocab_size = tokenizer_info.GetVocabSize();
410+
tmp_accepted_indices_.clear();
411+
tmp_rejected_indices_.clear();
412+
tmp_uncertain_indices_.clear();
413+
// For every character in the current token, stores whether it is possible to reach the end of
414+
// the rule when matching until this character. Store it in a stack for later rollback.
415+
tmp_can_reach_end_stack_.push_back(false);
416+
tmp_can_reach_end_prefix_or_stack_.push_back(false);
417+
std::bitset<256> first_character_mask = GetFirstCharacterMask();
418+
bool rejected_indices_are_filled = GetTokenMaskWithFirstCharacterOptimization(
419+
tokenizer_info, first_character_mask, is_root_rule
420+
);
421+
if (rejected_indices_are_filled) {
422+
return AdaptiveTokenMask(
423+
vocab_size,
424+
sorted_decoded_vocab,
425+
tmp_accepted_indices_,
426+
tmp_rejected_indices_,
427+
tmp_uncertain_indices_
428+
);
429+
} else {
430+
return AdaptiveTokenMask(
431+
vocab_size, sorted_decoded_vocab, tmp_accepted_indices_, tmp_uncertain_indices_
432+
);
433+
}
434+
}
435+
436+
std::bitset<256> GrammarMatcherForTokenMaskCache::GetFirstCharacterMask() {
437+
std::bitset<256> first_character_mask;
438+
const auto& sequence = grammar_->GetGrammarExpr(initial_state.sequence_id);
439+
if (!grammar_->per_rule_fsms[init_rule_id].has_value()) {
440+
const auto& sub_sequence = grammar_->GetGrammarExpr(sequence[initial_state.element_id]);
441+
switch (sub_sequence.type) {
442+
case Grammar::Impl::GrammarExprType::kByteString: {
443+
first_character_mask[sub_sequence[initial_state.sub_element_id]] = true;
444+
break;
445+
}
446+
case xgrammar::Grammar::Impl::GrammarExprType::kCharacterClass:
447+
case xgrammar::Grammar::Impl::GrammarExprType::kCharacterClassStar: {
448+
if (initial_state.sub_element_id == 0) {
449+
bool is_negative = sub_sequence[0];
450+
for (int i = 1; i < sub_sequence.size(); i += 2) {
451+
int left_char = static_cast<uint8_t>(sub_sequence[i]);
452+
int right_char = static_cast<uint8_t>(sub_sequence[i + 1]);
453+
for (int c = left_char; c <= right_char; ++c) {
454+
first_character_mask[c] = true;
455+
}
456+
}
457+
if (is_negative) {
458+
first_character_mask = ~first_character_mask;
459+
}
460+
break;
461+
}
462+
// Otherwise, it's matching a UTF-8 character. We can optimize the matching process
463+
// here.
464+
for (size_t i = 0x80; i < 0xC0; ++i) {
465+
first_character_mask[i] = true;
466+
}
467+
break;
468+
}
469+
default: {
470+
XGRAMMAR_LOG(FATAL) << "Unsupported grammar expr type: " << static_cast<int>(sequence.type);
471+
}
472+
}
473+
} else {
474+
const auto& fsm = grammar_->per_rule_fsms[init_rule_id].value();
475+
const auto& edges = fsm.GetFsm().GetEdges(initial_state.element_id);
476+
for (const auto& edge : edges) {
477+
if (edge.IsCharRange()) {
478+
for (int c = edge.min; c <= edge.max; ++c) {
479+
first_character_mask[c] = true;
480+
}
481+
}
482+
}
483+
}
484+
return first_character_mask;
485+
}
486+
487+
bool GrammarMatcherForTokenMaskCache::CheckAllPossibleTokens(
488+
const TokenizerInfo& tokenizer_info,
489+
const std::vector<std::pair<int32_t, int32_t>>& possible_intervals,
490+
bool speculative_calculation,
491+
const std::bitset<256>& speculative_mask,
492+
const std::optional<const DynamicBitset*>& definite_accepted_bitset,
493+
bool is_root_rule,
494+
bool fill_reject_indices
495+
) {
496+
int prev_matched_size = 0;
497+
int last_rejected_range = 0;
498+
const auto& sorted_decoded_vocab = tokenizer_info.GetSortedDecodedVocab();
499+
const auto& subtree_nodes_range = tokenizer_info.GetTrieSubtreeNodesRange();
500+
const bool& is_exact_lookahead = grammar_->GetRule(init_rule_id).is_exact_lookahead;
372501
const std::string* prev_token = nullptr;
373502
for (size_t interval_idx = 0; interval_idx < possible_intervals.size(); ++interval_idx) {
374503
const auto& interval = possible_intervals[interval_idx];
@@ -512,101 +641,9 @@ bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterCheck(
512641

513642
// Rollback the last matched part.
514643
PopLastStates(prev_matched_size);
515-
516-
if (possible_intervals.back().second != static_cast<int>(sorted_decoded_vocab.size()) &&
517-
fill_reject_indices) {
518-
// If the last interval is not closed, we need to reject the rest tokens.
519-
for (int i = possible_intervals.back().second;
520-
i < static_cast<int>(sorted_decoded_vocab.size());
521-
++i) {
522-
tmp_rejected_indices_.push_back(i);
523-
}
524-
}
525-
526644
return fill_reject_indices;
527645
}
528646

529-
AdaptiveTokenMask GrammarMatcherForTokenMaskCache::GetAdaptiveTokenMask(
530-
const TokenizerInfo& tokenizer_info, bool is_root_rule
531-
) {
532-
const auto& sorted_decoded_vocab = tokenizer_info.GetSortedDecodedVocab();
533-
const int vocab_size = tokenizer_info.GetVocabSize();
534-
tmp_accepted_indices_.clear();
535-
tmp_rejected_indices_.clear();
536-
tmp_uncertain_indices_.clear();
537-
// For every character in the current token, stores whether it is possible to reach the end of
538-
// the rule when matching until this character. Store it in a stack for later rollback.
539-
tmp_can_reach_end_stack_.push_back(false);
540-
tmp_can_reach_end_prefix_or_stack_.push_back(false);
541-
std::bitset<256> first_character_mask = GetFirstCharacterMask();
542-
bool rejected_indices_are_filled =
543-
GetTokenMaskWithFirstCharacterCheck(tokenizer_info, first_character_mask, is_root_rule);
544-
if (rejected_indices_are_filled) {
545-
return AdaptiveTokenMask(
546-
vocab_size,
547-
sorted_decoded_vocab,
548-
tmp_accepted_indices_,
549-
tmp_rejected_indices_,
550-
tmp_uncertain_indices_
551-
);
552-
} else {
553-
return AdaptiveTokenMask(
554-
vocab_size, sorted_decoded_vocab, tmp_accepted_indices_, tmp_uncertain_indices_
555-
);
556-
}
557-
}
558-
559-
std::bitset<256> GrammarMatcherForTokenMaskCache::GetFirstCharacterMask() {
560-
std::bitset<256> first_character_mask;
561-
const auto& sequence = grammar_->GetGrammarExpr(initial_state.sequence_id);
562-
if (!grammar_->per_rule_fsms[init_rule_id].has_value()) {
563-
const auto& sub_sequence = grammar_->GetGrammarExpr(sequence[initial_state.element_id]);
564-
switch (sub_sequence.type) {
565-
case Grammar::Impl::GrammarExprType::kByteString: {
566-
first_character_mask[sub_sequence[initial_state.sub_element_id]] = true;
567-
break;
568-
}
569-
case xgrammar::Grammar::Impl::GrammarExprType::kCharacterClass:
570-
case xgrammar::Grammar::Impl::GrammarExprType::kCharacterClassStar: {
571-
if (initial_state.sub_element_id == 0) {
572-
bool is_negative = sub_sequence[0];
573-
for (int i = 1; i < sub_sequence.size(); i += 2) {
574-
int left_char = static_cast<uint8_t>(sub_sequence[i]);
575-
int right_char = static_cast<uint8_t>(sub_sequence[i + 1]);
576-
for (int c = left_char; c <= right_char; ++c) {
577-
first_character_mask[c] = true;
578-
}
579-
}
580-
if (is_negative) {
581-
first_character_mask = ~first_character_mask;
582-
}
583-
break;
584-
}
585-
// Otherwise, it's matching a UTF-8 character. We can optimize the matching process
586-
// here.
587-
for (size_t i = 0x80; i < 0xC0; ++i) {
588-
first_character_mask[i] = true;
589-
}
590-
break;
591-
}
592-
default: {
593-
XGRAMMAR_LOG(FATAL) << "Unsupported grammar expr type: " << static_cast<int>(sequence.type);
594-
}
595-
}
596-
} else {
597-
const auto& fsm = grammar_->per_rule_fsms[init_rule_id].value();
598-
const auto& edges = fsm.GetFsm().GetEdges(initial_state.element_id);
599-
for (const auto& edge : edges) {
600-
if (edge.IsCharRange()) {
601-
for (int c = edge.min; c <= edge.max; ++c) {
602-
first_character_mask[c] = true;
603-
}
604-
}
605-
}
606-
}
607-
return first_character_mask;
608-
}
609-
610647
/******************* GrammarCompilerNoCache *******************/
611648

612649
/*!

0 commit comments

Comments
 (0)