@@ -29,10 +29,156 @@ static const uint16_t network[6][32]
29
29
{16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 , 30 , 31 ,
30
30
0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 }};
31
31
32
+ struct float16 {
33
+ uint16_t val;
34
+ };
35
+
36
+ template <>
37
+ struct zmm_vector <float16> {
38
+ using type_t = uint16_t ;
39
+ using zmm_t = __m512i;
40
+ using ymm_t = __m256i;
41
+ using opmask_t = __mmask32;
42
+ static const uint8_t numlanes = 32 ;
43
+
44
+ static zmm_t get_network (int index)
45
+ {
46
+ return _mm512_loadu_si512 (&network[index - 1 ][0 ]);
47
+ }
48
+ static type_t type_max ()
49
+ {
50
+ return X86_SIMD_SORT_INFINITYH;
51
+ }
52
+ static type_t type_min ()
53
+ {
54
+ return X86_SIMD_SORT_NEGINFINITYH;
55
+ }
56
+ static zmm_t zmm_max ()
57
+ {
58
+ return _mm512_set1_epi16 (type_max ());
59
+ }
60
+ static opmask_t knot_opmask (opmask_t x)
61
+ {
62
+ return _knot_mask32 (x);
63
+ }
64
+
65
+ static opmask_t ge (zmm_t x, zmm_t y)
66
+ {
67
+ zmm_t sign_x = _mm512_and_si512 (x, _mm512_set1_epi16 (0x8000 ));
68
+ zmm_t sign_y = _mm512_and_si512 (y, _mm512_set1_epi16 (0x8000 ));
69
+ zmm_t exp_x = _mm512_and_si512 (x, _mm512_set1_epi16 (0x7c00 ));
70
+ zmm_t exp_y = _mm512_and_si512 (y, _mm512_set1_epi16 (0x7c00 ));
71
+ zmm_t mant_x = _mm512_and_si512 (x, _mm512_set1_epi16 (0x3ff ));
72
+ zmm_t mant_y = _mm512_and_si512 (y, _mm512_set1_epi16 (0x3ff ));
73
+
74
+ __mmask32 mask_ge = _mm512_cmp_epu16_mask (
75
+ sign_x, sign_y, _MM_CMPINT_LT); // only greater than
76
+ __mmask32 sign_eq = _mm512_cmpeq_epu16_mask (sign_x, sign_y);
77
+ __mmask32 neg = _mm512_mask_cmpeq_epu16_mask (
78
+ sign_eq,
79
+ sign_x,
80
+ _mm512_set1_epi16 (0x8000 )); // both numbers are -ve
81
+
82
+ // compare exponents only if signs are equal:
83
+ mask_ge = mask_ge
84
+ | _mm512_mask_cmp_epu16_mask (
85
+ sign_eq, exp_x, exp_y, _MM_CMPINT_NLE);
86
+ // get mask for elements for which both sign and exponents are equal:
87
+ __mmask32 exp_eq = _mm512_mask_cmpeq_epu16_mask (sign_eq, exp_x, exp_y);
88
+
89
+ // compare mantissa for elements for which both sign and expponent are equal:
90
+ mask_ge = mask_ge
91
+ | _mm512_mask_cmp_epu16_mask (
92
+ exp_eq, mant_x, mant_y, _MM_CMPINT_NLT);
93
+ return _kxor_mask32 (mask_ge, neg);
94
+ }
95
+ static zmm_t loadu (void const *mem)
96
+ {
97
+ return _mm512_loadu_si512 (mem);
98
+ }
99
+ static zmm_t max (zmm_t x, zmm_t y)
100
+ {
101
+ return _mm512_mask_mov_epi16 (y, ge (x, y), x);
102
+ }
103
+ static void mask_compressstoreu (void *mem, opmask_t mask, zmm_t x)
104
+ {
105
+ // AVX512_VBMI2
106
+ return _mm512_mask_compressstoreu_epi16 (mem, mask, x);
107
+ }
108
+ static zmm_t mask_loadu (zmm_t x, opmask_t mask, void const *mem)
109
+ {
110
+ // AVX512BW
111
+ return _mm512_mask_loadu_epi16 (x, mask, mem);
112
+ }
113
+ static zmm_t mask_mov (zmm_t x, opmask_t mask, zmm_t y)
114
+ {
115
+ return _mm512_mask_mov_epi16 (x, mask, y);
116
+ }
117
+ static void mask_storeu (void *mem, opmask_t mask, zmm_t x)
118
+ {
119
+ return _mm512_mask_storeu_epi16 (mem, mask, x);
120
+ }
121
+ static zmm_t min (zmm_t x, zmm_t y)
122
+ {
123
+ return _mm512_mask_mov_epi16 (x, ge (x, y), y);
124
+ }
125
+ static zmm_t permutexvar (__m512i idx, zmm_t zmm)
126
+ {
127
+ return _mm512_permutexvar_epi16 (idx, zmm);
128
+ }
129
+ // Apparently this is a terrible for perf, npy_half_to_float seems to work
130
+ // better
131
+ // static float uint16_to_float(uint16_t val)
132
+ // {
133
+ // // Ideally use _mm_loadu_si16, but its only gcc > 11.x
134
+ // // TODO: use inline ASM? https://godbolt.org/z/aGYvh7fMM
135
+ // __m128i xmm = _mm_maskz_loadu_epi16(0x01, &val);
136
+ // __m128 xmm2 = _mm_cvtph_ps(xmm);
137
+ // return _mm_cvtss_f32(xmm2);
138
+ // }
139
+ static type_t float_to_uint16 (float val)
140
+ {
141
+ __m128 xmm = _mm_load_ss (&val);
142
+ __m128i xmm2 = _mm_cvtps_ph (xmm, _MM_FROUND_NO_EXC);
143
+ return _mm_extract_epi16 (xmm2, 0 );
144
+ }
145
+ static type_t reducemax (zmm_t v)
146
+ {
147
+ __m512 lo = _mm512_cvtph_ps (_mm512_extracti64x4_epi64 (v, 0 ));
148
+ __m512 hi = _mm512_cvtph_ps (_mm512_extracti64x4_epi64 (v, 1 ));
149
+ float lo_max = _mm512_reduce_max_ps (lo);
150
+ float hi_max = _mm512_reduce_max_ps (hi);
151
+ return float_to_uint16 (std::max (lo_max, hi_max));
152
+ }
153
+ static type_t reducemin (zmm_t v)
154
+ {
155
+ __m512 lo = _mm512_cvtph_ps (_mm512_extracti64x4_epi64 (v, 0 ));
156
+ __m512 hi = _mm512_cvtph_ps (_mm512_extracti64x4_epi64 (v, 1 ));
157
+ float lo_max = _mm512_reduce_min_ps (lo);
158
+ float hi_max = _mm512_reduce_min_ps (hi);
159
+ return float_to_uint16 (std::min (lo_max, hi_max));
160
+ }
161
+ static zmm_t set1 (type_t v)
162
+ {
163
+ return _mm512_set1_epi16 (v);
164
+ }
165
+ template <uint8_t mask>
166
+ static zmm_t shuffle (zmm_t zmm)
167
+ {
168
+ zmm = _mm512_shufflehi_epi16 (zmm, (_MM_PERM_ENUM)mask);
169
+ return _mm512_shufflelo_epi16 (zmm, (_MM_PERM_ENUM)mask);
170
+ }
171
+ static void storeu (void *mem, zmm_t x)
172
+ {
173
+ return _mm512_storeu_si512 (mem, x);
174
+ }
175
+ };
176
+
32
177
template <>
33
178
struct zmm_vector <int16_t > {
34
179
using type_t = int16_t ;
35
180
using zmm_t = __m512i;
181
+ using ymm_t = __m256i;
36
182
using opmask_t = __mmask32;
37
183
static const uint8_t numlanes = 32 ;
38
184
@@ -130,6 +276,7 @@ template <>
130
276
struct zmm_vector <uint16_t > {
131
277
using type_t = uint16_t ;
132
278
using zmm_t = __m512i;
279
+ using ymm_t = __m256i;
133
280
using opmask_t = __mmask32;
134
281
static const uint8_t numlanes = 32 ;
135
282
@@ -227,7 +374,7 @@ struct zmm_vector<uint16_t> {
227
374
* https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg
228
375
*/
229
376
template <typename vtype, typename zmm_t = typename vtype::zmm_t >
230
- X86_SIMD_SORT_FORCEINLINE zmm_t sort_zmm_16bit (zmm_t zmm)
377
+ X86_SIMD_SORT_FINLINE zmm_t sort_zmm_16bit (zmm_t zmm)
231
378
{
232
379
// Level 1
233
380
zmm = cmp_merge<vtype>(
@@ -287,7 +434,7 @@ X86_SIMD_SORT_FORCEINLINE zmm_t sort_zmm_16bit(zmm_t zmm)
287
434
288
435
// Assumes zmm is bitonic and performs a recursive half cleaner
289
436
template <typename vtype, typename zmm_t = typename vtype::zmm_t >
290
- X86_SIMD_SORT_FORCEINLINE zmm_t bitonic_merge_zmm_16bit (zmm_t zmm)
437
+ X86_SIMD_SORT_FINLINE zmm_t bitonic_merge_zmm_16bit (zmm_t zmm)
291
438
{
292
439
// 1) half_cleaner[32]: compare 1-17, 2-18, 3-19 etc ..
293
440
zmm = cmp_merge<vtype>(
@@ -313,8 +460,7 @@ X86_SIMD_SORT_FORCEINLINE zmm_t bitonic_merge_zmm_16bit(zmm_t zmm)
313
460
314
461
// Assumes zmm1 and zmm2 are sorted and performs a recursive half cleaner
315
462
template <typename vtype, typename zmm_t = typename vtype::zmm_t >
316
- X86_SIMD_SORT_FORCEINLINE void bitonic_merge_two_zmm_16bit (zmm_t &zmm1,
317
- zmm_t &zmm2)
463
+ X86_SIMD_SORT_FINLINE void bitonic_merge_two_zmm_16bit (zmm_t &zmm1, zmm_t &zmm2)
318
464
{
319
465
// 1) First step of a merging network: coex of zmm1 and zmm2 reversed
320
466
zmm2 = vtype::permutexvar (vtype::get_network (4 ), zmm2);
@@ -328,7 +474,7 @@ X86_SIMD_SORT_FORCEINLINE void bitonic_merge_two_zmm_16bit(zmm_t &zmm1,
328
474
// Assumes [zmm0, zmm1] and [zmm2, zmm3] are sorted and performs a recursive
329
475
// half cleaner
330
476
template <typename vtype, typename zmm_t = typename vtype::zmm_t >
331
- X86_SIMD_SORT_FORCEINLINE void bitonic_merge_four_zmm_16bit (zmm_t *zmm)
477
+ X86_SIMD_SORT_FINLINE void bitonic_merge_four_zmm_16bit (zmm_t *zmm)
332
478
{
333
479
zmm_t zmm2r = vtype::permutexvar (vtype::get_network (4 ), zmm[2 ]);
334
480
zmm_t zmm3r = vtype::permutexvar (vtype::get_network (4 ), zmm[3 ]);
@@ -349,7 +495,7 @@ X86_SIMD_SORT_FORCEINLINE void bitonic_merge_four_zmm_16bit(zmm_t *zmm)
349
495
}
350
496
351
497
template <typename vtype, typename type_t >
352
- X86_SIMD_SORT_FORCEINLINE void sort_32_16bit (type_t *arr, int32_t N)
498
+ X86_SIMD_SORT_FINLINE void sort_32_16bit (type_t *arr, int32_t N)
353
499
{
354
500
typename vtype::opmask_t load_mask = ((0x1ull << N) - 0x1ull ) & 0xFFFFFFFF ;
355
501
typename vtype::zmm_t zmm
@@ -358,7 +504,7 @@ X86_SIMD_SORT_FORCEINLINE void sort_32_16bit(type_t *arr, int32_t N)
358
504
}
359
505
360
506
template <typename vtype, typename type_t >
361
- X86_SIMD_SORT_FORCEINLINE void sort_64_16bit (type_t *arr, int32_t N)
507
+ X86_SIMD_SORT_FINLINE void sort_64_16bit (type_t *arr, int32_t N)
362
508
{
363
509
if (N <= 32 ) {
364
510
sort_32_16bit<vtype>(arr, N);
@@ -377,7 +523,7 @@ X86_SIMD_SORT_FORCEINLINE void sort_64_16bit(type_t *arr, int32_t N)
377
523
}
378
524
379
525
template <typename vtype, typename type_t >
380
- X86_SIMD_SORT_FORCEINLINE void sort_128_16bit (type_t *arr, int32_t N)
526
+ X86_SIMD_SORT_FINLINE void sort_128_16bit (type_t *arr, int32_t N)
381
527
{
382
528
if (N <= 64 ) {
383
529
sort_64_16bit<vtype>(arr, N);
@@ -410,9 +556,9 @@ X86_SIMD_SORT_FORCEINLINE void sort_128_16bit(type_t *arr, int32_t N)
410
556
}
411
557
412
558
template <typename vtype, typename type_t >
413
- X86_SIMD_SORT_FORCEINLINE type_t get_pivot_16bit (type_t *arr,
414
- const int64_t left,
415
- const int64_t right)
559
+ X86_SIMD_SORT_FINLINE type_t get_pivot_16bit (type_t *arr,
560
+ const int64_t left,
561
+ const int64_t right)
416
562
{
417
563
// median of 32
418
564
int64_t size = (right - left) / 32 ;
@@ -453,6 +599,34 @@ X86_SIMD_SORT_FORCEINLINE type_t get_pivot_16bit(type_t *arr,
453
599
return ((type_t *)&sort)[16 ];
454
600
}
455
601
602
+ template <>
603
+ bool comparison_func<zmm_vector<float16>>(const uint16_t &a, const uint16_t &b)
604
+ {
605
+ uint16_t signa = a & 0x8000 , signb = b & 0x8000 ;
606
+ uint16_t expa = a & 0x7c00 , expb = b & 0x7c00 ;
607
+ uint16_t manta = a & 0x3ff , mantb = b & 0x3ff ;
608
+ if (signa != signb) {
609
+ // opposite signs
610
+ return a > b;
611
+ }
612
+ else if (signa > 0 ) {
613
+ // both -ve
614
+ if (expa != expb) { return expa > expb; }
615
+ else {
616
+ return manta > mantb;
617
+ }
618
+ }
619
+ else {
620
+ // both +ve
621
+ if (expa != expb) { return expa < expb; }
622
+ else {
623
+ return manta < mantb;
624
+ }
625
+ }
626
+
627
+ // return npy_half_to_float(a) < npy_half_to_float(b);
628
+ }
629
+
456
630
template <typename vtype, typename type_t >
457
631
static void
458
632
qsort_16bit_ (type_t *arr, int64_t left, int64_t right, int64_t max_iters)
@@ -461,7 +635,7 @@ qsort_16bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters)
461
635
* Resort to std::sort if quicksort isnt making any progress
462
636
*/
463
637
if (max_iters <= 0 ) {
464
- std::sort (arr + left, arr + right + 1 );
638
+ std::sort (arr + left, arr + right + 1 , comparison_func<vtype> );
465
639
return ;
466
640
}
467
641
/*
@@ -483,12 +657,40 @@ qsort_16bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters)
483
657
qsort_16bit_<vtype>(arr, pivot_index, right, max_iters - 1 );
484
658
}
485
659
660
+ X86_SIMD_SORT_FINLINE int64_t replace_nan_with_inf (uint16_t *arr,
661
+ int64_t arrsize)
662
+ {
663
+ int64_t nan_count = 0 ;
664
+ __mmask16 loadmask = 0xFFFF ;
665
+ while (arrsize > 0 ) {
666
+ if (arrsize < 16 ) { loadmask = (0x0001 << arrsize) - 0x0001 ; }
667
+ __m256i in_zmm = _mm256_maskz_loadu_epi16 (loadmask, arr);
668
+ __m512 in_zmm_asfloat = _mm512_cvtph_ps (in_zmm);
669
+ __mmask16 nanmask = _mm512_cmp_ps_mask (
670
+ in_zmm_asfloat, in_zmm_asfloat, _CMP_NEQ_UQ);
671
+ nan_count += _mm_popcnt_u32 ((int32_t )nanmask);
672
+ _mm256_mask_storeu_epi16 (arr, nanmask, YMM_MAX_HALF);
673
+ arr += 16 ;
674
+ arrsize -= 16 ;
675
+ }
676
+ return nan_count;
677
+ }
678
+
679
+ X86_SIMD_SORT_FINLINE void
680
+ replace_inf_with_nan (uint16_t *arr, int64_t arrsize, int64_t nan_count)
681
+ {
682
+ for (int64_t ii = arrsize - 1 ; nan_count > 0 ; --ii) {
683
+ arr[ii] = 0xFFFF ;
684
+ nan_count -= 1 ;
685
+ }
686
+ }
687
+
486
688
template <>
487
689
void avx512_qsort (int16_t *arr, int64_t arrsize)
488
690
{
489
691
if (arrsize > 1 ) {
490
692
qsort_16bit_<zmm_vector<int16_t >, int16_t >(
491
- arr, 0 , arrsize - 1 , 2 * (63 - __builtin_clzll (arrsize) ));
693
+ arr, 0 , arrsize - 1 , 2 * (int64_t ) log2 (arrsize));
492
694
}
493
695
}
494
696
@@ -497,7 +699,17 @@ void avx512_qsort(uint16_t *arr, int64_t arrsize)
497
699
{
498
700
if (arrsize > 1 ) {
499
701
qsort_16bit_<zmm_vector<uint16_t >, uint16_t >(
500
- arr, 0 , arrsize - 1 , 2 * (63 - __builtin_clzll (arrsize)));
702
+ arr, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
703
+ }
704
+ }
705
+
706
+ void avx512_qsort_fp16 (uint16_t *arr, int64_t arrsize)
707
+ {
708
+ if (arrsize > 1 ) {
709
+ int64_t nan_count = replace_nan_with_inf (arr, arrsize);
710
+ qsort_16bit_<zmm_vector<float16>, uint16_t >(
711
+ arr, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
712
+ replace_inf_with_nan (arr, arrsize, nan_count);
501
713
}
502
714
}
503
715
#endif // AVX512_QSORT_16BIT
0 commit comments