Skip to content

Commit a876f96

Browse files
author
Raghuveer Devulapalli
committed
Enable use of zmm registers for (32bit,32bit) key-value sort
1 parent 9b978ec commit a876f96

File tree

3 files changed

+246
-13
lines changed

3 files changed

+246
-13
lines changed

src/avx512-32bit-qsort.hpp

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ template <>
3232
struct zmm_vector<int32_t> {
3333
using type_t = int32_t;
3434
using reg_t = __m512i;
35+
using regi_t = __m512i;
3536
using halfreg_t = __m256i;
3637
using opmask_t = __mmask16;
3738
static const uint8_t numlanes = 16;
@@ -65,6 +66,10 @@ struct zmm_vector<int32_t> {
6566
{
6667
return _mm512_cmp_epi32_mask(x, y, _MM_CMPINT_NLT);
6768
}
69+
static opmask_t eq(reg_t x, reg_t y)
70+
{
71+
return _mm512_cmpeq_epi32_mask(x, y);
72+
}
6873
static opmask_t get_partial_loadmask(uint64_t num_to_read)
6974
{
7075
return ((0x1ull << num_to_read) - 0x1ull);
@@ -123,6 +128,40 @@ struct zmm_vector<int32_t> {
123128
{
124129
return _mm512_set1_epi32(v);
125130
}
131+
static regi_t seti(int v1,
132+
int v2,
133+
int v3,
134+
int v4,
135+
int v5,
136+
int v6,
137+
int v7,
138+
int v8,
139+
int v9,
140+
int v10,
141+
int v11,
142+
int v12,
143+
int v13,
144+
int v14,
145+
int v15,
146+
int v16)
147+
{
148+
return _mm512_set_epi32(v1,
149+
v2,
150+
v3,
151+
v4,
152+
v5,
153+
v6,
154+
v7,
155+
v8,
156+
v9,
157+
v10,
158+
v11,
159+
v12,
160+
v13,
161+
v14,
162+
v15,
163+
v16);
164+
}
126165
template <uint8_t mask>
127166
static reg_t shuffle(reg_t zmm)
128167
{
@@ -171,6 +210,7 @@ template <>
171210
struct zmm_vector<uint32_t> {
172211
using type_t = uint32_t;
173212
using reg_t = __m512i;
213+
using regi_t = __m512i;
174214
using halfreg_t = __m256i;
175215
using opmask_t = __mmask16;
176216
static const uint8_t numlanes = 16;
@@ -214,6 +254,10 @@ struct zmm_vector<uint32_t> {
214254
{
215255
return _mm512_cmp_epu32_mask(x, y, _MM_CMPINT_NLT);
216256
}
257+
static opmask_t eq(reg_t x, reg_t y)
258+
{
259+
return _mm512_cmpeq_epu32_mask(x, y);
260+
}
217261
static opmask_t get_partial_loadmask(uint64_t num_to_read)
218262
{
219263
return ((0x1ull << num_to_read) - 0x1ull);
@@ -262,6 +306,40 @@ struct zmm_vector<uint32_t> {
262306
{
263307
return _mm512_set1_epi32(v);
264308
}
309+
static regi_t seti(int v1,
310+
int v2,
311+
int v3,
312+
int v4,
313+
int v5,
314+
int v6,
315+
int v7,
316+
int v8,
317+
int v9,
318+
int v10,
319+
int v11,
320+
int v12,
321+
int v13,
322+
int v14,
323+
int v15,
324+
int v16)
325+
{
326+
return _mm512_set_epi32(v1,
327+
v2,
328+
v3,
329+
v4,
330+
v5,
331+
v6,
332+
v7,
333+
v8,
334+
v9,
335+
v10,
336+
v11,
337+
v12,
338+
v13,
339+
v14,
340+
v15,
341+
v16);
342+
}
265343
template <uint8_t mask>
266344
static reg_t shuffle(reg_t zmm)
267345
{
@@ -310,6 +388,7 @@ template <>
310388
struct zmm_vector<float> {
311389
using type_t = float;
312390
using reg_t = __m512;
391+
using regi_t = __m512i;
313392
using halfreg_t = __m256;
314393
using opmask_t = __mmask16;
315394
static const uint8_t numlanes = 16;
@@ -343,6 +422,10 @@ struct zmm_vector<float> {
343422
{
344423
return _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ);
345424
}
425+
static opmask_t eq(reg_t x, reg_t y)
426+
{
427+
return _mm512_cmpeq_ps_mask(x, y);
428+
}
346429
static opmask_t get_partial_loadmask(uint64_t num_to_read)
347430
{
348431
return ((0x1ull << num_to_read) - 0x1ull);
@@ -415,6 +498,40 @@ struct zmm_vector<float> {
415498
{
416499
return _mm512_set1_ps(v);
417500
}
501+
static regi_t seti(int v1,
502+
int v2,
503+
int v3,
504+
int v4,
505+
int v5,
506+
int v6,
507+
int v7,
508+
int v8,
509+
int v9,
510+
int v10,
511+
int v11,
512+
int v12,
513+
int v13,
514+
int v14,
515+
int v15,
516+
int v16)
517+
{
518+
return _mm512_set_epi32(v1,
519+
v2,
520+
v3,
521+
v4,
522+
v5,
523+
v6,
524+
v7,
525+
v8,
526+
v9,
527+
v10,
528+
v11,
529+
v12,
530+
v13,
531+
v14,
532+
v15,
533+
v16);
534+
}
418535
template <uint8_t mask>
419536
static reg_t shuffle(reg_t zmm)
420537
{

src/avx512-64bit-keyvaluesort.hpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -267,12 +267,17 @@ template <typename T1, typename T2>
267267
X86_SIMD_SORT_INLINE void
268268
avx512_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan = false)
269269
{
270-
using keytype = typename std::conditional<sizeof(T1) == sizeof(int32_t),
271-
ymm_vector<T1>,
272-
zmm_vector<T1>>::type;
273-
using valtype = typename std::conditional<sizeof(T2) == sizeof(int32_t),
274-
ymm_vector<T2>,
275-
zmm_vector<T2>>::type;
270+
using keytype =
271+
typename std::conditional<sizeof(T1) != sizeof(T2)
272+
&& sizeof(T1) == sizeof(int32_t),
273+
ymm_vector<T1>,
274+
zmm_vector<T1>>::type;
275+
using valtype =
276+
typename std::conditional<sizeof(T1) != sizeof(T2)
277+
&& sizeof(T2) == sizeof(int32_t),
278+
ymm_vector<T2>,
279+
zmm_vector<T2>>::type;
280+
276281
if (arrsize > 1) {
277282
if constexpr (std::is_floating_point_v<T1>) {
278283
arrsize_t nan_count = 0;

src/xss-network-keyvaluesort.hpp

Lines changed: 118 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,112 @@ template <typename vtype1,
4545
typename vtype2,
4646
typename reg_t = typename vtype1::reg_t,
4747
typename index_type = typename vtype2::reg_t>
48-
X86_SIMD_SORT_INLINE reg_t sort_zmm_64bit(reg_t key_zmm, index_type &index_zmm)
48+
X86_SIMD_SORT_INLINE reg_t sort_reg_16lanes(reg_t key_zmm,
49+
index_type &index_zmm)
50+
{
51+
key_zmm = cmp_merge<vtype1, vtype2>(
52+
key_zmm,
53+
vtype1::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(key_zmm),
54+
index_zmm,
55+
vtype2::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(index_zmm),
56+
0xAAAA);
57+
key_zmm = cmp_merge<vtype1, vtype2>(
58+
key_zmm,
59+
vtype1::template shuffle<SHUFFLE_MASK(0, 1, 2, 3)>(key_zmm),
60+
index_zmm,
61+
vtype2::template shuffle<SHUFFLE_MASK(0, 1, 2, 3)>(index_zmm),
62+
0xCCCC);
63+
key_zmm = cmp_merge<vtype1, vtype2>(
64+
key_zmm,
65+
vtype1::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(key_zmm),
66+
index_zmm,
67+
vtype2::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(index_zmm),
68+
0xAAAA);
69+
key_zmm = cmp_merge<vtype1, vtype2>(
70+
key_zmm,
71+
vtype1::permutexvar(vtype1::seti(NETWORK_32BIT_3), key_zmm),
72+
index_zmm,
73+
vtype2::permutexvar(vtype2::seti(NETWORK_32BIT_3), index_zmm),
74+
0xF0F0);
75+
key_zmm = cmp_merge<vtype1, vtype2>(
76+
key_zmm,
77+
vtype1::template shuffle<SHUFFLE_MASK(1, 0, 3, 2)>(key_zmm),
78+
index_zmm,
79+
vtype2::template shuffle<SHUFFLE_MASK(1, 0, 3, 2)>(index_zmm),
80+
0xCCCC);
81+
key_zmm = cmp_merge<vtype1, vtype2>(
82+
key_zmm,
83+
vtype1::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(key_zmm),
84+
index_zmm,
85+
vtype2::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(index_zmm),
86+
0xAAAA);
87+
key_zmm = cmp_merge<vtype1, vtype2>(
88+
key_zmm,
89+
vtype1::permutexvar(vtype1::seti(NETWORK_32BIT_5), key_zmm),
90+
index_zmm,
91+
vtype2::permutexvar(vtype2::seti(NETWORK_32BIT_5), index_zmm),
92+
0xFF00);
93+
key_zmm = cmp_merge<vtype1, vtype2>(
94+
key_zmm,
95+
vtype1::permutexvar(vtype1::seti(NETWORK_32BIT_6), key_zmm),
96+
index_zmm,
97+
vtype2::permutexvar(vtype2::seti(NETWORK_32BIT_6), index_zmm),
98+
0xF0F0);
99+
key_zmm = cmp_merge<vtype1, vtype2>(
100+
key_zmm,
101+
vtype1::template shuffle<SHUFFLE_MASK(1, 0, 3, 2)>(key_zmm),
102+
index_zmm,
103+
vtype2::template shuffle<SHUFFLE_MASK(1, 0, 3, 2)>(index_zmm),
104+
0xCCCC);
105+
key_zmm = cmp_merge<vtype1, vtype2>(
106+
key_zmm,
107+
vtype1::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(key_zmm),
108+
index_zmm,
109+
vtype2::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(index_zmm),
110+
0xAAAA);
111+
return key_zmm;
112+
}
113+
114+
// Assumes zmm is bitonic and performs a recursive half cleaner
115+
template <typename vtype1,
116+
typename vtype2,
117+
typename reg_t = typename vtype1::reg_t,
118+
typename index_type = typename vtype2::reg_t>
119+
X86_SIMD_SORT_INLINE reg_t bitonic_merge_reg_16lanes(reg_t key_zmm,
120+
index_type &index_zmm)
121+
{
122+
key_zmm = cmp_merge<vtype1, vtype2>(
123+
key_zmm,
124+
vtype1::permutexvar(vtype1::seti(NETWORK_32BIT_7), key_zmm),
125+
index_zmm,
126+
vtype2::permutexvar(vtype2::seti(NETWORK_32BIT_7), index_zmm),
127+
0xFF00);
128+
key_zmm = cmp_merge<vtype1, vtype2>(
129+
key_zmm,
130+
vtype1::permutexvar(vtype1::seti(NETWORK_32BIT_6), key_zmm),
131+
index_zmm,
132+
vtype2::permutexvar(vtype2::seti(NETWORK_32BIT_6), index_zmm),
133+
0xF0F0);
134+
key_zmm = cmp_merge<vtype1, vtype2>(
135+
key_zmm,
136+
vtype1::template shuffle<SHUFFLE_MASK(1, 0, 3, 2)>(key_zmm),
137+
index_zmm,
138+
vtype2::template shuffle<SHUFFLE_MASK(1, 0, 3, 2)>(index_zmm),
139+
0xCCCC);
140+
key_zmm = cmp_merge<vtype1, vtype2>(
141+
key_zmm,
142+
vtype1::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(key_zmm),
143+
index_zmm,
144+
vtype2::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(index_zmm),
145+
0xAAAA);
146+
return key_zmm;
147+
}
148+
149+
template <typename vtype1,
150+
typename vtype2,
151+
typename reg_t = typename vtype1::reg_t,
152+
typename index_type = typename vtype2::reg_t>
153+
X86_SIMD_SORT_INLINE reg_t sort_reg_8lanes(reg_t key_zmm, index_type &index_zmm)
49154
{
50155
const typename vtype1::regi_t rev_index1 = vtype1::seti(NETWORK_64BIT_2);
51156
const typename vtype2::regi_t rev_index2 = vtype2::seti(NETWORK_64BIT_2);
@@ -93,8 +198,8 @@ template <typename vtype1,
93198
typename vtype2,
94199
typename reg_t = typename vtype1::reg_t,
95200
typename index_type = typename vtype2::reg_t>
96-
X86_SIMD_SORT_INLINE reg_t bitonic_merge_zmm_64bit(reg_t key_zmm,
97-
index_type &index_zmm)
201+
X86_SIMD_SORT_INLINE reg_t bitonic_merge_reg_8lanes(reg_t key_zmm,
202+
index_type &index_zmm)
98203
{
99204

100205
// 1) half_cleaner[8]: compare 0-4, 1-5, 2-6, 3-7
@@ -128,10 +233,13 @@ bitonic_merge_dispatch(typename keyType::reg_t &key,
128233
{
129234
constexpr int numlanes = keyType::numlanes;
130235
if constexpr (numlanes == 8) {
131-
key = bitonic_merge_zmm_64bit<keyType, valueType>(key, value);
236+
key = bitonic_merge_reg_8lanes<keyType, valueType>(key, value);
237+
}
238+
else if constexpr (numlanes == 16) {
239+
key = bitonic_merge_reg_16lanes<keyType, valueType>(key, value);
132240
}
133241
else {
134-
static_assert(numlanes == -1, "should not reach here");
242+
static_assert(numlanes == -1, "No implementation");
135243
UNUSED(key);
136244
UNUSED(value);
137245
}
@@ -143,10 +251,13 @@ X86_SIMD_SORT_INLINE void sort_vec_dispatch(typename keyType::reg_t &key,
143251
{
144252
constexpr int numlanes = keyType::numlanes;
145253
if constexpr (numlanes == 8) {
146-
key = sort_zmm_64bit<keyType, valueType>(key, value);
254+
key = sort_reg_8lanes<keyType, valueType>(key, value);
255+
}
256+
else if constexpr (numlanes == 16) {
257+
key = sort_reg_16lanes<keyType, valueType>(key, value);
147258
}
148259
else {
149-
static_assert(numlanes == -1, "should not reach here");
260+
static_assert(numlanes == -1, "No implementation");
150261
UNUSED(key);
151262
UNUSED(value);
152263
}

0 commit comments

Comments
 (0)