Skip to content

Commit 5c133e7

Browse files
author
Raghuveer Devulapalli
authored
Merge pull request #120 from r-devulap/kv-32bit
Improve key-value sort performance
2 parents 9b978ec + 845bc36 commit 5c133e7

File tree

7 files changed

+421
-41
lines changed

7 files changed

+421
-41
lines changed

benchmarks/bench-objsort.hpp

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,19 @@ static constexpr char euclidean[] = "euclidean";
55
static constexpr char taxicab[] = "taxicab";
66
static constexpr char chebyshev[] = "chebyshev";
77

8-
template <const char* val>
8+
template <typename T, const char* val>
99
struct Point3D {
10-
double x;
11-
double y;
12-
double z;
10+
T x;
11+
T y;
12+
T z;
1313
static constexpr std::string_view name {val};
1414
Point3D()
1515
{
16-
x = (double)rand() / RAND_MAX;
17-
y = (double)rand() / RAND_MAX;
18-
z = (double)rand() / RAND_MAX;
16+
x = (T)rand() / RAND_MAX;
17+
y = (T)rand() / RAND_MAX;
18+
z = (T)rand() / RAND_MAX;
1919
}
20-
double distance()
20+
T distance()
2121
{
2222
if constexpr (name == "x") {
2323
return x;
@@ -77,7 +77,7 @@ static void simdobjsort(benchmark::State &state)
7777
std::vector<T> arr_bkp = arr;
7878
// benchmark
7979
for (auto _ : state) {
80-
x86simdsort::object_qsort(arr.data(), arr.size(), [](T p) -> double {
80+
x86simdsort::object_qsort(arr.data(), arr.size(), [](T p) {
8181
return p.distance();
8282
});
8383
state.PauseTiming();
@@ -89,20 +89,22 @@ static void simdobjsort(benchmark::State &state)
8989
}
9090
}
9191

92-
#define BENCHMARK_OBJSORT(func, T) \
93-
BENCHMARK_TEMPLATE(func, T) \
92+
#define BENCHMARK_OBJSORT(func, T, type, dist) \
93+
BENCHMARK_TEMPLATE(func, T<type,dist>) \
9494
->Arg(10e1) \
9595
->Arg(10e2) \
9696
->Arg(10e3) \
9797
->Arg(10e4) \
9898
->Arg(10e5) \
9999
->Arg(10e6);
100100

101-
BENCHMARK_OBJSORT(simdobjsort, Point3D<x>)
102-
BENCHMARK_OBJSORT(scalarobjsort, Point3D<x>)
103-
BENCHMARK_OBJSORT(simdobjsort, Point3D<taxicab>)
104-
BENCHMARK_OBJSORT(scalarobjsort, Point3D<taxicab>)
105-
BENCHMARK_OBJSORT(simdobjsort, Point3D<euclidean>)
106-
BENCHMARK_OBJSORT(scalarobjsort, Point3D<euclidean>)
107-
BENCHMARK_OBJSORT(simdobjsort, Point3D<chebyshev>)
108-
BENCHMARK_OBJSORT(scalarobjsort, Point3D<chebyshev>)
101+
BENCHMARK_OBJSORT(simdobjsort, Point3D, double, x)
102+
BENCHMARK_OBJSORT(scalarobjsort, Point3D, double, x)
103+
BENCHMARK_OBJSORT(simdobjsort, Point3D, float, x)
104+
BENCHMARK_OBJSORT(scalarobjsort, Point3D, float, x)
105+
BENCHMARK_OBJSORT(simdobjsort, Point3D, double, taxicab )
106+
BENCHMARK_OBJSORT(scalarobjsort, Point3D, double, taxicab)
107+
BENCHMARK_OBJSORT(simdobjsort, Point3D, double, euclidean)
108+
BENCHMARK_OBJSORT(scalarobjsort, Point3D, double, euclidean)
109+
BENCHMARK_OBJSORT(simdobjsort, Point3D, double, chebyshev)
110+
BENCHMARK_OBJSORT(scalarobjsort, Point3D, double, chebyshev)

run-bench.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
parser.add_argument("-b", '--branch', type=str, default="main", required=False)
88
parser.add_argument('--benchcompare', type=str, help='Compare simd bench with stdsort methods. Requires one of qsort, qselect, partialsort, argsort or argselect')
99
parser.add_argument("-f", '--filter', type=str, required=False)
10+
parser.add_argument("-r", '--repeat', type=int, required=False)
1011
args = parser.parse_args()
1112

1213
if len(sys.argv) == 1:
@@ -15,6 +16,9 @@
1516
filterb = ""
1617
if args.filter is not None:
1718
filterb = args.filter
19+
repeatnum = 1
20+
if args.repeat is not None:
21+
repeatnum = args.repeat
1822

1923
if args.benchcompare:
2024
baseline = ""
@@ -43,11 +47,11 @@
4347
else:
4448
parser.print_help(sys.stderr)
4549
parser.error("ERROR: Unknown argument '%s'" % args.benchcompare)
46-
rc = subprocess.check_call("./scripts/bench-compare.sh '%s' '%s'" % (baseline, contender), shell=True)
50+
rc = subprocess.check_call("./scripts/bench-compare.sh '%s' '%s' '%d'" % (baseline, contender, repeatnum), shell=True)
4751

4852
if args.branchcompare:
4953
branch = args.branch
5054
if args.filter is None:
51-
rc = subprocess.check_call("./scripts/branch-compare.sh '%s'" % (branch), shell=True)
55+
rc = subprocess.check_call("./scripts/branch-compare.sh '%s' '%d'" % (branch, repeatnum), shell=True)
5256
else:
53-
rc = subprocess.check_call("./scripts/branch-compare.sh '%s' '%s'" % (branch, args.filter), shell=True)
57+
rc = subprocess.check_call("./scripts/branch-compare.sh '%s' '%s' '%d'" % (branch, args.filter, repeatnum), shell=True)

scripts/bench-compare.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@ compare=$(realpath .bench/google-benchmark/tools/compare.py)
1414
meson setup -Dbuild_benchmarks=true --warnlevel 0 --buildtype release builddir-${branch}
1515
cd builddir-${branch}
1616
ninja
17-
$compare filters ./benchexe $1 $2
17+
$compare filters ./benchexe $1 $2 --benchmark_repetitions=$3

scripts/branch-compare.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@ build_branch $basebranch
4444
contender=$(realpath ${branch}/builddir/benchexe)
4545
baseline=$(realpath ${basebranch}/builddir/benchexe)
4646

47-
if [ -z "$2" ]; then
47+
if [ -z "$3" ]; then
4848
echo "Comparing all benchmarks .."
49-
$compare benchmarks $baseline $contender
49+
$compare benchmarks $baseline $contender --benchmark_repetitions=$2
5050
else
5151
echo "Comparing benchmark $2 .."
52-
$compare benchmarksfiltered $baseline $2 $contender $2
52+
$compare benchmarksfiltered $baseline $2 $contender $2 --benchmark_repetitions=$3
5353
fi

src/avx512-32bit-qsort.hpp

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ template <>
3232
struct zmm_vector<int32_t> {
3333
using type_t = int32_t;
3434
using reg_t = __m512i;
35+
using regi_t = __m512i;
3536
using halfreg_t = __m256i;
3637
using opmask_t = __mmask16;
3738
static const uint8_t numlanes = 16;
@@ -65,6 +66,10 @@ struct zmm_vector<int32_t> {
6566
{
6667
return _mm512_cmp_epi32_mask(x, y, _MM_CMPINT_NLT);
6768
}
69+
static opmask_t eq(reg_t x, reg_t y)
70+
{
71+
return _mm512_cmpeq_epi32_mask(x, y);
72+
}
6873
static opmask_t get_partial_loadmask(uint64_t num_to_read)
6974
{
7075
return ((0x1ull << num_to_read) - 0x1ull);
@@ -123,6 +128,40 @@ struct zmm_vector<int32_t> {
123128
{
124129
return _mm512_set1_epi32(v);
125130
}
131+
static regi_t seti(int v1,
132+
int v2,
133+
int v3,
134+
int v4,
135+
int v5,
136+
int v6,
137+
int v7,
138+
int v8,
139+
int v9,
140+
int v10,
141+
int v11,
142+
int v12,
143+
int v13,
144+
int v14,
145+
int v15,
146+
int v16)
147+
{
148+
return _mm512_set_epi32(v1,
149+
v2,
150+
v3,
151+
v4,
152+
v5,
153+
v6,
154+
v7,
155+
v8,
156+
v9,
157+
v10,
158+
v11,
159+
v12,
160+
v13,
161+
v14,
162+
v15,
163+
v16);
164+
}
126165
template <uint8_t mask>
127166
static reg_t shuffle(reg_t zmm)
128167
{
@@ -171,6 +210,7 @@ template <>
171210
struct zmm_vector<uint32_t> {
172211
using type_t = uint32_t;
173212
using reg_t = __m512i;
213+
using regi_t = __m512i;
174214
using halfreg_t = __m256i;
175215
using opmask_t = __mmask16;
176216
static const uint8_t numlanes = 16;
@@ -214,6 +254,10 @@ struct zmm_vector<uint32_t> {
214254
{
215255
return _mm512_cmp_epu32_mask(x, y, _MM_CMPINT_NLT);
216256
}
257+
static opmask_t eq(reg_t x, reg_t y)
258+
{
259+
return _mm512_cmpeq_epu32_mask(x, y);
260+
}
217261
static opmask_t get_partial_loadmask(uint64_t num_to_read)
218262
{
219263
return ((0x1ull << num_to_read) - 0x1ull);
@@ -262,6 +306,40 @@ struct zmm_vector<uint32_t> {
262306
{
263307
return _mm512_set1_epi32(v);
264308
}
309+
static regi_t seti(int v1,
310+
int v2,
311+
int v3,
312+
int v4,
313+
int v5,
314+
int v6,
315+
int v7,
316+
int v8,
317+
int v9,
318+
int v10,
319+
int v11,
320+
int v12,
321+
int v13,
322+
int v14,
323+
int v15,
324+
int v16)
325+
{
326+
return _mm512_set_epi32(v1,
327+
v2,
328+
v3,
329+
v4,
330+
v5,
331+
v6,
332+
v7,
333+
v8,
334+
v9,
335+
v10,
336+
v11,
337+
v12,
338+
v13,
339+
v14,
340+
v15,
341+
v16);
342+
}
265343
template <uint8_t mask>
266344
static reg_t shuffle(reg_t zmm)
267345
{
@@ -310,6 +388,7 @@ template <>
310388
struct zmm_vector<float> {
311389
using type_t = float;
312390
using reg_t = __m512;
391+
using regi_t = __m512i;
313392
using halfreg_t = __m256;
314393
using opmask_t = __mmask16;
315394
static const uint8_t numlanes = 16;
@@ -343,6 +422,10 @@ struct zmm_vector<float> {
343422
{
344423
return _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ);
345424
}
425+
static opmask_t eq(reg_t x, reg_t y)
426+
{
427+
return _mm512_cmpeq_ps_mask(x, y);
428+
}
346429
static opmask_t get_partial_loadmask(uint64_t num_to_read)
347430
{
348431
return ((0x1ull << num_to_read) - 0x1ull);
@@ -415,6 +498,40 @@ struct zmm_vector<float> {
415498
{
416499
return _mm512_set1_ps(v);
417500
}
501+
static regi_t seti(int v1,
502+
int v2,
503+
int v3,
504+
int v4,
505+
int v5,
506+
int v6,
507+
int v7,
508+
int v8,
509+
int v9,
510+
int v10,
511+
int v11,
512+
int v12,
513+
int v13,
514+
int v14,
515+
int v15,
516+
int v16)
517+
{
518+
return _mm512_set_epi32(v1,
519+
v2,
520+
v3,
521+
v4,
522+
v5,
523+
v6,
524+
v7,
525+
v8,
526+
v9,
527+
v10,
528+
v11,
529+
v12,
530+
v13,
531+
v14,
532+
v15,
533+
v16);
534+
}
418535
template <uint8_t mask>
419536
static reg_t shuffle(reg_t zmm)
420537
{

0 commit comments

Comments
 (0)