Skip to content

Commit a3d22a5

Browse files
author
Raghuveer Devulapalli
committed
Replace openmp with std::threads: almost entirely written by copilot AI
1 parent 2c39de4 commit a3d22a5

File tree

3 files changed

+276
-63
lines changed

3 files changed

+276
-63
lines changed

src/avx512-16bit-qsort.hpp

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -548,25 +548,35 @@ avx512_qsort_fp16_helper(uint16_t *arr, arrsize_t arrsize)
548548
using T = uint16_t;
549549
using vtype = zmm_vector<float16>;
550550

551-
#ifdef XSS_COMPILE_OPENMP
551+
#ifdef XSS_BUILD_WITH_STD_THREADS
552552
bool use_parallel = arrsize > 100000;
553+
#else
554+
bool use_parallel = false;
555+
#endif
553556

554557
if (use_parallel) {
555-
// This thread limit was determined experimentally; it may be better for it to be the number of physical cores on the system
558+
#ifdef XSS_BUILD_WITH_STD_THREADS
559+
560+
// This thread limit was determined experimentally
556561
constexpr int thread_limit = 8;
557-
int thread_count = std::min(thread_limit, omp_get_max_threads());
562+
int thread_count = std::min(thread_limit,
563+
(int)std::thread::hardware_concurrency());
558564
arrsize_t task_threshold = std::max((arrsize_t)100000, arrsize / 100);
559565

560-
// We use omp parallel and then omp single to setup the threads that will run the omp task calls in qsort_
561-
// The omp single prevents multiple threads from running the initial qsort_ simultaneously and causing problems
562-
// Note that we do not use the if(...) clause built into OpenMP, because it causes a performance regression for small arrays
563-
#pragma omp parallel num_threads(thread_count)
564-
#pragma omp single
565-
qsort_<vtype, comparator, T>(arr,
566-
0,
567-
arrsize - 1,
568-
2 * (arrsize_t)log2(arrsize),
569-
task_threshold);
566+
// Create a thread pool
567+
ThreadPool pool(thread_count);
568+
569+
// Initial sort task
570+
qsort_threads<vtype, comparator, T>(arr,
571+
0,
572+
arrsize - 1,
573+
2 * (arrsize_t)log2(arrsize),
574+
task_threshold,
575+
pool);
576+
577+
// Wait for all tasks to complete
578+
pool.wait_all();
579+
#endif
570580
}
571581
else {
572582
qsort_<vtype, comparator, T>(arr,
@@ -575,11 +585,6 @@ avx512_qsort_fp16_helper(uint16_t *arr, arrsize_t arrsize)
575585
2 * (arrsize_t)log2(arrsize),
576586
std::numeric_limits<arrsize_t>::max());
577587
}
578-
#pragma omp taskwait
579-
#else
580-
qsort_<vtype, comparator, T>(
581-
arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize), 0);
582-
#endif
583588
}
584589

585590
[[maybe_unused]] X86_SIMD_SORT_INLINE void

src/xss-common-qsort.h

Lines changed: 126 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
#ifndef XSS_COMMON_QSORT
1212
#define XSS_COMMON_QSORT
1313

