Skip to content

Commit ca3183f

Browse files
committed
switch to 32 bulk size
1 parent 7e75746 commit ca3183f

File tree

10 files changed

+86
-84
lines changed

10 files changed

+86
-84
lines changed

benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
@OutputTimeUnit(TimeUnit.MILLISECONDS)
4444
@State(Scope.Benchmark)
4545
// first iteration is complete garbage, so make sure we really warmup
46-
@Warmup(iterations = 4, time = 1)
46+
@Warmup(iterations = 3, time = 1)
4747
// real iterations. not useful to spend tons of time here, better to fork more
4848
@Measurement(iterations = 5, time = 1)
4949
// engage some noise reduction
@@ -54,13 +54,16 @@ public class OSQScorerBenchmark {
5454
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
5555
}
5656

57+
@Param({ "16", "32", "64", "128", "256" })
58+
int bulkSize;
59+
5760
@Param({ "384", "782", "1024" })
5861
int dims;
5962

6063
int length;
6164

62-
int numVectors = ES91OSQVectorsScorer.BULK_SIZE * 10;
63-
int numQueries = 10;
65+
final int numVectors = 16 * 64;
66+
final int numQueries = 10;
6467

6568
byte[][] binaryVectors;
6669
byte[][] binaryQueries;
@@ -83,7 +86,6 @@ public class OSQScorerBenchmark {
8386
@Setup
8487
public void setup() throws IOException {
8588
Random random = new Random(123);
86-
8789
this.length = OptimizedScalarQuantizer.discretize(dims, 64) / 8;
8890

8991
binaryVectors = new byte[numVectors][length];
@@ -95,9 +97,9 @@ public void setup() throws IOException {
9597
dirNiofs = new NIOFSDirectory(Files.createTempDirectory("vectorDataNFIOS"));
9698
IndexOutput outMmap = dirMmap.createOutput("vectors", IOContext.DEFAULT);
9799
IndexOutput outNfios = dirNiofs.createOutput("vectors", IOContext.DEFAULT);
98-
byte[] correctionBytes = new byte[14 * ES91OSQVectorsScorer.BULK_SIZE];
99-
for (int i = 0; i < numVectors; i += ES91OSQVectorsScorer.BULK_SIZE) {
100-
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
100+
byte[] correctionBytes = new byte[14 * bulkSize];
101+
for (int i = 0; i < numVectors; i += bulkSize) {
102+
for (int j = 0; j < bulkSize; j++) {
101103
outMmap.writeBytes(binaryVectors[i + j], 0, binaryVectors[i + j].length);
102104
outNfios.writeBytes(binaryVectors[i + j], 0, binaryVectors[i + j].length);
103105
}
@@ -123,9 +125,9 @@ public void setup() throws IOException {
123125
centroidDp = random.nextFloat();
124126

125127
scratch = new byte[length];
126-
scorerMmap = ESVectorizationProvider.getInstance().newES91OSQVectorsScorer(inMmap, dims);
127-
scorerNfios = ESVectorizationProvider.getInstance().newES91OSQVectorsScorer(inNiofs, dims);
128-
scratchScores = new float[16];
128+
scorerMmap = ESVectorizationProvider.getInstance().newES91OSQVectorsScorer(inMmap, bulkSize, dims);
129+
scorerNfios = ESVectorizationProvider.getInstance().newES91OSQVectorsScorer(inNiofs, bulkSize, dims);
130+
scratchScores = new float[bulkSize];
129131
corrections = new float[3];
130132
}
131133

@@ -134,24 +136,18 @@ public void teardown() throws IOException {
134136
IOUtils.close(dirMmap, inMmap, dirNiofs, inNiofs);
135137
}
136138

137-
@Benchmark
138139
public void scoreFromMemorySegmentOnlyVectorMmapScalar(Blackhole bh) throws IOException {
139140
scoreFromMemorySegmentOnlyVector(bh, inMmap, scorerMmap);
140141
}
141142

142-
@Benchmark
143-
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
144143
public void scoreFromMemorySegmentOnlyVectorMmapVect(Blackhole bh) throws IOException {
145144
scoreFromMemorySegmentOnlyVector(bh, inMmap, scorerMmap);
146145
}
147146

148-
@Benchmark
149147
public void scoreFromMemorySegmentOnlyVectorNiofsScalar(Blackhole bh) throws IOException {
150148
scoreFromMemorySegmentOnlyVector(bh, inNiofs, scorerNfios);
151149
}
152150

153-
@Benchmark
154-
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
155151
public void scoreFromMemorySegmentOnlyVectorNiofsVect(Blackhole bh) throws IOException {
156152
scoreFromMemorySegmentOnlyVector(bh, inNiofs, scorerNfios);
157153
}
@@ -181,34 +177,28 @@ private void scoreFromMemorySegmentOnlyVector(Blackhole bh, IndexInput in, ES91O
181177
}
182178
}
183179

184-
@Benchmark
185180
public void scoreFromMemorySegmentOnlyVectorBulkMmapScalar(Blackhole bh) throws IOException {
186181
scoreFromMemorySegmentOnlyVectorBulk(bh, inMmap, scorerMmap);
187182
}
188183

189-
@Benchmark
190-
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
191184
public void scoreFromMemorySegmentOnlyVectorBulkMmapVect(Blackhole bh) throws IOException {
192185
scoreFromMemorySegmentOnlyVectorBulk(bh, inMmap, scorerMmap);
193186
}
194187

195-
@Benchmark
196188
public void scoreFromMemorySegmentOnlyVectorBulkNiofsScalar(Blackhole bh) throws IOException {
197189
scoreFromMemorySegmentOnlyVectorBulk(bh, inNiofs, scorerNfios);
198190
}
199191

200-
@Benchmark
201-
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
202192
public void scoreFromMemorySegmentOnlyVectorBulkNiofsVect(Blackhole bh) throws IOException {
203193
scoreFromMemorySegmentOnlyVectorBulk(bh, inNiofs, scorerNfios);
204194
}
205195

206196
private void scoreFromMemorySegmentOnlyVectorBulk(Blackhole bh, IndexInput in, ES91OSQVectorsScorer scorer) throws IOException {
207197
for (int j = 0; j < numQueries; j++) {
208198
in.seek(0);
209-
for (int i = 0; i < numVectors; i += 16) {
210-
scorer.quantizeScoreBulk(binaryQueries[j], ES91OSQVectorsScorer.BULK_SIZE, scratchScores);
211-
for (int k = 0; k < ES91OSQVectorsScorer.BULK_SIZE; k++) {
199+
for (int i = 0; i < numVectors; i += bulkSize) {
200+
scorer.quantizeScoreBulk(binaryQueries[j], bulkSize, scratchScores);
201+
for (int k = 0; k < bulkSize; k++) {
212202
in.readFloats(corrections, 0, corrections.length);
213203
int addition = Short.toUnsignedInt(in.readShort());
214204
float score = scorer.score(
@@ -230,7 +220,6 @@ private void scoreFromMemorySegmentOnlyVectorBulk(Blackhole bh, IndexInput in, E
230220
}
231221
}
232222

233-
@Benchmark
234223
public void scoreFromMemorySegmentAllBulkMmapScalar(Blackhole bh) throws IOException {
235224
scoreFromMemorySegmentAllBulk(bh, inMmap, scorerMmap);
236225
}
@@ -241,7 +230,6 @@ public void scoreFromMemorySegmentAllBulkMmapVect(Blackhole bh) throws IOExcepti
241230
scoreFromMemorySegmentAllBulk(bh, inMmap, scorerMmap);
242231
}
243232

244-
@Benchmark
245233
public void scoreFromMemorySegmentAllBulkNiofsScalar(Blackhole bh) throws IOException {
246234
scoreFromMemorySegmentAllBulk(bh, inNiofs, scorerNfios);
247235
}
@@ -255,7 +243,7 @@ public void scoreFromMemorySegmentAllBulkNiofsVect(Blackhole bh) throws IOExcept
255243
private void scoreFromMemorySegmentAllBulk(Blackhole bh, IndexInput in, ES91OSQVectorsScorer scorer) throws IOException {
256244
for (int j = 0; j < numQueries; j++) {
257245
in.seek(0);
258-
for (int i = 0; i < numVectors; i += 16) {
246+
for (int i = 0; i < numVectors; i += bulkSize) {
259247
scorer.scoreBulk(
260248
binaryQueries[j],
261249
result.lowerInterval(),

libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91OSQVectorsScorer.java

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
/** Scorer for quantized vectors stored as an {@link IndexInput}. */
2323
public class ES91OSQVectorsScorer {
2424

25-
public static final int BULK_SIZE = 16;
25+
public static final int BULK_SIZE = 32;
2626

2727
protected static final float FOUR_BIT_SCALE = 1f / ((1 << 4) - 1);
2828

@@ -32,16 +32,22 @@ public class ES91OSQVectorsScorer {
3232
protected final int length;
3333
protected final int dimensions;
3434

35-
protected final float[] lowerIntervals = new float[BULK_SIZE];
36-
protected final float[] upperIntervals = new float[BULK_SIZE];
37-
protected final int[] targetComponentSums = new int[BULK_SIZE];
38-
protected final float[] additionalCorrections = new float[BULK_SIZE];
35+
protected final float[] lowerIntervals;
36+
protected final float[] upperIntervals;
37+
protected final int[] targetComponentSums;
38+
protected final float[] additionalCorrections;
39+
protected final int bulkSize;
3940

4041
/** Sole constructor, called by sub-classes. */
41-
public ES91OSQVectorsScorer(IndexInput in, int dimensions) {
42+
public ES91OSQVectorsScorer(IndexInput in, int bulkSize, int dimensions) {
4243
this.in = in;
4344
this.dimensions = dimensions;
4445
this.length = OptimizedScalarQuantizer.discretize(dimensions, 64) / 8;
46+
this.lowerIntervals = new float[bulkSize];
47+
this.upperIntervals = new float[bulkSize];
48+
this.targetComponentSums = new int[bulkSize];
49+
this.additionalCorrections = new float[bulkSize];
50+
this.bulkSize = bulkSize;
4551
}
4652

4753
/**
@@ -151,15 +157,15 @@ public float scoreBulk(
151157
float centroidDp,
152158
float[] scores
153159
) throws IOException {
154-
quantizeScoreBulk(q, BULK_SIZE, scores);
155-
in.readFloats(lowerIntervals, 0, BULK_SIZE);
156-
in.readFloats(upperIntervals, 0, BULK_SIZE);
157-
for (int i = 0; i < BULK_SIZE; i++) {
160+
quantizeScoreBulk(q, this.bulkSize, scores);
161+
in.readFloats(lowerIntervals, 0, this.bulkSize);
162+
in.readFloats(upperIntervals, 0, this.bulkSize);
163+
for (int i = 0; i < this.bulkSize; i++) {
158164
targetComponentSums[i] = Short.toUnsignedInt(in.readShort());
159165
}
160-
in.readFloats(additionalCorrections, 0, BULK_SIZE);
166+
in.readFloats(additionalCorrections, 0, this.bulkSize);
161167
float maxScore = Float.NEGATIVE_INFINITY;
162-
for (int i = 0; i < BULK_SIZE; i++) {
168+
for (int i = 0; i < this.bulkSize; i++) {
163169
scores[i] = score(
164170
queryLowerInterval,
165171
queryUpperInterval,

libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES92Int7VectorsScorer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
* */
2424
public class ES92Int7VectorsScorer {
2525

26-
public static final int BULK_SIZE = 16;
26+
public static final int BULK_SIZE = ES91OSQVectorsScorer.BULK_SIZE;
2727
protected static final float SEVEN_BIT_SCALE = 1f / ((1 << 7) - 1);
2828

2929
/** The wrapper {@link IndexInput}. */

libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESNextOSQVectorsScorer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
/** Scorer for quantized vectors stored as an {@link IndexInput}. */
2222
public class ESNextOSQVectorsScorer {
2323

24-
public static final int BULK_SIZE = 16;
24+
public static final int BULK_SIZE = ES91OSQVectorsScorer.BULK_SIZE;
2525

2626
protected static final float[] BIT_SCALES = new float[] {
2727
1f,

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ public ESVectorUtilSupport getVectorUtilSupport() {
3030
}
3131

3232
@Override
33-
public ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) {
34-
return new ES91OSQVectorsScorer(input, dimension);
33+
public ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int bulkSize, int dimension) {
34+
return new ES91OSQVectorsScorer(input, bulkSize, dimension);
3535
}
3636

3737
@Override

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,12 @@ public static ESVectorizationProvider getInstance() {
3131

3232
public abstract ESVectorUtilSupport getVectorUtilSupport();
3333

34+
public final ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException {
35+
return newES91OSQVectorsScorer(input, ES91OSQVectorsScorer.BULK_SIZE, dimension);
36+
}
37+
3438
/** Create a new {@link ES91OSQVectorsScorer} for the given {@link IndexInput}. */
35-
public abstract ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException;
39+
public abstract ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int bulkSize, int dimension) throws IOException;
3640

3741
public abstract ESNextOSQVectorsScorer newESNextOSQVectorsScorer(
3842
IndexInput input,

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,12 @@ public static ESVectorizationProvider getInstance() {
3838

3939
public abstract ESVectorUtilSupport getVectorUtilSupport();
4040

41+
public final ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException {
42+
return newES91OSQVectorsScorer(input, ES91OSQVectorsScorer.BULK_SIZE, dimension);
43+
}
44+
4145
/** Create a new {@link ES91OSQVectorsScorer} for the given {@link IndexInput}. */
42-
public abstract ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException;
46+
public abstract ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int bulkSize, int dimension) throws IOException;
4347

4448
public abstract ESNextOSQVectorsScorer newESNextOSQVectorsScorer(
4549
IndexInput input,

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScore
4848

4949
private final MemorySegment memorySegment;
5050

51-
public MemorySegmentES91OSQVectorsScorer(IndexInput in, int dimensions, MemorySegment memorySegment) {
52-
super(in, dimensions);
51+
public MemorySegmentES91OSQVectorsScorer(IndexInput in, int bulkSize, int dimensions, MemorySegment memorySegment) {
52+
super(in, bulkSize, dimensions);
5353
this.memorySegment = memorySegment;
5454
}
5555

@@ -411,8 +411,8 @@ private float score128Bulk(
411411
float centroidDp,
412412
float[] scores
413413
) throws IOException {
414-
quantizeScore128Bulk(q, BULK_SIZE, scores);
415-
int limit = FLOAT_SPECIES_128.loopBound(BULK_SIZE);
414+
quantizeScore128Bulk(q, this.bulkSize, scores);
415+
int limit = FLOAT_SPECIES_128.loopBound(this.bulkSize);
416416
int i = 0;
417417
long offset = in.getFilePointer();
418418
float ay = queryLowerInterval;
@@ -424,19 +424,19 @@ private float score128Bulk(
424424
var lx = FloatVector.fromMemorySegment(
425425
FLOAT_SPECIES_128,
426426
memorySegment,
427-
offset + 4 * BULK_SIZE + i * Float.BYTES,
427+
offset + 4 * this.bulkSize + i * Float.BYTES,
428428
ByteOrder.LITTLE_ENDIAN
429429
).sub(ax);
430430
var targetComponentSums = ShortVector.fromMemorySegment(
431431
SHORT_SPECIES_128,
432432
memorySegment,
433-
offset + 8 * BULK_SIZE + i * Short.BYTES,
433+
offset + 8 * this.bulkSize + i * Short.BYTES,
434434
ByteOrder.LITTLE_ENDIAN
435435
).convert(VectorOperators.S2I, 0).reinterpretAsInts().and(0xffff).convert(VectorOperators.I2F, 0);
436436
var additionalCorrections = FloatVector.fromMemorySegment(
437437
FLOAT_SPECIES_128,
438438
memorySegment,
439-
offset + 10 * BULK_SIZE + i * Float.BYTES,
439+
offset + 10 * this.bulkSize + i * Float.BYTES,
440440
ByteOrder.LITTLE_ENDIAN
441441
);
442442
var qcDist = FloatVector.fromArray(FLOAT_SPECIES_128, scores, i);
@@ -472,7 +472,7 @@ private float score128Bulk(
472472
}
473473
}
474474
}
475-
in.seek(offset + 14L * BULK_SIZE);
475+
in.seek(offset + 14L * this.bulkSize);
476476
return maxScore;
477477
}
478478

@@ -486,8 +486,8 @@ private float score256Bulk(
486486
float centroidDp,
487487
float[] scores
488488
) throws IOException {
489-
quantizeScore256Bulk(q, BULK_SIZE, scores);
490-
int limit = FLOAT_SPECIES_256.loopBound(BULK_SIZE);
489+
quantizeScore256Bulk(q, this.bulkSize, scores);
490+
int limit = FLOAT_SPECIES_256.loopBound(this.bulkSize);
491491
int i = 0;
492492
long offset = in.getFilePointer();
493493
float ay = queryLowerInterval;
@@ -499,19 +499,19 @@ private float score256Bulk(
499499
var lx = FloatVector.fromMemorySegment(
500500
FLOAT_SPECIES_256,
501501
memorySegment,
502-
offset + 4 * BULK_SIZE + i * Float.BYTES,
502+
offset + 4 * this.bulkSize + i * Float.BYTES,
503503
ByteOrder.LITTLE_ENDIAN
504504
).sub(ax);
505505
var targetComponentSums = ShortVector.fromMemorySegment(
506506
SHORT_SPECIES_256,
507507
memorySegment,
508-
offset + 8 * BULK_SIZE + i * Short.BYTES,
508+
offset + 8 * this.bulkSize + i * Short.BYTES,
509509
ByteOrder.LITTLE_ENDIAN
510510
).convert(VectorOperators.S2I, 0).reinterpretAsInts().and(0xffff).convert(VectorOperators.I2F, 0);
511511
var additionalCorrections = FloatVector.fromMemorySegment(
512512
FLOAT_SPECIES_256,
513513
memorySegment,
514-
offset + 10 * BULK_SIZE + i * Float.BYTES,
514+
offset + 10 * this.bulkSize + i * Float.BYTES,
515515
ByteOrder.LITTLE_ENDIAN
516516
);
517517
var qcDist = FloatVector.fromArray(FLOAT_SPECIES_256, scores, i);
@@ -547,7 +547,7 @@ private float score256Bulk(
547547
}
548548
}
549549
}
550-
in.seek(offset + 14L * BULK_SIZE);
550+
in.seek(offset + 14L * this.bulkSize);
551551
return maxScore;
552552
}
553553
}

0 commit comments

Comments
 (0)