Skip to content

Commit c2fa38c

Browse files
committed
Support for key-value select and partial sort
1 parent 6175892 commit c2fa38c

8 files changed

+290
-42
lines changed

lib/x86simdsort-avx2.cpp

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,37 +34,30 @@
3434
return x86simdsortStatic::argselect(arr, k, arrsize, hasnan); \
3535
}
3636

37-
#define DEFINE_KEYVALUE_METHODS(type) \
38-
template <> \
39-
void keyvalue_qsort(type *key, uint64_t *val, size_t arrsize, bool hasnan) \
40-
{ \
41-
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
42-
} \
43-
template <> \
44-
void keyvalue_qsort(type *key, int64_t *val, size_t arrsize, bool hasnan) \
45-
{ \
46-
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
47-
} \
48-
template <> \
49-
void keyvalue_qsort(type *key, double *val, size_t arrsize, bool hasnan) \
50-
{ \
51-
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
52-
} \
37+
#define DEFINE_KEYVALUE_METHODS_BASE(type1, type2) \
5338
template <> \
54-
void keyvalue_qsort(type *key, uint32_t *val, size_t arrsize, bool hasnan) \
39+
void keyvalue_qsort(type1 *key, type2 *val, size_t arrsize, bool hasnan) \
5540
{ \
5641
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
5742
} \
5843
template <> \
59-
void keyvalue_qsort(type *key, int32_t *val, size_t arrsize, bool hasnan) \
44+
void keyvalue_select(type1 *key, type2 *val, size_t k, size_t arrsize, bool hasnan) \
6045
{ \
61-
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
46+
x86simdsortStatic::keyvalue_select(key, val, k, arrsize, hasnan); \
6247
} \
6348
template <> \
64-
void keyvalue_qsort(type *key, float *val, size_t arrsize, bool hasnan) \
49+
void keyvalue_partial_sort(type1 *key, type2 *val, size_t k, size_t arrsize, bool hasnan) \
6550
{ \
66-
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
51+
x86simdsortStatic::keyvalue_partial_sort(key, val, k, arrsize, hasnan); \
6752
}
53+
54+
#define DEFINE_KEYVALUE_METHODS(type) \
55+
DEFINE_KEYVALUE_METHODS_BASE(type, uint64_t) \
56+
DEFINE_KEYVALUE_METHODS_BASE(type, int64_t) \
57+
DEFINE_KEYVALUE_METHODS_BASE(type, double) \
58+
DEFINE_KEYVALUE_METHODS_BASE(type, uint32_t) \
59+
DEFINE_KEYVALUE_METHODS_BASE(type, int32_t) \
60+
DEFINE_KEYVALUE_METHODS_BASE(type, float)
6861

6962
namespace xss {
7063
namespace avx2 {

lib/x86simdsort-internal.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,21 @@ namespace avx512 {
2121
size_t arrsize,
2222
bool hasnan = false,
2323
bool descending = false);
24+
// key-value select
25+
template <typename T1, typename T2>
26+
XSS_EXPORT_SYMBOL void
27+
keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false);
2428
// partial sort
2529
template <typename T>
2630
XSS_HIDE_SYMBOL void partial_qsort(T *arr,
2731
size_t k,
2832
size_t arrsize,
2933
bool hasnan = false,
3034
bool descending = false);
35+
// key-value partial sort
36+
template <typename T1, typename T2>
37+
XSS_EXPORT_SYMBOL void
38+
keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false);
3139
// argsort
3240
template <typename T>
3341
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr,
@@ -55,13 +63,21 @@ namespace avx2 {
5563
size_t arrsize,
5664
bool hasnan = false,
5765
bool descending = false);
66+
// key-value select
67+
template <typename T1, typename T2>
68+
XSS_EXPORT_SYMBOL void
69+
keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false);
5870
// partial sort
5971
template <typename T>
6072
XSS_HIDE_SYMBOL void partial_qsort(T *arr,
6173
size_t k,
6274
size_t arrsize,
6375
bool hasnan = false,
6476
bool descending = false);
77+
// key-value partial sort
78+
template <typename T1, typename T2>
79+
XSS_EXPORT_SYMBOL void
80+
keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false);
6581
// argsort
6682
template <typename T>
6783
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr,
@@ -89,13 +105,21 @@ namespace scalar {
89105
size_t arrsize,
90106
bool hasnan = false,
91107
bool descending = false);
108+
// key-value select
109+
template <typename T1, typename T2>
110+
XSS_EXPORT_SYMBOL void
111+
keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false);
92112
// partial sort
93113
template <typename T>
94114
XSS_HIDE_SYMBOL void partial_qsort(T *arr,
95115
size_t k,
96116
size_t arrsize,
97117
bool hasnan = false,
98118
bool descending = false);
119+
// key-value partial sort
120+
template <typename T1, typename T2>
121+
XSS_EXPORT_SYMBOL void
122+
keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false);
99123
// argsort
100124
template <typename T>
101125
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr,

