|
11 | 11 | #ifndef XSS_COMMON_QSORT
|
12 | 12 | #define XSS_COMMON_QSORT
|
13 | 13 |
|
| 14 | +#ifdef XSS_BUILD_WITH_STD_THREADS |
| 15 | +#include "xss-thread-pool.hpp" |
| 16 | +#endif |
| 17 | + |
14 | 18 | /*
|
15 | 19 | * Quicksort using AVX-512. The ideas and code are based on these two research
|
16 | 20 | * papers [1] and [2]. On a high level, the idea is to vectorize quicksort
|
@@ -533,8 +537,61 @@ static void qsort_(type_t *arr,
|
533 | 537 | arrsize_t max_iters,
|
534 | 538 | arrsize_t task_threshold)
|
535 | 539 | {
|
| 540 | + UNUSED(task_threshold); |
536 | 541 | /*
|
537 |
| - * Resort to std::sort if quicksort isnt making any progress |
| 542 | + * Resort to std::sort if quicksort isn't making any progress |
| 543 | + */ |
| 544 | + if (max_iters <= 0) { |
| 545 | + std::sort(arr + left, arr + right + 1, comparator::STDSortComparator); |
| 546 | + return; |
| 547 | + } |
| 548 | + /* |
| 549 | + * Base case: use bitonic networks to sort arrays <= vtype::network_sort_threshold |
| 550 | + */ |
| 551 | + if (right + 1 - left <= vtype::network_sort_threshold) { |
| 552 | + sort_n<vtype, comparator, vtype::network_sort_threshold>( |
| 553 | + arr + left, (int32_t)(right + 1 - left)); |
| 554 | + return; |
| 555 | + } |
| 556 | + |
| 557 | + auto pivot_result |
| 558 | + = get_pivot_smart<vtype, comparator, type_t>(arr, left, right); |
| 559 | + type_t pivot = pivot_result.pivot; |
| 560 | + |
| 561 | + if (pivot_result.result == pivot_result_t::Sorted) { return; } |
| 562 | + |
| 563 | + type_t smallest = vtype::type_max(); |
| 564 | + type_t biggest = vtype::type_min(); |
| 565 | + |
| 566 | + arrsize_t pivot_index = partition_unrolled<vtype, |
| 567 | + comparator, |
| 568 | + vtype::partition_unroll_factor>( |
| 569 | + arr, left, right + 1, pivot, &smallest, &biggest); |
| 570 | + |
| 571 | + if (pivot_result.result == pivot_result_t::Only2Values) { return; } |
| 572 | + |
| 573 | + type_t leftmostValue = comparator::leftmost(smallest, biggest); |
| 574 | + type_t rightmostValue = comparator::rightmost(smallest, biggest); |
| 575 | + |
| 576 | + // Sequential recursion |
| 577 | + if (pivot != leftmostValue) |
| 578 | + qsort_<vtype, comparator>(arr, left, pivot_index - 1, max_iters - 1, 0); |
| 579 | + if (pivot != rightmostValue) |
| 580 | + qsort_<vtype, comparator>(arr, pivot_index, right, max_iters - 1, 0); |
| 581 | +} |
| 582 | + |
| 583 | +// Template function for std::thread-based parallel quicksort implementation |
| 584 | +#ifdef XSS_BUILD_WITH_STD_THREADS |
| 585 | +template <typename vtype, typename comparator, typename type_t> |
| 586 | +static void qsort_threads(type_t *arr, |
| 587 | + arrsize_t left, |
| 588 | + arrsize_t right, |
| 589 | + arrsize_t max_iters, |
| 590 | + arrsize_t task_threshold, |
| 591 | + ThreadPool &thread_pool) |
| 592 | +{ |
| 593 | + /* |
| 594 | + * Resort to std::sort if quicksort isn't making any progress |
538 | 595 | */
|
539 | 596 | if (max_iters <= 0) {
|
540 | 597 | std::sort(arr + left, arr + right + 1, comparator::STDSortComparator);
|
@@ -568,41 +625,65 @@ static void qsort_(type_t *arr,
|
568 | 625 | type_t leftmostValue = comparator::leftmost(smallest, biggest);
|
569 | 626 | type_t rightmostValue = comparator::rightmost(smallest, biggest);
|
570 | 627 |
|
571 |
| -#ifdef XSS_COMPILE_OPENMP |
| 628 | + // Process left partition |
572 | 629 | if (pivot != leftmostValue) {
|
573 | 630 | bool parallel_left = (pivot_index - left) > task_threshold;
|
574 | 631 | if (parallel_left) {
|
575 |
| -#pragma omp task |
576 |
| - qsort_<vtype, comparator>( |
577 |
| - arr, left, pivot_index - 1, max_iters - 1, task_threshold); |
| 632 | + submit_task(thread_pool, |
| 633 | + [arr, |
| 634 | + left, |
| 635 | + pivot_index, |
| 636 | + max_iters, |
| 637 | + task_threshold, |
| 638 | + &thread_pool]() { |
| 639 | + qsort_threads<vtype, comparator>(arr, |
| 640 | + left, |
| 641 | + pivot_index - 1, |
| 642 | + max_iters - 1, |
| 643 | + task_threshold, |
| 644 | + thread_pool); |
| 645 | + }); |
578 | 646 | }
|
579 | 647 | else {
|
580 |
| - qsort_<vtype, comparator>( |
581 |
| - arr, left, pivot_index - 1, max_iters - 1, task_threshold); |
| 648 | + qsort_threads<vtype, comparator>(arr, |
| 649 | + left, |
| 650 | + pivot_index - 1, |
| 651 | + max_iters - 1, |
| 652 | + task_threshold, |
| 653 | + thread_pool); |
582 | 654 | }
|
583 | 655 | }
|
| 656 | + |
| 657 | + // Process right partition |
584 | 658 | if (pivot != rightmostValue) {
|
585 | 659 | bool parallel_right = (right - pivot_index) > task_threshold;
|
586 |
| - |
587 | 660 | if (parallel_right) {
|
588 |
| -#pragma omp task |
589 |
| - qsort_<vtype, comparator>( |
590 |
| - arr, pivot_index, right, max_iters - 1, task_threshold); |
| 661 | + submit_task(thread_pool, |
| 662 | + [arr, |
| 663 | + pivot_index, |
| 664 | + right, |
| 665 | + max_iters, |
| 666 | + task_threshold, |
| 667 | + &thread_pool]() { |
| 668 | + qsort_threads<vtype, comparator>(arr, |
| 669 | + pivot_index, |
| 670 | + right, |
| 671 | + max_iters - 1, |
| 672 | + task_threshold, |
| 673 | + thread_pool); |
| 674 | + }); |
591 | 675 | }
|
592 | 676 | else {
|
593 |
| - qsort_<vtype, comparator>( |
594 |
| - arr, pivot_index, right, max_iters - 1, task_threshold); |
| 677 | + qsort_threads<vtype, comparator>(arr, |
| 678 | + pivot_index, |
| 679 | + right, |
| 680 | + max_iters - 1, |
| 681 | + task_threshold, |
| 682 | + thread_pool); |
595 | 683 | }
|
596 | 684 | }
|
597 |
| -#else |
598 |
| - UNUSED(task_threshold); |
599 |
| - |
600 |
| - if (pivot != leftmostValue) |
601 |
| - qsort_<vtype, comparator>(arr, left, pivot_index - 1, max_iters - 1, 0); |
602 |
| - if (pivot != rightmostValue) |
603 |
| - qsort_<vtype, comparator>(arr, pivot_index, right, max_iters - 1, 0); |
604 |
| -#endif |
605 | 685 | }
|
| 686 | +#endif // XSS_BUILD_WITH_STD_THREADS |
606 | 687 |
|
607 | 688 | template <typename vtype, typename comparator, typename type_t>
|
608 | 689 | X86_SIMD_SORT_INLINE void qselect_(type_t *arr,
|
@@ -667,40 +748,40 @@ X86_SIMD_SORT_INLINE void xss_qsort(T *arr, arrsize_t arrsize, bool hasnan)
|
667 | 748 |
|
668 | 749 | UNUSED(hasnan);
|
669 | 750 |
|
670 |
| -#ifdef XSS_COMPILE_OPENMP |
671 |
| - |
| 751 | +#ifdef XSS_BUILD_WITH_STD_THREADS |
672 | 752 | bool use_parallel = arrsize > 100000;
|
| 753 | +#else |
| 754 | + bool use_parallel = false; |
| 755 | +#endif |
673 | 756 |
|
674 | 757 | if (use_parallel) {
|
675 |
| - // This thread limit was determined experimentally; it may be better for it to be the number of physical cores on the system |
| 758 | +#ifdef XSS_BUILD_WITH_STD_THREADS |
| 759 | + // This thread limit was determined experimentally |
676 | 760 | constexpr int thread_limit = 8;
|
677 |
| - int thread_count = std::min(thread_limit, omp_get_max_threads()); |
| 761 | + int thread_count = std::min( |
| 762 | + thread_limit, (int)std::thread::hardware_concurrency()); |
678 | 763 | arrsize_t task_threshold
|
679 | 764 | = std::max((arrsize_t)100000, arrsize / 100);
|
680 | 765 |
|
681 |
| - // We use omp parallel and then omp single to setup the threads that will run the omp task calls in qsort_ |
682 |
| - // The omp single prevents multiple threads from running the initial qsort_ simultaneously and causing problems |
683 |
| - // Note that we do not use the if(...) clause built into OpenMP, because it causes a performance regression for small arrays |
684 |
| -#pragma omp parallel num_threads(thread_count) |
685 |
| -#pragma omp single |
686 |
| - qsort_<vtype, comparator, T>(arr, |
687 |
| - 0, |
688 |
| - arrsize - 1, |
689 |
| - 2 * (arrsize_t)log2(arrsize), |
690 |
| - task_threshold); |
691 |
| -#pragma omp taskwait |
| 766 | + // Create a thread pool |
| 767 | + ThreadPool pool(thread_count); |
| 768 | + |
| 769 | + // Initial sort task |
| 770 | + qsort_threads<vtype, comparator, T>(arr, |
| 771 | + 0, |
| 772 | + arrsize - 1, |
| 773 | + 2 * (arrsize_t)log2(arrsize), |
| 774 | + task_threshold, |
| 775 | + pool); |
| 776 | + // Wait for all tasks to complete |
| 777 | + pool.wait_all(); |
| 778 | +#endif |
692 | 779 | }
|
693 | 780 | else {
|
694 |
| - qsort_<vtype, comparator, T>(arr, |
695 |
| - 0, |
696 |
| - arrsize - 1, |
697 |
| - 2 * (arrsize_t)log2(arrsize), |
698 |
| - std::numeric_limits<arrsize_t>::max()); |
| 781 | + // For small arrays, just use the sequential version |
| 782 | + qsort_<vtype, comparator, T>( |
| 783 | + arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize), 0); |
699 | 784 | }
|
700 |
| -#else |
701 |
| - qsort_<vtype, comparator, T>( |
702 |
| - arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize), 0); |
703 |
| -#endif |
704 | 785 |
|
705 | 786 | replace_inf_with_nan(arr, arrsize, nan_count, descending);
|
706 | 787 | }
|
|
0 commit comments