Skip to content

Commit 93b9c99

Browse files
committed
Adds descending sort for kvsort, kvselect, and kvpartial_sort and related tests
1 parent c2fa38c commit 93b9c99

11 files changed

+183
-76
lines changed

benchmarks/bench-keyvalue.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ static void scalarkvsort(benchmark::State &state, Args &&...args)
1313
std::vector<T> key_bkp = key;
1414
// benchmark
1515
for (auto _ : state) {
16-
xss::scalar::keyvalue_qsort(key.data(), val.data(), arrsize, false);
16+
xss::scalar::keyvalue_qsort(key.data(), val.data(), arrsize, false, false);
1717
state.PauseTiming();
1818
key = key_bkp;
1919
state.ResumeTiming();

lib/x86simdsort-avx2.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,19 @@
3636

3737
#define DEFINE_KEYVALUE_METHODS_BASE(type1, type2) \
3838
template <> \
39-
void keyvalue_qsort(type1 *key, type2 *val, size_t arrsize, bool hasnan) \
39+
void keyvalue_qsort(type1 *key, type2 *val, size_t arrsize, bool hasnan, bool descending) \
4040
{ \
41-
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
41+
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan, descending); \
4242
} \
4343
template <> \
44-
void keyvalue_select(type1 *key, type2 *val, size_t k, size_t arrsize, bool hasnan) \
44+
void keyvalue_select(type1 *key, type2 *val, size_t k, size_t arrsize, bool hasnan, bool descending) \
4545
{ \
46-
x86simdsortStatic::keyvalue_select(key, val, k, arrsize, hasnan); \
46+
x86simdsortStatic::keyvalue_select(key, val, k, arrsize, hasnan, descending); \
4747
} \
4848
template <> \
49-
void keyvalue_partial_sort(type1 *key, type2 *val, size_t k, size_t arrsize, bool hasnan) \
49+
void keyvalue_partial_sort(type1 *key, type2 *val, size_t k, size_t arrsize, bool hasnan, bool descending) \
5050
{ \
51-
x86simdsortStatic::keyvalue_partial_sort(key, val, k, arrsize, hasnan); \
51+
x86simdsortStatic::keyvalue_partial_sort(key, val, k, arrsize, hasnan, descending); \
5252
}
5353

5454
#define DEFINE_KEYVALUE_METHODS(type) \

lib/x86simdsort-internal.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ namespace avx512 {
1313
// key-value quicksort
1414
template <typename T1, typename T2>
1515
XSS_HIDE_SYMBOL void
16-
keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan = false);
16+
keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan = false, bool descending = false);
1717
// quickselect
1818
template <typename T>
1919
XSS_HIDE_SYMBOL void qselect(T *arr,
@@ -24,7 +24,7 @@ namespace avx512 {
2424
// key-value select
2525
template <typename T1, typename T2>
2626
XSS_EXPORT_SYMBOL void
27-
keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false);
27+
keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false, bool descending = false);
2828
// partial sort
2929
template <typename T>
3030
XSS_HIDE_SYMBOL void partial_qsort(T *arr,
@@ -35,7 +35,7 @@ namespace avx512 {
3535
// key-value partial sort
3636
template <typename T1, typename T2>
3737
XSS_EXPORT_SYMBOL void
38-
keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false);
38+
keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false, bool descending = false);
3939
// argsort
4040
template <typename T>
4141
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr,
@@ -66,7 +66,7 @@ namespace avx2 {
6666
// key-value select
6767
template <typename T1, typename T2>
6868
XSS_EXPORT_SYMBOL void
69-
keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false);
69+
keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false, bool descending = false);
7070
// partial sort
7171
template <typename T>
7272
XSS_HIDE_SYMBOL void partial_qsort(T *arr,
@@ -77,7 +77,7 @@ namespace avx2 {
7777
// key-value partial sort
7878
template <typename T1, typename T2>
7979
XSS_EXPORT_SYMBOL void
80-
keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false);
80+
keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false, bool descending = false);
8181
// argsort
8282
template <typename T>
8383
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr,
@@ -97,7 +97,7 @@ namespace scalar {
9797
// key-value quicksort
9898
template <typename T1, typename T2>
9999
XSS_HIDE_SYMBOL void
100-
keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan = false);
100+
keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan = false, bool descending = false);
101101
// quickselect
102102
template <typename T>
103103
XSS_HIDE_SYMBOL void qselect(T *arr,
@@ -108,7 +108,7 @@ namespace scalar {
108108
// key-value select
109109
template <typename T1, typename T2>
110110
XSS_EXPORT_SYMBOL void
111-
keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false);
111+
keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false, bool descending = false);
112112
// partial sort
113113
template <typename T>
114114
XSS_HIDE_SYMBOL void partial_qsort(T *arr,
@@ -119,7 +119,7 @@ namespace scalar {
119119
// key-value partial sort
120120
template <typename T1, typename T2>
121121
XSS_EXPORT_SYMBOL void
122-
keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false);
122+
keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false, bool descending = false);
123123
// argsort
124124
template <typename T>
125125
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr,

