@@ -86,6 +86,13 @@ class GrammarMatcherForTokenMaskCache : public EarleyParser {
8686 const std::vector<std::pair<int32_t , std::string>>& sorted_decoded_vocab
8787 );
8888
89+ /* !
90+ * \brief Get the first character mask for the initial state. i.e. which characters can be
91+ * accepted as the first character in the initial state.
92+ * \return The first character mask.
93+ */
94+ std::bitset<256 > GetFirstCharacterMask ();
95+
8996 // The id of the initial rule.
9097 int32_t init_rule_id;
9198
@@ -371,10 +378,9 @@ bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterCheck(
371378 if (i < last_rejected_range) {
372379 if (fill_reject_indices) {
373380 tmp_rejected_indices_.push_back (i);
374- fill_reject_indices =
375- tmp_rejected_indices_.size () >= AdaptiveTokenMask::USE_BITSET_THRESHOLD
376- ? false
377- : fill_reject_indices;
381+ if (tmp_rejected_indices_.size () >= AdaptiveTokenMask::USE_BITSET_THRESHOLD) {
382+ fill_reject_indices = false ;
383+ }
378384 } else {
379385 i = last_rejected_range - 1 ;
380386 }
@@ -498,9 +504,9 @@ bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterCheck(
498504 for (int i = interval.second ; i < next_interval.first ; ++i) {
499505 tmp_rejected_indices_.push_back (i);
500506 }
501- fill_reject_indices = tmp_rejected_indices_.size () >= AdaptiveTokenMask::USE_BITSET_THRESHOLD
502- ? false
503- : fill_reject_indices;
507+ if ( tmp_rejected_indices_.size () >= AdaptiveTokenMask::USE_BITSET_THRESHOLD) {
508+ fill_reject_indices = false ;
509+ }
504510 }
505511 }
506512
@@ -532,6 +538,25 @@ AdaptiveTokenMask GrammarMatcherForTokenMaskCache::GetAdaptiveTokenMask(
532538 // the rule when matching until this character. Store it in a stack for later rollback.
533539 tmp_can_reach_end_stack_.push_back (false );
534540 tmp_can_reach_end_prefix_or_stack_.push_back (false );
541+ std::bitset<256 > first_character_mask = GetFirstCharacterMask ();
542+ bool rejected_indices_are_filled =
543+ GetTokenMaskWithFirstCharacterCheck (tokenizer_info, first_character_mask, is_root_rule);
544+ if (rejected_indices_are_filled) {
545+ return AdaptiveTokenMask (
546+ vocab_size,
547+ sorted_decoded_vocab,
548+ tmp_accepted_indices_,
549+ tmp_rejected_indices_,
550+ tmp_uncertain_indices_
551+ );
552+ } else {
553+ return AdaptiveTokenMask (
554+ vocab_size, sorted_decoded_vocab, tmp_accepted_indices_, tmp_uncertain_indices_
555+ );
556+ }
557+ }
558+
559+ std::bitset<256 > GrammarMatcherForTokenMaskCache::GetFirstCharacterMask () {
535560 std::bitset<256 > first_character_mask;
536561 const auto & sequence = grammar_->GetGrammarExpr (initial_state.sequence_id );
537562 if (!grammar_->per_rule_fsms [init_rule_id].has_value ()) {
@@ -579,21 +604,7 @@ AdaptiveTokenMask GrammarMatcherForTokenMaskCache::GetAdaptiveTokenMask(
579604 }
580605 }
581606 }
582- bool rejected_indices_are_filled =
583- GetTokenMaskWithFirstCharacterCheck (tokenizer_info, first_character_mask, is_root_rule);
584- if (rejected_indices_are_filled) {
585- return AdaptiveTokenMask (
586- vocab_size,
587- sorted_decoded_vocab,
588- tmp_accepted_indices_,
589- tmp_rejected_indices_,
590- tmp_uncertain_indices_
591- );
592- } else {
593- return AdaptiveTokenMask (
594- vocab_size, sorted_decoded_vocab, tmp_accepted_indices_, tmp_uncertain_indices_
595- );
596- }
607+ return first_character_mask;
597608}
598609
599610/* ****************** GrammarCompilerNoCache *******************/
0 commit comments