Skip to content

Commit 9b0557d

Browse files
authored
feat: add clear_tasks() (#69)
* feat: add `clear_tasks()` * test: add tests for `thread_safe_queue::clear()` * test: add tests for `thread_pool::clear_tasks()`
1 parent 8bd7661 commit 9b0557d

File tree

4 files changed

+144
-0
lines changed

4 files changed

+144
-0
lines changed

include/thread_pool/thread_pool.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,22 @@ namespace dp {
241241
}
242242
}
243243

244+
/**
245+
* @brief Makes best-case attempt to clear all tasks from the thread_pool
246+
* @details Note that this does not guarantee that all tasks will be cleared, as currently
247+
* running tasks could add additional tasks. Also a thread could steal a task from another
248+
* in the middle of this.
249+
* @return number of tasks cleared
250+
*/
251+
size_t clear_tasks() {
252+
size_t removed_task_count{0};
253+
for (auto &task_list : tasks_) removed_task_count += task_list.tasks.clear();
254+
in_flight_tasks_.fetch_sub(removed_task_count, std::memory_order_release);
255+
unassigned_tasks_.fetch_sub(removed_task_count, std::memory_order_release);
256+
257+
return removed_task_count;
258+
}
259+
244260
private:
245261
template <typename Function>
246262
void enqueue_task(Function &&f) {

include/thread_pool/thread_safe_queue.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,14 @@ namespace dp {
4444
return data_.empty();
4545
}
4646

47+
size_type clear() {
48+
std::scoped_lock lock(mutex_);
49+
auto size = data_.size();
50+
data_.clear();
51+
52+
return size;
53+
}
54+
4755
[[nodiscard]] std::optional<T> pop_front() {
4856
std::scoped_lock lock(mutex_);
4957
if (data_.empty()) return std::nullopt;

test/source/thread_pool.cpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55

66
#include <algorithm>
77
#include <array>
8+
#include <barrier>
89
#include <iostream>
910
#include <numeric>
1011
#include <random>
12+
#include <shared_mutex>
1113
#include <string>
1214
#include <thread>
1315

@@ -560,3 +562,86 @@ TEST_CASE("Initialization function is called") {
560562
}
561563
CHECK_EQ(counter.load(), 4);
562564
}
565+
566+
TEST_CASE("Check clear_tasks() can be called from a task") {
567+
// Here:
568+
// - we use a barrier to trigger tasks_clear() once all threads are busy;
569+
// - to prevent race conditions (e.g. task_clear() getting called whilst we are still adding
570+
// tasks), we use a mutex to prevent the tasks from running, until all tasks have been added
571+
// to the pool.
572+
573+
unsigned int thread_count = 0;
574+
575+
SUBCASE("with single thread") { thread_count = 1; }
576+
SUBCASE("with multiple threads") { thread_count = 4; }
577+
578+
std::atomic<unsigned int> counter = 0;
579+
dp::thread_pool pool(thread_count);
580+
std::shared_mutex mutex;
581+
582+
{
583+
/* Clear thread_pool when barrier is hit, this must not throw */
584+
auto clear_func = [&pool]() noexcept {
585+
try {
586+
pool.clear_tasks();
587+
} catch (...) {
588+
}
589+
};
590+
std::barrier sync_point(thread_count, clear_func);
591+
592+
auto func = [&counter, &sync_point, &mutex]() {
593+
std::shared_lock lock(mutex);
594+
counter.fetch_add(1);
595+
sync_point.arrive_and_wait();
596+
};
597+
598+
{
599+
std::unique_lock lock(mutex);
600+
for (int i = 0; i < 10; i++) pool.enqueue_detach(func);
601+
}
602+
603+
pool.wait_for_tasks();
604+
}
605+
606+
CHECK_EQ(counter.load(), thread_count);
607+
}
608+
609+
TEST_CASE("Check clear_tasks() clears tasks") {
610+
// Here we:
611+
// - add twice as many tasks to the pool as can be run simultaniously
612+
// - use a lock to prevent race conditions (e.g. clear_task() running whilst the another task is
613+
// being added)
614+
615+
unsigned int thread_count{4};
616+
size_t cleared_tasks{0};
617+
std::atomic<unsigned int> counter{0};
618+
619+
SUBCASE("with no thread") { thread_count = 0; }
620+
SUBCASE("with single thread") { thread_count = 1; }
621+
SUBCASE("with multiple threads") { thread_count = 4; }
622+
623+
{
624+
std::mutex mutex;
625+
dp::thread_pool pool(thread_count);
626+
627+
std::function<void(void)> func;
628+
func = [&counter, &mutex]() {
629+
counter.fetch_add(1);
630+
std::lock_guard lock(mutex);
631+
};
632+
633+
{
634+
/* fill the thread_pool twice over, and wait until all threads running and locked in a
635+
* task */
636+
std::lock_guard lock(mutex);
637+
for (unsigned int i = 0; i < 2 * thread_count; i++) pool.enqueue_detach(func);
638+
639+
while (counter != thread_count)
640+
std::this_thread::sleep_for(std::chrono::milliseconds(100));
641+
642+
cleared_tasks = pool.clear_tasks();
643+
}
644+
}
645+
CHECK_EQ(cleared_tasks, static_cast<size_t>(thread_count));
646+
CHECK_EQ(thread_count, counter.load());
647+
}

test/source/thread_safe_queue.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,38 @@ TEST_CASE("Ensure insert and pop works with thread contention") {
5151
CHECK_NE(res2, res3);
5252
CHECK_NE(res3, res1);
5353
}
54+
55+
TEST_CASE("Ensure clear() works and returns correct count") {
56+
// create a synchronization barrier to ensure our threads have started before executing code to
57+
// clear the queue
58+
59+
// here, we check that:
60+
// - the queue is cleared
61+
// - that clear() return the correct number
62+
63+
std::barrier barrier(3);
64+
std::atomic<size_t> removed_count{0};
65+
66+
dp::thread_safe_queue<int> queue;
67+
{
68+
std::jthread t1([&queue, &barrier, &removed_count] {
69+
queue.push_front(1);
70+
barrier.arrive_and_wait();
71+
removed_count = queue.clear();
72+
barrier.arrive_and_wait();
73+
});
74+
std::jthread t2([&queue, &barrier] {
75+
queue.push_front(2);
76+
barrier.arrive_and_wait();
77+
barrier.arrive_and_wait();
78+
});
79+
std::jthread t3([&queue, &barrier] {
80+
queue.push_front(3);
81+
barrier.arrive_and_wait();
82+
barrier.arrive_and_wait();
83+
});
84+
}
85+
86+
CHECK(queue.empty());
87+
CHECK_EQ(removed_count, 3);
88+
}

0 commit comments

Comments
 (0)