Skip to content

Commit 9c1b20f

Browse files
committed
Fixed testing logic
1 parent 93b9c99 commit 9c1b20f

File tree

2 files changed

+89
-28
lines changed

2 files changed

+89
-28
lines changed

lib/x86simdsort-internal.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ namespace avx512 {
2323
bool descending = false);
2424
// key-value select
2525
template <typename T1, typename T2>
26-
XSS_EXPORT_SYMBOL void
26+
XSS_HIDE_SYMBOL void
2727
keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false, bool descending = false);
2828
// partial sort
2929
template <typename T>
@@ -34,7 +34,7 @@ namespace avx512 {
3434
bool descending = false);
3535
// key-value partial sort
3636
template <typename T1, typename T2>
37-
XSS_EXPORT_SYMBOL void
37+
XSS_HIDE_SYMBOL void
3838
keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false, bool descending = false);
3939
// argsort
4040
template <typename T>
@@ -55,7 +55,7 @@ namespace avx2 {
5555
// key-value quicksort
5656
template <typename T1, typename T2>
5757
XSS_HIDE_SYMBOL void
58-
keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan = false);
58+
keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan = false, bool descending = false);
5959
// quickselect
6060
template <typename T>
6161
XSS_HIDE_SYMBOL void qselect(T *arr,
@@ -65,7 +65,7 @@ namespace avx2 {
6565
bool descending = false);
6666
// key-value select
6767
template <typename T1, typename T2>
68-
XSS_EXPORT_SYMBOL void
68+
XSS_HIDE_SYMBOL void
6969
keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false, bool descending = false);
7070
// partial sort
7171
template <typename T>
@@ -76,7 +76,7 @@ namespace avx2 {
7676
bool descending = false);
7777
// key-value partial sort
7878
template <typename T1, typename T2>
79-
XSS_EXPORT_SYMBOL void
79+
XSS_HIDE_SYMBOL void
8080
keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false, bool descending = false);
8181
// argsort
8282
template <typename T>
@@ -107,7 +107,7 @@ namespace scalar {
107107
bool descending = false);
108108
// key-value select
109109
template <typename T1, typename T2>
110-
XSS_EXPORT_SYMBOL void
110+
XSS_HIDE_SYMBOL void
111111
keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false, bool descending = false);
112112
// partial sort
113113
template <typename T>
@@ -118,7 +118,7 @@ namespace scalar {
118118
bool descending = false);
119119
// key-value partial sort
120120
template <typename T1, typename T2>
121-
XSS_EXPORT_SYMBOL void
121+
XSS_HIDE_SYMBOL void
122122
keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false, bool descending = false);
123123
// argsort
124124
template <typename T>

tests/test-keyvalue.cpp

Lines changed: 82 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,9 @@ TYPED_TEST_SUITE_P(simdkvsort);
3232