lib/x86simdsort-scalar.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,28 +100,28 @@ namespace scalar {
100100
return arg;
101101
}
102102
template <typename T1, typename T2>
103-
void keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan)
103+
void keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan, bool descending)
104104
{
105-
std::vector<size_t> arg = argsort(key, arrsize, hasnan, false);
105+
std::vector<size_t> arg = argsort(key, arrsize, hasnan, descending);
106106
utils::apply_permutation_in_place(key, arg);
107107
utils::apply_permutation_in_place(val, arg);
108108
}
109109
template <typename T1, typename T2>
110-
void keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan)
110+
void keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan, bool descending)
111111
{
112112
if (k == 0) return;
113113
// Note that this does a full partial sort, not just a select
114-
std::vector<size_t> arg = argsort(key, arrsize, hasnan, false);
114+
std::vector<size_t> arg = argsort(key, arrsize, hasnan, descending);
115115
//arg.resize(k);
116116

117117
utils::apply_permutation_in_place(key, arg);
118118
utils::apply_permutation_in_place(val, arg);
119119
}
120120
template <typename T1, typename T2>
121-
void keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan)
121+
void keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan, bool descending)
122122
{
123123
if (k == 0) return;
124-
std::vector<size_t> arg = argsort(key, arrsize, hasnan, false);
124+
std::vector<size_t> arg = argsort(key, arrsize, hasnan, descending);
125125
//arg.resize(k);
126126

127127
utils::apply_permutation_in_place(key, arg);

lib/x86simdsort-skx.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,19 @@
3636

3737
#define DEFINE_KEYVALUE_METHODS_BASE(type1, type2) \
3838
template <> \
39-
void keyvalue_qsort(type1 *key, type2 *val, size_t arrsize, bool hasnan) \
39+
void keyvalue_qsort(type1 *key, type2 *val, size_t arrsize, bool hasnan, bool descending) \
4040
{ \
41-
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
41+
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan, descending); \
4242
} \
4343
template <> \
44-
void keyvalue_select(type1 *key, type2 *val, size_t k, size_t arrsize, bool hasnan) \
44+
void keyvalue_select(type1 *key, type2 *val, size_t k, size_t arrsize, bool hasnan, bool descending) \
4545
{ \
46-
x86simdsortStatic::keyvalue_select(key, val, k, arrsize, hasnan); \
46+
x86simdsortStatic::keyvalue_select(key, val, k, arrsize, hasnan, descending); \
4747
} \
4848
template <> \
49-
void keyvalue_partial_sort(type1 *key, type2 *val, size_t k, size_t arrsize, bool hasnan) \
49+
void keyvalue_partial_sort(type1 *key, type2 *val, size_t k, size_t arrsize, bool hasnan, bool descending) \
5050
{ \
51-
x86simdsortStatic::keyvalue_partial_sort(key, val, k, arrsize, hasnan); \
51+
x86simdsortStatic::keyvalue_partial_sort(key, val, k, arrsize, hasnan, descending); \
5252
}
5353

5454
#define DEFINE_KEYVALUE_METHODS(type) \

lib/x86simdsort.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -131,13 +131,13 @@ namespace x86simdsort {
131131

132132
#define DISPATCH_KEYVALUE_SORT(TYPE1, TYPE2, ISA) \
133133
static void(CAT(CAT(*internal_kv_qsort_, TYPE1), TYPE2))( \
134-
TYPE1 *, TYPE2 *, size_t, bool) \
134+
TYPE1 *, TYPE2 *, size_t, bool, bool) \
135135
= NULL; \
136136
template <> \
137-
void keyvalue_qsort(TYPE1 *key, TYPE2 *val, size_t arrsize, bool hasnan) \
137+
void keyvalue_qsort(TYPE1 *key, TYPE2 *val, size_t arrsize, bool hasnan, bool descending) \
138138
{ \
139139
(CAT(CAT(*internal_kv_qsort_, TYPE1), TYPE2))( \
140-
key, val, arrsize, hasnan); \
140+
key, val, arrsize, hasnan, descending); \
141141
} \
142142
static __attribute__((constructor)) void CAT( \
143143
CAT(resolve_keyvalue_qsort_, TYPE1), TYPE2)(void) \
@@ -162,13 +162,13 @@ namespace x86simdsort {
162162
} \
163163
}\
164164
static void(CAT(CAT(*internal_kv_select_, TYPE1), TYPE2))( \
165-
TYPE1 *, TYPE2 *, size_t, size_t, bool) \
165+
TYPE1 *, TYPE2 *, size_t, size_t, bool, bool) \
166166
= NULL; \
167167
template <> \
168-
void keyvalue_select(TYPE1 *key, TYPE2 *val, size_t k, size_t arrsize, bool hasnan) \
168+
void keyvalue_select(TYPE1 *key, TYPE2 *val, size_t k, size_t arrsize, bool hasnan, bool descending) \
169169
{ \
170170
(CAT(CAT(*internal_kv_select_, TYPE1), TYPE2))( \
171-
key, val, k, arrsize, hasnan); \
171+
key, val, k, arrsize, hasnan, descending); \
172172
} \
173173
static __attribute__((constructor)) void CAT( \
174174
CAT(resolve_keyvalue_select_, TYPE1), TYPE2)(void) \
@@ -193,13 +193,13 @@ namespace x86simdsort {
193193
} \
194194
} \
195195
static void(CAT(CAT(*internal_kv_partial_sort_, TYPE1), TYPE2))( \
196-
TYPE1 *, TYPE2 *, size_t, size_t, bool) \
196+
TYPE1 *, TYPE2 *, size_t, size_t, bool, bool) \
197197
= NULL; \
198198
template <> \
199-
void keyvalue_partial_sort(TYPE1 *key, TYPE2 *val, size_t k, size_t arrsize, bool hasnan) \
199+
void keyvalue_partial_sort(TYPE1 *key, TYPE2 *val, size_t k, size_t arrsize, bool hasnan, bool descending) \
200200
{ \
201201
(CAT(CAT(*internal_kv_partial_sort_, TYPE1), TYPE2))( \
202-
key, val, k, arrsize, hasnan); \
202+
key, val, k, arrsize, hasnan, descending); \
203203
} \
204204
static __attribute__((constructor)) void CAT( \
205205
CAT(resolve_keyvalue_partial_sort_, TYPE1), TYPE2)(void) \

