Skip to content

Commit 1da52cb

Browse files
committed
fix: use persistent thread pool
1 parent 911a7a3 commit 1da52cb

File tree

3 files changed

+62
-60
lines changed

3 files changed

+62
-60
lines changed

cpp/earley_parser.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -713,9 +713,7 @@ bool RepeatDetector::IsVisited(const ParserState& state) const {
713713

714714
void RepeatDetector::Insert(const ParserState& state) {
715715
if (size_ == transition_threshold_) {
716-
for (const auto& s : visited_vector_) {
717-
visited_set_.insert(s);
718-
}
716+
visited_set_.insert(visited_vector_.begin(), visited_vector_.begin() + size_);
719717
}
720718
size_++;
721719
if (size_ > transition_threshold_) {

cpp/grammar_compiler.cc

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "grammar_functor.h"
2121
#include "grammar_impl.h"
2222
#include "support/logging.h"
23+
#include "support/reflection.h"
2324
#include "support/thread_pool.h"
2425
#include "support/thread_safe_cache.h"
2526
#include "support/utils.h"
@@ -544,7 +545,12 @@ AdaptiveTokenMask GrammarMatcherForTokenMaskCache::GetAdaptiveTokenMask(
544545
class GrammarCompilerNoCache {
545546
public:
546547
GrammarCompilerNoCache(const TokenizerInfo& tokenizer_info, int max_threads)
547-
: tokenizer_info_(tokenizer_info), max_threads_(max_threads) {}
548+
: tokenizer_info_(tokenizer_info), thread_pool_() {
549+
if (max_threads > 1) {
550+
/// NOTE: maybe we can allow max_threads = 1, and use 0 as no extra thread.
551+
thread_pool_.emplace(max_threads);
552+
}
553+
}
548554

549555
CompiledGrammar CompileBuiltinJSONGrammar();
550556

@@ -571,8 +577,9 @@ class GrammarCompilerNoCache {
571577

572578
/*! \brief The vocabulary associated with this storage class. */
573579
const TokenizerInfo tokenizer_info_;
574-
/*! \brief The maximum number of threads to use. */
575-
const int max_threads_;
580+
581+
/*! \brief The persistent thread pool for multi-threading. */
582+
std::optional<ThreadPool> thread_pool_;
576583
};
577584

578585
CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar grammar) {
@@ -597,12 +604,9 @@ CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar gramma
597604
// TODO(Charlie): Figure out how to support ThreadPool and std::mutex in WebAssembly.
598605
// Only declare ThreadPool and mutex if max_threads > 1, so when max_threads = 1, we do
599606
// not need ThreadPool or std::mutex, which throws error in runtime in WebAssembly.
600-
std::optional<ThreadPool> thread_pool;
601-
std::optional<std::mutex> adaptive_token_mask_cache_mutex;
602-
603-
if (max_threads_ > 1) {
604-
thread_pool.emplace(max_threads_);
605-
adaptive_token_mask_cache_mutex.emplace();
607+
std::optional<TaskCounter> task_counter;
608+
if (thread_pool_) {
609+
task_counter.emplace();
606610
}
607611

608612
auto add_adaptive_token_mask = [&](const ParserState& state, bool is_root_rule) {
@@ -613,18 +617,20 @@ CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar gramma
613617
tokenizer_info_.GetTrieSubtreeNodesRange(),
614618
is_root_rule
615619
);
616-
if (max_threads_ > 1) {
617-
std::lock_guard<std::mutex> lock(adaptive_token_mask_cache_mutex.value());
618-
compiled_grammar_impl->adaptive_token_mask_cache[state] = cur_adaptive_token_mask_cache;
620+
if (thread_pool_) {
621+
task_counter->CompleteOne([&] {
622+
compiled_grammar_impl->adaptive_token_mask_cache[state] = cur_adaptive_token_mask_cache;
623+
});
619624
} else {
620625
compiled_grammar_impl->adaptive_token_mask_cache[state] = cur_adaptive_token_mask_cache;
621626
}
622627
};
623628

624629
auto add_task_adaptive_token_mask = [&](const ParserState& state, bool is_root_rule) {
625630
// Execute depending on whether we use thread_pool
626-
if (max_threads_ > 1) {
627-
thread_pool->Execute([add_adaptive_token_mask, state, is_root_rule]() {
631+
if (thread_pool_) {
632+
task_counter->AddOne();
633+
thread_pool_->Execute([add_adaptive_token_mask, state, is_root_rule] {
628634
add_adaptive_token_mask(state, is_root_rule);
629635
});
630636
} else {
@@ -685,8 +691,8 @@ CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar gramma
685691
}
686692
}
687693

688-
if (max_threads_ > 1) {
689-
thread_pool->Join();
694+
if (thread_pool_) {
695+
task_counter->Wait();
690696
}
691697

692698
return CompiledGrammar(compiled_grammar_impl);
@@ -916,7 +922,7 @@ CompiledGrammar GrammarCompiler::Impl::Compute(const UnionKey& key) {
916922
} else if constexpr (std::is_same_v<KeyType, BuiltinJSONGrammarKey>) {
917923
return this->no_cache_compiler_.CompileBuiltinJSONGrammar();
918924
} else {
919-
XGRAMMAR_UNREACHABLE();
925+
static_assert(detail::reflection::false_v<KeyType>, "non-exhaustive visitor!");
920926
}
921927
},
922928
key

cpp/support/thread_pool.h

Lines changed: 38 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
#ifndef XGRAMMAR_SUPPORT_THREAD_POOL_H_
77
#define XGRAMMAR_SUPPORT_THREAD_POOL_H_
88

9+
#include <atomic>
910
#include <condition_variable>
11+
#include <cstddef>
1012
#include <functional>
11-
#include <future>
1213
#include <mutex>
1314
#include <queue>
1415
#include <thread>
15-
#include <type_traits>
1616
#include <vector>
1717

1818
#include "logging.h"
@@ -35,8 +35,9 @@ class ThreadPool {
3535
*/
3636
ThreadPool(size_t num_threads = std::thread::hardware_concurrency()) {
3737
// Initialize thread pool with num_threads threads
38-
for (size_t i = 0; i < num_threads; ++i) {
39-
workers_.emplace_back([this] {
38+
workers_.resize(num_threads);
39+
for (auto& worker : workers_) {
40+
worker = std::thread([this] {
4041
while (true) {
4142
std::function<void()> task;
4243
{
@@ -58,38 +59,6 @@ class ThreadPool {
5859
}
5960
}
6061

61-
/*!
62-
* \brief Add a new task to be executed by the thread pool.
63-
* \tparam F Type of the function to execute
64-
* \tparam Args Types of the arguments to pass to the function
65-
* \param f Function to execute
66-
* \param args Arguments to pass to the function
67-
* \return std::shared_future containing the result of the function call
68-
* \note Tasks are executed in FIFO order but may complete in any order.
69-
*/
70-
template <class F, class... Args>
71-
auto Submit(F&& f, Args&&... args) -> std::shared_future<std::invoke_result_t<F, Args...>> {
72-
using return_type = std::invoke_result_t<F, Args...>;
73-
74-
// Package the task with its arguments into a shared pointer
75-
auto task = std::make_shared<std::packaged_task<return_type()>>(
76-
std::bind(std::forward<F>(f), std::forward<Args>(args)...)
77-
);
78-
79-
std::shared_future<return_type> res = task->get_future().share();
80-
81-
{
82-
std::unique_lock<std::mutex> lock(queue_mutex_);
83-
XGRAMMAR_CHECK(!shutdown_) << "Cannot submit task to stopped ThreadPool";
84-
++unfinished_task_count_; // Increment task count
85-
86-
// Directly add the task without wrapping
87-
task_queue_.emplace([task]() { (*task)(); });
88-
}
89-
queue_condition_.notify_one();
90-
return res;
91-
}
92-
9362
/*!
9463
* \brief Add a new task to be executed by the thread pool without returning a future.
9564
* \tparam F Type of the function to execute
@@ -98,21 +67,20 @@ class ThreadPool {
9867
* \param args Arguments to pass to the function
9968
* \note Tasks are executed asynchronously by the worker threads.
10069
*/
101-
template <class F, class... Args>
102-
void Execute(F&& f, Args&&... args) {
70+
void Execute(std::function<void()> f) {
10371
{
10472
std::unique_lock<std::mutex> lock(queue_mutex_);
10573
XGRAMMAR_CHECK(!shutdown_) << "Cannot execute task in stopped ThreadPool";
10674
++unfinished_task_count_; // Increment task count
10775

10876
// Directly add the task without wrapping
109-
task_queue_.emplace(std::bind(std::forward<F>(f), std::forward<Args>(args)...));
77+
task_queue_.emplace(std::move(f));
11078
}
11179
queue_condition_.notify_one();
11280
}
11381

11482
void Wait() {
115-
std::unique_lock<std::mutex> lock(queue_mutex_);
83+
auto lock = std::unique_lock{queue_mutex_};
11684
tasks_done_condition_.wait(lock, [this] { return unfinished_task_count_ == 0; });
11785
}
11886

@@ -147,6 +115,8 @@ class ThreadPool {
147115
ThreadPool& operator=(const ThreadPool&) = delete;
148116
ThreadPool& operator=(ThreadPool&&) = delete;
149117

118+
std::size_t NumThreads() const { return workers_.size(); }
119+
150120
private:
151121
void TaskComplete() {
152122
std::unique_lock<std::mutex> lock(queue_mutex_);
@@ -172,6 +142,34 @@ class ThreadPool {
172142
int unfinished_task_count_ = 0;
173143
};
174144

145+
class TaskCounter {
146+
public:
147+
template <typename F>
148+
void CompleteOne(F&& f) {
149+
const auto lock = std::lock_guard{mutex_};
150+
std::forward<F>(f)();
151+
const auto working = working_.fetch_sub(1, std::memory_order_relaxed) - 1;
152+
if (working == 0 && waiting_ > 0) cv_.notify_all();
153+
}
154+
155+
// This can be called by other threads, so we must use atomic.
156+
// We don't rely on any happens before relationship, so we use relaxed order.
157+
std::size_t AddOne() { return working_.fetch_add(1, std::memory_order_relaxed) + 1; }
158+
159+
void Wait() {
160+
auto lock = std::unique_lock{mutex_};
161+
++waiting_;
162+
cv_.wait(lock, [this] { return working_.load(std::memory_order_relaxed) == 0; });
163+
--waiting_;
164+
}
165+
166+
private:
167+
std::mutex mutex_;
168+
std::condition_variable cv_;
169+
std::size_t waiting_ = 0;
170+
std::atomic_size_t working_ = 0;
171+
};
172+
175173
inline void ParallelFor(int low, int high, int num_threads, std::function<void(int)> f) {
176174
if (high - low == 1) {
177175
f(low);

0 commit comments

Comments
 (0)