Skip to content

Commit 4973514

Browse files
authored
[src] Assert empty workqueue on ThreadPoolLight destruction (#4568)
Other minor tweaks: * Convert runtime asserts to static asserts in template. * Remove ThreadPoolLight::nworkers_ as it's equal to the .size() of the worker vector, and was used only once. * Replace pointer indirection with std::move and make ThreadPoolLightWorker::thread_ a direct class member. * Update to coding guidelines and IWYU.
1 parent 9796afd commit 4973514

File tree

1 file changed

+43
-38
lines changed

1 file changed

+43
-38
lines changed

src/cudadecoder/thread-pool-light.h

Lines changed: 43 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@
1919
#define KALDI_CUDADECODER_THREAD_POOL_LIGHT_H_
2020

2121
#include <atomic>
22+
#include <memory>
2223
#include <thread>
2324
#include <vector>
2425

25-
#include "util/stl-utils.h"
26-
2726
namespace kaldi {
2827
namespace cuda_decoder {
2928

30-
const double kSleepForWorkerAvailable = 1e-3;
29+
constexpr double kSleepForWorkAvailable = 1e-3;
30+
constexpr double kSleepForWorkerAvailable = 1e-3;
3131

3232
struct ThreadPoolLightTask {
3333
void (*func_ptr)(void *, uint64_t, void *);
@@ -39,20 +39,17 @@ struct ThreadPoolLightTask {
3939
template <int QUEUE_SIZE>
4040
// Single producer, multiple consumer
4141
class ThreadPoolLightSPMCQueue {
42-
static const unsigned int QUEUE_MASK = QUEUE_SIZE - 1;
42+
static constexpr unsigned int QUEUE_MASK = QUEUE_SIZE - 1;
4343
std::vector<ThreadPoolLightTask> tasks_;
4444
std::atomic<int> back_;
4545
std::atomic<int> front_;
46-
int inc(int curr) { return ((curr + 1) & QUEUE_MASK); }
46+
static int inc(int curr) { return ((curr + 1) & QUEUE_MASK); }
4747

4848
public:
49-
ThreadPoolLightSPMCQueue() {
50-
KALDI_ASSERT(QUEUE_SIZE > 1);
51-
bool is_power_of_2 = ((QUEUE_SIZE & (QUEUE_SIZE - 1)) == 0);
52-
KALDI_ASSERT(is_power_of_2); // validity of QUEUE_MASK
53-
tasks_.resize(QUEUE_SIZE);
54-
front_.store(0);
55-
back_.store(0);
49+
ThreadPoolLightSPMCQueue() : tasks_(QUEUE_SIZE), front_(0), back_(0) {
50+
KALDI_COMPILE_TIME_ASSERT(QUEUE_SIZE > 1);
51+
constexpr bool is_power_of_2 = ((QUEUE_SIZE & (QUEUE_SIZE - 1)) == 0);
52+
KALDI_COMPILE_TIME_ASSERT(is_power_of_2); // validity of QUEUE_MASK
5653
}
5754

5855
bool TryPush(const ThreadPoolLightTask &task) {
@@ -70,33 +67,35 @@ class ThreadPoolLightSPMCQueue {
7067
bool TryPop(ThreadPoolLightTask *front_task) {
7168
while (true) {
7269
int front = front_.load(std::memory_order_relaxed);
73-
if (front == back_.load(std::memory_order_acquire))
70+
if (front == back_.load(std::memory_order_acquire)) {
7471
return false; // queue is empty
72+
}
7573
*front_task = tasks_[front];
7674
if (front_.compare_exchange_weak(front, inc(front),
77-
std::memory_order_release))
75+
std::memory_order_release)) {
7876
return true;
77+
}
7978
}
8079
}
8180
};
8281

83-
class ThreadPoolLightWorker {
82+
class ThreadPoolLightWorker final {
8483
// Multi consumer queue, because worker can steal work
8584
ThreadPoolLightSPMCQueue<512> queue_;
8685
// If this thread has no more work to do, it will try to steal work from
8786
// other
88-
std::unique_ptr<std::thread> thread_;
89-
bool run_thread_;
87+
std::thread thread_;
88+
volatile bool run_thread_;
9089
ThreadPoolLightTask curr_task_;
9190
std::weak_ptr<ThreadPoolLightWorker> other_;
9291

9392
void Work() {
9493
while (run_thread_) {
9594
bool got_task = queue_.TryPop(&curr_task_);
9695
if (!got_task) {
97-
if (auto other_sp = other_.lock()) {
98-
got_task = other_sp->TrySteal(&curr_task_);
99-
}
96+
if (auto other_sp = other_.lock()) {
97+
got_task = other_sp->TrySteal(&curr_task_);
98+
}
10099
}
101100
if (got_task) {
102101
// Not calling func_ptr as a member function,
@@ -106,71 +105,77 @@ class ThreadPoolLightWorker {
106105
(curr_task_.func_ptr)(curr_task_.obj_ptr, curr_task_.arg1,
107106
curr_task_.arg2);
108107
} else {
109-
Sleep(1e-3f); // TODO
108+
Sleep(kSleepForWorkAvailable); // TODO
110109
}
111110
}
112111
}
113112

114-
protected:
115113
// Another worker can steal a task from this queue
116114
// This is done so that a very long task computed by one thread does not
117115
// hold the entire threadpool to complete a time-sensitive task
118116
bool TrySteal(ThreadPoolLightTask *task) { return queue_.TryPop(task); }
119117

120118
public:
121119
ThreadPoolLightWorker() : run_thread_(true), other_() {}
122-
virtual ~ThreadPoolLightWorker() { Stop(); }
123-
bool TryPush(const ThreadPoolLightTask &task) { return queue_.TryPush(task); }
120+
~ThreadPoolLightWorker() {
121+
KALDI_ASSERT(!queue_.TryPop(&curr_task_));
122+
}
123+
bool TryPush(const ThreadPoolLightTask &task) {
124+
return queue_.TryPush(task);
125+
}
124126
void SetOtherWorkerToStealFrom(
125127
const std::shared_ptr<ThreadPoolLightWorker>& other) {
126128
other_ = other;
127129
}
128130
void Start() {
129131
KALDI_ASSERT("Please call SetOtherWorkerToStealFrom() first" && !other_.expired());
130-
thread_.reset(new std::thread(&ThreadPoolLightWorker::Work, this));
132+
thread_ = std::move(std::thread(&ThreadPoolLightWorker::Work, this));
131133
}
132134
void Stop() {
133135
run_thread_ = false;
134-
thread_->join();
136+
thread_.join();
137+
other_.reset();
135138
}
136139
};
137140

138141
class ThreadPoolLight {
139142
std::vector<std::shared_ptr<ThreadPoolLightWorker>> workers_;
140143
int curr_iworker_; // next call on tryPush will post work on this
141144
// worker
142-
int nworkers_;
143-
144145
public:
145146
ThreadPoolLight(int32 nworkers = std::thread::hardware_concurrency())
146-
: curr_iworker_(0), nworkers_(nworkers) {
147+
: workers_(nworkers), curr_iworker_(0) {
147148
KALDI_ASSERT(nworkers > 1);
148-
workers_.resize(nworkers);
149-
for (int i = 0; i < workers_.size(); ++i)
149+
for (size_t i = 0; i < workers_.size(); ++i) {
150150
workers_[i] = std::make_shared<ThreadPoolLightWorker>();
151-
152-
for (int i = 0; i < workers_.size(); ++i) {
151+
}
152+
for (size_t i = 0; i < workers_.size(); ++i) {
153153
int iother = (i + nworkers / 2) % nworkers;
154154
workers_[i]->SetOtherWorkerToStealFrom(workers_[iother]);
155155
workers_[i]->Start();
156156
}
157157
}
158158

159+
~ThreadPoolLight() {
160+
for (auto& wkr : workers_) wkr->Stop();
161+
}
162+
159163
bool TryPush(const ThreadPoolLightTask &task) {
160164
if (!workers_[curr_iworker_]->TryPush(task)) return false;
161165
++curr_iworker_;
162-
if (curr_iworker_ == nworkers_) curr_iworker_ = 0;
166+
if (curr_iworker_ == workers_.size()) curr_iworker_ = 0;
163167
return true;
164168
}
165169

166170
void Push(const ThreadPoolLightTask &task) {
167171
// Could try another curr_iworker_
168-
while (!TryPush(task))
172+
while (!TryPush(task)) {
169173
Sleep(kSleepForWorkerAvailable);
174+
}
170175
}
171176
};
172177

173-
} // end namespace cuda_decoder
174-
} // end namespace kaldi
178+
} // namespace cuda_decoder
179+
} // namespace kaldi
175180

176-
#endif // KALDI_CUDADECODER_THREAD_POOL_H_
181+
#endif // KALDI_CUDADECODER_THREAD_POOL_LIGHT_H_

0 commit comments

Comments
 (0)