lib/x86simdsort-scalar.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,27 @@ namespace scalar {
106106
utils::apply_permutation_in_place(key, arg);
107107
utils::apply_permutation_in_place(val, arg);
108108
}
109+
template <typename T1, typename T2>
110+
void keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan)
111+
{
112+
if (k == 0) return;
113+
// Note that this does a full partial sort, not just a select
114+
std::vector<size_t> arg = argsort(key, arrsize, hasnan, false);
115+
//arg.resize(k);
116+
117+
utils::apply_permutation_in_place(key, arg);
118+
utils::apply_permutation_in_place(val, arg);
119+
}
120+
template <typename T1, typename T2>
121+
void keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan)
122+
{
123+
if (k == 0) return;
124+
std::vector<size_t> arg = argsort(key, arrsize, hasnan, false);
125+
//arg.resize(k);
126+
127+
utils::apply_permutation_in_place(key, arg);
128+
utils::apply_permutation_in_place(val, arg);
129+
}
109130

110131
} // namespace scalar
111132
} // namespace xss

lib/x86simdsort-skx.cpp

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,37 +34,30 @@
3434
return x86simdsortStatic::argselect(arr, k, arrsize, hasnan); \
3535
}
3636

37-
#define DEFINE_KEYVALUE_METHODS(type) \
38-
template <> \
39-
void keyvalue_qsort(type *key, uint64_t *val, size_t arrsize, bool hasnan) \
40-
{ \
41-
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
42-
} \
43-
template <> \
44-
void keyvalue_qsort(type *key, int64_t *val, size_t arrsize, bool hasnan) \
45-
{ \
46-
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
47-
} \
48-
template <> \
49-
void keyvalue_qsort(type *key, double *val, size_t arrsize, bool hasnan) \
50-
{ \
51-
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
52-
} \
37+
#define DEFINE_KEYVALUE_METHODS_BASE(type1, type2) \
5338
template <> \
54-
void keyvalue_qsort(type *key, uint32_t *val, size_t arrsize, bool hasnan) \
39+
void keyvalue_qsort(type1 *key, type2 *val, size_t arrsize, bool hasnan) \
5540
{ \
5641
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
5742
} \
5843
template <> \
59-
void keyvalue_qsort(type *key, int32_t *val, size_t arrsize, bool hasnan) \
44+
void keyvalue_select(type1 *key, type2 *val, size_t k, size_t arrsize, bool hasnan) \
6045
{ \
61-
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
46+
x86simdsortStatic::keyvalue_select(key, val, k, arrsize, hasnan); \
6247
} \
6348
template <> \
64-
void keyvalue_qsort(type *key, float *val, size_t arrsize, bool hasnan) \
49+
void keyvalue_partial_sort(type1 *key, type2 *val, size_t k, size_t arrsize, bool hasnan) \
6550
{ \
66-
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
51+
x86simdsortStatic::keyvalue_partial_sort(key, val, k, arrsize, hasnan); \
6752
}
53+
54+
#define DEFINE_KEYVALUE_METHODS(type) \
55+
DEFINE_KEYVALUE_METHODS_BASE(type, uint64_t) \
56+
DEFINE_KEYVALUE_METHODS_BASE(type, int64_t) \
57+
DEFINE_KEYVALUE_METHODS_BASE(type, double) \
58+
DEFINE_KEYVALUE_METHODS_BASE(type, uint32_t) \
59+
DEFINE_KEYVALUE_METHODS_BASE(type, int32_t) \
60+
DEFINE_KEYVALUE_METHODS_BASE(type, float)
6861

