Skip to content

Commit d8a9d28

Browse files
committed
Add OpenMP acceleration for normal quicksort
1 parent 59e298d commit d8a9d28

File tree

6 files changed

+85
-13
lines changed

6 files changed

+85
-13
lines changed

src/avx512-16bit-qsort.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,7 @@ avx512_qsort_fp16(uint16_t *arr,
556556
{
557557
using vtype = zmm_vector<float16>;
558558

559+
// TODO multithreading support here
559560
if (arrsize > 1) {
560561
arrsize_t nan_count = 0;
561562
if (UNLIKELY(hasnan)) {
@@ -564,11 +565,11 @@ avx512_qsort_fp16(uint16_t *arr,
564565
}
565566
if (descending) {
566567
qsort_<vtype, Comparator<vtype, true>, uint16_t>(
567-
arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize));
568+
arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize), 0);
568569
}
569570
else {
570571
qsort_<vtype, Comparator<vtype, false>, uint16_t>(
571-
arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize));
572+
arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize), 0);
572573
}
573574
replace_inf_with_nan(arr, arrsize, nan_count, descending);
574575
}

src/avx512-64bit-common.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -968,7 +968,8 @@ struct zmm_vector<uint64_t> {
968968
static_assert(sizeof(size_t) == sizeof(uint64_t),
969969
"Size of size_t and uint64_t are not the same");
970970
template <>
971-
struct zmm_vector<size_t> : public zmm_vector<uint64_t> {};
971+
struct zmm_vector<size_t> : public zmm_vector<uint64_t> {
972+
};
972973
#endif
973974

974975
template <>

src/xss-common-includes.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,11 @@
8282
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, \
8383
21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31
8484

85+
#if defined(XSS_USE_OPENMP) && defined(_OPENMP)
86+
#define XSS_COMPILE_OPENMP
87+
#include <omp.h>
88+
#endif
89+
8590
template <class... T>
8691
constexpr bool always_false = false;
8792

src/xss-common-keyvaluesort.hpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,6 @@
1111
#include "xss-common-qsort.h"
1212
#include "xss-network-keyvaluesort.hpp"
1313

14-
#if defined(XSS_USE_OPENMP) && defined(_OPENMP)
15-
#define XSS_COMPILE_OPENMP
16-
#include <omp.h>
17-
#endif
18-
1914
/*
2015
* Parition one ZMM register based on the pivot and returns the index of the
2116
* last element that is less than equal to the pivot.

src/xss-common-qsort.h

Lines changed: 71 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -521,8 +521,11 @@ template <typename vtype, int maxN>
521521
void sort_n(typename vtype::type_t *arr, int N);
522522

523523
template <typename vtype, typename comparator, typename type_t>
524-
static void
525-
qsort_(type_t *arr, arrsize_t left, arrsize_t right, arrsize_t max_iters)
524+
static void qsort_(type_t *arr,
525+
arrsize_t left,
526+
arrsize_t right,
527+
arrsize_t max_iters,
528+
arrsize_t task_threshold)
526529
{
527530
/*
528531
* Resort to std::sort if quicksort isnt making any progress
@@ -559,10 +562,40 @@ qsort_(type_t *arr, arrsize_t left, arrsize_t right, arrsize_t max_iters)
559562
type_t leftmostValue = comparator::leftmost(smallest, biggest);
560563
type_t rightmostValue = comparator::rightmost(smallest, biggest);
561564

565+
#ifdef XSS_COMPILE_OPENMP
566+
if (pivot != leftmostValue) {
567+
bool parallel_left = (pivot_index - left) > task_threshold;
568+
if (parallel_left) {
569+
#pragma omp task
570+
qsort_<vtype, comparator>(
571+
arr, left, pivot_index - 1, max_iters - 1, task_threshold);
572+
}
573+
else {
574+
qsort_<vtype, comparator>(
575+
arr, left, pivot_index - 1, max_iters - 1, task_threshold);
576+
}
577+
}
578+
if (pivot != rightmostValue) {
579+
bool parallel_right = (right - pivot_index) > task_threshold;
580+
581+
if (parallel_right) {
582+
#pragma omp task
583+
qsort_<vtype, comparator>(
584+
arr, pivot_index, right, max_iters - 1, task_threshold);
585+
}
586+
else {
587+
qsort_<vtype, comparator>(
588+
arr, pivot_index, right, max_iters - 1, task_threshold);
589+
}
590+
}
591+
#else
592+
UNUSED(task_threshold);
593+
562594
if (pivot != leftmostValue)
563-
qsort_<vtype, comparator>(arr, left, pivot_index - 1, max_iters - 1);
595+
qsort_<vtype, comparator>(arr, left, pivot_index - 1, max_iters - 1, 0);
564596
if (pivot != rightmostValue)
565-
qsort_<vtype, comparator>(arr, pivot_index, right, max_iters - 1);
597+
qsort_<vtype, comparator>(arr, pivot_index, right, max_iters - 1, 0);
598+
#endif
566599
}
567600

568601
template <typename vtype, typename comparator, typename type_t>
@@ -627,8 +660,41 @@ X86_SIMD_SORT_INLINE void xss_qsort(T *arr, arrsize_t arrsize, bool hasnan)
627660
}
628661

629662
UNUSED(hasnan);
663+
664+
#ifdef XSS_COMPILE_OPENMP
665+
666+
bool use_parallel = arrsize > 10000;
667+
668+
if (use_parallel) {
669+
// This thread limit was determined experimentally; it may be better for it to be the number of physical cores on the system
670+
constexpr int thread_limit = 8;
671+
int thread_count = std::min(thread_limit, omp_get_max_threads());
672+
arrsize_t task_threshold
673+
= std::max((arrsize_t)10000, arrsize / 100);
674+
675+
// We use omp parallel and then omp single to setup the threads that will run the omp task calls in qsort_
676+
// The omp single prevents multiple threads from running the initial qsort_ simultaneously and causing problems
677+
// Note that we do not use the if(...) clause built into OpenMP, because it causes a performance regression for small arrays
678+
#pragma omp parallel num_threads(thread_count)
679+
#pragma omp single
680+
qsort_<vtype, comparator, T>(arr,
681+
0,
682+
arrsize - 1,
683+
2 * (arrsize_t)log2(arrsize),
684+
task_threshold);
685+
}
686+
else {
687+
qsort_<vtype, comparator, T>(arr,
688+
0,
689+
arrsize - 1,
690+
2 * (arrsize_t)log2(arrsize),
691+
std::numeric_limits<arrsize_t>::max());
692+
}
693+
#pragma omp taskwait
694+
#else
630695
qsort_<vtype, comparator, T>(
631-
arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize));
696+
arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize), 0);
697+
#endif
632698

633699
replace_inf_with_nan(arr, arrsize, nan_count, descending);
634700
}

tests/test-qsort.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ class simdsort : public ::testing::Test {
1111
simdsort()
1212
{
1313
std::iota(arrsize.begin(), arrsize.end(), 1);
14+
arrsize.push_back(10'000);
15+
arrsize.push_back(100'000);
16+
arrsize.push_back(1'000'000);
17+
1418
arrtype = {"random",
1519
"constant",
1620
"sorted",

0 commit comments

Comments
 (0)