Skip to content

Commit 8bd7661

Browse files
Fix issues with wait_for_tasks() (#68)
* test: add new test for wait_for_tasks() * docs: rename variable and add comments * test: split test to two separate ones * chore: update memory ordering of atomics * wip: use std::barrier in wait_for_tasks() Use std::barrier with wait for tasks. This requires std::move_only_function to be available. * chore: use std::atomic_bool instead of a barrier * simplify test to remove extra variables * reset the thread complete signal any time we add new tasks to an empty queue. * add extra test * chore: remove unecessary try/catch We suppress exceptions before enquing the task so it doesn't seem necessary to have try/catch when we invoke the enqueued task * test: add variety to tests and added missing include * format: auto formatting fixes * docs: add more doc strings and comments * fix: catch exceptions in the thread init function * chore: use `store` with atomic bool --------- Co-authored-by: Justin Davis <[email protected]>
1 parent 98f7b95 commit 8bd7661

File tree

2 files changed

+123
-17
lines changed

2 files changed

+123
-17
lines changed

include/thread_pool/thread_pool.h

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#pragma once
22

33
#include <atomic>
4-
#include <barrier>
54
#include <concepts>
65
#include <deque>
76
#include <functional>
@@ -47,20 +46,29 @@ namespace dp {
4746
try {
4847
threads_.emplace_back([&, id = current_id,
4948
init](const std::stop_token &stop_tok) {
50-
init(id);
49+
// invoke the init function on the thread
50+
try {
51+
std::invoke(init, id);
52+
} catch (...) {
53+
// suppress exceptions
54+
}
55+
5156
do {
5257
// wait until signaled
5358
tasks_[id].signal.acquire();
5459

5560
do {
5661
// invoke the task
5762
while (auto task = tasks_[id].tasks.pop_front()) {
58-
try {
59-
unassigned_tasks_.fetch_sub(1, std::memory_order_release);
60-
std::invoke(std::move(task.value()));
61-
completed_tasks_.fetch_sub(1, std::memory_order_release);
62-
} catch (...) {
63-
}
63+
// decrement the unassigned tasks as the task is now going
64+
// to be executed
65+
unassigned_tasks_.fetch_sub(1, std::memory_order_release);
66+
// invoke the task
67+
std::invoke(std::move(task.value()));
68+
// the above task can push more work onto the pool, so we
69+
// only decrement the in flights once the task has been
70+
// executed because now it's now longer "in flight"
71+
in_flight_tasks_.fetch_sub(1, std::memory_order_release);
6472
}
6573

6674
// try to steal a task
@@ -70,7 +78,7 @@ namespace dp {
7078
// steal a task
7179
unassigned_tasks_.fetch_sub(1, std::memory_order_release);
7280
std::invoke(std::move(task.value()));
73-
completed_tasks_.fetch_sub(1, std::memory_order_release);
81+
in_flight_tasks_.fetch_sub(1, std::memory_order_release);
7482
// stop stealing once we have invoked a stolen task
7583
break;
7684
}
@@ -82,8 +90,9 @@ namespace dp {
8290
priority_queue_.rotate_to_front(id);
8391
// check if all tasks are completed and release the barrier (binary
8492
// semaphore)
85-
if (completed_tasks_.load(std::memory_order_acquire) == 0) {
86-
threads_done_.release();
93+
if (in_flight_tasks_.load(std::memory_order_acquire) == 0) {
94+
threads_complete_signal_.store(true, std::memory_order_release);
95+
threads_complete_signal_.notify_one();
8796
}
8897

8998
} while (!stop_tok.stop_requested());
@@ -214,16 +223,21 @@ namespace dp {
214223
}));
215224
}
216225

226+
/**
227+
* @brief Returns the number of threads in the pool.
228+
*
229+
* @return std::size_t The number of threads in the pool.
230+
*/
217231
[[nodiscard]] auto size() const { return threads_.size(); }
218232

219233
/**
220234
* @brief Wait for all tasks to finish.
221235
* @details This function will block until all tasks have been completed.
222236
*/
223237
void wait_for_tasks() {
224-
if (completed_tasks_.load(std::memory_order_acquire) > 0) {
238+
if (in_flight_tasks_.load(std::memory_order_acquire) > 0) {
225239
// wait for all tasks to finish
226-
threads_done_.acquire();
240+
threads_complete_signal_.wait(false);
227241
}
228242
}
229243

@@ -235,9 +249,19 @@ namespace dp {
235249
// would only be a problem if there are zero threads
236250
return;
237251
}
252+
// get the index
238253
auto i = *(i_opt);
239-
unassigned_tasks_.fetch_add(1, std::memory_order_relaxed);
240-
completed_tasks_.fetch_add(1, std::memory_order_relaxed);
254+
255+
// increment the unassigned tasks and in flight tasks
256+
unassigned_tasks_.fetch_add(1, std::memory_order_release);
257+
const auto prev_in_flight = in_flight_tasks_.fetch_add(1, std::memory_order_release);
258+
259+
// reset the in flight signal if the list was previously empty
260+
if (prev_in_flight == 0) {
261+
threads_complete_signal_.store(false, std::memory_order_release);
262+
}
263+
264+
// assign work
241265
tasks_[i].tasks.push_back(std::forward<Function>(f));
242266
tasks_[i].signal.release();
243267
}
@@ -250,8 +274,9 @@ namespace dp {
250274
std::vector<ThreadType> threads_;
251275
std::deque<task_item> tasks_;
252276
dp::thread_safe_queue<std::size_t> priority_queue_;
253-
std::atomic_int_fast64_t unassigned_tasks_{}, completed_tasks_{};
254-
std::binary_semaphore threads_done_{0};
277+
// guarantee these get zero-initialized
278+
std::atomic_int_fast64_t unassigned_tasks_{0}, in_flight_tasks_{0};
279+
std::atomic_bool threads_complete_signal_{false};
255280
};
256281

257282
/**

test/source/thread_pool.cpp

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <thread_pool/version.h>
55

66
#include <algorithm>
7+
#include <array>
78
#include <iostream>
89
#include <numeric>
910
#include <random>
@@ -469,6 +470,86 @@ TEST_CASE("Ensure wait_for_tasks() properly blocks current execution.") {
469470
CHECK_EQ(counter.load(), total_tasks);
470471
}
471472

473+
TEST_CASE("Ensure wait_for_tasks() properly waits for tasks to fully complete") {
474+
class counter_wrapper {
475+
public:
476+
std::atomic_int counter = 0;
477+
478+
void increment_counter() { counter.fetch_add(1, std::memory_order_release); }
479+
};
480+
481+
dp::thread_pool local_pool{};
482+
constexpr auto task_count = 10;
483+
std::array<int, task_count> counts{{0, 0, 0, 0, 0, 0, 0, 0, 0, 0}};
484+
for (size_t i = 0; i < task_count; i++) {
485+
counter_wrapper cnt_wrp{};
486+
487+
for (size_t var1 = 0; var1 < 17; var1++) {
488+
for (int var2 = 0; var2 < 12; var2++) {
489+
local_pool.enqueue_detach([&cnt_wrp]() { cnt_wrp.increment_counter(); });
490+
}
491+
}
492+
local_pool.wait_for_tasks();
493+
// std::cout << cnt_wrp.counter << std::endl;
494+
counts[i] = cnt_wrp.counter.load(std::memory_order_acquire);
495+
}
496+
497+
auto all_correct_count =
498+
std::ranges::all_of(counts, [](int count) { return count == 17 * 12; });
499+
const auto sum = std::accumulate(counts.begin(), counts.end(), 0);
500+
CHECK_EQ(sum, 17 * 12 * task_count);
501+
CHECK(all_correct_count);
502+
}
503+
504+
TEST_CASE("Ensure wait_for_tasks() can be called multiple times on the same pool") {
505+
class counter_wrapper {
506+
public:
507+
std::atomic_int counter = 0;
508+
509+
void increment_counter() { counter.fetch_add(1, std::memory_order_release); }
510+
};
511+
512+
dp::thread_pool local_pool{};
513+
constexpr auto task_count = 10;
514+
std::array<int, task_count> counts{{0, 0, 0, 0, 0, 0, 0, 0, 0, 0}};
515+
for (size_t i = 0; i < task_count; i++) {
516+
counter_wrapper cnt_wrp{};
517+
518+
for (size_t var1 = 0; var1 < 16; var1++) {
519+
for (int var2 = 0; var2 < 13; var2++) {
520+
local_pool.enqueue_detach([&cnt_wrp]() { cnt_wrp.increment_counter(); });
521+
}
522+
}
523+
local_pool.wait_for_tasks();
524+
// std::cout << cnt_wrp.counter << std::endl;
525+
counts[i] = cnt_wrp.counter.load(std::memory_order_acquire);
526+
}
527+
528+
auto all_correct_count =
529+
std::ranges::all_of(counts, [](int count) { return count == 16 * 13; });
530+
auto sum = std::accumulate(counts.begin(), counts.end(), 0);
531+
CHECK_EQ(sum, 16 * 13 * task_count);
532+
CHECK(all_correct_count);
533+
534+
for (size_t i = 0; i < task_count; i++) {
535+
counter_wrapper cnt_wrp{};
536+
537+
for (size_t var1 = 0; var1 < 17; var1++) {
538+
for (int var2 = 0; var2 < 12; var2++) {
539+
local_pool.enqueue_detach([&cnt_wrp]() { cnt_wrp.increment_counter(); });
540+
}
541+
}
542+
local_pool.wait_for_tasks();
543+
// std::cout << cnt_wrp.counter << std::endl;
544+
counts[i] = cnt_wrp.counter.load(std::memory_order_acquire);
545+
}
546+
547+
all_correct_count = std::ranges::all_of(counts, [](int count) { return count == 17 * 12; });
548+
sum = std::accumulate(counts.begin(), counts.end(), 0);
549+
CHECK_EQ(sum, 17 * 12 * task_count);
550+
CHECK(all_correct_count);
551+
}
552+
472553
TEST_CASE("Initialization function is called") {
473554
std::atomic_int counter = 0;
474555
{

0 commit comments

Comments
 (0)