Skip to content

Commit d2dc8ab

Browse files
author
Raghuveer Devulapalli
authored
Merge pull request #39 from r-devulap/key-value-fail-case
Fix bug in the key-value sorting networks
2 parents eb581ce + a26a60c commit d2dc8ab

19 files changed

+212
-73
lines changed
File renamed without changes.
File renamed without changes.
File renamed without changes.

benchmarks/bench-qsort.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#include "bench-qsort.hpp"
2+
#include "bench-argsort.hpp"
3+
#include "bench-partial-qsort.hpp"
4+
#include "bench-qselect.hpp"
File renamed without changes.
File renamed without changes.

benchmarks/bench_qsort.cpp

Lines changed: 0 additions & 4 deletions
This file was deleted.

benchmarks/meson.build

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ libbench = []
22

33
if cpp.has_argument('-march=icelake-client')
44
libbench += static_library('bench_qsort',
5-
files('bench_qsort.cpp', ),
5+
files('bench-qsort.cpp', ),
66
dependencies: gbench_dep,
77
include_directories : [src, utils],
88
cpp_args : ['-O3', '-march=icelake-client'],
@@ -11,7 +11,7 @@ endif
1111

1212
if cancompilefp16
1313
libbench += static_library('bench_qsortfp16',
14-
files('bench_qsortfp16.cpp', ),
14+
files('bench-qsortfp16.cpp', ),
1515
dependencies: gbench_dep,
1616
include_directories : [src, utils],
1717
cpp_args : ['-O3', '-march=sapphirerapids'],

src/avx512-64bit-common.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ struct ymm_vector<float> {
156156
}
157157
static void storeu(void *mem, zmm_t x)
158158
{
159-
return _mm256_storeu_ps((float*)mem, x);
159+
_mm256_storeu_ps((float*)mem, x);
160160
}
161161
};
162162
template <>
@@ -285,7 +285,7 @@ struct ymm_vector<uint32_t> {
285285
}
286286
static void storeu(void *mem, zmm_t x)
287287
{
288-
return _mm256_storeu_epi32(mem, x);
288+
_mm256_storeu_epi32(mem, x);
289289
}
290290
};
291291
template <>
@@ -414,7 +414,7 @@ struct ymm_vector<int32_t> {
414414
}
415415
static void storeu(void *mem, zmm_t x)
416416
{
417-
return _mm256_storeu_epi32(mem, x);
417+
_mm256_storeu_epi32(mem, x);
418418
}
419419
};
420420
template <>
@@ -538,7 +538,7 @@ struct zmm_vector<int64_t> {
538538
}
539539
static void storeu(void *mem, zmm_t x)
540540
{
541-
return _mm512_storeu_si512(mem, x);
541+
_mm512_storeu_si512(mem, x);
542542
}
543543
};
544544
template <>
@@ -650,7 +650,7 @@ struct zmm_vector<uint64_t> {
650650
}
651651
static void storeu(void *mem, zmm_t x)
652652
{
653-
return _mm512_storeu_si512(mem, x);
653+
_mm512_storeu_si512(mem, x);
654654
}
655655
};
656656
template <>
@@ -770,7 +770,7 @@ struct zmm_vector<double> {
770770
}
771771
static void storeu(void *mem, zmm_t x)
772772
{
773-
return _mm512_storeu_pd(mem, x);
773+
_mm512_storeu_pd(mem, x);
774774
}
775775
};
776776
X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(double *arr, int64_t arrsize)

src/avx512-64bit-keyvalue-networks.hpp

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,14 @@ X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_64bit(zmm_t &key_zmm1,
9696
zmm_t key_zmm3 = vtype1::min(key_zmm1, key_zmm2);
9797
zmm_t key_zmm4 = vtype1::max(key_zmm1, key_zmm2);
9898

99-
index_type index_zmm3 = vtype2::mask_mov(
100-
index_zmm2, vtype1::eq(key_zmm3, key_zmm1), index_zmm1);
101-
index_type index_zmm4 = vtype2::mask_mov(
102-
index_zmm1, vtype1::eq(key_zmm3, key_zmm1), index_zmm2);
99+
typename vtype1::opmask_t movmask = vtype1::eq(key_zmm3, key_zmm1);
100+
101+
index_type index_zmm3 = vtype2::mask_mov(index_zmm2, movmask, index_zmm1);
102+
index_type index_zmm4 = vtype2::mask_mov(index_zmm1, movmask, index_zmm2);
103+
104+
/* need to reverse the lower registers to keep the correct order */
105+
key_zmm4 = vtype1::permutexvar(rev_index1, key_zmm4);
106+
index_zmm4 = vtype2::permutexvar(rev_index2, index_zmm4);
103107

104108
// 2) Recursive half cleaner for each
105109
key_zmm1 = bitonic_merge_zmm_64bit<vtype1, vtype2>(key_zmm3, index_zmm3);
@@ -129,14 +133,17 @@ X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(zmm_t *key_zmm,
129133
zmm_t key_zmm_m1 = vtype1::max(key_zmm[0], key_zmm3r);
130134
zmm_t key_zmm_m2 = vtype1::max(key_zmm[1], key_zmm2r);
131135

136+
typename vtype1::opmask_t movmask1 = vtype1::eq(key_zmm_t1, key_zmm[0]);
137+
typename vtype1::opmask_t movmask2 = vtype1::eq(key_zmm_t2, key_zmm[1]);
138+
132139
index_type index_zmm_t1 = vtype2::mask_mov(
133-
index_zmm3r, vtype1::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]);
140+
index_zmm3r, movmask1, index_zmm[0]);
134141
index_type index_zmm_m1 = vtype2::mask_mov(
135-
index_zmm[0], vtype1::eq(key_zmm_t1, key_zmm[0]), index_zmm3r);
142+
index_zmm[0], movmask1, index_zmm3r);
136143
index_type index_zmm_t2 = vtype2::mask_mov(
137-
index_zmm2r, vtype1::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]);
144+
index_zmm2r, movmask2, index_zmm[1]);
138145
index_type index_zmm_m2 = vtype2::mask_mov(
139-
index_zmm[1], vtype1::eq(key_zmm_t2, key_zmm[1]), index_zmm2r);
146+
index_zmm[1], movmask2, index_zmm2r);
140147

