Skip to content

Commit 28cfa14

Browse files
author
Raghuveer Devulapalli
authored
Merge pull request #56 from r-devulap/argselect
Add avx512_argselect for 32-bit and 64-bit dtypes
2 parents f22807a + 0ae431a commit 28cfa14

13 files changed

+617
-365
lines changed

src/avx512-16bit-qsort.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -433,11 +433,11 @@ void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan)
433433
{
434434
int64_t indx_last_elem = arrsize - 1;
435435
if (UNLIKELY(hasnan)) {
436-
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
436+
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
437437
}
438438
if (indx_last_elem >= k) {
439439
qselect_16bit_<zmm_vector<float16>, uint16_t>(
440-
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
440+
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
441441
}
442442
}
443443

src/avx512-32bit-qsort.hpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -715,7 +715,10 @@ replace_inf_with_nan(float *arr, int64_t arrsize, int64_t nan_count)
715715
}
716716

717717
template <>
718-
void avx512_qselect<int32_t>(int32_t *arr, int64_t k, int64_t arrsize, bool hasnan)
718+
void avx512_qselect<int32_t>(int32_t *arr,
719+
int64_t k,
720+
int64_t arrsize,
721+
bool hasnan)
719722
{
720723
if (arrsize > 1) {
721724
qselect_32bit_<zmm_vector<int32_t>, int32_t>(
@@ -724,7 +727,10 @@ void avx512_qselect<int32_t>(int32_t *arr, int64_t k, int64_t arrsize, bool hasn
724727
}
725728

726729
template <>
727-
void avx512_qselect<uint32_t>(uint32_t *arr, int64_t k, int64_t arrsize, bool hasnan)
730+
void avx512_qselect<uint32_t>(uint32_t *arr,
731+
int64_t k,
732+
int64_t arrsize,
733+
bool hasnan)
728734
{
729735
if (arrsize > 1) {
730736
qselect_32bit_<zmm_vector<uint32_t>, uint32_t>(
@@ -737,11 +743,11 @@ void avx512_qselect<float>(float *arr, int64_t k, int64_t arrsize, bool hasnan)
737743
{
738744
int64_t indx_last_elem = arrsize - 1;
739745
if (UNLIKELY(hasnan)) {
740-
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
746+
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
741747
}
742748
if (indx_last_elem >= k) {
743749
qselect_32bit_<zmm_vector<float>, float>(
744-
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
750+
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
745751
}
746752
}
747753

src/avx512-64bit-argsort.hpp

Lines changed: 139 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,28 @@
88
#define AVX512_ARGSORT_64BIT
99

1010
#include "avx512-64bit-common.h"
11-
#include "avx512-common-argsort.h"
1211
#include "avx512-64bit-keyvalue-networks.hpp"
12+
#include "avx512-common-argsort.h"
13+
14+
template <typename T>
15+
void std_argselect_withnan(
16+
T *arr, int64_t *arg, int64_t k, int64_t left, int64_t right)
17+
{
18+
std::nth_element(arg + left,
19+
arg + k,
20+
arg + right,
21+
[arr](int64_t a, int64_t b) -> bool {
22+
if ((!std::isnan(arr[a])) && (!std::isnan(arr[b]))) {
23+
return arr[a] < arr[b];
24+
}
25+
else if (std::isnan(arr[a])) {
26+
return false;
27+
}
28+
else {
29+
return true;
30+
}
31+
});
32+
}
1333

1434
/* argsort using std::sort */
1535
template <typename T>
@@ -18,9 +38,15 @@ void std_argsort_withnan(T *arr, int64_t *arg, int64_t left, int64_t right)
1838
std::sort(arg + left,
1939
arg + right,
2040
[arr](int64_t left, int64_t right) -> bool {
21-
if ((!std::isnan(arr[left])) && (!std::isnan(arr[right]))) {return arr[left] < arr[right];}
22-
else if (std::isnan(arr[left])) {return false;}
23-
else {return true;}
41+
if ((!std::isnan(arr[left])) && (!std::isnan(arr[right]))) {
42+
return arr[left] < arr[right];
43+
}
44+
else if (std::isnan(arr[left])) {
45+
return false;
46+
}
47+
else {
48+
return true;
49+
}
2450
});
2551
}
2652

@@ -284,7 +310,42 @@ inline void argsort_64bit_(type_t *arr,
284310
}
285311

286312
template <typename vtype, typename type_t>
287-
bool has_nan(type_t* arr, int64_t arrsize)
313+
static void argselect_64bit_(type_t *arr,
314+
int64_t *arg,
315+
int64_t pos,
316+
int64_t left,
317+
int64_t right,
318+
int64_t max_iters)
319+
{
320+
/*
321+
* Resort to std::sort if quicksort isnt making any progress
322+
*/
323+
if (max_iters <= 0) {
324+
std_argsort(arr, arg, left, right + 1);
325+
return;
326+
}
327+
/*
328+
* Base case: use bitonic networks to sort arrays <= 64
329+
*/
330+
if (right + 1 - left <= 64) {
331+
argsort_64_64bit<vtype>(arr, arg + left, (int32_t)(right + 1 - left));
332+
return;
333+
}
334+
type_t pivot = get_pivot_64bit<vtype>(arr, arg, left, right);
335+
type_t smallest = vtype::type_max();
336+
type_t biggest = vtype::type_min();
337+
int64_t pivot_index = partition_avx512_unrolled<vtype, 4>(
338+
arr, arg, left, right + 1, pivot, &smallest, &biggest);
339+
if ((pivot != smallest) && (pos < pivot_index))
340+
argselect_64bit_<vtype>(
341+
arr, arg, pos, left, pivot_index - 1, max_iters - 1);
342+
else if ((pivot != biggest) && (pos >= pivot_index))
343+
argselect_64bit_<vtype>(
344+
arr, arg, pos, pivot_index, right, max_iters - 1);
345+
}
346+
347+
template <typename vtype, typename type_t>
348+
bool has_nan(type_t *arr, int64_t arrsize)
288349
{
289350
using opmask_t = typename vtype::opmask_t;
290351
using zmm_t = typename vtype::zmm_t;
@@ -299,7 +360,7 @@ bool has_nan(type_t* arr, int64_t arrsize)
299360
else {
300361
in = vtype::loadu(arr);
301362
}
302-
opmask_t nanmask = vtype::template fpclass<0x01|0x80>(in);
363+
opmask_t nanmask = vtype::template fpclass<0x01 | 0x80>(in);
303364
arr += vtype::numlanes;
304365
arrsize -= vtype::numlanes;
305366
if (nanmask != 0x00) {
@@ -310,8 +371,9 @@ bool has_nan(type_t* arr, int64_t arrsize)
310371
return found_nan;
311372
}
312373

374+
/* argsort methods for 32-bit and 64-bit dtypes */
313375
template <typename T>
314-
void avx512_argsort(T* arr, int64_t *arg, int64_t arrsize)
376+
void avx512_argsort(T *arr, int64_t *arg, int64_t arrsize)
315377
{
316378
if (arrsize > 1) {
317379
argsort_64bit_<zmm_vector<T>>(
@@ -320,7 +382,7 @@ void avx512_argsort(T* arr, int64_t *arg, int64_t arrsize)
320382
}
321383

322384
template <>
323-
void avx512_argsort(double* arr, int64_t *arg, int64_t arrsize)
385+
void avx512_argsort(double *arr, int64_t *arg, int64_t arrsize)
324386
{
325387
if (arrsize > 1) {
326388
if (has_nan<zmm_vector<double>>(arr, arrsize)) {
@@ -333,9 +395,8 @@ void avx512_argsort(double* arr, int64_t *arg, int64_t arrsize)
333395
}
334396
}
335397

336-
337398
template <>
338-
void avx512_argsort(int32_t* arr, int64_t *arg, int64_t arrsize)
399+
void avx512_argsort(int32_t *arr, int64_t *arg, int64_t arrsize)
339400
{
340401
if (arrsize > 1) {
341402
argsort_64bit_<ymm_vector<int32_t>>(
@@ -344,7 +405,7 @@ void avx512_argsort(int32_t* arr, int64_t *arg, int64_t arrsize)
344405
}
345406

346407
template <>
347-
void avx512_argsort(uint32_t* arr, int64_t *arg, int64_t arrsize)
408+
void avx512_argsort(uint32_t *arr, int64_t *arg, int64_t arrsize)
348409
{
349410
if (arrsize > 1) {
350411
argsort_64bit_<ymm_vector<uint32_t>>(
@@ -353,7 +414,7 @@ void avx512_argsort(uint32_t* arr, int64_t *arg, int64_t arrsize)
353414
}
354415

355416
template <>
356-
void avx512_argsort(float* arr, int64_t *arg, int64_t arrsize)
417+
void avx512_argsort(float *arr, int64_t *arg, int64_t arrsize)
357418
{
358419
if (arrsize > 1) {
359420
if (has_nan<ymm_vector<float>>(arr, arrsize)) {
@@ -367,12 +428,77 @@ void avx512_argsort(float* arr, int64_t *arg, int64_t arrsize)
367428
}
368429

369430
template <typename T>
370-
std::vector<int64_t> avx512_argsort(T* arr, int64_t arrsize)
431+
std::vector<int64_t> avx512_argsort(T *arr, int64_t arrsize)
371432
{
372433
std::vector<int64_t> indices(arrsize);
373434
std::iota(indices.begin(), indices.end(), 0);
374435
avx512_argsort<T>(arr, indices.data(), arrsize);
375436
return indices;
376437
}
377438

439+
/* argselect methods for 32-bit and 64-bit dtypes */
440+
template <typename T>
441+
void avx512_argselect(T *arr, int64_t *arg, int64_t k, int64_t arrsize)
442+
{
443+
if (arrsize > 1) {
444+
argselect_64bit_<zmm_vector<T>>(
445+
arr, arg, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
446+
}
447+
}
448+
449+
template <>
450+
void avx512_argselect(double *arr, int64_t *arg, int64_t k, int64_t arrsize)
451+
{
452+
if (arrsize > 1) {
453+
if (has_nan<zmm_vector<double>>(arr, arrsize)) {
454+
std_argselect_withnan(arr, arg, k, 0, arrsize);
455+
}
456+
else {
457+
argselect_64bit_<zmm_vector<double>>(
458+
arr, arg, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
459+
}
460+
}
461+
}
462+
463+
template <>
464+
void avx512_argselect(int32_t *arr, int64_t *arg, int64_t k, int64_t arrsize)
465+
{
466+
if (arrsize > 1) {
467+
argselect_64bit_<ymm_vector<int32_t>>(
468+
arr, arg, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
469+
}
470+
}
471+
472+
template <>
473+
void avx512_argselect(uint32_t *arr, int64_t *arg, int64_t k, int64_t arrsize)
474+
{
475+
if (arrsize > 1) {
476+
argselect_64bit_<ymm_vector<uint32_t>>(
477+
arr, arg, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
478+
}
479+
}
480+
481+
template <>
482+
void avx512_argselect(float *arr, int64_t *arg, int64_t k, int64_t arrsize)
483+
{
484+
if (arrsize > 1) {
485+
if (has_nan<ymm_vector<float>>(arr, arrsize)) {
486+
std_argselect_withnan(arr, arg, k, 0, arrsize);
487+
}
488+
else {
489+
argselect_64bit_<ymm_vector<float>>(
490+
arr, arg, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
491+
}
492+
}
493+
}
494+
495+
template <typename T>
496+
std::vector<int64_t> avx512_argselect(T *arr, int64_t k, int64_t arrsize)
497+
{
498+
std::vector<int64_t> indices(arrsize);
499+
std::iota(indices.begin(), indices.end(), 0);
500+
avx512_argselect<T>(arr, indices.data(), k, arrsize);
501+
return indices;
502+
}
503+
378504
#endif // AVX512_ARGSORT_64BIT

src/avx512-64bit-keyvalue-networks.hpp

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -136,14 +136,14 @@ X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(zmm_t *key_zmm,
136136
typename vtype1::opmask_t movmask1 = vtype1::eq(key_zmm_t1, key_zmm[0]);
137137
typename vtype1::opmask_t movmask2 = vtype1::eq(key_zmm_t2, key_zmm[1]);
138138

139-
index_type index_zmm_t1 = vtype2::mask_mov(
140-
index_zmm3r, movmask1, index_zmm[0]);
141-
index_type index_zmm_m1 = vtype2::mask_mov(
142-
index_zmm[0], movmask1, index_zmm3r);
143-
index_type index_zmm_t2 = vtype2::mask_mov(
144-
index_zmm2r, movmask2, index_zmm[1]);
145-
index_type index_zmm_m2 = vtype2::mask_mov(
146-
index_zmm[1], movmask2, index_zmm2r);
139+
index_type index_zmm_t1
140+
= vtype2::mask_mov(index_zmm3r, movmask1, index_zmm[0]);
141+
index_type index_zmm_m1
142+
= vtype2::mask_mov(index_zmm[0], movmask1, index_zmm3r);
143+
index_type index_zmm_t2
144+
= vtype2::mask_mov(index_zmm2r, movmask2, index_zmm[1]);
145+
index_type index_zmm_m2
146+
= vtype2::mask_mov(index_zmm[1], movmask2, index_zmm2r);
147147

148148
// 2) Recursive half clearer: 16
149149
zmm_t key_zmm_t3 = vtype1::permutexvar(rev_index1, key_zmm_m2);
@@ -159,14 +159,14 @@ X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(zmm_t *key_zmm,
159159
movmask1 = vtype1::eq(key_zmm0, key_zmm_t1);
160160
movmask2 = vtype1::eq(key_zmm2, key_zmm_t3);
161161

162-
index_type index_zmm0 = vtype2::mask_mov(
163-
index_zmm_t2, movmask1, index_zmm_t1);
164-
index_type index_zmm1 = vtype2::mask_mov(
165-
index_zmm_t1, movmask1, index_zmm_t2);
166-
index_type index_zmm2 = vtype2::mask_mov(
167-
index_zmm_t4, movmask2, index_zmm_t3);
168-
index_type index_zmm3 = vtype2::mask_mov(
169-
index_zmm_t3, movmask2, index_zmm_t4);
162+
index_type index_zmm0
163+
= vtype2::mask_mov(index_zmm_t2, movmask1, index_zmm_t1);
164+
index_type index_zmm1
165+
= vtype2::mask_mov(index_zmm_t1, movmask1, index_zmm_t2);
166+
index_type index_zmm2
167+
= vtype2::mask_mov(index_zmm_t4, movmask2, index_zmm_t3);
168+
index_type index_zmm3
169+
= vtype2::mask_mov(index_zmm_t3, movmask2, index_zmm_t4);
170170

171171
key_zmm[0] = bitonic_merge_zmm_64bit<vtype1, vtype2>(key_zmm0, index_zmm0);
172172
key_zmm[1] = bitonic_merge_zmm_64bit<vtype1, vtype2>(key_zmm1, index_zmm1);
@@ -212,22 +212,22 @@ X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(zmm_t *key_zmm,
212212
typename vtype1::opmask_t movmask3 = vtype1::eq(key_zmm_t3, key_zmm[2]);
213213
typename vtype1::opmask_t movmask4 = vtype1::eq(key_zmm_t4, key_zmm[3]);
214214

215-
index_type index_zmm_t1 = vtype2::mask_mov(
216-
index_zmm7r, movmask1, index_zmm[0]);
217-
index_type index_zmm_m1 = vtype2::mask_mov(
218-
index_zmm[0], movmask1, index_zmm7r);
219-
index_type index_zmm_t2 = vtype2::mask_mov(
220-
index_zmm6r, movmask2, index_zmm[1]);
221-
index_type index_zmm_m2 = vtype2::mask_mov(
222-
index_zmm[1], movmask2, index_zmm6r);
223-
index_type index_zmm_t3 = vtype2::mask_mov(
224-
index_zmm5r, movmask3, index_zmm[2]);
225-
index_type index_zmm_m3 = vtype2::mask_mov(
226-
index_zmm[2], movmask3, index_zmm5r);
227-
index_type index_zmm_t4 = vtype2::mask_mov(
228-
index_zmm4r, movmask4, index_zmm[3]);
229-
index_type index_zmm_m4 = vtype2::mask_mov(
230-
index_zmm[3], movmask4, index_zmm4r);
215+
index_type index_zmm_t1
216+
= vtype2::mask_mov(index_zmm7r, movmask1, index_zmm[0]);
217+
index_type index_zmm_m1
218+
= vtype2::mask_mov(index_zmm[0], movmask1, index_zmm7r);
219+
index_type index_zmm_t2
220+
= vtype2::mask_mov(index_zmm6r, movmask2, index_zmm[1]);
221+
index_type index_zmm_m2
222+
= vtype2::mask_mov(index_zmm[1], movmask2, index_zmm6r);
223+
index_type index_zmm_t3
224+
= vtype2::mask_mov(index_zmm5r, movmask3, index_zmm[2]);
225+
index_type index_zmm_m3
226+
= vtype2::mask_mov(index_zmm[2], movmask3, index_zmm5r);
227+
index_type index_zmm_t4
228+
= vtype2::mask_mov(index_zmm4r, movmask4, index_zmm[3]);
229+
index_type index_zmm_m4
230+
= vtype2::mask_mov(index_zmm[3], movmask4, index_zmm4r);
231231

232232
zmm_t key_zmm_t5 = vtype1::permutexvar(rev_index1, key_zmm_m4);
233233
zmm_t key_zmm_t6 = vtype1::permutexvar(rev_index1, key_zmm_m3);

0 commit comments

Comments
 (0)