Skip to content

Commit 9029f61

Browse files
committed
Align qselect parameter meaning with nth_element
The QuickSelect internal method is now phrased such that the position to be sorted is given as an offset (in the same way that left points to the first element and right points to the last element). Similarly, the avx512_qselect method also now uses this interpretation.
1 parent 1204561 commit 9029f61

File tree

8 files changed

+59
-41
lines changed

8 files changed

+59
-41
lines changed

benchmarks/bench_qselect.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ static void avx512_qselect(benchmark::State& state) {
1818
arr_bkp = arr;
1919

2020
/* Choose random index to make sorted */
21-
int k = get_uniform_rand_array<int64_t>(1, ARRSIZE, 1).front();
21+
int k = get_uniform_rand_array<int64_t>(1, ARRSIZE - 1, 0).front();
2222

2323
/* call avx512 quickselect */
2424
for (auto _ : state) {
@@ -42,7 +42,7 @@ static void stdnthelement(benchmark::State& state) {
4242
arr_bkp = arr;
4343

4444
/* Choose random index to make sorted */
45-
int k = get_uniform_rand_array<int64_t>(1, ARRSIZE, 1).front();
45+
int k = get_uniform_rand_array<int64_t>(1, ARRSIZE - 1, 0).front();
4646

4747
/* call std::nth_element */
4848
for (auto _ : state) {

benchmarks/bench_qsortfp16.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ static void avx512_qselect(benchmark::State& state) {
7878
arr_bkp = arr;
7979

8080
/* Choose random index to make sorted */
81-
int k = get_uniform_rand_array<int64_t>(1, ARRSIZE, 1).front();
81+
int k = get_uniform_rand_array<int64_t>(1, ARRSIZE - 1, 0).front();
8282

8383
/* call avx512 quickselect */
8484
for (auto _ : state) {
@@ -110,7 +110,7 @@ static void stdnthelement(benchmark::State& state) {
110110
arr_bkp = arr;
111111

112112
/* Choose random index to sort until */
113-
int k = get_uniform_rand_array<int64_t>(1, ARRSIZE, 1).front();
113+
int k = get_uniform_rand_array<int64_t>(1, ARRSIZE - 1, 0).front();
114114

115115
/* call std::nth_element */
116116
for (auto _ : state) {

src/avx512-16bit-common.h

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -261,9 +261,7 @@ X86_SIMD_SORT_INLINE type_t get_pivot_16bit(type_t *arr,
261261

262262
template <typename vtype, typename type_t>
263263
static void
264-
qselect_16bit_(type_t *arr, int64_t k,
265-
int64_t left, int64_t right,
266-
int64_t max_iters)
264+
qsort_16bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters)
267265
{
268266
/*
269267
* Resort to std::sort if quicksort isnt making any progress
@@ -285,15 +283,17 @@ qselect_16bit_(type_t *arr, int64_t k,
285283
type_t biggest = vtype::type_min();
286284
int64_t pivot_index = partition_avx512<vtype>(
287285
arr, left, right + 1, pivot, &smallest, &biggest);
288-
if ((pivot != smallest) && (k <= pivot_index))
289-
qselect_16bit_<vtype>(arr, k, left, pivot_index - 1, max_iters - 1);
290-
else if ((pivot != biggest) && (k > pivot_index))
291-
qselect_16bit_<vtype>(arr, k, pivot_index, right, max_iters - 1);
286+
if (pivot != smallest)
287+
qsort_16bit_<vtype>(arr, left, pivot_index - 1, max_iters - 1);
288+
if (pivot != biggest)
289+
qsort_16bit_<vtype>(arr, pivot_index, right, max_iters - 1);
292290
}
293291

294292
template <typename vtype, typename type_t>
295293
static void
296-
qsort_16bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters)
294+
qselect_16bit_(type_t *arr, int64_t pos,
295+
int64_t left, int64_t right,
296+
int64_t max_iters)
297297
{
298298
/*
299299
* Resort to std::sort if quicksort isnt making any progress
@@ -315,10 +315,10 @@ qsort_16bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters)
315315
type_t biggest = vtype::type_min();
316316
int64_t pivot_index = partition_avx512<vtype>(
317317
arr, left, right + 1, pivot, &smallest, &biggest);
318-
if (pivot != smallest)
319-
qsort_16bit_<vtype>(arr, left, pivot_index - 1, max_iters - 1);
320-
if (pivot != biggest)
321-
qsort_16bit_<vtype>(arr, pivot_index, right, max_iters - 1);
318+
if ((pivot != smallest) && (pos < pivot_index))
319+
qselect_16bit_<vtype>(arr, pos, left, pivot_index - 1, max_iters - 1);
320+
else if ((pivot != biggest) && (pos >= pivot_index))
321+
qselect_16bit_<vtype>(arr, pos, pivot_index, right, max_iters - 1);
322322
}
323323

324324
#endif // AVX512_16BIT_COMMON

src/avx512-32bit-qsort.hpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -628,9 +628,7 @@ X86_SIMD_SORT_INLINE type_t get_pivot_32bit(type_t *arr,
628628

629629
template <typename vtype, typename type_t>
630630
static void
631-
qselect_32bit_(type_t *arr, int64_t k,
632-
int64_t left, int64_t right,
633-
int64_t max_iters)
631+
qsort_32bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters)
634632
{
635633
/*
636634
* Resort to std::sort if quicksort isnt making any progress
@@ -652,15 +650,17 @@ qselect_32bit_(type_t *arr, int64_t k,
652650
type_t biggest = vtype::type_min();
653651
int64_t pivot_index = partition_avx512<vtype>(
654652
arr, left, right + 1, pivot, &smallest, &biggest);
655-
if ((pivot != smallest) && (k <= pivot_index))
656-
qselect_32bit_<vtype>(arr, k, left, pivot_index - 1, max_iters - 1);
657-
else if ((pivot != biggest) && (k > pivot_index))
658-
qselect_32bit_<vtype>(arr, k, pivot_index, right, max_iters - 1);
653+
if (pivot != smallest)
654+
qsort_32bit_<vtype>(arr, left, pivot_index - 1, max_iters - 1);
655+
if (pivot != biggest)
656+
qsort_32bit_<vtype>(arr, pivot_index, right, max_iters - 1);
659657
}
660658

661659
template <typename vtype, typename type_t>
662660
static void
663-
qsort_32bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters)
661+
qselect_32bit_(type_t *arr, int64_t pos,
662+
int64_t left, int64_t right,
663+
int64_t max_iters)
664664
{
665665
/*
666666
* Resort to std::sort if quicksort isnt making any progress
@@ -682,10 +682,10 @@ qsort_32bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters)
682682
type_t biggest = vtype::type_min();
683683
int64_t pivot_index = partition_avx512<vtype>(
684684
arr, left, right + 1, pivot, &smallest, &biggest);
685-
if (pivot != smallest)
686-
qsort_32bit_<vtype>(arr, left, pivot_index - 1, max_iters - 1);
687-
if (pivot != biggest)
688-
qsort_32bit_<vtype>(arr, pivot_index, right, max_iters - 1);
685+
if ((pivot != smallest) && (pos < pivot_index))
686+
qselect_32bit_<vtype>(arr, pos, left, pivot_index - 1, max_iters - 1);
687+
else if ((pivot != biggest) && (pos >= pivot_index))
688+
qselect_32bit_<vtype>(arr, pos, pivot_index, right, max_iters - 1);
689689
}
690690

691691
X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(float *arr, int64_t arrsize)

src/avx512-64bit-qsort.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ qsort_64bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters)
403403

404404
template <typename vtype, typename type_t>
405405
static void
406-
qselect_64bit_(type_t *arr, int64_t k,
406+
qselect_64bit_(type_t *arr, int64_t pos,
407407
int64_t left, int64_t right,
408408
int64_t max_iters)
409409
{
@@ -427,18 +427,18 @@ qselect_64bit_(type_t *arr, int64_t k,
427427
type_t biggest = vtype::type_min();
428428
int64_t pivot_index = partition_avx512<vtype>(
429429
arr, left, right + 1, pivot, &smallest, &biggest);
430-
if ((pivot != smallest) && (k <= pivot_index))
431-
qselect_64bit_<vtype>(arr, k, left, pivot_index - 1, max_iters - 1);
432-
else if ((pivot != biggest) && (k > pivot_index))
433-
qselect_64bit_<vtype>(arr, k, pivot_index, right, max_iters - 1);
430+
if ((pivot != smallest) && (pos < pivot_index))
431+
qselect_64bit_<vtype>(arr, pos, left, pivot_index - 1, max_iters - 1);
432+
else if ((pivot != biggest) && (pos >= pivot_index))
433+
qselect_64bit_<vtype>(arr, pos, pivot_index, right, max_iters - 1);
434434
}
435435

436436
template <>
437437
void avx512_qselect<int64_t>(int64_t *arr, int64_t k, int64_t arrsize)
438438
{
439439
if (arrsize > 1) {
440440
qselect_64bit_<zmm_vector<int64_t>, int64_t>(
441-
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
441+
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
442442
}
443443
}
444444

src/avx512-common-qsort.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,14 +95,14 @@ void avx512_qselect(T *arr, int64_t k, int64_t arrsize);
9595
void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize);
9696

9797
template <typename T>
98-
void avx512_partial_qsort(T *arr, int64_t k, int64_t arrsize) {
99-
avx512_qselect<T>(arr, k, arrsize);
100-
avx512_qsort<T>(arr, k);
98+
inline void avx512_partial_qsort(T *arr, int64_t k, int64_t arrsize) {
99+
avx512_qselect<T>(arr, k - 1, arrsize);
100+
avx512_qsort<T>(arr, k - 1);
101101
}
102102
inline void avx512_partial_qsort_fp16(uint16_t *arr, int64_t k, int64_t arrsize)
103103
{
104-
avx512_qselect_fp16(arr, k, arrsize);
105-
avx512_qsort_fp16(arr, k);
104+
avx512_qselect_fp16(arr, k - 1, arrsize);
105+
avx512_qsort_fp16(arr, k - 1);
106106
}
107107

108108
template <typename vtype, typename T = typename vtype::type_t>

tests/test_qselect.hpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,17 @@ TYPED_TEST_P(avx512_select, test_arrsizes)
2626
std::sort(sortedarr.begin(), sortedarr.end());
2727
for (size_t k = 0; k < arr.size(); ++k) {
2828
psortedarr = arr;
29-
avx512_qselect<TypeParam>(psortedarr.data(), k+1, psortedarr.size());
29+
avx512_qselect<TypeParam>(psortedarr.data(), k, psortedarr.size());
30+
/* index k is correct */
3031
ASSERT_EQ(sortedarr[k], psortedarr[k]);
32+
/* Check left partition */
33+
for (size_t jj = 0; jj < k; jj++) {
34+
ASSERT_LE(psortedarr[jj], psortedarr[k]);
35+
}
36+
/* Check right partition */
37+
for (size_t jj = k+1; jj < arr.size(); jj++) {
38+
ASSERT_GE(psortedarr[jj], psortedarr[k]);
39+
}
3140
psortedarr.clear();
3241
}
3342
arr.clear();

tests/test_qsortfp16.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,17 @@ TEST(avx512_qselect_float16, test_arrsizes)
9595
std::sort(sortedarr.begin(), sortedarr.end());
9696
for (size_t k = 0; k < arr.size(); ++k) {
9797
psortedarr = arr;
98-
avx512_qselect<_Float16>(psortedarr.data(), k+1, psortedarr.size());
98+
avx512_qselect<_Float16>(psortedarr.data(), k, psortedarr.size());
99+
/* index k is correct */
99100
ASSERT_EQ(sortedarr[k], psortedarr[k]);
101+
/* Check left partition */
102+
for (size_t jj = 0; jj < k; jj++) {
103+
ASSERT_LE(psortedarr[jj], psortedarr[k]);
104+
}
105+
/* Check right partition */
106+
for (size_t jj = k+1; jj < arr.size(); jj++) {
107+
ASSERT_GE(psortedarr[jj], psortedarr[k]);
108+
}
100109
psortedarr.clear();
101110
}
102111
arr.clear();

0 commit comments

Comments
 (0)