Skip to content

Commit bae0eff

Browse files
author
Raghuveer Devulapalli
committed
Add hasnan = false to all methods
1 parent 64908e7 commit bae0eff

12 files changed

+94
-63
lines changed

lib/x86simdsort-avx2.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44

55
#define DEFINE_ALL_METHODS(type) \
66
template <> \
7-
void qsort(type *arr, size_t arrsize) \
7+
void qsort(type *arr, size_t arrsize, bool hasnan) \
88
{ \
9-
avx2_qsort(arr, arrsize); \
9+
avx2_qsort(arr, arrsize, hasnan); \
1010
} \
1111
template <> \
1212
void qselect(type *arr, size_t k, size_t arrsize, bool hasnan) \
@@ -24,5 +24,5 @@ namespace avx2 {
2424
DEFINE_ALL_METHODS(uint32_t)
2525
DEFINE_ALL_METHODS(int32_t)
2626
DEFINE_ALL_METHODS(float)
27-
} // namespace avx512
27+
} // namespace avx2
2828
} // namespace xss

lib/x86simdsort-icl.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
namespace xss {
66
namespace avx512 {
77
template <>
8-
void qsort(uint16_t *arr, size_t size)
8+
void qsort(uint16_t *arr, size_t size, bool hasnan)
99
{
10-
avx512_qsort(arr, size);
10+
avx512_qsort(arr, size, hasnan);
1111
}
1212
template <>
1313
void qselect(uint16_t *arr, size_t k, size_t arrsize, bool hasnan)
@@ -20,9 +20,9 @@ namespace avx512 {
2020
avx512_partial_qsort(arr, k, arrsize, hasnan);
2121
}
2222
template <>
23-
void qsort(int16_t *arr, size_t size)
23+
void qsort(int16_t *arr, size_t size, bool hasnan)
2424
{
25-
avx512_qsort(arr, size);
25+
avx512_qsort(arr, size, hasnan);
2626
}
2727
template <>
2828
void qselect(int16_t *arr, size_t k, size_t arrsize, bool hasnan)

lib/x86simdsort-internal.h

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace xss {
88
namespace avx512 {
99
// quicksort
1010
template <typename T>
11-
XSS_HIDE_SYMBOL void qsort(T *arr, size_t arrsize);
11+
XSS_HIDE_SYMBOL void qsort(T *arr, size_t arrsize, bool hasnan = false);
1212
// quickselect
1313
template <typename T>
1414
XSS_HIDE_SYMBOL void
@@ -19,16 +19,17 @@ namespace avx512 {
1919
partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan = false);
2020
// argsort
2121
template <typename T>
22-
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr, size_t arrsize);
22+
XSS_HIDE_SYMBOL std::vector<size_t>
23+
argsort(T *arr, size_t arrsize, bool hasnan = false);
2324
// argselect
2425
template <typename T>
2526
XSS_HIDE_SYMBOL std::vector<size_t>
26-
argselect(T *arr, size_t k, size_t arrsize);
27+
argselect(T *arr, size_t k, size_t arrsize, bool hasnan = false);
2728
} // namespace avx512
2829
namespace avx2 {
2930
// quicksort
3031
template <typename T>
31-
XSS_HIDE_SYMBOL void qsort(T *arr, size_t arrsize);
32+
XSS_HIDE_SYMBOL void qsort(T *arr, size_t arrsize, bool hasnan = false);
3233
// quickselect
3334
template <typename T>
3435
XSS_HIDE_SYMBOL void
@@ -39,16 +40,17 @@ namespace avx2 {
3940
partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan = false);
4041
// argsort
4142
template <typename T>
42-
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr, size_t arrsize);
43+
XSS_HIDE_SYMBOL std::vector<size_t>
44+
argsort(T *arr, size_t arrsize, bool hasnan = false);
4345
// argselect
4446
template <typename T>
4547
XSS_HIDE_SYMBOL std::vector<size_t>
46-
argselect(T *arr, size_t k, size_t arrsize);
48+
argselect(T *arr, size_t k, size_t arrsize, bool hasnan = false);
4749
} // namespace avx2
4850
namespace scalar {
4951
// quicksort
5052
template <typename T>
51-
XSS_HIDE_SYMBOL void qsort(T *arr, size_t arrsize);
53+
XSS_HIDE_SYMBOL void qsort(T *arr, size_t arrsize, bool hasnan = false);
5254
// quickselect
5355
template <typename T>
5456
XSS_HIDE_SYMBOL void
@@ -59,11 +61,12 @@ namespace scalar {
5961
partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan = false);
6062
// argsort
6163
template <typename T>
62-
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr, size_t arrsize);
64+
XSS_HIDE_SYMBOL std::vector<size_t>
65+
argsort(T *arr, size_t arrsize, bool hasnan = false);
6366
// argselect
6467
template <typename T>
6568
XSS_HIDE_SYMBOL std::vector<size_t>
66-
argselect(T *arr, size_t k, size_t arrsize);
69+
argselect(T *arr, size_t k, size_t arrsize, bool hasnan = false);
6770
} // namespace scalar
6871
} // namespace xss
6972
#endif

