@@ -468,7 +468,7 @@ X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr,
468
468
}
469
469
470
470
template <typename vtype, typename argtype, typename type_t >
471
- X86_SIMD_SORT_INLINE void argsort_64bit_ (type_t *arr,
471
+ X86_SIMD_SORT_INLINE void argsort_ (type_t *arr,
472
472
arrsize_t *arg,
473
473
arrsize_t left,
474
474
arrsize_t right,
@@ -495,15 +495,15 @@ X86_SIMD_SORT_INLINE void argsort_64bit_(type_t *arr,
495
495
arrsize_t pivot_index = argpartition_unrolled<vtype, argtype, 4 >(
496
496
arr, arg, left, right + 1 , pivot, &smallest, &biggest);
497
497
if (pivot != smallest)
498
- argsort_64bit_ <vtype, argtype>(
498
+ argsort_ <vtype, argtype>(
499
499
arr, arg, left, pivot_index - 1 , max_iters - 1 );
500
500
if (pivot != biggest)
501
- argsort_64bit_ <vtype, argtype>(
501
+ argsort_ <vtype, argtype>(
502
502
arr, arg, pivot_index, right, max_iters - 1 );
503
503
}
504
504
505
505
template <typename vtype, typename argtype, typename type_t >
506
- X86_SIMD_SORT_INLINE void argselect_64bit_ (type_t *arr,
506
+ X86_SIMD_SORT_INLINE void argselect_ (type_t *arr,
507
507
arrsize_t *arg,
508
508
arrsize_t pos,
509
509
arrsize_t left,
@@ -531,30 +531,34 @@ X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr,
531
531
arrsize_t pivot_index = argpartition_unrolled<vtype, argtype, 4 >(
532
532
arr, arg, left, right + 1 , pivot, &smallest, &biggest);
533
533
if ((pivot != smallest) && (pos < pivot_index))
534
- argselect_64bit_ <vtype, argtype>(
534
+ argselect_ <vtype, argtype>(
535
535
arr, arg, pos, left, pivot_index - 1 , max_iters - 1 );
536
536
else if ((pivot != biggest) && (pos >= pivot_index))
537
- argselect_64bit_ <vtype, argtype>(
537
+ argselect_ <vtype, argtype>(
538
538
arr, arg, pos, pivot_index, right, max_iters - 1 );
539
539
}
540
540
541
541
/* argsort methods for 32-bit and 64-bit dtypes */
542
- template <typename T>
543
- X86_SIMD_SORT_INLINE void avx512_argsort (T *arr,
542
+ template <typename T,
543
+ template <typename ...>
544
+ typename full_vector,
545
+ template <typename ...>
546
+ typename half_vector>
547
+ X86_SIMD_SORT_INLINE void xss_argsort (T *arr,
544
548
arrsize_t *arg,
545
549
arrsize_t arrsize,
546
550
bool hasnan = false ,
547
551
bool descending = false )
548
552
{
549
- /* TODO optimization: on 32-bit, use zmm_vector for 32-bit dtype */
553
+ /* TODO optimization: on 32-bit, use full_vector for 32-bit dtype */
550
554
using vectype = typename std::conditional<sizeof (T) == sizeof (int32_t ),
551
- ymm_vector <T>,
552
- zmm_vector <T>>::type;
555
+ half_vector <T>,
556
+ full_vector <T>>::type;
553
557
554
558
using argtype =
555
559
typename std::conditional<sizeof (arrsize_t ) == sizeof (int32_t ),
556
- ymm_vector <arrsize_t >,
557
- zmm_vector <arrsize_t >>::type;
560
+ half_vector <arrsize_t >,
561
+ full_vector <arrsize_t >>::type;
558
562
559
563
if (arrsize > 1 ) {
560
564
if constexpr (xss::fp::is_floating_point_v<T>) {
@@ -567,64 +571,54 @@ X86_SIMD_SORT_INLINE void avx512_argsort(T *arr,
567
571
}
568
572
}
569
573
UNUSED (hasnan);
570
- argsort_64bit_ <vectype, argtype>(
574
+ argsort_ <vectype, argtype>(
571
575
arr, arg, 0 , arrsize - 1 , 2 * (arrsize_t )log2 (arrsize));
572
576
573
577
if (descending) { std::reverse (arg, arg + arrsize); }
574
578
}
575
579
}
576
580
577
- /* argsort methods for 32-bit and 64-bit dtypes */
578
581
template <typename T>
579
- X86_SIMD_SORT_INLINE void avx2_argsort (T *arr,
580
- arrsize_t *arg,
581
- arrsize_t arrsize,
582
- bool hasnan = false ,
583
- bool descending = false )
582
+ X86_SIMD_SORT_INLINE void avx512_argsort (T *arr,
583
+ arrsize_t *arg,
584
+ arrsize_t arrsize,
585
+ bool hasnan = false ,
586
+ bool descending = false )
584
587
{
585
- using vectype = typename std::conditional<sizeof (T) == sizeof (int32_t ),
586
- avx2_half_vector<T>,
587
- avx2_vector<T>>::type;
588
-
589
- using argtype =
590
- typename std::conditional<sizeof (arrsize_t ) == sizeof (int32_t ),
591
- avx2_half_vector<arrsize_t >,
592
- avx2_vector<arrsize_t >>::type;
593
- if (arrsize > 1 ) {
594
- if constexpr (xss::fp::is_floating_point_v<T>) {
595
- if ((hasnan) && (array_has_nan<vectype>(arr, arrsize))) {
596
- std_argsort_withnan (arr, arg, 0 , arrsize);
597
-
598
- if (descending) { std::reverse (arg, arg + arrsize); }
599
-
600
- return ;
601
- }
602
- }
603
- UNUSED (hasnan);
604
- argsort_64bit_<vectype, argtype>(
605
- arr, arg, 0 , arrsize - 1 , 2 * (arrsize_t )log2 (arrsize));
588
+ xss_argsort<T, zmm_vector, ymm_vector>(arr, arg, arrsize, hasnan, descending);
589
+ }
606
590
607
- if (descending) { std::reverse (arg, arg + arrsize); }
608
- }
591
+ template <typename T>
592
+ X86_SIMD_SORT_INLINE void avx2_argsort (T *arr,
593
+ arrsize_t *arg,
594
+ arrsize_t arrsize,
595
+ bool hasnan = false ,
596
+ bool descending = false )
597
+ {
598
+ xss_argsort<T, avx2_vector, avx2_half_vector>(arr, arg, arrsize, hasnan, descending);
609
599
}
610
600
611
601
/* argselect methods for 32-bit and 64-bit dtypes */
612
- template <typename T>
613
- X86_SIMD_SORT_INLINE void avx512_argselect (T *arr,
602
+ template <typename T,
603
+ template <typename ...>
604
+ typename full_vector,
605
+ template <typename ...>
606
+ typename half_vector>
607
+ X86_SIMD_SORT_INLINE void xss_argselect (T *arr,
614
608
arrsize_t *arg,
615
609
arrsize_t k,
616
610
arrsize_t arrsize,
617
611
bool hasnan = false )
618
612
{
619
- /* TODO optimization: on 32-bit, use zmm_vector for 32-bit dtype */
613
+ /* TODO optimization: on 32-bit, use full_vector for 32-bit dtype */
620
614
using vectype = typename std::conditional<sizeof (T) == sizeof (int32_t ),
621
- ymm_vector <T>,
622
- zmm_vector <T>>::type;
615
+ half_vector <T>,
616
+ full_vector <T>>::type;
623
617
624
618
using argtype =
625
619
typename std::conditional<sizeof (arrsize_t ) == sizeof (int32_t ),
626
- ymm_vector <arrsize_t >,
627
- zmm_vector <arrsize_t >>::type;
620
+ half_vector <arrsize_t >,
621
+ full_vector <arrsize_t >>::type;
628
622
629
623
if (arrsize > 1 ) {
630
624
if constexpr (xss::fp::is_floating_point_v<T>) {
@@ -634,38 +628,29 @@ X86_SIMD_SORT_INLINE void avx512_argselect(T *arr,
634
628
}
635
629
}
636
630
UNUSED (hasnan);
637
- argselect_64bit_ <vectype, argtype>(
631
+ argselect_ <vectype, argtype>(
638
632
arr, arg, k, 0 , arrsize - 1 , 2 * (arrsize_t )log2 (arrsize));
639
633
}
640
634
}
641
635
642
- /* argselect methods for 32-bit and 64-bit dtypes */
643
636
template <typename T>
644
- X86_SIMD_SORT_INLINE void avx2_argselect (T *arr,
637
+ X86_SIMD_SORT_INLINE void avx512_argselect (T *arr,
645
638
arrsize_t *arg,
646
639
arrsize_t k,
647
640
arrsize_t arrsize,
648
641
bool hasnan = false )
649
642
{
650
- using vectype = typename std::conditional<sizeof (T) == sizeof (int32_t ),
651
- avx2_half_vector<T>,
652
- avx2_vector<T>>::type;
653
-
654
- using argtype =
655
- typename std::conditional<sizeof (arrsize_t ) == sizeof (int32_t ),
656
- avx2_half_vector<arrsize_t >,
657
- avx2_vector<arrsize_t >>::type;
643
+ xss_argselect<T, zmm_vector, ymm_vector>(arr, arg, k, arrsize, hasnan);
644
+ }
658
645
659
- if (arrsize > 1 ) {
660
- if constexpr (xss::fp::is_floating_point_v<T>) {
661
- if ((hasnan) && (array_has_nan<vectype>(arr, arrsize))) {
662
- std_argselect_withnan (arr, arg, k, 0 , arrsize);
663
- return ;
664
- }
665
- }
666
- UNUSED (hasnan);
667
- argselect_64bit_<vectype, argtype>(
668
- arr, arg, k, 0 , arrsize - 1 , 2 * (arrsize_t )log2 (arrsize));
669
- }
646
+ template <typename T>
647
+ X86_SIMD_SORT_INLINE void avx2_argselect (T *arr,
648
+ arrsize_t *arg,
649
+ arrsize_t k,
650
+ arrsize_t arrsize,
651
+ bool hasnan = false )
652
+ {
653
+ xss_argselect<T, avx2_vector, avx2_half_vector>(arr, arg, k, arrsize, hasnan);
670
654
}
655
+
671
656
#endif // XSS_COMMON_ARGSORT
0 commit comments