1010#include < cctype>
1111#include < cstddef>
1212#include < cstdint>
13+ #include < mutex>
1314#include < optional>
15+ #include < type_traits>
1416#include < unordered_map>
1517#include < utility>
1618#include < variant>
3032
3133namespace xgrammar {
3234
35+ struct EmptyHolder {
36+ EmptyHolder () = default ;
37+ template <typename ... Args>
38+ explicit EmptyHolder (Args&&...) {}
39+ };
40+
3341/* ************* AdaptiveTokenMaskCache Generator **************/
3442
3543/* ! \brief The concrete implementation of GrammarMatcherNode. */
@@ -57,7 +65,7 @@ class GrammarMatcherForTokenMaskCache : public EarleyParser {
5765 const std::vector<std::pair<int32_t , std::string>>& sorted_decoded_vocab,
5866 const std::vector<int32_t >& subtree_nodes_range,
5967 bool is_root_rule
60- );
68+ ) && ;
6169
6270 /* !
6371 * \brief Get the token mask for the given ParserState.
@@ -531,7 +539,7 @@ AdaptiveTokenMask GrammarMatcherForTokenMaskCache::GetAdaptiveTokenMask(
531539 const std::vector<std::pair<int32_t , std::string>>& sorted_decoded_vocab,
532540 const std::vector<int32_t >& subtree_nodes_range,
533541 bool is_root_rule
534- ) {
542+ ) && {
535543 tmp_accepted_indices_.clear ();
536544 tmp_rejected_indices_.clear ();
537545 tmp_uncertain_indices_.clear ();
@@ -590,17 +598,20 @@ AdaptiveTokenMask GrammarMatcherForTokenMaskCache::GetAdaptiveTokenMask(
590598 sorted_decoded_vocab, first_character_mask, subtree_nodes_range, is_root_rule
591599 );
592600 if (rejected_indices_are_filled) {
593- return AdaptiveTokenMask (
601+ return AdaptiveTokenMask{
594602 vocab_size,
595603 sorted_decoded_vocab,
596- tmp_accepted_indices_,
597- tmp_rejected_indices_,
598- tmp_uncertain_indices_
599- ) ;
604+ std::move ( tmp_accepted_indices_) ,
605+ std::move ( tmp_rejected_indices_) ,
606+ std::move ( tmp_uncertain_indices_),
607+ } ;
600608 } else {
601- return AdaptiveTokenMask (
602- vocab_size, sorted_decoded_vocab, tmp_accepted_indices_, tmp_uncertain_indices_
603- );
609+ return AdaptiveTokenMask{
610+ vocab_size,
611+ sorted_decoded_vocab,
612+ std::move (tmp_accepted_indices_),
613+ std::move (tmp_uncertain_indices_),
614+ };
604615 }
605616}
606617
@@ -611,8 +622,10 @@ AdaptiveTokenMask GrammarMatcherForTokenMaskCache::GetAdaptiveTokenMask(
611622 */
612623class GrammarCompilerNoCache {
613624 public:
614- GrammarCompilerNoCache (const TokenizerInfo& tokenizer_info, int max_threads)
615- : tokenizer_info_(tokenizer_info), max_threads_(max_threads) {}
625+ GrammarCompilerNoCache (const TokenizerInfo& tokenizer_info, const int max_threads)
626+ : tokenizer_info_(tokenizer_info) {
627+ if (max_threads > 1 ) thread_pool_.emplace (max_threads);
628+ }
616629
617630 CompiledGrammar CompileBuiltinJSONGrammar ();
618631
@@ -635,17 +648,28 @@ class GrammarCompilerNoCache {
635648
636649 private:
637650 /* ! \brief The main logic. Compile the grammar with multi-threading. */
638- CompiledGrammar MultiThreadCompileGrammar (Grammar grammar);
651+ CompiledGrammar MultiThreadCompileGrammar (Grammar grammar) {
652+ // dispatch based on whether thread_pool_ is set
653+ if (thread_pool_) {
654+ return MultiThreadCompileGrammarImpl<true >(std::move (grammar));
655+ } else {
656+ return MultiThreadCompileGrammarImpl<false >(std::move (grammar));
657+ }
658+ }
659+
660+ template <bool kUseMultiThread >
661+ CompiledGrammar MultiThreadCompileGrammarImpl (Grammar grammar);
639662
640663 /* ! \brief The vocabulary associated with this storage class. */
641664 const TokenizerInfo tokenizer_info_;
642665 /* ! \brief The maximum number of threads to use. */
643- const int max_threads_ ;
666+ std::optional<ThreadPool> thread_pool_ ;
644667 /* ! \brief Mapping from the rule_id to the definite accepted token mask. */
645668 std::unordered_map<int32_t , DynamicBitset> tag_dispatch_rule_id_to_second_slicing_bitset;
646669};
647670
648- CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar (Grammar grammar_unoptimized) {
671+ template <bool kUseMultiThread >
672+ CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammarImpl (Grammar grammar_unoptimized) {
649673 using GrammarExprType = Grammar::Impl::GrammarExprType;
650674
651675 auto compiled_grammar_impl = std::make_shared<CompiledGrammar::Impl>();
@@ -703,46 +727,43 @@ CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar gramma
703727 // 2. All byte strings (with element_in_string=0, 1, 2, ...)
704728 // since other positions will be expanded to the above positions
705729
706- // TODO(Charlie): Figure out how to support ThreadPool and std::mutex in WebAssembly.
707- // Only declare ThreadPool and mutex if max_threads > 1, so when max_threads = 1, we do
708- // not need ThreadPool or std::mutex, which throws error in runtime in WebAssembly.
709- std::optional<ThreadPool> thread_pool;
710- std::optional<std::mutex> adaptive_token_mask_cache_mutex;
711- if (max_threads_ > 1 ) {
712- thread_pool.emplace (max_threads_);
713- adaptive_token_mask_cache_mutex.emplace ();
714- }
730+ using TaskCounter = std::conditional_t <kUseMultiThread , ThreadPool::TaskCounter, EmptyHolder>;
731+ using TokenMaskMutex = std::conditional_t <kUseMultiThread , std::mutex, EmptyHolder>;
715732
716- auto add_adaptive_token_mask = [&](const ParserState& state, bool is_root_rule) {
733+ [[maybe_unused]]
734+ auto thread_pool = TaskCounter{*thread_pool_};
735+ [[maybe_unused]]
736+ auto adaptive_token_mask_cache_mutex = TokenMaskMutex{};
737+
738+ const auto get_adaptive_token_mask = [&](const ParserState& state, bool is_root_rule) {
717739 auto grammar_matcher = GrammarMatcherForTokenMaskCache (
718740 compiled_grammar_impl->grammar , state, tag_dispatch_rule_id_to_second_slicing_bitset, false
719741 );
720- auto cur_adaptive_token_mask_cache = grammar_matcher.GetAdaptiveTokenMask (
721- tokenizer_info_.GetVocabSize (),
722- tokenizer_info_.GetSortedDecodedVocab (),
723- tokenizer_info_.GetTrieSubtreeNodesRange (),
724- is_root_rule
725- );
726- if (max_threads_ > 1 ) {
727- std::lock_guard<std::mutex> lock (adaptive_token_mask_cache_mutex.value ());
728- compiled_grammar_impl->adaptive_token_mask_cache [state] = cur_adaptive_token_mask_cache;
729- } else {
730- compiled_grammar_impl->adaptive_token_mask_cache [state] = cur_adaptive_token_mask_cache;
731- }
742+ return std::move (grammar_matcher)
743+ .GetAdaptiveTokenMask (
744+ tokenizer_info_.GetVocabSize (),
745+ tokenizer_info_.GetSortedDecodedVocab (),
746+ tokenizer_info_.GetTrieSubtreeNodesRange (),
747+ is_root_rule
748+ );
732749 };
733750
734- auto add_task_adaptive_token_mask = [&](const ParserState& state, bool is_root_rule) {
751+ const auto add_task_adaptive_token_mask = [&](const ParserState& state, bool is_root_rule) {
735752 // Execute depending on whether we use thread_pool
736- if (max_threads_ > 1 ) {
737- thread_pool->Execute ([add_adaptive_token_mask, state, is_root_rule]() {
738- add_adaptive_token_mask (state, is_root_rule);
753+ if constexpr (kUseMultiThread ) {
754+ thread_pool.Submit ([=, &adaptive_token_mask_cache_mutex] {
755+ auto cache = get_adaptive_token_mask (state, is_root_rule);
756+ const auto lock = std::lock_guard{adaptive_token_mask_cache_mutex};
757+ compiled_grammar_impl->adaptive_token_mask_cache [state] = std::move (cache);
739758 });
740759 } else {
741- add_adaptive_token_mask (state, is_root_rule);
760+ compiled_grammar_impl->adaptive_token_mask_cache .try_emplace (
761+ state, get_adaptive_token_mask (state, is_root_rule)
762+ );
742763 }
743764 };
744765
745- auto root_rule_id = compiled_grammar_impl->grammar ->GetRootRuleId ();
766+ const auto root_rule_id = compiled_grammar_impl->grammar ->GetRootRuleId ();
746767
747768 for (int32_t rule_id = 0 ; rule_id < static_cast <int >(compiled_grammar_impl->grammar ->NumRules ());
748769 ++rule_id) {
@@ -796,8 +817,8 @@ CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar gramma
796817 }
797818 }
798819
799- if (max_threads_ > 1 ) {
800- thread_pool-> Join ();
820+ if constexpr ( kUseMultiThread ) {
821+ thread_pool. WaitUntilComplete ();
801822 }
802823
803824 return CompiledGrammar (compiled_grammar_impl);
0 commit comments