66#ifndef XGRAMMAR_SUPPORT_THREAD_POOL_H_
77#define XGRAMMAR_SUPPORT_THREAD_POOL_H_
88
9+ #include < atomic>
910#include < condition_variable>
11+ #include < cstddef>
1012#include < functional>
11- #include < future>
1213#include < mutex>
1314#include < queue>
1415#include < thread>
15- #include < type_traits>
1616#include < vector>
1717
1818#include " logging.h"
@@ -35,8 +35,9 @@ class ThreadPool {
3535 */
3636 ThreadPool (size_t num_threads = std::thread::hardware_concurrency()) {
3737 // Initialize thread pool with num_threads threads
38- for (size_t i = 0 ; i < num_threads; ++i) {
39- workers_.emplace_back ([this ] {
38+ workers_.resize (num_threads);
39+ for (auto & worker : workers_) {
40+ worker = std::thread ([this ] {
4041 while (true ) {
4142 std::function<void ()> task;
4243 {
@@ -58,38 +59,6 @@ class ThreadPool {
5859 }
5960 }
6061
61- /* !
62- * \brief Add a new task to be executed by the thread pool.
63- * \tparam F Type of the function to execute
64- * \tparam Args Types of the arguments to pass to the function
65- * \param f Function to execute
66- * \param args Arguments to pass to the function
67- * \return std::shared_future containing the result of the function call
68- * \note Tasks are executed in FIFO order but may complete in any order.
69- */
70- template <class F , class ... Args>
71- auto Submit (F&& f, Args&&... args) -> std::shared_future<std::invoke_result_t<F, Args...>> {
72- using return_type = std::invoke_result_t <F, Args...>;
73-
74- // Package the task with its arguments into a shared pointer
75- auto task = std::make_shared<std::packaged_task<return_type ()>>(
76- std::bind (std::forward<F>(f), std::forward<Args>(args)...)
77- );
78-
79- std::shared_future<return_type> res = task->get_future ().share ();
80-
81- {
82- std::unique_lock<std::mutex> lock (queue_mutex_);
83- XGRAMMAR_CHECK (!shutdown_) << " Cannot submit task to stopped ThreadPool" ;
84- ++unfinished_task_count_; // Increment task count
85-
86- // Directly add the task without wrapping
87- task_queue_.emplace ([task]() { (*task)(); });
88- }
89- queue_condition_.notify_one ();
90- return res;
91- }
92-
9362 /* !
9463 * \brief Add a new task to be executed by the thread pool without returning a future.
9564 * \tparam F Type of the function to execute
@@ -98,21 +67,20 @@ class ThreadPool {
9867 * \param args Arguments to pass to the function
9968 * \note Tasks are executed asynchronously by the worker threads.
10069 */
101- template <class F , class ... Args>
102- void Execute (F&& f, Args&&... args) {
70+ void Execute (std::function<void ()> f) {
10371 {
10472 std::unique_lock<std::mutex> lock (queue_mutex_);
10573 XGRAMMAR_CHECK (!shutdown_) << " Cannot execute task in stopped ThreadPool" ;
10674 ++unfinished_task_count_; // Increment task count
10775
10876 // Directly add the task without wrapping
109- task_queue_.emplace (std::bind (std::forward<F>(f), std::forward<Args>(args)... ));
77+ task_queue_.emplace (std::move (f ));
11078 }
11179 queue_condition_.notify_one ();
11280 }
11381
11482 void Wait () {
115- std::unique_lock<std::mutex> lock ( queue_mutex_) ;
83+ auto lock = std::unique_lock{ queue_mutex_} ;
11684 tasks_done_condition_.wait (lock, [this ] { return unfinished_task_count_ == 0 ; });
11785 }
11886
@@ -147,6 +115,8 @@ class ThreadPool {
147115 ThreadPool& operator =(const ThreadPool&) = delete ;
148116 ThreadPool& operator =(ThreadPool&&) = delete ;
149117
118+ std::size_t NumThreads () const { return workers_.size (); }
119+
150120 private:
151121 void TaskComplete () {
152122 std::unique_lock<std::mutex> lock (queue_mutex_);
@@ -172,6 +142,34 @@ class ThreadPool {
172142 int unfinished_task_count_ = 0 ;
173143};
174144
145+ class TaskCounter {
146+ public:
147+ template <typename F>
148+ void CompleteOne (F&& f) {
149+ const auto lock = std::lock_guard{mutex_};
150+ std::forward<F>(f)();
151+ const auto working = working_.fetch_sub (1 , std::memory_order_relaxed) - 1 ;
152+ if (working == 0 && waiting_ > 0 ) cv_.notify_all ();
153+ }
154+
155+ // This can be called by other threads, so we must use atomic.
156+ // We don't rely on any happens before relationship, so we use relaxed order.
157+ std::size_t AddOne () { return working_.fetch_add (1 , std::memory_order_relaxed) + 1 ; }
158+
159+ void Wait () {
160+ auto lock = std::unique_lock{mutex_};
161+ ++waiting_;
162+ cv_.wait (lock, [this ] { return working_.load (std::memory_order_relaxed) == 0 ; });
163+ --waiting_;
164+ }
165+
166+ private:
167+ std::mutex mutex_;
168+ std::condition_variable cv_;
169+ std::size_t waiting_ = 0 ;
170+ std::atomic_size_t working_ = 0 ;
171+ };
172+
175173inline void ParallelFor (int low, int high, int num_threads, std::function<void (int )> f) {
176174 if (high - low == 1 ) {
177175 f (low);
0 commit comments