Skip to content

Commit 72eb422

Browse files
author
Raghuveer Devulapalli
committed
Switch to std::threads from openmp
1 parent b12fe4d commit 72eb422

File tree

5 files changed

+66
-63
lines changed

5 files changed

+66
-63
lines changed

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
test:
2-
meson setup -Dbuild_tests=true -Duse_openmp=false --warnlevel 2 --werror --buildtype release builddir
2+
meson setup -Dbuild_tests=true -Duse_openmp=true --warnlevel 2 --werror --buildtype release builddir
33
cd builddir && ninja
44

55
test_openmp:
66
meson setup -Dbuild_tests=true -Duse_openmp=true --warnlevel 2 --werror --buildtype release builddir
77
cd builddir && ninja
88

99
bench:
10-
meson setup -Dbuild_benchmarks=true --warnlevel 2 --werror --buildtype release builddir
10+
meson setup -Dbuild_benchmarks=true -Duse_openmp=true --warnlevel 2 --werror --buildtype release builddir
1111
cd builddir && ninja
1212

1313
debug:

scripts/bench-compare.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ if [ ! -d .bench/google-benchmark ]; then
1111
fi
1212
compare=$(realpath .bench/google-benchmark/tools/compare.py)
1313

14-
meson setup -Dbuild_benchmarks=true -Dbuild_ippbench=true --warnlevel 0 --buildtype release builddir-${branch}
14+
meson setup -Dbuild_benchmarks=true -Duse_openmp=true --warnlevel 0 --buildtype release builddir-${branch}
1515
cd builddir-${branch}
1616
ninja
1717
$compare filters ./benchexe $1 $2 --benchmark_repetitions=$3

src/avx512-16bit-qsort.hpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,6 @@
99

1010
#include "avx512-16bit-common.h"
1111

12-
struct float16 {
13-
uint16_t val;
14-
};
15-
1612
template <>
1713
struct zmm_vector<float16> {
1814
using type_t = uint16_t;
@@ -555,6 +551,7 @@ avx512_qsort_fp16(uint16_t *arr,
555551
bool descending = false)
556552
{
557553
using vtype = zmm_vector<float16>;
554+
struct threadmanager tm;
558555

559556
// TODO multithreading support here
560557
if (arrsize > 1) {
@@ -565,11 +562,11 @@ avx512_qsort_fp16(uint16_t *arr,
565562
}
566563
if (descending) {
567564
qsort_<vtype, Comparator<vtype, true>, uint16_t>(
568-
arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize), 0);
565+
arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize), tm);
569566
}
570567
else {
571568
qsort_<vtype, Comparator<vtype, false>, uint16_t>(
572-
arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize), 0);
569+
arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize), tm);
573570
}
574571
replace_inf_with_nan(arr, arrsize, nan_count, descending);
575572
}

src/xss-common-includes.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#include <immintrin.h>
88
#include <limits>
99
#include <vector>
10+
#include <thread>
11+
#include <mutex>
1012
#include "xss-custom-float.h"
1113

1214
#define X86_SIMD_SORT_INFINITY std::numeric_limits<double>::infinity()
@@ -87,6 +89,10 @@
8789
#include <omp.h>
8890
#endif
8991

92+
struct float16 {
93+
uint16_t val;
94+
};
95+
9096
template <class... T>
9197
constexpr bool always_false = false;
9298

@@ -109,4 +115,26 @@ enum class simd_type : int { AVX2, AVX512 };
109115
template <typename vtype, typename T = typename vtype::type_t>
110116
X86_SIMD_SORT_INLINE bool comparison_func(const T &a, const T &b);
111117

118+
struct threadmanager {
119+
int max_thread_count;
120+
std::mutex mymutex;
121+
int sharedCount;
122+
arrsize_t task_threshold;
123+
124+
threadmanager() {
125+
#ifdef XSS_COMPILE_OPENMP
126+
max_thread_count = 8;
127+
#else
128+
max_thread_count = 0;
129+
#endif
130+
sharedCount = 0;
131+
task_threshold = 100000;
132+
};
133+
void incrementCount(int ii) {
134+
mymutex.lock();
135+
sharedCount += ii;
136+
mymutex.unlock();
137+
}
138+
};
139+
112140
#endif // XSS_COMMON_INCLUDES

src/xss-common-qsort.h

