Skip to content

Commit a4fb64c

Browse files
committed
update first mask.
Signed-off-by: Yuchuan <[email protected]>
1 parent 59fa487 commit a4fb64c

File tree

1 file changed

+33
-22
lines changed

1 file changed

+33
-22
lines changed

cpp/grammar_compiler.cc

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,13 @@ class GrammarMatcherForTokenMaskCache : public EarleyParser {
8686
const std::vector<std::pair<int32_t, std::string>>& sorted_decoded_vocab
8787
);
8888

89+
/*!
90+
* \brief Get the first character mask for the initial state. i.e. which characters can be
91+
* accepted as the first character in the initial state.
92+
* \return The first character mask.
93+
*/
94+
std::bitset<256> GetFirstCharacterMask();
95+
8996
// The id of the initial rule.
9097
int32_t init_rule_id;
9198

@@ -371,10 +378,9 @@ bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterCheck(
371378
if (i < last_rejected_range) {
372379
if (fill_reject_indices) {
373380
tmp_rejected_indices_.push_back(i);
374-
fill_reject_indices =
375-
tmp_rejected_indices_.size() >= AdaptiveTokenMask::USE_BITSET_THRESHOLD
376-
? false
377-
: fill_reject_indices;
381+
if (tmp_rejected_indices_.size() >= AdaptiveTokenMask::USE_BITSET_THRESHOLD) {
382+
fill_reject_indices = false;
383+
}
378384
} else {
379385
i = last_rejected_range - 1;
380386
}
@@ -498,9 +504,9 @@ bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterCheck(
498504
for (int i = interval.second; i < next_interval.first; ++i) {
499505
tmp_rejected_indices_.push_back(i);
500506
}
501-
fill_reject_indices = tmp_rejected_indices_.size() >= AdaptiveTokenMask::USE_BITSET_THRESHOLD
502-
? false
503-
: fill_reject_indices;
507+
if (tmp_rejected_indices_.size() >= AdaptiveTokenMask::USE_BITSET_THRESHOLD) {
508+
fill_reject_indices = false;
509+
}
504510
}
505511
}
506512

@@ -532,6 +538,25 @@ AdaptiveTokenMask GrammarMatcherForTokenMaskCache::GetAdaptiveTokenMask(
532538
// the rule when matching until this character. Store it in a stack for later rollback.
533539
tmp_can_reach_end_stack_.push_back(false);
534540
tmp_can_reach_end_prefix_or_stack_.push_back(false);
541+
std::bitset<256> first_character_mask = GetFirstCharacterMask();
542+
bool rejected_indices_are_filled =
543+
GetTokenMaskWithFirstCharacterCheck(tokenizer_info, first_character_mask, is_root_rule);
544+
if (rejected_indices_are_filled) {
545+
return AdaptiveTokenMask(
546+
vocab_size,
547+
sorted_decoded_vocab,
548+
tmp_accepted_indices_,
549+
tmp_rejected_indices_,
550+
tmp_uncertain_indices_
551+
);
552+
} else {
553+
return AdaptiveTokenMask(
554+
vocab_size, sorted_decoded_vocab, tmp_accepted_indices_, tmp_uncertain_indices_
555+
);
556+
}
557+
}
558+
559+
std::bitset<256> GrammarMatcherForTokenMaskCache::GetFirstCharacterMask() {
535560
std::bitset<256> first_character_mask;
536561
const auto& sequence = grammar_->GetGrammarExpr(initial_state.sequence_id);
537562
if (!grammar_->per_rule_fsms[init_rule_id].has_value()) {
@@ -579,21 +604,7 @@ AdaptiveTokenMask GrammarMatcherForTokenMaskCache::GetAdaptiveTokenMask(
579604
}
580605
}
581606
}
582-
bool rejected_indices_are_filled =
583-
GetTokenMaskWithFirstCharacterCheck(tokenizer_info, first_character_mask, is_root_rule);
584-
if (rejected_indices_are_filled) {
585-
return AdaptiveTokenMask(
586-
vocab_size,
587-
sorted_decoded_vocab,
588-
tmp_accepted_indices_,
589-
tmp_rejected_indices_,
590-
tmp_uncertain_indices_
591-
);
592-
} else {
593-
return AdaptiveTokenMask(
594-
vocab_size, sorted_decoded_vocab, tmp_accepted_indices_, tmp_uncertain_indices_
595-
);
596-
}
607+
return first_character_mask;
597608
}
598609

599610
/******************* GrammarCompilerNoCache *******************/

0 commit comments

Comments
 (0)