Skip to content

Commit 33e30e6

Browse files
author
Raghuveer Devulapalli
committed
Detect NAN in float/double array and skip avx512_argsort
1 parent 5aca459 commit 33e30e6

File tree

2 files changed

+76
-2
lines changed

2 files changed

+76
-2
lines changed

src/avx512-64bit-argsort.hpp

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,19 @@
1111
#include "avx512-common-argsort.h"
1212
#include "avx512-64bit-keyvalue-networks.hpp"
1313

14+
/* argsort using std::sort */
15+
template <typename T>
16+
void std_argsort_withnan(T *arr, int64_t *arg, int64_t left, int64_t right)
17+
{
18+
std::sort(arg + left,
19+
arg + right,
20+
[arr](int64_t left, int64_t right) -> bool {
21+
if ((!std::isnan(arr[left])) && (!std::isnan(arr[right]))) {return arr[left] < arr[right];}
22+
else if (std::isnan(arr[left])) {return false;}
23+
else {return true;}
24+
});
25+
}
26+
1427
/* argsort using std::sort */
1528
template <typename T>
1629
void std_argsort(T *arr, int64_t *arg, int64_t left, int64_t right)
@@ -270,6 +283,33 @@ inline void argsort_64bit_(type_t *arr,
270283
argsort_64bit_<vtype>(arr, arg, pivot_index, right, max_iters - 1);
271284
}
272285

286+
template <typename vtype, typename type_t>
287+
bool has_nan(type_t* arr, int64_t arrsize)
288+
{
289+
using opmask_t = typename vtype::opmask_t;
290+
using zmm_t = typename vtype::zmm_t;
291+
bool found_nan = false;
292+
opmask_t loadmask = 0xFF;
293+
zmm_t in;
294+
while (arrsize > 0) {
295+
if (arrsize < vtype::numlanes) {
296+
loadmask = (0x01 << arrsize) - 0x01;
297+
in = vtype::maskz_loadu(loadmask, arr);
298+
}
299+
else {
300+
in = vtype::loadu(arr);
301+
}
302+
opmask_t nanmask = vtype::template fpclass<0x01|0x80>(in);
303+
arr += vtype::numlanes;
304+
arrsize -= vtype::numlanes;
305+
if (nanmask != 0x00) {
306+
found_nan = true;
307+
break;
308+
}
309+
}
310+
return found_nan;
311+
}
312+
273313
template <typename T>
274314
void avx512_argsort(T* arr, int64_t *arg, int64_t arrsize)
275315
{
@@ -279,6 +319,21 @@ void avx512_argsort(T* arr, int64_t *arg, int64_t arrsize)
279319
}
280320
}
281321

322+
template <>
323+
void avx512_argsort(double* arr, int64_t *arg, int64_t arrsize)
324+
{
325+
if (arrsize > 1) {
326+
if (has_nan<zmm_vector<double>>(arr, arrsize)) {
327+
std_argsort_withnan(arr, arg, 0, arrsize);
328+
}
329+
else {
330+
argsort_64bit_<zmm_vector<double>>(
331+
arr, arg, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
332+
}
333+
}
334+
}
335+
336+
282337
template <>
283338
void avx512_argsort(int32_t* arr, int64_t *arg, int64_t arrsize)
284339
{
@@ -301,8 +356,13 @@ template <>
301356
void avx512_argsort(float* arr, int64_t *arg, int64_t arrsize)
302357
{
303358
if (arrsize > 1) {
304-
argsort_64bit_<ymm_vector<float>>(
305-
arr, arg, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
359+
if (has_nan<ymm_vector<float>>(arr, arrsize)) {
360+
std_argsort_withnan(arr, arg, 0, arrsize);
361+
}
362+
else {
363+
argsort_64bit_<ymm_vector<float>>(
364+
arr, arg, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
365+
}
306366
}
307367
}
308368

src/avx512-64bit-common.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ struct ymm_vector<float> {
7171
{
7272
return _mm256_cmp_ps_mask(x, y, _CMP_EQ_OQ);
7373
}
74+
template <int type>
75+
static opmask_t fpclass(zmm_t x)
76+
{
77+
return _mm256_fpclass_ps_mask(x, type);
78+
}
7479
template <int scale>
7580
static zmm_t
7681
mask_i64gather(zmm_t src, opmask_t mask, __m512i index, void const *base)
@@ -682,6 +687,10 @@ struct zmm_vector<double> {
682687
return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8);
683688
}
684689

690+
static zmm_t maskz_loadu(opmask_t mask, void const *mem)
691+
{
692+
return _mm512_maskz_loadu_pd(mask, mem);
693+
}
685694
static opmask_t knot_opmask(opmask_t x)
686695
{
687696
return _knot_mask8(x);
@@ -694,6 +703,11 @@ struct zmm_vector<double> {
694703
{
695704
return _mm512_cmp_pd_mask(x, y, _CMP_EQ_OQ);
696705
}
706+
template <int type>
707+
static opmask_t fpclass(zmm_t x)
708+
{
709+
return _mm512_fpclass_pd_mask(x, type);
710+
}
697711
template <int scale>
698712
static zmm_t
699713
mask_i64gather(zmm_t src, opmask_t mask, __m512i index, void const *base)

0 commit comments

Comments
 (0)