Skip to content

Commit da218c8

Browse files
authored
Add bulk processing capabilities to ES91Int4VectorsScorer (elastic#131202)
It uses the same approach as the one taken in ES91OSQVectorsScorer
1 parent 4db75b8 commit da218c8

File tree

4 files changed

+608
-35
lines changed

4 files changed

+608
-35
lines changed

benchmarks/src/main/java/org/elasticsearch/benchmark/vector/Int4ScorerBenchmark.java

Lines changed: 72 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88
*/
99
package org.elasticsearch.benchmark.vector;
1010

11+
import org.apache.lucene.index.VectorSimilarityFunction;
1112
import org.apache.lucene.store.Directory;
1213
import org.apache.lucene.store.IOContext;
1314
import org.apache.lucene.store.IndexInput;
1415
import org.apache.lucene.store.IndexOutput;
1516
import org.apache.lucene.store.MMapDirectory;
1617
import org.apache.lucene.util.VectorUtil;
18+
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
1719
import org.elasticsearch.common.logging.LogConfigurator;
1820
import org.elasticsearch.core.IOUtils;
1921
import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
@@ -52,20 +54,26 @@ public class Int4ScorerBenchmark {
5254
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
5355
}
5456

55-
@Param({ "384", "702", "1024" })
57+
@Param({ "384", "782", "1024" })
5658
int dims;
5759

58-
int numVectors = 200;
59-
int numQueries = 10;
60+
int numVectors = 20 * ES91Int4VectorsScorer.BULK_SIZE;
61+
int numQueries = 5;
6062

6163
byte[] scratch;
6264
byte[][] binaryVectors;
6365
byte[][] binaryQueries;
66+
float[] scores = new float[ES91Int4VectorsScorer.BULK_SIZE];
67+
68+
float[] scratchFloats = new float[3];
6469

6570
ES91Int4VectorsScorer scorer;
6671
Directory dir;
6772
IndexInput in;
6873

74+
OptimizedScalarQuantizer.QuantizationResult queryCorrections;
75+
float centroidDp;
76+
6977
@Setup
7078
public void setup() throws IOException {
7179
binaryVectors = new byte[numVectors][dims];
@@ -77,9 +85,19 @@ public void setup() throws IOException {
7785
binaryVector[i] = (byte) ThreadLocalRandom.current().nextInt(16);
7886
}
7987
out.writeBytes(binaryVector, 0, binaryVector.length);
88+
ThreadLocalRandom.current().nextBytes(binaryVector);
89+
out.writeBytes(binaryVector, 0, 14); // corrections
8090
}
8191
}
8292

93+
queryCorrections = new OptimizedScalarQuantizer.QuantizationResult(
94+
ThreadLocalRandom.current().nextFloat(),
95+
ThreadLocalRandom.current().nextFloat(),
96+
ThreadLocalRandom.current().nextFloat(),
97+
Short.toUnsignedInt((short) ThreadLocalRandom.current().nextInt())
98+
);
99+
centroidDp = ThreadLocalRandom.current().nextFloat();
100+
83101
in = dir.openInput("vectors", IOContext.DEFAULT);
84102
binaryQueries = new byte[numVectors][dims];
85103
for (byte[] binaryVector : binaryVectors) {
@@ -105,18 +123,66 @@ public void scoreFromArray(Blackhole bh) throws IOException {
105123
in.seek(0);
106124
for (int i = 0; i < numVectors; i++) {
107125
in.readBytes(scratch, 0, dims);
108-
bh.consume(VectorUtil.int4DotProduct(binaryQueries[j], scratch));
126+
int dp = VectorUtil.int4DotProduct(binaryQueries[j], scratch);
127+
in.readFloats(scratchFloats, 0, 3);
128+
float score = scorer.applyCorrections(
129+
queryCorrections.lowerInterval(),
130+
queryCorrections.upperInterval(),
131+
queryCorrections.quantizedComponentSum(),
132+
queryCorrections.additionalCorrection(),
133+
VectorSimilarityFunction.EUCLIDEAN,
134+
centroidDp, // assuming no centroid dot product for this benchmark
135+
scratchFloats[0],
136+
scratchFloats[1],
137+
Short.toUnsignedInt(in.readShort()),
138+
scratchFloats[2],
139+
dp
140+
);
141+
bh.consume(score);
109142
}
110143
}
111144
}
112145

