Skip to content

Commit 98b00be

Browse files
Merge branch 'main' into poc_runtime_function_evaluators_code_only
2 parents 7713043 + a9777af commit 98b00be

File tree

10 files changed

+268
-226
lines changed

10 files changed

+268
-226
lines changed

benchmarks/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ dependencies {
5353
api(project(':x-pack:plugin:logsdb'))
5454
implementation project(path: ':libs:native')
5555
implementation project(path: ':libs:simdvec')
56+
implementation (testFixtures(project(path: ':libs:simdvec')))
5657
implementation project(path: ':libs:swisshash')
5758
implementation project(path: ':libs:exponential-histogram')
5859
implementation(project(':x-pack:plugin:searchable-snapshots')) {

benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerOSQBenchmark.java

Lines changed: 36 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@
1616
import org.apache.lucene.store.IndexOutput;
1717
import org.apache.lucene.store.MMapDirectory;
1818
import org.apache.lucene.store.NIOFSDirectory;
19-
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
19+
import org.apache.lucene.util.VectorUtil;
2020
import org.elasticsearch.common.logging.LogConfigurator;
2121
import org.elasticsearch.core.IOUtils;
2222
import org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat;
2323
import org.elasticsearch.simdvec.ESNextOSQVectorsScorer;
2424
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
25+
import org.elasticsearch.simdvec.internal.vectorization.VectorScorerTestUtils;
2526
import org.elasticsearch.xpack.searchablesnapshots.store.SearchableSnapshotDirectoryFactory;
2627
import org.openjdk.jmh.annotations.Benchmark;
2728
import org.openjdk.jmh.annotations.BenchmarkMode;
@@ -42,6 +43,11 @@
4243
import java.util.Random;
4344
import java.util.concurrent.TimeUnit;
4445

46+
import static org.elasticsearch.simdvec.internal.vectorization.VectorScorerTestUtils.createOSQIndexData;
47+
import static org.elasticsearch.simdvec.internal.vectorization.VectorScorerTestUtils.createOSQQueryData;
48+
import static org.elasticsearch.simdvec.internal.vectorization.VectorScorerTestUtils.randomVector;
49+
import static org.elasticsearch.simdvec.internal.vectorization.VectorScorerTestUtils.writeBulkOSQVectorData;
50+
4551
@BenchmarkMode(Mode.Throughput)
4652
@OutputTimeUnit(TimeUnit.MILLISECONDS)
4753
@State(Scope.Benchmark)
@@ -72,7 +78,7 @@ public enum VectorImplementation {
7278
public int dims;
7379

7480
@Param({ "1", "2", "4" })
75-
public int bits;
81+
public byte bits;
7682

7783
int bulkSize = ESNextOSQVectorsScorer.BULK_SIZE;
7884

@@ -90,9 +96,7 @@ public enum VectorImplementation {
9096

9197
int length;
9298

93-
byte[][] binaryVectors;
94-
byte[][] binaryQueries;
95-
OptimizedScalarQuantizer.QuantizationResult result;
99+
VectorScorerTestUtils.OSQVectorData[] binaryQueries;
96100
float centroidDp;
97101

98102
byte[] scratch;
@@ -111,17 +115,12 @@ public void setup() throws IOException {
111115
}
112116

113117
void setup(Random random) throws IOException {
114-
this.length = switch (bits) {
115-
case 1 -> ESNextDiskBBQVectorsFormat.QuantEncoding.ONE_BIT_4BIT_QUERY.getDocPackedLength(dims);
116-
case 2 -> ESNextDiskBBQVectorsFormat.QuantEncoding.TWO_BIT_4BIT_QUERY.getDocPackedLength(dims);
117-
case 4 -> ESNextDiskBBQVectorsFormat.QuantEncoding.FOUR_BIT_SYMMETRIC.getDocPackedLength(dims);
118-
default -> throw new IllegalArgumentException("Unsupported bits: " + bits);
119-
};
118+
this.length = ESNextDiskBBQVectorsFormat.QuantEncoding.fromBits(bits).getDocPackedLength(dims);
120119

121-
binaryVectors = new byte[numVectors][length];
122-
for (byte[] binaryVector : binaryVectors) {
123-
random.nextBytes(binaryVector);
124-
}
120+
final float[] centroid = new float[dims];
121+
randomVector(random, centroid, similarityFunction);
122+
123+
var quantizer = new org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer(similarityFunction);
125124

126125
directory = switch (directoryType) {
127126
case MMAP -> new MMapDirectory(createTempDirectory("vectorDataMmap"));
@@ -130,35 +129,27 @@ void setup(Random random) throws IOException {
130129
};
131130

132131
try (IndexOutput output = directory.createOutput("vectors", IOContext.DEFAULT)) {
133-
byte[] correctionBytes = new byte[16 * bulkSize];
132+
VectorScorerTestUtils.OSQVectorData[] vectors = new VectorScorerTestUtils.OSQVectorData[bulkSize];
134133
for (int i = 0; i < numVectors; i += bulkSize) {
135134
for (int j = 0; j < bulkSize; j++) {
136-
output.writeBytes(binaryVectors[i + j], 0, binaryVectors[i + j].length);
135+
var vector = new float[dims];
136+
randomVector(random, vector, similarityFunction);
137+
vectors[j] = createOSQIndexData(vector, centroid, quantizer, dims, bits, length);
137138
}
138-
random.nextBytes(correctionBytes);
139-
output.writeBytes(correctionBytes, 0, correctionBytes.length);
139+
writeBulkOSQVectorData(bulkSize, output, vectors);
140140
}
141141
CodecUtil.writeFooter(output);
142142
}
143143
input = directory.openInput("vectors", IOContext.DEFAULT);
144-
int binaryQueryLength = switch (bits) {
145-
case 1 -> ESNextDiskBBQVectorsFormat.QuantEncoding.ONE_BIT_4BIT_QUERY.getQueryPackedLength(dims);
146-
case 2 -> ESNextDiskBBQVectorsFormat.QuantEncoding.TWO_BIT_4BIT_QUERY.getQueryPackedLength(dims);
147-
case 4 -> ESNextDiskBBQVectorsFormat.QuantEncoding.FOUR_BIT_SYMMETRIC.getQueryPackedLength(dims);
148-
default -> throw new IllegalArgumentException("Unsupported bits: " + bits);
149-
};
144+
int binaryQueryLength = ESNextDiskBBQVectorsFormat.QuantEncoding.fromBits(bits).getQueryPackedLength(dims);
150145

151-
binaryQueries = new byte[numVectors][binaryQueryLength];
152-
for (byte[] binaryQuery : binaryQueries) {
153-
random.nextBytes(binaryQuery);
146+
binaryQueries = new VectorScorerTestUtils.OSQVectorData[numVectors];
147+
var query = new float[dims];
148+
for (int i = 0; i < numVectors; ++i) {
149+
randomVector(random, query, similarityFunction);
150+
binaryQueries[i] = createOSQQueryData(query, centroid, quantizer, dims, (byte) 4, binaryQueryLength);
154151
}
155-
result = new OptimizedScalarQuantizer.QuantizationResult(
156-
random.nextFloat(),
157-
random.nextFloat(),
158-
random.nextFloat(),
159-
Short.toUnsignedInt((short) random.nextInt())
160-
);
161-
centroidDp = random.nextFloat();
152+
centroidDp = VectorUtil.dotProduct(centroid, centroid);
162153

163154
scratch = new byte[length];
164155
final int docBits;
@@ -202,14 +193,14 @@ public float[] score() throws IOException {
202193
for (int j = 0; j < numQueries; j++) {
203194
input.seek(0);
204195
for (int i = 0; i < numVectors; i++) {
205-
float qDist = scorer.quantizeScore(binaryQueries[j]);
196+
float qDist = scorer.quantizeScore(binaryQueries[j].quantizedVector());
206197
input.readFloats(corrections, 0, corrections.length);
207198
int addition = Short.toUnsignedInt(input.readShort());
208199
float score = scorer.score(
209-
result.lowerInterval(),
210-
result.upperInterval(),
211-
result.quantizedComponentSum(),
212-
result.additionalCorrection(),
200+
binaryQueries[j].lowerInterval(),
201+
binaryQueries[j].upperInterval(),
202+
binaryQueries[j].quantizedComponentSum(),
203+
binaryQueries[j].additionalCorrection(),
213204
similarityFunction,
214205
centroidDp,
215206
corrections[0],
@@ -231,11 +222,11 @@ public float[] bulkScore() throws IOException {
231222
input.seek(0);
232223
for (int i = 0; i < numVectors; i += scratchScores.length) {
233224
scorer.scoreBulk(
234-
binaryQueries[j],
235-
result.lowerInterval(),
236-
result.upperInterval(),
237-
result.quantizedComponentSum(),
238-
result.additionalCorrection(),
225+
binaryQueries[j].quantizedVector(),
226+
binaryQueries[j].lowerInterval(),
227+
binaryQueries[j].upperInterval(),
228+
binaryQueries[j].quantizedComponentSum(),
229+
binaryQueries[j].additionalCorrection(),
239230
similarityFunction,
240231
centroidDp,
241232
scratchScores

benchmarks/src/test/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerOSQBenchmarkTests.java

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.apache.lucene.util.Constants;
1616
import org.elasticsearch.core.IOUtils;
1717
import org.elasticsearch.test.ESTestCase;
18+
import org.elasticsearch.test.junit.annotations.TestLogging;
1819
import org.junit.BeforeClass;
1920
import org.openjdk.jmh.annotations.Param;
2021

@@ -24,17 +25,22 @@
2425

2526
import static org.elasticsearch.common.util.CollectionUtils.appendToCopy;
2627

28+
@TestLogging(
29+
reason = "Noisy logging",
30+
value = "org.elasticsearch.env.NodeEnvironment:WARN,org.elasticsearch.xpack.searchablesnapshots.cache.full.PersistentCache:WARN"
31+
)
2732
public class VectorScorerOSQBenchmarkTests extends ESTestCase {
2833

34+
private static final int REPETITIONS = 10;
2935
private final float deltaPercent = 0.1f;
3036
private final int dims;
31-
private final int bits;
37+
private final byte bits;
3238
private final VectorScorerOSQBenchmark.DirectoryType directoryType;
3339
private final VectorSimilarityFunction similarityFunction;
3440

3541
public VectorScorerOSQBenchmarkTests(
3642
int dims,
37-
int bits,
43+
byte bits,
3844
VectorScorerOSQBenchmark.DirectoryType directoryType,
3945
VectorSimilarityFunction similarityFunction
4046
) {
@@ -50,7 +56,7 @@ public static void skipWindows() {
5056
}
5157

5258
public void testSingleScalarVsVectorized() throws Exception {
53-
for (int i = 0; i < 100; i++) {
59+
for (int i = 0; i < REPETITIONS; i++) {
5460
var seed = randomLong();
5561

5662
var scalar = new VectorScorerOSQBenchmark();
@@ -85,7 +91,7 @@ public void testSingleScalarVsVectorized() throws Exception {
8591
}
8692

8793
public void testBulkScalarVsVectorized() throws Exception {
88-
for (int i = 0; i < 100; i++) {
94+
for (int i = 0; i < REPETITIONS; i++) {
8995
var seed = randomLong();
9096

9197
var scalar = new VectorScorerOSQBenchmark();
@@ -128,7 +134,7 @@ public static Iterable<Object[]> parametersFactory() {
128134

129135
return () -> Arrays.stream(dims)
130136
.map(Integer::parseInt)
131-
.flatMap(d -> Arrays.stream(bits).map(Integer::parseInt).map(b -> List.<Object>of(d, b)))
137+
.flatMap(d -> Arrays.stream(bits).map(Byte::parseByte).map(b -> List.<Object>of(d, b)))
132138
.flatMap(params -> Arrays.stream(VectorScorerOSQBenchmark.DirectoryType.values()).map(dir -> appendToCopy(params, dir)))
133139
.flatMap(params -> Arrays.stream(VectorSimilarityFunction.values()).map(f -> appendToCopy(params, f).toArray()))
134140
.iterator();

libs/simdvec/build.gradle

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,20 @@ import org.elasticsearch.gradle.internal.precommit.CheckForbiddenApisTask
1212
apply plugin: 'elasticsearch.publish'
1313
apply plugin: 'elasticsearch.build'
1414
apply plugin: 'elasticsearch.mrjar'
15+
apply plugin: 'java-test-fixtures'
1516

1617
dependencies {
1718
implementation project(':libs:native')
1819
implementation project(':libs:logging')
1920
implementation "org.apache.lucene:lucene-core:${versions.lucene}"
2021

22+
testImplementation(testArtifact(project(':x-pack:plugin:searchable-snapshots')))
2123
testImplementation(project(":test:framework")) {
2224
exclude group: 'org.elasticsearch', module: 'native'
2325
}
24-
testImplementation(testArtifact(project(':x-pack:plugin:searchable-snapshots')))
26+
testFixturesImplementation(project(":test:framework")) {
27+
exclude group: 'org.elasticsearch', module: 'native'
28+
}
2529
}
2630

2731
// compileMain21Java does not exist within idea (see MrJarPlugin) so we cannot reference directly by name

libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ public void testScore() throws Exception {
120120
long qDist = defaultScorer.quantizeScore(quantizeQuery);
121121
slice.readFloats(floatScratch, 0, 3);
122122
int quantizedComponentSum = slice.readShort();
123-
float defaulScore = defaultScorer.score(
123+
float defaultScore = defaultScorer.score(
124124
queryCorrections.lowerInterval(),
125125
queryCorrections.upperInterval(),
126126
queryCorrections.quantizedComponentSum(),
@@ -149,7 +149,7 @@ public void testScore() throws Exception {
149149
floatScratch[2],
150150
qDist
151151
);
152-
assertEquals(defaulScore, panamaScore, 1e-2f);
152+
assertEquals(defaultScore, panamaScore, 1e-2f);
153153
assertEquals(((long) (i + 1) * (length + 14)), slice.getFilePointer());
154154
assertEquals(padding + ((long) (i + 1) * (length + 14)), in.getFilePointer());
155155
}
@@ -234,9 +234,7 @@ public void testScoreBulk() throws Exception {
234234
scoresPanama
235235
);
236236
assertEquals(defaultMaxScore, panamaMaxScore, 1e-2f);
237-
for (int j = 0; j < bulkSize; j++) {
238-
assertEquals(scoresDefault[j], scoresPanama[j], 1e-2f);
239-
}
237+
assertArrayEqualsPercent(scoresDefault, scoresPanama, 0.05f, 1e-2f);
240238
assertEquals(((long) (bulkSize) * (length + 14)), slice.getFilePointer());
241239
assertEquals(padding + ((long) (i + bulkSize) * (length + 14)), in.getFilePointer());
242240
}

0 commit comments

Comments
 (0)