Skip to content

Commit 0792fbd

Browse files
author
Raghuveer Devulapalli
authored
Merge pull request #140 from sterrettm2/reverse_sort
Adds descending sort order to quicksort, quickselect, and partial sort
2 parents 06d31e7 + b0d0929 commit 0792fbd

22 files changed

+1187
-606
lines changed

benchmarks/bench-qsort.hpp

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,49 @@ static void simdsort(benchmark::State &state, Args &&...args)
3636
}
3737
}
3838

39+
template <typename T, class... Args>
40+
static void scalar_revsort(benchmark::State &state, Args &&...args)
41+
{
42+
// Get args
43+
auto args_tuple = std::make_tuple(std::move(args)...);
44+
size_t arrsize = std::get<0>(args_tuple);
45+
std::string arrtype = std::get<1>(args_tuple);
46+
// set up array
47+
std::vector<T> arr = get_array<T>(arrtype, arrsize);
48+
std::vector<T> arr_bkp = arr;
49+
// benchmark
50+
for (auto _ : state) {
51+
std::sort(arr.rbegin(), arr.rend());
52+
state.PauseTiming();
53+
arr = arr_bkp;
54+
state.ResumeTiming();
55+
}
56+
}
57+
58+
template <typename T, class... Args>
59+
static void simd_revsort(benchmark::State &state, Args &&...args)
60+
{
61+
// Get args
62+
auto args_tuple = std::make_tuple(std::move(args)...);
63+
size_t arrsize = std::get<0>(args_tuple);
64+
std::string arrtype = std::get<1>(args_tuple);
65+
// set up array
66+
std::vector<T> arr = get_array<T>(arrtype, arrsize);
67+
std::vector<T> arr_bkp = arr;
68+
// benchmark
69+
for (auto _ : state) {
70+
x86simdsort::qsort(arr.data(), arrsize, false, true);
71+
state.PauseTiming();
72+
arr = arr_bkp;
73+
state.ResumeTiming();
74+
}
75+
}
76+
3977
#define BENCH_BOTH_QSORT(type) \
4078
BENCH_SORT(simdsort, type) \
41-
BENCH_SORT(scalarsort, type)
79+
BENCH_SORT(scalarsort, type) \
80+
BENCH_SORT(simd_revsort, type) \
81+
BENCH_SORT(scalar_revsort, type)
4282

4383
BENCH_BOTH_QSORT(uint64_t)
4484
BENCH_BOTH_QSORT(int64_t)

lib/x86simdsort-avx2.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,21 @@
77

