Skip to content

Commit 01bae64

Browse files
committed
Fixes needed when rebasing
1 parent e89a50a commit 01bae64

File tree

1 file changed

+59
-18
lines changed

1 file changed

+59
-18
lines changed

src/xss-common-argsort.h

Lines changed: 59 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -482,16 +482,13 @@ X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr,
482482
}
483483
}
484484

485-
template <typename vtype, typename type_t>
485+
template <typename vtype, typename argtype, typename type_t>
486486
X86_SIMD_SORT_INLINE void argsort_64bit_(type_t *arr,
487487
arrsize_t *arg,
488488
arrsize_t left,
489489
arrsize_t right,
490490
arrsize_t max_iters)
491491
{
492-
using argtype = typename std::conditional<vtype::numlanes == 4,
493-
avx2_vector<arrsize_t>,
494-
zmm_vector<arrsize_t>>::type;
495492
/*
496493
* Resort to std::sort if quicksort isnt making any progress
497494
*/
@@ -503,7 +500,8 @@ X86_SIMD_SORT_INLINE void argsort_64bit_(type_t *arr,
503500
* Base case: use bitonic networks to sort arrays <= 64
504501
*/
505502
if (right + 1 - left <= 256) {
506-
argsort_n<vtype, 256>(arr, arg + left, (int32_t)(right + 1 - left));
503+
argsort_n<vtype, argtype, 256>(
504+
arr, arg + left, (int32_t)(right + 1 - left));
507505
return;
508506
}
509507
type_t pivot = get_pivot_64bit<vtype>(arr, arg, left, right);
@@ -512,22 +510,21 @@ X86_SIMD_SORT_INLINE void argsort_64bit_(type_t *arr,
512510
arrsize_t pivot_index = partition_avx512_unrolled<vtype, argtype, 4>(
513511
arr, arg, left, right + 1, pivot, &smallest, &biggest);
514512
if (pivot != smallest)
515-
argsort_64bit_<vtype>(arr, arg, left, pivot_index - 1, max_iters - 1);
513+
argsort_64bit_<vtype, argtype>(
514+
arr, arg, left, pivot_index - 1, max_iters - 1);
516515
if (pivot != biggest)
517-
argsort_64bit_<vtype>(arr, arg, pivot_index, right, max_iters - 1);
516+
argsort_64bit_<vtype, argtype>(
517+
arr, arg, pivot_index, right, max_iters - 1);
518518
}
519519

520-
template <typename vtype, typename type_t>
520+
template <typename vtype, typename argtype, typename type_t>
521521
X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr,
522522
arrsize_t *arg,
523523
arrsize_t pos,
524524
arrsize_t left,
525525
arrsize_t right,
526526
arrsize_t max_iters)
527527
{
528-
using argtype = typename std::conditional<vtype::numlanes == 4,
529-
avx2_vector<arrsize_t>,
530-
zmm_vector<arrsize_t>>::type;
531528
/*
532529
* Resort to std::sort if quicksort isnt making any progress
533530
*/
@@ -539,7 +536,8 @@ X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr,
539536
* Base case: use bitonic networks to sort arrays <= 64
540537
*/
541538
if (right + 1 - left <= 256) {
542-
argsort_n<vtype, 256>(arr, arg + left, (int32_t)(right + 1 - left));
539+
argsort_n<vtype, argtype, 256>(
540+
arr, arg + left, (int32_t)(right + 1 - left));
543541
return;
544542
}
545543
type_t pivot = get_pivot_64bit<vtype>(arr, arg, left, right);
@@ -548,10 +546,10 @@ X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr,
548546
arrsize_t pivot_index = partition_avx512_unrolled<vtype, argtype, 4>(
549547
arr, arg, left, right + 1, pivot, &smallest, &biggest);
550548
if ((pivot != smallest) && (pos < pivot_index))
551-
argselect_64bit_<vtype>(
549+
argselect_64bit_<vtype, argtype>(
552550
arr, arg, pos, left, pivot_index - 1, max_iters - 1);
553551
else if ((pivot != biggest) && (pos >= pivot_index))
554-
argselect_64bit_<vtype>(
552+
argselect_64bit_<vtype, argtype>(
555553
arr, arg, pos, pivot_index, right, max_iters - 1);
556554
}
557555

@@ -560,9 +558,25 @@ template <typename T>
560558
X86_SIMD_SORT_INLINE void
561559
avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false)
562560
{
561+
/* TODO optimization: on 32-bit, use zmm_vector for 32-bit dtype */
563562
using vectype = typename std::conditional<sizeof(T) == sizeof(int32_t),
564563
ymm_vector<T>,
565564
zmm_vector<T>>::type;
565+
566+
/* Workaround for NumPy failed build on macOS x86_64: implicit instantiation of
567+
* undefined template 'zmm_vector<unsigned long>'*/
568+
#ifdef __APPLE__
569+
using argtype =
570+
typename std::conditional<sizeof(arrsize_t) == sizeof(int32_t),
571+
ymm_vector<uint32_t>,
572+
zmm_vector<uint64_t>>::type;
573+
#else
574+
using argtype =
575+
typename std::conditional<sizeof(arrsize_t) == sizeof(int32_t),
576+
ymm_vector<arrsize_t>,
577+
zmm_vector<arrsize_t>>::type;
578+
#endif
579+
566580
if (arrsize > 1) {
567581
if constexpr (std::is_floating_point_v<T>) {
568582
if ((hasnan) && (array_has_nan<vectype>(arr, arrsize))) {
@@ -571,7 +585,7 @@ avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false)
571585
}
572586
}
573587
UNUSED(hasnan);
574-
argsort_64bit_<vectype>(
588+
argsort_64bit_<vectype, argtype>(
575589
arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize));
576590
}
577591
}
@@ -594,6 +608,13 @@ avx2_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false)
594608
using vectype = typename std::conditional<sizeof(T) == sizeof(int32_t),
595609
avx2_half_vector<T>,
596610
avx2_vector<T>>::type;
611+
612+
using argtype =
613+
typename std::conditional<sizeof(arrsize_t) == sizeof(int32_t),
614+
avx2_half_vector<arrsize_t>,
615+
avx2_vector<arrsize_t>>::type;
616+
617+
597618
if (arrsize > 1) {
598619
if constexpr (std::is_floating_point_v<T>) {
599620
if ((hasnan) && (array_has_nan<vectype>(arr, arrsize))) {
@@ -602,7 +623,7 @@ avx2_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false)
602623
}
603624
}
604625
UNUSED(hasnan);
605-
argsort_64bit_<vectype>(
626+
argsort_64bit_<vectype, argtype>(
606627
arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize));
607628
}
608629
}
@@ -625,10 +646,25 @@ X86_SIMD_SORT_INLINE void avx512_argselect(T *arr,
625646
arrsize_t arrsize,
626647
bool hasnan = false)
627648
{
649+
/* TODO optimization: on 32-bit, use zmm_vector for 32-bit dtype */
628650
using vectype = typename std::conditional<sizeof(T) == sizeof(int32_t),
629651
ymm_vector<T>,
630652
zmm_vector<T>>::type;
631653

654+
/* Workaround for NumPy failed build on macOS x86_64: implicit instantiation of
655+
* undefined template 'zmm_vector<unsigned long>'*/
656+
#ifdef __APPLE__
657+
using argtype =
658+
typename std::conditional<sizeof(arrsize_t) == sizeof(int32_t),
659+
ymm_vector<uint32_t>,
660+
zmm_vector<uint64_t>>::type;
661+
#else
662+
using argtype =
663+
typename std::conditional<sizeof(arrsize_t) == sizeof(int32_t),
664+
ymm_vector<arrsize_t>,
665+
zmm_vector<arrsize_t>>::type;
666+
#endif
667+
632668
if (arrsize > 1) {
633669
if constexpr (std::is_floating_point_v<T>) {
634670
if ((hasnan) && (array_has_nan<vectype>(arr, arrsize))) {
@@ -637,7 +673,7 @@ X86_SIMD_SORT_INLINE void avx512_argselect(T *arr,
637673
}
638674
}
639675
UNUSED(hasnan);
640-
argselect_64bit_<vectype>(
676+
argselect_64bit_<vectype, argtype>(
641677
arr, arg, k, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize));
642678
}
643679
}
@@ -664,6 +700,11 @@ X86_SIMD_SORT_INLINE void avx2_argselect(T *arr,
664700
avx2_half_vector<T>,
665701
avx2_vector<T>>::type;
666702

703+
using argtype =
704+
typename std::conditional<sizeof(arrsize_t) == sizeof(int32_t),
705+
avx2_half_vector<arrsize_t>,
706+
avx2_vector<arrsize_t>>::type;
707+
667708
if (arrsize > 1) {
668709
if constexpr (std::is_floating_point_v<T>) {
669710
if ((hasnan) && (array_has_nan<vectype>(arr, arrsize))) {
@@ -672,7 +713,7 @@ X86_SIMD_SORT_INLINE void avx2_argselect(T *arr,
672713
}
673714
}
674715
UNUSED(hasnan);
675-
argselect_64bit_<vectype>(
716+
argselect_64bit_<vectype, argtype>(
676717
arr, arg, k, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize));
677718
}
678719
}

0 commit comments

Comments
 (0)