@@ -95,7 +95,8 @@ void avx512_qselect(T *arr, int64_t k, int64_t arrsize);
95
95
void avx512_qselect_fp16 (uint16_t *arr, int64_t k, int64_t arrsize);
96
96
97
97
template <typename T>
98
- inline void avx512_partial_qsort (T *arr, int64_t k, int64_t arrsize) {
98
+ inline void avx512_partial_qsort (T *arr, int64_t k, int64_t arrsize)
99
+ {
99
100
avx512_qselect<T>(arr, k - 1 , arrsize);
100
101
avx512_qsort<T>(arr, k - 1 );
101
102
}
@@ -259,4 +260,123 @@ static inline int64_t partition_avx512(type_t *arr,
259
260
*biggest = vtype::reducemax (max_vec);
260
261
return l_store;
261
262
}
263
+
264
+ template <typename vtype,
265
+ int num_unroll,
266
+ typename type_t = typename vtype::type_t >
267
+ static inline int64_t partition_avx512_unrolled (type_t *arr,
268
+ int64_t left,
269
+ int64_t right,
270
+ type_t pivot,
271
+ type_t *smallest,
272
+ type_t *biggest)
273
+ {
274
+ if (right - left <= 2 * num_unroll * vtype::numlanes) {
275
+ return partition_avx512<vtype>(
276
+ arr, left, right, pivot, smallest, biggest);
277
+ }
278
+ /* make array length divisible by 8*vtype::numlanes , shortening the array */
279
+ for (int32_t i = ((right - left) % (num_unroll * vtype::numlanes)); i > 0 ;
280
+ --i) {
281
+ *smallest = std::min (*smallest, arr[left], comparison_func<vtype>);
282
+ *biggest = std::max (*biggest, arr[left], comparison_func<vtype>);
283
+ if (!comparison_func<vtype>(arr[left], pivot)) {
284
+ std::swap (arr[left], arr[--right]);
285
+ }
286
+ else {
287
+ ++left;
288
+ }
289
+ }
290
+
291
+ if (left == right)
292
+ return left; /* less than vtype::numlanes elements in the array */
293
+
294
+ using zmm_t = typename vtype::zmm_t ;
295
+ zmm_t pivot_vec = vtype::set1 (pivot);
296
+ zmm_t min_vec = vtype::set1 (*smallest);
297
+ zmm_t max_vec = vtype::set1 (*biggest);
298
+
299
+ // We will now have atleast 16 registers worth of data to process:
300
+ // left and right vtype::numlanes values are partitioned at the end
301
+ zmm_t vec_left[num_unroll], vec_right[num_unroll];
302
+ #pragma GCC unroll 8
303
+ for (int ii = 0 ; ii < num_unroll; ++ii) {
304
+ vec_left[ii] = vtype::loadu (arr + left + vtype::numlanes * ii);
305
+ vec_right[ii] = vtype::loadu (
306
+ arr + (right - vtype::numlanes * (num_unroll - ii)));
307
+ }
308
+ // store points of the vectors
309
+ int64_t r_store = right - vtype::numlanes;
310
+ int64_t l_store = left;
311
+ // indices for loading the elements
312
+ left += num_unroll * vtype::numlanes;
313
+ right -= num_unroll * vtype::numlanes;
314
+ while (right - left != 0 ) {
315
+ zmm_t curr_vec[num_unroll];
316
+ /*
317
+ * if fewer elements are stored on the right side of the array,
318
+ * then next elements are loaded from the right side,
319
+ * otherwise from the left side
320
+ */
321
+ if ((r_store + vtype::numlanes) - right < left - l_store) {
322
+ right -= num_unroll * vtype::numlanes;
323
+ #pragma GCC unroll 8
324
+ for (int ii = 0 ; ii < num_unroll; ++ii) {
325
+ curr_vec[ii] = vtype::loadu (arr + right + ii * vtype::numlanes);
326
+ }
327
+ }
328
+ else {
329
+ #pragma GCC unroll 8
330
+ for (int ii = 0 ; ii < num_unroll; ++ii) {
331
+ curr_vec[ii] = vtype::loadu (arr + left + ii * vtype::numlanes);
332
+ }
333
+ left += num_unroll * vtype::numlanes;
334
+ }
335
+ // partition the current vector and save it on both sides of the array
336
+ #pragma GCC unroll 8
337
+ for (int ii = 0 ; ii < num_unroll; ++ii) {
338
+ int32_t amount_ge_pivot
339
+ = partition_vec<vtype>(arr,
340
+ l_store,
341
+ r_store + vtype::numlanes,
342
+ curr_vec[ii],
343
+ pivot_vec,
344
+ &min_vec,
345
+ &max_vec);
346
+ l_store += (vtype::numlanes - amount_ge_pivot);
347
+ r_store -= amount_ge_pivot;
348
+ }
349
+ }
350
+
351
+ /* partition and save vec_left[8] and vec_right[8] */
352
+ #pragma GCC unroll 8
353
+ for (int ii = 0 ; ii < num_unroll; ++ii) {
354
+ int32_t amount_ge_pivot
355
+ = partition_vec<vtype>(arr,
356
+ l_store,
357
+ r_store + vtype::numlanes,
358
+ vec_left[ii],
359
+ pivot_vec,
360
+ &min_vec,
361
+ &max_vec);
362
+ l_store += (vtype::numlanes - amount_ge_pivot);
363
+ r_store -= amount_ge_pivot;
364
+ }
365
+ #pragma GCC unroll 8
366
+ for (int ii = 0 ; ii < num_unroll; ++ii) {
367
+ int32_t amount_ge_pivot
368
+ = partition_vec<vtype>(arr,
369
+ l_store,
370
+ r_store + vtype::numlanes,
371
+ vec_right[ii],
372
+ pivot_vec,
373
+ &min_vec,
374
+ &max_vec);
375
+ l_store += (vtype::numlanes - amount_ge_pivot);
376
+ r_store -= amount_ge_pivot;
377
+ }
378
+ *smallest = vtype::reducemin (min_vec);
379
+ *biggest = vtype::reducemax (max_vec);
380
+ return l_store;
381
+ }
262
382
#endif // AVX512_QSORT_COMMON
0 commit comments