Skip to content

Commit a6d888f

Browse files
committed
Use semaphore. Numberous fixes.
Signed-off-by: Michał Zientkiewicz <mzient@gmail.com>
1 parent 74e41dc commit a6d888f

File tree

2 files changed

+50
-24
lines changed

2 files changed

+50
-24
lines changed

dali/core/exec/thread_pool_base.cc

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ Job::~Job() noexcept(false) {
2121
if (!tasks_.empty() && !waited_for_) {
2222
throw std::logic_error("The job is not empty, but hasn't been scrapped or waited for.");
2323
}
24+
while (running_)
25+
std::this_thread::yield();
2426
}
2527

2628
void Job::Wait() {
@@ -30,8 +32,8 @@ void Job::Wait() {
3032
if (waited_for_)
3133
throw std::logic_error("This job has already been waited for.");
3234

35+
auto ready = [&]() { return num_pending_tasks_ == 0; };
3336
if (ThreadPoolBase::this_thread_pool() != nullptr) {
34-
auto ready = [&]() { return num_pending_tasks_ == 0; };
3537
bool result = ThreadPoolBase::this_thread_pool()->WaitOrRunTasks(cv_, ready);
3638
waited_for_ = true;
3739
if (!result)
@@ -62,6 +64,7 @@ void Job::Run(ThreadPoolBase &tp, bool wait) {
6264
if (started_)
6365
throw std::logic_error("This job has already been started.");
6466
started_ = true;
67+
running_ = !tasks_.empty();
6568
{
6669
auto batch = tp.BulkAdd();
6770
num_pending_tasks_ += tasks_.size();
@@ -86,6 +89,8 @@ IncrementalJob::~IncrementalJob() noexcept(false) {
8689
if (!tasks_.empty() && !waited_for_) {
8790
throw std::logic_error("The job is not empty, but hasn't been scrapped or waited for.");
8891
}
92+
while (running_)
93+
std::this_thread::yield();
8994
}
9095

9196
void IncrementalJob::Wait() {
@@ -95,8 +100,8 @@ void IncrementalJob::Wait() {
95100
if (waited_for_)
96101
throw std::logic_error("This job has already been waited for.");
97102

103+
auto ready = [&]() { return num_pending_tasks_ == 0; };
98104
if (ThreadPoolBase::this_thread_pool() != nullptr) {
99-
auto ready = [&]() { return num_pending_tasks_ == 0; };
100105
bool result = ThreadPoolBase::this_thread_pool()->WaitOrRunTasks(cv_, ready);
101106
waited_for_ = true;
102107
if (!result)
@@ -134,6 +139,8 @@ void IncrementalJob::Run(ThreadPoolBase &tp, bool wait) {
134139
batch.Add(std::move(it->func));
135140
last_task_run_ = it;
136141
}
142+
running_ = batch.Size() > 0;
143+
batch.Submit();
137144
}
138145
if (wait && !tasks_.empty())
139146
Wait();
@@ -148,7 +155,7 @@ void IncrementalJob::Scrap() {
148155
///////////////////////////////////////////////////////////////////////////
149156

150157
thread_local ThreadPoolBase *ThreadPoolBase::this_thread_pool_ = nullptr;
151-
thread_local int ThreadPoolBase::this_thread_idx_ = -1;;
158+
thread_local int ThreadPoolBase::this_thread_idx_ = -1;
152159

153160
void ThreadPoolBase::Init(int num_threads, const std::function<OnThreadStartFn> &on_thread_start) {
154161
if (shutdown_pending_)
@@ -162,20 +169,19 @@ void ThreadPoolBase::Init(int num_threads, const std::function<OnThreadStartFn>
162169
}
163170

164171
void ThreadPoolBase::Shutdown(bool join) {
165-
if (shutdown_pending_ && !join)
172+
if ((shutdown_pending_ && !join) || threads_.empty())
166173
return;
167174
{
168175
std::lock_guard<std::mutex> g(mtx_);
169176
if (shutdown_pending_ && !join)
170177
return;
171178
shutdown_pending_ = true;
172-
cv_.notify_all();
179+
sem_.release(threads_.size());
173180
}
174181

175182
for (auto &t : threads_)
176183
t.join();
177-
178-
assert(tasks_.empty());
184+
threads_.clear();
179185
}
180186

181187
void ThreadPoolBase::AddTaskNoLock(TaskFunc &&f) {
@@ -189,7 +195,6 @@ void ThreadPoolBase::AddTask(TaskFunc &&f) {
189195
std::lock_guard<std::mutex> g(mtx_);
190196
AddTaskNoLock(std::move(f));
191197
}
192-
cv_.notify_one();
193198
}
194199

195200
void ThreadPoolBase::Run(
@@ -200,11 +205,12 @@ void ThreadPoolBase::Run(
200205
std::any scope;
201206
if (on_thread_start)
202207
scope = on_thread_start(index);
203-
std::unique_lock lock(mtx_);
204208
while (!shutdown_pending_ || !tasks_.empty()) {
205-
cv_.wait(lock, [&]() { return shutdown_pending_ || !tasks_.empty(); });
206-
if (tasks_.empty())
209+
sem_.acquire();
210+
std::unique_lock lock(mtx_);
211+
if (shutdown_pending_)
207212
break;
213+
assert(!tasks_.empty() && "Semaphore acquired but no tasks present.");
208214
PopAndRunTask(lock);
209215
}
210216
}
@@ -228,11 +234,12 @@ bool ThreadPoolBase::WaitOrRunTasks(std::condition_variable &cv, Condition &&con
228234

229235
if (ret || condition()) // re-evaluate the condition, just in case
230236
return true;
231-
if (tasks_.empty()) {
232-
assert(shutdown_pending_);
237+
if (shutdown_pending_)
233238
return condition();
234-
}
239+
if (!sem_.try_acquire())
240+
continue;
235241

242+
assert(!tasks_.empty() && "Semaphore acquired but no tasks present.");
236243
PopAndRunTask(lock);
237244
}
238245
return condition();

include/dali/core/exec/thread_pool_base.h

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <vector>
3030
#include <mutex>
3131
#include <condition_variable>
32+
#include "dali/core/semaphore.h"
3233
#include "dali/core/api_helper.h"
3334
#include "dali/core/multi_error.h"
3435
#include "dali/core/mm/detail/aux_alloc.h"
@@ -66,10 +67,14 @@ class DLL_PUBLIC Job {
6667

6768
if (--num_pending_tasks_ == 0) {
6869
num_pending_tasks_.notify_all();
69-
std::cerr << make_string((void *)this, " notified.") << std::endl;
70+
(void)std::lock_guard(mtx_);
7071
cv_.notify_all();
72+
// We need this second flag to avoid a race condition where the
73+
// desctructor is called between decrementing num_pending_tasks_ and notification_
74+
// without excessive use of mutexes. This must be the very last operation in the task
75+
// function that touches `this`.
76+
running_ = false;
7177
}
72-
assert(num_pending_tasks_ >= 0);
7378
};
7479
} catch (...) { // if, for whatever reason, we cannot initialize the task, we should erase it
7580
tasks_.erase(it);
@@ -88,8 +93,10 @@ class DLL_PUBLIC Job {
8893

8994
private:
9095
// atomic wait has no timeout, so we're stuck with condvar for reentrance
96+
std::mutex mtx_;
9197
std::condition_variable cv_;
9298
std::atomic_int num_pending_tasks_{0};
99+
std::atomic_bool running_{false};
93100
bool started_ = false;
94101
bool waited_for_ = false;
95102

@@ -139,8 +146,10 @@ class DLL_PUBLIC IncrementalJob {
139146
const void *executor_ = nullptr;
140147
bool waited_for_ = false;
141148
// atomic wait has no timeout, so we're stuck with condvar for reentrance
149+
std::mutex mtx_;
142150
std::condition_variable cv_;
143151
std::atomic_int num_pending_tasks_{0};
152+
std::atomic_bool running_{false};
144153
using task_list_t = std::list<Task, mm::detail::object_pool_allocator<Task>>;
145154
task_list_t tasks_;
146155
std::optional<task_list_t::iterator> last_task_run_;
@@ -192,18 +201,20 @@ class DLL_PUBLIC ThreadPoolBase {
192201
void Submit() {
193202
if (lock.owns_lock()) {
194203
lock.unlock();
195-
if (tasks_added > 1)
196-
owner->cv_.notify_all();
197-
else
198-
owner->cv_.notify_one();
204+
owner->sem_.release(tasks_added);
199205
}
200206
}
207+
208+
int Size() const {
209+
return tasks_added;
210+
}
211+
201212
private:
202213
friend class ThreadPoolBase;
203214
explicit TaskBulkAdd(ThreadPoolBase *o) : owner(o), lock(o->mtx_, std::defer_lock) {}
204-
ThreadPoolBase *owner;
215+
ThreadPoolBase *owner = nullptr;
205216
std::unique_lock<std::mutex> lock;
206-
int tasks_added;
217+
int tasks_added = 0;
207218
};
208219
friend class TaskBulkAdd;
209220

@@ -247,7 +258,7 @@ class DLL_PUBLIC ThreadPoolBase {
247258
void Run(int index, const std::function<OnThreadStartFn> &on_thread_start) noexcept;
248259

249260
std::mutex mtx_;
250-
std::condition_variable cv_;
261+
counting_semaphore sem_{0};
251262
bool shutdown_pending_ = false;
252263
std::queue<TaskFunc> tasks_;
253264
std::vector<std::thread> threads_;
@@ -289,13 +300,15 @@ void Job::Run(Executor &executor, bool wait) {
289300
if (started_)
290301
throw std::logic_error("This job has already been started.");
291302
started_ = true;
303+
running_ = !tasks_.empty();
292304
for (auto &x : tasks_) {
293305
num_pending_tasks_++;
294306
try {
295307
executor.AddTask(std::move(x.second.func));
296308
} catch (...) {
297309
if (--num_pending_tasks_ == 0) {
298310
num_pending_tasks_.notify_all();
311+
(void)std::lock_guard(mtx_);
299312
cv_.notify_all();
300313
}
301314
throw;
@@ -325,9 +338,14 @@ IncrementalJob::AddTask(Runnable &&runnable) {
325338

326339
if (--num_pending_tasks_ == 0) {
327340
num_pending_tasks_.notify_all();
341+
(void)std::lock_guard(mtx_);
328342
cv_.notify_all();
343+
// We need this second flag to avoid a race condition where the
344+
// desctructor is called between decrementing num_pending_tasks_ and notification_
345+
// without excessive use of mutexes. This must be the very last operation in the task
346+
// function that touches `this`.
347+
running_ = false;
329348
}
330-
assert(num_pending_tasks_ >= 0);
331349
};
332350
} catch (...) { // if, for whatever reason, we cannot initialize the task, we should erase it
333351
tasks_.erase(it);
@@ -346,6 +364,7 @@ void IncrementalJob::Run(Executor &executor, bool wait) {
346364
executor_ = &executor;
347365
auto it = last_task_run_.has_value() ? std::next(*last_task_run_) : tasks_.begin();
348366
for (; it != tasks_.end(); ++it) {
367+
running_ = true;
349368
executor.AddTask(std::move(it->func));
350369
last_task_run_ = it;
351370
}

0 commit comments

Comments
 (0)