diff --git a/cpp/grammar_compiler.cc b/cpp/grammar_compiler.cc index 4df5b847..06bfc940 100644 --- a/cpp/grammar_compiler.cc +++ b/cpp/grammar_compiler.cc @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -30,6 +31,12 @@ namespace xgrammar { +struct EmptyHolder { + EmptyHolder() = default; + template + explicit EmptyHolder(Args&&...) {} +}; + /************** AdaptiveTokenMaskCache Generator **************/ /*! \brief The concrete implementation of GrammarMatcherNode. */ @@ -57,7 +64,7 @@ class GrammarMatcherForTokenMaskCache : public EarleyParser { const std::vector>& sorted_decoded_vocab, const std::vector& subtree_nodes_range, bool is_root_rule - ); + ) &&; /*! * \brief Get the token mask for the given ParserState. @@ -531,7 +538,7 @@ AdaptiveTokenMask GrammarMatcherForTokenMaskCache::GetAdaptiveTokenMask( const std::vector>& sorted_decoded_vocab, const std::vector& subtree_nodes_range, bool is_root_rule -) { +) && { tmp_accepted_indices_.clear(); tmp_rejected_indices_.clear(); tmp_uncertain_indices_.clear(); @@ -590,17 +597,20 @@ AdaptiveTokenMask GrammarMatcherForTokenMaskCache::GetAdaptiveTokenMask( sorted_decoded_vocab, first_character_mask, subtree_nodes_range, is_root_rule ); if (rejected_indices_are_filled) { - return AdaptiveTokenMask( + return AdaptiveTokenMask{ vocab_size, sorted_decoded_vocab, - tmp_accepted_indices_, - tmp_rejected_indices_, - tmp_uncertain_indices_ - ); + std::move(tmp_accepted_indices_), + std::move(tmp_rejected_indices_), + std::move(tmp_uncertain_indices_), + }; } else { - return AdaptiveTokenMask( - vocab_size, sorted_decoded_vocab, tmp_accepted_indices_, tmp_uncertain_indices_ - ); + return AdaptiveTokenMask{ + vocab_size, + sorted_decoded_vocab, + std::move(tmp_accepted_indices_), + std::move(tmp_uncertain_indices_), + }; } } @@ -611,8 +621,10 @@ AdaptiveTokenMask GrammarMatcherForTokenMaskCache::GetAdaptiveTokenMask( */ class GrammarCompilerNoCache { public: - GrammarCompilerNoCache(const TokenizerInfo& tokenizer_info, int max_threads) - : tokenizer_info_(tokenizer_info), max_threads_(max_threads) {} + GrammarCompilerNoCache(const TokenizerInfo& tokenizer_info, const int max_threads) + : tokenizer_info_(tokenizer_info) { + if (max_threads > 1) thread_pool_.emplace(max_threads); + } CompiledGrammar CompileBuiltinJSONGrammar(); @@ -635,17 +647,28 @@ class GrammarCompilerNoCache { private: /*! \brief The main logic. Compile the grammar with multi-threading. */ - CompiledGrammar MultiThreadCompileGrammar(Grammar grammar); + CompiledGrammar MultiThreadCompileGrammar(Grammar grammar) { + // dispatch based on whether thread_pool_ is set + if (thread_pool_) { + return MultiThreadCompileGrammarImpl(std::move(grammar)); + } else { + return MultiThreadCompileGrammarImpl(std::move(grammar)); + } + } + + template + CompiledGrammar MultiThreadCompileGrammarImpl(Grammar grammar); /*! \brief The vocabulary associated with this storage class. */ const TokenizerInfo tokenizer_info_; /*! \brief The maximum number of threads to use. */ - const int max_threads_; + std::optional thread_pool_; /*! \brief Mapping from the rule_id to the definite accepted token mask. */ std::unordered_map tag_dispatch_rule_id_to_second_slicing_bitset; }; -CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar grammar_unoptimized) { +template +CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammarImpl(Grammar grammar_unoptimized) { using GrammarExprType = Grammar::Impl::GrammarExprType; auto compiled_grammar_impl = std::make_shared(); @@ -702,47 +725,46 @@ CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar gramma // 1. All character class or character class star (with last_utf8_bytes=0, 1, 2, 3) // 2. All byte strings (with element_in_string=0, 1, 2, ...) // since other positions will be expanded to the above positions + [[maybe_unused]] auto task_counter = [&] { + if constexpr (kUseMultiThread) { + return thread_pool_->CreateTaskCounter(); + } else { + return 0; + } + }(); - // TODO(Charlie): Figure out how to support ThreadPool and std::mutex in WebAssembly. - // Only declare ThreadPool and mutex if max_threads > 1, so when max_threads = 1, we do - // not need ThreadPool or std::mutex, which throws error in runtime in WebAssembly. - std::optional thread_pool; - std::optional adaptive_token_mask_cache_mutex; - if (max_threads_ > 1) { - thread_pool.emplace(max_threads_); - adaptive_token_mask_cache_mutex.emplace(); - } - - auto add_adaptive_token_mask = [&](const ParserState& state, bool is_root_rule) { + const auto get_adaptive_token_mask = [&](const ParserState& state, bool is_root_rule) { auto grammar_matcher = GrammarMatcherForTokenMaskCache( compiled_grammar_impl->grammar, state, tag_dispatch_rule_id_to_second_slicing_bitset, false ); - auto cur_adaptive_token_mask_cache = grammar_matcher.GetAdaptiveTokenMask( - tokenizer_info_.GetVocabSize(), - tokenizer_info_.GetSortedDecodedVocab(), - tokenizer_info_.GetTrieSubtreeNodesRange(), - is_root_rule - ); - if (max_threads_ > 1) { - std::lock_guard lock(adaptive_token_mask_cache_mutex.value()); - compiled_grammar_impl->adaptive_token_mask_cache[state] = cur_adaptive_token_mask_cache; - } else { - compiled_grammar_impl->adaptive_token_mask_cache[state] = cur_adaptive_token_mask_cache; - } + return std::move(grammar_matcher) + .GetAdaptiveTokenMask( + tokenizer_info_.GetVocabSize(), + tokenizer_info_.GetSortedDecodedVocab(), + tokenizer_info_.GetTrieSubtreeNodesRange(), + is_root_rule + ); }; - auto add_task_adaptive_token_mask = [&](const ParserState& state, bool is_root_rule) { + const auto add_task_adaptive_token_mask = [&](const ParserState& state, bool is_root_rule) { // Execute depending on whether we use thread_pool - if (max_threads_ > 1) { - thread_pool->Execute([add_adaptive_token_mask, state, is_root_rule]() { - add_adaptive_token_mask(state, is_root_rule); - }); + if constexpr (kUseMultiThread) { + task_counter.Submit( + // parallel part: construct the mask + [=] { return get_adaptive_token_mask(state, is_root_rule); }, + // protected region, insert into the cache + [=](AdaptiveTokenMask&& cache) { + compiled_grammar_impl->adaptive_token_mask_cache.try_emplace(state, std::move(cache)); + } + ); } else { - add_adaptive_token_mask(state, is_root_rule); + compiled_grammar_impl->adaptive_token_mask_cache.try_emplace( + state, get_adaptive_token_mask(state, is_root_rule) + ); } }; - auto root_rule_id = compiled_grammar_impl->grammar->GetRootRuleId(); + const auto root_rule_id = compiled_grammar_impl->grammar->GetRootRuleId(); for (int32_t rule_id = 0; rule_id < static_cast(compiled_grammar_impl->grammar->NumRules()); ++rule_id) { @@ -796,8 +818,8 @@ CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar gramma } } - if (max_threads_ > 1) { - thread_pool->Join(); + if constexpr (kUseMultiThread) { + task_counter.WaitUntilComplete(); } return CompiledGrammar(compiled_grammar_impl); diff --git a/cpp/grammar_matcher.cc b/cpp/grammar_matcher.cc index 60b3e520..ac9aee3d 100644 --- a/cpp/grammar_matcher.cc +++ b/cpp/grammar_matcher.cc @@ -389,6 +389,12 @@ class BatchGrammarMatcher::Impl { bool debug_print ); + void InitThreadPoolOnce() { + if (!thread_pool_.has_value() && max_threads_ > 1) { + thread_pool_.emplace(max_threads_); + } + } + private: std::optional thread_pool_ = std::nullopt; int32_t max_threads_ = 1; @@ -912,33 +918,25 @@ void BatchGrammarMatcher::Impl::BatchFillNextTokenBitmask( XGRAMMAR_CHECK(!indices.has_value() || indices->size() == matchers->size()) << "The size of indices (" << (indices.has_value() ? indices->size() : 0) << ") should be the same as the size of matchers (" << matchers->size() << ")."; - // Initialize the thread pool if needed. It should be initialized each time, - // because ThreadPool cannot be reused after Join(). - if (max_threads_ > 1) { - thread_pool_.emplace(max_threads_); - } + this->InitThreadPoolOnce(); + auto fill_next_token_mask = [&](int i) { + auto& matcher = (*matchers)[i]; + int index = indices.has_value() ? (*indices)[i] : i; + XGRAMMAR_CHECK(index >= 0 && index < next_token_bitmask->shape[0]) + << "The index " << index << " is out of range [0, " << next_token_bitmask->shape[0] + << ") for batch_id " << i << "."; + matcher->FillNextTokenBitmask(next_token_bitmask, index, debug_print); + }; if (!thread_pool_.has_value()) { for (int i = 0; i < static_cast(matchers->size()); i++) { - auto& matcher = (*matchers)[i]; - int index = indices.has_value() ? (*indices)[i] : i; - XGRAMMAR_CHECK(index >= 0 && index < next_token_bitmask->shape[0]) - << "The index " << index << " is out of range [0, " << next_token_bitmask->shape[0] - << ") for batch_id " << i << "."; - matcher->FillNextTokenBitmask(next_token_bitmask, index, debug_print); + fill_next_token_mask(i); } } else { - auto fill_next_token_mask = [&](int32_t batch_id) { - auto& matcher = (*matchers)[batch_id]; - int index = indices.has_value() ? (*indices)[batch_id] : batch_id; - XGRAMMAR_CHECK(index >= 0 && index < next_token_bitmask->shape[0]) - << "The index " << index << " is out of range [0, " << next_token_bitmask->shape[0] - << ") for batch_id " << batch_id << "."; - matcher->FillNextTokenBitmask(next_token_bitmask, index, debug_print); - }; + auto task_counter = thread_pool_->CreateTaskCounter(); for (int i = 0; i < static_cast(matchers->size()); i++) { - thread_pool_->Execute([fill_next_token_mask, i]() { fill_next_token_mask(i); }); + task_counter.Submit([fill_next_token_mask, i] { fill_next_token_mask(i); }); } - thread_pool_->Join(); + task_counter.WaitUntilComplete(); } } diff --git a/cpp/support/thread_pool.h b/cpp/support/thread_pool.h index 0e7064b7..68b2882e 100644 --- a/cpp/support/thread_pool.h +++ b/cpp/support/thread_pool.h @@ -6,13 +6,14 @@ #ifndef XGRAMMAR_SUPPORT_THREAD_POOL_H_ #define XGRAMMAR_SUPPORT_THREAD_POOL_H_ +#include #include +#include #include -#include +#include #include #include #include -#include #include #include "logging.h" @@ -28,6 +29,92 @@ namespace xgrammar { */ class ThreadPool { public: + struct TaskCounter { + public: + // A dummy callback that does nothing + inline static constexpr auto kCallback = [] {}; + inline static constexpr auto kNoLimit = std::numeric_limits::max(); + + explicit TaskCounter(ThreadPool& pool) : TaskCounter(pool, GetLimit(pool)) {} + explicit TaskCounter(ThreadPool& pool, std::size_t limit) + : m_active(0), m_pool(pool), m_rate_limit(limit) { + XGRAMMAR_CHECK(m_rate_limit > 0) << "TaskCounter rate limit must be greater than zero."; + m_pool.active_tasks_++; + } + + // not copy/moveable + TaskCounter(const TaskCounter&) = delete; + TaskCounter(TaskCounter&&) = delete; + TaskCounter& operator=(const TaskCounter&) = delete; + TaskCounter& operator=(TaskCounter&&) = delete; + + void WaitUntilComplete() { + auto lock = std::unique_lock{m_mutex}; + m_cv.wait(lock, [this] { return m_active == 0; }); + } + + template + void Submit(F&& f, C&& c = kCallback) { + using ResultType = std::invoke_result_t; + static_assert( + std::is_void_v || std::is_invocable_v, + "Callback must be invocable with the result of the task." + ); + + // real task to be executed by the thread pool + auto fn = std::function{[this, task = std::forward(f), callback = std::forward(c)] { + if constexpr (std::is_void_v) { + task(); + { + const auto lock = std::lock_guard{m_mutex}; + callback(); + m_active -= 1; + } + m_cv.notify_all(); + } else { + auto result = task(); + { + const auto lock = std::lock_guard{m_mutex}; + callback(std::move(result)); + m_active -= 1; + } + m_cv.notify_all(); + } + }}; + + // rate limiting before submitting the task + { + auto lock = std::unique_lock{m_mutex}; + m_active += 1; + m_cv.wait(lock, [this] { return m_active <= m_rate_limit; }); + } + + // emplace the task into the thread pool + { + const auto lock = std::lock_guard{m_pool.queue_mutex_}; + m_pool.task_queue_.push(std::move(fn)); + } + m_pool.queue_condition_.notify_one(); + } + + ~TaskCounter() { + this->WaitUntilComplete(); + m_pool.active_tasks_--; + } + + private: + friend class ThreadPool; + + // default no limit, yet we can still implement rate limiting if needed + static std::size_t GetLimit([[maybe_unused]] ThreadPool& pool) { return kNoLimit; } + + std::size_t m_active; + std::condition_variable m_cv; + std::mutex m_mutex; + ThreadPool& m_pool; + const std::size_t m_rate_limit; + }; + /*! * \brief Construct a new thread pool with the specified number of threads. * \param num_threads Number of worker threads to create. Defaults to hardware concurrency. @@ -35,8 +122,9 @@ class ThreadPool { */ ThreadPool(size_t num_threads = std::thread::hardware_concurrency()) { // Initialize thread pool with num_threads threads - for (size_t i = 0; i < num_threads; ++i) { - workers_.emplace_back([this] { + workers_.resize(num_threads); + for (auto& worker : workers_) { + worker = std::thread([this] { while (true) { std::function task; { @@ -52,69 +140,14 @@ class ThreadPool { task_queue_.pop(); } task(); - TaskComplete(); } }); } } - /*! - * \brief Add a new task to be executed by the thread pool. - * \tparam F Type of the function to execute - * \tparam Args Types of the arguments to pass to the function - * \param f Function to execute - * \param args Arguments to pass to the function - * \return std::shared_future containing the result of the function call - * \note Tasks are executed in FIFO order but may complete in any order. - */ - template - auto Submit(F&& f, Args&&... args) -> std::shared_future> { - using return_type = std::invoke_result_t; - - // Package the task with its arguments into a shared pointer - auto task = std::make_shared>( - std::bind(std::forward(f), std::forward(args)...) - ); - - std::shared_future res = task->get_future().share(); - - { - std::unique_lock lock(queue_mutex_); - XGRAMMAR_CHECK(!shutdown_) << "Cannot submit task to stopped ThreadPool"; - ++unfinished_task_count_; // Increment task count - - // Directly add the task without wrapping - task_queue_.emplace([task]() { (*task)(); }); - } - queue_condition_.notify_one(); - return res; - } - - /*! - * \brief Add a new task to be executed by the thread pool without returning a future. - * \tparam F Type of the function to execute - * \tparam Args Types of the arguments to pass to the function - * \param f Function to execute - * \param args Arguments to pass to the function - * \note Tasks are executed asynchronously by the worker threads. - */ - template - void Execute(F&& f, Args&&... args) { - { - std::unique_lock lock(queue_mutex_); - XGRAMMAR_CHECK(!shutdown_) << "Cannot execute task in stopped ThreadPool"; - ++unfinished_task_count_; // Increment task count - - // Directly add the task without wrapping - task_queue_.emplace(std::bind(std::forward(f), std::forward(args)...)); - } - queue_condition_.notify_one(); - } - - void Wait() { - std::unique_lock lock(queue_mutex_); - tasks_done_condition_.wait(lock, [this] { return unfinished_task_count_ == 0; }); - } + TaskCounter CreateTaskCounter() { return TaskCounter{*this}; } + TaskCounter CreateTaskCounterWithLimit(std::size_t limit) { return TaskCounter{*this, limit}; } + std::size_t NumThreads() const { return workers_.size(); } /*! * \brief Join all threads in the pool. @@ -139,7 +172,10 @@ class ThreadPool { /*! * \brief Destructor that ensures graceful shutdown of the thread pool. */ - ~ThreadPool() { Join(); } + ~ThreadPool() { + Join(); + XGRAMMAR_CHECK(active_tasks_ == 0) << "ThreadPool destroyed while tasks are still active."; + } // Prevent copying or moving of the thread pool ThreadPool(const ThreadPool&) = delete; @@ -147,15 +183,16 @@ class ThreadPool { ThreadPool& operator=(const ThreadPool&) = delete; ThreadPool& operator=(ThreadPool&&) = delete; - private: - void TaskComplete() { - std::unique_lock lock(queue_mutex_); - --unfinished_task_count_; - if (unfinished_task_count_ == 0) { - tasks_done_condition_.notify_all(); // Notify waiting threads + // Debug only function to get thread IDs + std::vector DebugGetThreadIDs() const { + std::vector thread_ids; + for (const auto& worker : workers_) { + thread_ids.push_back(worker.get_id()); } + return thread_ids; } + private: /*! \brief Thread container */ std::vector workers_; /*! \brief Task queue */ @@ -168,36 +205,10 @@ class ThreadPool { std::condition_variable tasks_done_condition_; /*! \brief Flag to indicate thread pool shutdown */ bool shutdown_ = false; - /*! \brief Number of unfinished tasks */ - int unfinished_task_count_ = 0; + /*! \brief Number of active tasks */ + std::atomic_size_t active_tasks_{0}; }; -inline void ParallelFor(int low, int high, int num_threads, std::function f) { - if (high - low == 1) { - f(low); - return; - } - - ThreadPool pool(num_threads); - - int total = high - low; - int chunk_size = (total + num_threads - 1) / num_threads; - - for (int t = 0; t < num_threads; ++t) { - int start = low + t * chunk_size; - int end = std::min(start + chunk_size, high); - - if (start >= end) break; // No more iterations to process - - pool.Execute([f, start, end]() { - for (int i = start; i < end; ++i) { - f(i); - } - }); - } - pool.Join(); -} - } // namespace xgrammar #endif // XGRAMMAR_SUPPORT_THREAD_POOL_H_ diff --git a/include/xgrammar/grammar.h b/include/xgrammar/grammar.h index 175ed637..715cebb7 100644 --- a/include/xgrammar/grammar.h +++ b/include/xgrammar/grammar.h @@ -10,7 +10,6 @@ #include #include -#include #include #include #include diff --git a/tests/cpp/test_thread_pool.cc b/tests/cpp/test_thread_pool.cc index 8ae66b05..2d768223 100644 --- a/tests/cpp/test_thread_pool.cc +++ b/tests/cpp/test_thread_pool.cc @@ -1,66 +1,56 @@ #include #include +#include +#include #include "support/thread_pool.h" using namespace xgrammar; -TEST(XGramamrThreadPoolTest, FunctionalTest) { - ThreadPool pool(4); - - // Example 1: Use Submit to submit tasks with return values - std::vector> futures; - for (int i = 0; i < 8; ++i) { - auto fut = pool.Submit([i] { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - std::cout << "Task " << i << " is running in thread " << std::this_thread::get_id() << "\n"; - return i * i; - }); - futures.push_back(fut); - } - - for (auto& fut : futures) { - int result = fut.get(); - std::cout << "Result: " << result << "\n"; - } - - // Example 2: Use Execute to submit tasks without return values - for (int i = 0; i < 5; ++i) { - pool.Execute([i] { - std::this_thread::sleep_for(std::chrono::milliseconds(50)); - std::cout << "Execute task " << i << " is running in thread " << std::this_thread::get_id() - << "\n"; +TEST(XGrammarThreadPoolTest, FunctionalTest) { + ThreadPool pool(2); + + const auto tid_map = [&pool] { + const auto thread_ids = pool.DebugGetThreadIDs(); + std::unordered_map map; + for (size_t i = 0; i < thread_ids.size(); ++i) { + map[thread_ids[i]] = static_cast(i); + } + return map; + }(); + + std::thread threads[2]; + + // Example 1: Use Execute to submit tasks without return values + // with a rate limit of 4, meaning at most 4 tasks can be queued at any time. + const auto start = std::chrono::high_resolution_clock::now(); + for (int j = 0; j < 2; ++j) { + threads[j] = std::thread([j, &pool, start, &tid_map] { + auto counter = pool.CreateTaskCounterWithLimit(4); + if (j == 0) std::this_thread::sleep_for(std::chrono::milliseconds(50)); + for (int i = 0; i < 10; ++i) { + counter.Submit([i, j, start, &tid_map] { + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + const auto now = std::chrono::high_resolution_clock::now(); + const auto dur = std::chrono::duration_cast(now - start); + auto os = std::ostringstream{}; + os << "[" << dur.count() << " ms] [job <" << j << ">] "; + const auto tid = std::this_thread::get_id(); + os << "Execute task " << i << " is running in thread " << tid_map.at(tid) << "\n"; + std::cout << os.str(); + }); + const auto now = std::chrono::high_resolution_clock::now(); + const auto dur = std::chrono::duration_cast(now - start); + auto os = std::ostringstream{}; + os << "[" << dur.count() << " ms] [job <" << j << ">] "; + os << "Submit task " << i << "\n"; + std::cout << os.str(); + } }); } + threads[0].join(); + threads[1].join(); // Wait for task to complete pool.Join(); } - -// TEST(XGramamrThreadPoolTest, PressureTest) { -// const size_t num_threads = std::thread::hardware_concurrency(); -// ThreadPool pool(num_threads); - -// const size_t num_tasks = 1000; -// int counter = 0; -// std::mutex counter_mutex; - -// auto start_time = std::chrono::high_resolution_clock::now(); - -// for (size_t i = 0; i < num_tasks; ++i) { -// pool.Execute([&counter, &counter_mutex, i]() { -// std::this_thread::sleep_for(std::chrono::milliseconds(i % 50)); -// std::lock_guard lock(counter_mutex); -// counter++; -// }); -// } - -// pool.Wait(); - -// auto end_time = std::chrono::high_resolution_clock::now(); - -// EXPECT_EQ(counter, static_cast(num_tasks)); - -// auto duration = std::chrono::duration_cast(end_time - start_time); -// std::cout << "Pressure test completed, time taken: " << duration.count() << " milliseconds.\n"; -// }