8
8
#define AVX512_ARGSORT_64BIT
9
9
10
10
#include " avx512-64bit-common.h"
11
- #include " avx512-common-argsort.h"
12
11
#include " avx512-64bit-keyvalue-networks.hpp"
12
+ #include " avx512-common-argsort.h"
13
+
14
+ template <typename T>
15
+ void std_argselect_withnan (
16
+ T *arr, int64_t *arg, int64_t k, int64_t left, int64_t right)
17
+ {
18
+ std::nth_element (arg + left,
19
+ arg + k,
20
+ arg + right,
21
+ [arr](int64_t a, int64_t b) -> bool {
22
+ if ((!std::isnan (arr[a])) && (!std::isnan (arr[b]))) {
23
+ return arr[a] < arr[b];
24
+ }
25
+ else if (std::isnan (arr[a])) {
26
+ return false ;
27
+ }
28
+ else {
29
+ return true ;
30
+ }
31
+ });
32
+ }
13
33
14
34
/* argsort using std::sort */
15
35
template <typename T>
@@ -18,9 +38,15 @@ void std_argsort_withnan(T *arr, int64_t *arg, int64_t left, int64_t right)
18
38
std::sort (arg + left,
19
39
arg + right,
20
40
[arr](int64_t left, int64_t right) -> bool {
21
- if ((!std::isnan (arr[left])) && (!std::isnan (arr[right]))) {return arr[left] < arr[right];}
22
- else if (std::isnan (arr[left])) {return false ;}
23
- else {return true ;}
41
+ if ((!std::isnan (arr[left])) && (!std::isnan (arr[right]))) {
42
+ return arr[left] < arr[right];
43
+ }
44
+ else if (std::isnan (arr[left])) {
45
+ return false ;
46
+ }
47
+ else {
48
+ return true ;
49
+ }
24
50
});
25
51
}
26
52
@@ -284,7 +310,42 @@ inline void argsort_64bit_(type_t *arr,
284
310
}
285
311
286
312
template <typename vtype, typename type_t >
287
- bool has_nan (type_t * arr, int64_t arrsize)
313
+ static void argselect_64bit_ (type_t *arr,
314
+ int64_t *arg,
315
+ int64_t pos,
316
+ int64_t left,
317
+ int64_t right,
318
+ int64_t max_iters)
319
+ {
320
+ /*
321
+ * Resort to std::sort if quicksort isnt making any progress
322
+ */
323
+ if (max_iters <= 0 ) {
324
+ std_argsort (arr, arg, left, right + 1 );
325
+ return ;
326
+ }
327
+ /*
328
+ * Base case: use bitonic networks to sort arrays <= 64
329
+ */
330
+ if (right + 1 - left <= 64 ) {
331
+ argsort_64_64bit<vtype>(arr, arg + left, (int32_t )(right + 1 - left));
332
+ return ;
333
+ }
334
+ type_t pivot = get_pivot_64bit<vtype>(arr, arg, left, right);
335
+ type_t smallest = vtype::type_max ();
336
+ type_t biggest = vtype::type_min ();
337
+ int64_t pivot_index = partition_avx512_unrolled<vtype, 4 >(
338
+ arr, arg, left, right + 1 , pivot, &smallest, &biggest);
339
+ if ((pivot != smallest) && (pos < pivot_index))
340
+ argselect_64bit_<vtype>(
341
+ arr, arg, pos, left, pivot_index - 1 , max_iters - 1 );
342
+ else if ((pivot != biggest) && (pos >= pivot_index))
343
+ argselect_64bit_<vtype>(
344
+ arr, arg, pos, pivot_index, right, max_iters - 1 );
345
+ }
346
+
347
+ template <typename vtype, typename type_t >
348
+ bool has_nan (type_t *arr, int64_t arrsize)
288
349
{
289
350
using opmask_t = typename vtype::opmask_t ;
290
351
using zmm_t = typename vtype::zmm_t ;
@@ -299,7 +360,7 @@ bool has_nan(type_t* arr, int64_t arrsize)
299
360
else {
300
361
in = vtype::loadu (arr);
301
362
}
302
- opmask_t nanmask = vtype::template fpclass<0x01 | 0x80 >(in);
363
+ opmask_t nanmask = vtype::template fpclass<0x01 | 0x80 >(in);
303
364
arr += vtype::numlanes;
304
365
arrsize -= vtype::numlanes;
305
366
if (nanmask != 0x00 ) {
@@ -310,8 +371,9 @@ bool has_nan(type_t* arr, int64_t arrsize)
310
371
return found_nan;
311
372
}
312
373
374
+ /* argsort methods for 32-bit and 64-bit dtypes */
313
375
template <typename T>
314
- void avx512_argsort (T* arr, int64_t *arg, int64_t arrsize)
376
+ void avx512_argsort (T * arr, int64_t *arg, int64_t arrsize)
315
377
{
316
378
if (arrsize > 1 ) {
317
379
argsort_64bit_<zmm_vector<T>>(
@@ -320,7 +382,7 @@ void avx512_argsort(T* arr, int64_t *arg, int64_t arrsize)
320
382
}
321
383
322
384
template <>
323
- void avx512_argsort (double * arr, int64_t *arg, int64_t arrsize)
385
+ void avx512_argsort (double * arr, int64_t *arg, int64_t arrsize)
324
386
{
325
387
if (arrsize > 1 ) {
326
388
if (has_nan<zmm_vector<double >>(arr, arrsize)) {
@@ -333,9 +395,8 @@ void avx512_argsort(double* arr, int64_t *arg, int64_t arrsize)
333
395
}
334
396
}
335
397
336
-
337
398
template <>
338
- void avx512_argsort (int32_t * arr, int64_t *arg, int64_t arrsize)
399
+ void avx512_argsort (int32_t * arr, int64_t *arg, int64_t arrsize)
339
400
{
340
401
if (arrsize > 1 ) {
341
402
argsort_64bit_<ymm_vector<int32_t >>(
@@ -344,7 +405,7 @@ void avx512_argsort(int32_t* arr, int64_t *arg, int64_t arrsize)
344
405
}
345
406
346
407
template <>
347
- void avx512_argsort (uint32_t * arr, int64_t *arg, int64_t arrsize)
408
+ void avx512_argsort (uint32_t * arr, int64_t *arg, int64_t arrsize)
348
409
{
349
410
if (arrsize > 1 ) {
350
411
argsort_64bit_<ymm_vector<uint32_t >>(
@@ -353,7 +414,7 @@ void avx512_argsort(uint32_t* arr, int64_t *arg, int64_t arrsize)
353
414
}
354
415
355
416
template <>
356
- void avx512_argsort (float * arr, int64_t *arg, int64_t arrsize)
417
+ void avx512_argsort (float * arr, int64_t *arg, int64_t arrsize)
357
418
{
358
419
if (arrsize > 1 ) {
359
420
if (has_nan<ymm_vector<float >>(arr, arrsize)) {
@@ -367,12 +428,77 @@ void avx512_argsort(float* arr, int64_t *arg, int64_t arrsize)
367
428
}
368
429
369
430
template <typename T>
370
- std::vector<int64_t > avx512_argsort (T* arr, int64_t arrsize)
431
+ std::vector<int64_t > avx512_argsort (T * arr, int64_t arrsize)
371
432
{
372
433
std::vector<int64_t > indices (arrsize);
373
434
std::iota (indices.begin (), indices.end (), 0 );
374
435
avx512_argsort<T>(arr, indices.data (), arrsize);
375
436
return indices;
376
437
}
377
438
439
+ /* argselect methods for 32-bit and 64-bit dtypes */
440
+ template <typename T>
441
+ void avx512_argselect (T *arr, int64_t *arg, int64_t k, int64_t arrsize)
442
+ {
443
+ if (arrsize > 1 ) {
444
+ argselect_64bit_<zmm_vector<T>>(
445
+ arr, arg, k, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
446
+ }
447
+ }
448
+
449
+ template <>
450
+ void avx512_argselect (double *arr, int64_t *arg, int64_t k, int64_t arrsize)
451
+ {
452
+ if (arrsize > 1 ) {
453
+ if (has_nan<zmm_vector<double >>(arr, arrsize)) {
454
+ std_argselect_withnan (arr, arg, k, 0 , arrsize);
455
+ }
456
+ else {
457
+ argselect_64bit_<zmm_vector<double >>(
458
+ arr, arg, k, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
459
+ }
460
+ }
461
+ }
462
+
463
+ template <>
464
+ void avx512_argselect (int32_t *arr, int64_t *arg, int64_t k, int64_t arrsize)
465
+ {
466
+ if (arrsize > 1 ) {
467
+ argselect_64bit_<ymm_vector<int32_t >>(
468
+ arr, arg, k, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
469
+ }
470
+ }
471
+
472
+ template <>
473
+ void avx512_argselect (uint32_t *arr, int64_t *arg, int64_t k, int64_t arrsize)
474
+ {
475
+ if (arrsize > 1 ) {
476
+ argselect_64bit_<ymm_vector<uint32_t >>(
477
+ arr, arg, k, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
478
+ }
479
+ }
480
+
481
+ template <>
482
+ void avx512_argselect (float *arr, int64_t *arg, int64_t k, int64_t arrsize)
483
+ {
484
+ if (arrsize > 1 ) {
485
+ if (has_nan<ymm_vector<float >>(arr, arrsize)) {
486
+ std_argselect_withnan (arr, arg, k, 0 , arrsize);
487
+ }
488
+ else {
489
+ argselect_64bit_<ymm_vector<float >>(
490
+ arr, arg, k, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
491
+ }
492
+ }
493
+ }
494
+
495
+ template <typename T>
496
+ std::vector<int64_t > avx512_argselect (T *arr, int64_t k, int64_t arrsize)
497
+ {
498
+ std::vector<int64_t > indices (arrsize);
499
+ std::iota (indices.begin (), indices.end (), 0 );
500
+ avx512_argselect<T>(arr, indices.data (), k, arrsize);
501
+ return indices;
502
+ }
503
+
378
504
#endif // AVX512_ARGSORT_64BIT
0 commit comments