Lines changed: 32 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,7 @@ static void qsort_(type_t *arr,
525525
arrsize_t left,
526526
arrsize_t right,
527527
arrsize_t max_iters,
528-
arrsize_t task_threshold)
528+
struct threadmanager &tm)
529529
{
530530
/*
531531
* Resort to std::sort if quicksort isnt making any progress
@@ -562,40 +562,49 @@ static void qsort_(type_t *arr,
562562
type_t leftmostValue = comparator::leftmost(smallest, biggest);
563563
type_t rightmostValue = comparator::rightmost(smallest, biggest);
564564

565-
#ifdef XSS_COMPILE_OPENMP
565+
std::thread t1, t2;
566+
bool parallel_left = ((pivot_index - left) > tm.task_threshold)
567+
&& (tm.sharedCount < tm.max_thread_count);
566568
if (pivot != leftmostValue) {
567-
bool parallel_left = (pivot_index - left) > task_threshold;
568569
if (parallel_left) {
569-
#pragma omp task
570-
qsort_<vtype, comparator>(
571-
arr, left, pivot_index - 1, max_iters - 1, task_threshold);
570+
tm.incrementCount(1);
571+
t1 = std::thread(qsort_<vtype, comparator, type_t>,
572+
arr,
573+
left,
574+
pivot_index - 1,
575+
max_iters - 1,
576+
std::ref(tm));
572577
}
573578
else {
574579
qsort_<vtype, comparator>(
575-
arr, left, pivot_index - 1, max_iters - 1, task_threshold);
580+
arr, left, pivot_index - 1, max_iters - 1, tm);
576581
}
577582
}
583+
bool parallel_right = ((right - pivot_index) > tm.task_threshold)
584+
&& (tm.sharedCount < tm.max_thread_count);
578585
if (pivot != rightmostValue) {
579-
bool parallel_right = (right - pivot_index) > task_threshold;
580-
581586
if (parallel_right) {
582-
#pragma omp task
583-
qsort_<vtype, comparator>(
584-
arr, pivot_index, right, max_iters - 1, task_threshold);
587+
tm.incrementCount(1);
588+
t2 = std::thread(qsort_<vtype, comparator, type_t>,
589+
arr,
590+
pivot_index,
591+
right,
592+
max_iters - 1,
593+
std::ref(tm));
585594
}
586595
else {
587596
qsort_<vtype, comparator>(
588-
arr, pivot_index, right, max_iters - 1, task_threshold);
597+
arr, pivot_index, right, max_iters - 1, tm);
589598
}
590599
}
591-
#else
592-
UNUSED(task_threshold);
593-
594-
if (pivot != leftmostValue)
595-
qsort_<vtype, comparator>(arr, left, pivot_index - 1, max_iters - 1, 0);
596-
if (pivot != rightmostValue)
597-
qsort_<vtype, comparator>(arr, pivot_index, right, max_iters - 1, 0);
598-
#endif
600+
if (t1.joinable()) {
601+
t1.join();
602+
tm.incrementCount(-1);
603+
}
604+
if (t2.joinable()) {
605+
t2.join();
606+
tm.incrementCount(-1);
607+
}
599608
}
600609

601610
template <typename vtype, typename comparator, typename type_t>
@@ -661,40 +670,9 @@ X86_SIMD_SORT_INLINE void xss_qsort(T *arr, arrsize_t arrsize, bool hasnan)
661670

662671
UNUSED(hasnan);
663672

664-
#ifdef XSS_COMPILE_OPENMP
665-
666-
bool use_parallel = arrsize > 100000;
667-
668-
if (use_parallel) {
669-
// This thread limit was determined experimentally; it may be better for it to be the number of physical cores on the system
670-
constexpr int thread_limit = 8;
671-
int thread_count = std::min(thread_limit, omp_get_max_threads());
672-
arrsize_t task_threshold
673-
= std::max((arrsize_t)100000, arrsize / 100);
674-
675-
// We use omp parallel and then omp single to setup the threads that will run the omp task calls in qsort_
676-
// The omp single prevents multiple threads from running the initial qsort_ simultaneously and causing problems
677-
// Note that we do not use the if(...) clause built into OpenMP, because it causes a performance regression for small arrays
678-
#pragma omp parallel num_threads(thread_count)
679-
#pragma omp single
680-
qsort_<vtype, comparator, T>(arr,
681-
0,
682-
arrsize - 1,
683-
2 * (arrsize_t)log2(arrsize),
684-
task_threshold);
685-
}
686-
else {
687-
qsort_<vtype, comparator, T>(arr,
688-
0,
689-
arrsize - 1,
690-
2 * (arrsize_t)log2(arrsize),
691-
std::numeric_limits<arrsize_t>::max());
692-
}
693-
#pragma omp taskwait
694-
#else
673+
struct threadmanager tm;
695674
qsort_<vtype, comparator, T>(
696-
arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize), 0);
697-
#endif
675+
arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize), tm);
698676

699677
replace_inf_with_nan(arr, arrsize, nan_count, descending);
700678
}

0 commit comments

Comments
 (0)