Skip to content

Commit b0d0929

Browse files
author
Raghuveer Devulapalli
committed
Condense AscendingComparator and DescendingComparator into one class
1 parent b629ba5 commit b0d0929

File tree

5 files changed

+56
-86
lines changed

5 files changed

+56
-86
lines changed

src/avx512-16bit-qsort.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -574,11 +574,11 @@ X86_SIMD_SORT_INLINE void avx512_qsort_fp16(uint16_t *arr,
574574
arr, arrsize);
575575
}
576576
if (descending) {
577-
qsort_<vtype, DescendingComparator<vtype>, uint16_t>(
577+
qsort_<vtype, Comparator<vtype, true>, uint16_t>(
578578
arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize));
579579
}
580580
else {
581-
qsort_<vtype, AscendingComparator<vtype>, uint16_t>(
581+
qsort_<vtype, Comparator<vtype, false>, uint16_t>(
582582
arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize));
583583
}
584584
replace_inf_with_nan(arr, arrsize, nan_count, descending);
@@ -599,15 +599,15 @@ X86_SIMD_SORT_INLINE void avx512_qselect_fp16(uint16_t *arr,
599599
}
600600
if (indx_last_elem >= k) {
601601
if (descending) {
602-
qselect_<vtype, DescendingComparator<vtype>, uint16_t>(
602+
qselect_<vtype, Comparator<vtype, true>, uint16_t>(
603603
arr,
604604
k,
605605
0,
606606
indx_last_elem,
607607
2 * (arrsize_t)log2(indx_last_elem));
608608
}
609609
else {
610-
qselect_<vtype, AscendingComparator<vtype>, uint16_t>(
610+
qselect_<vtype, Comparator<vtype, false>, uint16_t>(
611611
arr,
612612
k,
613613
0,

src/avx512fp16-16bit-qsort.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,8 @@ avx512_qsort(_Float16 *arr, arrsize_t arrsize, bool hasnan)
208208
using vtype = zmm_vector<_Float16>;
209209
using comparator =
210210
typename std::conditional<descending,
211-
DescendingComparator<vtype>,
212-
AscendingComparator<vtype>>::type;
211+
Comparator<vtype, true>,
212+
Comparator<vtype, false>>::type;
213213

214214
if (arrsize > 1) {
215215
arrsize_t nan_count = 0;
@@ -231,8 +231,8 @@ avx512_qselect(_Float16 *arr, arrsize_t k, arrsize_t arrsize, bool hasnan)
231231
using vtype = zmm_vector<_Float16>;
232232
using comparator =
233233
typename std::conditional<descending,
234-
DescendingComparator<vtype>,
235-
AscendingComparator<vtype>>::type;
234+
Comparator<vtype, true>,
235+
Comparator<vtype, false>>::type;
236236

237237
arrsize_t index_first_elem = 0;
238238
arrsize_t index_last_elem = arrsize - 1;

src/xss-common-comparators.hpp

Lines changed: 43 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -34,124 +34,94 @@ type_t next_value(type_t value)
3434
template <typename vtype, typename mm_t>
3535
X86_SIMD_SORT_INLINE void COEX(mm_t &a, mm_t &b);
3636

37-
template <typename vtype>
38-
struct AscendingComparator {
37+
template <typename vtype, bool descend>
38+
struct Comparator {
3939
using reg_t = typename vtype::reg_t;
4040
using opmask_t = typename vtype::opmask_t;
4141
using type_t = typename vtype::type_t;
4242

4343
X86_SIMD_SORT_FINLINE bool STDSortComparator(const type_t &a,
4444
const type_t &b)
4545
{
46-
return comparison_func<vtype>(a, b);
46+
if constexpr (descend) { return comparison_func<vtype>(b, a); }
47+
else {
48+
return comparison_func<vtype>(a, b);
49+
}
4750
}
4851

4952
X86_SIMD_SORT_FINLINE opmask_t PartitionComparator(reg_t a, reg_t b)
5053
{
51-
return vtype::ge(a, b);
54+
if constexpr (descend) { return vtype::ge(b, a); }
55+
else {
56+
return vtype::ge(a, b);
57+
}
5258
}
5359

5460
X86_SIMD_SORT_FINLINE void COEX(reg_t &a, reg_t &b)
5561
{
56-
::COEX<vtype, reg_t>(a, b);
62+
if constexpr (descend) { ::COEX<vtype, reg_t>(b, a); }
63+
else {
64+
::COEX<vtype, reg_t>(a, b);
65+
}
5766
}
5867

5968
// Returns a vector of values that would be sorted as far right as possible
6069
// For ascending order, this is the maximum possible value
6170
X86_SIMD_SORT_FINLINE reg_t rightmostPossibleVec()
6271
{
63-
return vtype::zmm_max();
72+
if constexpr (descend) { return vtype::zmm_min(); }
73+
else {
74+
return vtype::zmm_max();
75+
}
6476
}
6577

6678
// Returns the value that would be leftmost of the two when sorted
6779
// For ascending order, that is the smaller value
6880
X86_SIMD_SORT_FINLINE type_t leftmost(type_t smaller, type_t larger)
6981
{
70-
UNUSED(larger);
71-
return smaller;
82+
if constexpr (descend) {
83+
UNUSED(smaller);
84+
return larger;
85+
}
86+
else {
87+
UNUSED(larger);
88+
return smaller;
89+
}
7290
}
7391

7492
// Returns the value that would be rightmost of the two when sorted
7593
// For ascending order, that is the larger value
7694
X86_SIMD_SORT_FINLINE type_t rightmost(type_t smaller, type_t larger)
7795
{
78-
UNUSED(smaller);
79-
return larger;
96+
if constexpr (descend) {
97+
UNUSED(larger);
98+
return smaller;
99+
}
100+
else {
101+
UNUSED(smaller);
102+
return larger;
103+
}
80104
}
81105

82106
// If median == smallest, that implies approximately half the array is equal to smallest, unless we were very unlucky with our sample
83107
// Try just doing the next largest value greater than this seemingly very common value to seperate them out
84108
X86_SIMD_SORT_FINLINE type_t choosePivotMedianIsSmallest(type_t median)
85109
{
86-
return next_value<type_t>(median);
110+
if constexpr (descend) { return median; }
111+
else {
112+
return next_value<type_t>(median);
113+
}
87114
}
88115

89116
// If median == largest, that implies approximately half the array is equal to largest, unless we were very unlucky with our sample
90117
// Thus, median probably is a fine pivot, since it will move all of this common value into its own partition
91118
X86_SIMD_SORT_FINLINE type_t choosePivotMedianIsLargest(type_t median)
92119
{
93-
return median;
94-
}
95-
};
96-
97-
template <typename vtype>
98-
struct DescendingComparator {
99-
using reg_t = typename vtype::reg_t;
100-
using opmask_t = typename vtype::opmask_t;
101-
using type_t = typename vtype::type_t;
102-
103-
X86_SIMD_SORT_FINLINE bool STDSortComparator(const type_t &a,
104-
const type_t &b)
105-
{
106-
return comparison_func<vtype>(b, a);
107-
}
108-
109-
X86_SIMD_SORT_FINLINE opmask_t PartitionComparator(reg_t a, reg_t b)
110-
{
111-
return vtype::ge(b, a);
112-
}
113-
114-
X86_SIMD_SORT_FINLINE void COEX(reg_t &a, reg_t &b)
115-
{
116-
::COEX<vtype, reg_t>(b, a);
117-
}
118-
119-
// Returns a vector of values that would be sorted as far right as possible
120-
// For descending order, this is the minimum possible value
121-
X86_SIMD_SORT_FINLINE reg_t rightmostPossibleVec()
122-
{
123-
return vtype::zmm_min();
124-
}
125-
126-
// Returns the value that would be leftmost of the two when sorted
127-
// For descending order, that is the larger value
128-
X86_SIMD_SORT_FINLINE type_t leftmost(type_t smaller, type_t bigger)
129-
{
130-
UNUSED(smaller);
131-
return bigger;
132-
}
133-
134-
// Returns the value that would be rightmost of the two when sorted
135-
// For descending order, that is the smaller value
136-
X86_SIMD_SORT_FINLINE type_t rightmost(type_t smaller, type_t bigger)
137-
{
138-
UNUSED(bigger);
139-
return smaller;
140-
}
141-
142-
// If median == smallest, that implies approximately half the array is equal to smallest, unless we were very unlucky with our sample
143-
// Thus, median probably is a fine pivot, since it will move all of this common value into its own partition
144-
X86_SIMD_SORT_FINLINE type_t choosePivotMedianIsSmallest(type_t median)
145-
{
146-
return median;
147-
}
148-
149-
// If median == largest, that implies approximately half the array is equal to largest, unless we were very unlucky with our sample
150-
// Try just doing the next smallest value less than this seemingly very common value to seperate them out
151-
X86_SIMD_SORT_FINLINE type_t choosePivotMedianIsLargest(type_t median)
152-
{
153-
return prev_value<type_t>(median);
120+
if constexpr (descend) { return prev_value<type_t>(median); }
121+
else {
122+
return median;
123+
}
154124
}
155125
};
156126

157-
#endif // XSS_COMMON_COMPARATORS
127+
#endif // XSS_COMMON_COMPARATORS

src/xss-common-qsort.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -615,8 +615,8 @@ X86_SIMD_SORT_INLINE void xss_qsort(T *arr, arrsize_t arrsize, bool hasnan)
615615
{
616616
using comparator =
617617
typename std::conditional<descending,
618-
DescendingComparator<vtype>,
619-
AscendingComparator<vtype>>::type;
618+
Comparator<vtype, true>,
619+
Comparator<vtype, false>>::type;
620620

621621
if (arrsize > 1) {
622622
arrsize_t nan_count = 0;
@@ -641,8 +641,8 @@ xss_qselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan)
641641
{
642642
using comparator =
643643
typename std::conditional<descending,
644-
DescendingComparator<vtype>,
645-
AscendingComparator<vtype>>::type;
644+
Comparator<vtype, true>,
645+
Comparator<vtype, false>>::type;
646646

647647
arrsize_t index_first_elem = 0;
648648
arrsize_t index_last_elem = arrsize - 1;

src/xss-pivot-selection.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ get_pivot_smart(type_t *arr, const arrsize_t left, const arrsize_t right)
115115
// Sort the samples
116116
// Note that this intentionally uses the AscendingComparator
117117
// instead of the provided comparator
118-
sort_vectors<vtype, AscendingComparator<vtype>, numVecs>(vecs);
118+
sort_vectors<vtype, Comparator<vtype, false>, numVecs>(vecs);
119119

120120
type_t samples[N];
121121
for (int i = 0; i < numVecs; i++) {

0 commit comments

Comments
 (0)