@@ -46,11 +46,19 @@ struct zmm_vector<_Float16> {
46
46
{
47
47
return _knot_mask32 (x);
48
48
}
49
-
50
49
static opmask_t ge (zmm_t x, zmm_t y)
51
50
{
52
51
return _mm512_cmp_ph_mask (x, y, _CMP_GE_OQ);
53
52
}
53
+ static opmask_t get_partial_loadmask (int size)
54
+ {
55
+ return (0x00000001 << size) - 0x00000001 ;
56
+ }
57
+ template <int type>
58
+ static opmask_t fpclass (zmm_t x)
59
+ {
60
+ return _mm512_fpclass_ph_mask (x, type);
61
+ }
54
62
static zmm_t loadu (void const *mem)
55
63
{
56
64
return _mm512_loadu_ph (mem);
@@ -65,6 +73,11 @@ struct zmm_vector<_Float16> {
65
73
// AVX512_VBMI2
66
74
return _mm512_mask_compressstoreu_epi16 (mem, mask, temp);
67
75
}
76
+ static zmm_t maskz_loadu (opmask_t mask, void const *mem)
77
+ {
78
+ return _mm512_castsi512_ph (
79
+ _mm512_maskz_loadu_epi16 (mask, mem));
80
+ }
68
81
static zmm_t mask_loadu (zmm_t x, opmask_t mask, void const *mem)
69
82
{
70
83
// AVX512BW
@@ -140,4 +153,21 @@ void qsort_<zmm_vector<_Float16>>(_Float16* arr, int64_t left, int64_t right, in
140
153
{
141
154
qsort_16bit_<zmm_vector<_Float16>>(arr, left, right, maxiters);
142
155
}
156
+
157
+ template <>
158
+ void replace_inf_with_nan (_Float16 *arr, int64_t arrsize, int64_t nan_count)
159
+ {
160
+ memset (arr + arrsize - nan_count, 0xFF , nan_count * 2 );
161
+ }
162
+
163
+ template <>
164
+ void avx512_qsort (_Float16 *arr, int64_t arrsize)
165
+ {
166
+ if (arrsize > 1 ) {
167
+ int64_t nan_count = replace_nan_with_inf<zmm_vector<_Float16>, _Float16>(arr, arrsize);
168
+ qsort_16bit_<zmm_vector<_Float16>, _Float16>(
169
+ arr, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
170
+ replace_inf_with_nan (arr, arrsize, nan_count);
171
+ }
172
+ }
143
173
#endif // AVX512FP16_QSORT_16BIT
0 commit comments