Skip to content

Commit a83618d

Browse files
committed
feat: rewrite the thread pool; set hard limit
1 parent de241aa commit a83618d

File tree

5 files changed

+210
-217
lines changed

5 files changed

+210
-217
lines changed

cpp/grammar_compiler.cc

Lines changed: 66 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
#include <cctype>
1111
#include <cstddef>
1212
#include <cstdint>
13+
#include <mutex>
1314
#include <optional>
15+
#include <type_traits>
1416
#include <unordered_map>
1517
#include <utility>
1618
#include <variant>
@@ -30,6 +32,12 @@
3032

3133
namespace xgrammar {
3234

35+
struct EmptyHolder {
36+
EmptyHolder() = default;
37+
template <typename... Args>
38+
explicit EmptyHolder(Args&&...) {}
39+
};
40+
3341
/************** AdaptiveTokenMaskCache Generator **************/
3442

3543
/*! \brief The concrete implementation of GrammarMatcherNode. */
@@ -57,7 +65,7 @@ class GrammarMatcherForTokenMaskCache : public EarleyParser {
5765
const std::vector<std::pair<int32_t, std::string>>& sorted_decoded_vocab,
5866
const std::vector<int32_t>& subtree_nodes_range,
5967
bool is_root_rule
60-
);
68+
) &&;
6169

