Skip to content

Commit dee9505

Browse files
author
Raghuveer Devulapalli
committed
Fix NAN check for _Float16
1 parent 336c998 commit dee9505

File tree

3 files changed

+28
-8
lines changed

3 files changed

+28
-8
lines changed

src/avx512-16bit-qsort.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,8 @@ arrsize_t replace_nan_with_inf<zmm_vector<float16>>(uint16_t *arr,
445445
template <>
446446
bool is_a_nan<uint16_t>(uint16_t elem)
447447
{
448-
return (elem & 0x7c00) == 0x7c00;
448+
return ((elem & 0x7c00u) == 0x7c00u) &&
449+
((elem & 0x03ffu) != 0);
449450
}
450451

451452
X86_SIMD_SORT_INLINE

src/avx512-common-qsort.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ X86_SIMD_SORT_INLINE arrsize_t move_nans_to_end_of_array(T *arr, arrsize_t size)
191191
arrsize_t jj = size - 1;
192192
arrsize_t ii = 0;
193193
arrsize_t count = 0;
194-
while (ii <= jj) {
194+
while (ii < jj) {
195195
if (is_a_nan(arr[ii])) {
196196
std::swap(arr[ii], arr[jj]);
197197
jj -= 1;
@@ -201,6 +201,10 @@ X86_SIMD_SORT_INLINE arrsize_t move_nans_to_end_of_array(T *arr, arrsize_t size)
201201
ii += 1;
202202
}
203203
}
204+
/* Haven't checked for nan when ii == jj */
205+
if (is_a_nan(arr[ii])) {
206+
count++;
207+
}
204208
return size - count - 1;
205209
}
206210

src/avx512fp16-16bit-qsort.hpp

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -145,17 +145,19 @@ struct zmm_vector<_Float16> {
145145
template <>
146146
bool is_a_nan<_Float16>(_Float16 elem)
147147
{
148-
Fp16Bits temp;
149-
temp.f_ = elem;
150-
return (temp.i_ & 0x7c00) == 0x7c00;
148+
return elem != elem;
151149
}
152150

153151
template <>
154-
void replace_inf_with_nan(_Float16 *arr, arrsize_t arrsize, arrsize_t nan_count)
152+
void replace_inf_with_nan(_Float16 *arr, arrsize_t size, arrsize_t nan_count)
155153
{
156-
memset(arr + arrsize - nan_count, 0xFF, nan_count * 2);
154+
Fp16Bits val;
155+
val.i_ = 0x7c01;
156+
for (arrsize_t ii = size - 1; nan_count > 0; --ii) {
157+
arr[ii] = val.f_;
158+
nan_count -= 1;
159+
}
157160
}
158-
159161
/* Specialized template function for _Float16 qsort_*/
160162
template <>
161163
void avx512_qsort(_Float16 *arr, arrsize_t arrsize)
@@ -169,4 +171,17 @@ void avx512_qsort(_Float16 *arr, arrsize_t arrsize)
169171
replace_inf_with_nan(arr, arrsize, nan_count);
170172
}
171173
}
174+
175+
template <>
176+
void avx512_qselect(_Float16 *arr, arrsize_t k, arrsize_t arrsize, bool hasnan)
177+
{
178+
arrsize_t indx_last_elem = arrsize - 1;
179+
if (UNLIKELY(hasnan)) {
180+
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
181+
}
182+
if (indx_last_elem >= k) {
183+
qselect_<zmm_vector<_Float16>, _Float16>(
184+
arr, k, 0, indx_last_elem, 2 * (arrsize_t)log2(indx_last_elem));
185+
}
186+
}
172187
#endif // AVX512FP16_QSORT_16BIT

0 commit comments

Comments
 (0)