@@ -64,7 +64,7 @@ class GrammarMatcherForTokenMaskCache : public EarleyParser {
6464 * \returns True if the rejected indices are filled as usual, False otherwise.
6565 * It's used to determine which construction function will be used.
6666 */
67- bool GetTokenMaskWithFirstCharacterCheck (
67+ bool GetTokenMaskWithFirstCharacterOptimization (
6868 const TokenizerInfo& tokenizer_info,
6969 const std::bitset<256 >& first_char_mask,
7070 bool is_root_rule
@@ -93,6 +93,20 @@ class GrammarMatcherForTokenMaskCache : public EarleyParser {
9393 */
9494 std::bitset<256 > GetFirstCharacterMask ();
9595
96+ /* !
97+ * \brief Check all intervals for possible tokens.
98+ * \note All the possible tokens will be divided into accepted, rejected and uncertain tokens.
99+ */
100+ bool CheckAllPossibleTokens (
101+ const TokenizerInfo& tokenizer_info,
102+ const std::vector<std::pair<int32_t , int32_t >>& possible_intervals,
103+ bool speculative_calculation,
104+ const std::bitset<256 >& speculative_mask,
105+ const std::optional<const DynamicBitset*>& definite_accepted_bitset,
106+ bool is_root_rule,
107+ bool fill_reject_indices
108+ );
109+
96110 // The id of the initial rule.
97111 int32_t init_rule_id;
98112
@@ -322,12 +336,11 @@ std::pair<bool, std::bitset<256>> GrammarMatcherForTokenMaskCache::GetSpeculativ
322336 return {can_be_applied, speculative_mask};
323337}
324338
325- bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterCheck (
339+ bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterOptimization (
326340 const TokenizerInfo& tokenizer_info, const std::bitset<256 >& first_char_mask, bool is_root_rule
327341) {
328342 // the pair (a, b) means [a, b). Intialize the possible intervals.
329343 const auto & sorted_decoded_vocab = tokenizer_info.GetSortedDecodedVocab ();
330- const auto & subtree_nodes_range = tokenizer_info.GetTrieSubtreeNodesRange ();
331344 std::vector<std::pair<int32_t , int32_t >> possible_intervals;
332345 int possible_token_num =
333346 GetPossibleTokenIntervals (sorted_decoded_vocab, first_char_mask, possible_intervals);
@@ -357,9 +370,6 @@ bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterCheck(
357370 GetSpeculativeCalculation (sorted_decoded_vocab);
358371 }
359372
360- int prev_matched_size = 0 ;
361- int last_rejected_range = 0 ;
362- const bool & is_exact_lookahead = grammar_->GetRule (init_rule_id).is_exact_lookahead ;
363373 std::optional<const DynamicBitset*> definite_accepted_bitset = std::nullopt ;
364374 const bool is_tag_dispatch_rule =
365375 grammar_->GetGrammarExpr (grammar_->GetRule (init_rule_id).body_expr_id ).type ==
@@ -369,6 +379,125 @@ bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterCheck(
369379 definite_accepted_bitset = &tag_dispatch_rule_id_to_second_slicing_bitset.at (init_rule_id);
370380 }
371381
382+ fill_reject_indices = CheckAllPossibleTokens (
383+ tokenizer_info,
384+ possible_intervals,
385+ speculative_calculation,
386+ speculative_mask,
387+ definite_accepted_bitset,
388+ is_root_rule,
389+ fill_reject_indices
390+ );
391+
392+ if (possible_intervals.back ().second != static_cast <int >(sorted_decoded_vocab.size ()) &&
393+ fill_reject_indices) {
394+ // If the last interval is not closed, we need to reject the rest tokens.
395+ for (int i = possible_intervals.back ().second ;
396+ i < static_cast <int >(sorted_decoded_vocab.size ());
397+ ++i) {
398+ tmp_rejected_indices_.push_back (i);
399+ }
400+ }
401+
402+ return fill_reject_indices;
403+ }
404+
405+ AdaptiveTokenMask GrammarMatcherForTokenMaskCache::GetAdaptiveTokenMask (
406+ const TokenizerInfo& tokenizer_info, bool is_root_rule
407+ ) {
408+ const auto & sorted_decoded_vocab = tokenizer_info.GetSortedDecodedVocab ();
409+ const int vocab_size = tokenizer_info.GetVocabSize ();
410+ tmp_accepted_indices_.clear ();
411+ tmp_rejected_indices_.clear ();
412+ tmp_uncertain_indices_.clear ();
413+ // For every character in the current token, stores whether it is possible to reach the end of
414+ // the rule when matching until this character. Store it in a stack for later rollback.
415+ tmp_can_reach_end_stack_.push_back (false );
416+ tmp_can_reach_end_prefix_or_stack_.push_back (false );
417+ std::bitset<256 > first_character_mask = GetFirstCharacterMask ();
418+ bool rejected_indices_are_filled = GetTokenMaskWithFirstCharacterOptimization (
419+ tokenizer_info, first_character_mask, is_root_rule
420+ );
421+ if (rejected_indices_are_filled) {
422+ return AdaptiveTokenMask (
423+ vocab_size,
424+ sorted_decoded_vocab,
425+ tmp_accepted_indices_,
426+ tmp_rejected_indices_,
427+ tmp_uncertain_indices_
428+ );
429+ } else {
430+ return AdaptiveTokenMask (
431+ vocab_size, sorted_decoded_vocab, tmp_accepted_indices_, tmp_uncertain_indices_
432+ );
433+ }
434+ }
435+
436+ std::bitset<256 > GrammarMatcherForTokenMaskCache::GetFirstCharacterMask () {
437+ std::bitset<256 > first_character_mask;
438+ const auto & sequence = grammar_->GetGrammarExpr (initial_state.sequence_id );
439+ if (!grammar_->per_rule_fsms [init_rule_id].has_value ()) {
440+ const auto & sub_sequence = grammar_->GetGrammarExpr (sequence[initial_state.element_id ]);
441+ switch (sub_sequence.type ) {
442+ case Grammar::Impl::GrammarExprType::kByteString : {
443+ first_character_mask[sub_sequence[initial_state.sub_element_id ]] = true ;
444+ break ;
445+ }
446+ case xgrammar::Grammar::Impl::GrammarExprType::kCharacterClass :
447+ case xgrammar::Grammar::Impl::GrammarExprType::kCharacterClassStar : {
448+ if (initial_state.sub_element_id == 0 ) {
449+ bool is_negative = sub_sequence[0 ];
450+ for (int i = 1 ; i < sub_sequence.size (); i += 2 ) {
451+ int left_char = static_cast <uint8_t >(sub_sequence[i]);
452+ int right_char = static_cast <uint8_t >(sub_sequence[i + 1 ]);
453+ for (int c = left_char; c <= right_char; ++c) {
454+ first_character_mask[c] = true ;
455+ }
456+ }
457+ if (is_negative) {
458+ first_character_mask = ~first_character_mask;
459+ }
460+ break ;
461+ }
462+ // Otherwise, it's matching a UTF-8 character. We can optimize the matching process
463+ // here.
464+ for (size_t i = 0x80 ; i < 0xC0 ; ++i) {
465+ first_character_mask[i] = true ;
466+ }
467+ break ;
468+ }
469+ default : {
470+ XGRAMMAR_LOG (FATAL) << " Unsupported grammar expr type: " << static_cast <int >(sequence.type );
471+ }
472+ }
473+ } else {
474+ const auto & fsm = grammar_->per_rule_fsms [init_rule_id].value ();
475+ const auto & edges = fsm.GetFsm ().GetEdges (initial_state.element_id );
476+ for (const auto & edge : edges) {
477+ if (edge.IsCharRange ()) {
478+ for (int c = edge.min ; c <= edge.max ; ++c) {
479+ first_character_mask[c] = true ;
480+ }
481+ }
482+ }
483+ }
484+ return first_character_mask;
485+ }
486+
487+ bool GrammarMatcherForTokenMaskCache::CheckAllPossibleTokens (
488+ const TokenizerInfo& tokenizer_info,
489+ const std::vector<std::pair<int32_t , int32_t >>& possible_intervals,
490+ bool speculative_calculation,
491+ const std::bitset<256 >& speculative_mask,
492+ const std::optional<const DynamicBitset*>& definite_accepted_bitset,
493+ bool is_root_rule,
494+ bool fill_reject_indices
495+ ) {
496+ int prev_matched_size = 0 ;
497+ int last_rejected_range = 0 ;
498+ const auto & sorted_decoded_vocab = tokenizer_info.GetSortedDecodedVocab ();
499+ const auto & subtree_nodes_range = tokenizer_info.GetTrieSubtreeNodesRange ();
500+ const bool & is_exact_lookahead = grammar_->GetRule (init_rule_id).is_exact_lookahead ;
372501 const std::string* prev_token = nullptr ;
373502 for (size_t interval_idx = 0 ; interval_idx < possible_intervals.size (); ++interval_idx) {
374503 const auto & interval = possible_intervals[interval_idx];
@@ -512,101 +641,9 @@ bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterCheck(
512641
513642 // Rollback the last matched part.
514643 PopLastStates (prev_matched_size);
515-
516- if (possible_intervals.back ().second != static_cast <int >(sorted_decoded_vocab.size ()) &&
517- fill_reject_indices) {
518- // If the last interval is not closed, we need to reject the rest tokens.
519- for (int i = possible_intervals.back ().second ;
520- i < static_cast <int >(sorted_decoded_vocab.size ());
521- ++i) {
522- tmp_rejected_indices_.push_back (i);
523- }
524- }
525-
526644 return fill_reject_indices;
527645}
528646
529- AdaptiveTokenMask GrammarMatcherForTokenMaskCache::GetAdaptiveTokenMask (
530- const TokenizerInfo& tokenizer_info, bool is_root_rule
531- ) {
532- const auto & sorted_decoded_vocab = tokenizer_info.GetSortedDecodedVocab ();
533- const int vocab_size = tokenizer_info.GetVocabSize ();
534- tmp_accepted_indices_.clear ();
535- tmp_rejected_indices_.clear ();
536- tmp_uncertain_indices_.clear ();
537- // For every character in the current token, stores whether it is possible to reach the end of
538- // the rule when matching until this character. Store it in a stack for later rollback.
539- tmp_can_reach_end_stack_.push_back (false );
540- tmp_can_reach_end_prefix_or_stack_.push_back (false );
541- std::bitset<256 > first_character_mask = GetFirstCharacterMask ();
542- bool rejected_indices_are_filled =
543- GetTokenMaskWithFirstCharacterCheck (tokenizer_info, first_character_mask, is_root_rule);
544- if (rejected_indices_are_filled) {
545- return AdaptiveTokenMask (
546- vocab_size,
547- sorted_decoded_vocab,
548- tmp_accepted_indices_,
549- tmp_rejected_indices_,
550- tmp_uncertain_indices_
551- );
552- } else {
553- return AdaptiveTokenMask (
554- vocab_size, sorted_decoded_vocab, tmp_accepted_indices_, tmp_uncertain_indices_
555- );
556- }
557- }
558-
559- std::bitset<256 > GrammarMatcherForTokenMaskCache::GetFirstCharacterMask () {
560- std::bitset<256 > first_character_mask;
561- const auto & sequence = grammar_->GetGrammarExpr (initial_state.sequence_id );
562- if (!grammar_->per_rule_fsms [init_rule_id].has_value ()) {
563- const auto & sub_sequence = grammar_->GetGrammarExpr (sequence[initial_state.element_id ]);
564- switch (sub_sequence.type ) {
565- case Grammar::Impl::GrammarExprType::kByteString : {
566- first_character_mask[sub_sequence[initial_state.sub_element_id ]] = true ;
567- break ;
568- }
569- case xgrammar::Grammar::Impl::GrammarExprType::kCharacterClass :
570- case xgrammar::Grammar::Impl::GrammarExprType::kCharacterClassStar : {
571- if (initial_state.sub_element_id == 0 ) {
572- bool is_negative = sub_sequence[0 ];
573- for (int i = 1 ; i < sub_sequence.size (); i += 2 ) {
574- int left_char = static_cast <uint8_t >(sub_sequence[i]);
575- int right_char = static_cast <uint8_t >(sub_sequence[i + 1 ]);
576- for (int c = left_char; c <= right_char; ++c) {
577- first_character_mask[c] = true ;
578- }
579- }
580- if (is_negative) {
581- first_character_mask = ~first_character_mask;
582- }
583- break ;
584- }
585- // Otherwise, it's matching a UTF-8 character. We can optimize the matching process
586- // here.
587- for (size_t i = 0x80 ; i < 0xC0 ; ++i) {
588- first_character_mask[i] = true ;
589- }
590- break ;
591- }
592- default : {
593- XGRAMMAR_LOG (FATAL) << " Unsupported grammar expr type: " << static_cast <int >(sequence.type );
594- }
595- }
596- } else {
597- const auto & fsm = grammar_->per_rule_fsms [init_rule_id].value ();
598- const auto & edges = fsm.GetFsm ().GetEdges (initial_state.element_id );
599- for (const auto & edge : edges) {
600- if (edge.IsCharRange ()) {
601- for (int c = edge.min ; c <= edge.max ; ++c) {
602- first_character_mask[c] = true ;
603- }
604- }
605- }
606- }
607- return first_character_mask;
608- }
609-
610647/* ****************** GrammarCompilerNoCache *******************/
611648
612649/* !
0 commit comments