Skip to content

Commit 3ddc914

Browse files
author
Raghuveer Devulapalli
committed
Fix bug in avx512fp16 nan processing
1 parent 92a628a commit 3ddc914

File tree

5 files changed

+74
-33
lines changed

5 files changed

+74
-33
lines changed

src/avx512-32bit-qsort.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,10 @@ struct zmm_vector<float> {
256256
{
257257
return _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ);
258258
}
259+
static opmask_t get_partial_loadmask(int size)
260+
{
261+
return (0x0001 << size) - 0x0001;
262+
}
259263
template <int type>
260264
static opmask_t fpclass(zmm_t x)
261265
{

src/avx512-64bit-argsort.hpp

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -344,33 +344,6 @@ static void argselect_64bit_(type_t *arr,
344344
arr, arg, pos, pivot_index, right, max_iters - 1);
345345
}
346346

347-
template <typename vtype, typename type_t>
348-
bool has_nan(type_t *arr, int64_t arrsize)
349-
{
350-
using opmask_t = typename vtype::opmask_t;
351-
using zmm_t = typename vtype::zmm_t;
352-
bool found_nan = false;
353-
opmask_t loadmask = 0xFF;
354-
zmm_t in;
355-
while (arrsize > 0) {
356-
if (arrsize < vtype::numlanes) {
357-
loadmask = (0x01 << arrsize) - 0x01;
358-
in = vtype::maskz_loadu(loadmask, arr);
359-
}
360-
else {
361-
in = vtype::loadu(arr);
362-
}
363-
opmask_t nanmask = vtype::template fpclass<0x01 | 0x80>(in);
364-
arr += vtype::numlanes;
365-
arrsize -= vtype::numlanes;
366-
if (nanmask != 0x00) {
367-
found_nan = true;
368-
break;
369-
}
370-
}
371-
return found_nan;
372-
}
373-
374347
/* argsort methods for 32-bit and 64-bit dtypes */
375348
template <typename T>
376349
void avx512_argsort(T *arr, int64_t *arg, int64_t arrsize)

src/avx512-64bit-common.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ struct ymm_vector<float> {
7171
{
7272
return _mm256_cmp_ps_mask(x, y, _CMP_EQ_OQ);
7373
}
74+
static opmask_t get_partial_loadmask(int size)
75+
{
76+
return (0x01 << size) - 0x01;
77+
}
7478
template <int type>
7579
static opmask_t fpclass(zmm_t x)
7680
{
@@ -703,6 +707,10 @@ struct zmm_vector<double> {
703707
{
704708
return _mm512_cmp_pd_mask(x, y, _CMP_EQ_OQ);
705709
}
710+
static opmask_t get_partial_loadmask(int size)
711+
{
712+
return (0x01 << size) - 0x01;
713+
}
706714
template <int type>
707715
static opmask_t fpclass(zmm_t x)
708716
{

src/avx512-common-qsort.h

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,18 +100,17 @@ bool is_a_nan(T elem)
100100
return std::isnan(elem);
101101
}
102102

103-
template <typename vtype, typename type_t>
104-
int64_t replace_nan_with_inf(type_t *arr, int64_t arrsize)
103+
template <typename vtype, typename T>
104+
int64_t replace_nan_with_inf(T *arr, int64_t arrsize)
105105
{
106106
int64_t nan_count = 0;
107107
using opmask_t = typename vtype::opmask_t;
108108
using zmm_t = typename vtype::zmm_t;
109-
bool found_nan = false;
110-
opmask_t loadmask = 0xFF;
109+
opmask_t loadmask;
111110
zmm_t in;
112111
while (arrsize > 0) {
113112
if (arrsize < vtype::numlanes) {
114-
loadmask = (0x01 << arrsize) - 0x01;
113+
loadmask = vtype::get_partial_loadmask(arrsize);
115114
in = vtype::maskz_loadu(loadmask, arr);
116115
}
117116
else {
@@ -126,6 +125,33 @@ int64_t replace_nan_with_inf(type_t *arr, int64_t arrsize)
126125
return nan_count;
127126
}
128127

128+
template <typename vtype, typename type_t>
129+
bool has_nan(type_t *arr, int64_t arrsize)
130+
{
131+
using opmask_t = typename vtype::opmask_t;
132+
using zmm_t = typename vtype::zmm_t;
133+
bool found_nan = false;
134+
opmask_t loadmask;
135+
zmm_t in;
136+
while (arrsize > 0) {
137+
if (arrsize < vtype::numlanes) {
138+
loadmask = vtype::get_partial_loadmask(arrsize);
139+
in = vtype::maskz_loadu(loadmask, arr);
140+
}
141+
else {
142+
in = vtype::loadu(arr);
143+
}
144+
opmask_t nanmask = vtype::template fpclass<0x01 | 0x80>(in);
145+
arr += vtype::numlanes;
146+
arrsize -= vtype::numlanes;
147+
if (nanmask != 0x00) {
148+
found_nan = true;
149+
break;
150+
}
151+
}
152+
return found_nan;
153+
}
154+
129155
template<typename type_t>
130156
void replace_inf_with_nan(type_t *arr, int64_t arrsize, int64_t nan_count)
131157
{

src/avx512fp16-16bit-qsort.hpp

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,19 @@ struct zmm_vector<_Float16> {
4646
{
4747
return _knot_mask32(x);
4848
}
49-
5049
static opmask_t ge(zmm_t x, zmm_t y)
5150
{
5251
return _mm512_cmp_ph_mask(x, y, _CMP_GE_OQ);
5352
}
53+
static opmask_t get_partial_loadmask(int size)
54+
{
55+
return (0x00000001 << size) - 0x00000001;
56+
}
57+
template <int type>
58+
static opmask_t fpclass(zmm_t x)
59+
{
60+
return _mm512_fpclass_ph_mask(x, type);
61+
}
5462
static zmm_t loadu(void const *mem)
5563
{
5664
return _mm512_loadu_ph(mem);
@@ -65,6 +73,11 @@ struct zmm_vector<_Float16> {
6573
// AVX512_VBMI2
6674
return _mm512_mask_compressstoreu_epi16(mem, mask, temp);
6775
}
76+
static zmm_t maskz_loadu(opmask_t mask, void const *mem)
77+
{
78+
return _mm512_castsi512_ph(
79+
_mm512_maskz_loadu_epi16(mask, mem));
80+
}
6881
static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem)
6982
{
7083
// AVX512BW
@@ -140,4 +153,21 @@ void qsort_<zmm_vector<_Float16>>(_Float16* arr, int64_t left, int64_t right, in
140153
{
141154
qsort_16bit_<zmm_vector<_Float16>>(arr, left, right, maxiters);
142155
}
156+
157+
template<>
158+
void replace_inf_with_nan(_Float16 *arr, int64_t arrsize, int64_t nan_count)
159+
{
160+
memset(arr + arrsize - nan_count, 0xFF, nan_count * 2);
161+
}
162+
163+
template<>
164+
void avx512_qsort(_Float16 *arr, int64_t arrsize)
165+
{
166+
if (arrsize > 1) {
167+
int64_t nan_count = replace_nan_with_inf<zmm_vector<_Float16>, _Float16>(arr, arrsize);
168+
qsort_16bit_<zmm_vector<_Float16>, _Float16>(
169+
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
170+
replace_inf_with_nan(arr, arrsize, nan_count);
171+
}
172+
}
143173
#endif // AVX512FP16_QSORT_16BIT

0 commit comments

Comments
 (0)