Skip to content

Commit 0f1023b

Browse files
author
Raghuveer Devulapalli
authored
Merge pull request #3 from r-devulap/float16
Add AVX-512 sort for float16 dtype
2 parents eda9e45 + 84a2065 commit 0f1023b

File tree

4 files changed

+288
-66
lines changed

4 files changed

+288
-66
lines changed

src/avx512-16bit-qsort.hpp

Lines changed: 226 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,156 @@ static const uint16_t network[6][32]
2929
{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
3030
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}};
3131

32+
struct float16 {
33+
uint16_t val;
34+
};
35+
36+
template <>
37+
struct zmm_vector<float16> {
38+
using type_t = uint16_t;
39+
using zmm_t = __m512i;
40+
using ymm_t = __m256i;
41+
using opmask_t = __mmask32;
42+
static const uint8_t numlanes = 32;
43+
44+
static zmm_t get_network(int index)
45+
{
46+
return _mm512_loadu_si512(&network[index - 1][0]);
47+
}
48+
static type_t type_max()
49+
{
50+
return X86_SIMD_SORT_INFINITYH;
51+
}
52+
static type_t type_min()
53+
{
54+
return X86_SIMD_SORT_NEGINFINITYH;
55+
}
56+
static zmm_t zmm_max()
57+
{
58+
return _mm512_set1_epi16(type_max());
59+
}
60+
static opmask_t knot_opmask(opmask_t x)
61+
{
62+
return _knot_mask32(x);
63+
}
64+
65+
static opmask_t ge(zmm_t x, zmm_t y)
66+
{
67+
zmm_t sign_x = _mm512_and_si512(x, _mm512_set1_epi16(0x8000));
68+
zmm_t sign_y = _mm512_and_si512(y, _mm512_set1_epi16(0x8000));
69+
zmm_t exp_x = _mm512_and_si512(x, _mm512_set1_epi16(0x7c00));
70+
zmm_t exp_y = _mm512_and_si512(y, _mm512_set1_epi16(0x7c00));
71+
zmm_t mant_x = _mm512_and_si512(x, _mm512_set1_epi16(0x3ff));
72+
zmm_t mant_y = _mm512_and_si512(y, _mm512_set1_epi16(0x3ff));
73+
74+
__mmask32 mask_ge = _mm512_cmp_epu16_mask(
75+
sign_x, sign_y, _MM_CMPINT_LT); // only greater than
76+
__mmask32 sign_eq = _mm512_cmpeq_epu16_mask(sign_x, sign_y);
77+
__mmask32 neg = _mm512_mask_cmpeq_epu16_mask(
78+
sign_eq,
79+
sign_x,
80+
_mm512_set1_epi16(0x8000)); // both numbers are -ve
81+
82+
// compare exponents only if signs are equal:
83+
mask_ge = mask_ge
84+
| _mm512_mask_cmp_epu16_mask(
85+
sign_eq, exp_x, exp_y, _MM_CMPINT_NLE);
86+
// get mask for elements for which both sign and exponents are equal:
87+
__mmask32 exp_eq = _mm512_mask_cmpeq_epu16_mask(sign_eq, exp_x, exp_y);
88+
89+
// compare mantissa for elements for which both sign and expponent are equal:
90+
mask_ge = mask_ge
91+
| _mm512_mask_cmp_epu16_mask(
92+
exp_eq, mant_x, mant_y, _MM_CMPINT_NLT);
93+
return _kxor_mask32(mask_ge, neg);
94+
}
95+
static zmm_t loadu(void const *mem)
96+
{
97+
return _mm512_loadu_si512(mem);
98+
}
99+
static zmm_t max(zmm_t x, zmm_t y)
100+
{
101+
return _mm512_mask_mov_epi16(y, ge(x, y), x);
102+
}
103+
static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x)
104+
{
105+
// AVX512_VBMI2
106+
return _mm512_mask_compressstoreu_epi16(mem, mask, x);
107+
}
108+
static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem)
109+
{
110+
// AVX512BW
111+
return _mm512_mask_loadu_epi16(x, mask, mem);
112+
}
113+
static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y)
114+
{
115+
return _mm512_mask_mov_epi16(x, mask, y);
116+
}
117+
static void mask_storeu(void *mem, opmask_t mask, zmm_t x)
118+
{
119+
return _mm512_mask_storeu_epi16(mem, mask, x);
120+
}
121+
static zmm_t min(zmm_t x, zmm_t y)
122+
{
123+
return _mm512_mask_mov_epi16(x, ge(x, y), y);
124+
}
125+
static zmm_t permutexvar(__m512i idx, zmm_t zmm)
126+
{
127+
return _mm512_permutexvar_epi16(idx, zmm);
128+
}
129+
// Apparently this is a terrible for perf, npy_half_to_float seems to work
130+
// better
131+
//static float uint16_to_float(uint16_t val)
132+
//{
133+
// // Ideally use _mm_loadu_si16, but its only gcc > 11.x
134+
// // TODO: use inline ASM? https://godbolt.org/z/aGYvh7fMM
135+
// __m128i xmm = _mm_maskz_loadu_epi16(0x01, &val);
136+
// __m128 xmm2 = _mm_cvtph_ps(xmm);
137+
// return _mm_cvtss_f32(xmm2);
138+
//}
139+
static type_t float_to_uint16(float val)
140+
{
141+
__m128 xmm = _mm_load_ss(&val);
142+
__m128i xmm2 = _mm_cvtps_ph(xmm, _MM_FROUND_NO_EXC);
143+
return _mm_extract_epi16(xmm2, 0);
144+
}
145+
static type_t reducemax(zmm_t v)
146+
{
147+
__m512 lo = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(v, 0));
148+
__m512 hi = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(v, 1));
149+
float lo_max = _mm512_reduce_max_ps(lo);
150+
float hi_max = _mm512_reduce_max_ps(hi);
151+
return float_to_uint16(std::max(lo_max, hi_max));
152+
}
153+
static type_t reducemin(zmm_t v)
154+
{
155+
__m512 lo = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(v, 0));
156+
__m512 hi = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(v, 1));
157+
float lo_max = _mm512_reduce_min_ps(lo);
158+
float hi_max = _mm512_reduce_min_ps(hi);
159+
return float_to_uint16(std::min(lo_max, hi_max));
160+
}
161+
static zmm_t set1(type_t v)
162+
{
163+
return _mm512_set1_epi16(v);
164+
}
165+
template <uint8_t mask>
166+
static zmm_t shuffle(zmm_t zmm)
167+
{
168+
zmm = _mm512_shufflehi_epi16(zmm, (_MM_PERM_ENUM)mask);
169+
return _mm512_shufflelo_epi16(zmm, (_MM_PERM_ENUM)mask);
170+
}
171+
static void storeu(void *mem, zmm_t x)
172+
{
173+
return _mm512_storeu_si512(mem, x);
174+
}
175+
};
176+
32177
template <>
33178
struct zmm_vector<int16_t> {
34179
using type_t = int16_t;
35180
using zmm_t = __m512i;
181+
using ymm_t = __m256i;
36182
using opmask_t = __mmask32;
37183
static const uint8_t numlanes = 32;
38184

@@ -130,6 +276,7 @@ template <>
130276
struct zmm_vector<uint16_t> {
131277
using type_t = uint16_t;
132278
using zmm_t = __m512i;
279+
using ymm_t = __m256i;
133280
using opmask_t = __mmask32;
134281
static const uint8_t numlanes = 32;
135282

@@ -227,7 +374,7 @@ struct zmm_vector<uint16_t> {
227374
* https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg
228375
*/
229376
template <typename vtype, typename zmm_t = typename vtype::zmm_t>
230-
X86_SIMD_SORT_FORCEINLINE zmm_t sort_zmm_16bit(zmm_t zmm)
377+
X86_SIMD_SORT_FINLINE zmm_t sort_zmm_16bit(zmm_t zmm)
231378
{
232379
// Level 1
233380
zmm = cmp_merge<vtype>(
@@ -287,7 +434,7 @@ X86_SIMD_SORT_FORCEINLINE zmm_t sort_zmm_16bit(zmm_t zmm)
287434

288435
// Assumes zmm is bitonic and performs a recursive half cleaner
289436
template <typename vtype, typename zmm_t = typename vtype::zmm_t>
290-
X86_SIMD_SORT_FORCEINLINE zmm_t bitonic_merge_zmm_16bit(zmm_t zmm)
437+
X86_SIMD_SORT_FINLINE zmm_t bitonic_merge_zmm_16bit(zmm_t zmm)
291438
{
292439
// 1) half_cleaner[32]: compare 1-17, 2-18, 3-19 etc ..
293440
zmm = cmp_merge<vtype>(
@@ -313,8 +460,7 @@ X86_SIMD_SORT_FORCEINLINE zmm_t bitonic_merge_zmm_16bit(zmm_t zmm)
313460

314461
// Assumes zmm1 and zmm2 are sorted and performs a recursive half cleaner
315462
template <typename vtype, typename zmm_t = typename vtype::zmm_t>
316-
X86_SIMD_SORT_FORCEINLINE void bitonic_merge_two_zmm_16bit(zmm_t &zmm1,
317-
zmm_t &zmm2)
463+
X86_SIMD_SORT_FINLINE void bitonic_merge_two_zmm_16bit(zmm_t &zmm1, zmm_t &zmm2)
318464
{
319465
// 1) First step of a merging network: coex of zmm1 and zmm2 reversed
320466
zmm2 = vtype::permutexvar(vtype::get_network(4), zmm2);
@@ -328,7 +474,7 @@ X86_SIMD_SORT_FORCEINLINE void bitonic_merge_two_zmm_16bit(zmm_t &zmm1,
328474
// Assumes [zmm0, zmm1] and [zmm2, zmm3] are sorted and performs a recursive
329475
// half cleaner
330476
template <typename vtype, typename zmm_t = typename vtype::zmm_t>
331-
X86_SIMD_SORT_FORCEINLINE void bitonic_merge_four_zmm_16bit(zmm_t *zmm)
477+
X86_SIMD_SORT_FINLINE void bitonic_merge_four_zmm_16bit(zmm_t *zmm)
332478
{
333479
zmm_t zmm2r = vtype::permutexvar(vtype::get_network(4), zmm[2]);
334480
zmm_t zmm3r = vtype::permutexvar(vtype::get_network(4), zmm[3]);
@@ -349,7 +495,7 @@ X86_SIMD_SORT_FORCEINLINE void bitonic_merge_four_zmm_16bit(zmm_t *zmm)
349495
}
350496

351497
template <typename vtype, typename type_t>
352-
X86_SIMD_SORT_FORCEINLINE void sort_32_16bit(type_t *arr, int32_t N)
498+
X86_SIMD_SORT_FINLINE void sort_32_16bit(type_t *arr, int32_t N)
353499
{
354500
typename vtype::opmask_t load_mask = ((0x1ull << N) - 0x1ull) & 0xFFFFFFFF;
355501
typename vtype::zmm_t zmm
@@ -358,7 +504,7 @@ X86_SIMD_SORT_FORCEINLINE void sort_32_16bit(type_t *arr, int32_t N)
358504
}
359505

360506
template <typename vtype, typename type_t>
361-
X86_SIMD_SORT_FORCEINLINE void sort_64_16bit(type_t *arr, int32_t N)
507+
X86_SIMD_SORT_FINLINE void sort_64_16bit(type_t *arr, int32_t N)
362508
{
363509
if (N <= 32) {
364510
sort_32_16bit<vtype>(arr, N);
@@ -377,7 +523,7 @@ X86_SIMD_SORT_FORCEINLINE void sort_64_16bit(type_t *arr, int32_t N)
377523
}
378524

379525
template <typename vtype, typename type_t>
380-
X86_SIMD_SORT_FORCEINLINE void sort_128_16bit(type_t *arr, int32_t N)
526+
X86_SIMD_SORT_FINLINE void sort_128_16bit(type_t *arr, int32_t N)
381527
{
382528
if (N <= 64) {
383529
sort_64_16bit<vtype>(arr, N);
@@ -410,9 +556,9 @@ X86_SIMD_SORT_FORCEINLINE void sort_128_16bit(type_t *arr, int32_t N)
410556
}
411557

412558
template <typename vtype, typename type_t>
413-
X86_SIMD_SORT_FORCEINLINE type_t get_pivot_16bit(type_t *arr,
414-
const int64_t left,
415-
const int64_t right)
559+
X86_SIMD_SORT_FINLINE type_t get_pivot_16bit(type_t *arr,
560+
const int64_t left,
561+
const int64_t right)
416562
{
417563
// median of 32
418564
int64_t size = (right - left) / 32;
@@ -453,6 +599,34 @@ X86_SIMD_SORT_FORCEINLINE type_t get_pivot_16bit(type_t *arr,
453599
return ((type_t *)&sort)[16];
454600
}
455601

602+
template <>
603+
bool comparison_func<zmm_vector<float16>>(const uint16_t &a, const uint16_t &b)
604+
{
605+
uint16_t signa = a & 0x8000, signb = b & 0x8000;
606+
uint16_t expa = a & 0x7c00, expb = b & 0x7c00;
607+
uint16_t manta = a & 0x3ff, mantb = b & 0x3ff;
608+
if (signa != signb) {
609+
// opposite signs
610+
return a > b;
611+
}
612+
else if (signa > 0) {
613+
// both -ve
614+
if (expa != expb) { return expa > expb; }
615+
else {
616+
return manta > mantb;
617+
}
618+
}
619+
else {
620+
// both +ve
621+
if (expa != expb) { return expa < expb; }
622+
else {
623+
return manta < mantb;
624+
}
625+
}
626+
627+
//return npy_half_to_float(a) < npy_half_to_float(b);
628+
}
629+
456630
template <typename vtype, typename type_t>
457631
static void
458632
qsort_16bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters)
@@ -461,7 +635,7 @@ qsort_16bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters)
461635
* Resort to std::sort if quicksort isnt making any progress
462636
*/
463637
if (max_iters <= 0) {
464-
std::sort(arr + left, arr + right + 1);
638+
std::sort(arr + left, arr + right + 1, comparison_func<vtype>);
465639
return;
466640
}
467641
/*
@@ -483,12 +657,40 @@ qsort_16bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters)
483657
qsort_16bit_<vtype>(arr, pivot_index, right, max_iters - 1);
484658
}
485659

660+
X86_SIMD_SORT_FINLINE int64_t replace_nan_with_inf(uint16_t *arr,
661+
int64_t arrsize)
662+
{
663+
int64_t nan_count = 0;
664+
__mmask16 loadmask = 0xFFFF;
665+
while (arrsize > 0) {
666+
if (arrsize < 16) { loadmask = (0x0001 << arrsize) - 0x0001; }
667+
__m256i in_zmm = _mm256_maskz_loadu_epi16(loadmask, arr);
668+
__m512 in_zmm_asfloat = _mm512_cvtph_ps(in_zmm);
669+
__mmask16 nanmask = _mm512_cmp_ps_mask(
670+
in_zmm_asfloat, in_zmm_asfloat, _CMP_NEQ_UQ);
671+
nan_count += _mm_popcnt_u32((int32_t)nanmask);
672+
_mm256_mask_storeu_epi16(arr, nanmask, YMM_MAX_HALF);
673+
arr += 16;
674+
arrsize -= 16;
675+
}
676+
return nan_count;
677+
}
678+
679+
X86_SIMD_SORT_FINLINE void
680+
replace_inf_with_nan(uint16_t *arr, int64_t arrsize, int64_t nan_count)
681+
{
682+
for (int64_t ii = arrsize - 1; nan_count > 0; --ii) {
683+
arr[ii] = 0xFFFF;
684+
nan_count -= 1;
685+
}
686+
}
687+
486688
template <>
487689
void avx512_qsort(int16_t *arr, int64_t arrsize)
488690
{
489691
if (arrsize > 1) {
490692
qsort_16bit_<zmm_vector<int16_t>, int16_t>(
491-
arr, 0, arrsize - 1, 2 * (63 - __builtin_clzll(arrsize)));
693+
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
492694
}
493695
}
494696

@@ -497,7 +699,17 @@ void avx512_qsort(uint16_t *arr, int64_t arrsize)
497699
{
498700
if (arrsize > 1) {
499701
qsort_16bit_<zmm_vector<uint16_t>, uint16_t>(
500-
arr, 0, arrsize - 1, 2 * (63 - __builtin_clzll(arrsize)));
702+
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
703+
}
704+
}
705+
706+
void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize)
707+
{
708+
if (arrsize > 1) {
709+
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
710+
qsort_16bit_<zmm_vector<float16>, uint16_t>(
711+
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
712+
replace_inf_with_nan(arr, arrsize, nan_count);
501713
}
502714
}
503715
#endif // AVX512_QSORT_16BIT

0 commit comments

Comments
 (0)