Skip to content

Commit c609cd1

Browse files
committed
Merge remote-tracking branch 'elastic/main' into fix-values-aggregator
2 parents 4c44fa1 + ef9f544 commit c609cd1

File tree

45 files changed

+1748
-275
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+1748
-275
lines changed

benchmarks/src/main/java/org/elasticsearch/benchmark/vector/VectorScorerBenchmark.java renamed to benchmarks/src/main/java/org/elasticsearch/benchmark/vector/Int7uScorerBenchmark.java

Lines changed: 5 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -55,24 +55,23 @@
5555
/**
5656
* Benchmark that compares various scalar quantized vector similarity function
5757
* implementations;: scalar, lucene's panama-ized, and Elasticsearch's native.
58-
* Run with ./gradlew -p benchmarks run --args 'VectorScorerBenchmark'
58+
* Run with ./gradlew -p benchmarks run --args 'Int7uScorerBenchmark'
5959
*/
60-
public class VectorScorerBenchmark {
60+
public class Int7uScorerBenchmark {
6161

6262
static {
6363
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
6464
}
6565

6666
@Param({ "96", "768", "1024" })
67-
int dims;
68-
int size = 2; // there are only two vectors to compare
67+
public int dims;
68+
final int size = 2; // there are only two vectors to compare
6969

7070
Directory dir;
7171
IndexInput in;
7272
VectorScorerFactory factory;
7373

74-
byte[] vec1;
75-
byte[] vec2;
74+
byte[] vec1, vec2;
7675
float vec1Offset;
7776
float vec2Offset;
7877
float scoreCorrectionConstant;
@@ -139,39 +138,6 @@ public void setup() throws IOException {
139138
nativeDotScorerQuery = factory.getInt7SQVectorScorer(VectorSimilarityFunction.DOT_PRODUCT, values, queryVec).get();
140139
luceneSqrScorerQuery = luceneScorer(values, VectorSimilarityFunction.EUCLIDEAN, queryVec);
141140
nativeSqrScorerQuery = factory.getInt7SQVectorScorer(VectorSimilarityFunction.EUCLIDEAN, values, queryVec).get();
142-
143-
// sanity
144-
var f1 = dotProductLucene();
145-
var f2 = dotProductNative();
146-
var f3 = dotProductScalar();
147-
if (f1 != f2) {
148-
throw new AssertionError("lucene[" + f1 + "] != " + "native[" + f2 + "]");
149-
}
150-
if (f1 != f3) {
151-
throw new AssertionError("lucene[" + f1 + "] != " + "scalar[" + f3 + "]");
152-
}
153-
// square distance
154-
f1 = squareDistanceLucene();
155-
f2 = squareDistanceNative();
156-
f3 = squareDistanceScalar();
157-
if (f1 != f2) {
158-
throw new AssertionError("lucene[" + f1 + "] != " + "native[" + f2 + "]");
159-
}
160-
if (f1 != f3) {
161-
throw new AssertionError("lucene[" + f1 + "] != " + "scalar[" + f3 + "]");
162-
}
163-
164-
var q1 = dotProductLuceneQuery();
165-
var q2 = dotProductNativeQuery();
166-
if (q1 != q2) {
167-
throw new AssertionError("query: lucene[" + q1 + "] != " + "native[" + q2 + "]");
168-
}
169-
170-
var sqr1 = squareDistanceLuceneQuery();
171-
var sqr2 = squareDistanceNativeQuery();
172-
if (sqr1 != sqr2) {
173-
throw new AssertionError("query: lucene[" + q1 + "] != " + "native[" + q2 + "]");
174-
}
175141
}
176142

177143
@TearDown
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
package org.elasticsearch.benchmark.vector;
10+
11+
import org.apache.lucene.util.VectorUtil;
12+
import org.elasticsearch.common.logging.LogConfigurator;
13+
import org.elasticsearch.common.logging.NodeNamePatternConverter;
14+
import org.elasticsearch.nativeaccess.NativeAccess;
15+
import org.elasticsearch.nativeaccess.VectorSimilarityFunctions;
16+
import org.openjdk.jmh.annotations.Benchmark;
17+
import org.openjdk.jmh.annotations.BenchmarkMode;
18+
import org.openjdk.jmh.annotations.Fork;
19+
import org.openjdk.jmh.annotations.Level;
20+
import org.openjdk.jmh.annotations.Measurement;
21+
import org.openjdk.jmh.annotations.Mode;
22+
import org.openjdk.jmh.annotations.OutputTimeUnit;
23+
import org.openjdk.jmh.annotations.Param;
24+
import org.openjdk.jmh.annotations.Scope;
25+
import org.openjdk.jmh.annotations.Setup;
26+
import org.openjdk.jmh.annotations.State;
27+
import org.openjdk.jmh.annotations.TearDown;
28+
import org.openjdk.jmh.annotations.Warmup;
29+
30+
import java.lang.foreign.Arena;
31+
import java.lang.foreign.MemorySegment;
32+
import java.lang.foreign.ValueLayout;
33+
import java.nio.ByteOrder;
34+
import java.util.concurrent.ThreadLocalRandom;
35+
import java.util.concurrent.TimeUnit;
36+
37+
@BenchmarkMode(Mode.AverageTime)
38+
@OutputTimeUnit(TimeUnit.NANOSECONDS)
39+
@State(Scope.Benchmark)
40+
@Warmup(iterations = 3, time = 1)
41+
@Measurement(iterations = 5, time = 1)
42+
public class JDKVectorFloat32Benchmark {
43+
44+
static {
45+
NodeNamePatternConverter.setGlobalNodeName("foo");
46+
LogConfigurator.loadLog4jPlugins();
47+
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
48+
}
49+
50+
static final ValueLayout.OfFloat LAYOUT_LE_FLOAT = ValueLayout.JAVA_FLOAT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN);
51+
52+
float[] floatsA;
53+
float[] floatsB;
54+
float[] scratch;
55+
MemorySegment heapSegA, heapSegB;
56+
MemorySegment nativeSegA, nativeSegB;
57+
58+
Arena arena;
59+
60+
@Param({ "1", "128", "207", "256", "300", "512", "702", "1024", "1536", "2048" })
61+
public int size;
62+
63+
@Setup(Level.Iteration)
64+
public void init() {
65+
ThreadLocalRandom random = ThreadLocalRandom.current();
66+
67+
floatsA = new float[size];
68+
floatsB = new float[size];
69+
scratch = new float[size];
70+
for (int i = 0; i < size; ++i) {
71+
floatsA[i] = random.nextFloat();
72+
floatsB[i] = random.nextFloat();
73+
}
74+
heapSegA = MemorySegment.ofArray(floatsA);
75+
heapSegB = MemorySegment.ofArray(floatsB);
76+
77+
arena = Arena.ofConfined();
78+
nativeSegA = arena.allocate((long) floatsA.length * Float.BYTES);
79+
MemorySegment.copy(MemorySegment.ofArray(floatsA), LAYOUT_LE_FLOAT, 0L, nativeSegA, LAYOUT_LE_FLOAT, 0L, floatsA.length);
80+
nativeSegB = arena.allocate((long) floatsB.length * Float.BYTES);
81+
MemorySegment.copy(MemorySegment.ofArray(floatsB), LAYOUT_LE_FLOAT, 0L, nativeSegB, LAYOUT_LE_FLOAT, 0L, floatsB.length);
82+
}
83+
84+
@TearDown
85+
public void teardown() {
86+
arena.close();
87+
}
88+
89+
// -- cosine
90+
91+
@Benchmark
92+
@Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
93+
public float cosineLucene() {
94+
return VectorUtil.cosine(floatsA, floatsB);
95+
}
96+
97+
@Benchmark
98+
@Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
99+
public float cosineLuceneWithCopy() {
100+
// add a copy to better reflect what Lucene has to do to get the target vector on-heap
101+
MemorySegment.copy(nativeSegB, LAYOUT_LE_FLOAT, 0L, scratch, 0, scratch.length);
102+
return VectorUtil.cosine(floatsA, scratch);
103+
}
104+
105+
@Benchmark
106+
@Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
107+
public float cosineNativeWithNativeSeg() {
108+
return cosineFloat32(nativeSegA, nativeSegB, size);
109+
}
110+
111+
@Benchmark
112+
@Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
113+
public float cosineNativeWithHeapSeg() {
114+
return cosineFloat32(heapSegA, heapSegB, size);
115+
}
116+
117+
// -- dot product
118+
119+
@Benchmark
120+
@Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
121+
public float dotProductLucene() {
122+
return VectorUtil.dotProduct(floatsA, floatsB);
123+
}
124+
125+
@Benchmark
126+
@Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
127+
public float dotProductLuceneWithCopy() {
128+
// add a copy to better reflect what Lucene has to do to get the target vector on-heap
129+
MemorySegment.copy(nativeSegB, LAYOUT_LE_FLOAT, 0L, scratch, 0, scratch.length);
130+
return VectorUtil.dotProduct(floatsA, scratch);
131+
}
132+
133+
@Benchmark
134+
@Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
135+
public float dotProductNativeWithNativeSeg() {
136+
return dotProductFloat32(nativeSegA, nativeSegB, size);
137+
}
138+
139+
@Benchmark
140+
@Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
141+
public float dotProductNativeWithHeapSeg() {
142+
return dotProductFloat32(heapSegA, heapSegB, size);
143+
}
144+
145+
// -- square distance
146+
147+
@Benchmark
148+
@Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
149+
public float squareDistanceLucene() {
150+
return VectorUtil.squareDistance(floatsA, floatsB);
151+
}
152+
153+
@Benchmark
154+
@Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
155+
public float squareDistanceLuceneWithCopy() {
156+
// add a copy to better reflect what Lucene has to do to get the target vector on-heap
157+
MemorySegment.copy(nativeSegB, LAYOUT_LE_FLOAT, 0L, scratch, 0, scratch.length);
158+
return VectorUtil.squareDistance(floatsA, scratch);
159+
}
160+
161+
@Benchmark
162+
@Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
163+
public float squareDistanceNativeWithNativeSeg() {
164+
return squareDistanceFloat32(nativeSegA, nativeSegB, size);
165+
}
166+
167+
@Benchmark
168+
@Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
169+
public float squareDistanceNativeWithHeapSeg() {
170+
return squareDistanceFloat32(heapSegA, heapSegB, size);
171+
}
172+
173+
static final VectorSimilarityFunctions vectorSimilarityFunctions = vectorSimilarityFunctions();
174+
175+
static VectorSimilarityFunctions vectorSimilarityFunctions() {
176+
return NativeAccess.instance().getVectorSimilarityFunctions().get();
177+
}
178+
179+
float cosineFloat32(MemorySegment a, MemorySegment b, int length) {
180+
try {
181+
return (float) vectorSimilarityFunctions.cosineHandleFloat32().invokeExact(a, b, length);
182+
} catch (Throwable e) {
183+
if (e instanceof Error err) {
184+
throw err;
185+
} else if (e instanceof RuntimeException re) {
186+
throw re;
187+
} else {
188+
throw new RuntimeException(e);
189+
}
190+
}
191+
}
192+
193+
float dotProductFloat32(MemorySegment a, MemorySegment b, int length) {
194+
try {
195+
return (float) vectorSimilarityFunctions.dotProductHandleFloat32().invokeExact(a, b, length);
196+
} catch (Throwable e) {
197+
if (e instanceof Error err) {
198+
throw err;
199+
} else if (e instanceof RuntimeException re) {
200+
throw re;
201+
} else {
202+
throw new RuntimeException(e);
203+
}
204+
}
205+
}
206+
207+
float squareDistanceFloat32(MemorySegment a, MemorySegment b, int length) {
208+
try {
209+
return (float) vectorSimilarityFunctions.squareDistanceHandleFloat32().invokeExact(a, b, length);
210+
} catch (Throwable e) {
211+
if (e instanceof Error err) {
212+
throw err;
213+
} else if (e instanceof RuntimeException re) {
214+
throw re;
215+
} else {
216+
throw new RuntimeException(e);
217+
}
218+
}
219+
}
220+
}

benchmarks/src/main/java/org/elasticsearch/benchmark/vector/JDKVectorInt7uBenchmark.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public class JDKVectorInt7uBenchmark {
5252

5353
Arena arena;
5454

55-
@Param({ "1", "128", "207", "256", "300", "512", "702", "1024" })
55+
@Param({ "1", "128", "207", "256", "300", "512", "702", "1024", "1536", "2048" })
5656
public int size;
5757

5858
@Setup(Level.Iteration)
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.benchmark.vector;
11+
12+
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
13+
14+
import org.apache.lucene.util.Constants;
15+
import org.elasticsearch.test.ESTestCase;
16+
import org.junit.BeforeClass;
17+
import org.openjdk.jmh.annotations.Param;
18+
19+
import java.util.Arrays;
20+
21+
public class Int7uScorerBenchmarkTests extends ESTestCase {
22+
23+
final double delta = 1e-3;
24+
final int dims;
25+
26+
public Int7uScorerBenchmarkTests(int dims) {
27+
this.dims = dims;
28+
}
29+
30+
@BeforeClass
31+
public static void skipWindows() {
32+
assumeFalse("doesn't work on windows yet", Constants.WINDOWS);
33+
}
34+
35+
public void testDotProduct() throws Exception {
36+
for (int i = 0; i < 100; i++) {
37+
var bench = new Int7uScorerBenchmark();
38+
bench.dims = dims;
39+
bench.setup();
40+
try {
41+
float expected = bench.dotProductScalar();
42+
assertEquals(expected, bench.dotProductLucene(), delta);
43+
assertEquals(expected, bench.dotProductNative(), delta);
44+
45+
expected = bench.dotProductLuceneQuery();
46+
assertEquals(expected, bench.dotProductNativeQuery(), delta);
47+
} finally {
48+
bench.teardown();
49+
}
50+
}
51+
}
52+
53+
public void testSquareDistance() throws Exception {
54+
for (int i = 0; i < 100; i++) {
55+
var bench = new Int7uScorerBenchmark();
56+
bench.dims = dims;
57+
bench.setup();
58+
try {
59+
float expected = bench.squareDistanceScalar();
60+
assertEquals(expected, bench.squareDistanceLucene(), delta);
61+
assertEquals(expected, bench.squareDistanceNative(), delta);
62+
63+
expected = bench.squareDistanceLuceneQuery();
64+
assertEquals(expected, bench.squareDistanceNativeQuery(), delta);
65+
} finally {
66+
bench.teardown();
67+
}
68+
}
69+
}
70+
71+
@ParametersFactory
72+
public static Iterable<Object[]> parametersFactory() {
73+
try {
74+
var params = Int7uScorerBenchmark.class.getField("dims").getAnnotationsByType(Param.class)[0].value();
75+
return () -> Arrays.stream(params).map(Integer::parseInt).map(i -> new Object[] { i }).iterator();
76+
} catch (NoSuchFieldException e) {
77+
throw new AssertionError(e);
78+
}
79+
}
80+
}

0 commit comments

Comments
 (0)