Skip to content

Commit b383a5b

Browse files
committed
Make NAN moving logic for kvsort use simd when possible
1 parent a55a655 commit b383a5b

File tree

1 file changed

+34
-3
lines changed

1 file changed

+34
-3
lines changed

src/xss-common-keyvaluesort.hpp

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,42 @@
2020
* Sort all the NAN's to end of the array and return the index of the last elem
2121
* in the array which is not a nan
2222
*/
23-
template <typename T1, typename T2>
23+
template <typename T1, typename T2, typename vtype>
2424
X86_SIMD_SORT_INLINE arrsize_t move_nans_to_end_of_array(T1 *keys,
2525
T2 *vals,
2626
arrsize_t size)
2727
{
28+
using reg_t = typename vtype::reg_t;
29+
2830
arrsize_t jj = size - 1;
2931
arrsize_t ii = 0;
3032
arrsize_t count = 0;
33+
34+
while (ii + vtype::numlanes < jj) {
35+
reg_t in = vtype::loadu(keys + ii);
36+
auto nanmask = vtype::convert_mask_to_int(
37+
vtype::template fpclass<0x01 | 0x80>(in));
38+
39+
// Check if there are any nans in this vector, and process them if so
40+
if (nanmask != 0x00) {
41+
for (size_t offset = 0; offset < vtype::numlanes; offset++) {
42+
if (is_a_nan(keys[ii])) {
43+
std::swap(keys[ii], keys[jj]);
44+
std::swap(vals[ii], vals[jj]);
45+
jj -= 1;
46+
count++;
47+
}
48+
else {
49+
ii += 1;
50+
}
51+
}
52+
}
53+
else {
54+
ii += vtype::numlanes;
55+
}
56+
}
57+
58+
// Handle the remainders once we have less than 1 vector worth
3159
while (ii < jj) {
3260
if (is_a_nan(keys[ii])) {
3361
std::swap(keys[ii], keys[jj]);
@@ -39,6 +67,7 @@ X86_SIMD_SORT_INLINE arrsize_t move_nans_to_end_of_array(T1 *keys,
3967
ii += 1;
4068
}
4169
}
70+
4271
/* Haven't checked for nan when ii == jj */
4372
if (is_a_nan(keys[ii])) { count++; }
4473
return size - count - 1;
@@ -570,7 +599,8 @@ X86_SIMD_SORT_INLINE void xss_qsort_kv(
570599
if constexpr (xss::fp::is_floating_point_v<T1>) {
571600
if (UNLIKELY(hasnan)) {
572601
index_last_elem
573-
= move_nans_to_end_of_array(keys, indexes, arrsize);
602+
= move_nans_to_end_of_array<T1, T2, full_vector<T1>>(
603+
keys, indexes, arrsize);
574604
}
575605
}
576606
else {
@@ -660,7 +690,8 @@ X86_SIMD_SORT_INLINE void xss_select_kv(T1 *keys,
660690
if constexpr (xss::fp::is_floating_point_v<T1>) {
661691
if (UNLIKELY(hasnan)) {
662692
index_last_elem
663-
= move_nans_to_end_of_array(keys, indexes, arrsize);
693+
= move_nans_to_end_of_array<T1, T2, full_vector<T1>>(
694+
keys, indexes, arrsize);
664695
}
665696
}
666697

0 commit comments

Comments
 (0)