88
#define DEFINE_ALL_METHODS(type) \
99
template <> \
10-
void qsort(type *arr, size_t arrsize, bool hasnan) \
10+
void qsort(type *arr, size_t arrsize, bool hasnan, bool descending) \
1111
{ \
12-
avx2_qsort(arr, arrsize, hasnan); \
12+
avx2_qsort(arr, arrsize, hasnan, descending); \
1313
} \
1414
template <> \
15-
void qselect(type *arr, size_t k, size_t arrsize, bool hasnan) \
15+
void qselect( \
16+
type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \
1617
{ \
17-
avx2_qselect(arr, k, arrsize, hasnan); \
18+
avx2_qselect(arr, k, arrsize, hasnan, descending); \
1819
} \
1920
template <> \
20-
void partial_qsort(type *arr, size_t k, size_t arrsize, bool hasnan) \
21+
void partial_qsort( \
22+
type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \
2123
{ \
22-
avx2_partial_qsort(arr, k, arrsize, hasnan); \
24+
avx2_partial_qsort(arr, k, arrsize, hasnan, descending); \
2325
} \
2426
template <> \
2527
std::vector<size_t> argsort(type *arr, size_t arrsize, bool hasnan) \

lib/x86simdsort-icl.cpp

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,34 +5,50 @@
55
namespace xss {
66
namespace avx512 {
77
template <>
8-
void qsort(uint16_t *arr, size_t size, bool hasnan)
8+
void qsort(uint16_t *arr, size_t size, bool hasnan, bool descending)
99
{
10-
avx512_qsort(arr, size, hasnan);
10+
avx512_qsort(arr, size, hasnan, descending);
1111
}
1212
template <>
13-
void qselect(uint16_t *arr, size_t k, size_t arrsize, bool hasnan)
13+
void qselect(uint16_t *arr,
14+
size_t k,
15+
size_t arrsize,
16+
bool hasnan,
17+
bool descending)
1418
{
15-
avx512_qselect(arr, k, arrsize, hasnan);
19+
avx512_qselect(arr, k, arrsize, hasnan, descending);
1620
}
1721
template <>
18-
void partial_qsort(uint16_t *arr, size_t k, size_t arrsize, bool hasnan)
22+
void partial_qsort(uint16_t *arr,
23+
size_t k,
24+
size_t arrsize,
25+
bool hasnan,
26+
bool descending)
1927
{
20-
avx512_partial_qsort(arr, k, arrsize, hasnan);
28+
avx512_partial_qsort(arr, k, arrsize, hasnan, descending);
2129
}
2230
template <>
23-
void qsort(int16_t *arr, size_t size, bool hasnan)
31+
void qsort(int16_t *arr, size_t size, bool hasnan, bool descending)
2432
{
25-
avx512_qsort(arr, size, hasnan);
33+
avx512_qsort(arr, size, hasnan, descending);
2634
}
2735
template <>
28-
void qselect(int16_t *arr, size_t k, size_t arrsize, bool hasnan)
36+
void qselect(int16_t *arr,
37+
size_t k,
38+
size_t arrsize,
39+
bool hasnan,
40+
bool descending)
2941
{
30-
avx512_qselect(arr, k, arrsize, hasnan);
42+
avx512_qselect(arr, k, arrsize, hasnan, descending);
3143
}
3244
template <>
33-
void partial_qsort(int16_t *arr, size_t k, size_t arrsize, bool hasnan)
45+
void partial_qsort(int16_t *arr,
46+
size_t k,
47+
size_t arrsize,
48+
bool hasnan,
49+
bool descending)
3450
{
35-
avx512_partial_qsort(arr, k, arrsize, hasnan);
51+
avx512_partial_qsort(arr, k, arrsize, hasnan, descending);
3652
}
3753
} // namespace avx512
3854
} // namespace xss

lib/x86simdsort-internal.h

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,26 @@ namespace xss {
88
namespace avx512 {
99
// quicksort
1010
template <typename T>
11-
XSS_HIDE_SYMBOL void qsort(T *arr, size_t arrsize, bool hasnan = false);
11+
XSS_HIDE_SYMBOL void
12+
qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false);
1213
// key-value quicksort
1314
template <typename T1, typename T2>
1415
XSS_EXPORT_SYMBOL void
1516
keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan = false);
1617
// quickselect
1718
template <typename T>
18-
XSS_HIDE_SYMBOL void
19-
qselect(T *arr, size_t k, size_t arrsize, bool hasnan = false);
19+
XSS_HIDE_SYMBOL void qselect(T *arr,
20+
size_t k,
21+
size_t arrsize,
22+
bool hasnan = false,
23+
bool descending = false);
2024
// partial sort
2125
template <typename T>
22-
XSS_HIDE_SYMBOL void
23-
partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan = false);
26+
XSS_HIDE_SYMBOL void partial_qsort(T *arr,
27+
size_t k,
28+
size_t arrsize,
29+
bool hasnan = false,
30+
bool descending = false);
2431
// argsort
2532
template <typename T>
2633
XSS_HIDE_SYMBOL std::vector<size_t>
@@ -33,19 +40,26 @@ namespace avx512 {
3340
namespace avx2 {
3441
// quicksort
3542
template <typename T>
36-
XSS_HIDE_SYMBOL void qsort(T *arr, size_t arrsize, bool hasnan = false);
43+
XSS_HIDE_SYMBOL void
44+
qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false);
3745
// key-value quicksort
3846
template <typename T1, typename T2>
3947
XSS_EXPORT_SYMBOL void
4048
keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan = false);
4149
// quickselect
4250
template <typename T>
43-
XSS_HIDE_SYMBOL void
44-
qselect(T *arr, size_t k, size_t arrsize, bool hasnan = false);
51+
XSS_HIDE_SYMBOL void qselect(T *arr,
52+
size_t k,
53+
size_t arrsize,
54+
bool hasnan = false,
55+
bool descending = false);
4556
// partial sort
4657
template <typename T>
47-
XSS_HIDE_SYMBOL void
48-
partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan = false);
58+
XSS_HIDE_SYMBOL void partial_qsort(T *arr,
59+
size_t k,
60+
size_t arrsize,
61+
bool hasnan = false,
62+
bool descending = false);
4963
// argsort
5064
template <typename T>
5165
XSS_HIDE_SYMBOL std::vector<size_t>
@@ -58,19 +72,26 @@ namespace avx2 {
5872
namespace scalar {
5973
// quicksort
6074
template <typename T>
61-
XSS_HIDE_SYMBOL void qsort(T *arr, size_t arrsize, bool hasnan = false);
75+
XSS_HIDE_SYMBOL void
76+
qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false);
6277
// key-value quicksort
6378
template <typename T1, typename T2>
6479
XSS_EXPORT_SYMBOL void
6580
keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan = false);
6681
// quickselect
6782
template <typename T>
68-
XSS_HIDE_SYMBOL void
69-
qselect(T *arr, size_t k, size_t arrsize, bool hasnan = false);
83+
XSS_HIDE_SYMBOL void qselect(T *arr,
84+
size_t k,
85+
size_t arrsize,
86+
bool hasnan = false,
87+
bool descending = false);
7088
// partial sort
7189
template <typename T>
72-
XSS_HIDE_SYMBOL void
73-
partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan = false);
90+
XSS_HIDE_SYMBOL void partial_qsort(T *arr,
91+
size_t k,
92+
size_t arrsize,
93+
bool hasnan = false,
94+
bool descending = false);
7495
// argsort
7596
template <typename T>
7697
XSS_HIDE_SYMBOL std::vector<size_t>

