@@ -24,6 +24,7 @@ template <typename vtype, typename reg_t>
2424X86_SIMD_SORT_INLINE reg_t sort_zmm_64bit (reg_t zmm);
2525
2626struct avx512_64bit_swizzle_ops ;
27+ struct avx512_ymm_64bit_swizzle_ops ;
2728
2829template <>
2930struct ymm_vector <float > {
@@ -34,6 +35,8 @@ struct ymm_vector<float> {
3435 static const uint8_t numlanes = 8 ;
3536 static constexpr simd_type vec_type = simd_type::AVX512;
3637
38+ using swizzle_ops = avx512_ymm_64bit_swizzle_ops;
39+
3740 static type_t type_max ()
3841 {
3942 return X86_SIMD_SORT_INFINITYF;
@@ -212,6 +215,14 @@ struct ymm_vector<float> {
212215 const __m256i rev_index = _mm256_set_epi32 (NETWORK_32BIT_AVX2_2);
213216 return permutexvar (rev_index, ymm);
214217 }
218+ static int double_compressstore (type_t *left_addr,
219+ type_t *right_addr,
220+ opmask_t k,
221+ reg_t reg)
222+ {
223+ return avx512_double_compressstore<ymm_vector<type_t >>(
224+ left_addr, right_addr, k, reg);
225+ }
215226};
216227template <>
217228struct ymm_vector <uint32_t > {
@@ -222,6 +233,8 @@ struct ymm_vector<uint32_t> {
222233 static const uint8_t numlanes = 8 ;
223234 static constexpr simd_type vec_type = simd_type::AVX512;
224235
236+ using swizzle_ops = avx512_ymm_64bit_swizzle_ops;
237+
225238 static type_t type_max ()
226239 {
227240 return X86_SIMD_SORT_MAX_UINT32;
@@ -386,6 +399,14 @@ struct ymm_vector<uint32_t> {
386399 const __m256i rev_index = _mm256_set_epi32 (NETWORK_32BIT_AVX2_2);
387400 return permutexvar (rev_index, ymm);
388401 }
402+ static int double_compressstore (type_t *left_addr,
403+ type_t *right_addr,
404+ opmask_t k,
405+ reg_t reg)
406+ {
407+ return avx512_double_compressstore<ymm_vector<type_t >>(
408+ left_addr, right_addr, k, reg);
409+ }
389410};
390411template <>
391412struct ymm_vector <int32_t > {
@@ -396,6 +417,8 @@ struct ymm_vector<int32_t> {
396417 static const uint8_t numlanes = 8 ;
397418 static constexpr simd_type vec_type = simd_type::AVX512;
398419
420+ using swizzle_ops = avx512_ymm_64bit_swizzle_ops;
421+
399422 static type_t type_max ()
400423 {
401424 return X86_SIMD_SORT_MAX_INT32;
@@ -560,6 +583,14 @@ struct ymm_vector<int32_t> {
560583 const __m256i rev_index = _mm256_set_epi32 (NETWORK_32BIT_AVX2_2);
561584 return permutexvar (rev_index, ymm);
562585 }
586+ static int double_compressstore (type_t *left_addr,
587+ type_t *right_addr,
588+ opmask_t k,
589+ reg_t reg)
590+ {
591+ return avx512_double_compressstore<ymm_vector<type_t >>(
592+ left_addr, right_addr, k, reg);
593+ }
563594};
564595template <>
565596struct zmm_vector <int64_t > {
@@ -1215,4 +1246,77 @@ struct avx512_64bit_swizzle_ops {
12151246 }
12161247};
12171248
1249+ struct avx512_ymm_64bit_swizzle_ops {
1250+ template <typename vtype, int scale>
1251+ X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n (typename vtype::reg_t reg)
1252+ {
1253+ __m256i v = vtype::cast_to (reg);
1254+
1255+ if constexpr (scale == 2 ) {
1256+ __m256 vf = _mm256_castsi256_ps (v);
1257+ vf = _mm256_permute_ps (vf, 0b10110001 );
1258+ v = _mm256_castps_si256 (vf);
1259+ }
1260+ else if constexpr (scale == 4 ) {
1261+ __m256 vf = _mm256_castsi256_ps (v);
1262+ vf = _mm256_permute_ps (vf, 0b01001110 );
1263+ v = _mm256_castps_si256 (vf);
1264+ }
1265+ else if constexpr (scale == 8 ) {
1266+ v = _mm256_permute2x128_si256 (v, v, 0b00000001 );
1267+ }
1268+ else {
1269+ static_assert (scale == -1 , " should not be reached" );
1270+ }
1271+
1272+ return vtype::cast_from (v);
1273+ }
1274+
1275+ template <typename vtype, int scale>
1276+ X86_SIMD_SORT_INLINE typename vtype::reg_t
1277+ reverse_n (typename vtype::reg_t reg)
1278+ {
1279+ __m256i v = vtype::cast_to (reg);
1280+
1281+ if constexpr (scale == 2 ) { return swap_n<vtype, 2 >(reg); }
1282+ else if constexpr (scale == 4 ) {
1283+ constexpr uint64_t mask = 0b00011011 ;
1284+ __m256 vf = _mm256_castsi256_ps (v);
1285+ vf = _mm256_permute_ps (vf, mask);
1286+ v = _mm256_castps_si256 (vf);
1287+ }
1288+ else if constexpr (scale == 8 ) {
1289+ return vtype::reverse (reg);
1290+ }
1291+ else {
1292+ static_assert (scale == -1 , " should not be reached" );
1293+ }
1294+
1295+ return vtype::cast_from (v);
1296+ }
1297+
1298+ template <typename vtype, int scale>
1299+ X86_SIMD_SORT_INLINE typename vtype::reg_t
1300+ merge_n (typename vtype::reg_t reg, typename vtype::reg_t other)
1301+ {
1302+ __m256i v1 = vtype::cast_to (reg);
1303+ __m256i v2 = vtype::cast_to (other);
1304+
1305+ if constexpr (scale == 2 ) {
1306+ v1 = _mm256_blend_epi32 (v1, v2, 0b01010101 );
1307+ }
1308+ else if constexpr (scale == 4 ) {
1309+ v1 = _mm256_blend_epi32 (v1, v2, 0b00110011 );
1310+ }
1311+ else if constexpr (scale == 8 ) {
1312+ v1 = _mm256_blend_epi32 (v1, v2, 0b00001111 );
1313+ }
1314+ else {
1315+ static_assert (scale == -1 , " should not be reached" );
1316+ }
1317+
1318+ return vtype::cast_from (v1);
1319+ }
1320+ };
1321+
12181322#endif
0 commit comments