14+
#ifdef XSS_BUILD_WITH_STD_THREADS
15+
#include "xss-thread-pool.hpp"
16+
#endif
17+
1418
/*
1519
* Quicksort using AVX-512. The ideas and code are based on these two research
1620
* papers [1] and [2]. On a high level, the idea is to vectorize quicksort
@@ -533,8 +537,61 @@ static void qsort_(type_t *arr,
533537
arrsize_t max_iters,
534538
arrsize_t task_threshold)
535539
{
540+
UNUSED(task_threshold);
536541
/*
537-
* Resort to std::sort if quicksort isnt making any progress
542+
* Resort to std::sort if quicksort isn't making any progress
543+
*/
544+
if (max_iters <= 0) {
545+
std::sort(arr + left, arr + right + 1, comparator::STDSortComparator);
546+
return;
547+
}
548+
/*
549+
* Base case: use bitonic networks to sort arrays <= vtype::network_sort_threshold
550+
*/
551+
if (right + 1 - left <= vtype::network_sort_threshold) {
552+
sort_n<vtype, comparator, vtype::network_sort_threshold>(
553+
arr + left, (int32_t)(right + 1 - left));
554+
return;
555+
}
556+
557+
auto pivot_result
558+
= get_pivot_smart<vtype, comparator, type_t>(arr, left, right);
559+
type_t pivot = pivot_result.pivot;
560+
561+
if (pivot_result.result == pivot_result_t::Sorted) { return; }
562+
563+
type_t smallest = vtype::type_max();
564+
type_t biggest = vtype::type_min();
565+
566+
arrsize_t pivot_index = partition_unrolled<vtype,
567+
comparator,
568+
vtype::partition_unroll_factor>(
569+
arr, left, right + 1, pivot, &smallest, &biggest);
570+
571+
if (pivot_result.result == pivot_result_t::Only2Values) { return; }
572+
573+
type_t leftmostValue = comparator::leftmost(smallest, biggest);
574+
type_t rightmostValue = comparator::rightmost(smallest, biggest);
575+
576+
// Sequential recursion
577+
if (pivot != leftmostValue)
578+
qsort_<vtype, comparator>(arr, left, pivot_index - 1, max_iters - 1, 0);
579+
if (pivot != rightmostValue)
580+
qsort_<vtype, comparator>(arr, pivot_index, right, max_iters - 1, 0);
581+
}
582+
583+
// Template function for std::thread-based parallel quicksort implementation
584+
#ifdef XSS_BUILD_WITH_STD_THREADS
585+
template <typename vtype, typename comparator, typename type_t>
586+
static void qsort_threads(type_t *arr,
587+
arrsize_t left,
588+
arrsize_t right,
589+
arrsize_t max_iters,
590+
arrsize_t task_threshold,
591+
ThreadPool &thread_pool)
592+
{
593+
/*
594+
* Resort to std::sort if quicksort isn't making any progress
538595
*/
539596
if (max_iters <= 0) {
540597
std::sort(arr + left, arr + right + 1, comparator::STDSortComparator);
@@ -568,41 +625,65 @@ static void qsort_(type_t *arr,
568625
type_t leftmostValue = comparator::leftmost(smallest, biggest);
569626
type_t rightmostValue = comparator::rightmost(smallest, biggest);
570627

571-
#ifdef XSS_COMPILE_OPENMP
628+
// Process left partition
572629
if (pivot != leftmostValue) {
573630
bool parallel_left = (pivot_index - left) > task_threshold;
574631
if (parallel_left) {
575-
#pragma omp task
576-
qsort_<vtype, comparator>(
577-
arr, left, pivot_index - 1, max_iters - 1, task_threshold);
632+
submit_task(thread_pool,
633+
[arr,
634+
left,
635+
pivot_index,
636+
max_iters,
637+
task_threshold,
638+
&thread_pool]() {
639+
qsort_threads<vtype, comparator>(arr,
640+
left,
641+
pivot_index - 1,
642+
max_iters - 1,
643+
task_threshold,
644+
thread_pool);
645+
});
578646
}
579647
else {
580-
qsort_<vtype, comparator>(
581-
arr, left, pivot_index - 1, max_iters - 1, task_threshold);
648+
qsort_threads<vtype, comparator>(arr,
649+
left,
650+
pivot_index - 1,
651+
max_iters - 1,
652+
task_threshold,
653+
thread_pool);
582654
}
583655
}
656+
657+
// Process right partition
584658
if (pivot != rightmostValue) {
585659
bool parallel_right = (right - pivot_index) > task_threshold;
586-
587660
if (parallel_right) {
588-
#pragma omp task
589-
qsort_<vtype, comparator>(
590-
arr, pivot_index, right, max_iters - 1, task_threshold);
661+
submit_task(thread_pool,
662+
[arr,
663+
pivot_index,
664+
right,
665+
max_iters,
666+
task_threshold,
667+
&thread_pool]() {
668+
qsort_threads<vtype, comparator>(arr,
669+
pivot_index,
670+
right,
671+
max_iters - 1,
672+
task_threshold,
673+
thread_pool);
674+
});
591675
}
592676
else {
593-
qsort_<vtype, comparator>(
594-
arr, pivot_index, right, max_iters - 1, task_threshold);
677+
qsort_threads<vtype, comparator>(arr,
678+
pivot_index,
679+
right,
680+
max_iters - 1,
681+
task_threshold,
682+
thread_pool);
595683
}
596684
}
597-
#else
598-
UNUSED(task_threshold);
599-
600-
if (pivot != leftmostValue)
601-
qsort_<vtype, comparator>(arr, left, pivot_index - 1, max_iters - 1, 0);
602-
if (pivot != rightmostValue)
603-
qsort_<vtype, comparator>(arr, pivot_index, right, max_iters - 1, 0);
604-
#endif
605685
}
686+
#endif // XSS_BUILD_WITH_STD_THREADS
606687

607688
template <typename vtype, typename comparator, typename type_t>
608689
X86_SIMD_SORT_INLINE void qselect_(type_t *arr,
@@ -667,40 +748,40 @@ X86_SIMD_SORT_INLINE void xss_qsort(T *arr, arrsize_t arrsize, bool hasnan)
667748

668749
UNUSED(hasnan);
669750

670-
#ifdef XSS_COMPILE_OPENMP
671-
751+
#ifdef XSS_BUILD_WITH_STD_THREADS
672752
bool use_parallel = arrsize > 100000;
753+
#else
754+
bool use_parallel = false;
755+
#endif
673756

674757
if (use_parallel) {
675-
// This thread limit was determined experimentally; it may be better for it to be the number of physical cores on the system
758+
#ifdef XSS_BUILD_WITH_STD_THREADS
759+
// This thread limit was determined experimentally
676760
constexpr int thread_limit = 8;
677-
int thread_count = std::min(thread_limit, omp_get_max_threads());
761+
int thread_count = std::min(
762+
thread_limit, (int)std::thread::hardware_concurrency());
678763
arrsize_t task_threshold
679764
= std::max((arrsize_t)100000, arrsize / 100);
680765

681-
// We use omp parallel and then omp single to setup the threads that will run the omp task calls in qsort_
682-
// The omp single prevents multiple threads from running the initial qsort_ simultaneously and causing problems
683-
// Note that we do not use the if(...) clause built into OpenMP, because it causes a performance regression for small arrays
684-
#pragma omp parallel num_threads(thread_count)
685-
#pragma omp single
686-
qsort_<vtype, comparator, T>(arr,
687-
0,
688-
arrsize - 1,
689-
2 * (arrsize_t)log2(arrsize),
690-
task_threshold);
691-
#pragma omp taskwait
766+
// Create a thread pool
767+
ThreadPool pool(thread_count);
768+
769+
// Initial sort task
770+
qsort_threads<vtype, comparator, T>(arr,
771+
0,
772+
arrsize - 1,
773+
2 * (arrsize_t)log2(arrsize),
774+
task_threshold,
775+
pool);
776+
// Wait for all tasks to complete
777+
pool.wait_all();
778+
#endif
692779
}
693780
else {
694-
qsort_<vtype, comparator, T>(arr,
695-
0,
696-
arrsize - 1,
697-
2 * (arrsize_t)log2(arrsize),
698-
std::numeric_limits<arrsize_t>::max());
781+
// For small arrays, just use the sequential version
782+
qsort_<vtype, comparator, T>(
783+
arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize), 0);
699784
}
700-
#else
701-
qsort_<vtype, comparator, T>(
702-
arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize), 0);
703-
#endif
704785

705786
replace_inf_with_nan(arr, arrsize, nan_count, descending);
706787
}

0 commit comments

Comments
 (0)