lib/x86simdsort-scalar.h

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44

55
namespace xss {
66
namespace utils {
7-
/* O(1) permute array in place: stolen from
8-
* http://www.davidespataro.it/apply-a-permutation-to-a-vector */
7+
/*
8+
* O(1) permute array in place: stolen from
9+
* http://www.davidespataro.it/apply-a-permutation-to-a-vector
10+
*/
911
template <typename T>
1012
void apply_permutation_in_place(T *arr, std::vector<size_t> arg)
1113
{
@@ -21,40 +23,51 @@ namespace utils {
2123
arg[curr] = curr;
2224
}
2325
}
24-
} // namespace utils
25-
26-
namespace scalar {
2726
template <typename T>
28-
void qsort(T *arr, size_t arrsize, bool hasnan)
27+
decltype(auto) get_cmp_func(bool hasnan, bool reverse)
2928
{
29+
std::function<bool(T, T)> cmp;
3030
if (hasnan) {
31-
std::sort(arr, arr + arrsize, compare<T, std::less<T>>());
31+
if (reverse == true) { cmp = compare<T, std::greater<T>>(); }
32+
else {
33+
cmp = compare<T, std::less<T>>();
34+
}
3235
}
3336
else {
34-
std::sort(arr, arr + arrsize);
37+
if (reverse == true) { cmp = std::greater<T>(); }
38+
else {
39+
cmp = std::less<T>();
40+
}
3541
}
42+
return cmp;
3643
}
44+
} // namespace utils
45+
46+
namespace scalar {
3747
template <typename T>
38-
void qselect(T *arr, size_t k, size_t arrsize, bool hasnan)
48+
void qsort(T *arr, size_t arrsize, bool hasnan, bool reversed)
3949
{
40-
if (hasnan) {
41-
std::nth_element(
42-
arr, arr + k, arr + arrsize, compare<T, std::less<T>>());
43-
}
44-
else {
45-
std::nth_element(arr, arr + k, arr + arrsize);
46-
}
50+
std::sort(arr,
51+
arr + arrsize,
52+
xss::utils::get_cmp_func<T>(hasnan, reversed));
4753
}
54+
4855
template <typename T>
49-
void partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan)
56+
void qselect(T *arr, size_t k, size_t arrsize, bool hasnan, bool reversed)
5057
{
51-
if (hasnan) {
52-
std::partial_sort(
53-
arr, arr + k, arr + arrsize, compare<T, std::less<T>>());
54-
}
55-
else {
56-
std::partial_sort(arr, arr + k, arr + arrsize);
57-
}
58+
std::nth_element(arr,
59+
arr + k,
60+
arr + arrsize,
61+
xss::utils::get_cmp_func<T>(hasnan, reversed));
62+
}
63+
template <typename T>
64+
void
65+
partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan, bool reversed)
66+
{
67+
std::partial_sort(arr,
68+
arr + k,
69+
arr + arrsize,
70+
xss::utils::get_cmp_func<T>(hasnan, reversed));
5871
}
5972
template <typename T>
6073
std::vector<size_t> argsort(T *arr, size_t arrsize, bool hasnan)