113146
@Benchmark
114147
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
115-
public void scoreFromMemorySegmentOnlyVector(Blackhole bh) throws IOException {
148+
public void scoreFromMemorySegment(Blackhole bh) throws IOException {
116149
for (int j = 0; j < numQueries; j++) {
117150
in.seek(0);
118151
for (int i = 0; i < numVectors; i++) {
119-
bh.consume(scorer.int4DotProduct(binaryQueries[j]));
152+
bh.consume(
153+
scorer.score(
154+
binaryQueries[j],
155+
queryCorrections.lowerInterval(),
156+
queryCorrections.upperInterval(),
157+
queryCorrections.quantizedComponentSum(),
158+
queryCorrections.additionalCorrection(),
159+
VectorSimilarityFunction.EUCLIDEAN,
160+
centroidDp
161+
)
162+
);
163+
}
164+
}
165+
}
166+
167+
@Benchmark
168+
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
169+
public void scoreFromMemorySegmentBulk(Blackhole bh) throws IOException {
170+
for (int j = 0; j < numQueries; j++) {
171+
in.seek(0);
172+
for (int i = 0; i < numVectors; i += ES91Int4VectorsScorer.BULK_SIZE) {
173+
scorer.scoreBulk(
174+
binaryQueries[j],
175+
queryCorrections.lowerInterval(),
176+
queryCorrections.upperInterval(),
177+
queryCorrections.quantizedComponentSum(),
178+
queryCorrections.additionalCorrection(),
179+
VectorSimilarityFunction.EUCLIDEAN,
180+
centroidDp,
181+
scores
182+
);
183+
for (float score : scores) {
184+
bh.consume(score);
185+
}
120186
}
121187
}
122188
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91Int4VectorsScorer.java

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,15 @@
88
*/
99
package org.elasticsearch.simdvec;
1010

11+
import org.apache.lucene.index.VectorSimilarityFunction;
1112
import org.apache.lucene.store.IndexInput;
13+
import org.apache.lucene.util.VectorUtil;
1214

1315
import java.io.IOException;
1416

17+
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
18+
import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
19+
1520
/** Scorer for quantized vectors stored as an {@link IndexInput}.
1621
* <p>
1722
* Similar to {@link org.apache.lucene.util.VectorUtil#int4DotProduct(byte[], byte[])} but
@@ -20,18 +25,30 @@
2025
* */
2126
public class ES91Int4VectorsScorer {
2227

28+
public static final int BULK_SIZE = 16;
29+
protected static final float FOUR_BIT_SCALE = 1f / ((1 << 4) - 1);
30+
2331
/** The wrapper {@link IndexInput}. */
2432
protected final IndexInput in;
2533
protected final int dimensions;
2634
protected byte[] scratch;
2735

36+
protected final float[] lowerIntervals = new float[BULK_SIZE];
37+
protected final float[] upperIntervals = new float[BULK_SIZE];
38+
protected final int[] targetComponentSums = new int[BULK_SIZE];
39+
protected final float[] additionalCorrections = new float[BULK_SIZE];
40+
2841
/** Sole constructor, called by sub-classes. */
2942
public ES91Int4VectorsScorer(IndexInput in, int dimensions) {
3043
this.in = in;
3144
this.dimensions = dimensions;
3245
scratch = new byte[dimensions];
3346
}
3447

48+
/**
49+
* compute the quantize distance between the provided quantized query and the quantized vector
50+
* that is read from the wrapped {@link IndexInput}.
51+
*/
3552
public long int4DotProduct(byte[] b) throws IOException {
3653
in.readBytes(scratch, 0, dimensions);
3754
int total = 0;
@@ -40,4 +57,129 @@ public long int4DotProduct(byte[] b) throws IOException {
4057
}
4158
return total;
4259
}
60+
61+
/**
62+
* compute the quantize distance between the provided quantized query and the quantized vectors
63+
* that are read from the wrapped {@link IndexInput}. The number of quantized vectors to read is
64+
* determined by {code count} and the results are stored in the provided {@code scores} array.
65+
*/
66+
public void int4DotProductBulk(byte[] b, int count, float[] scores) throws IOException {
67+
for (int i = 0; i < count; i++) {
68+
scores[i] = int4DotProduct(b);
69+
}
70+
}
71+
72+
/**
73+
* Computes the score by applying the necessary corrections to the provided quantized distance.
74+
*/
75+
public float score(
76+
byte[] q,
77+
float queryLowerInterval,
78+
float queryUpperInterval,
79+
int queryComponentSum,
80+
float queryAdditionalCorrection,
81+
VectorSimilarityFunction similarityFunction,
82+
float centroidDp
83+
) throws IOException {
84+
float score = int4DotProduct(q);
85+
in.readFloats(lowerIntervals, 0, 3);
86+
int addition = Short.toUnsignedInt(in.readShort());
87+
return applyCorrections(
88+
queryLowerInterval,
89+
queryUpperInterval,
90+
queryComponentSum,
91+
queryAdditionalCorrection,
92+
similarityFunction,
93+
centroidDp,
94+
lowerIntervals[0],
95+
lowerIntervals[1],
96+
addition,
97+
lowerIntervals[2],
98+
score
99+
);
100+
}
101+
102+
/**
103+
* compute the distance between the provided quantized query and the quantized vectors that are
104+
* read from the wrapped {@link IndexInput}.
105+
*
106+
* <p>The number of vectors to score is defined by {@link #BULK_SIZE}. The expected format of the
107+
* input is as follows: First the quantized vectors are read from the input,then all the lower
108+
* intervals as floats, then all the upper intervals as floats, then all the target component sums
109+
* as shorts, and finally all the additional corrections as floats.
110+
*
111+
* <p>The results are stored in the provided scores array.
112+
*/
113+
public void scoreBulk(
114+
byte[] q,
115+
float queryLowerInterval,
116+
float queryUpperInterval,
117+
int queryComponentSum,
118+
float queryAdditionalCorrection,
119+
VectorSimilarityFunction similarityFunction,
120+
float centroidDp,
121+
float[] scores
122+
) throws IOException {
123+
int4DotProductBulk(q, BULK_SIZE, scores);
124+
in.readFloats(lowerIntervals, 0, BULK_SIZE);
125+
in.readFloats(upperIntervals, 0, BULK_SIZE);
126+
for (int i = 0; i < BULK_SIZE; i++) {
127+
targetComponentSums[i] = Short.toUnsignedInt(in.readShort());
128+
}
129+
in.readFloats(additionalCorrections, 0, BULK_SIZE);
130+
for (int i = 0; i < BULK_SIZE; i++) {
131+
scores[i] = applyCorrections(
132+
queryLowerInterval,
133+
queryUpperInterval,
134+
queryComponentSum,
135+
queryAdditionalCorrection,
136+
similarityFunction,
137+
centroidDp,
138+
lowerIntervals[i],
139+
upperIntervals[i],
140+
targetComponentSums[i],
141+
additionalCorrections[i],
142+
scores[i]
143+
);
144+
}
145+
}
146+
147+
/**
148+
* Computes the score by applying the necessary corrections to the provided quantized distance.
149+
*/
150+
public float applyCorrections(
151+
float queryLowerInterval,
152+
float queryUpperInterval,
153+
int queryComponentSum,
154+
float queryAdditionalCorrection,
155+
VectorSimilarityFunction similarityFunction,
156+
float centroidDp,
157+
float lowerInterval,
158+
float upperInterval,
159+
int targetComponentSum,
160+
float additionalCorrection,
161+
float qcDist
162+
) {
163+
float ax = lowerInterval;
164+
// Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
165+
float lx = upperInterval - ax;
166+
float ay = queryLowerInterval;
167+
float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE;
168+
float y1 = queryComponentSum;
169+
float score = ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist;
170+
// For euclidean, we need to invert the score and apply the additional correction, which is
171+
// assumed to be the squared l2norm of the centroid centered vectors.
172+
if (similarityFunction == EUCLIDEAN) {
173+
score = queryAdditionalCorrection + additionalCorrection - 2 * score;
174+
return Math.max(1 / (1f + score), 0);
175+
} else {
176+
// For cosine and max inner product, we need to apply the additional correction, which is
177+
// assumed to be the non-centered dot-product between the vector and the centroid
178+
score += queryAdditionalCorrection + additionalCorrection - centroidDp;
179+
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
180+
return VectorUtil.scaleMaxInnerProductScore(score);
181+
}
182+
return Math.max((1f + score) / 2f, 0);
183+
}
184+
}
43185
}

0 commit comments

Comments
 (0)