Skip to content

Commit 720e1f7

Browse files
author
Raghuveer Devulapalli
committed
Remove template specializations for quicksort methods
1 parent dfbcb09 commit 720e1f7

7 files changed

+113
-165
lines changed

src/avx512-16bit-qsort.hpp

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -377,8 +377,9 @@ bool comparison_func<zmm_vector<float16>>(const uint16_t &a, const uint16_t &b)
377377
//return npy_half_to_float(a) < npy_half_to_float(b);
378378
}
379379

380-
X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(uint16_t *arr,
381-
int64_t arrsize)
380+
template<>
381+
int64_t
382+
replace_nan_with_inf<zmm_vector<float16>>(uint16_t *arr, int64_t arrsize)
382383
{
383384
int64_t nan_count = 0;
384385
__mmask16 loadmask = 0xFFFF;
@@ -396,15 +397,6 @@ X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(uint16_t *arr,
396397
return nan_count;
397398
}
398399

399-
X86_SIMD_SORT_INLINE void
400-
replace_inf_with_nan(uint16_t *arr, int64_t arrsize, int64_t nan_count)
401-
{
402-
for (int64_t ii = arrsize - 1; nan_count > 0; --ii) {
403-
arr[ii] = 0xFFFF;
404-
nan_count -= 1;
405-
}
406-
}
407-
408400
template <>
409401
bool is_a_nan<uint16_t>(uint16_t elem)
410402
{
@@ -442,27 +434,21 @@ void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan)
442434
}
443435

