diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index 1ed829b..e05027d 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -175,7 +175,8 @@ struct zmm_vector { } static reg_t reverse(reg_t zmm) { - const auto rev_index = _mm512_set_epi16(NETWORK_REVERSE_32LANES); + constexpr static uint16_t arr[] = {NETWORK_REVERSE_32LANES}; + const auto rev_index = _mm512_loadu_si512(arr); return permutexvar(rev_index, zmm); } static reg_t sort_vec(reg_t x) @@ -320,7 +321,8 @@ struct zmm_vector { } static reg_t reverse(reg_t zmm) { - const auto rev_index = _mm512_set_epi16(NETWORK_REVERSE_32LANES); + constexpr static uint16_t arr[] = {NETWORK_REVERSE_32LANES}; + const auto rev_index = _mm512_loadu_si512(arr); return permutexvar(rev_index, zmm); } static reg_t sort_vec(reg_t x) @@ -462,7 +464,8 @@ struct zmm_vector { } static reg_t reverse(reg_t zmm) { - const auto rev_index = _mm512_set_epi16(NETWORK_REVERSE_32LANES); + constexpr static uint16_t arr[] = {NETWORK_REVERSE_32LANES}; + const auto rev_index = _mm512_loadu_si512(arr); return permutexvar(rev_index, zmm); } static reg_t sort_vec(reg_t x) diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index f93cf68..8f85e59 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -139,7 +139,8 @@ struct zmm_vector<_Float16> { } static reg_t reverse(reg_t zmm) { - const auto rev_index = _mm512_set_epi16(NETWORK_REVERSE_32LANES); + constexpr static uint16_t arr[] = {NETWORK_REVERSE_32LANES}; + const auto rev_index = _mm512_loadu_si512(arr); return permutexvar(rev_index, zmm); } static reg_t sort_vec(reg_t x) diff --git a/src/xss-common-includes.h b/src/xss-common-includes.h index 27d6c36..7408571 100644 --- a/src/xss-common-includes.h +++ b/src/xss-common-includes.h @@ -79,8 +79,8 @@ #define NETWORK_REVERSE_16LANES \ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 #define NETWORK_REVERSE_32LANES \ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, \ - 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31 + 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, \ + 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 #if defined(XSS_USE_OPENMP) && defined(_OPENMP) #define XSS_COMPILE_OPENMP