Skip to content

Commit b4a64a9

Browse files
committed
Cleanup to argsort logic
1 parent 369f639 commit b4a64a9

File tree

1 file changed

+57
-72
lines changed

1 file changed

+57
-72
lines changed

src/xss-common-argsort.h

Lines changed: 57 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr,
468468
}
469469

470470
template <typename vtype, typename argtype, typename type_t>
471-
X86_SIMD_SORT_INLINE void argsort_64bit_(type_t *arr,
471+
X86_SIMD_SORT_INLINE void argsort_(type_t *arr,
472472
arrsize_t *arg,
473473
arrsize_t left,
474474
arrsize_t right,
@@ -495,15 +495,15 @@ X86_SIMD_SORT_INLINE void argsort_64bit_(type_t *arr,
495495
arrsize_t pivot_index = argpartition_unrolled<vtype, argtype, 4>(
496496
arr, arg, left, right + 1, pivot, &smallest, &biggest);
497497
if (pivot != smallest)
498-
argsort_64bit_<vtype, argtype>(
498+
argsort_<vtype, argtype>(
499499
arr, arg, left, pivot_index - 1, max_iters - 1);
500500
if (pivot != biggest)
501-
argsort_64bit_<vtype, argtype>(
501+
argsort_<vtype, argtype>(
502502
arr, arg, pivot_index, right, max_iters - 1);
503503
}
504504

505505
template <typename vtype, typename argtype, typename type_t>
506-
X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr,
506+
X86_SIMD_SORT_INLINE void argselect_(type_t *arr,
507507
arrsize_t *arg,
508508
arrsize_t pos,
509509
arrsize_t left,
@@ -531,30 +531,34 @@ X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr,
531531
arrsize_t pivot_index = argpartition_unrolled<vtype, argtype, 4>(
532532
arr, arg, left, right + 1, pivot, &smallest, &biggest);
533533
if ((pivot != smallest) && (pos < pivot_index))
534-
argselect_64bit_<vtype, argtype>(
534+
argselect_<vtype, argtype>(
535535
arr, arg, pos, left, pivot_index - 1, max_iters - 1);
536536
else if ((pivot != biggest) && (pos >= pivot_index))
537-
argselect_64bit_<vtype, argtype>(
537+
argselect_<vtype, argtype>(
538538
arr, arg, pos, pivot_index, right, max_iters - 1);
539539
}
540540

541541
/* argsort methods for 32-bit and 64-bit dtypes */
542-
template <typename T>
543-
X86_SIMD_SORT_INLINE void avx512_argsort(T *arr,
542+
template <typename T,
543+
template <typename...>
544+
typename full_vector,
545+
template <typename...>
546+
typename half_vector>
547+
X86_SIMD_SORT_INLINE void xss_argsort(T *arr,
544548
arrsize_t *arg,
545549
arrsize_t arrsize,
546550
bool hasnan = false,
547551
bool descending = false)
548552
{
549-
/* TODO optimization: on 32-bit, use zmm_vector for 32-bit dtype */
553+
/* TODO optimization: on 32-bit, use full_vector for 32-bit dtype */
550554
using vectype = typename std::conditional<sizeof(T) == sizeof(int32_t),
551-
ymm_vector<T>,
552-
zmm_vector<T>>::type;
555+
half_vector<T>,
556+
full_vector<T>>::type;
553557

554558
using argtype =
555559
typename std::conditional<sizeof(arrsize_t) == sizeof(int32_t),
556-
ymm_vector<arrsize_t>,
557-
zmm_vector<arrsize_t>>::type;
560+
half_vector<arrsize_t>,
561+
full_vector<arrsize_t>>::type;
558562

559563
if (arrsize > 1) {
560564
if constexpr (xss::fp::is_floating_point_v<T>) {
@@ -567,64 +571,54 @@ X86_SIMD_SORT_INLINE void avx512_argsort(T *arr,
567571
}
568572
}
569573
UNUSED(hasnan);
570-
argsort_64bit_<vectype, argtype>(
574+
argsort_<vectype, argtype>(
571575
arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize));
572576

573577
if (descending) { std::reverse(arg, arg + arrsize); }
574578
}
575579
}
576580

577-
/* argsort methods for 32-bit and 64-bit dtypes */
578581
template <typename T>
579-
X86_SIMD_SORT_INLINE void avx2_argsort(T *arr,
580-
arrsize_t *arg,
581-
arrsize_t arrsize,
582-
bool hasnan = false,
583-
bool descending = false)
582+
X86_SIMD_SORT_INLINE void avx512_argsort(T *arr,
583+
arrsize_t *arg,
584+
arrsize_t arrsize,
585+
bool hasnan = false,
586+
bool descending = false)
584587
{
585-
using vectype = typename std::conditional<sizeof(T) == sizeof(int32_t),
586-
avx2_half_vector<T>,
587-
avx2_vector<T>>::type;
588-
589-
using argtype =
590-
typename std::conditional<sizeof(arrsize_t) == sizeof(int32_t),
591-
avx2_half_vector<arrsize_t>,
592-
avx2_vector<arrsize_t>>::type;
593-
if (arrsize > 1) {
594-
if constexpr (xss::fp::is_floating_point_v<T>) {
595-
if ((hasnan) && (array_has_nan<vectype>(arr, arrsize))) {
596-
std_argsort_withnan(arr, arg, 0, arrsize);
597-
598-
if (descending) { std::reverse(arg, arg + arrsize); }
599-
600-
return;
601-
}
602-
}
603-
UNUSED(hasnan);
604-
argsort_64bit_<vectype, argtype>(
605-
arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize));
588+
xss_argsort<T, zmm_vector, ymm_vector>(arr, arg, arrsize, hasnan, descending);
589+
}
606590

607-
if (descending) { std::reverse(arg, arg + arrsize); }
608-
}
591+
template <typename T>
592+
X86_SIMD_SORT_INLINE void avx2_argsort(T *arr,
593+
arrsize_t *arg,
594+
arrsize_t arrsize,
595+
bool hasnan = false,
596+
bool descending = false)
597+
{
598+
xss_argsort<T, avx2_vector, avx2_half_vector>(arr, arg, arrsize, hasnan, descending);
609599
}
610600

611601
/* argselect methods for 32-bit and 64-bit dtypes */
612-
template <typename T>
613-
X86_SIMD_SORT_INLINE void avx512_argselect(T *arr,
602+
template <typename T,
603+
template <typename...>
604+
typename full_vector,
605+
template <typename...>
606+
typename half_vector>
607+
X86_SIMD_SORT_INLINE void xss_argselect(T *arr,
614608
arrsize_t *arg,
615609
arrsize_t k,
616610
arrsize_t arrsize,
617611
bool hasnan = false)
618612
{
619-
/* TODO optimization: on 32-bit, use zmm_vector for 32-bit dtype */
613+
/* TODO optimization: on 32-bit, use full_vector for 32-bit dtype */
620614
using vectype = typename std::conditional<sizeof(T) == sizeof(int32_t),
621-
ymm_vector<T>,
622-
zmm_vector<T>>::type;
615+
half_vector<T>,
616+
full_vector<T>>::type;
623617

624618
using argtype =
625619
typename std::conditional<sizeof(arrsize_t) == sizeof(int32_t),
626-
ymm_vector<arrsize_t>,
627-
zmm_vector<arrsize_t>>::type;
620+
half_vector<arrsize_t>,
621+
full_vector<arrsize_t>>::type;
628622

629623
if (arrsize > 1) {
630624
if constexpr (xss::fp::is_floating_point_v<T>) {
@@ -634,38 +628,29 @@ X86_SIMD_SORT_INLINE void avx512_argselect(T *arr,
634628
}
635629
}
636630
UNUSED(hasnan);
637-
argselect_64bit_<vectype, argtype>(
631+
argselect_<vectype, argtype>(
638632
arr, arg, k, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize));
639633
}
640634
}
641635

642-
/* argselect methods for 32-bit and 64-bit dtypes */
643636
template <typename T>
644-
X86_SIMD_SORT_INLINE void avx2_argselect(T *arr,
637+
X86_SIMD_SORT_INLINE void avx512_argselect(T *arr,
645638
arrsize_t *arg,
646639
arrsize_t k,
647640
arrsize_t arrsize,
648641
bool hasnan = false)
649642
{
650-
using vectype = typename std::conditional<sizeof(T) == sizeof(int32_t),
651-
avx2_half_vector<T>,
652-
avx2_vector<T>>::type;
653-
654-
using argtype =
655-
typename std::conditional<sizeof(arrsize_t) == sizeof(int32_t),
656-
avx2_half_vector<arrsize_t>,
657-
avx2_vector<arrsize_t>>::type;
643+
xss_argselect<T, zmm_vector, ymm_vector>(arr, arg, k, arrsize, hasnan);
644+
}
658645

659-
if (arrsize > 1) {
660-
if constexpr (xss::fp::is_floating_point_v<T>) {
661-
if ((hasnan) && (array_has_nan<vectype>(arr, arrsize))) {
662-
std_argselect_withnan(arr, arg, k, 0, arrsize);
663-
return;
664-
}
665-
}
666-
UNUSED(hasnan);
667-
argselect_64bit_<vectype, argtype>(
668-
arr, arg, k, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize));
669-
}
646+
template <typename T>
647+
X86_SIMD_SORT_INLINE void avx2_argselect(T *arr,
648+
arrsize_t *arg,
649+
arrsize_t k,
650+
arrsize_t arrsize,
651+
bool hasnan = false)
652+
{
653+
xss_argselect<T, avx2_vector, avx2_half_vector>(arr, arg, k, arrsize, hasnan);
670654
}
655+
671656
#endif // XSS_COMMON_ARGSORT

0 commit comments

Comments
 (0)