Skip to content

Commit 8e7331e

Browse files
committed
New bulk scorer for binary quantized vectors via optimized scalar quantization
1 parent c2fdc06 commit 8e7331e

File tree

8 files changed

+960
-0
lines changed

8 files changed

+960
-0
lines changed
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
package org.elasticsearch.benchmark.vector;
10+
11+
import org.apache.lucene.index.VectorSimilarityFunction;
12+
import org.apache.lucene.store.Directory;
13+
import org.apache.lucene.store.IOContext;
14+
import org.apache.lucene.store.IndexInput;
15+
import org.apache.lucene.store.IndexOutput;
16+
import org.apache.lucene.store.MMapDirectory;
17+
import org.apache.lucene.util.VectorUtil;
18+
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
19+
import org.elasticsearch.common.logging.LogConfigurator;
20+
import org.elasticsearch.simdvec.internal.vectorization.ES91OSQVectorsScorer;
21+
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
22+
import org.openjdk.jmh.annotations.Benchmark;
23+
import org.openjdk.jmh.annotations.BenchmarkMode;
24+
import org.openjdk.jmh.annotations.Fork;
25+
import org.openjdk.jmh.annotations.Measurement;
26+
import org.openjdk.jmh.annotations.Mode;
27+
import org.openjdk.jmh.annotations.OutputTimeUnit;
28+
import org.openjdk.jmh.annotations.Param;
29+
import org.openjdk.jmh.annotations.Scope;
30+
import org.openjdk.jmh.annotations.Setup;
31+
import org.openjdk.jmh.annotations.State;
32+
import org.openjdk.jmh.annotations.Warmup;
33+
import org.openjdk.jmh.infra.Blackhole;
34+
35+
import java.io.IOException;
36+
import java.nio.file.Files;
37+
import java.util.Random;
38+
import java.util.concurrent.TimeUnit;
39+
40+
@BenchmarkMode(Mode.Throughput)
41+
@OutputTimeUnit(TimeUnit.MILLISECONDS)
42+
@State(Scope.Benchmark)
43+
// first iteration is complete garbage, so make sure we really warmup
44+
@Warmup(iterations = 4, time = 1)
45+
// real iterations. not useful to spend tons of time here, better to fork more
46+
@Measurement(iterations = 5, time = 1)
47+
// engage some noise reduction
48+
@Fork(value = 1)
49+
public class OSQScorerBenchmark {
50+
51+
static {
52+
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
53+
}
54+
55+
@Param({ "1024" })
56+
int dims;
57+
58+
int length;
59+
60+
int numVectors = ES91OSQVectorsScorer.BULK_SIZE * 10;
61+
int numQueries = 10;
62+
63+
byte[][] binaryVectors;
64+
byte[][] binaryQueries;
65+
OptimizedScalarQuantizer.QuantizationResult result;
66+
float centroidDp;
67+
68+
byte[] scratch;
69+
ES91OSQVectorsScorer scorer;
70+
71+
IndexInput in;
72+
73+
float[] scratchScores;
74+
float[] corrections;
75+
76+
@Setup
77+
public void setup() throws IOException {
78+
Random random = new Random(123);
79+
80+
this.length = OptimizedScalarQuantizer.discretize(dims, 64) / 8;
81+
82+
binaryVectors = new byte[numVectors][length];
83+
for (byte[] binaryVector : binaryVectors) {
84+
random.nextBytes(binaryVector);
85+
}
86+
87+
Directory dir = new MMapDirectory(Files.createTempDirectory("vectorData"));
88+
IndexOutput out = dir.createOutput("vectors", IOContext.DEFAULT);
89+
byte[] correctionBytes = new byte[14 * ES91OSQVectorsScorer.BULK_SIZE];
90+
for (int i = 0; i < numVectors; i += ES91OSQVectorsScorer.BULK_SIZE) {
91+
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
92+
out.writeBytes(binaryVectors[i + j], 0, binaryVectors[i + j].length);
93+
}
94+
random.nextBytes(correctionBytes);
95+
out.writeBytes(correctionBytes, 0, correctionBytes.length);
96+
}
97+
out.close();
98+
in = dir.openInput("vectors", IOContext.DEFAULT);
99+
100+
binaryQueries = new byte[numVectors][4 * length];
101+
for (byte[] binaryVector : binaryVectors) {
102+
random.nextBytes(binaryVector);
103+
}
104+
result = new OptimizedScalarQuantizer.QuantizationResult(
105+
random.nextFloat(),
106+
random.nextFloat(),
107+
random.nextFloat(),
108+
Short.toUnsignedInt((short) random.nextInt())
109+
);
110+
centroidDp = random.nextFloat();
111+
112+
scratch = new byte[length];
113+
scorer = ESVectorizationProvider.getInstance().newES91OSQVectorsScorer(in, dims);
114+
scratchScores = new float[16];
115+
corrections = new float[3];
116+
}
117+
118+
@Benchmark
119+
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
120+
public void scoreFromArray(Blackhole bh) throws IOException {
121+
for (int j = 0; j < numQueries; j++) {
122+
in.seek(0);
123+
for (int i = 0; i < numVectors; i++) {
124+
in.readBytes(scratch, 0, length);
125+
float qDist = VectorUtil.int4BitDotProduct(binaryQueries[j], scratch);
126+
in.readFloats(corrections, 0, corrections.length);
127+
int addition = Short.toUnsignedInt(in.readShort());
128+
float score = scorer.score(
129+
result,
130+
VectorSimilarityFunction.EUCLIDEAN,
131+
centroidDp,
132+
corrections[0],
133+
corrections[1],
134+
addition,
135+
corrections[2],
136+
qDist
137+
);
138+
bh.consume(score);
139+
}
140+
}
141+
}
142+
143+
@Benchmark
144+
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
145+
public void scoreFromMemorySegmentOnlyVector(Blackhole bh) throws IOException {
146+
for (int j = 0; j < numQueries; j++) {
147+
in.seek(0);
148+
for (int i = 0; i < numVectors; i++) {
149+
float qDist = scorer.quantizeScore(binaryQueries[j]);
150+
in.readFloats(corrections, 0, corrections.length);
151+
int addition = Short.toUnsignedInt(in.readShort());
152+
float score = scorer.score(
153+
result,
154+
VectorSimilarityFunction.EUCLIDEAN,
155+
centroidDp,
156+
corrections[0],
157+
corrections[1],
158+
addition,
159+
corrections[2],
160+
qDist
161+
);
162+
bh.consume(score);
163+
}
164+
}
165+
}
166+
167+
@Benchmark
168+
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
169+
public void scoreFromMemorySegmentOnlyVectorBulk(Blackhole bh) throws IOException {
170+
for (int j = 0; j < numQueries; j++) {
171+
in.seek(0);
172+
for (int i = 0; i < numVectors; i += 16) {
173+
scorer.quantizeScoreBulk(binaryQueries[j], ES91OSQVectorsScorer.BULK_SIZE, scratchScores);
174+
for (int k = 0; k < ES91OSQVectorsScorer.BULK_SIZE; k++) {
175+
in.readFloats(corrections, 0, corrections.length);
176+
int addition = Short.toUnsignedInt(in.readShort());
177+
float score = scorer.score(
178+
result,
179+
VectorSimilarityFunction.EUCLIDEAN,
180+
centroidDp,
181+
corrections[0],
182+
corrections[1],
183+
addition,
184+
corrections[2],
185+
scratchScores[k]
186+
);
187+
bh.consume(score);
188+
}
189+
}
190+
}
191+
}
192+
193+
@Benchmark
194+
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
195+
public void scoreFromMemorySegmentAllBulk(Blackhole bh) throws IOException {
196+
for (int j = 0; j < numQueries; j++) {
197+
in.seek(0);
198+
for (int i = 0; i < numVectors; i += 16) {
199+
scorer.scoreBulk(binaryQueries[j], result, VectorSimilarityFunction.EUCLIDEAN, centroidDp, scratchScores);
200+
bh.consume(scratchScores);
201+
}
202+
}
203+
}
204+
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99

