@@ -348,113 +348,57 @@ static void argselect_64bit_(type_t *arr,
348
348
template <typename T>
349
349
void avx512_argsort (T *arr, int64_t *arg, int64_t arrsize)
350
350
{
351
+ using vectype = typename std::conditional<sizeof (T) == sizeof (int32_t ),
352
+ ymm_vector<T>,
353
+ zmm_vector<T>>::type;
351
354
if (arrsize > 1 ) {
352
- argsort_64bit_<zmm_vector<T>>(
353
- arr, arg, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
354
- }
355
- }
356
-
357
- template <>
358
- void avx512_argsort (double *arr, int64_t *arg, int64_t arrsize)
359
- {
360
- if (arrsize > 1 ) {
361
- if (has_nan<zmm_vector<double >>(arr, arrsize)) {
362
- std_argsort_withnan (arr, arg, 0 , arrsize);
355
+ if constexpr (std::is_floating_point_v<T>) {
356
+ if (has_nan<vectype>(arr, arrsize)) {
357
+ std_argsort_withnan (arr, arg, 0 , arrsize);
358
+ return ;
359
+ }
363
360
}
364
- else {
365
- argsort_64bit_<zmm_vector<double >>(
366
- arr, arg, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
367
- }
368
- }
369
- }
370
-
371
- template <>
372
- void avx512_argsort (int32_t *arr, int64_t *arg, int64_t arrsize)
373
- {
374
- if (arrsize > 1 ) {
375
- argsort_64bit_<ymm_vector<int32_t >>(
361
+ argsort_64bit_<vectype>(
376
362
arr, arg, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
377
363
}
378
364
}
379
365
380
- template <>
381
- void avx512_argsort (uint32_t *arr, int64_t *arg, int64_t arrsize)
382
- {
383
- if (arrsize > 1 ) {
384
- argsort_64bit_<ymm_vector<uint32_t >>(
385
- arr, arg, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
386
- }
387
- }
388
-
389
- template <>
390
- void avx512_argsort (float *arr, int64_t *arg, int64_t arrsize)
366
+ template <typename T>
367
+ std::vector<int64_t > avx512_argsort (T *arr, int64_t arrsize)
391
368
{
392
- if (arrsize > 1 ) {
393
- if (has_nan<ymm_vector<float >>(arr, arrsize)) {
394
- std_argsort_withnan (arr, arg, 0 , arrsize);
395
- }
396
- else {
397
- argsort_64bit_<ymm_vector<float >>(
398
- arr, arg, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
399
- }
400
- }
369
+ std::vector<int64_t > indices (arrsize);
370
+ std::iota (indices.begin (), indices.end (), 0 );
371
+ avx512_argsort<T>(arr, indices.data (), arrsize);
372
+ return indices;
401
373
}
402
374
403
375
/* argselect methods for 32-bit and 64-bit dtypes */
404
376
template <typename T>
405
377
void avx512_argselect (T *arr, int64_t *arg, int64_t k, int64_t arrsize)
406
378
{
407
- if (arrsize > 1 ) {
408
- argselect_64bit_<zmm_vector<T>>(
409
- arr, arg, k, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
410
- }
411
- }
379
+ using vectype = typename std::conditional<sizeof (T) == sizeof (int32_t ),
380
+ ymm_vector<T>,
381
+ zmm_vector<T>>::type;
412
382
413
- template <>
414
- void avx512_argselect (double *arr, int64_t *arg, int64_t k, int64_t arrsize)
415
- {
416
383
if (arrsize > 1 ) {
417
- if (has_nan<zmm_vector<double >>(arr, arrsize)) {
418
- std_argselect_withnan (arr, arg, k, 0 , arrsize);
419
- }
420
- else {
421
- argselect_64bit_<zmm_vector<double >>(
422
- arr, arg, k, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
384
+ if constexpr (std::is_floating_point_v<T>) {
385
+ if (has_nan<vectype>(arr, arrsize)) {
386
+ std_argselect_withnan (arr, arg, k, 0 , arrsize);
387
+ return ;
388
+ }
423
389
}
424
- }
425
- }
426
-
427
- template <>
428
- void avx512_argselect (int32_t *arr, int64_t *arg, int64_t k, int64_t arrsize)
429
- {
430
- if (arrsize > 1 ) {
431
- argselect_64bit_<ymm_vector<int32_t >>(
390
+ argselect_64bit_<vectype>(
432
391
arr, arg, k, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
433
392
}
434
393
}
435
394
436
- template <>
437
- void avx512_argselect (uint32_t *arr, int64_t *arg, int64_t k, int64_t arrsize)
438
- {
439
- if (arrsize > 1 ) {
440
- argselect_64bit_<ymm_vector<uint32_t >>(
441
- arr, arg, k, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
442
- }
443
- }
444
-
445
- template <>
446
- void avx512_argselect (float *arr, int64_t *arg, int64_t k, int64_t arrsize)
395
+ template <typename T>
396
+ std::vector<int64_t > avx512_argselect (T *arr, int64_t k, int64_t arrsize)
447
397
{
448
- if (arrsize > 1 ) {
449
- if (has_nan<ymm_vector<float >>(arr, arrsize)) {
450
- std_argselect_withnan (arr, arg, k, 0 , arrsize);
451
- }
452
- else {
453
- argselect_64bit_<ymm_vector<float >>(
454
- arr, arg, k, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
455
- }
456
- }
398
+ std::vector<int64_t > indices (arrsize);
399
+ std::iota (indices.begin (), indices.end (), 0 );
400
+ avx512_argselect<T>(arr, indices.data (), k, arrsize);
401
+ return indices;
457
402
}
458
403
459
-
460
404
#endif // AVX512_ARGSORT_64BIT
0 commit comments