Skip to content

Commit 7b66816

Browse files
committed
Cleaning up some code and de-duplicating some logic
1 parent 5916cfe commit 7b66816

File tree

3 files changed

+47
-61
lines changed

3 files changed

+47
-61
lines changed

src/avx512-64bit-common.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ struct ymm_vector<float> {
3434
using opmask_t = __mmask8;
3535
static const uint8_t numlanes = 8;
3636
static constexpr simd_type vec_type = simd_type::AVX512;
37-
37+
3838
using swizzle_ops = avx512_ymm_64bit_swizzle_ops;
3939

4040
static type_t type_max()
@@ -232,7 +232,7 @@ struct ymm_vector<uint32_t> {
232232
using opmask_t = __mmask8;
233233
static const uint8_t numlanes = 8;
234234
static constexpr simd_type vec_type = simd_type::AVX512;
235-
235+
236236
using swizzle_ops = avx512_ymm_64bit_swizzle_ops;
237237

238238
static type_t type_max()
@@ -416,7 +416,7 @@ struct ymm_vector<int32_t> {
416416
using opmask_t = __mmask8;
417417
static const uint8_t numlanes = 8;
418418
static constexpr simd_type vec_type = simd_type::AVX512;
419-
419+
420420
using swizzle_ops = avx512_ymm_64bit_swizzle_ops;
421421

422422
static type_t type_max()

src/xss-common-keyvaluesort.hpp

Lines changed: 44 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -363,11 +363,11 @@ template <typename vtype1,
363363
typename vtype2,
364364
typename type1_t = typename vtype1::type_t,
365365
typename type2_t = typename vtype2::type_t>
366-
X86_SIMD_SORT_INLINE void qsort_64bit_(type1_t *keys,
367-
type2_t *indexes,
368-
arrsize_t left,
369-
arrsize_t right,
370-
int max_iters)
366+
X86_SIMD_SORT_INLINE void kvsort_(type1_t *keys,
367+
type2_t *indexes,
368+
arrsize_t left,
369+
arrsize_t right,
370+
int max_iters)
371371
{
372372
/*
373373
* Resort to std::sort if quicksort isnt making any progress
@@ -393,32 +393,35 @@ X86_SIMD_SORT_INLINE void qsort_64bit_(type1_t *keys,
393393
arrsize_t pivot_index = kvpartition_unrolled<vtype1, vtype2, 4>(
394394
keys, indexes, left, right + 1, pivot, &smallest, &biggest);
395395
if (pivot != smallest) {
396-
qsort_64bit_<vtype1, vtype2>(
396+
kvsort_<vtype1, vtype2>(
397397
keys, indexes, left, pivot_index - 1, max_iters - 1);
398398
}
399399
if (pivot != biggest) {
400-
qsort_64bit_<vtype1, vtype2>(
400+
kvsort_<vtype1, vtype2>(
401401
keys, indexes, pivot_index, right, max_iters - 1);
402402
}
403403
}
404404

405-
template <typename T1, typename T2>
405+
template <typename T1,
406+
typename T2,
407+
template <typename...>
408+
typename full_vector,
409+
template <typename...>
410+
typename half_vector>
406411
X86_SIMD_SORT_INLINE void
407-
avx512_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan = false)
412+
xss_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan)
408413
{
409414
using keytype =
410415
typename std::conditional<sizeof(T1) != sizeof(T2)
411416
&& sizeof(T1) == sizeof(int32_t),
412-
ymm_vector<T1>,
413-
zmm_vector<T1>>::type;
417+
half_vector<T1>,
418+
full_vector<T1>>::type;
414419
using valtype =
415420
typename std::conditional<sizeof(T1) != sizeof(T2)
416421
&& sizeof(T2) == sizeof(int32_t),
417-
ymm_vector<T2>,
418-
zmm_vector<T2>>::type;
419-
/*
420-
* Enable testing the heapsort key-value sort in the CI:
421-
*/
422+
half_vector<T2>,
423+
full_vector<T2>>::type;
424+
422425
#ifdef XSS_TEST_KEYVALUE_BASE_CASE
423426
int maxiters = -1;
424427
bool minarrsize = true;
@@ -428,57 +431,43 @@ avx512_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan = false)
428431
#endif // XSS_TEST_KEYVALUE_BASE_CASE
429432

430433
if (minarrsize) {
431-
arrsize_t nan_count = 0;
432-
if constexpr (xss::fp::is_floating_point_v<T1>) {
434+
if constexpr (std::is_floating_point_v<T1>) {
435+
arrsize_t nan_count = 0;
433436
if (UNLIKELY(hasnan)) {
434-
nan_count = replace_nan_with_inf<zmm_vector<T1>>(keys, arrsize);
437+
nan_count
438+
= replace_nan_with_inf<full_vector<T1>>(keys, arrsize);
435439
}
440+
kvsort_<keytype, valtype>(keys,
441+
indexes,
442+
0,
443+
arrsize - 1,
444+
2 * (arrsize_t)log2(arrsize));
445+
replace_inf_with_nan(keys, arrsize, nan_count);
436446
}
437447
else {
438448
UNUSED(hasnan);
449+
kvsort_<keytype, valtype>(keys,
450+
indexes,
451+
0,
452+
arrsize - 1,
453+
2 * (arrsize_t)log2(arrsize));
439454
}
440-
qsort_64bit_<keytype, valtype>(keys, indexes, 0, arrsize - 1, maxiters);
441-
replace_inf_with_nan(keys, arrsize, nan_count);
442455
}
443456
}
444457

445458
template <typename T1, typename T2>
446459
X86_SIMD_SORT_INLINE void
447-
avx2_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan = false)
460+
avx512_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan = false)
448461
{
449-
using keytype =
450-
typename std::conditional<sizeof(T1) != sizeof(T2)
451-
&& sizeof(T1) == sizeof(int32_t),
452-
avx2_half_vector<T1>,
453-
avx2_vector<T1>>::type;
454-
using valtype =
455-
typename std::conditional<sizeof(T1) != sizeof(T2)
456-
&& sizeof(T2) == sizeof(int32_t),
457-
avx2_half_vector<T2>,
458-
avx2_vector<T2>>::type;
462+
xss_qsort_kv<T1, T2, zmm_vector, ymm_vector>(
463+
keys, indexes, arrsize, hasnan);
464+
}
459465

460-
if (arrsize > 1) {
461-
if constexpr (std::is_floating_point_v<T1>) {
462-
arrsize_t nan_count = 0;
463-
if (UNLIKELY(hasnan)) {
464-
nan_count
465-
= replace_nan_with_inf<avx2_vector<T1>>(keys, arrsize);
466-
}
467-
qsort_64bit_<keytype, valtype>(keys,
468-
indexes,
469-
0,
470-
arrsize - 1,
471-
2 * (arrsize_t)log2(arrsize));
472-
replace_inf_with_nan(keys, arrsize, nan_count);
473-
}
474-
else {
475-
UNUSED(hasnan);
476-
qsort_64bit_<keytype, valtype>(keys,
477-
indexes,
478-
0,
479-
arrsize - 1,
480-
2 * (arrsize_t)log2(arrsize));
481-
}
482-
}
466+
template <typename T1, typename T2>
467+
X86_SIMD_SORT_INLINE void
468+
avx2_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan = false)
469+
{
470+
xss_qsort_kv<T1, T2, avx2_vector, avx2_half_vector>(
471+
keys, indexes, arrsize, hasnan);
483472
}
484473
#endif // AVX512_QSORT_64BIT_KV

src/xss-common-qsort.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,6 @@ X86_SIMD_SORT_INLINE void COEX(mm_t &a, mm_t &b)
189189
b = vtype::max(temp, b);
190190
}
191191

192-
template <typename maskType, typename vtype>
193-
typename vtype::opmask_t convert_int_to_mask(maskType mask);
194-
195192
template <typename vtype,
196193
typename reg_t = typename vtype::reg_t,
197194
typename opmask_t = typename vtype::opmask_t>

0 commit comments

Comments
 (0)