Skip to content

Commit a55a655

Browse files
committed
Fix kv-sort and kv-select NAN handling
1 parent 59e298d commit a55a655

File tree

5 files changed

+113
-44
lines changed

5 files changed

+113
-44
lines changed

src/xss-common-keyvaluesort.hpp

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,34 @@
1616
#include <omp.h>
1717
#endif
1818

19+
/*
20+
* Sort all the NAN's to end of the array and return the index of the last elem
21+
* in the array which is not a nan
22+
*/
23+
template <typename T1, typename T2>
24+
X86_SIMD_SORT_INLINE arrsize_t move_nans_to_end_of_array(T1 *keys,
25+
T2 *vals,
26+
arrsize_t size)
27+
{
28+
arrsize_t jj = size - 1;
29+
arrsize_t ii = 0;
30+
arrsize_t count = 0;
31+
while (ii < jj) {
32+
if (is_a_nan(keys[ii])) {
33+
std::swap(keys[ii], keys[jj]);
34+
std::swap(vals[ii], vals[jj]);
35+
jj -= 1;
36+
count++;
37+
}
38+
else {
39+
ii += 1;
40+
}
41+
}
42+
/* Haven't checked for nan when ii == jj */
43+
if (is_a_nan(keys[ii])) { count++; }
44+
return size - count - 1;
45+
}
46+
1947
/*
2048
* Parition one ZMM register based on the pivot and returns the index of the
2149
* last element that is less than equal to the pivot.
@@ -538,11 +566,11 @@ X86_SIMD_SORT_INLINE void xss_qsort_kv(
538566
#endif // XSS_TEST_KEYVALUE_BASE_CASE
539567

540568
if (minarrsize) {
541-
arrsize_t nan_count = 0;
569+
arrsize_t index_last_elem = arrsize - 1;
542570
if constexpr (xss::fp::is_floating_point_v<T1>) {
543571
if (UNLIKELY(hasnan)) {
544-
nan_count
545-
= replace_nan_with_inf<full_vector<T1>>(keys, arrsize);
572+
index_last_elem
573+
= move_nans_to_end_of_array(keys, indexes, arrsize);
546574
}
547575
}
548576
else {
@@ -565,24 +593,27 @@ X86_SIMD_SORT_INLINE void xss_qsort_kv(
565593
// Note that we do not use the if(...) clause built into OpenMP, because it causes a performance regression for small arrays
566594
#pragma omp parallel num_threads(thread_count)
567595
#pragma omp single
568-
kvsort_<keytype, valtype>(
569-
keys, indexes, 0, arrsize - 1, maxiters, task_threshold);
596+
kvsort_<keytype, valtype>(keys,
597+
indexes,
598+
0,
599+
index_last_elem,
600+
maxiters,
601+
task_threshold);
570602
}
571603
else {
572604
kvsort_<keytype, valtype>(keys,
573605
indexes,
574606
0,
575-
arrsize - 1,
607+
index_last_elem,
576608
maxiters,
577609
std::numeric_limits<arrsize_t>::max());
578610
}
579611
#pragma omp taskwait
580612
#else
581-
kvsort_<keytype, valtype>(keys, indexes, 0, arrsize - 1, maxiters, 0);
613+
kvsort_<keytype, valtype>(
614+
keys, indexes, 0, index_last_elem, maxiters, 0);
582615
#endif
583616

584-
replace_inf_with_nan(keys, arrsize, nan_count);
585-
586617
if (descending) {
587618
std::reverse(keys, keys + arrsize);
588619
std::reverse(indexes, indexes + arrsize);
@@ -625,20 +656,18 @@ X86_SIMD_SORT_INLINE void xss_select_kv(T1 *keys,
625656
if (minarrsize) {
626657
if (descending) { k = arrsize - 1 - k; }
627658

628-
if constexpr (std::is_floating_point_v<T1>) {
629-
arrsize_t nan_count = 0;
659+
arrsize_t index_last_elem = arrsize - 1;
660+
if constexpr (xss::fp::is_floating_point_v<T1>) {
630661
if (UNLIKELY(hasnan)) {
631-
nan_count
632-
= replace_nan_with_inf<full_vector<T1>>(keys, arrsize);
662+
index_last_elem
663+
= move_nans_to_end_of_array(keys, indexes, arrsize);
633664
}
634-
kvselect_<keytype, valtype>(
635-
keys, indexes, k, 0, arrsize - 1, maxiters);
636-
replace_inf_with_nan(keys, arrsize, nan_count);
637665
}
638-
else {
639-
UNUSED(hasnan);
666+
667+
UNUSED(hasnan);
668+
if (index_last_elem >= k) {
640669
kvselect_<keytype, valtype>(
641-
keys, indexes, k, 0, arrsize - 1, maxiters);
670+
keys, indexes, k, 0, index_last_elem, maxiters);
642671
}
643672

644673
if (descending) {

tests/test-keyvalue.cpp

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ class simdkvsort : public ::testing::Test {
2626
"smallrange",
2727
"max_at_the_end",
2828
"random_5d",
29-
"rand_max"};
29+
"rand_max",
30+
"rand_with_nan",
31+
"rand_with_max_and_nan"};
3032
}
3133
std::vector<std::string> arrtype;
3234
std::vector<size_t> arrsize = std::vector<size_t>(1024);
@@ -123,27 +125,36 @@ bool is_kv_partialsorted(T1 *keys_comp,
123125
}
124126

125127
// Now, we need to do some more work to handle keys exactly equal to the true kth
128+
// There may be more values after the kth element with the same key,
129+
// and thus we can find that the values of the kth elements do not match,
130+
// even though the sort is correct.
131+
126132
// First, fully kvsort both arrays
127133
xss::scalar::keyvalue_qsort<T1, T2>(keys_ref, vals_ref, size, true, false);
128134
xss::scalar::keyvalue_qsort<T1, T2>(
129135
keys_comp, vals_comp, size, true, false);
130136

131-
auto trueKth = keys_ref[k];
132-
bool notFoundFirst = true;
137+
auto trueKthKey = keys_ref[k];
138+
bool foundFirstKthKey = false;
133139
size_t i = 0;
134140

141+
// Search forwards until we find the block of keys that match the kth key,
142+
// then find where it ends
135143
for (; i < size; i++) {
136-
if (notFoundFirst && cmp_eq(keys_ref[i], trueKth)) {
137-
notFoundFirst = false;
144+
if (!foundFirstKthKey && cmp_eq(keys_ref[i], trueKthKey)) {
145+
foundFirstKthKey = true;
138146
i_start = i;
139147
}
140-
else if (!notFoundFirst && !cmp_eq(keys_ref[i], trueKth)) {
148+
else if (foundFirstKthKey && !cmp_eq(keys_ref[i], trueKthKey)) {
141149
break;
142150
}
143151
}
144152

145-
if (notFoundFirst) return false;
153+
// kth key is somehow missing? Since we got that value from keys_ref, should be impossible
154+
if (!foundFirstKthKey) { return false; }
146155

156+
// Check that the values in the kth key block match, so they are equivalent
157+
// up to permutation, which is allowed since the sort is not stable
147158
if (!same_values(vals_ref + i_start, vals_comp + i_start, i - i_start)) {
148159
return false;
149160
}
@@ -156,7 +167,7 @@ TYPED_TEST_P(simdkvsort, test_kvsort_ascending)
156167
using T1 = typename std::tuple_element<0, decltype(TypeParam())>::type;
157168
using T2 = typename std::tuple_element<1, decltype(TypeParam())>::type;
158169
for (auto type : this->arrtype) {
159-
bool hasnan = (type == "rand_with_nan") ? true : false;
170+
bool hasnan = is_nan_test(type);
160171
for (auto size : this->arrsize) {
161172
std::vector<T1> key = get_array<T1>(type, size);
162173
std::vector<T2> val = get_array<T2>(type, size);
@@ -187,7 +198,7 @@ TYPED_TEST_P(simdkvsort, test_kvsort_descending)
187198
using T1 = typename std::tuple_element<0, decltype(TypeParam())>::type;
188199
using T2 = typename std::tuple_element<1, decltype(TypeParam())>::type;
189200
for (auto type : this->arrtype) {
190-
bool hasnan = (type == "rand_with_nan") ? true : false;
201+
bool hasnan = is_nan_test(type);
191202
for (auto size : this->arrsize) {
192203
std::vector<T1> key = get_array<T1>(type, size);
193204
std::vector<T2> val = get_array<T2>(type, size);
@@ -217,8 +228,9 @@ TYPED_TEST_P(simdkvsort, test_kvselect_ascending)
217228
{
218229
using T1 = typename std::tuple_element<0, decltype(TypeParam())>::type;
219230
using T2 = typename std::tuple_element<1, decltype(TypeParam())>::type;
231+
auto cmp_eq = compare<T1, std::equal_to<T1>>();
220232
for (auto type : this->arrtype) {
221-
bool hasnan = (type == "rand_with_nan") ? true : false;
233+
bool hasnan = is_nan_test(type);
222234
for (auto size : this->arrsize) {
223235
size_t k = rand() % size;
224236

@@ -237,7 +249,7 @@ TYPED_TEST_P(simdkvsort, test_kvselect_ascending)
237249
xss::scalar::keyvalue_qsort(
238250
key.data(), val.data(), k, hasnan, false);
239251

240-
ASSERT_EQ(key[k], key_bckp[k]);
252+
ASSERT_EQ(cmp_eq(key[k], key_bckp[k]), true);
241253

242254
bool is_kv_partialsorted_
243255
= is_kv_partialsorted<T1, T2>(key.data(),
@@ -260,8 +272,9 @@ TYPED_TEST_P(simdkvsort, test_kvselect_descending)
260272
{
261273
using T1 = typename std::tuple_element<0, decltype(TypeParam())>::type;
262274
using T2 = typename std::tuple_element<1, decltype(TypeParam())>::type;
275+
auto cmp_eq = compare<T1, std::equal_to<T1>>();
263276
for (auto type : this->arrtype) {
264-
bool hasnan = (type == "rand_with_nan") ? true : false;
277+
bool hasnan = is_nan_test(type);
265278
for (auto size : this->arrsize) {
266279
size_t k = rand() % size;
267280

@@ -280,7 +293,7 @@ TYPED_TEST_P(simdkvsort, test_kvselect_descending)
280293
xss::scalar::keyvalue_qsort(
281294
key.data(), val.data(), k, hasnan, true);
282295

283-
ASSERT_EQ(key[k], key_bckp[k]);
296+
ASSERT_EQ(cmp_eq(key[k], key_bckp[k]), true);
284297

285298
bool is_kv_partialsorted_
286299
= is_kv_partialsorted<T1, T2>(key.data(),
@@ -304,7 +317,7 @@ TYPED_TEST_P(simdkvsort, test_kvpartial_sort_ascending)
304317
using T1 = typename std::tuple_element<0, decltype(TypeParam())>::type;
305318
using T2 = typename std::tuple_element<1, decltype(TypeParam())>::type;
306319
for (auto type : this->arrtype) {
307-
bool hasnan = (type == "rand_with_nan") ? true : false;
320+
bool hasnan = is_nan_test(type);
308321
for (auto size : this->arrsize) {
309322
size_t k = rand() % size;
310323

@@ -341,7 +354,7 @@ TYPED_TEST_P(simdkvsort, test_kvpartial_sort_descending)
341354
using T1 = typename std::tuple_element<0, decltype(TypeParam())>::type;
342355
using T2 = typename std::tuple_element<1, decltype(TypeParam())>::type;
343356
for (auto type : this->arrtype) {
344-
bool hasnan = (type == "rand_with_nan") ? true : false;
357+
bool hasnan = is_nan_test(type);
345358
for (auto size : this->arrsize) {
346359
size_t k = rand() % size;
347360

tests/test-qsort-common.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@
2020
ASSERT_TRUE(false) << msg << ". arr size = " << size \
2121
<< ", type = " << type << ", k = " << k;
2222

23+
inline bool is_nan_test(std::string type)
24+
{
25+
// Currently, determine whether the test uses nan just be checking if nan is in its name
26+
return type.find("nan") != std::string::npos;
27+
}
28+
2329
template <typename T>
2430
void IS_SORTED(std::vector<T> sorted, std::vector<T> arr, std::string type)
2531
{

tests/test-qsort.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ class simdsort : public ::testing::Test {
1919
"max_at_the_end",
2020
"random_5d",
2121
"rand_max",
22-
"rand_with_nan"};
22+
"rand_with_nan",
23+
"rand_with_max_and_nan"};
2324
}
2425
std::vector<std::string> arrtype;
2526
std::vector<size_t> arrsize = std::vector<size_t>(1024);
@@ -30,7 +31,7 @@ TYPED_TEST_SUITE_P(simdsort);
3031
TYPED_TEST_P(simdsort, test_qsort_ascending)
3132
{
3233
for (auto type : this->arrtype) {
33-
bool hasnan = (type == "rand_with_nan") ? true : false;
34+
bool hasnan = is_nan_test(type);
3435
for (auto size : this->arrsize) {
3536
std::vector<TypeParam> basearr = get_array<TypeParam>(type, size);
3637

@@ -52,7 +53,7 @@ TYPED_TEST_P(simdsort, test_qsort_ascending)
5253
TYPED_TEST_P(simdsort, test_qsort_descending)
5354
{
5455
for (auto type : this->arrtype) {
55-
bool hasnan = (type == "rand_with_nan") ? true : false;
56+
bool hasnan = is_nan_test(type);
5657
for (auto size : this->arrsize) {
5758
std::vector<TypeParam> basearr = get_array<TypeParam>(type, size);
5859

@@ -74,7 +75,7 @@ TYPED_TEST_P(simdsort, test_qsort_descending)
7475
TYPED_TEST_P(simdsort, test_argsort_ascending)
7576
{
7677
for (auto type : this->arrtype) {
77-
bool hasnan = (type == "rand_with_nan") ? true : false;
78+
bool hasnan = is_nan_test(type);
7879
for (auto size : this->arrsize) {
7980
std::vector<TypeParam> arr = get_array<TypeParam>(type, size);
8081
std::vector<TypeParam> sortedarr = arr;
@@ -92,7 +93,7 @@ TYPED_TEST_P(simdsort, test_argsort_ascending)
9293
TYPED_TEST_P(simdsort, test_argsort_descending)
9394
{
9495
for (auto type : this->arrtype) {
95-
bool hasnan = (type == "rand_with_nan") ? true : false;
96+
bool hasnan = is_nan_test(type);
9697
for (auto size : this->arrsize) {
9798
std::vector<TypeParam> arr = get_array<TypeParam>(type, size);
9899
std::vector<TypeParam> sortedarr = arr;
@@ -111,7 +112,7 @@ TYPED_TEST_P(simdsort, test_argsort_descending)
111112
TYPED_TEST_P(simdsort, test_qselect_ascending)
112113
{
113114
for (auto type : this->arrtype) {
114-
bool hasnan = (type == "rand_with_nan") ? true : false;
115+
bool hasnan = is_nan_test(type);
115116
for (auto size : this->arrsize) {
116117
size_t k = rand() % size;
117118
std::vector<TypeParam> basearr = get_array<TypeParam>(type, size);
@@ -135,7 +136,7 @@ TYPED_TEST_P(simdsort, test_qselect_ascending)
135136
TYPED_TEST_P(simdsort, test_qselect_descending)
136137
{
137138
for (auto type : this->arrtype) {
138-
bool hasnan = (type == "rand_with_nan") ? true : false;
139+
bool hasnan = is_nan_test(type);
139140
for (auto size : this->arrsize) {
140141
size_t k = rand() % size;
141142
std::vector<TypeParam> basearr = get_array<TypeParam>(type, size);
@@ -159,7 +160,7 @@ TYPED_TEST_P(simdsort, test_qselect_descending)
159160
TYPED_TEST_P(simdsort, test_argselect)
160161
{
161162
for (auto type : this->arrtype) {
162-
bool hasnan = (type == "rand_with_nan") ? true : false;
163+
bool hasnan = is_nan_test(type);
163164
for (auto size : this->arrsize) {
164165
size_t k = rand() % size;
165166
std::vector<TypeParam> arr = get_array<TypeParam>(type, size);
@@ -179,7 +180,7 @@ TYPED_TEST_P(simdsort, test_argselect)
179180
TYPED_TEST_P(simdsort, test_partial_qsort_ascending)
180181
{
181182
for (auto type : this->arrtype) {
182-
bool hasnan = (type == "rand_with_nan") ? true : false;
183+
bool hasnan = is_nan_test(type);
183184
for (auto size : this->arrsize) {
184185
size_t k = rand() % size;
185186
std::vector<TypeParam> basearr = get_array<TypeParam>(type, size);
@@ -202,7 +203,7 @@ TYPED_TEST_P(simdsort, test_partial_qsort_ascending)
202203
TYPED_TEST_P(simdsort, test_partial_qsort_descending)
203204
{
204205
for (auto type : this->arrtype) {
205-
bool hasnan = (type == "rand_with_nan") ? true : false;
206+
bool hasnan = is_nan_test(type);
206207
for (auto size : this->arrsize) {
207208
// k should be at least 1
208209
size_t k = std::max((size_t)1, rand() % size);

utils/rand_array.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,26 @@ static std::vector<T> get_array(std::string arrtype,
140140
if (rand() & 0x1) { arr[ii] = val; }
141141
}
142142
}
143+
else if (arrtype == "rand_with_max_and_nan") {
144+
arr = get_uniform_rand_array<T>(arrsize, max, min);
145+
T max_val;
146+
T nan_val;
147+
if constexpr (xss::fp::is_floating_point_v<T>) {
148+
max_val = xss::fp::infinity<T>();
149+
nan_val = xss::fp::quiet_NaN<T>();
150+
}
151+
else {
152+
max_val = std::numeric_limits<T>::max();
153+
nan_val = std::numeric_limits<T>::max();
154+
}
155+
for (size_t ii = 0; ii < arrsize; ++ii) {
156+
int res = rand() % 4;
157+
if (res == 2) { arr[ii] = max_val; }
158+
else if (res == 3) {
159+
arr[ii] = nan_val;
160+
}
161+
}
162+
}
143163
else {
144164
std::cout << "Warning: unrecognized array type " << arrtype
145165
<< std::endl;

0 commit comments

Comments
 (0)