444436
template <>
445-
void avx512_qsort(int16_t *arr, int64_t arrsize)
437+
void qsort_<zmm_vector<int16_t>>(int16_t* arr, int64_t left, int64_t right, int64_t maxiters)
446438
{
447-
if (arrsize > 1) {
448-
qsort_16bit_<zmm_vector<int16_t>, int16_t>(
449-
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
450-
}
439+
qsort_16bit_<zmm_vector<int16_t>>(arr, left, right, maxiters);
451440
}
452441

453442
template <>
454-
void avx512_qsort(uint16_t *arr, int64_t arrsize)
443+
void qsort_<zmm_vector<uint16_t>>(uint16_t* arr, int64_t left, int64_t right, int64_t maxiters)
455444
{
456-
if (arrsize > 1) {
457-
qsort_16bit_<zmm_vector<uint16_t>, uint16_t>(
458-
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
459-
}
445+
qsort_16bit_<zmm_vector<uint16_t>>(arr, left, right, maxiters);
460446
}
461447

462448
void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize)
463449
{
464450
if (arrsize > 1) {
465-
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
451+
int64_t nan_count = replace_nan_with_inf<zmm_vector<float16>, uint16_t>(arr, arrsize);
466452
qsort_16bit_<zmm_vector<float16>, uint16_t>(
467453
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
468454
replace_inf_with_nan(arr, arrsize, nan_count);

src/avx512-32bit-qsort.hpp

Lines changed: 15 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,11 @@ struct zmm_vector<float> {
256256
{
257257
return _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ);
258258
}
259+
template <int type>
260+
static opmask_t fpclass(zmm_t x)
261+
{
262+
return _mm512_fpclass_ps_mask(x, type);
263+
}
259264
template <int scale>
260265
static ymm_t i64gather(__m512i index, void const *base)
261266
{
@@ -279,6 +284,10 @@ struct zmm_vector<float> {
279284
{
280285
return _mm512_mask_compressstoreu_ps(mem, mask, x);
281286
}
287+
static zmm_t maskz_loadu(opmask_t mask, void const *mem)
288+
{
289+
return _mm512_maskz_loadu_ps(mask, mem);
290+
}
282291
static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem)
283292
{
284293
return _mm512_mask_loadu_ps(x, mask, mem);
@@ -689,31 +698,6 @@ static void qselect_32bit_(type_t *arr,
689698
qselect_32bit_<vtype>(arr, pos, pivot_index, right, max_iters - 1);
690699
}
691700

692-
X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(float *arr, int64_t arrsize)
693-
{
694-
int64_t nan_count = 0;
695-
__mmask16 loadmask = 0xFFFF;
696-
while (arrsize > 0) {
697-
if (arrsize < 16) { loadmask = (0x0001 << arrsize) - 0x0001; }
698-
__m512 in_zmm = _mm512_maskz_loadu_ps(loadmask, arr);
699-
__mmask16 nanmask = _mm512_cmp_ps_mask(in_zmm, in_zmm, _CMP_NEQ_UQ);
700-
nan_count += _mm_popcnt_u32((int32_t)nanmask);
701-
_mm512_mask_storeu_ps(arr, nanmask, ZMM_MAX_FLOAT);
702-
arr += 16;
703-
arrsize -= 16;
704-
}
705-
return nan_count;
706-
}
707-
708-
X86_SIMD_SORT_INLINE void
709-
replace_inf_with_nan(float *arr, int64_t arrsize, int64_t nan_count)
710-
{
711-
for (int64_t ii = arrsize - 1; nan_count > 0; --ii) {
712-
arr[ii] = std::nanf("1");
713-
nan_count -= 1;
714-
}
715-
}
716-
717701
template <>
718702
void avx512_qselect<int32_t>(int32_t *arr,
719703
int64_t k,
@@ -752,32 +736,20 @@ void avx512_qselect<float>(float *arr, int64_t k, int64_t arrsize, bool hasnan)
752736
}
753737

754738
template <>
755-
void avx512_qsort<int32_t>(int32_t *arr, int64_t arrsize)
739+
void qsort_<zmm_vector<int32_t>>(int32_t* arr, int64_t left, int64_t right, int64_t maxiters)
756740
{
757-
if (arrsize > 1) {
758-
qsort_32bit_<zmm_vector<int32_t>, int32_t>(
759-
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
760-
}
741+
qsort_32bit_<zmm_vector<int32_t>>(arr, left, right, maxiters);
761742
}
762743

763744
template <>
764-
void avx512_qsort<uint32_t>(uint32_t *arr, int64_t arrsize)
745+
void qsort_<zmm_vector<uint32_t>>(uint32_t* arr, int64_t left, int64_t right, int64_t maxiters)
765746
{
766-
if (arrsize > 1) {
767-
qsort_32bit_<zmm_vector<uint32_t>, uint32_t>(
768-
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
769-
}
747+
qsort_32bit_<zmm_vector<uint32_t>>(arr, left, right, maxiters);
770748
}
771749

772750
template <>
773-
void avx512_qsort<float>(float *arr, int64_t arrsize)
751+
void qsort_<zmm_vector<float>>(float* arr, int64_t left, int64_t right, int64_t maxiters)
774752
{
775-
if (arrsize > 1) {
776-
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
777-
qsort_32bit_<zmm_vector<float>, float>(
778-
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
779-
replace_inf_with_nan(arr, arrsize, nan_count);
780-
}
753+
qsort_32bit_<zmm_vector<float>>(arr, left, right, maxiters);
781754
}
782-
783755
#endif //AVX512_QSORT_32BIT

src/avx512-64bit-common.h

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -773,30 +773,7 @@ struct zmm_vector<double> {
773773
_mm512_storeu_pd(mem, x);
774774
}
775775
};
776-
X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(double *arr, int64_t arrsize)
777-
{
778-
int64_t nan_count = 0;
779-
__mmask8 loadmask = 0xFF;
780-
while (arrsize > 0) {
781-
if (arrsize < 8) { loadmask = (0x01 << arrsize) - 0x01; }
782-
__m512d in_zmm = _mm512_maskz_loadu_pd(loadmask, arr);
783-
__mmask8 nanmask = _mm512_cmp_pd_mask(in_zmm, in_zmm, _CMP_NEQ_UQ);
784-
nan_count += _mm_popcnt_u32((int32_t)nanmask);
785-
_mm512_mask_storeu_pd(arr, nanmask, ZMM_MAX_DOUBLE);
786-
arr += 8;
787-
arrsize -= 8;
788-
}
789-
return nan_count;
790-
}
791776

792-
X86_SIMD_SORT_INLINE void
793-
replace_inf_with_nan(double *arr, int64_t arrsize, int64_t nan_count)
794-
{
795-
for (int64_t ii = arrsize - 1; nan_count > 0; --ii) {
796-
arr[ii] = std::nan("1");
797-
nan_count -= 1;
798-
}
799-
}
800777
/*
801778
* Assumes zmm is random and performs a full sorting network defined in
802779
* https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg

src/avx512-64bit-keyvaluesort.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ template <>
463463
void avx512_qsort_kv<double>(double *keys, uint64_t *indexes, int64_t arrsize)
464464
{
465465
if (arrsize > 1) {
466-
int64_t nan_count = replace_nan_with_inf(keys, arrsize);
466+
int64_t nan_count = replace_nan_with_inf<zmm_vector<double>>(keys, arrsize);
467467
qsort_64bit_<zmm_vector<double>, zmm_vector<uint64_t>>(
468468
keys, indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
469469
replace_inf_with_nan(keys, arrsize, nan_count);

src/avx512-64bit-qsort.hpp

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -824,31 +824,20 @@ void avx512_qselect<double>(double *arr,
824824
}
825825

826826
template <>
827-
void avx512_qsort<int64_t>(int64_t *arr, int64_t arrsize)
827+
void qsort_<zmm_vector<int64_t>>(int64_t* arr, int64_t left, int64_t right, int64_t maxiters)
828828
{
829-
if (arrsize > 1) {
830-
qsort_64bit_<zmm_vector<int64_t>, int64_t>(
831-
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
832-
}
829+
qsort_64bit_<zmm_vector<int64_t>>(arr, left, right, maxiters);
833830
}
834831

835832
template <>
836-
void avx512_qsort<uint64_t>(uint64_t *arr, int64_t arrsize)
833+
void qsort_<zmm_vector<uint64_t>>(uint64_t* arr, int64_t left, int64_t right, int64_t maxiters)
837834
{
838-
if (arrsize > 1) {
839-
qsort_64bit_<zmm_vector<uint64_t>, uint64_t>(
840-
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
841-
}
835+
qsort_64bit_<zmm_vector<uint64_t>>(arr, left, right, maxiters);
842836
}
843837

844838
template <>
845-
void avx512_qsort<double>(double *arr, int64_t arrsize)
839+
void qsort_<zmm_vector<double>>(double* arr, int64_t left, int64_t right, int64_t maxiters)
846840
{
847-
if (arrsize > 1) {
848-
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
849-
qsort_64bit_<zmm_vector<double>, double>(
850-
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
851-
replace_inf_with_nan(arr, arrsize, nan_count);
852-
}
841+
qsort_64bit_<zmm_vector<double>>(arr, left, right, maxiters);
853842
}
854843
#endif // AVX512_QSORT_64BIT

src/avx512-common-qsort.h

Lines changed: 81 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -94,35 +94,50 @@ struct zmm_vector;
9494
template <typename type>
9595
struct ymm_vector;
9696

97-
// Regular quicksort routines:
9897
template <typename T>
99-
void avx512_qsort(T *arr, int64_t arrsize);
100-
void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize);
101-
102-
template <typename T>
103-
void avx512_qselect(T *arr, int64_t k, int64_t arrsize, bool hasnan = false);
104-
void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan = false);
105-
106-
template <typename T>
107-
inline void avx512_partial_qsort(T *arr, int64_t k, int64_t arrsize, bool hasnan = false)
98+
bool is_a_nan(T elem)
10899
{
109-
avx512_qselect<T>(arr, k - 1, arrsize, hasnan);
110-
avx512_qsort<T>(arr, k - 1);
100+
return std::isnan(elem);
111101
}
112-
inline void avx512_partial_qsort_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan = false)
102+
103+
template <typename vtype, typename type_t>
104+
int64_t replace_nan_with_inf(type_t *arr, int64_t arrsize)
113105
{
114-
avx512_qselect_fp16(arr, k - 1, arrsize, hasnan);
115-
avx512_qsort_fp16(arr, k - 1);
106+
int64_t nan_count = 0;
107+
using opmask_t = typename vtype::opmask_t;
108+
using zmm_t = typename vtype::zmm_t;
109+
bool found_nan = false;
110+
opmask_t loadmask = 0xFF;
111+
zmm_t in;
112+
while (arrsize > 0) {
113+
if (arrsize < vtype::numlanes) {
114+
loadmask = (0x01 << arrsize) - 0x01;
115+
in = vtype::maskz_loadu(loadmask, arr);
116+
}
117+
else {
118+
in = vtype::loadu(arr);
119+
}
120+
opmask_t nanmask = vtype::template fpclass<0x01 | 0x80>(in);
121+
nan_count += _mm_popcnt_u32((int32_t)nanmask);
122+
vtype::mask_storeu(arr, nanmask, vtype::zmm_max());
123+
arr += vtype::numlanes;
124+
arrsize -= vtype::numlanes;
125+
}
126+
return nan_count;
116127
}
117128

118-
// key-value sort routines
119-
template <typename T>
120-
void avx512_qsort_kv(T *keys, uint64_t *indexes, int64_t arrsize);
121-
122-
template <typename T>
123-
bool is_a_nan(T elem)
129+
template<typename type_t>
130+
void replace_inf_with_nan(type_t *arr, int64_t arrsize, int64_t nan_count)
124131
{
125-
return std::isnan(elem);
132+
for (int64_t ii = arrsize - 1; nan_count > 0; --ii) {
133+
if constexpr (std::is_floating_point_v<type_t>) {
134+
arr[ii] = std::numeric_limits<type_t>::quiet_NaN();
135+
}
136+
else {
137+
arr[ii] = 0xFFFF;
138+
}
139+
nan_count -= 1;
140+
}
126141
}
127142

128143
/*
@@ -628,4 +643,48 @@ static inline int64_t partition_avx512(type_t1 *keys,
628643
*biggest = vtype1::reducemax(max_vec);
629644
return l_store;
630645
}
646+
647+
template <typename vtype, typename type_t>
648+
void qsort_(type_t* arr, int64_t left, int64_t right, int64_t maxiters);
649+
650+
// Regular quicksort routines:
651+
template <typename T>
652+
void avx512_qsort(T *arr, int64_t arrsize)
653+
{
654+
if (arrsize > 1) {
655+
if constexpr (std::is_floating_point_v<T>) {
656+
int64_t nan_count = replace_nan_with_inf<zmm_vector<T>>(arr, arrsize);
657+
qsort_<zmm_vector<T>, T>(
658+
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
659+
replace_inf_with_nan(arr, arrsize, nan_count);
660+
}
661+
else {
662+
qsort_<zmm_vector<T>, T>(
663+
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
664+
}
665+
}
666+
}
667+
668+
void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize);
669+
670+
template <typename T>
671+
void avx512_qselect(T *arr, int64_t k, int64_t arrsize, bool hasnan = false);
672+
void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan = false);
673+
674+
template <typename T>
675+
inline void avx512_partial_qsort(T *arr, int64_t k, int64_t arrsize, bool hasnan = false)
676+
{
677+
avx512_qselect<T>(arr, k - 1, arrsize, hasnan);
678+
avx512_qsort<T>(arr, k - 1);
679+
}
680+
inline void avx512_partial_qsort_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan = false)
681+
{
682+
avx512_qselect_fp16(arr, k - 1, arrsize, hasnan);
683+
avx512_qsort_fp16(arr, k - 1);
684+
}
685+
686+
// key-value sort routines
687+
template <typename T>
688+
void avx512_qsort_kv(T *keys, uint64_t *indexes, int64_t arrsize);
689+
631690
#endif // AVX512_QSORT_COMMON

0 commit comments

Comments
 (0)