Skip to content

Commit df928a4

Browse files
committed
Introduce an Int7VectorScorer for scoring 7-bit quantize vectors
1 parent 8d3634a commit df928a4

File tree

10 files changed

+1144
-4
lines changed

10 files changed

+1144
-4
lines changed
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
package org.elasticsearch.benchmark.vector;
10+
11+
import org.apache.lucene.index.VectorSimilarityFunction;
12+
import org.apache.lucene.store.Directory;
13+
import org.apache.lucene.store.IOContext;
14+
import org.apache.lucene.store.IndexInput;
15+
import org.apache.lucene.store.IndexOutput;
16+
import org.apache.lucene.store.MMapDirectory;
17+
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
18+
import org.elasticsearch.common.logging.LogConfigurator;
19+
import org.elasticsearch.core.IOUtils;
20+
import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
21+
import org.elasticsearch.simdvec.ES92Int7VectorsScorer;
22+
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
23+
import org.openjdk.jmh.annotations.Benchmark;
24+
import org.openjdk.jmh.annotations.BenchmarkMode;
25+
import org.openjdk.jmh.annotations.Fork;
26+
import org.openjdk.jmh.annotations.Measurement;
27+
import org.openjdk.jmh.annotations.Mode;
28+
import org.openjdk.jmh.annotations.OutputTimeUnit;
29+
import org.openjdk.jmh.annotations.Param;
30+
import org.openjdk.jmh.annotations.Scope;
31+
import org.openjdk.jmh.annotations.Setup;
32+
import org.openjdk.jmh.annotations.State;
33+
import org.openjdk.jmh.annotations.TearDown;
34+
import org.openjdk.jmh.annotations.Warmup;
35+
import org.openjdk.jmh.infra.Blackhole;
36+
37+
import java.io.IOException;
38+
import java.nio.file.Files;
39+
import java.util.concurrent.ThreadLocalRandom;
40+
import java.util.concurrent.TimeUnit;
41+
42+
@BenchmarkMode(Mode.Throughput)
43+
@OutputTimeUnit(TimeUnit.MILLISECONDS)
44+
@State(Scope.Benchmark)
45+
// first iteration is complete garbage, so make sure we really warmup
46+
@Warmup(iterations = 4, time = 1)
47+
// real iterations. not useful to spend tons of time here, better to fork more
48+
@Measurement(iterations = 5, time = 1)
49+
// engage some noise reduction
50+
@Fork(value = 1)
51+
public class Int7ScorerBenchmark {
52+
53+
static {
54+
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
55+
}
56+
57+
@Param({ "384", "782", "1024" })
58+
int dims;
59+
60+
int numVectors = 20 * ES92Int7VectorsScorer.BULK_SIZE;
61+
int numQueries = 5;
62+
63+
byte[] scratch;
64+
byte[][] binaryVectors;
65+
byte[][] binaryQueries;
66+
float[] scores = new float[ES92Int7VectorsScorer.BULK_SIZE];
67+
68+
ES92Int7VectorsScorer scorer;
69+
Directory dir;
70+
IndexInput in;
71+
72+
OptimizedScalarQuantizer.QuantizationResult queryCorrections;
73+
float centroidDp;
74+
75+
@Setup
76+
public void setup() throws IOException {
77+
binaryVectors = new byte[numVectors][dims];
78+
dir = new MMapDirectory(Files.createTempDirectory("vectorData"));
79+
try (IndexOutput out = dir.createOutput("vectors", IOContext.DEFAULT)) {
80+
for (byte[] binaryVector : binaryVectors) {
81+
for (int i = 0; i < dims; i++) {
82+
// 4-bit quantization
83+
binaryVector[i] = (byte) ThreadLocalRandom.current().nextInt(128);
84+
}
85+
out.writeBytes(binaryVector, 0, binaryVector.length);
86+
ThreadLocalRandom.current().nextBytes(binaryVector);
87+
out.writeBytes(binaryVector, 0, 16); // corrections
88+
}
89+
}
90+
91+
queryCorrections = new OptimizedScalarQuantizer.QuantizationResult(
92+
ThreadLocalRandom.current().nextFloat(),
93+
ThreadLocalRandom.current().nextFloat(),
94+
ThreadLocalRandom.current().nextFloat(),
95+
Short.toUnsignedInt((short) ThreadLocalRandom.current().nextInt())
96+
);
97+
centroidDp = ThreadLocalRandom.current().nextFloat();
98+
99+
in = dir.openInput("vectors", IOContext.DEFAULT);
100+
binaryQueries = new byte[numVectors][dims];
101+
for (byte[] binaryVector : binaryVectors) {
102+
for (int i = 0; i < dims; i++) {
103+
// 7-bit quantization
104+
binaryVector[i] = (byte) ThreadLocalRandom.current().nextInt(128);
105+
}
106+
}
107+
108+
scratch = new byte[dims];
109+
scorer = ESVectorizationProvider.getInstance().newES92Int7VectorsScorer(in, dims);
110+
}
111+
112+
@TearDown
113+
public void teardown() throws IOException {
114+
IOUtils.close(dir, in);
115+
}
116+
117+
@Benchmark
118+
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
119+
public void scoreFromMemorySegment(Blackhole bh) throws IOException {
120+
for (int j = 0; j < numQueries; j++) {
121+
in.seek(0);
122+
for (int i = 0; i < numVectors; i++) {
123+
bh.consume(
124+
scorer.score(
125+
binaryQueries[j],
126+
queryCorrections.lowerInterval(),
127+
queryCorrections.upperInterval(),
128+
queryCorrections.quantizedComponentSum(),
129+
queryCorrections.additionalCorrection(),
130+
VectorSimilarityFunction.EUCLIDEAN,
131+
centroidDp
132+
)
133+
);
134+
}
135+
}
136+
}
137+
138+
@Benchmark
139+
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
140+
public void scoreFromMemorySegmentBulk(Blackhole bh) throws IOException {
141+
for (int j = 0; j < numQueries; j++) {
142+
in.seek(0);
143+
for (int i = 0; i < numVectors; i += ES91Int4VectorsScorer.BULK_SIZE) {
144+
scorer.scoreBulk(
145+
binaryQueries[j],
146+
queryCorrections.lowerInterval(),
147+
queryCorrections.upperInterval(),
148+
queryCorrections.quantizedComponentSum(),
149+
queryCorrections.additionalCorrection(),
150+
VectorSimilarityFunction.EUCLIDEAN,
151+
centroidDp,
152+
scores
153+
);
154+
for (float score : scores) {
155+
bh.consume(score);
156+
}
157+
}
158+
}
159+
}
160+
}
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
package org.elasticsearch.simdvec;
10+
11+
import org.apache.lucene.index.VectorSimilarityFunction;
12+
import org.apache.lucene.store.IndexInput;
13+
import org.apache.lucene.util.VectorUtil;
14+
15+
import java.io.IOException;
16+
17+
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
18+
import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
19+
20+
/**
21+
* Scorer for 7 bit quantized vectors stored in a {@link IndexInput}.
22+
* Queries are expected to be quantized using 7 bits as well.
23+
* */
24+
public class ES92Int7VectorsScorer {
25+
26+
public static final int BULK_SIZE = 16;
27+
protected static final float SEVEN_BIT_SCALE = 1f / ((1 << 7) - 1);
28+
29+
/** The wrapper {@link IndexInput}. */
30+
protected final IndexInput in;
31+
protected final int dimensions;
32+
33+
private final float[] lowerIntervals = new float[BULK_SIZE];
34+
private final float[] upperIntervals = new float[BULK_SIZE];
35+
private final int[] targetComponentSums = new int[BULK_SIZE];
36+
private final float[] additionalCorrections = new float[BULK_SIZE];
37+
38+
/** Sole constructor, called by sub-classes. */
39+
public ES92Int7VectorsScorer(IndexInput in, int dimensions) {
40+
this.in = in;
41+
this.dimensions = dimensions;
42+
}
43+
44+
/**
45+
* compute the quantize distance between the provided quantized query and the quantized vector
46+
* that is read from the wrapped {@link IndexInput}.
47+
*/
48+
public long int7DotProduct(byte[] b) throws IOException {
49+
int total = 0;
50+
for (int i = 0; i < dimensions; i++) {
51+
total += in.readByte() * b[i];
52+
}
53+
return total;
54+
}
55+
56+
/**
57+
* compute the quantize distance between the provided quantized query and the quantized vectors
58+
* that are read from the wrapped {@link IndexInput}. The number of quantized vectors to read is
59+
* determined by {code count} and the results are stored in the provided {@code scores} array.
60+
*/
61+
public void int7DotProductBulk(byte[] b, int count, float[] scores) throws IOException {
62+
for (int i = 0; i < count; i++) {
63+
scores[i] = int7DotProduct(b);
64+
}
65+
}
66+
67+
/**
68+
* Computes the score by applying the necessary corrections to the provided quantized distance.
69+
*/
70+
public float score(
71+
byte[] q,
72+
float queryLowerInterval,
73+
float queryUpperInterval,
74+
int queryComponentSum,
75+
float queryAdditionalCorrection,
76+
VectorSimilarityFunction similarityFunction,
77+
float centroidDp
78+
) throws IOException {
79+
float score = int7DotProduct(q);
80+
in.readFloats(lowerIntervals, 0, 3);
81+
int addition = in.readInt();
82+
return applyCorrections(
83+
queryLowerInterval,
84+
queryUpperInterval,
85+
queryComponentSum,
86+
queryAdditionalCorrection,
87+
similarityFunction,
88+
centroidDp,
89+
lowerIntervals[0],
90+
lowerIntervals[1],
91+
addition,
92+
lowerIntervals[2],
93+
score
94+
);
95+
}
96+
97+
/**
98+
* compute the distance between the provided quantized query and the quantized vectors that are
99+
* read from the wrapped {@link IndexInput}.
100+
*
101+
* <p>The number of vectors to score is defined by {@link #BULK_SIZE}. The expected format of the
102+
* input is as follows: First the quantized vectors are read from the input,then all the lower
103+
* intervals as floats, then all the upper intervals as floats, then all the target component sums
104+
* as shorts, and finally all the additional corrections as floats.
105+
*
106+
* <p>The results are stored in the provided scores array.
107+
*/
108+
public void scoreBulk(
109+
byte[] q,
110+
float queryLowerInterval,
111+
float queryUpperInterval,
112+
int queryComponentSum,
113+
float queryAdditionalCorrection,
114+
VectorSimilarityFunction similarityFunction,
115+
float centroidDp,
116+
float[] scores
117+
) throws IOException {
118+
int7DotProductBulk(q, BULK_SIZE, scores);
119+
in.readFloats(lowerIntervals, 0, BULK_SIZE);
120+
in.readFloats(upperIntervals, 0, BULK_SIZE);
121+
in.readInts(targetComponentSums, 0, BULK_SIZE);
122+
in.readFloats(additionalCorrections, 0, BULK_SIZE);
123+
for (int i = 0; i < BULK_SIZE; i++) {
124+
scores[i] = applyCorrections(
125+
queryLowerInterval,
126+
queryUpperInterval,
127+
queryComponentSum,
128+
queryAdditionalCorrection,
129+
similarityFunction,
130+
centroidDp,
131+
lowerIntervals[i],
132+
upperIntervals[i],
133+
targetComponentSums[i],
134+
additionalCorrections[i],
135+
scores[i]
136+
);
137+
}
138+
}
139+
140+
/**
141+
* Computes the score by applying the necessary corrections to the provided quantized distance.
142+
*/
143+
public float applyCorrections(
144+
float queryLowerInterval,
145+
float queryUpperInterval,
146+
int queryComponentSum,
147+
float queryAdditionalCorrection,
148+
VectorSimilarityFunction similarityFunction,
149+
float centroidDp,
150+
float lowerInterval,
151+
float upperInterval,
152+
int targetComponentSum,
153+
float additionalCorrection,
154+
float qcDist
155+
) {
156+
float ax = lowerInterval;
157+
// Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
158+
float lx = (upperInterval - ax) * SEVEN_BIT_SCALE;
159+
float ay = queryLowerInterval;
160+
float ly = (queryUpperInterval - ay) * SEVEN_BIT_SCALE;
161+
float y1 = queryComponentSum;
162+
float score = ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist;
163+
// For euclidean, we need to invert the score and apply the additional correction, which is
164+
// assumed to be the squared l2norm of the centroid centered vectors.
165+
if (similarityFunction == EUCLIDEAN) {
166+
score = queryAdditionalCorrection + additionalCorrection - 2 * score;
167+
return Math.max(1 / (1f + score), 0);
168+
} else {
169+
// For cosine and max inner product, we need to apply the additional correction, which is
170+
// assumed to be the non-centered dot-product between the vector and the centroid
171+
score += queryAdditionalCorrection + additionalCorrection - centroidDp;
172+
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
173+
return VectorUtil.scaleMaxInnerProductScore(score);
174+
}
175+
return Math.max((1f + score) / 2f, 0);
176+
}
177+
}
178+
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ public static ES91Int4VectorsScorer getES91Int4VectorsScorer(IndexInput input, i
5151
return ESVectorizationProvider.getInstance().newES91Int4VectorsScorer(input, dimension);
5252
}
5353