1010
package org.elasticsearch.simdvec.internal.vectorization;
1111

12+
import org.apache.lucene.store.IndexInput;
13+
14+
import java.io.IOException;
15+
1216
final class DefaultESVectorizationProvider extends ESVectorizationProvider {
1317
private final ESVectorUtilSupport vectorUtilSupport;
1418

@@ -20,4 +24,9 @@ final class DefaultESVectorizationProvider extends ESVectorizationProvider {
2024
public ESVectorUtilSupport getVectorUtilSupport() {
2125
return vectorUtilSupport;
2226
}
27+
28+
@Override
29+
public ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException {
30+
return new ES91OSQVectorsScorer(input, dimension);
31+
}
2332
}
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
package org.elasticsearch.simdvec.internal.vectorization;
8+
9+
import org.apache.lucene.index.VectorSimilarityFunction;
10+
import org.apache.lucene.store.IndexInput;
11+
import org.apache.lucene.util.BitUtil;
12+
import org.apache.lucene.util.VectorUtil;
13+
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
14+
15+
import java.io.IOException;
16+
17+
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
18+
import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
19+
20+
/** Scorer for quantized vectors stored as an {@link IndexInput}. */
21+
public class ES91OSQVectorsScorer {
22+
23+
public static final int BULK_SIZE = 16;
24+
25+
protected static final float FOUR_BIT_SCALE = 1f / ((1 << 4) - 1);
26+
27+
/** The wrapper {@link IndexInput}. */
28+
protected final IndexInput in;
29+
30+
protected final int length;
31+
protected final int dimensions;
32+
33+
protected final float[] lowerIntervals = new float[BULK_SIZE];
34+
protected final float[] upperIntervals = new float[BULK_SIZE];
35+
protected final int[] targetComponentSums = new int[BULK_SIZE];
36+
protected final float[] additionalCorrections = new float[BULK_SIZE];
37+
38+
/** Sole constructor, called by sub-classes. */
39+
public ES91OSQVectorsScorer(IndexInput in, int dimensions) {
40+
this.in = in;
41+
this.dimensions = dimensions;
42+
this.length = OptimizedScalarQuantizer.discretize(dimensions, 64) / 8;
43+
}
44+
45+
/**
46+
* compute the quantize distance between the provided quantized query and the quantized vector
47+
* that is read from the wrapped {@link IndexInput}.
48+
*/
49+
public long quantizeScore(byte[] q) throws IOException {
50+
assert q.length == length * 4;
51+
final int size = length;
52+
long subRet0 = 0;
53+
long subRet1 = 0;
54+
long subRet2 = 0;
55+
long subRet3 = 0;
56+
int r = 0;
57+
for (final int upperBound = size & -Long.BYTES; r < upperBound; r += Long.BYTES) {
58+
final long value = in.readLong();
59+
subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, r) & value);
60+
subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, r + size) & value);
61+
subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, r + 2 * size) & value);
62+
subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, r + 3 * size) & value);
63+
}
64+
for (final int upperBound = size & -Integer.BYTES; r < upperBound; r += Integer.BYTES) {
65+
final int value = in.readInt();
66+
subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, r) & value);
67+
subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, r + size) & value);
68+
subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, r + 2 * size) & value);
69+
subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, r + 3 * size) & value);
70+
}
71+
for (; r < size; r++) {
72+
final byte value = in.readByte();
73+
subRet0 += Integer.bitCount((q[r] & value) & 0xFF);
74+
subRet1 += Integer.bitCount((q[r + size] & value) & 0xFF);
75+
subRet2 += Integer.bitCount((q[r + 2 * size] & value) & 0xFF);
76+
subRet3 += Integer.bitCount((q[r + 3 * size] & value) & 0xFF);
77+
}
78+
return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
79+
}
80+
81+
/**
82+
* compute the quantize distance between the provided quantized query and the quantized vectors
83+
* that are read from the wrapped {@link IndexInput}. The number of quantized vectors to read is
84+
* determined by {code count} and the results are stored in the provided {@code scores} array.
85+
*/
86+
public void quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOException {
87+
for (int i = 0; i < count; i++) {
88+
scores[i] = quantizeScore(q);
89+
}
90+
}
91+
92+
/**
93+
* Computes the score by applying the necessary corrections to the provided quantized distance.
94+
*/
95+
public float score(
96+
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
97+
VectorSimilarityFunction similarityFunction,
98+
float centroidDp,
99+
float lowerInterval,
100+
float upperInterval,
101+
int targetComponentSum,
102+
float additionalCorrection,
103+
float qcDist
104+
) {
105+
float ax = lowerInterval;
106+
// Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
107+
float lx = upperInterval - ax;
108+
float ay = queryCorrections.lowerInterval();
109+
float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
110+
float y1 = queryCorrections.quantizedComponentSum();
111+
float score = ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist;
112+
// For euclidean, we need to invert the score and apply the additional correction, which is
113+
// assumed to be the squared l2norm of the centroid centered vectors.
114+
if (similarityFunction == EUCLIDEAN) {
115+
score = queryCorrections.additionalCorrection() + additionalCorrection - 2 * score;
116+
return Math.max(1 / (1f + score), 0);
117+
} else {
118+
// For cosine and max inner product, we need to apply the additional correction, which is
119+
// assumed to be the non-centered dot-product between the vector and the centroid
120+
score += queryCorrections.additionalCorrection() + additionalCorrection - centroidDp;
121+
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
122+
return VectorUtil.scaleMaxInnerProductScore(score);
123+
}
124+
return Math.max((1f + score) / 2f, 0);
125+
}
126+
}
127+
128+
/**
129+
* compute the distance between the provided quantized query and the quantized vectors that are
130+
* read from the wrapped {@link IndexInput}.
131+
*
132+
* <p>The number of vectors to score is defined by {@link #BULK_SIZE}. The expected format of the
133+
* input is as follows: First the quantized vectors are read from the input,then all the lower
134+
* intervals as floats, then all the upper intervals as floats, then all the target component sums
135+
* as shorts, and finally all the additional corrections as floats.
136+
*
137+
* <p>The results are stored in the provided scores array.
138+
*/
139+
public void scoreBulk(
140+
byte[] q,
141+
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
142+
VectorSimilarityFunction similarityFunction,
143+
float centroidDp,
144+
float[] scores
145+
) throws IOException {
146+
quantizeScoreBulk(q, BULK_SIZE, scores);
147+
in.readFloats(lowerIntervals, 0, BULK_SIZE);
148+
in.readFloats(upperIntervals, 0, BULK_SIZE);
149+
for (int i = 0; i < BULK_SIZE; i++) {
150+
targetComponentSums[i] = Short.toUnsignedInt(in.readShort());
151+
}
152+
in.readFloats(additionalCorrections, 0, BULK_SIZE);
153+
for (int i = 0; i < BULK_SIZE; i++) {
154+
scores[i] = score(
155+
queryCorrections,
156+
similarityFunction,
157+
centroidDp,
158+
lowerIntervals[i],
159+
upperIntervals[i],
160+
targetComponentSums[i],
161+
additionalCorrections[i],
162+
scores[i]
163+
);
164+
}
165+
}
166+
}

0 commit comments

Comments
 (0)