@@ -45,7 +45,112 @@ template <typename vtype1,
45
45
typename vtype2,
46
46
typename reg_t = typename vtype1::reg_t ,
47
47
typename index_type = typename vtype2::reg_t >
48
- X86_SIMD_SORT_INLINE reg_t sort_zmm_64bit (reg_t key_zmm, index_type &index_zmm)
48
+ X86_SIMD_SORT_INLINE reg_t sort_reg_16lanes (reg_t key_zmm,
49
+ index_type &index_zmm)
50
+ {
51
+ key_zmm = cmp_merge<vtype1, vtype2>(
52
+ key_zmm,
53
+ vtype1::template shuffle<SHUFFLE_MASK (2 , 3 , 0 , 1 )>(key_zmm),
54
+ index_zmm,
55
+ vtype2::template shuffle<SHUFFLE_MASK (2 , 3 , 0 , 1 )>(index_zmm),
56
+ 0xAAAA );
57
+ key_zmm = cmp_merge<vtype1, vtype2>(
58
+ key_zmm,
59
+ vtype1::template shuffle<SHUFFLE_MASK (0 , 1 , 2 , 3 )>(key_zmm),
60
+ index_zmm,
61
+ vtype2::template shuffle<SHUFFLE_MASK (0 , 1 , 2 , 3 )>(index_zmm),
62
+ 0xCCCC );
63
+ key_zmm = cmp_merge<vtype1, vtype2>(
64
+ key_zmm,
65
+ vtype1::template shuffle<SHUFFLE_MASK (2 , 3 , 0 , 1 )>(key_zmm),
66
+ index_zmm,
67
+ vtype2::template shuffle<SHUFFLE_MASK (2 , 3 , 0 , 1 )>(index_zmm),
68
+ 0xAAAA );
69
+ key_zmm = cmp_merge<vtype1, vtype2>(
70
+ key_zmm,
71
+ vtype1::permutexvar (vtype1::seti (NETWORK_32BIT_3), key_zmm),
72
+ index_zmm,
73
+ vtype2::permutexvar (vtype2::seti (NETWORK_32BIT_3), index_zmm),
74
+ 0xF0F0 );
75
+ key_zmm = cmp_merge<vtype1, vtype2>(
76
+ key_zmm,
77
+ vtype1::template shuffle<SHUFFLE_MASK (1 , 0 , 3 , 2 )>(key_zmm),
78
+ index_zmm,
79
+ vtype2::template shuffle<SHUFFLE_MASK (1 , 0 , 3 , 2 )>(index_zmm),
80
+ 0xCCCC );
81
+ key_zmm = cmp_merge<vtype1, vtype2>(
82
+ key_zmm,
83
+ vtype1::template shuffle<SHUFFLE_MASK (2 , 3 , 0 , 1 )>(key_zmm),
84
+ index_zmm,
85
+ vtype2::template shuffle<SHUFFLE_MASK (2 , 3 , 0 , 1 )>(index_zmm),
86
+ 0xAAAA );
87
+ key_zmm = cmp_merge<vtype1, vtype2>(
88
+ key_zmm,
89
+ vtype1::permutexvar (vtype1::seti (NETWORK_32BIT_5), key_zmm),
90
+ index_zmm,
91
+ vtype2::permutexvar (vtype2::seti (NETWORK_32BIT_5), index_zmm),
92
+ 0xFF00 );
93
+ key_zmm = cmp_merge<vtype1, vtype2>(
94
+ key_zmm,
95
+ vtype1::permutexvar (vtype1::seti (NETWORK_32BIT_6), key_zmm),
96
+ index_zmm,
97
+ vtype2::permutexvar (vtype2::seti (NETWORK_32BIT_6), index_zmm),
98
+ 0xF0F0 );
99
+ key_zmm = cmp_merge<vtype1, vtype2>(
100
+ key_zmm,
101
+ vtype1::template shuffle<SHUFFLE_MASK (1 , 0 , 3 , 2 )>(key_zmm),
102
+ index_zmm,
103
+ vtype2::template shuffle<SHUFFLE_MASK (1 , 0 , 3 , 2 )>(index_zmm),
104
+ 0xCCCC );
105
+ key_zmm = cmp_merge<vtype1, vtype2>(
106
+ key_zmm,
107
+ vtype1::template shuffle<SHUFFLE_MASK (2 , 3 , 0 , 1 )>(key_zmm),
108
+ index_zmm,
109
+ vtype2::template shuffle<SHUFFLE_MASK (2 , 3 , 0 , 1 )>(index_zmm),
110
+ 0xAAAA );
111
+ return key_zmm;
112
+ }
113
+
114
+ // Assumes zmm is bitonic and performs a recursive half cleaner
115
+ template <typename vtype1,
116
+ typename vtype2,
117
+ typename reg_t = typename vtype1::reg_t ,
118
+ typename index_type = typename vtype2::reg_t >
119
+ X86_SIMD_SORT_INLINE reg_t bitonic_merge_reg_16lanes (reg_t key_zmm,
120
+ index_type &index_zmm)
121
+ {
122
+ key_zmm = cmp_merge<vtype1, vtype2>(
123
+ key_zmm,
124
+ vtype1::permutexvar (vtype1::seti (NETWORK_32BIT_7), key_zmm),
125
+ index_zmm,
126
+ vtype2::permutexvar (vtype2::seti (NETWORK_32BIT_7), index_zmm),
127
+ 0xFF00 );
128
+ key_zmm = cmp_merge<vtype1, vtype2>(
129
+ key_zmm,
130
+ vtype1::permutexvar (vtype1::seti (NETWORK_32BIT_6), key_zmm),
131
+ index_zmm,
132
+ vtype2::permutexvar (vtype2::seti (NETWORK_32BIT_6), index_zmm),
133
+ 0xF0F0 );
134
+ key_zmm = cmp_merge<vtype1, vtype2>(
135
+ key_zmm,
136
+ vtype1::template shuffle<SHUFFLE_MASK (1 , 0 , 3 , 2 )>(key_zmm),
137
+ index_zmm,
138
+ vtype2::template shuffle<SHUFFLE_MASK (1 , 0 , 3 , 2 )>(index_zmm),
139
+ 0xCCCC );
140
+ key_zmm = cmp_merge<vtype1, vtype2>(
141
+ key_zmm,
142
+ vtype1::template shuffle<SHUFFLE_MASK (2 , 3 , 0 , 1 )>(key_zmm),
143
+ index_zmm,
144
+ vtype2::template shuffle<SHUFFLE_MASK (2 , 3 , 0 , 1 )>(index_zmm),
145
+ 0xAAAA );
146
+ return key_zmm;
147
+ }
148
+
149
+ template <typename vtype1,
150
+ typename vtype2,
151
+ typename reg_t = typename vtype1::reg_t ,
152
+ typename index_type = typename vtype2::reg_t >
153
+ X86_SIMD_SORT_INLINE reg_t sort_reg_8lanes (reg_t key_zmm, index_type &index_zmm)
49
154
{
50
155
const typename vtype1::regi_t rev_index1 = vtype1::seti (NETWORK_64BIT_2);
51
156
const typename vtype2::regi_t rev_index2 = vtype2::seti (NETWORK_64BIT_2);
@@ -93,8 +198,8 @@ template <typename vtype1,
93
198
typename vtype2,
94
199
typename reg_t = typename vtype1::reg_t ,
95
200
typename index_type = typename vtype2::reg_t >
96
- X86_SIMD_SORT_INLINE reg_t bitonic_merge_zmm_64bit (reg_t key_zmm,
97
- index_type &index_zmm)
201
+ X86_SIMD_SORT_INLINE reg_t bitonic_merge_reg_8lanes (reg_t key_zmm,
202
+ index_type &index_zmm)
98
203
{
99
204
100
205
// 1) half_cleaner[8]: compare 0-4, 1-5, 2-6, 3-7
@@ -128,10 +233,13 @@ bitonic_merge_dispatch(typename keyType::reg_t &key,
128
233
{
129
234
constexpr int numlanes = keyType::numlanes;
130
235
if constexpr (numlanes == 8 ) {
131
- key = bitonic_merge_zmm_64bit<keyType, valueType>(key, value);
236
+ key = bitonic_merge_reg_8lanes<keyType, valueType>(key, value);
237
+ }
238
+ else if constexpr (numlanes == 16 ) {
239
+ key = bitonic_merge_reg_16lanes<keyType, valueType>(key, value);
132
240
}
133
241
else {
134
- static_assert (numlanes == -1 , " should not reach here " );
242
+ static_assert (numlanes == -1 , " No implementation " );
135
243
UNUSED (key);
136
244
UNUSED (value);
137
245
}
@@ -143,10 +251,13 @@ X86_SIMD_SORT_INLINE void sort_vec_dispatch(typename keyType::reg_t &key,
143
251
{
144
252
constexpr int numlanes = keyType::numlanes;
145
253
if constexpr (numlanes == 8 ) {
146
- key = sort_zmm_64bit<keyType, valueType>(key, value);
254
+ key = sort_reg_8lanes<keyType, valueType>(key, value);
255
+ }
256
+ else if constexpr (numlanes == 16 ) {
257
+ key = sort_reg_16lanes<keyType, valueType>(key, value);
147
258
}
148
259
else {
149
- static_assert (numlanes == -1 , " should not reach here " );
260
+ static_assert (numlanes == -1 , " No implementation " );
150
261
UNUSED (key);
151
262
UNUSED (value);
152
263
}
0 commit comments