Skip to content

Commit 06d31e7

Browse files
authored
Merge pull request #142 from r-devulap/kv-sort-tests
Fix bug while processing NAN in key-value sort
2 parents 64b1e27 + 9fe775a commit 06d31e7

File tree

3 files changed

+51
-54
lines changed

3 files changed

+51
-54
lines changed

src/avx512-64bit-keyvaluesort.hpp

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,13 @@ template <typename vtype1,
5757
typename type_t2 = typename vtype2::type_t,
5858
typename reg_t1 = typename vtype1::reg_t,
5959
typename reg_t2 = typename vtype2::reg_t>
60-
X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t1 *keys,
61-
type_t2 *indexes,
62-
arrsize_t left,
63-
arrsize_t right,
64-
type_t1 pivot,
65-
type_t1 *smallest,
66-
type_t1 *biggest)
60+
X86_SIMD_SORT_INLINE arrsize_t kvpartition(type_t1 *keys,
61+
type_t2 *indexes,
62+
arrsize_t left,
63+
arrsize_t right,
64+
type_t1 pivot,
65+
type_t1 *smallest,
66+
type_t1 *biggest)
6767
{
6868
/* make array length divisible by vtype1::numlanes , shortening the array */
6969
for (int32_t i = (right - left) % vtype1::numlanes; i > 0; --i) {
@@ -189,16 +189,16 @@ template <typename vtype1,
189189
typename type_t2 = typename vtype2::type_t,
190190
typename reg_t1 = typename vtype1::reg_t,
191191
typename reg_t2 = typename vtype2::reg_t>
192-
X86_SIMD_SORT_INLINE arrsize_t partition_avx512_unrolled(type_t1 *keys,
193-
type_t2 *indexes,
194-
arrsize_t left,
195-
arrsize_t right,
196-
type_t1 pivot,
197-
type_t1 *smallest,
198-
type_t1 *biggest)
192+
X86_SIMD_SORT_INLINE arrsize_t kvpartition_unrolled(type_t1 *keys,
193+
type_t2 *indexes,
194+
arrsize_t left,
195+
arrsize_t right,
196+
type_t1 pivot,
197+
type_t1 *smallest,
198+
type_t1 *biggest)
199199
{
200200
if (right - left <= 8 * num_unroll * vtype1::numlanes) {
201-
return partition_avx512<vtype1, vtype2>(
201+
return kvpartition<vtype1, vtype2>(
202202
keys, indexes, left, right, pivot, smallest, biggest);
203203
}
204204
/* make array length divisible by vtype1::numlanes , shortening the array */
@@ -391,7 +391,7 @@ X86_SIMD_SORT_INLINE void qsort_64bit_(type1_t *keys,
391391
type1_t pivot = get_pivot_blocks<vtype1>(keys, left, right);
392392
type1_t smallest = vtype1::type_max();
393393
type1_t biggest = vtype1::type_min();
394-
arrsize_t pivot_index = partition_avx512_unrolled<vtype1, vtype2, 4>(
394+
arrsize_t pivot_index = kvpartition_unrolled<vtype1, vtype2, 4>(
395395
keys, indexes, left, right + 1, pivot, &smallest, &biggest);
396396
if (pivot != smallest) {
397397
qsort_64bit_<vtype1, vtype2>(
@@ -422,8 +422,7 @@ avx512_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan = false)
422422
if constexpr (std::is_floating_point_v<T1>) {
423423
arrsize_t nan_count = 0;
424424
if (UNLIKELY(hasnan)) {
425-
nan_count = replace_nan_with_inf<zmm_vector<double>>(keys,
426-
arrsize);
425+
nan_count = replace_nan_with_inf<zmm_vector<T1>>(keys, arrsize);
427426
}
428427
qsort_64bit_<keytype, valtype>(keys,
429428
indexes,

src/xss-common-argsort.h

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -173,13 +173,13 @@ X86_SIMD_SORT_INLINE int32_t partition_vec(type_t *arg,
173173
* last element that is less than equal to the pivot.
174174
*/
175175
template <typename vtype, typename argtype, typename type_t>
176-
X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t *arr,
177-
arrsize_t *arg,
178-
arrsize_t left,
179-
arrsize_t right,
180-
type_t pivot,
181-
type_t *smallest,
182-
type_t *biggest)
176+
X86_SIMD_SORT_INLINE arrsize_t argpartition(type_t *arr,
177+
arrsize_t *arg,
178+
arrsize_t left,
179+
arrsize_t right,
180+
type_t pivot,
181+
type_t *smallest,
182+
type_t *biggest)
183183
{
184184
/* make array length divisible by vtype::numlanes , shortening the array */
185185
for (int32_t i = (right - left) % vtype::numlanes; i > 0; --i) {
@@ -292,16 +292,16 @@ template <typename vtype,
292292
typename argtype,
293293
int num_unroll,
294294
typename type_t = typename vtype::type_t>
295-
X86_SIMD_SORT_INLINE arrsize_t partition_avx512_unrolled(type_t *arr,
296-
arrsize_t *arg,
297-
arrsize_t left,
298-
arrsize_t right,
299-
type_t pivot,
300-
type_t *smallest,
301-
type_t *biggest)
295+
X86_SIMD_SORT_INLINE arrsize_t argpartition_unrolled(type_t *arr,
296+
arrsize_t *arg,
297+
arrsize_t left,
298+
arrsize_t right,
299+
type_t pivot,
300+
type_t *smallest,
301+
type_t *biggest)
302302
{
303303
if (right - left <= 8 * num_unroll * vtype::numlanes) {
304-
return partition_avx512<vtype, argtype>(
304+
return argpartition<vtype, argtype>(
305305
arr, arg, left, right, pivot, smallest, biggest);
306306
}
307307
/* make array length divisible by vtype::numlanes , shortening the array */
@@ -493,7 +493,7 @@ X86_SIMD_SORT_INLINE void argsort_64bit_(type_t *arr,
493493
type_t pivot = get_pivot_64bit<vtype>(arr, arg, left, right);
494494
type_t smallest = vtype::type_max();
495495
type_t biggest = vtype::type_min();
496-
arrsize_t pivot_index = partition_avx512_unrolled<vtype, argtype, 4>(
496+
arrsize_t pivot_index = argpartition_unrolled<vtype, argtype, 4>(
497497
arr, arg, left, right + 1, pivot, &smallest, &biggest);
498498
if (pivot != smallest)
499499
argsort_64bit_<vtype, argtype>(
@@ -529,7 +529,7 @@ X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr,
529529
type_t pivot = get_pivot_64bit<vtype>(arr, arg, left, right);
530530
type_t smallest = vtype::type_max();
531531
type_t biggest = vtype::type_min();
532-
arrsize_t pivot_index = partition_avx512_unrolled<vtype, argtype, 4>(
532+
arrsize_t pivot_index = argpartition_unrolled<vtype, argtype, 4>(
533533
arr, arg, left, right + 1, pivot, &smallest, &biggest);
534534
if ((pivot != smallest) && (pos < pivot_index))
535535
argselect_64bit_<vtype, argtype>(

src/xss-common-qsort.h

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -206,12 +206,12 @@ X86_SIMD_SORT_INLINE arrsize_t partition_vec(type_t *l_store,
206206
* first element that is greater than or equal to the pivot.
207207
*/
208208
template <typename vtype, typename type_t>
209-
X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t *arr,
210-
arrsize_t left,
211-
arrsize_t right,
212-
type_t pivot,
213-
type_t *smallest,
214-
type_t *biggest)
209+
X86_SIMD_SORT_INLINE arrsize_t partition(type_t *arr,
210+
arrsize_t left,
211+
arrsize_t right,
212+
type_t pivot,
213+
type_t *smallest,
214+
type_t *biggest)
215215
{
216216
/* make array length divisible by vtype::numlanes , shortening the array */
217217
for (int32_t i = (right - left) % vtype::numlanes; i > 0; --i) {
@@ -316,22 +316,20 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t *arr,
316316
template <typename vtype,
317317
int num_unroll,
318318
typename type_t = typename vtype::type_t>
319-
X86_SIMD_SORT_INLINE arrsize_t partition_avx512_unrolled(type_t *arr,
320-
arrsize_t left,
321-
arrsize_t right,
322-
type_t pivot,
323-
type_t *smallest,
324-
type_t *biggest)
319+
X86_SIMD_SORT_INLINE arrsize_t partition_unrolled(type_t *arr,
320+
arrsize_t left,
321+
arrsize_t right,
322+
type_t pivot,
323+
type_t *smallest,
324+
type_t *biggest)
325325
{
326326
if constexpr (num_unroll == 0) {
327-
return partition_avx512<vtype>(
328-
arr, left, right, pivot, smallest, biggest);
327+
return partition<vtype>(arr, left, right, pivot, smallest, biggest);
329328
}
330329

331-
/* Use regular partition_avx512 for smaller arrays */
330+
/* Use regular partition for smaller arrays */
332331
if (right - left < 3 * num_unroll * vtype::numlanes) {
333-
return partition_avx512<vtype>(
334-
arr, left, right, pivot, smallest, biggest);
332+
return partition<vtype>(arr, left, right, pivot, smallest, biggest);
335333
}
336334

337335
/* make array length divisible by vtype::numlanes, shortening the array */
@@ -509,7 +507,7 @@ qsort_(type_t *arr, arrsize_t left, arrsize_t right, arrsize_t max_iters)
509507
type_t biggest = vtype::type_min();
510508

511509
arrsize_t pivot_index
512-
= partition_avx512_unrolled<vtype, vtype::partition_unroll_factor>(
510+
= partition_unrolled<vtype, vtype::partition_unroll_factor>(
513511
arr, left, right + 1, pivot, &smallest, &biggest);
514512

515513
if (pivot_result.result == pivot_result_t::Only2Values) { return; }
@@ -547,7 +545,7 @@ X86_SIMD_SORT_INLINE void qselect_(type_t *arr,
547545
type_t biggest = vtype::type_min();
548546

549547
arrsize_t pivot_index
550-
= partition_avx512_unrolled<vtype, vtype::partition_unroll_factor>(
548+
= partition_unrolled<vtype, vtype::partition_unroll_factor>(
551549
arr, left, right + 1, pivot, &smallest, &biggest);
552550

553551
if ((pivot != smallest) && (pos < pivot_index))

0 commit comments

Comments
 (0)