Skip to content

Commit 6c46d71

Browse files
author
Raghuveer Devulapalli
committed
Move threadpool under xss:tp:: namespace
1 parent 07b5a6e commit 6c46d71

File tree

3 files changed

+135
-116
lines changed

3 files changed

+135
-116
lines changed

src/avx512-16bit-qsort.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,7 @@ avx512_qsort_fp16_helper(uint16_t *arr, arrsize_t arrsize)
564564
arrsize_t task_threshold = std::max((arrsize_t)100000, arrsize / 100);
565565

566566
// Create a thread pool
567-
ThreadPool pool(thread_count);
567+
xss::tp::ThreadPool pool(thread_count);
568568

569569
// Initial sort task
570570
qsort_threads<vtype, comparator, T>(arr,

src/xss-common-qsort.h

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,7 @@ static void qsort_threads(type_t *arr,
588588
arrsize_t right,
589589
arrsize_t max_iters,
590590
arrsize_t task_threshold,
591-
ThreadPool &thread_pool)
591+
xss::tp::ThreadPool &thread_pool)
592592
{
593593
/*
594594
* Resort to std::sort if quicksort isn't making any progress
@@ -629,20 +629,21 @@ static void qsort_threads(type_t *arr,
629629
if (pivot != leftmostValue) {
630630
bool parallel_left = (pivot_index - left) > task_threshold;
631631
if (parallel_left) {
632-
submit_task(thread_pool,
633-
[arr,
634-
left,
635-
pivot_index,
636-
max_iters,
637-
task_threshold,
638-
&thread_pool]() {
639-
qsort_threads<vtype, comparator>(arr,
640-
left,
641-
pivot_index - 1,
642-
max_iters - 1,
643-
task_threshold,
644-
thread_pool);
645-
});
632+
xss::tp::submit_task(thread_pool,
633+
[arr,
634+
left,
635+
pivot_index,
636+
max_iters,
637+
task_threshold,
638+
&thread_pool]() {
639+
qsort_threads<vtype, comparator>(
640+
arr,
641+
left,
642+
pivot_index - 1,
643+
max_iters - 1,
644+
task_threshold,
645+
thread_pool);
646+
});
646647
}
647648
else {
648649
qsort_threads<vtype, comparator>(arr,
@@ -658,20 +659,21 @@ static void qsort_threads(type_t *arr,
658659
if (pivot != rightmostValue) {
659660
bool parallel_right = (right - pivot_index) > task_threshold;
660661
if (parallel_right) {
661-
submit_task(thread_pool,
662-
[arr,
663-
pivot_index,
664-
right,
665-
max_iters,
666-
task_threshold,
667-
&thread_pool]() {
668-
qsort_threads<vtype, comparator>(arr,
669-
pivot_index,
670-
right,
671-
max_iters - 1,
672-
task_threshold,
673-
thread_pool);
674-
});
662+
xss::tp::submit_task(thread_pool,
663+
[arr,
664+
pivot_index,
665+
right,
666+
max_iters,
667+
task_threshold,
668+
&thread_pool]() {
669+
qsort_threads<vtype, comparator>(
670+
arr,
671+
pivot_index,
672+
right,
673+
max_iters - 1,
674+
task_threshold,
675+
thread_pool);
676+
});
675677
}
676678
else {
677679
qsort_threads<vtype, comparator>(arr,
@@ -764,7 +766,7 @@ X86_SIMD_SORT_INLINE void xss_qsort(T *arr, arrsize_t arrsize, bool hasnan)
764766
= std::max((arrsize_t)100000, arrsize / 100);
765767

766768
// Create a thread pool
767-
ThreadPool pool(thread_count);
769+
xss::tp::ThreadPool pool(thread_count);
768770

769771
// Initial sort task
770772
qsort_threads<vtype, comparator, T>(arr,

src/xss-thread-pool.hpp

Lines changed: 102 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616
#include <vector>
1717
#include <atomic>
1818

19-
/*
19+
namespace xss {
20+
namespace tp {
21+
22+
/*
2023
* ThreadPool class and doc: Generated by copilot
2124
* This thread pool implementation is a simple and efficient way to manage a
2225
* pool of threads for executing tasks concurrently. It uses a std::queue to store
@@ -25,102 +28,116 @@
2528
* execute it. The thread pool can be stopped gracefully, and it also provides
2629
* a way to wait for all tasks to complete before stopping.
2730
* */
28-
class ThreadPool {
29-
private:
30-
std::vector<std::thread> workers;
31-
std::queue<std::function<void()>> tasks;
32-
std::mutex queue_mutex;
33-
std::condition_variable condition; // Condition variable for task queue
34-
std::condition_variable done_condition; // Condition variable for waiting
35-
int active_tasks {0};
36-
bool stop;
31+
class ThreadPool {
32+
private:
33+
std::vector<std::thread> workers;
34+
std::queue<std::function<void()>> tasks;
35+
std::mutex queue_mutex;
36+
std::condition_variable condition; // Condition variable for task queue
37+
std::condition_variable
38+
done_condition; // Condition variable for waiting
39+
int active_tasks {0};
40+
bool stop;
3741

38-
public:
39-
ThreadPool(size_t num_threads) : stop(false)
40-
{
41-
for (size_t i = 0; i < num_threads; ++i) {
42-
// Create a worker thread and add it to the pool
43-
// Each thread will run a lambda function that waits for tasks
44-
workers.emplace_back([this] {
45-
while (true) {
46-
// Lock the queue mutex and wait for a task to be available
47-
std::unique_lock<std::mutex> lock(queue_mutex);
48-
// Wait until there is a task or the pool is stopped
49-
condition.wait(lock,
50-
[this] { return stop || !tasks.empty(); });
42+
public:
43+
ThreadPool(size_t num_threads) : stop(false)
44+
{
45+
for (size_t i = 0; i < num_threads; ++i) {
46+
// Create a worker thread and add it to the pool
47+
// Each thread will run a lambda function that waits for tasks
48+
workers.emplace_back([this] {
49+
while (true) {
50+
// Lock the queue mutex and wait for a task to be available
51+
std::unique_lock<std::mutex> lock(queue_mutex);
52+
// Wait until there is a task or the pool is stopped
53+
condition.wait(lock, [this] {
54+
return stop || !tasks.empty();
55+
});
5156

52-
// Check if we need to terminate the thread
53-
if (stop && tasks.empty()) { return; }
57+
// Check if we need to terminate the thread
58+
if (stop && tasks.empty()) { return; }
5459

55-
// Extract the next task from the queue
56-
auto task = std::move(tasks.front());
57-
tasks.pop();
58-
// Unlock the mutex before executing the task
59-
lock.unlock();
60-
// Execute the task:
61-
task();
62-
}
63-
});
60+
// Extract the next task from the queue
61+
auto task = std::move(tasks.front());
62+
tasks.pop();
63+
// Unlock the mutex before executing the task
64+
lock.unlock();
65+
// Execute the task:
66+
task();
67+
}
68+
});
69+
}
6470
}
65-
}
6671

67-
template <class F>
68-
void enqueue(F &&func)
69-
{
70-
// Add a new task to the queue and notify one of the worker threads
71-
std::unique_lock<std::mutex> lock(queue_mutex);
72-
tasks.emplace(std::forward<F>(func));
73-
condition.notify_one();
74-
}
72+
template <class F>
73+
void enqueue(F &&func)
74+
{
75+
// Add a new task to the queue and notify one of the worker threads
76+
std::unique_lock<std::mutex> lock(queue_mutex);
77+
tasks.emplace(std::forward<F>(func));
78+
condition.notify_one();
79+
}
7580

76-
~ThreadPool()
77-
{
78-
// Stop the thread pool and join all threads
79-
std::unique_lock<std::mutex> lock(queue_mutex);
80-
stop = true;
81-
lock.unlock();
82-
condition.notify_all();
83-
for (std::thread &worker : workers) {
84-
worker.join();
81+
~ThreadPool()
82+
{
83+
// Stop the thread pool and join all threads
84+
std::unique_lock<std::mutex> lock(queue_mutex);
85+
stop = true;
86+
lock.unlock();
87+
condition.notify_all();
88+
for (std::thread &worker : workers) {
89+
worker.join();
90+
}
8591
}
86-
}
8792

88-
// Wait for all tasks to complete before stopping the pool
89-
void wait_all()
90-
{
91-
std::unique_lock<std::mutex> lock(queue_mutex);
92-
done_condition.wait(
93-
lock, [this] { return tasks.empty() && (active_tasks == 0); });
94-
// lock is automatically released here
95-
}
93+
// Wait for all tasks to complete before stopping the pool
94+
void wait_all()
95+
{
96+
std::unique_lock<std::mutex> lock(queue_mutex);
97+
done_condition.wait(lock, [this] {
98+
return tasks.empty() && (active_tasks == 0);
99+
});
100+
// lock is automatically released here
101+
}
96102

97-
// Track the number of active tasks
98-
void task_start()
99-
{
100-
std::unique_lock<std::mutex> lock(queue_mutex);
101-
active_tasks++;
102-
// lock is automatically released here
103-
}
103+
// Track the number of active tasks
104+
void task_start()
105+
{
106+
std::unique_lock<std::mutex> lock(queue_mutex);
107+
active_tasks++;
108+
// lock is automatically released here
109+
}
110+
111+
// Decrement the active task count and notify if all tasks are done
112+
void task_end()
113+
{
114+
std::unique_lock<std::mutex> lock(queue_mutex);
115+
active_tasks--;
116+
if (tasks.empty() && active_tasks == 0) {
117+
done_condition.notify_all();
118+
}
119+
// lock is automatically released here
120+
}
121+
};
104122

105-
// Decrement the active task count and notify if all tasks are done
106-
void task_end()
123+
// Wrapper for submitting tasks to the thread pool with automatic tracking
124+
template <typename F>
125+
void submit_task(ThreadPool &pool, F &&f)
107126
{
108-
std::unique_lock<std::mutex> lock(queue_mutex);
109-
active_tasks--;
110-
if (tasks.empty() && active_tasks == 0) { done_condition.notify_all(); }
111-
// lock is automatically released here
127+
pool.task_start();
128+
pool.enqueue([f = std::forward<F>(f), &pool]() {
129+
try {
130+
f();
131+
} catch (...) {
132+
// Ensure task_end is called even if the task throws an exception
133+
pool.task_end();
134+
throw; // Re-throw the exception
135+
}
136+
pool.task_end();
137+
});
112138
}
113-
};
114139

115-
// Wrapper for submitting tasks to the thread pool with automatic tracking
116-
template <typename F>
117-
void submit_task(ThreadPool &pool, F &&f)
118-
{
119-
pool.task_start();
120-
pool.enqueue([f = std::forward<F>(f), &pool]() {
121-
f();
122-
pool.task_end();
123-
});
124-
}
140+
} // namespace tp
141+
} // namespace xss
125142

126143
#endif // XSS_THREAD_POOL

0 commit comments

Comments
 (0)