Skip to content

Commit 8540f36

Browse files
ChrisHegartymridula-s109
authored andcommitted
JDKVectorLibrary: update low-level bounds checks and add benchmark (elastic#130216)
This commit updates the low-level bounds checks in JDKVectorLibrary and add benchmark, so that we can more easily bench the low-level operations. Note: I added the mr-jar gradle plugin to the benchmarks so that we can compile with preview features in Java 21, namely MemorySegment.
1 parent 6ec21ef commit 8540f36

File tree

6 files changed

+225
-18
lines changed

6 files changed

+225
-18
lines changed

benchmarks/build.gradle

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import org.elasticsearch.gradle.OS
1313
apply plugin: org.elasticsearch.gradle.internal.ElasticsearchJavaBasePlugin
1414
apply plugin: 'java-library'
1515
apply plugin: 'application'
16+
apply plugin: 'elasticsearch.mrjar'
1617

1718
var os = org.gradle.internal.os.OperatingSystem.current()
1819

@@ -46,6 +47,7 @@ dependencies {
4647
api(project(':x-pack:plugin:core'))
4748
api(project(':x-pack:plugin:esql'))
4849
api(project(':x-pack:plugin:esql:compute'))
50+
implementation project(path: ':libs:native')
4951
implementation project(path: ':libs:simdvec')
5052
expression(project(path: ':modules:lang-expression', configuration: 'zip'))
5153
painless(project(path: ':modules:lang-painless', configuration: 'zip'))
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
package org.elasticsearch.benchmark.vector;
10+
11+
import org.apache.lucene.util.VectorUtil;
12+
import org.elasticsearch.common.logging.LogConfigurator;
13+
import org.elasticsearch.common.logging.NodeNamePatternConverter;
14+
import org.elasticsearch.nativeaccess.NativeAccess;
15+
import org.elasticsearch.nativeaccess.VectorSimilarityFunctions;
16+
import org.openjdk.jmh.annotations.Benchmark;
17+
import org.openjdk.jmh.annotations.BenchmarkMode;
18+
import org.openjdk.jmh.annotations.Fork;
19+
import org.openjdk.jmh.annotations.Level;
20+
import org.openjdk.jmh.annotations.Measurement;
21+
import org.openjdk.jmh.annotations.Mode;
22+
import org.openjdk.jmh.annotations.OutputTimeUnit;
23+
import org.openjdk.jmh.annotations.Param;
24+
import org.openjdk.jmh.annotations.Scope;
25+
import org.openjdk.jmh.annotations.Setup;
26+
import org.openjdk.jmh.annotations.State;
27+
import org.openjdk.jmh.annotations.TearDown;
28+
import org.openjdk.jmh.annotations.Warmup;
29+
30+
import java.lang.foreign.Arena;
31+
import java.lang.foreign.MemorySegment;
32+
import java.util.concurrent.ThreadLocalRandom;
33+
import java.util.concurrent.TimeUnit;
34+
35+
@BenchmarkMode(Mode.AverageTime)
36+
@OutputTimeUnit(TimeUnit.NANOSECONDS)
37+
@State(Scope.Benchmark)
38+
@Warmup(iterations = 3, time = 1)
39+
@Measurement(iterations = 5, time = 1)
40+
public class JDKVectorInt7uBenchmark {
41+
42+
static {
43+
NodeNamePatternConverter.setGlobalNodeName("foo");
44+
LogConfigurator.loadLog4jPlugins();
45+
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
46+
}
47+
48+
byte[] byteArrayA;
49+
byte[] byteArrayB;
50+
MemorySegment heapSegA, heapSegB;
51+
MemorySegment nativeSegA, nativeSegB;
52+
53+
Arena arena;
54+
55+
@Param({ "1", "128", "207", "256", "300", "512", "702", "1024" })
56+
public int size;
57+
58+
@Setup(Level.Iteration)
59+
public void init() {
60+
byteArrayA = new byte[size];
61+
byteArrayB = new byte[size];
62+
for (int i = 0; i < size; ++i) {
63+
randomInt7BytesBetween(byteArrayA);
64+
randomInt7BytesBetween(byteArrayB);
65+
}
66+
heapSegA = MemorySegment.ofArray(byteArrayA);
67+
heapSegB = MemorySegment.ofArray(byteArrayB);
68+
69+
arena = Arena.ofConfined();
70+
nativeSegA = arena.allocate((long) byteArrayA.length);
71+
MemorySegment.copy(MemorySegment.ofArray(byteArrayA), 0L, nativeSegA, 0L, byteArrayA.length);
72+
nativeSegB = arena.allocate((long) byteArrayB.length);
73+
MemorySegment.copy(MemorySegment.ofArray(byteArrayB), 0L, nativeSegB, 0L, byteArrayB.length);
74+
}
75+
76+
@TearDown
77+
public void teardown() {
78+
arena.close();
79+
}
80+
81+
@Benchmark
82+
@Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
83+
public int dotProductLucene() {
84+
return VectorUtil.dotProduct(byteArrayA, byteArrayB);
85+
}
86+
87+
@Benchmark
88+
@Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
89+
public int dotProductNativeWithNativeSeg() {
90+
return dotProduct7u(nativeSegA, nativeSegB, size);
91+
}
92+
93+
@Benchmark
94+
@Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
95+
public int dotProductNativeWithHeapSeg() {
96+
return dotProduct7u(heapSegA, heapSegB, size);
97+
}
98+
99+
static final VectorSimilarityFunctions vectorSimilarityFunctions = vectorSimilarityFunctions();
100+
101+
static VectorSimilarityFunctions vectorSimilarityFunctions() {
102+
return NativeAccess.instance().getVectorSimilarityFunctions().get();
103+
}
104+
105+
int dotProduct7u(MemorySegment a, MemorySegment b, int length) {
106+
try {
107+
return (int) vectorSimilarityFunctions.dotProductHandle7u().invokeExact(a, b, length);
108+
} catch (Throwable e) {
109+
if (e instanceof Error err) {
110+
throw err;
111+
} else if (e instanceof RuntimeException re) {
112+
throw re;
113+
} else {
114+
throw new RuntimeException(e);
115+
}
116+
}
117+
}
118+
119+
// Unsigned int7 byte vectors have values in the range of 0 to 127 (inclusive).
120+
static final byte MIN_INT7_VALUE = 0;
121+
static final byte MAX_INT7_VALUE = 127;
122+
123+
static void randomInt7BytesBetween(byte[] bytes) {
124+
var random = ThreadLocalRandom.current();
125+
for (int i = 0, len = bytes.length; i < len;) {
126+
bytes[i++] = (byte) random.nextInt(MIN_INT7_VALUE, MAX_INT7_VALUE + 1);
127+
}
128+
}
129+
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.benchmark.vector;
11+
12+
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
13+
14+
import org.elasticsearch.test.ESTestCase;
15+
import org.openjdk.jmh.annotations.Param;
16+
17+
import java.util.Arrays;
18+
19+
public class JDKVectorInt7uBenchmarkTests extends ESTestCase {
20+
21+
final double delta = 1e-3;
22+
final int size;
23+
24+
public JDKVectorInt7uBenchmarkTests(int size) {
25+
this.size = size;
26+
}
27+
28+
public void testDotProduct() {
29+
for (int i = 0; i < 100; i++) {
30+
var bench = new JDKVectorInt7uBenchmark();
31+
bench.size = size;
32+
bench.init();
33+
try {
34+
float expected = dotProductScalar(bench.byteArrayA, bench.byteArrayB);
35+
assertEquals(expected, bench.dotProductLucene(), delta);
36+
assertEquals(expected, bench.dotProductNativeWithHeapSeg(), delta);
37+
assertEquals(expected, bench.dotProductNativeWithNativeSeg(), delta);
38+
} finally {
39+
bench.teardown();
40+
}
41+
}
42+
}
43+
44+
@ParametersFactory
45+
public static Iterable<Object[]> parametersFactory() {
46+
try {
47+
var params = JDKVectorInt7uBenchmark.class.getField("size").getAnnotationsByType(Param.class)[0].value();
48+
return () -> Arrays.stream(params).map(Integer::parseInt).map(i -> new Object[] { i }).iterator();
49+
} catch (NoSuchFieldException e) {
50+
throw new AssertionError(e);
51+
}
52+
}
53+
54+
/** Computes the dot product of the given vectors a and b. */
55+
static int dotProductScalar(byte[] a, byte[] b) {
56+
int res = 0;
57+
for (int i = 0; i < a.length; i++) {
58+
res += a[i] * b[i];
59+
}
60+
return res;
61+
}
62+
}

libs/native/src/main/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.lang.invoke.MethodHandle;
2121
import java.lang.invoke.MethodHandles;
2222
import java.lang.invoke.MethodType;
23+
import java.util.Objects;
2324

2425
import static java.lang.foreign.ValueLayout.ADDRESS;
2526
import static java.lang.foreign.ValueLayout.JAVA_INT;
@@ -99,13 +100,8 @@ private static final class JdkVectorSimilarityFunctions implements VectorSimilar
99100
* @param length the vector dimensions
100101
*/
101102
static int dotProduct7u(MemorySegment a, MemorySegment b, int length) {
102-
assert length >= 0;
103-
if (a.byteSize() != b.byteSize()) {
104-
throw new IllegalArgumentException("dimensions differ: " + a.byteSize() + "!=" + b.byteSize());
105-
}
106-
if (length > a.byteSize()) {
107-
throw new IllegalArgumentException("length: " + length + ", greater than vector dimensions: " + a.byteSize());
108-
}
103+
checkByteSize(a, b);
104+
Objects.checkFromIndexSize(0, length, (int) a.byteSize());
109105
return dot7u(a, b, length);
110106
}
111107

@@ -119,14 +115,15 @@ static int dotProduct7u(MemorySegment a, MemorySegment b, int length) {
119115
* @param length the vector dimensions
120116
*/
121117
static int squareDistance7u(MemorySegment a, MemorySegment b, int length) {
122-
assert length >= 0;
118+
checkByteSize(a, b);
119+
Objects.checkFromIndexSize(0, length, (int) a.byteSize());
120+
return sqr7u(a, b, length);
121+
}
122+
123+
static void checkByteSize(MemorySegment a, MemorySegment b) {
123124
if (a.byteSize() != b.byteSize()) {
124125
throw new IllegalArgumentException("dimensions differ: " + a.byteSize() + "!=" + b.byteSize());
125126
}
126-
if (length > a.byteSize()) {
127-
throw new IllegalArgumentException("length: " + length + ", greater than vector dimensions: " + a.byteSize());
128-
}
129-
return sqr7u(a, b, length);
130127
}
131128

132129
private static int dot7u(MemorySegment a, MemorySegment b, int length) {

libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryTests.java

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,19 @@ public class JDKVectorLibraryTests extends VectorSimilarityFunctionsTests {
2828
static final byte MAX_INT7_VALUE = 127;
2929

3030
static final Class<IllegalArgumentException> IAE = IllegalArgumentException.class;
31+
static final Class<IndexOutOfBoundsException> IOOBE = IndexOutOfBoundsException.class;
3132

3233
static final int[] VECTOR_DIMS = { 1, 4, 6, 8, 13, 16, 25, 31, 32, 33, 64, 100, 128, 207, 256, 300, 512, 702, 1023, 1024, 1025 };
3334

3435
final int size;
3536

3637
static Arena arena;
3738

39+
final double delta;
40+
3841
public JDKVectorLibraryTests(int size) {
3942
this.size = size;
43+
this.delta = 1e-5 * size; // scale the delta with the size
4044
}
4145

4246
@BeforeClass
@@ -103,11 +107,24 @@ static boolean testWithHeapSegments() {
103107
public void testIllegalDims() {
104108
assumeTrue(notSupportedMsg(), supported());
105109
var segment = arena.allocate((long) size * 3);
106-
var e = expectThrows(IAE, () -> dotProduct7u(segment.asSlice(0L, size), segment.asSlice(size, size + 1), size));
107-
assertThat(e.getMessage(), containsString("dimensions differ"));
108110

109-
e = expectThrows(IAE, () -> dotProduct7u(segment.asSlice(0L, size), segment.asSlice(size, size), size + 1));
110-
assertThat(e.getMessage(), containsString("greater than vector dimensions"));
111+
var e1 = expectThrows(IAE, () -> dotProduct7u(segment.asSlice(0L, size), segment.asSlice(size, size + 1), size));
112+
assertThat(e1.getMessage(), containsString("dimensions differ"));
113+
114+
var e2 = expectThrows(IOOBE, () -> dotProduct7u(segment.asSlice(0L, size), segment.asSlice(size, size), size + 1));
115+
assertThat(e2.getMessage(), containsString("out of bounds for length"));
116+
117+
var e3 = expectThrows(IOOBE, () -> dotProduct7u(segment.asSlice(0L, size), segment.asSlice(size, size), -1));
118+
assertThat(e3.getMessage(), containsString("out of bounds for length"));
119+
120+
var e4 = expectThrows(IAE, () -> squareDistance7u(segment.asSlice(0L, size), segment.asSlice(size, size + 1), size));
121+
assertThat(e4.getMessage(), containsString("dimensions differ"));
122+
123+
var e5 = expectThrows(IOOBE, () -> squareDistance7u(segment.asSlice(0L, size), segment.asSlice(size, size), size + 1));
124+
assertThat(e5.getMessage(), containsString("out of bounds for length"));
125+
126+
var e6 = expectThrows(IOOBE, () -> squareDistance7u(segment.asSlice(0L, size), segment.asSlice(size, size), -1));
127+
assertThat(e6.getMessage(), containsString("out of bounds for length"));
111128
}
112129

113130
int dotProduct7u(MemorySegment a, MemorySegment b, int length) {

libs/simdvec/native/build.gradle

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ var os = org.gradle.internal.os.OperatingSystem.current()
2323
// 1. Temporarily comment out the download in libs/native/library/build.gradle
2424
// libs "org.elasticsearch:vec:${vecVersion}@zip"
2525
// 2. Copy your locally built libvec binary, e.g.
26-
// cp libs/vec/native/build/libs/vec/shared/libvec.dylib libs/native/libraries/build/platform/darwin-aarch64/libvec.dylib
26+
// cp libs/simdvec/native/build/libs/vec/shared/aarch64/libvec.dylib libs/native/libraries/build/platform/darwin-aarch64/libvec.dylib
2727
//
2828
// Look at the disassemble:
29-
// objdump --disassemble-symbols=_dot8s build/libs/vec/shared/libvec.dylib
29+
// objdump --disassemble-symbols=_dot7u build/libs/vec/shared/aarch64/libvec.dylib
3030
// Note: symbol decoration may differ on Linux, i.e. the leading underscore is not present
3131
//
3232
// gcc -shared -fpic -o libvec.so -I src/vec/headers/ src/vec/c/vec.c -O3

0 commit comments

Comments
 (0)