@@ -468,7 +468,7 @@ X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr,
468468}
469469
470470template <typename vtype, typename argtype, typename type_t >
471- X86_SIMD_SORT_INLINE void argsort_64bit_ (type_t *arr,
471+ X86_SIMD_SORT_INLINE void argsort_ (type_t *arr,
472472 arrsize_t *arg,
473473 arrsize_t left,
474474 arrsize_t right,
@@ -495,15 +495,15 @@ X86_SIMD_SORT_INLINE void argsort_64bit_(type_t *arr,
495495 arrsize_t pivot_index = argpartition_unrolled<vtype, argtype, 4 >(
496496 arr, arg, left, right + 1 , pivot, &smallest, &biggest);
497497 if (pivot != smallest)
498- argsort_64bit_ <vtype, argtype>(
498+ argsort_ <vtype, argtype>(
499499 arr, arg, left, pivot_index - 1 , max_iters - 1 );
500500 if (pivot != biggest)
501- argsort_64bit_ <vtype, argtype>(
501+ argsort_ <vtype, argtype>(
502502 arr, arg, pivot_index, right, max_iters - 1 );
503503}
504504
505505template <typename vtype, typename argtype, typename type_t >
506- X86_SIMD_SORT_INLINE void argselect_64bit_ (type_t *arr,
506+ X86_SIMD_SORT_INLINE void argselect_ (type_t *arr,
507507 arrsize_t *arg,
508508 arrsize_t pos,
509509 arrsize_t left,
@@ -531,30 +531,34 @@ X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr,
531531 arrsize_t pivot_index = argpartition_unrolled<vtype, argtype, 4 >(
532532 arr, arg, left, right + 1 , pivot, &smallest, &biggest);
533533 if ((pivot != smallest) && (pos < pivot_index))
534- argselect_64bit_ <vtype, argtype>(
534+ argselect_ <vtype, argtype>(
535535 arr, arg, pos, left, pivot_index - 1 , max_iters - 1 );
536536 else if ((pivot != biggest) && (pos >= pivot_index))
537- argselect_64bit_ <vtype, argtype>(
537+ argselect_ <vtype, argtype>(
538538 arr, arg, pos, pivot_index, right, max_iters - 1 );
539539}
540540
541541/* argsort methods for 32-bit and 64-bit dtypes */
542- template <typename T>
543- X86_SIMD_SORT_INLINE void avx512_argsort (T *arr,
542+ template <typename T,
543+ template <typename ...>
544+ typename full_vector,
545+ template <typename ...>
546+ typename half_vector>
547+ X86_SIMD_SORT_INLINE void xss_argsort (T *arr,
544548 arrsize_t *arg,
545549 arrsize_t arrsize,
546550 bool hasnan = false ,
547551 bool descending = false )
548552{
549- /* TODO optimization: on 32-bit, use zmm_vector for 32-bit dtype */
553+ /* TODO optimization: on 32-bit, use full_vector for 32-bit dtype */
550554 using vectype = typename std::conditional<sizeof (T) == sizeof (int32_t ),
551- ymm_vector <T>,
552- zmm_vector <T>>::type;
555+ half_vector <T>,
556+ full_vector <T>>::type;
553557
554558 using argtype =
555559 typename std::conditional<sizeof (arrsize_t ) == sizeof (int32_t ),
556- ymm_vector <arrsize_t >,
557- zmm_vector <arrsize_t >>::type;
560+ half_vector <arrsize_t >,
561+ full_vector <arrsize_t >>::type;
558562
559563 if (arrsize > 1 ) {
560564 if constexpr (xss::fp::is_floating_point_v<T>) {
@@ -567,64 +571,54 @@ X86_SIMD_SORT_INLINE void avx512_argsort(T *arr,
567571 }
568572 }
569573 UNUSED (hasnan);
570- argsort_64bit_ <vectype, argtype>(
574+ argsort_ <vectype, argtype>(
571575 arr, arg, 0 , arrsize - 1 , 2 * (arrsize_t )log2 (arrsize));
572576
573577 if (descending) { std::reverse (arg, arg + arrsize); }
574578 }
575579}
576580
577- /* argsort methods for 32-bit and 64-bit dtypes */
578581template <typename T>
579- X86_SIMD_SORT_INLINE void avx2_argsort (T *arr,
580- arrsize_t *arg,
581- arrsize_t arrsize,
582- bool hasnan = false ,
583- bool descending = false )
582+ X86_SIMD_SORT_INLINE void avx512_argsort (T *arr,
583+ arrsize_t *arg,
584+ arrsize_t arrsize,
585+ bool hasnan = false ,
586+ bool descending = false )
584587{
585- using vectype = typename std::conditional<sizeof (T) == sizeof (int32_t ),
586- avx2_half_vector<T>,
587- avx2_vector<T>>::type;
588-
589- using argtype =
590- typename std::conditional<sizeof (arrsize_t ) == sizeof (int32_t ),
591- avx2_half_vector<arrsize_t >,
592- avx2_vector<arrsize_t >>::type;
593- if (arrsize > 1 ) {
594- if constexpr (xss::fp::is_floating_point_v<T>) {
595- if ((hasnan) && (array_has_nan<vectype>(arr, arrsize))) {
596- std_argsort_withnan (arr, arg, 0 , arrsize);
597-
598- if (descending) { std::reverse (arg, arg + arrsize); }
599-
600- return ;
601- }
602- }
603- UNUSED (hasnan);
604- argsort_64bit_<vectype, argtype>(
605- arr, arg, 0 , arrsize - 1 , 2 * (arrsize_t )log2 (arrsize));
588+ xss_argsort<T, zmm_vector, ymm_vector>(arr, arg, arrsize, hasnan, descending);
589+ }
606590
607- if (descending) { std::reverse (arg, arg + arrsize); }
608- }
591+ template <typename T>
592+ X86_SIMD_SORT_INLINE void avx2_argsort (T *arr,
593+ arrsize_t *arg,
594+ arrsize_t arrsize,
595+ bool hasnan = false ,
596+ bool descending = false )
597+ {
598+ xss_argsort<T, avx2_vector, avx2_half_vector>(arr, arg, arrsize, hasnan, descending);
609599}
610600
611601/* argselect methods for 32-bit and 64-bit dtypes */
612- template <typename T>
613- X86_SIMD_SORT_INLINE void avx512_argselect (T *arr,
602+ template <typename T,
603+ template <typename ...>
604+ typename full_vector,
605+ template <typename ...>
606+ typename half_vector>
607+ X86_SIMD_SORT_INLINE void xss_argselect (T *arr,
614608 arrsize_t *arg,
615609 arrsize_t k,
616610 arrsize_t arrsize,
617611 bool hasnan = false )
618612{
619- /* TODO optimization: on 32-bit, use zmm_vector for 32-bit dtype */
613+ /* TODO optimization: on 32-bit, use full_vector for 32-bit dtype */
620614 using vectype = typename std::conditional<sizeof (T) == sizeof (int32_t ),
621- ymm_vector <T>,
622- zmm_vector <T>>::type;
615+ half_vector <T>,
616+ full_vector <T>>::type;
623617
624618 using argtype =
625619 typename std::conditional<sizeof (arrsize_t ) == sizeof (int32_t ),
626- ymm_vector <arrsize_t >,
627- zmm_vector <arrsize_t >>::type;
620+ half_vector <arrsize_t >,
621+ full_vector <arrsize_t >>::type;
628622
629623 if (arrsize > 1 ) {
630624 if constexpr (xss::fp::is_floating_point_v<T>) {
@@ -634,38 +628,29 @@ X86_SIMD_SORT_INLINE void avx512_argselect(T *arr,
634628 }
635629 }
636630 UNUSED (hasnan);
637- argselect_64bit_ <vectype, argtype>(
631+ argselect_ <vectype, argtype>(
638632 arr, arg, k, 0 , arrsize - 1 , 2 * (arrsize_t )log2 (arrsize));
639633 }
640634}
641635
642- /* argselect methods for 32-bit and 64-bit dtypes */
643636template <typename T>
644- X86_SIMD_SORT_INLINE void avx2_argselect (T *arr,
637+ X86_SIMD_SORT_INLINE void avx512_argselect (T *arr,
645638 arrsize_t *arg,
646639 arrsize_t k,
647640 arrsize_t arrsize,
648641 bool hasnan = false )
649642{
650- using vectype = typename std::conditional<sizeof (T) == sizeof (int32_t ),
651- avx2_half_vector<T>,
652- avx2_vector<T>>::type;
653-
654- using argtype =
655- typename std::conditional<sizeof (arrsize_t ) == sizeof (int32_t ),
656- avx2_half_vector<arrsize_t >,
657- avx2_vector<arrsize_t >>::type;
643+ xss_argselect<T, zmm_vector, ymm_vector>(arr, arg, k, arrsize, hasnan);
644+ }
658645
659- if (arrsize > 1 ) {
660- if constexpr (xss::fp::is_floating_point_v<T>) {
661- if ((hasnan) && (array_has_nan<vectype>(arr, arrsize))) {
662- std_argselect_withnan (arr, arg, k, 0 , arrsize);
663- return ;
664- }
665- }
666- UNUSED (hasnan);
667- argselect_64bit_<vectype, argtype>(
668- arr, arg, k, 0 , arrsize - 1 , 2 * (arrsize_t )log2 (arrsize));
669- }
646+ template <typename T>
647+ X86_SIMD_SORT_INLINE void avx2_argselect (T *arr,
648+ arrsize_t *arg,
649+ arrsize_t k,
650+ arrsize_t arrsize,
651+ bool hasnan = false )
652+ {
653+ xss_argselect<T, avx2_vector, avx2_half_vector>(arr, arg, k, arrsize, hasnan);
670654}
655+
671656#endif // XSS_COMMON_ARGSORT
0 commit comments