lib/x86simdsort-skx.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,21 @@
77

88
#define DEFINE_ALL_METHODS(type) \
99
template <> \
10-
void qsort(type *arr, size_t arrsize, bool hasnan) \
10+
void qsort(type *arr, size_t arrsize, bool hasnan, bool descending) \
1111
{ \
12-
avx512_qsort(arr, arrsize, hasnan); \
12+
avx512_qsort(arr, arrsize, hasnan, descending); \
1313
} \
1414
template <> \
15-
void qselect(type *arr, size_t k, size_t arrsize, bool hasnan) \
15+
void qselect( \
16+
type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \
1617
{ \
17-
avx512_qselect(arr, k, arrsize, hasnan); \
18+
avx512_qselect(arr, k, arrsize, hasnan, descending); \
1819
} \
1920
template <> \
20-
void partial_qsort(type *arr, size_t k, size_t arrsize, bool hasnan) \
21+
void partial_qsort( \
22+
type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \
2123
{ \
22-
avx512_partial_qsort(arr, k, arrsize, hasnan); \
24+
avx512_partial_qsort(arr, k, arrsize, hasnan, descending); \
2325
} \
2426
template <> \
2527
std::vector<size_t> argsort(type *arr, size_t arrsize, bool hasnan) \

lib/x86simdsort-spr.cpp

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,36 @@
55
namespace xss {
66
namespace avx512 {
77
template <>
8-
void qsort(_Float16 *arr, size_t size, bool hasnan)
8+
void qsort(_Float16 *arr, size_t size, bool hasnan, bool descending)
99
{
10-
avx512_qsort(arr, size, hasnan);
10+
if (descending) { avx512_qsort<true>(arr, size, hasnan); }
11+
else {
12+
avx512_qsort<false>(arr, size, hasnan);
13+
}
1114
}
1215
template <>
13-
void qselect(_Float16 *arr, size_t k, size_t arrsize, bool hasnan)
16+
void qselect(_Float16 *arr,
17+
size_t k,
18+
size_t arrsize,
19+
bool hasnan,
20+
bool descending)
1421
{
15-
avx512_qselect(arr, k, arrsize, hasnan);
22+
if (descending) { avx512_qselect<true>(arr, k, arrsize, hasnan); }
23+
else {
24+
avx512_qselect<false>(arr, k, arrsize, hasnan);
25+
}
1626
}
1727
template <>
18-
void partial_qsort(_Float16 *arr, size_t k, size_t arrsize, bool hasnan)
28+
void partial_qsort(_Float16 *arr,
29+
size_t k,
30+
size_t arrsize,
31+
bool hasnan,
32+
bool descending)
1933
{
20-
avx512_partial_qsort(arr, k, arrsize, hasnan);
34+
if (descending) { avx512_partial_qsort<true>(arr, k, arrsize, hasnan); }
35+
else {
36+
avx512_partial_qsort<false>(arr, k, arrsize, hasnan);
37+
}
2138
}
2239
} // namespace avx512
2340
} // namespace xss

0 commit comments

Comments
 (0)