@@ -94,35 +94,50 @@ struct zmm_vector;
94
94
template <typename type>
95
95
struct ymm_vector ;
96
96
97
- // Regular quicksort routines:
98
97
template <typename T>
99
- void avx512_qsort (T *arr, int64_t arrsize);
100
- void avx512_qsort_fp16 (uint16_t *arr, int64_t arrsize);
101
-
102
- template <typename T>
103
- void avx512_qselect (T *arr, int64_t k, int64_t arrsize, bool hasnan = false );
104
- void avx512_qselect_fp16 (uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan = false );
105
-
106
- template <typename T>
107
- inline void avx512_partial_qsort (T *arr, int64_t k, int64_t arrsize, bool hasnan = false )
98
+ bool is_a_nan (T elem)
108
99
{
109
- avx512_qselect<T>(arr, k - 1 , arrsize, hasnan);
110
- avx512_qsort<T>(arr, k - 1 );
100
+ return std::isnan (elem);
111
101
}
112
- inline void avx512_partial_qsort_fp16 (uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan = false )
102
+
103
+ template <typename vtype, typename type_t >
104
+ int64_t replace_nan_with_inf (type_t *arr, int64_t arrsize)
113
105
{
114
- avx512_qselect_fp16 (arr, k - 1 , arrsize, hasnan);
115
- avx512_qsort_fp16 (arr, k - 1 );
106
+ int64_t nan_count = 0 ;
107
+ using opmask_t = typename vtype::opmask_t ;
108
+ using zmm_t = typename vtype::zmm_t ;
109
+ bool found_nan = false ;
110
+ opmask_t loadmask = 0xFF ;
111
+ zmm_t in;
112
+ while (arrsize > 0 ) {
113
+ if (arrsize < vtype::numlanes) {
114
+ loadmask = (0x01 << arrsize) - 0x01 ;
115
+ in = vtype::maskz_loadu (loadmask, arr);
116
+ }
117
+ else {
118
+ in = vtype::loadu (arr);
119
+ }
120
+ opmask_t nanmask = vtype::template fpclass<0x01 | 0x80 >(in);
121
+ nan_count += _mm_popcnt_u32 ((int32_t )nanmask);
122
+ vtype::mask_storeu (arr, nanmask, vtype::zmm_max ());
123
+ arr += vtype::numlanes;
124
+ arrsize -= vtype::numlanes;
125
+ }
126
+ return nan_count;
116
127
}
117
128
118
- // key-value sort routines
119
- template <typename T>
120
- void avx512_qsort_kv (T *keys, uint64_t *indexes, int64_t arrsize);
121
-
122
- template <typename T>
123
- bool is_a_nan (T elem)
129
+ template <typename type_t >
130
+ void replace_inf_with_nan (type_t *arr, int64_t arrsize, int64_t nan_count)
124
131
{
125
- return std::isnan (elem);
132
+ for (int64_t ii = arrsize - 1 ; nan_count > 0 ; --ii) {
133
+ if constexpr (std::is_floating_point_v<type_t >) {
134
+ arr[ii] = std::numeric_limits<type_t >::quiet_NaN ();
135
+ }
136
+ else {
137
+ arr[ii] = 0xFFFF ;
138
+ }
139
+ nan_count -= 1 ;
140
+ }
126
141
}
127
142
128
143
/*
@@ -628,4 +643,48 @@ static inline int64_t partition_avx512(type_t1 *keys,
628
643
*biggest = vtype1::reducemax (max_vec);
629
644
return l_store;
630
645
}
646
+
647
+ template <typename vtype, typename type_t >
648
+ void qsort_ (type_t * arr, int64_t left, int64_t right, int64_t maxiters);
649
+
650
+ // Regular quicksort routines:
651
+ template <typename T>
652
+ void avx512_qsort (T *arr, int64_t arrsize)
653
+ {
654
+ if (arrsize > 1 ) {
655
+ if constexpr (std::is_floating_point_v<T>) {
656
+ int64_t nan_count = replace_nan_with_inf<zmm_vector<T>>(arr, arrsize);
657
+ qsort_<zmm_vector<T>, T>(
658
+ arr, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
659
+ replace_inf_with_nan (arr, arrsize, nan_count);
660
+ }
661
+ else {
662
+ qsort_<zmm_vector<T>, T>(
663
+ arr, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
664
+ }
665
+ }
666
+ }
667
+
668
+ void avx512_qsort_fp16 (uint16_t *arr, int64_t arrsize);
669
+
670
+ template <typename T>
671
+ void avx512_qselect (T *arr, int64_t k, int64_t arrsize, bool hasnan = false );
672
+ void avx512_qselect_fp16 (uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan = false );
673
+
674
+ template <typename T>
675
+ inline void avx512_partial_qsort (T *arr, int64_t k, int64_t arrsize, bool hasnan = false )
676
+ {
677
+ avx512_qselect<T>(arr, k - 1 , arrsize, hasnan);
678
+ avx512_qsort<T>(arr, k - 1 );
679
+ }
680
+ inline void avx512_partial_qsort_fp16 (uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan = false )
681
+ {
682
+ avx512_qselect_fp16 (arr, k - 1 , arrsize, hasnan);
683
+ avx512_qsort_fp16 (arr, k - 1 );
684
+ }
685
+
686
+ // key-value sort routines
687
+ template <typename T>
688
+ void avx512_qsort_kv (T *keys, uint64_t *indexes, int64_t arrsize);
689
+
631
690
#endif // AVX512_QSORT_COMMON
0 commit comments