|
11 | 11 | #include "avx512-common-argsort.h"
|
12 | 12 | #include "avx512-64bit-keyvalue-networks.hpp"
|
13 | 13 |
|
| 14 | +template <typename T> |
| 15 | +void std_argselect_withnan(T *arr, int64_t *arg, int64_t k, int64_t left, int64_t right) |
| 16 | +{ |
| 17 | + std::nth_element(arg + left, |
| 18 | + arg + k, |
| 19 | + arg + right, |
| 20 | + [arr](int64_t a, int64_t b) -> bool { |
| 21 | + if ((!std::isnan(arr[a])) && (!std::isnan(arr[b]))) {return arr[a] < arr[b];} |
| 22 | + else if (std::isnan(arr[a])) {return false;} |
| 23 | + else {return true;} |
| 24 | + }); |
| 25 | +} |
| 26 | + |
| 27 | + |
14 | 28 | /* argsort using std::sort */
|
15 | 29 | template <typename T>
|
16 | 30 | void std_argsort_withnan(T *arr, int64_t *arg, int64_t left, int64_t right)
|
@@ -425,8 +439,7 @@ void avx512_argselect(double* arr, int64_t *arg, int64_t k, int64_t arrsize)
|
425 | 439 | {
|
426 | 440 | if (arrsize > 1) {
|
427 | 441 | if (has_nan<zmm_vector<double>>(arr, arrsize)) {
|
428 |
| - /* FIXME: no need to do a full argsort */ |
429 |
| - std_argsort_withnan(arr, arg, 0, arrsize); |
| 442 | + std_argselect_withnan(arr, arg, 0, arrsize); |
430 | 443 | }
|
431 | 444 | else {
|
432 | 445 | argselect_64bit_<zmm_vector<double>>(
|
@@ -458,8 +471,7 @@ void avx512_argselect(float* arr, int64_t *arg, int64_t k, int64_t arrsize)
|
458 | 471 | {
|
459 | 472 | if (arrsize > 1) {
|
460 | 473 | if (has_nan<ymm_vector<float>>(arr, arrsize)) {
|
461 |
| - /* FIXME: no need to do a full argsort */ |
462 |
| - std_argsort_withnan(arr, arg, 0, arrsize); |
| 474 | + std_argselect_withnan(arr, arg, 0, arrsize); |
463 | 475 | }
|
464 | 476 | else {
|
465 | 477 | argselect_64bit_<ymm_vector<float>>(
|
|
0 commit comments