@@ -162,73 +162,41 @@ X86_SIMD_SORT_INLINE reg_t cmp_merge(reg_t in1, reg_t in2, opmask_t mask)
162
162
reg_t max = vtype::max (in2, in1);
163
163
return vtype::mask_mov (min, mask, max); // 0 -> min, 1 -> max
164
164
}
165
- /*
166
- * Parition one ZMM register based on the pivot and returns the
167
- * number of elements that are greater than or equal to the pivot.
168
- */
165
+
169
166
template <typename vtype, typename type_t , typename reg_t >
170
- X86_SIMD_SORT_INLINE arrsize_t partition_vec_avx512 (type_t *l_store,
171
- type_t *r_store,
172
- const reg_t curr_vec,
173
- const reg_t pivot_vec,
174
- reg_t &smallest_vec,
175
- reg_t &biggest_vec)
167
+ int avx512_double_compressstore (type_t *left_addr,
168
+ type_t *right_addr,
169
+ typename vtype::opmask_t k,
170
+ reg_t reg)
176
171
{
177
- typename vtype::opmask_t ge_mask = vtype::ge (curr_vec, pivot_vec);
178
- int amount_ge_pivot = _mm_popcnt_u32 ((int )ge_mask);
172
+ int amount_ge_pivot = _mm_popcnt_u32 ((int )k);
179
173
180
- vtype::mask_compressstoreu (l_store , vtype::knot_opmask (ge_mask ), curr_vec );
174
+ vtype::mask_compressstoreu (left_addr , vtype::knot_opmask (k ), reg );
181
175
vtype::mask_compressstoreu (
182
- r_store + vtype::numlanes - amount_ge_pivot, ge_mask, curr_vec);
183
-
184
- smallest_vec = vtype::min (curr_vec, smallest_vec);
185
- biggest_vec = vtype::max (curr_vec, biggest_vec);
186
-
176
+ right_addr + vtype::numlanes - amount_ge_pivot, k, reg);
177
+
187
178
return amount_ge_pivot;
188
179
}
189
- /*
190
- * Parition one YMM register based on the pivot and returns the
191
- * number of elements that are greater than or equal to the pivot.
192
- */
180
+
181
+ // Generic function dispatches to AVX2 or AVX512 code
193
182
template <typename vtype, typename type_t , typename reg_t = typename vtype::reg_t >
194
- X86_SIMD_SORT_INLINE arrsize_t partition_vec_avx2 (type_t *l_store,
183
+ X86_SIMD_SORT_INLINE arrsize_t partition_vec (type_t *l_store,
195
184
type_t *r_store,
196
185
const reg_t curr_vec,
197
186
const reg_t pivot_vec,
198
187
reg_t &smallest_vec,
199
188
reg_t &biggest_vec)
200
189
{
201
- /* which elements are larger than or equal to the pivot */
202
190
typename vtype::opmask_t ge_mask = vtype::ge (curr_vec, pivot_vec);
203
-
204
- int32_t amount_ge_pivot = vtype::double_compressstore (
205
- l_store, r_store, ge_mask, curr_vec);
191
+
192
+ int amount_ge_pivot = vtype::double_compressstore (l_store, r_store, ge_mask, curr_vec);
206
193
207
194
smallest_vec = vtype::min (curr_vec, smallest_vec);
208
195
biggest_vec = vtype::max (curr_vec, biggest_vec);
209
196
210
197
return amount_ge_pivot;
211
198
}
212
199
213
- // Generic function dispatches to AVX2 or AVX512 code
214
- template <typename vtype, typename type_t , typename reg_t = typename vtype::reg_t >
215
- X86_SIMD_SORT_INLINE arrsize_t partition_vec (type_t *l_store,
216
- type_t *r_store,
217
- const reg_t curr_vec,
218
- const reg_t pivot_vec,
219
- reg_t &smallest_vec,
220
- reg_t &biggest_vec)
221
- {
222
- if constexpr (sizeof (reg_t ) == 64 ){
223
- return partition_vec_avx512<vtype>(l_store, r_store, curr_vec, pivot_vec, smallest_vec, biggest_vec);
224
- }else if constexpr (sizeof (reg_t ) == 32 ){
225
- return partition_vec_avx2<vtype>(l_store, r_store, curr_vec, pivot_vec, smallest_vec, biggest_vec);
226
- }else {
227
- static_assert (sizeof (reg_t ) == -1 , " should not reach here" );
228
- return 0 ;
229
- }
230
- }
231
-
232
200
/*
233
201
* Parition an array based on the pivot and returns the index of the
234
202
* first element that is greater than or equal to the pivot.
0 commit comments