@@ -542,7 +542,7 @@ X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr,
542542/* argsort methods for 32-bit and 64-bit dtypes */
543543template <typename T>
544544X86_SIMD_SORT_INLINE void
545- avx512_argsort (T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false )
545+ avx512_argsort (T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false , bool descending = false )
546546{
547547 /* TODO optimization: on 32-bit, use zmm_vector for 32-bit dtype */
548548 using vectype = typename std::conditional<sizeof (T) == sizeof (int32_t ),
@@ -558,29 +558,38 @@ avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false)
558558 if constexpr (std::is_floating_point_v<T>) {
559559 if ((hasnan) && (array_has_nan<vectype>(arr, arrsize))) {
560560 std_argsort_withnan (arr, arg, 0 , arrsize);
561+
562+ if (descending){
563+ std::reverse (arg, arg + arrsize);
564+ }
565+
561566 return ;
562567 }
563568 }
564569 UNUSED (hasnan);
565570 argsort_64bit_<vectype, argtype>(
566571 arr, arg, 0 , arrsize - 1 , 2 * (arrsize_t )log2 (arrsize));
572+
573+ if (descending){
574+ std::reverse (arg, arg + arrsize);
575+ }
567576 }
568577}
569578
570579template <typename T>
571580X86_SIMD_SORT_INLINE std::vector<arrsize_t >
572- avx512_argsort (T *arr, arrsize_t arrsize, bool hasnan = false )
581+ avx512_argsort (T *arr, arrsize_t arrsize, bool hasnan = false , bool descending = false )
573582{
574583 std::vector<arrsize_t > indices (arrsize);
575584 std::iota (indices.begin (), indices.end (), 0 );
576- avx512_argsort<T>(arr, indices.data (), arrsize, hasnan);
585+ avx512_argsort<T>(arr, indices.data (), arrsize, hasnan, descending );
577586 return indices;
578587}
579588
580589/* argsort methods for 32-bit and 64-bit dtypes */
581590template <typename T>
582591X86_SIMD_SORT_INLINE void
583- avx2_argsort (T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false )
592+ avx2_argsort (T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false , bool descending = false )
584593{
585594 using vectype = typename std::conditional<sizeof (T) == sizeof (int32_t ),
586595 avx2_half_vector<T>,
@@ -594,22 +603,31 @@ avx2_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false)
594603 if constexpr (std::is_floating_point_v<T>) {
595604 if ((hasnan) && (array_has_nan<vectype>(arr, arrsize))) {
596605 std_argsort_withnan (arr, arg, 0 , arrsize);
606+
607+ if (descending){
608+ std::reverse (arg, arg + arrsize);
609+ }
610+
597611 return ;
598612 }
599613 }
600614 UNUSED (hasnan);
601615 argsort_64bit_<vectype, argtype>(
602616 arr, arg, 0 , arrsize - 1 , 2 * (arrsize_t )log2 (arrsize));
617+
618+ if (descending){
619+ std::reverse (arg, arg + arrsize);
620+ }
603621 }
604622}
605623
606624template <typename T>
607625X86_SIMD_SORT_INLINE std::vector<arrsize_t >
608- avx2_argsort (T *arr, arrsize_t arrsize, bool hasnan = false )
626+ avx2_argsort (T *arr, arrsize_t arrsize, bool hasnan = false , bool descending = false )
609627{
610628 std::vector<arrsize_t > indices (arrsize);
611629 std::iota (indices.begin (), indices.end (), 0 );
612- avx2_argsort<T>(arr, indices.data (), arrsize, hasnan);
630+ avx2_argsort<T>(arr, indices.data (), arrsize, hasnan, descending );
613631 return indices;
614632}
615633
@@ -631,7 +649,7 @@ X86_SIMD_SORT_INLINE void avx512_argselect(T *arr,
631649 ymm_vector<arrsize_t >,
632650 zmm_vector<arrsize_t >>::type;
633651
634- if (arrsize > 1 ) {
652+ if (arrsize > 1 ) {
635653 if constexpr (std::is_floating_point_v<T>) {
636654 if ((hasnan) && (array_has_nan<vectype>(arr, arrsize))) {
637655 std_argselect_withnan (arr, arg, k, 0 , arrsize);
0 commit comments