@@ -26,7 +26,9 @@ class simdkvsort : public ::testing::Test {
26
26
" smallrange" ,
27
27
" max_at_the_end" ,
28
28
" random_5d" ,
29
- " rand_max" };
29
+ " rand_max" ,
30
+ " rand_with_nan" ,
31
+ " rand_with_max_and_nan" };
30
32
}
31
33
std::vector<std::string> arrtype;
32
34
std::vector<size_t > arrsize = std::vector<size_t >(1024 );
@@ -123,27 +125,36 @@ bool is_kv_partialsorted(T1 *keys_comp,
123
125
}
124
126
125
127
// Now, we need to do some more work to handle keys exactly equal to the true kth
128
+ // There may be more values after the kth element with the same key,
129
+ // and thus we can find that the values of the kth elements do not match,
130
+ // even though the sort is correct.
131
+
126
132
// First, fully kvsort both arrays
127
133
xss::scalar::keyvalue_qsort<T1, T2>(keys_ref, vals_ref, size, true , false );
128
134
xss::scalar::keyvalue_qsort<T1, T2>(
129
135
keys_comp, vals_comp, size, true , false );
130
136
131
- auto trueKth = keys_ref[k];
132
- bool notFoundFirst = true ;
137
+ auto trueKthKey = keys_ref[k];
138
+ bool foundFirstKthKey = false ;
133
139
size_t i = 0 ;
134
140
141
+ // Search forwards until we find the block of keys that match the kth key,
142
+ // then find where it ends
135
143
for (; i < size; i++) {
136
- if (notFoundFirst && cmp_eq (keys_ref[i], trueKth )) {
137
- notFoundFirst = false ;
144
+ if (!foundFirstKthKey && cmp_eq (keys_ref[i], trueKthKey )) {
145
+ foundFirstKthKey = true ;
138
146
i_start = i;
139
147
}
140
- else if (!notFoundFirst && !cmp_eq (keys_ref[i], trueKth )) {
148
+ else if (foundFirstKthKey && !cmp_eq (keys_ref[i], trueKthKey )) {
141
149
break ;
142
150
}
143
151
}
144
152
145
- if (notFoundFirst) return false ;
153
+ // kth key is somehow missing? Since we got that value from keys_ref, should be impossible
154
+ if (!foundFirstKthKey) { return false ; }
146
155
156
+ // Check that the values in the kth key block match, so they are equivalent
157
+ // up to permutation, which is allowed since the sort is not stable
147
158
if (!same_values (vals_ref + i_start, vals_comp + i_start, i - i_start)) {
148
159
return false ;
149
160
}
@@ -156,7 +167,7 @@ TYPED_TEST_P(simdkvsort, test_kvsort_ascending)
156
167
using T1 = typename std::tuple_element<0 , decltype (TypeParam ())>::type;
157
168
using T2 = typename std::tuple_element<1 , decltype (TypeParam ())>::type;
158
169
for (auto type : this ->arrtype ) {
159
- bool hasnan = (type == " rand_with_nan " ) ? true : false ;
170
+ bool hasnan = is_nan_test (type) ;
160
171
for (auto size : this ->arrsize ) {
161
172
std::vector<T1> key = get_array<T1>(type, size);
162
173
std::vector<T2> val = get_array<T2>(type, size);
@@ -187,7 +198,7 @@ TYPED_TEST_P(simdkvsort, test_kvsort_descending)
187
198
using T1 = typename std::tuple_element<0 , decltype (TypeParam ())>::type;
188
199
using T2 = typename std::tuple_element<1 , decltype (TypeParam ())>::type;
189
200
for (auto type : this ->arrtype ) {
190
- bool hasnan = (type == " rand_with_nan " ) ? true : false ;
201
+ bool hasnan = is_nan_test (type) ;
191
202
for (auto size : this ->arrsize ) {
192
203
std::vector<T1> key = get_array<T1>(type, size);
193
204
std::vector<T2> val = get_array<T2>(type, size);
@@ -217,8 +228,9 @@ TYPED_TEST_P(simdkvsort, test_kvselect_ascending)
217
228
{
218
229
using T1 = typename std::tuple_element<0 , decltype (TypeParam ())>::type;
219
230
using T2 = typename std::tuple_element<1 , decltype (TypeParam ())>::type;
231
+ auto cmp_eq = compare<T1, std::equal_to<T1>>();
220
232
for (auto type : this ->arrtype ) {
221
- bool hasnan = (type == " rand_with_nan " ) ? true : false ;
233
+ bool hasnan = is_nan_test (type) ;
222
234
for (auto size : this ->arrsize ) {
223
235
size_t k = rand () % size;
224
236
@@ -237,7 +249,7 @@ TYPED_TEST_P(simdkvsort, test_kvselect_ascending)
237
249
xss::scalar::keyvalue_qsort (
238
250
key.data (), val.data (), k, hasnan, false );
239
251
240
- ASSERT_EQ (key[k], key_bckp[k]);
252
+ ASSERT_EQ (cmp_eq ( key[k], key_bckp[k]), true );
241
253
242
254
bool is_kv_partialsorted_
243
255
= is_kv_partialsorted<T1, T2>(key.data (),
@@ -260,8 +272,9 @@ TYPED_TEST_P(simdkvsort, test_kvselect_descending)
260
272
{
261
273
using T1 = typename std::tuple_element<0 , decltype (TypeParam ())>::type;
262
274
using T2 = typename std::tuple_element<1 , decltype (TypeParam ())>::type;
275
+ auto cmp_eq = compare<T1, std::equal_to<T1>>();
263
276
for (auto type : this ->arrtype ) {
264
- bool hasnan = (type == " rand_with_nan " ) ? true : false ;
277
+ bool hasnan = is_nan_test (type) ;
265
278
for (auto size : this ->arrsize ) {
266
279
size_t k = rand () % size;
267
280
@@ -280,7 +293,7 @@ TYPED_TEST_P(simdkvsort, test_kvselect_descending)
280
293
xss::scalar::keyvalue_qsort (
281
294
key.data (), val.data (), k, hasnan, true );
282
295
283
- ASSERT_EQ (key[k], key_bckp[k]);
296
+ ASSERT_EQ (cmp_eq ( key[k], key_bckp[k]), true );
284
297
285
298
bool is_kv_partialsorted_
286
299
= is_kv_partialsorted<T1, T2>(key.data (),
@@ -304,7 +317,7 @@ TYPED_TEST_P(simdkvsort, test_kvpartial_sort_ascending)
304
317
using T1 = typename std::tuple_element<0 , decltype (TypeParam ())>::type;
305
318
using T2 = typename std::tuple_element<1 , decltype (TypeParam ())>::type;
306
319
for (auto type : this ->arrtype ) {
307
- bool hasnan = (type == " rand_with_nan " ) ? true : false ;
320
+ bool hasnan = is_nan_test (type) ;
308
321
for (auto size : this ->arrsize ) {
309
322
size_t k = rand () % size;
310
323
@@ -341,7 +354,7 @@ TYPED_TEST_P(simdkvsort, test_kvpartial_sort_descending)
341
354
using T1 = typename std::tuple_element<0 , decltype (TypeParam ())>::type;
342
355
using T2 = typename std::tuple_element<1 , decltype (TypeParam ())>::type;
343
356
for (auto type : this ->arrtype ) {
344
- bool hasnan = (type == " rand_with_nan " ) ? true : false ;
357
+ bool hasnan = is_nan_test (type) ;
345
358
for (auto size : this ->arrsize ) {
346
359
size_t k = rand () % size;
347
360
0 commit comments