From a55a65557a986c4a53221072648f53443596b5e3 Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Tue, 14 Jan 2025 12:43:48 -0800 Subject: [PATCH 1/3] Fix kv-sort and kv-select NAN handling --- src/xss-common-keyvaluesort.hpp | 67 +++++++++++++++++++++++---------- tests/test-keyvalue.cpp | 43 +++++++++++++-------- tests/test-qsort-common.h | 6 +++ tests/test-qsort.cpp | 21 ++++++----- utils/rand_array.h | 20 ++++++++++ 5 files changed, 113 insertions(+), 44 deletions(-) diff --git a/src/xss-common-keyvaluesort.hpp b/src/xss-common-keyvaluesort.hpp index 79b2af7d..9e870c87 100644 --- a/src/xss-common-keyvaluesort.hpp +++ b/src/xss-common-keyvaluesort.hpp @@ -16,6 +16,34 @@ #include #endif +/* + * Sort all the NAN's to end of the array and return the index of the last elem + * in the array which is not a nan + */ +template +X86_SIMD_SORT_INLINE arrsize_t move_nans_to_end_of_array(T1 *keys, + T2 *vals, + arrsize_t size) +{ + arrsize_t jj = size - 1; + arrsize_t ii = 0; + arrsize_t count = 0; + while (ii < jj) { + if (is_a_nan(keys[ii])) { + std::swap(keys[ii], keys[jj]); + std::swap(vals[ii], vals[jj]); + jj -= 1; + count++; + } + else { + ii += 1; + } + } + /* Haven't checked for nan when ii == jj */ + if (is_a_nan(keys[ii])) { count++; } + return size - count - 1; +} + /* * Parition one ZMM register based on the pivot and returns the index of the * last element that is less than equal to the pivot. @@ -538,11 +566,11 @@ X86_SIMD_SORT_INLINE void xss_qsort_kv( #endif // XSS_TEST_KEYVALUE_BASE_CASE if (minarrsize) { - arrsize_t nan_count = 0; + arrsize_t index_last_elem = arrsize - 1; if constexpr (xss::fp::is_floating_point_v) { if (UNLIKELY(hasnan)) { - nan_count - = replace_nan_with_inf>(keys, arrsize); + index_last_elem + = move_nans_to_end_of_array(keys, indexes, arrsize); } } else { @@ -565,24 +593,27 @@ X86_SIMD_SORT_INLINE void xss_qsort_kv( // Note that we do not use the if(...) clause built into OpenMP, because it causes a performance regression for small arrays #pragma omp parallel num_threads(thread_count) #pragma omp single - kvsort_( - keys, indexes, 0, arrsize - 1, maxiters, task_threshold); + kvsort_(keys, + indexes, + 0, + index_last_elem, + maxiters, + task_threshold); } else { kvsort_(keys, indexes, 0, - arrsize - 1, + index_last_elem, maxiters, std::numeric_limits::max()); } #pragma omp taskwait #else - kvsort_(keys, indexes, 0, arrsize - 1, maxiters, 0); + kvsort_( + keys, indexes, 0, index_last_elem, maxiters, 0); #endif - replace_inf_with_nan(keys, arrsize, nan_count); - if (descending) { std::reverse(keys, keys + arrsize); std::reverse(indexes, indexes + arrsize); @@ -625,20 +656,18 @@ X86_SIMD_SORT_INLINE void xss_select_kv(T1 *keys, if (minarrsize) { if (descending) { k = arrsize - 1 - k; } - if constexpr (std::is_floating_point_v) { - arrsize_t nan_count = 0; + arrsize_t index_last_elem = arrsize - 1; + if constexpr (xss::fp::is_floating_point_v) { if (UNLIKELY(hasnan)) { - nan_count - = replace_nan_with_inf>(keys, arrsize); + index_last_elem + = move_nans_to_end_of_array(keys, indexes, arrsize); } - kvselect_( - keys, indexes, k, 0, arrsize - 1, maxiters); - replace_inf_with_nan(keys, arrsize, nan_count); } - else { - UNUSED(hasnan); + + UNUSED(hasnan); + if (index_last_elem >= k) { kvselect_( - keys, indexes, k, 0, arrsize - 1, maxiters); + keys, indexes, k, 0, index_last_elem, maxiters); } if (descending) { diff --git a/tests/test-keyvalue.cpp b/tests/test-keyvalue.cpp index db1985dc..c0e683c0 100644 --- a/tests/test-keyvalue.cpp +++ b/tests/test-keyvalue.cpp @@ -26,7 +26,9 @@ class simdkvsort : public ::testing::Test { "smallrange", "max_at_the_end", "random_5d", - "rand_max"}; + "rand_max", + "rand_with_nan", + "rand_with_max_and_nan"}; } std::vector arrtype; std::vector arrsize = std::vector(1024); @@ -123,27 +125,36 @@ bool is_kv_partialsorted(T1 *keys_comp, } // Now, we need to do some more work to handle keys exactly equal to the true kth + // There may be more values after the kth element with the same key, + // and thus we can find that the values of the kth elements do not match, + // even though the sort is correct. + // First, fully kvsort both arrays xss::scalar::keyvalue_qsort(keys_ref, vals_ref, size, true, false); xss::scalar::keyvalue_qsort( keys_comp, vals_comp, size, true, false); - auto trueKth = keys_ref[k]; - bool notFoundFirst = true; + auto trueKthKey = keys_ref[k]; + bool foundFirstKthKey = false; size_t i = 0; + // Search forwards until we find the block of keys that match the kth key, + // then find where it ends for (; i < size; i++) { - if (notFoundFirst && cmp_eq(keys_ref[i], trueKth)) { - notFoundFirst = false; + if (!foundFirstKthKey && cmp_eq(keys_ref[i], trueKthKey)) { + foundFirstKthKey = true; i_start = i; } - else if (!notFoundFirst && !cmp_eq(keys_ref[i], trueKth)) { + else if (foundFirstKthKey && !cmp_eq(keys_ref[i], trueKthKey)) { break; } } - if (notFoundFirst) return false; + // kth key is somehow missing? Since we got that value from keys_ref, should be impossible + if (!foundFirstKthKey) { return false; } + // Check that the values in the kth key block match, so they are equivalent + // up to permutation, which is allowed since the sort is not stable if (!same_values(vals_ref + i_start, vals_comp + i_start, i - i_start)) { return false; } @@ -156,7 +167,7 @@ TYPED_TEST_P(simdkvsort, test_kvsort_ascending) using T1 = typename std::tuple_element<0, decltype(TypeParam())>::type; using T2 = typename std::tuple_element<1, decltype(TypeParam())>::type; for (auto type : this->arrtype) { - bool hasnan = (type == "rand_with_nan") ? true : false; + bool hasnan = is_nan_test(type); for (auto size : this->arrsize) { std::vector key = get_array(type, size); std::vector val = get_array(type, size); @@ -187,7 +198,7 @@ TYPED_TEST_P(simdkvsort, test_kvsort_descending) using T1 = typename std::tuple_element<0, decltype(TypeParam())>::type; using T2 = typename std::tuple_element<1, decltype(TypeParam())>::type; for (auto type : this->arrtype) { - bool hasnan = (type == "rand_with_nan") ? true : false; + bool hasnan = is_nan_test(type); for (auto size : this->arrsize) { std::vector key = get_array(type, size); std::vector val = get_array(type, size); @@ -217,8 +228,9 @@ TYPED_TEST_P(simdkvsort, test_kvselect_ascending) { using T1 = typename std::tuple_element<0, decltype(TypeParam())>::type; using T2 = typename std::tuple_element<1, decltype(TypeParam())>::type; + auto cmp_eq = compare>(); for (auto type : this->arrtype) { - bool hasnan = (type == "rand_with_nan") ? true : false; + bool hasnan = is_nan_test(type); for (auto size : this->arrsize) { size_t k = rand() % size; @@ -237,7 +249,7 @@ TYPED_TEST_P(simdkvsort, test_kvselect_ascending) xss::scalar::keyvalue_qsort( key.data(), val.data(), k, hasnan, false); - ASSERT_EQ(key[k], key_bckp[k]); + ASSERT_EQ(cmp_eq(key[k], key_bckp[k]), true); bool is_kv_partialsorted_ = is_kv_partialsorted(key.data(), @@ -260,8 +272,9 @@ TYPED_TEST_P(simdkvsort, test_kvselect_descending) { using T1 = typename std::tuple_element<0, decltype(TypeParam())>::type; using T2 = typename std::tuple_element<1, decltype(TypeParam())>::type; + auto cmp_eq = compare>(); for (auto type : this->arrtype) { - bool hasnan = (type == "rand_with_nan") ? true : false; + bool hasnan = is_nan_test(type); for (auto size : this->arrsize) { size_t k = rand() % size; @@ -280,7 +293,7 @@ TYPED_TEST_P(simdkvsort, test_kvselect_descending) xss::scalar::keyvalue_qsort( key.data(), val.data(), k, hasnan, true); - ASSERT_EQ(key[k], key_bckp[k]); + ASSERT_EQ(cmp_eq(key[k], key_bckp[k]), true); bool is_kv_partialsorted_ = is_kv_partialsorted(key.data(), @@ -304,7 +317,7 @@ TYPED_TEST_P(simdkvsort, test_kvpartial_sort_ascending) using T1 = typename std::tuple_element<0, decltype(TypeParam())>::type; using T2 = typename std::tuple_element<1, decltype(TypeParam())>::type; for (auto type : this->arrtype) { - bool hasnan = (type == "rand_with_nan") ? true : false; + bool hasnan = is_nan_test(type); for (auto size : this->arrsize) { size_t k = rand() % size; @@ -341,7 +354,7 @@ TYPED_TEST_P(simdkvsort, test_kvpartial_sort_descending) using T1 = typename std::tuple_element<0, decltype(TypeParam())>::type; using T2 = typename std::tuple_element<1, decltype(TypeParam())>::type; for (auto type : this->arrtype) { - bool hasnan = (type == "rand_with_nan") ? true : false; + bool hasnan = is_nan_test(type); for (auto size : this->arrsize) { size_t k = rand() % size; diff --git a/tests/test-qsort-common.h b/tests/test-qsort-common.h index 4fdb87fc..0d67b37d 100644 --- a/tests/test-qsort-common.h +++ b/tests/test-qsort-common.h @@ -20,6 +20,12 @@ ASSERT_TRUE(false) << msg << ". arr size = " << size \ << ", type = " << type << ", k = " << k; +inline bool is_nan_test(std::string type) +{ + // Currently, determine whether the test uses nan just be checking if nan is in its name + return type.find("nan") != std::string::npos; +} + template void IS_SORTED(std::vector sorted, std::vector arr, std::string type) { diff --git a/tests/test-qsort.cpp b/tests/test-qsort.cpp index 0df7addf..7eef83ec 100644 --- a/tests/test-qsort.cpp +++ b/tests/test-qsort.cpp @@ -19,7 +19,8 @@ class simdsort : public ::testing::Test { "max_at_the_end", "random_5d", "rand_max", - "rand_with_nan"}; + "rand_with_nan", + "rand_with_max_and_nan"}; } std::vector arrtype; std::vector arrsize = std::vector(1024); @@ -30,7 +31,7 @@ TYPED_TEST_SUITE_P(simdsort); TYPED_TEST_P(simdsort, test_qsort_ascending) { for (auto type : this->arrtype) { - bool hasnan = (type == "rand_with_nan") ? true : false; + bool hasnan = is_nan_test(type); for (auto size : this->arrsize) { std::vector basearr = get_array(type, size); @@ -52,7 +53,7 @@ TYPED_TEST_P(simdsort, test_qsort_ascending) TYPED_TEST_P(simdsort, test_qsort_descending) { for (auto type : this->arrtype) { - bool hasnan = (type == "rand_with_nan") ? true : false; + bool hasnan = is_nan_test(type); for (auto size : this->arrsize) { std::vector basearr = get_array(type, size); @@ -74,7 +75,7 @@ TYPED_TEST_P(simdsort, test_qsort_descending) TYPED_TEST_P(simdsort, test_argsort_ascending) { for (auto type : this->arrtype) { - bool hasnan = (type == "rand_with_nan") ? true : false; + bool hasnan = is_nan_test(type); for (auto size : this->arrsize) { std::vector arr = get_array(type, size); std::vector sortedarr = arr; @@ -92,7 +93,7 @@ TYPED_TEST_P(simdsort, test_argsort_ascending) TYPED_TEST_P(simdsort, test_argsort_descending) { for (auto type : this->arrtype) { - bool hasnan = (type == "rand_with_nan") ? true : false; + bool hasnan = is_nan_test(type); for (auto size : this->arrsize) { std::vector arr = get_array(type, size); std::vector sortedarr = arr; @@ -111,7 +112,7 @@ TYPED_TEST_P(simdsort, test_argsort_descending) TYPED_TEST_P(simdsort, test_qselect_ascending) { for (auto type : this->arrtype) { - bool hasnan = (type == "rand_with_nan") ? true : false; + bool hasnan = is_nan_test(type); for (auto size : this->arrsize) { size_t k = rand() % size; std::vector basearr = get_array(type, size); @@ -135,7 +136,7 @@ TYPED_TEST_P(simdsort, test_qselect_ascending) TYPED_TEST_P(simdsort, test_qselect_descending) { for (auto type : this->arrtype) { - bool hasnan = (type == "rand_with_nan") ? true : false; + bool hasnan = is_nan_test(type); for (auto size : this->arrsize) { size_t k = rand() % size; std::vector basearr = get_array(type, size); @@ -159,7 +160,7 @@ TYPED_TEST_P(simdsort, test_qselect_descending) TYPED_TEST_P(simdsort, test_argselect) { for (auto type : this->arrtype) { - bool hasnan = (type == "rand_with_nan") ? true : false; + bool hasnan = is_nan_test(type); for (auto size : this->arrsize) { size_t k = rand() % size; std::vector arr = get_array(type, size); @@ -179,7 +180,7 @@ TYPED_TEST_P(simdsort, test_argselect) TYPED_TEST_P(simdsort, test_partial_qsort_ascending) { for (auto type : this->arrtype) { - bool hasnan = (type == "rand_with_nan") ? true : false; + bool hasnan = is_nan_test(type); for (auto size : this->arrsize) { size_t k = rand() % size; std::vector basearr = get_array(type, size); @@ -202,7 +203,7 @@ TYPED_TEST_P(simdsort, test_partial_qsort_ascending) TYPED_TEST_P(simdsort, test_partial_qsort_descending) { for (auto type : this->arrtype) { - bool hasnan = (type == "rand_with_nan") ? true : false; + bool hasnan = is_nan_test(type); for (auto size : this->arrsize) { // k should be at least 1 size_t k = std::max((size_t)1, rand() % size); diff --git a/utils/rand_array.h b/utils/rand_array.h index dccbacdf..dcb0d018 100644 --- a/utils/rand_array.h +++ b/utils/rand_array.h @@ -140,6 +140,26 @@ static std::vector get_array(std::string arrtype, if (rand() & 0x1) { arr[ii] = val; } } } + else if (arrtype == "rand_with_max_and_nan") { + arr = get_uniform_rand_array(arrsize, max, min); + T max_val; + T nan_val; + if constexpr (xss::fp::is_floating_point_v) { + max_val = xss::fp::infinity(); + nan_val = xss::fp::quiet_NaN(); + } + else { + max_val = std::numeric_limits::max(); + nan_val = std::numeric_limits::max(); + } + for (size_t ii = 0; ii < arrsize; ++ii) { + int res = rand() % 4; + if (res == 2) { arr[ii] = max_val; } + else if (res == 3) { + arr[ii] = nan_val; + } + } + } else { std::cout << "Warning: unrecognized array type " << arrtype << std::endl; From b383a5b77ff509bbcdfb5187fa3d7ae157fc3911 Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Tue, 14 Jan 2025 15:51:26 -0800 Subject: [PATCH 2/3] Make NAN moving logic for kvsort use simd when possible --- src/xss-common-keyvaluesort.hpp | 37 ++++++++++++++++++++++++++++++--- 1 file changed, 34 insertions(+), 3 deletions(-) diff --git a/src/xss-common-keyvaluesort.hpp b/src/xss-common-keyvaluesort.hpp index 9e870c87..02e1985b 100644 --- a/src/xss-common-keyvaluesort.hpp +++ b/src/xss-common-keyvaluesort.hpp @@ -20,14 +20,42 @@ * Sort all the NAN's to end of the array and return the index of the last elem * in the array which is not a nan */ -template +template X86_SIMD_SORT_INLINE arrsize_t move_nans_to_end_of_array(T1 *keys, T2 *vals, arrsize_t size) { + using reg_t = typename vtype::reg_t; + arrsize_t jj = size - 1; arrsize_t ii = 0; arrsize_t count = 0; + + while (ii + vtype::numlanes < jj) { + reg_t in = vtype::loadu(keys + ii); + auto nanmask = vtype::convert_mask_to_int( + vtype::template fpclass<0x01 | 0x80>(in)); + + // Check if there are any nans in this vector, and process them if so + if (nanmask != 0x00) { + for (size_t offset = 0; offset < vtype::numlanes; offset++) { + if (is_a_nan(keys[ii])) { + std::swap(keys[ii], keys[jj]); + std::swap(vals[ii], vals[jj]); + jj -= 1; + count++; + } + else { + ii += 1; + } + } + } + else { + ii += vtype::numlanes; + } + } + + // Handle the remainders once we have less than 1 vector worth while (ii < jj) { if (is_a_nan(keys[ii])) { std::swap(keys[ii], keys[jj]); @@ -39,6 +67,7 @@ X86_SIMD_SORT_INLINE arrsize_t move_nans_to_end_of_array(T1 *keys, ii += 1; } } + /* Haven't checked for nan when ii == jj */ if (is_a_nan(keys[ii])) { count++; } return size - count - 1; @@ -570,7 +599,8 @@ X86_SIMD_SORT_INLINE void xss_qsort_kv( if constexpr (xss::fp::is_floating_point_v) { if (UNLIKELY(hasnan)) { index_last_elem - = move_nans_to_end_of_array(keys, indexes, arrsize); + = move_nans_to_end_of_array>( + keys, indexes, arrsize); } } else { @@ -660,7 +690,8 @@ X86_SIMD_SORT_INLINE void xss_select_kv(T1 *keys, if constexpr (xss::fp::is_floating_point_v) { if (UNLIKELY(hasnan)) { index_last_elem - = move_nans_to_end_of_array(keys, indexes, arrsize); + = move_nans_to_end_of_array>( + keys, indexes, arrsize); } } From 89de98db4a0c417d27426060c5a6d1e06fb13e2d Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Thu, 16 Jan 2025 08:48:58 -0800 Subject: [PATCH 3/3] Exit early if no work would be done to avoid crashes --- src/xss-common-keyvaluesort.hpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/xss-common-keyvaluesort.hpp b/src/xss-common-keyvaluesort.hpp index 02e1985b..43f64eca 100644 --- a/src/xss-common-keyvaluesort.hpp +++ b/src/xss-common-keyvaluesort.hpp @@ -586,6 +586,9 @@ X86_SIMD_SORT_INLINE void xss_qsort_kv( half_vector, full_vector>::type; + // Exit early if no work would be done + if (arrsize <= 1) return; + #ifdef XSS_TEST_KEYVALUE_BASE_CASE int maxiters = -1; bool minarrsize = true; @@ -675,6 +678,9 @@ X86_SIMD_SORT_INLINE void xss_select_kv(T1 *keys, half_vector, full_vector>::type; + // Exit early if no work would be done + if (arrsize <= 1) return; + #ifdef XSS_TEST_KEYVALUE_BASE_CASE int maxiters = -1; bool minarrsize = true;