diff --git a/cpp/grammar_compiler.cc b/cpp/grammar_compiler.cc index 4df5b847..1742a2fc 100644 --- a/cpp/grammar_compiler.cc +++ b/cpp/grammar_compiler.cc @@ -27,6 +27,7 @@ #include "support/thread_safe_cache.h" #include "support/utils.h" #include "xgrammar/grammar.h" +#include "xgrammar/tokenizer_info.h" namespace xgrammar { @@ -52,12 +53,7 @@ class GrammarMatcherForTokenMaskCache : public EarleyParser { * \param is_root_rule Whether to consider the parent rule. If false, there will be * no uncertain tokens. Useful for the root rule. */ - AdaptiveTokenMask GetAdaptiveTokenMask( - size_t vocab_size, - const std::vector>& sorted_decoded_vocab, - const std::vector& subtree_nodes_range, - bool is_root_rule - ); + AdaptiveTokenMask GetAdaptiveTokenMask(const TokenizerInfo& tokenizer_info, bool is_root_rule); /*! * \brief Get the token mask for the given ParserState. @@ -68,10 +64,9 @@ class GrammarMatcherForTokenMaskCache : public EarleyParser { * \returns True if the rejected indices are filled as usual, False otherwise. * It's used to determine which construction function will be used. */ - bool GetTokenMaskWithFirstCharacterCheck( - const std::vector>& sorted_decoded_vocab, + bool GetTokenMaskWithFirstCharacterOptimization( + const TokenizerInfo& tokenizer_info, const std::bitset<256>& first_char_mask, - const std::vector& subtree_nodes_range, bool is_root_rule ); @@ -82,7 +77,8 @@ class GrammarMatcherForTokenMaskCache : public EarleyParser { ); /*! - * \brief Check if speculative calculation will be applied. + * \brief Check if speculative calculation will be applied. It will detect self-recursive-like + * patterns, and utilize them to optimize the token mask calculation. * \return first: whether speculative calculation is applicable. * \return second: part of the first character mask, * which can be used in speculative calculation. @@ -91,21 +87,108 @@ class GrammarMatcherForTokenMaskCache : public EarleyParser { const std::vector>& sorted_decoded_vocab ); - // The id of the initial rule. + /*! + * \brief Get the first character mask for the initial state. i.e. which characters can be + * accepted as the first character in the initial state. + * \return The first character mask. + */ + std::bitset<256> GetFirstCharacterMask(); + + /*! + * \brief Check all intervals for possible tokens. + * \param tokenizer_info The tokenizer info. + * \param possible_intervals The possible intervals for tokens. + * \param speculative_calculation_applied Whether to use speculative calculation. + * \param speculative_mask The speculative mask for speculative calculation. + * \param definite_accepted_bitset The definite accepted bitset for TagDispatch rules. + * \param is_root_rule Whether to consider the parent rule. If false, there will be + * no uncertain tokens. Useful for the root rule. + * \param fill_reject_indices Whether to fill the rejected indices. + * \return True if the rejected indices are filled as usual, False otherwise. + * \note All the possible tokens will be divided into accepted, rejected and uncertain tokens. + */ + bool CheckAllPossibleTokens( + const TokenizerInfo& tokenizer_info, + const std::vector>& possible_intervals, + bool speculative_calculation_applied, + const std::bitset<256>& speculative_mask, + const std::optional& definite_accepted_bitset, + bool is_root_rule, + bool fill_reject_indices + ); + + /*! \brief Check each token in a given interval. + * \param tokenizer_info The tokenizer info. + * \param interval The interval to check. + * \param speculative_calculation_applied Whether to use speculative calculation. + * \param speculative_mask The speculative mask for speculative calculation. + * \param definite_accepted_bitset The definite accepted bitset for TagDispatch rules. + * \param is_root_rule Whether to consider the parent rule. If false, there will be + * no uncertain tokens. Useful for the root rule. + * \param fill_reject_indices Whether to fill the rejected indices. + * \param last_rejected_range The last rejected subtree range. If a token's index is less than + * this value, it will be rejected directly. + * \param prev_token The previous token parsed in the parser. + * \param prev_matched_size The matched size of the previous token. + * \return True if the rejected indices are filled as usual, False otherwise. + * \note All the tokens in the given interval will be divided into accepted, rejected and + * uncertain tokens. + */ + bool CheckTokensInInterval( + const TokenizerInfo& tokenizer_info, + const std::pair& interval, + bool speculative_calculation_applied, + const std::bitset<256>& speculative_mask, + const std::optional& definite_accepted_bitset, + bool is_root_rule, + bool fill_reject_indices, + int* last_rejected_range, + const std::string*& prev_token, + int* prev_matched_size + ); + + /*! \brief Apply speculative calculation for a token. + * \param token The token to check. + * \param index The index of the token in the vocabulary. + * \param speculative_mask The speculative mask for speculative calculation. + * \param definite_accepted_bitset The definite accepted bitset for TagDispatch rules. + * \return True if the token is accepted by speculative calculation, False otherwise. + */ + bool ApplySpeculativeCalculation( + const std::string& token, + int32_t index, + const std::bitset<256>& speculative_mask, + const std::optional& definite_accepted_bitset + ); + + /*! \brief Find the common prefix size with the previous token. + * \param token The current token. + * \param prev_token The previous token. + * \param prev_matched_size The matched size of the previous token. + * \param accepted Whether the current token is accepted. + */ + void FindCommonPrefixWithPreviousToken( + const std::string& token, + const std::string*& prev_token, + int* prev_matched_size, + bool* accepted + ); + + /*! \brief The id of the initial rule. */ int32_t init_rule_id; - // The initial state of the parser. + /*! \brief The initial state of the parser. */ ParserState initial_state; /*! - \brief This is a mapping from TagDispatch rule id to the bitset used for second slicing. - \note If a rule is a TagDispatch rule, then there will be an AC automaton for its triggers. - Which means that it can accept a lot of tokens. However, it will be slow to check a lot of - tokens. The DynamicBitset here is used to do a second slicing: if a token's substr(1, n - 1) - can be accepted by the start state of the AC automaton, then it will be True in the bitset. - When we check a token, we first check if its first character can transit to the start state. - If yes, then we check if it is in the bitset. If yes, then we accept it directly. - */ + * \brief This is a mapping from TagDispatch rule id to the bitset used for second slicing. + * \note If a rule is a TagDispatch rule, then there will be an AC automaton for its triggers. + * Which means that it can accept a lot of tokens. However, it will be slow to check a lot of + * tokens. The DynamicBitset here is used to do a second slicing: if a token's substr(1, n - 1) + * can be accepted by the start state of the AC automaton, then it will be True in the bitset. + * When we check a token, we first check if its first character can transit to the start state. + * If yes, then we check if it is in the bitset. If yes, then we accept it directly. + */ const std::unordered_map& tag_dispatch_rule_id_to_second_slicing_bitset; // Temporary data for GetAdaptiveTokenMask. @@ -320,13 +403,11 @@ std::pair> GrammarMatcherForTokenMaskCache::GetSpeculativ return {can_be_applied, speculative_mask}; } -bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterCheck( - const std::vector>& sorted_decoded_vocab, - const std::bitset<256>& first_char_mask, - const std::vector& subtree_nodes_range, - bool is_root_rule +bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterOptimization( + const TokenizerInfo& tokenizer_info, const std::bitset<256>& first_char_mask, bool is_root_rule ) { // the pair (a, b) means [a, b). Intialize the possible intervals. + const auto& sorted_decoded_vocab = tokenizer_info.GetSortedDecodedVocab(); std::vector> possible_intervals; int possible_token_num = GetPossibleTokenIntervals(sorted_decoded_vocab, first_char_mask, possible_intervals); @@ -344,21 +425,18 @@ bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterCheck( } } - bool speculative_calculation = false; + bool speculative_calculation_applied = false; std::bitset<256> speculative_mask; if (init_rule_id == -1 || !grammar_->per_rule_fsms[init_rule_id].has_value()) { - speculative_calculation = + speculative_calculation_applied = GetSpeculativeCalculation(sorted_decoded_vocab).first && (possible_token_num >= static_cast(sorted_decoded_vocab.size() / 4)); speculative_mask = first_char_mask; } else { - std::tie(speculative_calculation, speculative_mask) = + std::tie(speculative_calculation_applied, speculative_mask) = GetSpeculativeCalculation(sorted_decoded_vocab); } - int prev_matched_size = 0; - int last_rejected_range = 0; - const bool& is_exact_lookahead = grammar_->GetRule(init_rule_id).is_exact_lookahead; std::optional definite_accepted_bitset = std::nullopt; const bool is_tag_dispatch_rule = grammar_->GetGrammarExpr(grammar_->GetRule(init_rule_id).body_expr_id).type == @@ -368,150 +446,15 @@ bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterCheck( definite_accepted_bitset = &tag_dispatch_rule_id_to_second_slicing_bitset.at(init_rule_id); } - const std::string* prev_token = nullptr; - for (size_t interval_idx = 0; interval_idx < possible_intervals.size(); ++interval_idx) { - const auto& interval = possible_intervals[interval_idx]; - for (int i = interval.first; i < interval.second; ++i) { - // Check if the current token is in the rejected range. i.e. check if the current token - // is on the subtree of the rejected token. - if (i < last_rejected_range) { - if (fill_reject_indices) { - tmp_rejected_indices_.push_back(i); - fill_reject_indices = - tmp_rejected_indices_.size() >= AdaptiveTokenMask::USE_BITSET_THRESHOLD - ? false - : fill_reject_indices; - } else { - i = last_rejected_range - 1; - } - continue; - } - - const auto& token = sorted_decoded_vocab[i].second; - // This optimization is useful for simple self-recursive rules, like string content. - if (speculative_calculation) { - // Optimization for tag dispatch rules. - if (definite_accepted_bitset.has_value()) { - // If the token is empty, it must be accepted. - if (token.empty()) { - tmp_accepted_indices_.push_back(i); - continue; - } - // If the token doesn't contain tags or stop strings since the second character, and it - // will transit to the start state after consuming the first character, it must be - // accepted. - if (speculative_mask[static_cast(token[0])] && - (*definite_accepted_bitset.value())[i]) { - tmp_accepted_indices_.push_back(i); - continue; - } - } else { - bool all_accepted = true; - for (char ch : token) { - // If the first character is not the ascii character or can't be accepted by the - // first character mask, we need to check them in the parser. - if (isascii(ch) == 0 || !speculative_mask[static_cast(ch)]) { - all_accepted = false; - break; - } - } - if (all_accepted) { - tmp_accepted_indices_.push_back(i); - continue; - } - } - } - // Many tokens may contain the same prefix, so we will avoid unnecessary matching - // by finding the longest common prefix with the previous token. - bool accepted = true; - if (prev_token != nullptr) { - int lcp_len = - std::mismatch(token.begin(), token.end(), prev_token->begin(), prev_token->end()) - .first - - token.begin(); - if (lcp_len > prev_matched_size) { - // Case 1. The common prefix is rejected by the matcher in the last token. Reject - // directly. - accepted = false; - } else if (lcp_len < prev_matched_size) { - // Case 2. The common prefix is shorter than the previous matched size. Rollback - // the non-common part. - PopLastStates(prev_matched_size - lcp_len); - tmp_can_reach_end_stack_.erase( - tmp_can_reach_end_stack_.end() - (prev_matched_size - lcp_len), - tmp_can_reach_end_stack_.end() - ); - tmp_can_reach_end_prefix_or_stack_.erase( - tmp_can_reach_end_prefix_or_stack_.end() - (prev_matched_size - lcp_len), - tmp_can_reach_end_prefix_or_stack_.end() - ); - } - prev_matched_size = std::min(prev_matched_size, lcp_len); - } - - prev_token = &token; - - if (accepted) { - // Accept the rest chars one by one. - for (int j = prev_matched_size; j < static_cast(token.size()); ++j) { - if (!Advance(token[j])) { - accepted = false; - break; - } - tmp_can_reach_end_stack_.push_back(IsCompleted()); - tmp_can_reach_end_prefix_or_stack_.push_back( - tmp_can_reach_end_stack_.back() || tmp_can_reach_end_prefix_or_stack_.back() - ); - prev_matched_size = j + 1; - } - } - - bool can_reach_end = tmp_can_reach_end_prefix_or_stack_.back(); - - if (accepted) { - tmp_accepted_indices_.push_back(i); - } else { - auto lookahead_result_pair = IsTokenPassLookaheadAssertion(token, tmp_can_reach_end_stack_); - if (can_reach_end && !is_root_rule && lookahead_result_pair.first && - prev_matched_size > 0) { - // 1. If the current rule is the root rule (is_root_rule=true), there are no - // uncertain tokens. Not accepted tokens are just rejected. - // 2. If a token cannot pass the lookahead assertion, it is rejected. - if ((!lookahead_result_pair.second) && is_exact_lookahead) { - tmp_accepted_indices_.push_back(i); - } else { - tmp_uncertain_indices_.push_back(i); - // On the subtree, they are all uncertain tokens. - if (lookahead_result_pair.second) { - for (int j = i + 1; j < subtree_nodes_range[i]; ++j) { - tmp_uncertain_indices_.push_back(j); - } - i = subtree_nodes_range[i] - 1; // Skip the subtree nodes. - } - } - } else { - tmp_rejected_indices_.push_back(i); - last_rejected_range = subtree_nodes_range[i]; - fill_reject_indices = - tmp_rejected_indices_.size() >= AdaptiveTokenMask::USE_BITSET_THRESHOLD - ? false - : fill_reject_indices; - } - } - } - if (interval_idx != possible_intervals.size() - 1 && fill_reject_indices) { - const auto& next_interval = possible_intervals[interval_idx + 1]; - for (int i = interval.second; i < next_interval.first; ++i) { - tmp_rejected_indices_.push_back(i); - } - fill_reject_indices = tmp_rejected_indices_.size() >= AdaptiveTokenMask::USE_BITSET_THRESHOLD - ? false - : fill_reject_indices; - } - } - - // Rollback the last matched part. - PopLastStates(prev_matched_size); + fill_reject_indices = CheckAllPossibleTokens( + tokenizer_info, + possible_intervals, + speculative_calculation_applied, + speculative_mask, + definite_accepted_bitset, + is_root_rule, + fill_reject_indices + ); if (possible_intervals.back().second != static_cast(sorted_decoded_vocab.size()) && fill_reject_indices) { @@ -527,11 +470,10 @@ bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterCheck( } AdaptiveTokenMask GrammarMatcherForTokenMaskCache::GetAdaptiveTokenMask( - size_t vocab_size, - const std::vector>& sorted_decoded_vocab, - const std::vector& subtree_nodes_range, - bool is_root_rule + const TokenizerInfo& tokenizer_info, bool is_root_rule ) { + const auto& sorted_decoded_vocab = tokenizer_info.GetSortedDecodedVocab(); + const int vocab_size = tokenizer_info.GetVocabSize(); tmp_accepted_indices_.clear(); tmp_rejected_indices_.clear(); tmp_uncertain_indices_.clear(); @@ -539,6 +481,26 @@ AdaptiveTokenMask GrammarMatcherForTokenMaskCache::GetAdaptiveTokenMask( // the rule when matching until this character. Store it in a stack for later rollback. tmp_can_reach_end_stack_.push_back(false); tmp_can_reach_end_prefix_or_stack_.push_back(false); + std::bitset<256> first_character_mask = GetFirstCharacterMask(); + bool rejected_indices_are_filled = GetTokenMaskWithFirstCharacterOptimization( + tokenizer_info, first_character_mask, is_root_rule + ); + if (rejected_indices_are_filled) { + return AdaptiveTokenMask( + vocab_size, + sorted_decoded_vocab, + tmp_accepted_indices_, + tmp_rejected_indices_, + tmp_uncertain_indices_ + ); + } else { + return AdaptiveTokenMask( + vocab_size, sorted_decoded_vocab, tmp_accepted_indices_, tmp_uncertain_indices_ + ); + } +} + +std::bitset<256> GrammarMatcherForTokenMaskCache::GetFirstCharacterMask() { std::bitset<256> first_character_mask; const auto& sequence = grammar_->GetGrammarExpr(initial_state.sequence_id); if (!grammar_->per_rule_fsms[init_rule_id].has_value()) { @@ -586,22 +548,210 @@ AdaptiveTokenMask GrammarMatcherForTokenMaskCache::GetAdaptiveTokenMask( } } } - bool rejected_indices_are_filled = GetTokenMaskWithFirstCharacterCheck( - sorted_decoded_vocab, first_character_mask, subtree_nodes_range, is_root_rule - ); - if (rejected_indices_are_filled) { - return AdaptiveTokenMask( - vocab_size, - sorted_decoded_vocab, - tmp_accepted_indices_, - tmp_rejected_indices_, - tmp_uncertain_indices_ + return first_character_mask; +} + +bool GrammarMatcherForTokenMaskCache::CheckAllPossibleTokens( + const TokenizerInfo& tokenizer_info, + const std::vector>& possible_intervals, + bool speculative_calculation_applied, + const std::bitset<256>& speculative_mask, + const std::optional& definite_accepted_bitset, + bool is_root_rule, + bool fill_reject_indices +) { + int prev_matched_size = 0; + int last_rejected_range = 0; + const std::string* prev_token = nullptr; + for (size_t interval_idx = 0; interval_idx < possible_intervals.size(); ++interval_idx) { + const auto& interval = possible_intervals[interval_idx]; + fill_reject_indices = CheckTokensInInterval( + tokenizer_info, + interval, + speculative_calculation_applied, + speculative_mask, + definite_accepted_bitset, + is_root_rule, + fill_reject_indices, + &last_rejected_range, + prev_token, + &prev_matched_size ); + if (interval_idx != possible_intervals.size() - 1 && fill_reject_indices) { + const auto& next_interval = possible_intervals[interval_idx + 1]; + for (int i = interval.second; i < next_interval.first; ++i) { + tmp_rejected_indices_.push_back(i); + } + if (tmp_rejected_indices_.size() >= AdaptiveTokenMask::USE_BITSET_THRESHOLD) { + fill_reject_indices = false; + } + } + } + + // Rollback the last matched part. + PopLastStates(prev_matched_size); + return fill_reject_indices; +} + +bool GrammarMatcherForTokenMaskCache::ApplySpeculativeCalculation( + const std::string& token, + int32_t index, + const std::bitset<256>& speculative_mask, + const std::optional& definite_accepted_bitset +) { + // This optimization is useful for simple self-recursive rules, like string content. + // Optimization for tag dispatch rules. + if (definite_accepted_bitset.has_value()) { + // If the token is empty, it must be accepted. + if (token.empty()) { + tmp_accepted_indices_.push_back(index); + return true; + } + // If the token doesn't contain tags or stop strings since the second character, and it + // will transit to the start state after consuming the first character, it must be + // accepted. + if (speculative_mask[static_cast(token[0])] && + (*definite_accepted_bitset.value())[index]) { + tmp_accepted_indices_.push_back(index); + return true; + } } else { - return AdaptiveTokenMask( - vocab_size, sorted_decoded_vocab, tmp_accepted_indices_, tmp_uncertain_indices_ - ); + bool all_accepted = true; + for (char ch : token) { + // If the first character is not the ascii character or can't be accepted by the + // first character mask, we need to check them in the parser. + if (isascii(ch) == 0 || !speculative_mask[static_cast(ch)]) { + all_accepted = false; + break; + } + } + if (all_accepted) { + tmp_accepted_indices_.push_back(index); + return true; + } } + return false; +} + +void GrammarMatcherForTokenMaskCache::FindCommonPrefixWithPreviousToken( + const std::string& token, const std::string*& prev_token, int* prev_matched_size, bool* accepted +) { + if (prev_token != nullptr) { + int lcp_len = + std::mismatch(token.begin(), token.end(), prev_token->begin(), prev_token->end()).first - + token.begin(); + if (lcp_len > *prev_matched_size) { + // Case 1. The common prefix is rejected by the matcher in the last token. Reject + // directly. + *accepted = false; + } else if (lcp_len < *prev_matched_size) { + // Case 2. The common prefix is shorter than the previous matched size. Rollback + // the non-common part. + PopLastStates(*prev_matched_size - lcp_len); + tmp_can_reach_end_stack_.erase( + tmp_can_reach_end_stack_.end() - (*prev_matched_size - lcp_len), + tmp_can_reach_end_stack_.end() + ); + tmp_can_reach_end_prefix_or_stack_.erase( + tmp_can_reach_end_prefix_or_stack_.end() - (*prev_matched_size - lcp_len), + tmp_can_reach_end_prefix_or_stack_.end() + ); + } + *prev_matched_size = std::min(*prev_matched_size, lcp_len); + } +} + +bool GrammarMatcherForTokenMaskCache::CheckTokensInInterval( + const TokenizerInfo& tokenizer_info, + const std::pair& interval, + bool speculative_calculation_applied, + const std::bitset<256>& speculative_mask, + const std::optional& definite_accepted_bitset, + bool is_root_rule, + bool fill_reject_indices, + int* last_rejected_range, + const std::string*& prev_token, + int* prev_matched_size +) { + const auto& sorted_decoded_vocab = tokenizer_info.GetSortedDecodedVocab(); + const auto& subtree_nodes_range = tokenizer_info.GetTrieSubtreeNodesRange(); + const bool& is_exact_lookahead = grammar_->GetRule(init_rule_id).is_exact_lookahead; + for (int i = interval.first; i < interval.second; ++i) { + // Check if the current token is in the rejected range. i.e. check if the current token + // is on the subtree of the rejected token. + if (i < *last_rejected_range) { + if (fill_reject_indices) { + tmp_rejected_indices_.push_back(i); + if (tmp_rejected_indices_.size() >= AdaptiveTokenMask::USE_BITSET_THRESHOLD) { + fill_reject_indices = false; + } + } else { + i = *last_rejected_range - 1; + } + continue; + } + + const auto& token = sorted_decoded_vocab[i].second; + if (speculative_calculation_applied) { + bool speculative_accepted = + ApplySpeculativeCalculation(token, i, speculative_mask, definite_accepted_bitset); + if (speculative_accepted) { + continue; + } + } + // Many tokens may contain the same prefix, so we will avoid unnecessary matching + // by finding the longest common prefix with the previous token. + bool accepted = true; + FindCommonPrefixWithPreviousToken(token, prev_token, prev_matched_size, &accepted); + prev_token = &token; + + if (accepted) { + // Accept the rest chars one by one. + for (int j = *prev_matched_size; j < static_cast(token.size()); ++j) { + if (!Advance(token[j])) { + accepted = false; + break; + } + tmp_can_reach_end_stack_.push_back(IsCompleted()); + tmp_can_reach_end_prefix_or_stack_.push_back( + tmp_can_reach_end_stack_.back() || tmp_can_reach_end_prefix_or_stack_.back() + ); + *prev_matched_size = j + 1; + } + } + + bool can_reach_end = tmp_can_reach_end_prefix_or_stack_.back(); + + if (accepted) { + tmp_accepted_indices_.push_back(i); + } else { + auto lookahead_result_pair = IsTokenPassLookaheadAssertion(token, tmp_can_reach_end_stack_); + if (can_reach_end && !is_root_rule && lookahead_result_pair.first && *prev_matched_size > 0) { + // 1. If the current rule is the root rule (is_root_rule=true), there are no + // uncertain tokens. Not accepted tokens are just rejected. + // 2. If a token cannot pass the lookahead assertion, it is rejected. + if ((!lookahead_result_pair.second) && is_exact_lookahead) { + tmp_accepted_indices_.push_back(i); + } else { + tmp_uncertain_indices_.push_back(i); + // On the subtree, they are all uncertain tokens. + if (lookahead_result_pair.second) { + for (int j = i + 1; j < subtree_nodes_range[i]; ++j) { + tmp_uncertain_indices_.push_back(j); + } + i = subtree_nodes_range[i] - 1; // Skip the subtree nodes. + } + } + } else { + tmp_rejected_indices_.push_back(i); + *last_rejected_range = subtree_nodes_range[i]; + if (tmp_rejected_indices_.size() >= AdaptiveTokenMask::USE_BITSET_THRESHOLD) { + fill_reject_indices = false; + } + } + } + } + return fill_reject_indices; } /******************* GrammarCompilerNoCache *******************/ @@ -636,25 +786,45 @@ class GrammarCompilerNoCache { private: /*! \brief The main logic. Compile the grammar with multi-threading. */ CompiledGrammar MultiThreadCompileGrammar(Grammar grammar); + /*! \brief Optimization for TagDispatch: Precompute the definitely accepted tokens. */ + void TagDispatchOptimization( + decltype(std::make_shared()) compiled_grammar_impl, + std::unordered_map* tag_dispatch_rule_id_to_second_slicing_bitset + ); + /*! \brief Generate the token mask cache for all scannable states. */ + void GenerateTokenMaskCacheForScannableStates( + decltype(std::make_shared()) compiled_grammar_impl, + const std::unordered_map& + tag_dispatch_rule_id_to_second_slicing_bitset + ); /*! \brief The vocabulary associated with this storage class. */ const TokenizerInfo tokenizer_info_; /*! \brief The maximum number of threads to use. */ const int max_threads_; - /*! \brief Mapping from the rule_id to the definite accepted token mask. */ - std::unordered_map tag_dispatch_rule_id_to_second_slicing_bitset; }; CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar grammar_unoptimized) { - using GrammarExprType = Grammar::Impl::GrammarExprType; - auto compiled_grammar_impl = std::make_shared(); - compiled_grammar_impl->grammar = GrammarOptimizer::Apply(grammar_unoptimized); compiled_grammar_impl->tokenizer_info = tokenizer_info_; if (tokenizer_info_.GetVocabSize() == 0) { return CompiledGrammar(compiled_grammar_impl); } + std::unordered_map tag_dispatch_rule_id_to_second_slicing_bitset; + TagDispatchOptimization(compiled_grammar_impl, &tag_dispatch_rule_id_to_second_slicing_bitset); + GenerateTokenMaskCacheForScannableStates( + compiled_grammar_impl, tag_dispatch_rule_id_to_second_slicing_bitset + ); + return CompiledGrammar(compiled_grammar_impl); +} + +void GrammarCompilerNoCache::TagDispatchOptimization( + decltype(std::make_shared()) compiled_grammar_impl, + std::unordered_map* tag_dispatch_rule_id_to_second_slicing_bitset +) { + using GrammarExprType = Grammar::Impl::GrammarExprType; + tag_dispatch_rule_id_to_second_slicing_bitset->clear(); // Optimization for TagDispatch: Precompute the definitely accepted tokens. for (int i = 0; i < compiled_grammar_impl->grammar->NumRules(); i++) { @@ -695,15 +865,17 @@ CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar gramma definite_accepted_tokens_since_second_char.Set(i); } } - tag_dispatch_rule_id_to_second_slicing_bitset[i] = definite_accepted_tokens_since_second_char; + (*tag_dispatch_rule_id_to_second_slicing_bitset)[i] = + definite_accepted_tokens_since_second_char; } - // Step 3. Compute the adaptive token mask cache - // The token mask cache is computed for these positions in the grammar: - // 1. All character class or character class star (with last_utf8_bytes=0, 1, 2, 3) - // 2. All byte strings (with element_in_string=0, 1, 2, ...) - // since other positions will be expanded to the above positions +} + +void GrammarCompilerNoCache::GenerateTokenMaskCacheForScannableStates( + decltype(std::make_shared()) compiled_grammar_impl, + const std::unordered_map& tag_dispatch_rule_id_to_second_slicing_bitset +) { + using GrammarExprType = Grammar::Impl::GrammarExprType; - // TODO(Charlie): Figure out how to support ThreadPool and std::mutex in WebAssembly. // Only declare ThreadPool and mutex if max_threads > 1, so when max_threads = 1, we do // not need ThreadPool or std::mutex, which throws error in runtime in WebAssembly. std::optional thread_pool; @@ -712,17 +884,15 @@ CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar gramma thread_pool.emplace(max_threads_); adaptive_token_mask_cache_mutex.emplace(); } + // TODO(Charlie): Figure out how to support ThreadPool and std::mutex in WebAssembly. + // Function to add adaptive token mask for a given parser state. auto add_adaptive_token_mask = [&](const ParserState& state, bool is_root_rule) { auto grammar_matcher = GrammarMatcherForTokenMaskCache( compiled_grammar_impl->grammar, state, tag_dispatch_rule_id_to_second_slicing_bitset, false ); - auto cur_adaptive_token_mask_cache = grammar_matcher.GetAdaptiveTokenMask( - tokenizer_info_.GetVocabSize(), - tokenizer_info_.GetSortedDecodedVocab(), - tokenizer_info_.GetTrieSubtreeNodesRange(), - is_root_rule - ); + auto cur_adaptive_token_mask_cache = + grammar_matcher.GetAdaptiveTokenMask(tokenizer_info_, is_root_rule); if (max_threads_ > 1) { std::lock_guard lock(adaptive_token_mask_cache_mutex.value()); compiled_grammar_impl->adaptive_token_mask_cache[state] = cur_adaptive_token_mask_cache; @@ -744,6 +914,8 @@ CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar gramma auto root_rule_id = compiled_grammar_impl->grammar->GetRootRuleId(); + // Iterate through all rules and their scannable states to generate the adaptive token mask, + // since unscanable states will be expanded to the scannable states. for (int32_t rule_id = 0; rule_id < static_cast(compiled_grammar_impl->grammar->NumRules()); ++rule_id) { auto rule = compiled_grammar_impl->grammar->GetRule(rule_id); @@ -799,8 +971,6 @@ CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar gramma if (max_threads_ > 1) { thread_pool->Join(); } - - return CompiledGrammar(compiled_grammar_impl); } CompiledGrammar GrammarCompilerNoCache::CompileBuiltinJSONGrammar() {