lib/x86simdsort.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,17 @@ argselect(T *arr, size_t k, size_t arrsize, bool hasnan = false);
4646
// keyvalue sort
4747
template <typename T1, typename T2>
4848
XSS_EXPORT_SYMBOL void
49-
keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan = false);
49+
keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan = false, bool descending = false);
5050

5151
// keyvalue select
5252
template <typename T1, typename T2>
5353
XSS_EXPORT_SYMBOL void
54-
keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false);
54+
keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false, bool descending = false);
5555

5656
// keyvalue partial sort
5757
template <typename T1, typename T2>
5858
XSS_EXPORT_SYMBOL void
59-
keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false);
59+
keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false, bool descending = false);
6060

6161
// sort an object
6262
template <typename T, typename Func>

src/x86simdsort-static-incl.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,15 @@ argselect(T *arr, size_t *arg, size_t k, size_t size, bool hasnan = false);
4747

4848
template <typename T1, typename T2>
4949
X86_SIMD_SORT_FINLINE void
50-
keyvalue_qsort(T1 *key, T2 *val, size_t size, bool hasnan = false);
50+
keyvalue_qsort(T1 *key, T2 *val, size_t size, bool hasnan = false, bool descending = false);
5151

5252
template <typename T1, typename T2>
5353
X86_SIMD_SORT_FINLINE void
54-
keyvalue_select(T1 *key, T2 *val, size_t k, size_t size, bool hasnan = false);
54+
keyvalue_select(T1 *key, T2 *val, size_t k, size_t size, bool hasnan = false, bool descending = false);
5555

5656
template <typename T1, typename T2>
5757
X86_SIMD_SORT_FINLINE void
58-
keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t size, bool hasnan = false);
58+
keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t size, bool hasnan = false, bool descending = false);
5959

6060
} // namespace x86simdsortStatic
6161

@@ -111,21 +111,21 @@ keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t size, bool hasnan = fal
111111
} \
112112
template <typename T1, typename T2> \
113113
X86_SIMD_SORT_FINLINE void x86simdsortStatic::keyvalue_qsort( \
114-
T1 *key, T2 *val, size_t size, bool hasnan) \
114+
T1 *key, T2 *val, size_t size, bool hasnan, bool descending) \
115115
{ \
116-
ISA##_qsort_kv(key, val, size, hasnan); \
116+
ISA##_qsort_kv(key, val, size, hasnan, descending); \
117117
} \
118118
template <typename T1, typename T2> \
119119
X86_SIMD_SORT_FINLINE void x86simdsortStatic::keyvalue_select( \
120-
T1 *key, T2 *val, size_t k, size_t size, bool hasnan) \
120+
T1 *key, T2 *val, size_t k, size_t size, bool hasnan, bool descending) \
121121
{ \
122-
ISA##_select_kv(key, val, k, size, hasnan); \
122+
ISA##_select_kv(key, val, k, size, hasnan, descending); \
123123
} \
124124
template <typename T1, typename T2> \
125125
X86_SIMD_SORT_FINLINE void x86simdsortStatic::keyvalue_partial_sort( \
126-
T1 *key, T2 *val, size_t k, size_t size, bool hasnan) \
126+
T1 *key, T2 *val, size_t k, size_t size, bool hasnan, bool descending) \
127127
{ \
128-
ISA##_partial_sort_kv(key, val, k, size, hasnan); \
128+
ISA##_partial_sort_kv(key, val, k, size, hasnan, descending); \
129129
}
130130

131131
/*

0 commit comments

Comments
 (0)