Skip to content

Commit ac9aa92

Browse files
committed
Thread pool2 tests and implementation finishing touches
1 parent ad2b5de commit ac9aa92

File tree

3 files changed

+99
-11
lines changed

3 files changed

+99
-11
lines changed

cpr/threadpool2.cpp

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
#include "cpr/threadpool2.h"
22
#include <cassert>
3-
#include <cstddef>
3+
#include <chrono>
44
#include <condition_variable>
5-
#include <memory>
5+
#include <cstddef>
66
#include <functional>
7+
#include <memory>
78
#include <mutex>
8-
#include <chrono>
99
#include <thread>
10+
#include <utility>
1011

1112
namespace cpr {
1213
size_t ThreadPool2::DEFAULT_MAX_THREAD_COUNT = std::thread::hardware_concurrency();
@@ -49,15 +50,12 @@ void ThreadPool2::SetMaxThreadCount(size_t maxThreadCount) {
4950
void ThreadPool2::Start() {
5051
const std::unique_lock lock(controlMutex);
5152
setState(State::RUNNING);
52-
53-
for (size_t i = 0; i < minThreadCount; i++) {
54-
addThread();
55-
}
5653
}
5754

5855
void ThreadPool2::Stop() {
5956
const std::unique_lock controlLock(controlMutex);
6057
setState(State::STOP);
58+
taskQueueCondVar.notify_all();
6159

6260
// Join all workers
6361
const std::unique_lock workersLock{workerMutex};
@@ -70,11 +68,21 @@ void ThreadPool2::Stop() {
7068
}
7169
}
7270

71+
void ThreadPool2::Wait() {
72+
while (true) {
73+
if ((state != State::RUNNING && curThreadCount <= 0) || (tasks.empty() && curThreadCount <= idleThreadCount)) {
74+
break;
75+
}
76+
std::this_thread::yield();
77+
}
78+
}
79+
7380
void ThreadPool2::setState(State state) {
7481
const std::unique_lock lock(controlMutex);
7582
if (this->state == state) {
7683
return;
7784
}
85+
this->state = state;
7886
}
7987

8088
void ThreadPool2::addThread() {
@@ -84,14 +92,17 @@ void ThreadPool2::addThread() {
8492
workers.emplace_back();
8593
workers.back().thread = std::make_unique<std::thread>(&ThreadPool2::threadFunc, this, std::ref(workers.back()));
8694
curThreadCount++;
95+
idleThreadCount++;
8796
}
8897

8998
void ThreadPool2::threadFunc(WorkerThread& workerThread) {
9099
while (true) {
91100
std::cv_status result{std::cv_status::timeout};
92101
{
93102
std::unique_lock lock(taskQueueMutex);
94-
result = taskQueueCondVar.wait_for(lock, std::chrono::milliseconds(250));
103+
if (tasks.empty()) {
104+
result = taskQueueCondVar.wait_for(lock, std::chrono::milliseconds(250));
105+
}
95106
}
96107

97108
if (state == State::STOP) {
@@ -109,6 +120,16 @@ void ThreadPool2::threadFunc(WorkerThread& workerThread) {
109120
}
110121

111122
// Check for tasks and execute one
123+
const std::unique_lock lock(taskQueueMutex);
124+
if (!tasks.empty()) {
125+
idleThreadCount--;
126+
const std::function<void()> task = std::move(tasks.front());
127+
tasks.pop();
128+
129+
// Execute the task
130+
task();
131+
}
132+
idleThreadCount++;
112133
}
113134

114135
workerThread.state = State::STOP;

include/cpr/threadpool2.h

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@
44
#include <condition_variable>
55
#include <cstddef>
66
#include <cstdint>
7+
#include <functional>
8+
#include <future>
79
#include <list>
810
#include <memory>
911
#include <mutex>
12+
#include <queue>
1013
#include <thread>
1114

1215
namespace cpr {
@@ -16,7 +19,7 @@ class ThreadPool2 {
1619
static size_t DEFAULT_MAX_THREAD_COUNT;
1720

1821
private:
19-
enum class State : uint8_t { STOP, RUNNING, PAUSE };
22+
enum class State : uint8_t { STOP, RUNNING };
2023
struct WorkerThread {
2124
std::unique_ptr<std::thread> thread{nullptr};
2225
State state{State::RUNNING};
@@ -28,11 +31,13 @@ class ThreadPool2 {
2831

2932
std::mutex taskQueueMutex;
3033
std::condition_variable taskQueueCondVar;
34+
std::queue<std::function<void()>> tasks;
3135

3236
std::atomic<State> state = State::STOP;
3337
std::atomic_size_t minThreadCount;
3438
std::atomic_size_t curThreadCount{0};
3539
std::atomic_size_t maxThreadCount;
40+
std::atomic_size_t idleThreadCount{0};
3641

3742
std::recursive_mutex controlMutex;
3843

@@ -55,6 +60,36 @@ class ThreadPool2 {
5560

5661
void Start();
5762
void Stop();
63+
void Wait();
64+
65+
/**
66+
* Return a future, calling future.get() will wait task done and return RetType.
67+
* Submit(fn, args...)
68+
* Submit(std::bind(&Class::mem_fn, &obj))
69+
* Submit(std::mem_fn(&Class::mem_fn, &obj))
70+
**/
71+
template <class Fn, class... Args>
72+
auto Submit(Fn&& fn, Args&&... args) {
73+
// Add a new worker thread in case the tasks queue is not empty and we still can add a thread
74+
{
75+
std::unique_lock lock(taskQueueMutex);
76+
if (idleThreadCount < tasks.size() && curThreadCount < maxThreadCount) {
77+
addThread();
78+
}
79+
}
80+
81+
// Add task to queue
82+
using RetType = decltype(fn(args...));
83+
const std::shared_ptr<std::packaged_task<RetType()>> task = std::make_shared<std::packaged_task<RetType()>>([fn = std::forward<Fn>(fn), args...]() mutable { return std::invoke(fn, args...); });
84+
std::future<RetType> future = task->get_future();
85+
{
86+
std::unique_lock lock(taskQueueMutex);
87+
tasks.emplace([task] { (*task)(); });
88+
}
89+
90+
taskQueueCondVar.notify_one();
91+
return future;
92+
}
5893

5994
private:
6095
void setState(State newState);

test/threadpool2_tests.cpp

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,40 @@
44

55
#include "cpr/threadpool2.h"
66

7-
TEST(ThreadPool2Tests, StartStop) {
8-
cpr::ThreadPool2 tp(1, 1);
7+
TEST(ThreadPool2Tests, BasicWorkOneThread) {
8+
std::atomic_uint32_t invCount{0};
9+
uint32_t invCountExpected{100};
10+
11+
{
12+
cpr::ThreadPool2 tp(1, 1);
13+
14+
for (size_t i = 0; i < invCountExpected; ++i) {
15+
tp.Submit([&invCount]() -> void { invCount++; });
16+
}
17+
18+
// Wait for the thread pool to finish its work
19+
tp.Wait();
20+
}
21+
22+
EXPECT_EQ(invCount, invCountExpected);
23+
}
24+
25+
TEST(ThreadPool2Tests, BasicWorkMultipleThreads) {
26+
std::atomic_uint32_t invCount{0};
27+
uint32_t invCountExpected{100};
28+
29+
{
30+
cpr::ThreadPool2 tp(1, 10);
31+
32+
for (size_t i = 0; i < invCountExpected; ++i) {
33+
tp.Submit([&invCount]() -> void { invCount++; });
34+
}
35+
36+
// Wait for the thread pool to finish its work
37+
tp.Wait();
38+
}
39+
40+
EXPECT_EQ(invCount, invCountExpected);
941
}
1042

1143
int main(int argc, char** argv) {

0 commit comments

Comments
 (0)