141148
// 2) Recursive half clearer: 16
142149
zmm_t key_zmm_t3 = vtype1::permutexvar(rev_index1, key_zmm_m2);
@@ -149,14 +156,17 @@ X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(zmm_t *key_zmm,
149156
zmm_t key_zmm2 = vtype1::min(key_zmm_t3, key_zmm_t4);
150157
zmm_t key_zmm3 = vtype1::max(key_zmm_t3, key_zmm_t4);
151158

159+
movmask1 = vtype1::eq(key_zmm0, key_zmm_t1);
160+
movmask2 = vtype1::eq(key_zmm2, key_zmm_t3);
161+
152162
index_type index_zmm0 = vtype2::mask_mov(
153-
index_zmm_t2, vtype1::eq(key_zmm0, key_zmm_t1), index_zmm_t1);
163+
index_zmm_t2, movmask1, index_zmm_t1);
154164
index_type index_zmm1 = vtype2::mask_mov(
155-
index_zmm_t1, vtype1::eq(key_zmm0, key_zmm_t1), index_zmm_t2);
165+
index_zmm_t1, movmask1, index_zmm_t2);
156166
index_type index_zmm2 = vtype2::mask_mov(
157-
index_zmm_t4, vtype1::eq(key_zmm2, key_zmm_t3), index_zmm_t3);
167+
index_zmm_t4, movmask2, index_zmm_t3);
158168
index_type index_zmm3 = vtype2::mask_mov(
159-
index_zmm_t3, vtype1::eq(key_zmm2, key_zmm_t3), index_zmm_t4);
169+
index_zmm_t3, movmask2, index_zmm_t4);
160170

161171
key_zmm[0] = bitonic_merge_zmm_64bit<vtype1, vtype2>(key_zmm0, index_zmm0);
162172
key_zmm[1] = bitonic_merge_zmm_64bit<vtype1, vtype2>(key_zmm1, index_zmm1);
@@ -197,22 +207,27 @@ X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(zmm_t *key_zmm,
197207
zmm_t key_zmm_m3 = vtype1::max(key_zmm[2], key_zmm5r);
198208
zmm_t key_zmm_m4 = vtype1::max(key_zmm[3], key_zmm4r);
199209

210+
typename vtype1::opmask_t movmask1 = vtype1::eq(key_zmm_t1, key_zmm[0]);
211+
typename vtype1::opmask_t movmask2 = vtype1::eq(key_zmm_t2, key_zmm[1]);
212+
typename vtype1::opmask_t movmask3 = vtype1::eq(key_zmm_t3, key_zmm[2]);
213+
typename vtype1::opmask_t movmask4 = vtype1::eq(key_zmm_t4, key_zmm[3]);
214+
200215
index_type index_zmm_t1 = vtype2::mask_mov(
201-
index_zmm7r, vtype1::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]);
216+
index_zmm7r, movmask1, index_zmm[0]);
202217
index_type index_zmm_m1 = vtype2::mask_mov(
203-
index_zmm[0], vtype1::eq(key_zmm_t1, key_zmm[0]), index_zmm7r);
218+
index_zmm[0], movmask1, index_zmm7r);
204219
index_type index_zmm_t2 = vtype2::mask_mov(
205-
index_zmm6r, vtype1::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]);
220+
index_zmm6r, movmask2, index_zmm[1]);
206221
index_type index_zmm_m2 = vtype2::mask_mov(
207-
index_zmm[1], vtype1::eq(key_zmm_t2, key_zmm[1]), index_zmm6r);
222+
index_zmm[1], movmask2, index_zmm6r);
208223
index_type index_zmm_t3 = vtype2::mask_mov(
209-
index_zmm5r, vtype1::eq(key_zmm_t3, key_zmm[2]), index_zmm[2]);
224+
index_zmm5r, movmask3, index_zmm[2]);
210225
index_type index_zmm_m3 = vtype2::mask_mov(
211-
index_zmm[2], vtype1::eq(key_zmm_t3, key_zmm[2]), index_zmm5r);
226+
index_zmm[2], movmask3, index_zmm5r);
212227
index_type index_zmm_t4 = vtype2::mask_mov(
213-
index_zmm4r, vtype1::eq(key_zmm_t4, key_zmm[3]), index_zmm[3]);
228+
index_zmm4r, movmask4, index_zmm[3]);
214229
index_type index_zmm_m4 = vtype2::mask_mov(
215-
index_zmm[3], vtype1::eq(key_zmm_t4, key_zmm[3]), index_zmm4r);
230+
index_zmm[3], movmask4, index_zmm4r);
216231

217232
zmm_t key_zmm_t5 = vtype1::permutexvar(rev_index1, key_zmm_m4);
218233
zmm_t key_zmm_t6 = vtype1::permutexvar(rev_index1, key_zmm_m3);

0 commit comments

Comments
 (0)