11
11
#include " avx512-common-argsort.h"
12
12
#include " avx512-64bit-keyvalue-networks.hpp"
13
13
14
+ /* argsort using std::sort */
15
+ template <typename T>
16
+ void std_argsort_withnan (T *arr, int64_t *arg, int64_t left, int64_t right)
17
+ {
18
+ std::sort (arg + left,
19
+ arg + right,
20
+ [arr](int64_t left, int64_t right) -> bool {
21
+ if ((!std::isnan (arr[left])) && (!std::isnan (arr[right]))) {return arr[left] < arr[right];}
22
+ else if (std::isnan (arr[left])) {return false ;}
23
+ else {return true ;}
24
+ });
25
+ }
26
+
14
27
/* argsort using std::sort */
15
28
template <typename T>
16
29
void std_argsort (T *arr, int64_t *arg, int64_t left, int64_t right)
@@ -270,6 +283,33 @@ inline void argsort_64bit_(type_t *arr,
270
283
argsort_64bit_<vtype>(arr, arg, pivot_index, right, max_iters - 1 );
271
284
}
272
285
286
+ template <typename vtype, typename type_t >
287
+ bool has_nan (type_t * arr, int64_t arrsize)
288
+ {
289
+ using opmask_t = typename vtype::opmask_t ;
290
+ using zmm_t = typename vtype::zmm_t ;
291
+ bool found_nan = false ;
292
+ opmask_t loadmask = 0xFF ;
293
+ zmm_t in;
294
+ while (arrsize > 0 ) {
295
+ if (arrsize < vtype::numlanes) {
296
+ loadmask = (0x01 << arrsize) - 0x01 ;
297
+ in = vtype::maskz_loadu (loadmask, arr);
298
+ }
299
+ else {
300
+ in = vtype::loadu (arr);
301
+ }
302
+ opmask_t nanmask = vtype::template fpclass<0x01 |0x80 >(in);
303
+ arr += vtype::numlanes;
304
+ arrsize -= vtype::numlanes;
305
+ if (nanmask != 0x00 ) {
306
+ found_nan = true ;
307
+ break ;
308
+ }
309
+ }
310
+ return found_nan;
311
+ }
312
+
273
313
template <typename T>
274
314
void avx512_argsort (T* arr, int64_t *arg, int64_t arrsize)
275
315
{
@@ -279,6 +319,21 @@ void avx512_argsort(T* arr, int64_t *arg, int64_t arrsize)
279
319
}
280
320
}
281
321
322
+ template <>
323
+ void avx512_argsort (double * arr, int64_t *arg, int64_t arrsize)
324
+ {
325
+ if (arrsize > 1 ) {
326
+ if (has_nan<zmm_vector<double >>(arr, arrsize)) {
327
+ std_argsort_withnan (arr, arg, 0 , arrsize);
328
+ }
329
+ else {
330
+ argsort_64bit_<zmm_vector<double >>(
331
+ arr, arg, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
332
+ }
333
+ }
334
+ }
335
+
336
+
282
337
template <>
283
338
void avx512_argsort (int32_t * arr, int64_t *arg, int64_t arrsize)
284
339
{
@@ -301,8 +356,13 @@ template <>
301
356
void avx512_argsort (float * arr, int64_t *arg, int64_t arrsize)
302
357
{
303
358
if (arrsize > 1 ) {
304
- argsort_64bit_<ymm_vector<float >>(
305
- arr, arg, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
359
+ if (has_nan<ymm_vector<float >>(arr, arrsize)) {
360
+ std_argsort_withnan (arr, arg, 0 , arrsize);
361
+ }
362
+ else {
363
+ argsort_64bit_<ymm_vector<float >>(
364
+ arr, arg, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
365
+ }
306
366
}
307
367
}
308
368
0 commit comments