@@ -638,61 +638,64 @@ X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr,
638
638
/* argsort methods for 32-bit and 64-bit dtypes */
639
639
template <typename T>
640
640
X86_SIMD_SORT_INLINE void
641
- avx512_argsort (T *arr, arrsize_t *arg, arrsize_t arrsize)
641
+ avx512_argsort (T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false )
642
642
{
643
643
using vectype = typename std::conditional<sizeof (T) == sizeof (int32_t ),
644
644
ymm_vector<T>,
645
645
zmm_vector<T>>::type;
646
646
if (arrsize > 1 ) {
647
647
if constexpr (std::is_floating_point_v<T>) {
648
- if (has_nan <vectype>(arr, arrsize)) {
648
+ if ((hasnan) && (array_has_nan <vectype>(arr, arrsize) )) {
649
649
std_argsort_withnan (arr, arg, 0 , arrsize);
650
650
return ;
651
651
}
652
652
}
653
+ UNUSED (hasnan);
653
654
argsort_64bit_<vectype>(
654
655
arr, arg, 0 , arrsize - 1 , 2 * (arrsize_t )log2 (arrsize));
655
656
}
656
657
}
657
658
658
659
template <typename T>
659
660
X86_SIMD_SORT_INLINE std::vector<arrsize_t > avx512_argsort (T *arr,
660
- arrsize_t arrsize)
661
+ arrsize_t arrsize,
662
+ bool hasnan = false )
661
663
{
662
664
std::vector<arrsize_t > indices (arrsize);
663
665
std::iota (indices.begin (), indices.end (), 0 );
664
- avx512_argsort<T>(arr, indices.data (), arrsize);
666
+ avx512_argsort<T>(arr, indices.data (), arrsize, hasnan );
665
667
return indices;
666
668
}
667
669
668
670
/* argselect methods for 32-bit and 64-bit dtypes */
669
671
template <typename T>
670
672
X86_SIMD_SORT_INLINE void
671
- avx512_argselect (T *arr, arrsize_t *arg, arrsize_t k, arrsize_t arrsize)
673
+ avx512_argselect (T *arr, arrsize_t *arg, arrsize_t k, arrsize_t arrsize, bool hasnan = false )
672
674
{
673
675
using vectype = typename std::conditional<sizeof (T) == sizeof (int32_t ),
674
676
ymm_vector<T>,
675
677
zmm_vector<T>>::type;
676
678
677
679
if (arrsize > 1 ) {
678
680
if constexpr (std::is_floating_point_v<T>) {
679
- if (has_nan <vectype>(arr, arrsize)) {
681
+ if ((hasnan) && (array_has_nan <vectype>(arr, arrsize) )) {
680
682
std_argselect_withnan (arr, arg, k, 0 , arrsize);
681
683
return ;
682
684
}
683
685
}
686
+ UNUSED (hasnan);
684
687
argselect_64bit_<vectype>(
685
688
arr, arg, k, 0 , arrsize - 1 , 2 * (arrsize_t )log2 (arrsize));
686
689
}
687
690
}
688
691
689
692
template <typename T>
690
693
X86_SIMD_SORT_INLINE std::vector<arrsize_t >
691
- avx512_argselect (T *arr, arrsize_t k, arrsize_t arrsize)
694
+ avx512_argselect (T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false )
692
695
{
693
696
std::vector<arrsize_t > indices (arrsize);
694
697
std::iota (indices.begin (), indices.end (), 0 );
695
- avx512_argselect<T>(arr, indices.data (), k, arrsize);
698
+ avx512_argselect<T>(arr, indices.data (), k, arrsize, hasnan );
696
699
return indices;
697
700
}
698
701
0 commit comments