Skip to content

Commit eb581ce

Browse files
author
Raghuveer Devulapalli
authored
Merge pull request #38 from r-devulap/handle-nan
Skip avx-512 argsort for arrays with NAN
2 parents 5aca459 + af4ebe7 commit eb581ce

File tree

3 files changed

+111
-3
lines changed

3 files changed

+111
-3
lines changed

src/avx512-64bit-argsort.hpp

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,19 @@
1111
#include "avx512-common-argsort.h"
1212
#include "avx512-64bit-keyvalue-networks.hpp"
1313

14+
/* argsort using std::sort */
15+
template <typename T>
16+
void std_argsort_withnan(T *arr, int64_t *arg, int64_t left, int64_t right)
17+
{
18+
std::sort(arg + left,
19+
arg + right,
20+
[arr](int64_t left, int64_t right) -> bool {
21+
if ((!std::isnan(arr[left])) && (!std::isnan(arr[right]))) {return arr[left] < arr[right];}
22+
else if (std::isnan(arr[left])) {return false;}
23+
else {return true;}
24+
});
25+
}
26+
1427
/* argsort using std::sort */
1528
template <typename T>
1629
void std_argsort(T *arr, int64_t *arg, int64_t left, int64_t right)
@@ -270,6 +283,33 @@ inline void argsort_64bit_(type_t *arr,
270283
argsort_64bit_<vtype>(arr, arg, pivot_index, right, max_iters - 1);
271284
}
272285

286+
template <typename vtype, typename type_t>
287+
bool has_nan(type_t* arr, int64_t arrsize)
288+
{
289+
using opmask_t = typename vtype::opmask_t;
290+
using zmm_t = typename vtype::zmm_t;
291+
bool found_nan = false;
292+
opmask_t loadmask = 0xFF;
293+
zmm_t in;
294+
while (arrsize > 0) {
295+
if (arrsize < vtype::numlanes) {
296+
loadmask = (0x01 << arrsize) - 0x01;
297+
in = vtype::maskz_loadu(loadmask, arr);
298+
}
299+
else {
300+
in = vtype::loadu(arr);
301+
}
302+
opmask_t nanmask = vtype::template fpclass<0x01|0x80>(in);
303+
arr += vtype::numlanes;
304+
arrsize -= vtype::numlanes;
305+
if (nanmask != 0x00) {
306+
found_nan = true;
307+
break;
308+
}
309+
}
310+
return found_nan;
311+
}
312+
273313
template <typename T>
274314
void avx512_argsort(T* arr, int64_t *arg, int64_t arrsize)
275315
{
@@ -279,6 +319,21 @@ void avx512_argsort(T* arr, int64_t *arg, int64_t arrsize)
279319
}
280320
}
281321

322+
template <>
323+
void avx512_argsort(double* arr, int64_t *arg, int64_t arrsize)
324+
{
325+
if (arrsize > 1) {
326+
if (has_nan<zmm_vector<double>>(arr, arrsize)) {
327+
std_argsort_withnan(arr, arg, 0, arrsize);
328+
}
329+
else {
330+
argsort_64bit_<zmm_vector<double>>(
331+
arr, arg, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
332+
}
333+
}
334+
}
335+
336+
282337
template <>
283338
void avx512_argsort(int32_t* arr, int64_t *arg, int64_t arrsize)
284339
{
@@ -301,8 +356,13 @@ template <>
301356
void avx512_argsort(float* arr, int64_t *arg, int64_t arrsize)
302357
{
303358
if (arrsize > 1) {
304-
argsort_64bit_<ymm_vector<float>>(
305-
arr, arg, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
359+
if (has_nan<ymm_vector<float>>(arr, arrsize)) {
360+
std_argsort_withnan(arr, arg, 0, arrsize);
361+
}
362+
else {
363+
argsort_64bit_<ymm_vector<float>>(
364+
arr, arg, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
365+
}
306366
}
307367
}
308368

src/avx512-64bit-common.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ struct ymm_vector<float> {
7171
{
7272
return _mm256_cmp_ps_mask(x, y, _CMP_EQ_OQ);
7373
}
74+
template <int type>
75+
static opmask_t fpclass(zmm_t x)
76+
{
77+
return _mm256_fpclass_ps_mask(x, type);
78+
}
7479
template <int scale>
7580
static zmm_t
7681
mask_i64gather(zmm_t src, opmask_t mask, __m512i index, void const *base)
@@ -682,6 +687,10 @@ struct zmm_vector<double> {
682687
return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8);
683688
}
684689

690+
static zmm_t maskz_loadu(opmask_t mask, void const *mem)
691+
{
692+
return _mm512_maskz_loadu_pd(mask, mem);
693+
}
685694
static opmask_t knot_opmask(opmask_t x)
686695
{
687696
return _knot_mask8(x);
@@ -694,6 +703,11 @@ struct zmm_vector<double> {
694703
{
695704
return _mm512_cmp_pd_mask(x, y, _CMP_EQ_OQ);
696705
}
706+
template <int type>
707+
static opmask_t fpclass(zmm_t x)
708+
{
709+
return _mm512_fpclass_pd_mask(x, type);
710+
}
697711
template <int scale>
698712
static zmm_t
699713
mask_i64gather(zmm_t src, opmask_t mask, __m512i index, void const *base)

tests/test_argsort.cpp

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,12 +174,46 @@ TYPED_TEST_P(avx512argsort, test_reverse)
174174
}
175175
}
176176

177+
TYPED_TEST_P(avx512argsort, test_array_with_nan)
178+
{
179+
if (!cpu_has_avx512bw()) {
180+
GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA";
181+
}
182+
if (!std::is_floating_point<TypeParam>::value) {
183+
GTEST_SKIP() << "Skipping this test, it is meant for float/double";
184+
}
185+
std::vector<int64_t> arrsizes;
186+
for (int64_t ii = 2; ii <= 1024; ++ii) {
187+
arrsizes.push_back(ii);
188+
}
189+
std::vector<TypeParam> arr;
190+
for (auto &size : arrsizes) {
191+
arr = get_uniform_rand_array<TypeParam>(size);
192+
arr[0] = std::numeric_limits<TypeParam>::quiet_NaN();
193+
arr[1] = std::numeric_limits<TypeParam>::quiet_NaN();
194+
std::vector<int64_t> inx
195+
= avx512_argsort<TypeParam>(arr.data(), arr.size());
196+
std::vector<TypeParam> sort1;
197+
for (size_t jj = 0; jj < size; ++jj) {
198+
sort1.push_back(arr[inx[jj]]);
199+
}
200+
if ((!std::isnan(sort1[size-1])) || (!std::isnan(sort1[size-2]))) {
201+
FAIL() << "NAN's aren't sorted to the end";
202+
}
203+
if (!std::is_sorted(sort1.begin(), sort1.end() - 2)) {
204+
FAIL() << "Array isn't sorted";
205+
}
206+
arr.clear();
207+
}
208+
}
209+
177210
REGISTER_TYPED_TEST_SUITE_P(avx512argsort,
178211
test_random,
179212
test_reverse,
180213
test_constant,
181214
test_sorted,
182-
test_small_range);
215+
test_small_range,
216+
test_array_with_nan);
183217

184218
using ArgSortTestTypes = testing::Types<int32_t,
185219
uint32_t,

0 commit comments

Comments
 (0)