lib/x86simdsort-scalar.h

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,14 @@
55
namespace xss {
66
namespace scalar {
77
template <typename T>
8-
void qsort(T *arr, size_t arrsize)
8+
void qsort(T *arr, size_t arrsize, bool hasnan)
99
{
10-
std::sort(arr, arr + arrsize, compare<T, std::less<T>>());
10+
if (hasnan) {
11+
std::sort(arr, arr + arrsize, compare<T, std::less<T>>());
12+
}
13+
else {
14+
std::sort(arr, arr + arrsize);
15+
}
1116
}
1217
template <typename T>
1318
void qselect(T *arr, size_t k, size_t arrsize, bool hasnan)
@@ -32,16 +37,18 @@ namespace scalar {
3237
}
3338
}
3439
template <typename T>
35-
std::vector<size_t> argsort(T *arr, size_t arrsize)
40+
std::vector<size_t> argsort(T *arr, size_t arrsize, bool hasnan)
3641
{
42+
UNUSED(hasnan);
3743
std::vector<size_t> arg(arrsize);
3844
std::iota(arg.begin(), arg.end(), 0);
3945
std::sort(arg.begin(), arg.end(), compare_arg<T, std::less<T>>(arr));
4046
return arg;
4147
}
4248
template <typename T>
43-
std::vector<size_t> argselect(T *arr, size_t k, size_t arrsize)
49+
std::vector<size_t> argselect(T *arr, size_t k, size_t arrsize, bool hasnan)
4450
{
51+
UNUSED(hasnan);
4552
std::vector<size_t> arg(arrsize);
4653
std::iota(arg.begin(), arg.end(), 0);
4754
std::nth_element(arg.begin(),

lib/x86simdsort-skx.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66

77
#define DEFINE_ALL_METHODS(type) \
88
template <> \
9-
void qsort(type *arr, size_t arrsize) \
9+
void qsort(type *arr, size_t arrsize, bool hasnan) \
1010
{ \
11-
avx512_qsort(arr, arrsize); \
11+
avx512_qsort(arr, arrsize, hasnan); \
1212
} \
1313
template <> \
1414
void qselect(type *arr, size_t k, size_t arrsize, bool hasnan) \
@@ -21,14 +21,15 @@
2121
avx512_partial_qsort(arr, k, arrsize, hasnan); \
2222
} \
2323
template <> \
24-
std::vector<size_t> argsort(type *arr, size_t arrsize) \
24+
std::vector<size_t> argsort(type *arr, size_t arrsize, bool hasnan) \
2525
{ \
26-
return avx512_argsort(arr, arrsize); \
26+
return avx512_argsort(arr, arrsize, hasnan); \
2727
} \
2828
template <> \
29-
std::vector<size_t> argselect(type *arr, size_t k, size_t arrsize) \
29+
std::vector<size_t> argselect( \
30+
type *arr, size_t k, size_t arrsize, bool hasnan) \
3031
{ \
31-
return avx512_argselect(arr, k, arrsize); \
32+
return avx512_argselect(arr, k, arrsize, hasnan); \
3233
}
3334

3435
namespace xss {

lib/x86simdsort-spr.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
namespace xss {
66
namespace avx512 {
77
template <>
8-
void qsort(_Float16 *arr, size_t size)
8+
void qsort(_Float16 *arr, size_t size, bool hasnan)
99
{
10-
avx512_qsort(arr, size);
10+
avx512_qsort(arr, size, hasnan);
1111
}
1212
template <>
1313
void qselect(_Float16 *arr, size_t k, size_t arrsize, bool hasnan)

lib/x86simdsort.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,11 @@ dispatch_requested(std::string_view cpurequested,
5555
#define CAT(a, b) CAT_(a, b)
5656

5757
#define DECLARE_INTERNAL_qsort(TYPE) \
58-
static void (*internal_qsort##TYPE)(TYPE *, size_t) = NULL; \
58+
static void (*internal_qsort##TYPE)(TYPE *, size_t, bool) = NULL; \
5959
template <> \
60-
void qsort(TYPE *arr, size_t arrsize) \
60+
void qsort(TYPE *arr, size_t arrsize, bool hasnan) \
6161
{ \
62-
(*internal_qsort##TYPE)(arr, arrsize); \
62+
(*internal_qsort##TYPE)(arr, arrsize, hasnan); \
6363
}
6464

6565
#define DECLARE_INTERNAL_qselect(TYPE) \
@@ -81,22 +81,23 @@ dispatch_requested(std::string_view cpurequested,
8181
}
8282

8383
#define DECLARE_INTERNAL_argsort(TYPE) \
84-
static std::vector<size_t> (*internal_argsort##TYPE)(TYPE *, size_t) \
84+
static std::vector<size_t> (*internal_argsort##TYPE)(TYPE *, size_t, bool) \
8585
= NULL; \
8686
template <> \
87-
std::vector<size_t> argsort(TYPE *arr, size_t arrsize) \
87+
std::vector<size_t> argsort(TYPE *arr, size_t arrsize, bool hasnan) \
8888
{ \
89-
return (*internal_argsort##TYPE)(arr, arrsize); \
89+
return (*internal_argsort##TYPE)(arr, arrsize, hasnan); \
9090
}
9191

9292
#define DECLARE_INTERNAL_argselect(TYPE) \
9393
static std::vector<size_t> (*internal_argselect##TYPE)( \
94-
TYPE *, size_t, size_t) \
94+
TYPE *, size_t, size_t, bool) \
9595
= NULL; \
9696
template <> \
97-
std::vector<size_t> argselect(TYPE *arr, size_t k, size_t arrsize) \
97+
std::vector<size_t> argselect( \
98+
TYPE *arr, size_t k, size_t arrsize, bool hasnan) \
9899
{ \
99-
return (*internal_argselect##TYPE)(arr, k, arrsize); \
100+
return (*internal_argselect##TYPE)(arr, k, arrsize, hasnan); \
100101
}
101102

102103
/* runtime dispatch mechanism */

lib/x86simdsort.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,33 @@
66

77
#define XSS_EXPORT_SYMBOL __attribute__((visibility("default")))
88
#define XSS_HIDE_SYMBOL __attribute__((visibility("hidden")))
9+
#define UNUSED(x) (void)(x)
910

1011
namespace x86simdsort {
12+
1113
// quicksort
1214
template <typename T>
13-
XSS_EXPORT_SYMBOL void qsort(T *arr, size_t arrsize);
15+
XSS_EXPORT_SYMBOL void qsort(T *arr, size_t arrsize, bool hasnan = false);
16+
1417
// quickselect
1518
template <typename T>
1619
XSS_EXPORT_SYMBOL void
1720
qselect(T *arr, size_t k, size_t arrsize, bool hasnan = false);
21+
1822
// partial sort
1923
template <typename T>
2024
XSS_EXPORT_SYMBOL void
2125
partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan = false);
26+
2227
// argsort
2328
template <typename T>
24-
XSS_EXPORT_SYMBOL std::vector<size_t> argsort(T *arr, size_t arrsize);
29+
XSS_EXPORT_SYMBOL std::vector<size_t>
30+
argsort(T *arr, size_t arrsize, bool hasnan = false);
31+
2532
// argselect
2633
template <typename T>
2734
XSS_EXPORT_SYMBOL std::vector<size_t>
28-
argselect(T *arr, size_t k, size_t arrsize);
35+
argselect(T *arr, size_t k, size_t arrsize, bool hasnan = false);
36+
2937
} // namespace x86simdsort
3038
#endif

src/avx512-16bit-qsort.hpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -519,12 +519,14 @@ bool is_a_nan<uint16_t>(uint16_t elem)
519519
}
520520

521521
X86_SIMD_SORT_INLINE
522-
void avx512_qsort_fp16(uint16_t *arr, arrsize_t arrsize)
522+
void avx512_qsort_fp16(uint16_t *arr, arrsize_t arrsize, bool hasnan = false)
523523
{
524524
if (arrsize > 1) {
525-
arrsize_t nan_count
526-
= replace_nan_with_inf<zmm_vector<float16>, uint16_t>(arr,
527-
arrsize);
525+
arrsize_t nan_count = 0;
526+
if (UNLIKELY(hasnan)) {
527+
nan_count = replace_nan_with_inf<zmm_vector<float16>, uint16_t>(
528+
arr, arrsize);
529+
}
528530
qsort_<zmm_vector<float16>, uint16_t>(
529531
arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize));
530532
replace_inf_with_nan(arr, arrsize, nan_count);
@@ -535,7 +537,7 @@ X86_SIMD_SORT_INLINE
535537
void avx512_qselect_fp16(uint16_t *arr,
536538
arrsize_t k,
537539
arrsize_t arrsize,
538-
bool hasnan = true)
540+
bool hasnan = false)
539541
{
540542
arrsize_t indx_last_elem = arrsize - 1;
541543
if (UNLIKELY(hasnan)) {

src/avx512-64bit-argsort.hpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -638,61 +638,64 @@ X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr,
638638
/* argsort methods for 32-bit and 64-bit dtypes */
639639
template <typename T>
640640
X86_SIMD_SORT_INLINE void
641-
avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize)
641+
avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false)
642642
{
643643
using vectype = typename std::conditional<sizeof(T) == sizeof(int32_t),
644644
ymm_vector<T>,
645645
zmm_vector<T>>::type;
646646
if (arrsize > 1) {
647647
if constexpr (std::is_floating_point_v<T>) {
648-
if (has_nan<vectype>(arr, arrsize)) {
648+
if ((hasnan) && (array_has_nan<vectype>(arr, arrsize))) {
649649
std_argsort_withnan(arr, arg, 0, arrsize);
650650
return;
651651
}
652652
}
653+
UNUSED(hasnan);
653654
argsort_64bit_<vectype>(
654655
arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize));
655656
}
656657
}
657658

658659
template <typename T>
659660
X86_SIMD_SORT_INLINE std::vector<arrsize_t> avx512_argsort(T *arr,
660-
arrsize_t arrsize)
661+
arrsize_t arrsize,
662+
bool hasnan = false)
661663
{
662664
std::vector<arrsize_t> indices(arrsize);
663665
std::iota(indices.begin(), indices.end(), 0);
664-
avx512_argsort<T>(arr, indices.data(), arrsize);
666+
avx512_argsort<T>(arr, indices.data(), arrsize, hasnan);
665667
return indices;
666668
}
667669

668670
/* argselect methods for 32-bit and 64-bit dtypes */
669671
template <typename T>
670672
X86_SIMD_SORT_INLINE void
671-
avx512_argselect(T *arr, arrsize_t *arg, arrsize_t k, arrsize_t arrsize)
673+
avx512_argselect(T *arr, arrsize_t *arg, arrsize_t k, arrsize_t arrsize, bool hasnan = false)
672674
{
673675
using vectype = typename std::conditional<sizeof(T) == sizeof(int32_t),
674676
ymm_vector<T>,
675677
zmm_vector<T>>::type;
676678

677679
if (arrsize > 1) {
678680
if constexpr (std::is_floating_point_v<T>) {
679-
if (has_nan<vectype>(arr, arrsize)) {
681+
if ((hasnan) && (array_has_nan<vectype>(arr, arrsize))) {
680682
std_argselect_withnan(arr, arg, k, 0, arrsize);
681683
return;
682684
}
683685
}
686+
UNUSED(hasnan);
684687
argselect_64bit_<vectype>(
685688
arr, arg, k, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize));
686689
}
687690
}
688691

689692
template <typename T>
690693
X86_SIMD_SORT_INLINE std::vector<arrsize_t>
691-
avx512_argselect(T *arr, arrsize_t k, arrsize_t arrsize)
694+
avx512_argselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false)
692695
{
693696
std::vector<arrsize_t> indices(arrsize);
694697
std::iota(indices.begin(), indices.end(), 0);
695-
avx512_argselect<T>(arr, indices.data(), k, arrsize);
698+
avx512_argselect<T>(arr, indices.data(), k, arrsize, hasnan);
696699
return indices;
697700
}
698701

0 commit comments

Comments
 (0)