3333
template <typename T>
3434
bool same_values(T* v1, T* v2, size_t size){
35-
// Checks that the values are the same except (maybe) their ordering
35+
// Checks that the values are the same except ordering
3636
auto cmp_eq = compare<T, std::equal_to<T>>();
3737

38-
// TODO hardcoding hasnan to true doesn't break anything right?
3938
x86simdsort::qsort(v1, size, true);
4039
x86simdsort::qsort(v2, size, true);
4140

@@ -49,7 +48,7 @@ bool same_values(T* v1, T* v2, size_t size){
4948
}
5049

5150
template <typename T1, typename T2>
52-
bool kv_equivalent(T1* keys_comp, T2* vals_comp, T1* keys_ref, T2* vals_ref, size_t size){
51+
bool is_kv_sorted(T1* keys_comp, T2* vals_comp, T1* keys_ref, T2* vals_ref, size_t size){
5352
auto cmp_eq = compare<T1, std::equal_to<T1>>();
5453

5554
// First check keys are exactly identical
@@ -66,7 +65,7 @@ bool kv_equivalent(T1* keys_comp, T2* vals_comp, T1* keys_ref, T2* vals_ref, siz
6665
size_t i = 0;
6766
for (; i < size; i++){
6867
if (!cmp_eq(keys_comp[i], key_start)){
69-
// Check that every value in
68+
// Check that every value in this block of constant keys
7069

7170
if (!same_values(vals_ref + i_start, vals_comp + i_start, i - i_start)){
7271
return false;
@@ -78,6 +77,66 @@ bool kv_equivalent(T1* keys_comp, T2* vals_comp, T1* keys_ref, T2* vals_ref, siz
7877
}
7978
}
8079

80+
// Handle the last group
81+
if (!same_values(vals_ref + i_start, vals_comp + i_start, i - i_start)){
82+
return false;
83+
}
84+
85+
return true;
86+
}
87+
88+
template <typename T1, typename T2>
89+
bool is_kv_partialsorted(T1* keys_comp, T2* vals_comp, T1* keys_ref, T2* vals_ref, size_t size, size_t k){
90+
auto cmp_eq = compare<T1, std::equal_to<T1>>();
91+
92+
// First check keys are exactly identical (up to k)
93+
for (size_t i = 0; i < k; i++){
94+
if (!cmp_eq(keys_comp[i], keys_ref[i])){
95+
return false;
96+
}
97+
}
98+
99+
size_t i_start = 0;
100+
T1 key_start = keys_comp[0];
101+
// Loop through all identical keys in a block, then compare the sets of values to make sure they are identical
102+
for (size_t i = 0; i < k; i++){
103+
if (!cmp_eq(keys_comp[i], key_start)){
104+
// Check that every value in this block of constant keys
105+
106+
if (!same_values(vals_ref + i_start, vals_comp + i_start, i - i_start)){
107+
return false;
108+
}
109+
110+
// Now setup the start variables to begin gathering keys for the next group
111+
i_start = i;
112+
key_start = keys_comp[i];
113+
}
114+
}
115+
116+
// Now, we need to do some more work to handle keys exactly equal to the true kth
117+
// First, fully kvsort both arrays
118+
xss::scalar::keyvalue_qsort<T1, T2>(keys_ref, vals_ref, size, true, false);
119+
xss::scalar::keyvalue_qsort<T1, T2>(keys_comp, vals_comp, size, true, false);
120+
121+
auto trueKth = keys_ref[k];
122+
bool notFoundFirst = true;
123+
size_t i = 0;
124+
125+
for (; i < size; i++){
126+
if (notFoundFirst && cmp_eq(keys_ref[i], trueKth)){
127+
notFoundFirst = false;
128+
i_start = i;
129+
}else if (!notFoundFirst && !cmp_eq(keys_ref[i], trueKth)){
130+
break;
131+
}
132+
}
133+
134+
if (notFoundFirst) return false;
135+
136+
if (!same_values(vals_ref + i_start, vals_comp + i_start, i - i_start)){
137+
return false;
138+
}
139+
81140
return true;
82141
}
83142

@@ -96,8 +155,8 @@ TYPED_TEST_P(simdkvsort, test_kvsort_ascending)
96155
xss::scalar::keyvalue_qsort(
97156
key_bckp.data(), val_bckp.data(), size, hasnan, false);
98157

99-
bool is_kv_equivalent = kv_equivalent<T1, T2>(key.data(), val.data(), key_bckp.data(), val_bckp.data(), size);
100-
ASSERT_EQ(is_kv_equivalent, true);
158+
bool is_kv_sorted_ = is_kv_sorted<T1, T2>(key.data(), val.data(), key_bckp.data(), val_bckp.data(), size);
159+
ASSERT_EQ(is_kv_sorted_, true);
101160

102161
key.clear();
103162
val.clear();
@@ -122,8 +181,8 @@ TYPED_TEST_P(simdkvsort, test_kvsort_descending)
122181
xss::scalar::keyvalue_qsort(
123182
key_bckp.data(), val_bckp.data(), size, hasnan, true);
124183

125-
bool is_kv_equivalent = kv_equivalent<T1, T2>(key.data(), val.data(), key_bckp.data(), val_bckp.data(), size);
126-
ASSERT_EQ(is_kv_equivalent, true);
184+
bool is_kv_sorted_ = is_kv_sorted<T1, T2>(key.data(), val.data(), key_bckp.data(), val_bckp.data(), size);
185+
ASSERT_EQ(is_kv_sorted_, true);
127186

128187
key.clear();
129188
val.clear();
@@ -155,9 +214,10 @@ TYPED_TEST_P(simdkvsort, test_kvselect_ascending)
155214
IS_ARR_PARTITIONED<T1>(key, k, key_bckp[k], type);
156215
xss::scalar::keyvalue_qsort(key.data(), val.data(), k, hasnan, false);
157216

217+
ASSERT_EQ(key[k], key_bckp[k]);
158218

159-
bool is_kv_equivalent = kv_equivalent<T1, T2>(key.data(), val.data(), key_bckp.data(), val_bckp.data(), k);
160-
ASSERT_EQ(is_kv_equivalent, true);
219+
bool is_kv_partialsorted_ = is_kv_partialsorted<T1, T2>(key.data(), val.data(), key_bckp.data(), val_bckp.data(), size, k);
220+
ASSERT_EQ(is_kv_partialsorted_, true);
161221

162222
key.clear();
163223
val.clear();
@@ -189,9 +249,10 @@ TYPED_TEST_P(simdkvsort, test_kvselect_descending)
189249
IS_ARR_PARTITIONED<T1>(key, k, key_bckp[k], type, true);
190250
xss::scalar::keyvalue_qsort(key.data(), val.data(), k, hasnan, true);
191251

252+
ASSERT_EQ(key[k], key_bckp[k]);
192253

193-
bool is_kv_equivalent = kv_equivalent<T1, T2>(key.data(), val.data(), key_bckp.data(), val_bckp.data(), k);
194-
ASSERT_EQ(is_kv_equivalent, true);
254+
bool is_kv_partialsorted_ = is_kv_partialsorted<T1, T2>(key.data(), val.data(), key_bckp.data(), val_bckp.data(), size, k);
255+
ASSERT_EQ(is_kv_partialsorted_, true);
195256

196257
key.clear();
197258
val.clear();
@@ -220,8 +281,8 @@ TYPED_TEST_P(simdkvsort, test_kvpartial_sort_ascending)
220281

221282
IS_ARR_PARTIALSORTED<T1>(key, k, key_bckp, type);
222283

223-
bool is_kv_equivalent = kv_equivalent<T1, T2>(key.data(), val.data(), key_bckp.data(), val_bckp.data(), k);
224-
ASSERT_EQ(is_kv_equivalent, true);
284+
bool is_kv_partialsorted_ = is_kv_partialsorted<T1, T2>(key.data(), val.data(), key_bckp.data(), val_bckp.data(), size, k);
285+
ASSERT_EQ(is_kv_partialsorted_, true);
225286

226287
key.clear();
227288
val.clear();
@@ -250,8 +311,8 @@ TYPED_TEST_P(simdkvsort, test_kvpartial_sort_descending)
250311

251312
IS_ARR_PARTIALSORTED<T1>(key, k, key_bckp, type);
252313

253-
bool is_kv_equivalent = kv_equivalent<T1, T2>(key.data(), val.data(), key_bckp.data(), val_bckp.data(), k);
254-
ASSERT_EQ(is_kv_equivalent, true);
314+
bool is_kv_partialsorted_ = is_kv_partialsorted<T1, T2>(key.data(), val.data(), key_bckp.data(), val_bckp.data(), size, k);
315+
ASSERT_EQ(is_kv_partialsorted_, true);
255316

256317
key.clear();
257318
val.clear();
@@ -275,26 +336,26 @@ TYPED_TEST_P(simdkvsort, test_validator)
275336
std::vector<T2> val_bckp = val;
276337

277338
// Duplicate keys, but otherwise exactly identical
278-
is_kv_equivalent = kv_equivalent<T1, T2>(key.data(), val.data(), key_bckp.data(), val_bckp.data(), key.size());
339+
is_kv_equivalent = is_kv_sorted<T1, T2>(key.data(), val.data(), key_bckp.data(), val_bckp.data(), key.size());
279340
ASSERT_EQ(is_kv_equivalent, true);
280341

281342
val = {2,1,4,3};
282343

283344
// Now values are backwards, but this is still fine
284-
is_kv_equivalent = kv_equivalent<T1, T2>(key.data(), val.data(), key_bckp.data(), val_bckp.data(), key.size());
345+
is_kv_equivalent = is_kv_sorted<T1, T2>(key.data(), val.data(), key_bckp.data(), val_bckp.data(), key.size());
285346
ASSERT_EQ(is_kv_equivalent, true);
286347

287348
val = {1,3,2,4};
288349

289350
// Now values are mixed up, should fail
290-
is_kv_equivalent = kv_equivalent<T1, T2>(key.data(), val.data(), key_bckp.data(), val_bckp.data(), key.size());
351+
is_kv_equivalent = is_kv_sorted<T1, T2>(key.data(), val.data(), key_bckp.data(), val_bckp.data(), key.size());
291352
ASSERT_EQ(is_kv_equivalent, false);
292353

293354
val = {1,2,3,4};
294355
key = {0,0,0,0};
295356

296357
// Now keys are messed up, should fail
297-
is_kv_equivalent = kv_equivalent<T1, T2>(key.data(), val.data(), key_bckp.data(), val_bckp.data(), key.size());
358+
is_kv_equivalent = is_kv_sorted<T1, T2>(key.data(), val.data(), key_bckp.data(), val_bckp.data(), key.size());
298359
ASSERT_EQ(is_kv_equivalent, false);
299360

300361
key = {0,0,0,0,0,0};
@@ -303,7 +364,7 @@ TYPED_TEST_P(simdkvsort, test_validator)
303364
val = {4,3,1,6,5,2};
304365

305366
// All keys identical, simply reordered values
306-
is_kv_equivalent = kv_equivalent<T1, T2>(key.data(), val.data(), key_bckp.data(), val_bckp.data(), key.size());
367+
is_kv_equivalent = is_kv_sorted<T1, T2>(key.data(), val.data(), key_bckp.data(), val_bckp.data(), key.size());
307368
ASSERT_EQ(is_kv_equivalent, true);
308369
}
309370

0 commit comments

Comments
 (0)