20
20
* Sort all the NAN's to end of the array and return the index of the last elem
21
21
* in the array which is not a nan
22
22
*/
23
- template <typename T1, typename T2>
23
+ template <typename T1, typename T2, typename vtype >
24
24
X86_SIMD_SORT_INLINE arrsize_t move_nans_to_end_of_array (T1 *keys,
25
25
T2 *vals,
26
26
arrsize_t size)
27
27
{
28
+ using reg_t = typename vtype::reg_t ;
29
+
28
30
arrsize_t jj = size - 1 ;
29
31
arrsize_t ii = 0 ;
30
32
arrsize_t count = 0 ;
33
+
34
+ while (ii + vtype::numlanes < jj) {
35
+ reg_t in = vtype::loadu (keys + ii);
36
+ auto nanmask = vtype::convert_mask_to_int (
37
+ vtype::template fpclass<0x01 | 0x80 >(in));
38
+
39
+ // Check if there are any nans in this vector, and process them if so
40
+ if (nanmask != 0x00 ) {
41
+ for (size_t offset = 0 ; offset < vtype::numlanes; offset++) {
42
+ if (is_a_nan (keys[ii])) {
43
+ std::swap (keys[ii], keys[jj]);
44
+ std::swap (vals[ii], vals[jj]);
45
+ jj -= 1 ;
46
+ count++;
47
+ }
48
+ else {
49
+ ii += 1 ;
50
+ }
51
+ }
52
+ }
53
+ else {
54
+ ii += vtype::numlanes;
55
+ }
56
+ }
57
+
58
+ // Handle the remainders once we have less than 1 vector worth
31
59
while (ii < jj) {
32
60
if (is_a_nan (keys[ii])) {
33
61
std::swap (keys[ii], keys[jj]);
@@ -39,6 +67,7 @@ X86_SIMD_SORT_INLINE arrsize_t move_nans_to_end_of_array(T1 *keys,
39
67
ii += 1 ;
40
68
}
41
69
}
70
+
42
71
/* Haven't checked for nan when ii == jj */
43
72
if (is_a_nan (keys[ii])) { count++; }
44
73
return size - count - 1 ;
@@ -570,7 +599,8 @@ X86_SIMD_SORT_INLINE void xss_qsort_kv(
570
599
if constexpr (xss::fp::is_floating_point_v<T1>) {
571
600
if (UNLIKELY (hasnan)) {
572
601
index_last_elem
573
- = move_nans_to_end_of_array (keys, indexes, arrsize);
602
+ = move_nans_to_end_of_array<T1, T2, full_vector<T1>>(
603
+ keys, indexes, arrsize);
574
604
}
575
605
}
576
606
else {
@@ -660,7 +690,8 @@ X86_SIMD_SORT_INLINE void xss_select_kv(T1 *keys,
660
690
if constexpr (xss::fp::is_floating_point_v<T1>) {
661
691
if (UNLIKELY (hasnan)) {
662
692
index_last_elem
663
- = move_nans_to_end_of_array (keys, indexes, arrsize);
693
+ = move_nans_to_end_of_array<T1, T2, full_vector<T1>>(
694
+ keys, indexes, arrsize);
664
695
}
665
696
}
666
697
0 commit comments