6270
/*!
6371
* \brief Get the token mask for the given ParserState.
@@ -531,7 +539,7 @@ AdaptiveTokenMask GrammarMatcherForTokenMaskCache::GetAdaptiveTokenMask(
531539
const std::vector<std::pair<int32_t, std::string>>& sorted_decoded_vocab,
532540
const std::vector<int32_t>& subtree_nodes_range,
533541
bool is_root_rule
534-
) {
542+
) && {
535543
tmp_accepted_indices_.clear();
536544
tmp_rejected_indices_.clear();
537545
tmp_uncertain_indices_.clear();
@@ -590,17 +598,20 @@ AdaptiveTokenMask GrammarMatcherForTokenMaskCache::GetAdaptiveTokenMask(
590598
sorted_decoded_vocab, first_character_mask, subtree_nodes_range, is_root_rule
591599
);
592600
if (rejected_indices_are_filled) {
593-
return AdaptiveTokenMask(
601+
return AdaptiveTokenMask{
594602
vocab_size,
595603
sorted_decoded_vocab,
596-
tmp_accepted_indices_,
597-
tmp_rejected_indices_,
598-
tmp_uncertain_indices_
599-
);
604+
std::move(tmp_accepted_indices_),
605+
std::move(tmp_rejected_indices_),
606+
std::move(tmp_uncertain_indices_),
607+
};
600608
} else {
601-
return AdaptiveTokenMask(
602-
vocab_size, sorted_decoded_vocab, tmp_accepted_indices_, tmp_uncertain_indices_
603-
);
609+
return AdaptiveTokenMask{
610+
vocab_size,
611+
sorted_decoded_vocab,
612+
std::move(tmp_accepted_indices_),
613+
std::move(tmp_uncertain_indices_),
614+
};
604615
}
605616
}
606617

@@ -611,8 +622,10 @@ AdaptiveTokenMask GrammarMatcherForTokenMaskCache::GetAdaptiveTokenMask(
611622
*/
612623
class GrammarCompilerNoCache {
613624
public:
614-
GrammarCompilerNoCache(const TokenizerInfo& tokenizer_info, int max_threads)
615-
: tokenizer_info_(tokenizer_info), max_threads_(max_threads) {}
625+
GrammarCompilerNoCache(const TokenizerInfo& tokenizer_info, const int max_threads)
626+
: tokenizer_info_(tokenizer_info) {
627+
if (max_threads > 1) thread_pool_.emplace(max_threads);
628+
}
616629

617630
CompiledGrammar CompileBuiltinJSONGrammar();
618631

@@ -635,17 +648,28 @@ class GrammarCompilerNoCache {
635648

636649
private:
637650
/*! \brief The main logic. Compile the grammar with multi-threading. */
638-
CompiledGrammar MultiThreadCompileGrammar(Grammar grammar);
651+
CompiledGrammar MultiThreadCompileGrammar(Grammar grammar) {
652+
// dispatch based on whether thread_pool_ is set
653+
if (thread_pool_) {
654+
return MultiThreadCompileGrammarImpl<true>(std::move(grammar));
655+
} else {
656+
return MultiThreadCompileGrammarImpl<false>(std::move(grammar));
657+
}
658+
}
659+
660+
template <bool kUseMultiThread>
661+
CompiledGrammar MultiThreadCompileGrammarImpl(Grammar grammar);
639662

640663
/*! \brief The vocabulary associated with this storage class. */
641664
const TokenizerInfo tokenizer_info_;
642665
/*! \brief The maximum number of threads to use. */
643-
const int max_threads_;
666+
std::optional<ThreadPool> thread_pool_;
644667
/*! \brief Mapping from the rule_id to the definite accepted token mask. */
645668
std::unordered_map<int32_t, DynamicBitset> tag_dispatch_rule_id_to_second_slicing_bitset;
646669
};
647670

648-
CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar grammar_unoptimized) {
671+
template <bool kUseMultiThread>
672+
CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammarImpl(Grammar grammar_unoptimized) {
649673
using GrammarExprType = Grammar::Impl::GrammarExprType;
650674

651675
auto compiled_grammar_impl = std::make_shared<CompiledGrammar::Impl>();
@@ -703,46 +727,43 @@ CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar gramma
703727
// 2. All byte strings (with element_in_string=0, 1, 2, ...)
704728
// since other positions will be expanded to the above positions
705729

706-
// TODO(Charlie): Figure out how to support ThreadPool and std::mutex in WebAssembly.
707-
// Only declare ThreadPool and mutex if max_threads > 1, so when max_threads = 1, we do
708-
// not need ThreadPool or std::mutex, which throws error in runtime in WebAssembly.
709-
std::optional<ThreadPool> thread_pool;
710-
std::optional<std::mutex> adaptive_token_mask_cache_mutex;
711-
if (max_threads_ > 1) {
712-
thread_pool.emplace(max_threads_);
713-
adaptive_token_mask_cache_mutex.emplace();
714-
}
730+
using TaskCounter = std::conditional_t<kUseMultiThread, ThreadPool::TaskCounter, EmptyHolder>;
731+
using TokenMaskMutex = std::conditional_t<kUseMultiThread, std::mutex, EmptyHolder>;
715732

716-
auto add_adaptive_token_mask = [&](const ParserState& state, bool is_root_rule) {
733+
[[maybe_unused]]
734+
auto thread_pool = TaskCounter{*thread_pool_};
735+
[[maybe_unused]]
736+
auto adaptive_token_mask_cache_mutex = TokenMaskMutex{};
737+
738+
const auto get_adaptive_token_mask = [&](const ParserState& state, bool is_root_rule) {
717739
auto grammar_matcher = GrammarMatcherForTokenMaskCache(
718740
compiled_grammar_impl->grammar, state, tag_dispatch_rule_id_to_second_slicing_bitset, false
719741
);
720-
auto cur_adaptive_token_mask_cache = grammar_matcher.GetAdaptiveTokenMask(
721-
tokenizer_info_.GetVocabSize(),
722-
tokenizer_info_.GetSortedDecodedVocab(),
723-
tokenizer_info_.GetTrieSubtreeNodesRange(),
724-
is_root_rule
725-
);
726-
if (max_threads_ > 1) {
727-
std::lock_guard<std::mutex> lock(adaptive_token_mask_cache_mutex.value());
728-
compiled_grammar_impl->adaptive_token_mask_cache[state] = cur_adaptive_token_mask_cache;
729-
} else {
730-
compiled_grammar_impl->adaptive_token_mask_cache[state] = cur_adaptive_token_mask_cache;
731-
}
742+
return std::move(grammar_matcher)
743+
.GetAdaptiveTokenMask(
744+
tokenizer_info_.GetVocabSize(),
745+
tokenizer_info_.GetSortedDecodedVocab(),
746+
tokenizer_info_.GetTrieSubtreeNodesRange(),
747+
is_root_rule
748+
);
732749
};
733750

734-
auto add_task_adaptive_token_mask = [&](const ParserState& state, bool is_root_rule) {
751+
const auto add_task_adaptive_token_mask = [&](const ParserState& state, bool is_root_rule) {
735752
// Execute depending on whether we use thread_pool
736-
if (max_threads_ > 1) {
737-
thread_pool->Execute([add_adaptive_token_mask, state, is_root_rule]() {
738-
add_adaptive_token_mask(state, is_root_rule);
753+
if constexpr (kUseMultiThread) {
754+
thread_pool.Submit([=, &adaptive_token_mask_cache_mutex] {
755+
auto cache = get_adaptive_token_mask(state, is_root_rule);
756+
const auto lock = std::lock_guard{adaptive_token_mask_cache_mutex};
757+
compiled_grammar_impl->adaptive_token_mask_cache[state] = std::move(cache);
739758
});
740759
} else {
741-
add_adaptive_token_mask(state, is_root_rule);
760+
compiled_grammar_impl->adaptive_token_mask_cache.try_emplace(
761+
state, get_adaptive_token_mask(state, is_root_rule)
762+
);
742763
}
743764
};
744765

745-
auto root_rule_id = compiled_grammar_impl->grammar->GetRootRuleId();
766+
const auto root_rule_id = compiled_grammar_impl->grammar->GetRootRuleId();
746767

747768
for (int32_t rule_id = 0; rule_id < static_cast<int>(compiled_grammar_impl->grammar->NumRules());
748769
++rule_id) {
@@ -796,8 +817,8 @@ CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar gramma
796817
}
797818
}
798819

799-
if (max_threads_ > 1) {
800-
thread_pool->Join();
820+
if constexpr (kUseMultiThread) {
821+
thread_pool.WaitUntilComplete();
801822
}
802823

803824
return CompiledGrammar(compiled_grammar_impl);

cpp/grammar_matcher.cc

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,12 @@ class BatchGrammarMatcher::Impl {
389389
bool debug_print
390390
);
391391

392+
void InitThreadPoolOnce() {
393+
if (!thread_pool_.has_value() && max_threads_ > 1) {
394+
thread_pool_.emplace(max_threads_);
395+
}
396+
}
397+
392398
private:
393399
std::optional<ThreadPool> thread_pool_ = std::nullopt;
394400
int32_t max_threads_ = 1;
@@ -912,33 +918,25 @@ void BatchGrammarMatcher::Impl::BatchFillNextTokenBitmask(
912918
XGRAMMAR_CHECK(!indices.has_value() || indices->size() == matchers->size())
913919
<< "The size of indices (" << (indices.has_value() ? indices->size() : 0)
914920
<< ") should be the same as the size of matchers (" << matchers->size() << ").";
915-
// Initialize the thread pool if needed. It should be initialized each time,
916-
// because ThreadPool cannot be reused after Join().
917-
if (max_threads_ > 1) {
918-
thread_pool_.emplace(max_threads_);
919-
}
921+
this->InitThreadPoolOnce();
922+
auto fill_next_token_mask = [&](int i) {
923+
auto& matcher = (*matchers)[i];
924+
int index = indices.has_value() ? (*indices)[i] : i;
925+
XGRAMMAR_CHECK(index >= 0 && index < next_token_bitmask->shape[0])
926+
<< "The index " << index << " is out of range [0, " << next_token_bitmask->shape[0]
927+
<< ") for batch_id " << i << ".";
928+
matcher->FillNextTokenBitmask(next_token_bitmask, index, debug_print);
929+
};
920930
if (!thread_pool_.has_value()) {
921931
for (int i = 0; i < static_cast<int32_t>(matchers->size()); i++) {
922-
auto& matcher = (*matchers)[i];
923-
int index = indices.has_value() ? (*indices)[i] : i;
924-
XGRAMMAR_CHECK(index >= 0 && index < next_token_bitmask->shape[0])
925-
<< "The index " << index << " is out of range [0, " << next_token_bitmask->shape[0]
926-
<< ") for batch_id " << i << ".";
927-
matcher->FillNextTokenBitmask(next_token_bitmask, index, debug_print);
932+
fill_next_token_mask(i);
928933
}
929934
} else {
930-
auto fill_next_token_mask = [&](int32_t batch_id) {
931-
auto& matcher = (*matchers)[batch_id];
932-
int index = indices.has_value() ? (*indices)[batch_id] : batch_id;
933-
XGRAMMAR_CHECK(index >= 0 && index < next_token_bitmask->shape[0])
934-
<< "The index " << index << " is out of range [0, " << next_token_bitmask->shape[0]
935-
<< ") for batch_id " << batch_id << ".";
936-
matcher->FillNextTokenBitmask(next_token_bitmask, index, debug_print);
937-
};
935+
auto task_counter = thread_pool_->CreateTaskCounter();
938936
for (int i = 0; i < static_cast<int32_t>(matchers->size()); i++) {
939-
thread_pool_->Execute([fill_next_token_mask, i]() { fill_next_token_mask(i); });
937+
task_counter.Submit([fill_next_token_mask, i] { fill_next_token_mask(i); });
940938
}
941-
thread_pool_->Join();
939+
task_counter.WaitUntilComplete();
942940
}
943941
}
944942

0 commit comments

Comments
 (0)