Skip to content

Commit 8bd9c42

Browse files
author
Raghuveer Devulapalli
committed
Use std_argselect_withnan
1 parent 608538d commit 8bd9c42

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

src/avx512-64bit-argsort.hpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,20 @@
1111
#include "avx512-common-argsort.h"
1212
#include "avx512-64bit-keyvalue-networks.hpp"
1313

14+
template <typename T>
15+
void std_argselect_withnan(T *arr, int64_t *arg, int64_t k, int64_t left, int64_t right)
16+
{
17+
std::nth_element(arg + left,
18+
arg + k,
19+
arg + right,
20+
[arr](int64_t a, int64_t b) -> bool {
21+
if ((!std::isnan(arr[a])) && (!std::isnan(arr[b]))) {return arr[a] < arr[b];}
22+
else if (std::isnan(arr[a])) {return false;}
23+
else {return true;}
24+
});
25+
}
26+
27+
1428
/* argsort using std::sort */
1529
template <typename T>
1630
void std_argsort_withnan(T *arr, int64_t *arg, int64_t left, int64_t right)
@@ -425,8 +439,7 @@ void avx512_argselect(double* arr, int64_t *arg, int64_t k, int64_t arrsize)
425439
{
426440
if (arrsize > 1) {
427441
if (has_nan<zmm_vector<double>>(arr, arrsize)) {
428-
/* FIXME: no need to do a full argsort */
429-
std_argsort_withnan(arr, arg, 0, arrsize);
442+
std_argselect_withnan(arr, arg, 0, arrsize);
430443
}
431444
else {
432445
argselect_64bit_<zmm_vector<double>>(
@@ -458,8 +471,7 @@ void avx512_argselect(float* arr, int64_t *arg, int64_t k, int64_t arrsize)
458471
{
459472
if (arrsize > 1) {
460473
if (has_nan<ymm_vector<float>>(arr, arrsize)) {
461-
/* FIXME: no need to do a full argsort */
462-
std_argsort_withnan(arr, arg, 0, arrsize);
474+
std_argselect_withnan(arr, arg, 0, arrsize);
463475
}
464476
else {
465477
argselect_64bit_<ymm_vector<float>>(

0 commit comments

Comments
 (0)