Skip to content

Commit 74faf47

Browse files
authored
New bulk scorer for binary quantized vectors via optimized scalar quantization (#127189)
* New bulk scorer for binary quantized vectors via optimized scalar quantization * fixing headers * fixing tests
1 parent 85d375c commit 74faf47

File tree

8 files changed

+987
-0
lines changed

8 files changed

+987
-0
lines changed
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
package org.elasticsearch.benchmark.vector;
10+
11+
import org.apache.lucene.index.VectorSimilarityFunction;
12+
import org.apache.lucene.store.Directory;
13+
import org.apache.lucene.store.IOContext;
14+
import org.apache.lucene.store.IndexInput;
15+
import org.apache.lucene.store.IndexOutput;
16+
import org.apache.lucene.store.MMapDirectory;
17+
import org.apache.lucene.util.VectorUtil;
18+
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
19+
import org.elasticsearch.common.logging.LogConfigurator;
20+
import org.elasticsearch.simdvec.internal.vectorization.ES91OSQVectorsScorer;
21+
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
22+
import org.openjdk.jmh.annotations.Benchmark;
23+
import org.openjdk.jmh.annotations.BenchmarkMode;
24+
import org.openjdk.jmh.annotations.Fork;
25+
import org.openjdk.jmh.annotations.Measurement;
26+
import org.openjdk.jmh.annotations.Mode;
27+
import org.openjdk.jmh.annotations.OutputTimeUnit;
28+
import org.openjdk.jmh.annotations.Param;
29+
import org.openjdk.jmh.annotations.Scope;
30+
import org.openjdk.jmh.annotations.Setup;
31+
import org.openjdk.jmh.annotations.State;
32+
import org.openjdk.jmh.annotations.Warmup;
33+
import org.openjdk.jmh.infra.Blackhole;
34+
35+
import java.io.IOException;
36+
import java.nio.file.Files;
37+
import java.util.Random;
38+
import java.util.concurrent.TimeUnit;
39+
40+
@BenchmarkMode(Mode.Throughput)
41+
@OutputTimeUnit(TimeUnit.MILLISECONDS)
42+
@State(Scope.Benchmark)
43+
// first iteration is complete garbage, so make sure we really warmup
44+
@Warmup(iterations = 4, time = 1)
45+
// real iterations. not useful to spend tons of time here, better to fork more
46+
@Measurement(iterations = 5, time = 1)
47+
// engage some noise reduction
48+
@Fork(value = 1)
49+
public class OSQScorerBenchmark {
50+
51+
static {
52+
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
53+
}
54+
55+
@Param({ "1024" })
56+
int dims;
57+
58+
int length;
59+
60+
int numVectors = ES91OSQVectorsScorer.BULK_SIZE * 10;
61+
int numQueries = 10;
62+
63+
byte[][] binaryVectors;
64+
byte[][] binaryQueries;
65+
OptimizedScalarQuantizer.QuantizationResult result;
66+
float centroidDp;
67+
68+
byte[] scratch;
69+
ES91OSQVectorsScorer scorer;
70+
71+
IndexInput in;
72+
73+
float[] scratchScores;
74+
float[] corrections;
75+
76+
@Setup
77+
public void setup() throws IOException {
78+
Random random = new Random(123);
79+
80+
this.length = OptimizedScalarQuantizer.discretize(dims, 64) / 8;
81+
82+
binaryVectors = new byte[numVectors][length];
83+
for (byte[] binaryVector : binaryVectors) {
84+
random.nextBytes(binaryVector);
85+
}
86+
87+
Directory dir = new MMapDirectory(Files.createTempDirectory("vectorData"));
88+
IndexOutput out = dir.createOutput("vectors", IOContext.DEFAULT);
89+
byte[] correctionBytes = new byte[14 * ES91OSQVectorsScorer.BULK_SIZE];
90+
for (int i = 0; i < numVectors; i += ES91OSQVectorsScorer.BULK_SIZE) {
91+
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
92+
out.writeBytes(binaryVectors[i + j], 0, binaryVectors[i + j].length);
93+
}
94+
random.nextBytes(correctionBytes);
95+
out.writeBytes(correctionBytes, 0, correctionBytes.length);
96+
}
97+
out.close();
98+
in = dir.openInput("vectors", IOContext.DEFAULT);
99+
100+
binaryQueries = new byte[numVectors][4 * length];
101+
for (byte[] binaryVector : binaryVectors) {
102+
random.nextBytes(binaryVector);
103+
}
104+
result = new OptimizedScalarQuantizer.QuantizationResult(
105+
random.nextFloat(),
106+
random.nextFloat(),
107+
random.nextFloat(),
108+
Short.toUnsignedInt((short) random.nextInt())
109+
);
110+
centroidDp = random.nextFloat();
111+
112+
scratch = new byte[length];
113+
scorer = ESVectorizationProvider.getInstance().newES91OSQVectorsScorer(in, dims);
114+
scratchScores = new float[16];
115+
corrections = new float[3];
116+
}
117+
118+
@Benchmark
119+
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
120+
public void scoreFromArray(Blackhole bh) throws IOException {
121+
for (int j = 0; j < numQueries; j++) {
122+
in.seek(0);
123+
for (int i = 0; i < numVectors; i++) {
124+
in.readBytes(scratch, 0, length);
125+
float qDist = VectorUtil.int4BitDotProduct(binaryQueries[j], scratch);
126+
in.readFloats(corrections, 0, corrections.length);
127+
int addition = Short.toUnsignedInt(in.readShort());
128+
float score = scorer.score(
129+
result,
130+
VectorSimilarityFunction.EUCLIDEAN,
131+
centroidDp,
132+
corrections[0],
133+
corrections[1],
134+
addition,
135+
corrections[2],
136+
qDist
137+
);
138+
bh.consume(score);
139+
}
140+
}
141+
}
142+
143+
@Benchmark
144+
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
145+
public void scoreFromMemorySegmentOnlyVector(Blackhole bh) throws IOException {
146+
for (int j = 0; j < numQueries; j++) {
147+
in.seek(0);
148+
for (int i = 0; i < numVectors; i++) {
149+
float qDist = scorer.quantizeScore(binaryQueries[j]);
150+
in.readFloats(corrections, 0, corrections.length);
151+
int addition = Short.toUnsignedInt(in.readShort());
152+
float score = scorer.score(
153+
result,
154+
VectorSimilarityFunction.EUCLIDEAN,
155+
centroidDp,
156+
corrections[0],
157+
corrections[1],
158+
addition,
159+
corrections[2],
160+
qDist
161+
);
162+
bh.consume(score);
163+
}
164+
}
165+
}
166+
167+
@Benchmark
168+
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
169+
public void scoreFromMemorySegmentOnlyVectorBulk(Blackhole bh) throws IOException {
170+
for (int j = 0; j < numQueries; j++) {
171+
in.seek(0);
172+
for (int i = 0; i < numVectors; i += 16) {
173+
scorer.quantizeScoreBulk(binaryQueries[j], ES91OSQVectorsScorer.BULK_SIZE, scratchScores);
174+
for (int k = 0; k < ES91OSQVectorsScorer.BULK_SIZE; k++) {
175+
in.readFloats(corrections, 0, corrections.length);
176+
int addition = Short.toUnsignedInt(in.readShort());
177+
float score = scorer.score(
178+
result,
179+
VectorSimilarityFunction.EUCLIDEAN,
180+
centroidDp,
181+
corrections[0],
182+
corrections[1],
183+
addition,
184+
corrections[2],
185+
scratchScores[k]
186+
);
187+
bh.consume(score);
188+
}
189+
}
190+
}
191+
}
192+
193+
@Benchmark
194+
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
195+
public void scoreFromMemorySegmentAllBulk(Blackhole bh) throws IOException {
196+
for (int j = 0; j < numQueries; j++) {
197+
in.seek(0);
198+
for (int i = 0; i < numVectors; i += 16) {
199+
scorer.scoreBulk(binaryQueries[j], result, VectorSimilarityFunction.EUCLIDEAN, centroidDp, scratchScores);
200+
bh.consume(scratchScores);
201+
}
202+
}
203+
}
204+
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99

1010
package org.elasticsearch.simdvec.internal.vectorization;
1111

12+
import org.apache.lucene.store.IndexInput;
13+
14+
import java.io.IOException;
15+
1216
final class DefaultESVectorizationProvider extends ESVectorizationProvider {
1317
private final ESVectorUtilSupport vectorUtilSupport;
1418

@@ -20,4 +24,9 @@ final class DefaultESVectorizationProvider extends ESVectorizationProvider {
2024
public ESVectorUtilSupport getVectorUtilSupport() {
2125
return vectorUtilSupport;
2226
}
27+
28+
@Override
29+
public ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException {
30+
return new ES91OSQVectorsScorer(input, dimension);
31+
}
2332
}
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
package org.elasticsearch.simdvec.internal.vectorization;
10+
11+
import org.apache.lucene.index.VectorSimilarityFunction;
12+
import org.apache.lucene.store.IndexInput;
13+
import org.apache.lucene.util.BitUtil;
14+
import org.apache.lucene.util.VectorUtil;
15+
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
16+
17+
import java.io.IOException;
18+
19+
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
20+
import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
21+
22+
/** Scorer for quantized vectors stored as an {@link IndexInput}. */
23+
public class ES91OSQVectorsScorer {
24+
25+
public static final int BULK_SIZE = 16;
26+
27+
protected static final float FOUR_BIT_SCALE = 1f / ((1 << 4) - 1);
28+
29+
/** The wrapper {@link IndexInput}. */
30+
protected final IndexInput in;
31+
32+
protected final int length;
33+
protected final int dimensions;
34+
35+
protected final float[] lowerIntervals = new float[BULK_SIZE];
36+
protected final float[] upperIntervals = new float[BULK_SIZE];
37+
protected final int[] targetComponentSums = new int[BULK_SIZE];
38+
protected final float[] additionalCorrections = new float[BULK_SIZE];
39+
40+
/** Sole constructor, called by sub-classes. */
41+
public ES91OSQVectorsScorer(IndexInput in, int dimensions) {
42+
this.in = in;
43+
this.dimensions = dimensions;
44+
this.length = OptimizedScalarQuantizer.discretize(dimensions, 64) / 8;
45+
}
46+
47+
/**
48+
* compute the quantize distance between the provided quantized query and the quantized vector
49+
* that is read from the wrapped {@link IndexInput}.
50+
*/
51+
public long quantizeScore(byte[] q) throws IOException {
52+
assert q.length == length * 4;
53+
final int size = length;
54+
long subRet0 = 0;
55+
long subRet1 = 0;
56+
long subRet2 = 0;
57+
long subRet3 = 0;
58+
int r = 0;
59+
for (final int upperBound = size & -Long.BYTES; r < upperBound; r += Long.BYTES) {
60+
final long value = in.readLong();
61+
subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, r) & value);
62+
subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, r + size) & value);
63+
subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, r + 2 * size) & value);
64+
subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, r + 3 * size) & value);
65+
}
66+
for (final int upperBound = size & -Integer.BYTES; r < upperBound; r += Integer.BYTES) {
67+
final int value = in.readInt();
68+
subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, r) & value);
69+
subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, r + size) & value);
70+
subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, r + 2 * size) & value);
71+
subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, r + 3 * size) & value);
72+
}
73+
for (; r < size; r++) {
74+
final byte value = in.readByte();
75+
subRet0 += Integer.bitCount((q[r] & value) & 0xFF);
76+
subRet1 += Integer.bitCount((q[r + size] & value) & 0xFF);
77+
subRet2 += Integer.bitCount((q[r + 2 * size] & value) & 0xFF);
78+
subRet3 += Integer.bitCount((q[r + 3 * size] & value) & 0xFF);
79+
}
80+
return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
81+
}
82+
83+
/**
84+
* compute the quantize distance between the provided quantized query and the quantized vectors
85+
* that are read from the wrapped {@link IndexInput}. The number of quantized vectors to read is
86+
* determined by {code count} and the results are stored in the provided {@code scores} array.
87+
*/
88+
public void quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOException {
89+
for (int i = 0; i < count; i++) {
90+
scores[i] = quantizeScore(q);
91+
}
92+
}
93+
94+
/**
95+
* Computes the score by applying the necessary corrections to the provided quantized distance.
96+
*/
97+
public float score(
98+
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
99+
VectorSimilarityFunction similarityFunction,
100+
float centroidDp,
101+
float lowerInterval,
102+
float upperInterval,
103+
int targetComponentSum,
104+
float additionalCorrection,
105+
float qcDist
106+
) {
107+
float ax = lowerInterval;
108+
// Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
109+
float lx = upperInterval - ax;
110+
float ay = queryCorrections.lowerInterval();
111+
float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
112+
float y1 = queryCorrections.quantizedComponentSum();
113+
float score = ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist;
114+
// For euclidean, we need to invert the score and apply the additional correction, which is
115+
// assumed to be the squared l2norm of the centroid centered vectors.
116+
if (similarityFunction == EUCLIDEAN) {
117+
score = queryCorrections.additionalCorrection() + additionalCorrection - 2 * score;
118+
return Math.max(1 / (1f + score), 0);
119+
} else {
120+
// For cosine and max inner product, we need to apply the additional correction, which is
121+
// assumed to be the non-centered dot-product between the vector and the centroid
122+
score += queryCorrections.additionalCorrection() + additionalCorrection - centroidDp;
123+
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
124+
return VectorUtil.scaleMaxInnerProductScore(score);
125+
}
126+
return Math.max((1f + score) / 2f, 0);
127+
}
128+
}
129+
130+
/**
131+
* compute the distance between the provided quantized query and the quantized vectors that are
132+
* read from the wrapped {@link IndexInput}.
133+
*
134+
* <p>The number of vectors to score is defined by {@link #BULK_SIZE}. The expected format of the
135+
* input is as follows: First the quantized vectors are read from the input,then all the lower
136+
* intervals as floats, then all the upper intervals as floats, then all the target component sums
137+
* as shorts, and finally all the additional corrections as floats.
138+
*
139+
* <p>The results are stored in the provided scores array.
140+
*/
141+
public void scoreBulk(
142+
byte[] q,
143+
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
144+
VectorSimilarityFunction similarityFunction,
145+
float centroidDp,
146+
float[] scores
147+
) throws IOException {
148+
quantizeScoreBulk(q, BULK_SIZE, scores);
149+
in.readFloats(lowerIntervals, 0, BULK_SIZE);
150+
in.readFloats(upperIntervals, 0, BULK_SIZE);
151+
for (int i = 0; i < BULK_SIZE; i++) {
152+
targetComponentSums[i] = Short.toUnsignedInt(in.readShort());
153+
}
154+
in.readFloats(additionalCorrections, 0, BULK_SIZE);
155+
for (int i = 0; i < BULK_SIZE; i++) {
156+
scores[i] = score(
157+
queryCorrections,
158+
similarityFunction,
159+
centroidDp,
160+
lowerIntervals[i],
161+
upperIntervals[i],
162+
targetComponentSums[i],
163+
additionalCorrections[i],
164+
scores[i]
165+
);
166+
}
167+
}
168+
}

0 commit comments

Comments
 (0)