Skip to content

Commit b27f82b

Browse files
author
Raghuveer Devulapalli
authored
Merge pull request #178 from sterrettm2/kvsort-nantests
Fix kvsort/kvselect nan behavior and added tests for mixed nan/inf arrays
2 parents 9427923 + 89de98d commit b27f82b

File tree

5 files changed

+150
-44
lines changed

5 files changed

+150
-44
lines changed

src/xss-common-keyvaluesort.hpp

Lines changed: 85 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,63 @@
1616
#include <omp.h>
1717
#endif
1818

19+
/*
20+
* Sort all the NAN's to end of the array and return the index of the last elem
21+
* in the array which is not a nan
22+
*/
23+
template <typename T1, typename T2, typename vtype>
24+
X86_SIMD_SORT_INLINE arrsize_t move_nans_to_end_of_array(T1 *keys,
25+
T2 *vals,
26+
arrsize_t size)
27+
{
28+
using reg_t = typename vtype::reg_t;
29+
30+
arrsize_t jj = size - 1;
31+
arrsize_t ii = 0;
32+
arrsize_t count = 0;
33+
34+
while (ii + vtype::numlanes < jj) {
35+
reg_t in = vtype::loadu(keys + ii);
36+
auto nanmask = vtype::convert_mask_to_int(
37+
vtype::template fpclass<0x01 | 0x80>(in));
38+
39+
// Check if there are any nans in this vector, and process them if so
40+
if (nanmask != 0x00) {
41+
for (size_t offset = 0; offset < vtype::numlanes; offset++) {
42+
if (is_a_nan(keys[ii])) {
43+
std::swap(keys[ii], keys[jj]);
44+
std::swap(vals[ii], vals[jj]);
45+
jj -= 1;
46+
count++;
47+
}
48+
else {
49+
ii += 1;
50+
}
51+
}
52+
}
53+
else {
54+
ii += vtype::numlanes;
55+
}
56+
}
57+
58+
// Handle the remainders once we have less than 1 vector worth
59+
while (ii < jj) {
60+
if (is_a_nan(keys[ii])) {
61+
std::swap(keys[ii], keys[jj]);
62+
std::swap(vals[ii], vals[jj]);
63+
jj -= 1;
64+
count++;
65+
}
66+
else {
67+
ii += 1;
68+
}
69+
}
70+
71+
/* Haven't checked for nan when ii == jj */
72+
if (is_a_nan(keys[ii])) { count++; }
73+
return size - count - 1;
74+
}
75+
1976
/*
2077
* Parition one ZMM register based on the pivot and returns the index of the
2178
* last element that is less than equal to the pivot.
@@ -529,6 +586,9 @@ X86_SIMD_SORT_INLINE void xss_qsort_kv(
529586
half_vector<T2>,
530587
full_vector<T2>>::type;
531588

589+
// Exit early if no work would be done
590+
if (arrsize <= 1) return;
591+
532592
#ifdef XSS_TEST_KEYVALUE_BASE_CASE
533593
int maxiters = -1;
534594
bool minarrsize = true;
@@ -538,11 +598,12 @@ X86_SIMD_SORT_INLINE void xss_qsort_kv(
538598
#endif // XSS_TEST_KEYVALUE_BASE_CASE
539599

540600
if (minarrsize) {
541-
arrsize_t nan_count = 0;
601+
arrsize_t index_last_elem = arrsize - 1;
542602
if constexpr (xss::fp::is_floating_point_v<T1>) {
543603
if (UNLIKELY(hasnan)) {
544-
nan_count
545-
= replace_nan_with_inf<full_vector<T1>>(keys, arrsize);
604+
index_last_elem
605+
= move_nans_to_end_of_array<T1, T2, full_vector<T1>>(
606+
keys, indexes, arrsize);
546607
}
547608
}
548609
else {
@@ -565,24 +626,27 @@ X86_SIMD_SORT_INLINE void xss_qsort_kv(
565626
// Note that we do not use the if(...) clause built into OpenMP, because it causes a performance regression for small arrays
566627
#pragma omp parallel num_threads(thread_count)
567628
#pragma omp single
568-
kvsort_<keytype, valtype>(
569-
keys, indexes, 0, arrsize - 1, maxiters, task_threshold);
629+
kvsort_<keytype, valtype>(keys,
630+
indexes,
631+
0,
632+
index_last_elem,
633+
maxiters,
634+
task_threshold);
570635
}
571636
else {
572637
kvsort_<keytype, valtype>(keys,
573638
indexes,
574639
0,
575-
arrsize - 1,
640+
index_last_elem,
576641
maxiters,
577642
std::numeric_limits<arrsize_t>::max());
578643
}
579644
#pragma omp taskwait
580645
#else
581-
kvsort_<keytype, valtype>(keys, indexes, 0, arrsize - 1, maxiters, 0);
646+
kvsort_<keytype, valtype>(
647+
keys, indexes, 0, index_last_elem, maxiters, 0);
582648
#endif
583649

584-
replace_inf_with_nan(keys, arrsize, nan_count);
585-
586650
if (descending) {
587651
std::reverse(keys, keys + arrsize);
588652
std::reverse(indexes, indexes + arrsize);
@@ -614,6 +678,9 @@ X86_SIMD_SORT_INLINE void xss_select_kv(T1 *keys,
614678
half_vector<T2>,
615679
full_vector<T2>>::type;
616680

681+
// Exit early if no work would be done
682+
if (arrsize <= 1) return;
683+
617684
#ifdef XSS_TEST_KEYVALUE_BASE_CASE
618685
int maxiters = -1;
619686
bool minarrsize = true;
@@ -625,20 +692,19 @@ X86_SIMD_SORT_INLINE void xss_select_kv(T1 *keys,
625692
if (minarrsize) {
626693
if (descending) { k = arrsize - 1 - k; }
627694

628-
if constexpr (std::is_floating_point_v<T1>) {
629-
arrsize_t nan_count = 0;
695+
arrsize_t index_last_elem = arrsize - 1;
696+
if constexpr (xss::fp::is_floating_point_v<T1>) {
630697
if (UNLIKELY(hasnan)) {
631-
nan_count
632-
= replace_nan_with_inf<full_vector<T1>>(keys, arrsize);
698+
index_last_elem
699+
= move_nans_to_end_of_array<T1, T2, full_vector<T1>>(
700+
keys, indexes, arrsize);
633701
}
634-
kvselect_<keytype, valtype>(
635-
keys, indexes, k, 0, arrsize - 1, maxiters);
636-
replace_inf_with_nan(keys, arrsize, nan_count);
637702
}
638-
else {
639-
UNUSED(hasnan);
703+
704+
UNUSED(hasnan);
705+
if (index_last_elem >= k) {
640706
kvselect_<keytype, valtype>(
641-
keys, indexes, k, 0, arrsize - 1, maxiters);
707+
keys, indexes, k, 0, index_last_elem, maxiters);
642708
}
643709

644710
if (descending) {

tests/test-keyvalue.cpp

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ class simdkvsort : public ::testing::Test {
2626
"smallrange",
2727
"max_at_the_end",
2828
"random_5d",
29-
"rand_max"};
29+
"rand_max",
30+
"rand_with_nan",
31+
"rand_with_max_and_nan"};
3032
}
3133
std::vector<std::string> arrtype;
3234
std::vector<size_t> arrsize = std::vector<size_t>(1024);
@@ -123,27 +125,36 @@ bool is_kv_partialsorted(T1 *keys_comp,
123125
}
124126

125127
// Now, we need to do some more work to handle keys exactly equal to the true kth
128+
// There may be more values after the kth element with the same key,
129+
// and thus we can find that the values of the kth elements do not match,
130+
// even though the sort is correct.
131+
126132
// First, fully kvsort both arrays
127133
xss::scalar::keyvalue_qsort<T1, T2>(keys_ref, vals_ref, size, true, false);
128134
xss::scalar::keyvalue_qsort<T1, T2>(
129135
keys_comp, vals_comp, size, true, false);
130136

131-
auto trueKth = keys_ref[k];
132-
bool notFoundFirst = true;
137+
auto trueKthKey = keys_ref[k];
138+
bool foundFirstKthKey = false;
133139
size_t i = 0;
134140

141+
// Search forwards until we find the block of keys that match the kth key,
142+
// then find where it ends
135143
for (; i < size; i++) {
136-
if (notFoundFirst && cmp_eq(keys_ref[i], trueKth)) {
137-
notFoundFirst = false;
144+
if (!foundFirstKthKey && cmp_eq(keys_ref[i], trueKthKey)) {
145+
foundFirstKthKey = true;
138146
i_start = i;
139147
}
140-
else if (!notFoundFirst && !cmp_eq(keys_ref[i], trueKth)) {
148+
else if (foundFirstKthKey && !cmp_eq(keys_ref[i], trueKthKey)) {
141149
break;
142150
}
143151
}
144152

145-
if (notFoundFirst) return false;
153+
// kth key is somehow missing? Since we got that value from keys_ref, should be impossible
154+
if (!foundFirstKthKey) { return false; }
146155

156+
// Check that the values in the kth key block match, so they are equivalent
157+
// up to permutation, which is allowed since the sort is not stable
147158
if (!same_values(vals_ref + i_start, vals_comp + i_start, i - i_start)) {
148159
return false;
149160
}
@@ -156,7 +167,7 @@ TYPED_TEST_P(simdkvsort, test_kvsort_ascending)
156167
using T1 = typename std::tuple_element<0, decltype(TypeParam())>::type;
157168
using T2 = typename std::tuple_element<1, decltype(TypeParam())>::type;
158169
for (auto type : this->arrtype) {
159-
bool hasnan = (type == "rand_with_nan") ? true : false;
170+
bool hasnan = is_nan_test(type);
160171
for (auto size : this->arrsize) {
161172
std::vector<T1> key = get_array<T1>(type, size);
162173
std::vector<T2> val = get_array<T2>(type, size);
@@ -187,7 +198,7 @@ TYPED_TEST_P(simdkvsort, test_kvsort_descending)
187198
using T1 = typename std::tuple_element<0, decltype(TypeParam())>::type;
188199
using T2 = typename std::tuple_element<1, decltype(TypeParam())>::type;
189200
for (auto type : this->arrtype) {
190-
bool hasnan = (type == "rand_with_nan") ? true : false;
201+
bool hasnan = is_nan_test(type);
191202
for (auto size : this->arrsize) {
192203
std::vector<T1> key = get_array<T1>(type, size);
193204
std::vector<T2> val = get_array<T2>(type, size);
@@ -217,8 +228,9 @@ TYPED_TEST_P(simdkvsort, test_kvselect_ascending)
217228
{
218229
using T1 = typename std::tuple_element<0, decltype(TypeParam())>::type;
219230
using T2 = typename std::tuple_element<1, decltype(TypeParam())>::type;
231+
auto cmp_eq = compare<T1, std::equal_to<T1>>();
220232
for (auto type : this->arrtype) {
221-
bool hasnan = (type == "rand_with_nan") ? true : false;
233+
bool hasnan = is_nan_test(type);
222234
for (auto size : this->arrsize) {
223235
size_t k = rand() % size;
224236

@@ -237,7 +249,7 @@ TYPED_TEST_P(simdkvsort, test_kvselect_ascending)
237249
xss::scalar::keyvalue_qsort(
238250
key.data(), val.data(), k, hasnan, false);
239251

240-
ASSERT_EQ(key[k], key_bckp[k]);
252+
ASSERT_EQ(cmp_eq(key[k], key_bckp[k]), true);
241253

242254
bool is_kv_partialsorted_
243255
= is_kv_partialsorted<T1, T2>(key.data(),
@@ -260,8 +272,9 @@ TYPED_TEST_P(simdkvsort, test_kvselect_descending)
260272
{
261273
using T1 = typename std::tuple_element<0, decltype(TypeParam())>::type;
262274
using T2 = typename std::tuple_element<1, decltype(TypeParam())>::type;
275+
auto cmp_eq = compare<T1, std::equal_to<T1>>();
263276
for (auto type : this->arrtype) {
264-
bool hasnan = (type == "rand_with_nan") ? true : false;
277+
bool hasnan = is_nan_test(type);
265278
for (auto size : this->arrsize) {
266279
size_t k = rand() % size;
267280

@@ -280,7 +293,7 @@ TYPED_TEST_P(simdkvsort, test_kvselect_descending)
280293
xss::scalar::keyvalue_qsort(
281294
key.data(), val.data(), k, hasnan, true);
282295

283-
ASSERT_EQ(key[k], key_bckp[k]);
296+
ASSERT_EQ(cmp_eq(key[k], key_bckp[k]), true);
284297

285298
bool is_kv_partialsorted_
286299
= is_kv_partialsorted<T1, T2>(key.data(),
@@ -304,7 +317,7 @@ TYPED_TEST_P(simdkvsort, test_kvpartial_sort_ascending)
304317
using T1 = typename std::tuple_element<0, decltype(TypeParam())>::type;
305318
using T2 = typename std::tuple_element<1, decltype(TypeParam())>::type;
306319
for (auto type : this->arrtype) {
307-
bool hasnan = (type == "rand_with_nan") ? true : false;
320+
bool hasnan = is_nan_test(type);
308321
for (auto size : this->arrsize) {
309322
size_t k = rand() % size;
310323

@@ -341,7 +354,7 @@ TYPED_TEST_P(simdkvsort, test_kvpartial_sort_descending)
341354
using T1 = typename std::tuple_element<0, decltype(TypeParam())>::type;
342355
using T2 = typename std::tuple_element<1, decltype(TypeParam())>::type;
343356
for (auto type : this->arrtype) {
344-
bool hasnan = (type == "rand_with_nan") ? true : false;
357+
bool hasnan = is_nan_test(type);
345358
for (auto size : this->arrsize) {
346359
size_t k = rand() % size;
347360

tests/test-qsort-common.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@
2020
ASSERT_TRUE(false) << msg << ". arr size = " << size \
2121
<< ", type = " << type << ", k = " << k;
2222

23+
inline bool is_nan_test(std::string type)
24+
{
25+
// Currently, determine whether the test uses nan just be checking if nan is in its name
26+
return type.find("nan") != std::string::npos;
27+
}
28+
2329
template <typename T>
2430
void IS_SORTED(std::vector<T> sorted, std::vector<T> arr, std::string type)
2531
{

tests/test-qsort.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ class simdsort : public ::testing::Test {
1919
"max_at_the_end",
2020
"random_5d",
2121
"rand_max",
22-
"rand_with_nan"};
22+
"rand_with_nan",
23+
"rand_with_max_and_nan"};
2324
}
2425
std::vector<std::string> arrtype;
2526
std::vector<size_t> arrsize = std::vector<size_t>(1024);
@@ -30,7 +31,7 @@ TYPED_TEST_SUITE_P(simdsort);
3031
TYPED_TEST_P(simdsort, test_qsort_ascending)
3132
{
3233
for (auto type : this->arrtype) {
33-
bool hasnan = (type == "rand_with_nan") ? true : false;
34+
bool hasnan = is_nan_test(type);
3435
for (auto size : this->arrsize) {
3536
std::vector<TypeParam> basearr = get_array<TypeParam>(type, size);
3637

@@ -52,7 +53,7 @@ TYPED_TEST_P(simdsort, test_qsort_ascending)
5253
TYPED_TEST_P(simdsort, test_qsort_descending)
5354
{
5455
for (auto type : this->arrtype) {
55-
bool hasnan = (type == "rand_with_nan") ? true : false;
56+
bool hasnan = is_nan_test(type);
5657
for (auto size : this->arrsize) {
5758
std::vector<TypeParam> basearr = get_array<TypeParam>(type, size);
5859

@@ -74,7 +75,7 @@ TYPED_TEST_P(simdsort, test_qsort_descending)
7475
TYPED_TEST_P(simdsort, test_argsort_ascending)
7576
{
7677
for (auto type : this->arrtype) {
77-
bool hasnan = (type == "rand_with_nan") ? true : false;
78+
bool hasnan = is_nan_test(type);
7879
for (auto size : this->arrsize) {
7980
std::vector<TypeParam> arr = get_array<TypeParam>(type, size);
8081
std::vector<TypeParam> sortedarr = arr;
@@ -92,7 +93,7 @@ TYPED_TEST_P(simdsort, test_argsort_ascending)
9293
TYPED_TEST_P(simdsort, test_argsort_descending)
9394
{
9495
for (auto type : this->arrtype) {
95-
bool hasnan = (type == "rand_with_nan") ? true : false;
96+
bool hasnan = is_nan_test(type);
9697
for (auto size : this->arrsize) {
9798
std::vector<TypeParam> arr = get_array<TypeParam>(type, size);
9899
std::vector<TypeParam> sortedarr = arr;
@@ -111,7 +112,7 @@ TYPED_TEST_P(simdsort, test_argsort_descending)
111112
TYPED_TEST_P(simdsort, test_qselect_ascending)
112113
{
113114
for (auto type : this->arrtype) {
114-
bool hasnan = (type == "rand_with_nan") ? true : false;
115+
bool hasnan = is_nan_test(type);
115116
for (auto size : this->arrsize) {
116117
size_t k = rand() % size;
117118
std::vector<TypeParam> basearr = get_array<TypeParam>(type, size);
@@ -135,7 +136,7 @@ TYPED_TEST_P(simdsort, test_qselect_ascending)
135136
TYPED_TEST_P(simdsort, test_qselect_descending)
136137
{
137138
for (auto type : this->arrtype) {
138-
bool hasnan = (type == "rand_with_nan") ? true : false;
139+
bool hasnan = is_nan_test(type);
139140
for (auto size : this->arrsize) {
140141
size_t k = rand() % size;
141142
std::vector<TypeParam> basearr = get_array<TypeParam>(type, size);
@@ -159,7 +160,7 @@ TYPED_TEST_P(simdsort, test_qselect_descending)
159160
TYPED_TEST_P(simdsort, test_argselect)
160161
{
161162
for (auto type : this->arrtype) {
162-
bool hasnan = (type == "rand_with_nan") ? true : false;
163+
bool hasnan = is_nan_test(type);
163164
for (auto size : this->arrsize) {
164165
size_t k = rand() % size;
165166
std::vector<TypeParam> arr = get_array<TypeParam>(type, size);
@@ -179,7 +180,7 @@ TYPED_TEST_P(simdsort, test_argselect)
179180
TYPED_TEST_P(simdsort, test_partial_qsort_ascending)
180181
{
181182
for (auto type : this->arrtype) {
182-
bool hasnan = (type == "rand_with_nan") ? true : false;
183+
bool hasnan = is_nan_test(type);
183184
for (auto size : this->arrsize) {
184185
size_t k = rand() % size;
185186
std::vector<TypeParam> basearr = get_array<TypeParam>(type, size);
@@ -202,7 +203,7 @@ TYPED_TEST_P(simdsort, test_partial_qsort_ascending)
202203
TYPED_TEST_P(simdsort, test_partial_qsort_descending)
203204
{
204205
for (auto type : this->arrtype) {
205-
bool hasnan = (type == "rand_with_nan") ? true : false;
206+
bool hasnan = is_nan_test(type);
206207
for (auto size : this->arrsize) {
207208
// k should be at least 1
208209
size_t k = std::max((size_t)1, rand() % size);

0 commit comments

Comments
 (0)