1919#define KALDI_CUDADECODER_THREAD_POOL_LIGHT_H_
2020
2121#include < atomic>
22+ #include < memory>
2223#include < thread>
2324#include < vector>
2425
25- #include " util/stl-utils.h"
26-
2726namespace kaldi {
2827namespace cuda_decoder {
2928
30- const double kSleepForWorkerAvailable = 1e-3 ;
29+ constexpr double kSleepForWorkAvailable = 1e-3 ;
30+ constexpr double kSleepForWorkerAvailable = 1e-3 ;
3131
3232struct ThreadPoolLightTask {
3333 void (*func_ptr)(void *, uint64_t , void *);
@@ -39,20 +39,17 @@ struct ThreadPoolLightTask {
3939template <int QUEUE_SIZE>
4040// Single producer, multiple consumer
4141class ThreadPoolLightSPMCQueue {
42- static const unsigned int QUEUE_MASK = QUEUE_SIZE - 1 ;
42+ static constexpr unsigned int QUEUE_MASK = QUEUE_SIZE - 1 ;
4343 std::vector<ThreadPoolLightTask> tasks_;
4444 std::atomic<int > back_;
4545 std::atomic<int > front_;
46- int inc (int curr) { return ((curr + 1 ) & QUEUE_MASK); }
46+ static int inc (int curr) { return ((curr + 1 ) & QUEUE_MASK); }
4747
4848 public:
49- ThreadPoolLightSPMCQueue () {
50- KALDI_ASSERT (QUEUE_SIZE > 1 );
51- bool is_power_of_2 = ((QUEUE_SIZE & (QUEUE_SIZE - 1 )) == 0 );
52- KALDI_ASSERT (is_power_of_2); // validity of QUEUE_MASK
53- tasks_.resize (QUEUE_SIZE);
54- front_.store (0 );
55- back_.store (0 );
49+ ThreadPoolLightSPMCQueue () : tasks_(QUEUE_SIZE), front_(0 ), back_(0 ) {
50+ KALDI_COMPILE_TIME_ASSERT (QUEUE_SIZE > 1 );
51+ constexpr bool is_power_of_2 = ((QUEUE_SIZE & (QUEUE_SIZE - 1 )) == 0 );
52+ KALDI_COMPILE_TIME_ASSERT (is_power_of_2); // validity of QUEUE_MASK
5653 }
5754
5855 bool TryPush (const ThreadPoolLightTask &task) {
@@ -70,33 +67,35 @@ class ThreadPoolLightSPMCQueue {
7067 bool TryPop (ThreadPoolLightTask *front_task) {
7168 while (true ) {
7269 int front = front_.load (std::memory_order_relaxed);
73- if (front == back_.load (std::memory_order_acquire))
70+ if (front == back_.load (std::memory_order_acquire)) {
7471 return false ; // queue is empty
72+ }
7573 *front_task = tasks_[front];
7674 if (front_.compare_exchange_weak (front, inc (front),
77- std::memory_order_release))
75+ std::memory_order_release)) {
7876 return true ;
77+ }
7978 }
8079 }
8180};
8281
83- class ThreadPoolLightWorker {
82+ class ThreadPoolLightWorker final {
8483 // Multi consumer queue, because worker can steal work
8584 ThreadPoolLightSPMCQueue<512 > queue_;
8685 // If this thread has no more work to do, it will try to steal work from
8786 // other
88- std::unique_ptr<std:: thread> thread_;
89- bool run_thread_;
87+ std::thread thread_;
88+ volatile bool run_thread_;
9089 ThreadPoolLightTask curr_task_;
9190 std::weak_ptr<ThreadPoolLightWorker> other_;
9291
9392 void Work () {
9493 while (run_thread_) {
9594 bool got_task = queue_.TryPop (&curr_task_);
9695 if (!got_task) {
97- if (auto other_sp = other_.lock ()) {
98- got_task = other_sp->TrySteal (&curr_task_);
99- }
96+ if (auto other_sp = other_.lock ()) {
97+ got_task = other_sp->TrySteal (&curr_task_);
98+ }
10099 }
101100 if (got_task) {
102101 // Not calling func_ptr as a member function,
@@ -106,71 +105,77 @@ class ThreadPoolLightWorker {
106105 (curr_task_.func_ptr )(curr_task_.obj_ptr , curr_task_.arg1 ,
107106 curr_task_.arg2 );
108107 } else {
109- Sleep (1e- 3f ); // TODO
108+ Sleep (kSleepForWorkAvailable ); // TODO
110109 }
111110 }
112111 }
113112
114- protected:
115113 // Another worker can steal a task from this queue
116114 // This is done so that a very long task computed by one thread does not
117115 // hold the entire threadpool to complete a time-sensitive task
118116 bool TrySteal (ThreadPoolLightTask *task) { return queue_.TryPop (task); }
119117
120118 public:
121119 ThreadPoolLightWorker () : run_thread_(true ), other_() {}
122- virtual ~ThreadPoolLightWorker () { Stop (); }
123- bool TryPush (const ThreadPoolLightTask &task) { return queue_.TryPush (task); }
120+ ~ThreadPoolLightWorker () {
121+ KALDI_ASSERT (!queue_.TryPop (&curr_task_));
122+ }
123+ bool TryPush (const ThreadPoolLightTask &task) {
124+ return queue_.TryPush (task);
125+ }
124126 void SetOtherWorkerToStealFrom (
125127 const std::shared_ptr<ThreadPoolLightWorker>& other) {
126128 other_ = other;
127129 }
128130 void Start () {
129131 KALDI_ASSERT (" Please call SetOtherWorkerToStealFrom() first" && !other_.expired ());
130- thread_. reset ( new std::thread (&ThreadPoolLightWorker::Work, this ));
132+ thread_ = std::move ( std::thread (&ThreadPoolLightWorker::Work, this ));
131133 }
132134 void Stop () {
133135 run_thread_ = false ;
134- thread_->join ();
136+ thread_.join ();
137+ other_.reset ();
135138 }
136139};
137140
138141class ThreadPoolLight {
139142 std::vector<std::shared_ptr<ThreadPoolLightWorker>> workers_;
140143 int curr_iworker_; // next call on tryPush will post work on this
141144 // worker
142- int nworkers_;
143-
144145 public:
145146 ThreadPoolLight (int32 nworkers = std::thread::hardware_concurrency())
146- : curr_iworker_( 0 ), nworkers_(nworkers ) {
147+ : workers_(nworkers ), curr_iworker_( 0 ) {
147148 KALDI_ASSERT (nworkers > 1 );
148- workers_.resize (nworkers);
149- for (int i = 0 ; i < workers_.size (); ++i)
149+ for (size_t i = 0 ; i < workers_.size (); ++i) {
150150 workers_[i] = std::make_shared<ThreadPoolLightWorker>();
151-
152- for (int i = 0 ; i < workers_.size (); ++i) {
151+ }
152+ for (size_t i = 0 ; i < workers_.size (); ++i) {
153153 int iother = (i + nworkers / 2 ) % nworkers;
154154 workers_[i]->SetOtherWorkerToStealFrom (workers_[iother]);
155155 workers_[i]->Start ();
156156 }
157157 }
158158
159+ ~ThreadPoolLight () {
160+ for (auto & wkr : workers_) wkr->Stop ();
161+ }
162+
159163 bool TryPush (const ThreadPoolLightTask &task) {
160164 if (!workers_[curr_iworker_]->TryPush (task)) return false ;
161165 ++curr_iworker_;
162- if (curr_iworker_ == nworkers_ ) curr_iworker_ = 0 ;
166+ if (curr_iworker_ == workers_. size () ) curr_iworker_ = 0 ;
163167 return true ;
164168 }
165169
166170 void Push (const ThreadPoolLightTask &task) {
167171 // Could try another curr_iworker_
168- while (!TryPush (task))
172+ while (!TryPush (task)) {
169173 Sleep (kSleepForWorkerAvailable );
174+ }
170175 }
171176};
172177
173- } // end namespace cuda_decoder
174- } // end namespace kaldi
178+ } // namespace cuda_decoder
179+ } // namespace kaldi
175180
176- #endif // KALDI_CUDADECODER_THREAD_POOL_H_
181+ #endif // KALDI_CUDADECODER_THREAD_POOL_LIGHT_H_
0 commit comments