54+
public static ES92Int7VectorsScorer getES92Int7VectorsScorer(IndexInput input, int dimension) throws IOException {
55+
return ESVectorizationProvider.getInstance().newES92Int7VectorsScorer(input, dimension);
56+
}
57+
5458
public static long ipByteBinByte(byte[] q, byte[] d) {
5559
if (q.length != d.length * B_QUERY) {
5660
throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + B_QUERY + " x " + d.length);

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
import org.apache.lucene.store.IndexInput;
1313
import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
1414
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
15-
16-
import java.io.IOException;
15+
import org.elasticsearch.simdvec.ES92Int7VectorsScorer;
1716

1817
final class DefaultESVectorizationProvider extends ESVectorizationProvider {
1918
private final ESVectorUtilSupport vectorUtilSupport;
@@ -28,12 +27,17 @@ public ESVectorUtilSupport getVectorUtilSupport() {
2827
}
2928

3029
@Override
31-
public ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException {
30+
public ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) {
3231
return new ES91OSQVectorsScorer(input, dimension);
3332
}
3433

3534
@Override
36-
public ES91Int4VectorsScorer newES91Int4VectorsScorer(IndexInput input, int dimension) throws IOException {
35+
public ES91Int4VectorsScorer newES91Int4VectorsScorer(IndexInput input, int dimension) {
3736
return new ES91Int4VectorsScorer(input, dimension);
3837
}
38+
39+
@Override
40+
public ES92Int7VectorsScorer newES92Int7VectorsScorer(IndexInput input, int dimension) {
41+
return new ES92Int7VectorsScorer(input, dimension);
42+
}
3943
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.apache.lucene.store.IndexInput;
1313
import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
1414
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
15+
import org.elasticsearch.simdvec.ES92Int7VectorsScorer;
1516

1617
import java.io.IOException;
1718
import java.util.Objects;
@@ -35,6 +36,9 @@ public static ESVectorizationProvider getInstance() {
3536
/** Create a new {@link ES91Int4VectorsScorer} for the given {@link IndexInput}. */
3637
public abstract ES91Int4VectorsScorer newES91Int4VectorsScorer(IndexInput input, int dimension) throws IOException;
3738

39+
/** Create a new {@link ES92Int7VectorsScorer} for the given {@link IndexInput}. */
40+
public abstract ES92Int7VectorsScorer newES92Int7VectorsScorer(IndexInput input, int dimension) throws IOException;
41+
3842
// visible for tests
3943
static ESVectorizationProvider lookup(boolean testMode) {
4044
return new DefaultESVectorizationProvider();

0 commit comments

Comments
 (0)