6962
namespace xss {
7063
namespace avx512 {

lib/x86simdsort.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,68 @@ namespace x86simdsort {
160160
return; \
161161
} \
162162
} \
163+
}\
164+
static void(CAT(CAT(*internal_kv_select_, TYPE1), TYPE2))( \
165+
TYPE1 *, TYPE2 *, size_t, size_t, bool) \
166+
= NULL; \
167+
template <> \
168+
void keyvalue_select(TYPE1 *key, TYPE2 *val, size_t k, size_t arrsize, bool hasnan) \
169+
{ \
170+
(CAT(CAT(*internal_kv_select_, TYPE1), TYPE2))( \
171+
key, val, k, arrsize, hasnan); \
172+
} \
173+
static __attribute__((constructor)) void CAT( \
174+
CAT(resolve_keyvalue_select_, TYPE1), TYPE2)(void) \
175+
{ \
176+
CAT(CAT(internal_kv_select_, TYPE1), TYPE2) \
177+
= &xss::scalar::keyvalue_select<TYPE1, TYPE2>; \
178+
__builtin_cpu_init(); \
179+
std::string_view preferred_cpu = find_preferred_cpu(ISA); \
180+
if constexpr (dispatch_requested("avx512", ISA)) { \
181+
if (preferred_cpu.find("avx512") != std::string_view::npos) { \
182+
CAT(CAT(internal_kv_select_, TYPE1), TYPE2) \
183+
= &xss::avx512::keyvalue_select<TYPE1, TYPE2>; \
184+
return; \
185+
} \
186+
} \
187+
if constexpr (dispatch_requested("avx2", ISA)) { \
188+
if (preferred_cpu.find("avx2") != std::string_view::npos) { \
189+
CAT(CAT(internal_kv_select_, TYPE1), TYPE2) \
190+
= &xss::avx2::keyvalue_select<TYPE1, TYPE2>; \
191+
return; \
192+
} \
193+
} \
194+
} \
195+
static void(CAT(CAT(*internal_kv_partial_sort_, TYPE1), TYPE2))( \
196+
TYPE1 *, TYPE2 *, size_t, size_t, bool) \
197+
= NULL; \
198+
template <> \
199+
void keyvalue_partial_sort(TYPE1 *key, TYPE2 *val, size_t k, size_t arrsize, bool hasnan) \
200+
{ \
201+
(CAT(CAT(*internal_kv_partial_sort_, TYPE1), TYPE2))( \
202+
key, val, k, arrsize, hasnan); \
203+
} \
204+
static __attribute__((constructor)) void CAT( \
205+
CAT(resolve_keyvalue_partial_sort_, TYPE1), TYPE2)(void) \
206+
{ \
207+
CAT(CAT(internal_kv_partial_sort_, TYPE1), TYPE2) \
208+
= &xss::scalar::keyvalue_partial_sort<TYPE1, TYPE2>; \
209+
__builtin_cpu_init(); \
210+
std::string_view preferred_cpu = find_preferred_cpu(ISA); \
211+
if constexpr (dispatch_requested("avx512", ISA)) { \
212+
if (preferred_cpu.find("avx512") != std::string_view::npos) { \
213+
CAT(CAT(internal_kv_partial_sort_, TYPE1), TYPE2) \
214+
= &xss::avx512::keyvalue_partial_sort<TYPE1, TYPE2>; \
215+
return; \
216+
} \
217+
} \
218+
if constexpr (dispatch_requested("avx2", ISA)) { \
219+
if (preferred_cpu.find("avx2") != std::string_view::npos) { \
220+
CAT(CAT(internal_kv_partial_sort_, TYPE1), TYPE2) \
221+
= &xss::avx2::keyvalue_partial_sort<TYPE1, TYPE2>; \
222+
return; \
223+
} \
224+
} \
163225
}
164226

165227
#define ISA_LIST(...) \

lib/x86simdsort.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,16 @@ template <typename T1, typename T2>
4848
XSS_EXPORT_SYMBOL void
4949
keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan = false);
5050

51+
// keyvalue select
52+
template <typename T1, typename T2>
53+
XSS_EXPORT_SYMBOL void
54+
keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false);
55+
56+
// keyvalue partial sort
57+
template <typename T1, typename T2>
58+
XSS_EXPORT_SYMBOL void
59+
keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false);
60+
5161
// sort an object
5262
template <typename T, typename Func>
5363
XSS_EXPORT_SYMBOL void object_qsort(T *arr, uint32_t arrsize, Func key_func)

src/x86simdsort-static-incl.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,14 @@ template <typename T1, typename T2>
4949
X86_SIMD_SORT_FINLINE void
5050
keyvalue_qsort(T1 *key, T2 *val, size_t size, bool hasnan = false);
5151

52+
template <typename T1, typename T2>
53+
X86_SIMD_SORT_FINLINE void
54+
keyvalue_select(T1 *key, T2 *val, size_t k, size_t size, bool hasnan = false);
55+
56+
template <typename T1, typename T2>
57+
X86_SIMD_SORT_FINLINE void
58+
keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t size, bool hasnan = false);
59+
5260
} // namespace x86simdsortStatic
5361

5462
#define XSS_METHODS(ISA) \
@@ -106,6 +114,18 @@ keyvalue_qsort(T1 *key, T2 *val, size_t size, bool hasnan = false);
106114
T1 *key, T2 *val, size_t size, bool hasnan) \
107115
{ \
108116
ISA##_qsort_kv(key, val, size, hasnan); \
117+
} \
118+
template <typename T1, typename T2> \
119+
X86_SIMD_SORT_FINLINE void x86simdsortStatic::keyvalue_select( \
120+
T1 *key, T2 *val, size_t k, size_t size, bool hasnan) \
121+
{ \
122+
ISA##_select_kv(key, val, k, size, hasnan); \
123+
} \
124+
template <typename T1, typename T2> \
125+
X86_SIMD_SORT_FINLINE void x86simdsortStatic::keyvalue_partial_sort( \
126+
T1 *key, T2 *val, size_t k, size_t size, bool hasnan) \
127+
{ \
128+
ISA##_partial_sort_kv(key, val, k, size, hasnan); \
109129
}
110130

111131
/*

0 commit comments

Comments
 (0)