Skip to content

Commit e717741

Browse files
authored
Adding more bits capability to diskbbq (#135877)
This is adding the initial ground work for having more bit support. I am thinking it will be fairly restrictive, only allowing particular bits for indexing which are paired with a static set of query bits. This PR just lays down the initial interfaces for the "next" format, it doesn't really do anything new yet. One big thing it does is that it doesn't pad every vector to be divisible by 64. Maybe we want to do that, but for smaller vectors, its very wasteful. We already support iterating vectors that aren't exactly divisible by the number of bits available in the SIMD impl on the CPU.
1 parent e5b3064 commit e717741

File tree

15 files changed

+1328
-103
lines changed

15 files changed

+1328
-103
lines changed
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
package org.elasticsearch.simdvec;
10+
11+
import org.apache.lucene.index.VectorSimilarityFunction;
12+
import org.apache.lucene.store.IndexInput;
13+
import org.apache.lucene.util.BitUtil;
14+
import org.apache.lucene.util.VectorUtil;
15+
16+
import java.io.IOException;
17+
18+
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
19+
import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
20+
21+
/** Scorer for quantized vectors stored as an {@link IndexInput}. */
22+
public class ESNextOSQVectorsScorer {
23+
24+
public static final int BULK_SIZE = 16;
25+
26+
protected static final float[] BIT_SCALES = new float[] {
27+
1f,
28+
1f / ((1 << 2) - 1),
29+
1f / ((1 << 3) - 1),
30+
1f / ((1 << 4) - 1),
31+
1f / ((1 << 5) - 1),
32+
1f / ((1 << 6) - 1),
33+
1f / ((1 << 7) - 1),
34+
1f / ((1 << 8) - 1), };
35+
36+
/** The wrapper {@link IndexInput}. */
37+
protected final IndexInput in;
38+
39+
protected final byte queryBits;
40+
protected final byte indexBits;
41+
protected final int length;
42+
protected final int dimensions;
43+
44+
protected final float[] lowerIntervals = new float[BULK_SIZE];
45+
protected final float[] upperIntervals = new float[BULK_SIZE];
46+
protected final int[] targetComponentSums = new int[BULK_SIZE];
47+
protected final float[] additionalCorrections = new float[BULK_SIZE];
48+
49+
/** Sole constructor, called by sub-classes. */
50+
public ESNextOSQVectorsScorer(IndexInput in, byte queryBits, byte indexBits, int dimensions, int dataLength) {
51+
if (queryBits != 4 || indexBits != 1) {
52+
throw new IllegalArgumentException("Only asymmetric 4-bit query and 1-bit index supported");
53+
}
54+
this.in = in;
55+
this.queryBits = queryBits;
56+
this.indexBits = indexBits;
57+
this.dimensions = dimensions;
58+
this.length = dataLength;
59+
}
60+
61+
/**
62+
* compute the quantize distance between the provided quantized query and the quantized vector
63+
* that is read from the wrapped {@link IndexInput}.
64+
*/
65+
public long quantizeScore(byte[] q) throws IOException {
66+
if (indexBits == 1) {
67+
if (queryBits == 4) {
68+
return quantized4BitScore(q);
69+
}
70+
throw new IllegalArgumentException("Only asymmetric 4-bit query supported");
71+
}
72+
throw new IllegalArgumentException("Only 1-bit index supported");
73+
74+
}
75+
76+
private long quantized4BitScore(byte[] q) throws IOException {
77+
assert q.length == length * 4;
78+
final int size = length;
79+
long subRet0 = 0;
80+
long subRet1 = 0;
81+
long subRet2 = 0;
82+
long subRet3 = 0;
83+
int r = 0;
84+
for (final int upperBound = size & -Long.BYTES; r < upperBound; r += Long.BYTES) {
85+
final long value = in.readLong();
86+
subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, r) & value);
87+
subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, r + size) & value);
88+
subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, r + 2 * size) & value);
89+
subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, r + 3 * size) & value);
90+
}
91+
for (final int upperBound = size & -Integer.BYTES; r < upperBound; r += Integer.BYTES) {
92+
final int value = in.readInt();
93+
subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, r) & value);
94+
subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, r + size) & value);
95+
subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, r + 2 * size) & value);
96+
subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, r + 3 * size) & value);
97+
}
98+
for (; r < size; r++) {
99+
final byte value = in.readByte();
100+
subRet0 += Integer.bitCount((q[r] & value) & 0xFF);
101+
subRet1 += Integer.bitCount((q[r + size] & value) & 0xFF);
102+
subRet2 += Integer.bitCount((q[r + 2 * size] & value) & 0xFF);
103+
subRet3 += Integer.bitCount((q[r + 3 * size] & value) & 0xFF);
104+
}
105+
return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
106+
}
107+
108+
/**
109+
* compute the quantize distance between the provided quantized query and the quantized vectors
110+
* that are read from the wrapped {@link IndexInput}. The number of quantized vectors to read is
111+
* determined by {code count} and the results are stored in the provided {@code scores} array.
112+
*/
113+
public void quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOException {
114+
if (indexBits == 1) {
115+
if (queryBits == 4) {
116+
for (int i = 0; i < count; i++) {
117+
scores[i] = quantizeScore(q);
118+
}
119+
return;
120+
}
121+
throw new IllegalArgumentException("Only asymmetric 4-bit query supported");
122+
}
123+
}
124+
125+
/**
126+
* Computes the score by applying the necessary corrections to the provided quantized distance.
127+
*/
128+
public float score(
129+
float queryLowerInterval,
130+
float queryUpperInterval,
131+
int queryComponentSum,
132+
float queryAdditionalCorrection,
133+
VectorSimilarityFunction similarityFunction,
134+
float centroidDp,
135+
float lowerInterval,
136+
float upperInterval,
137+
int targetComponentSum,
138+
float additionalCorrection,
139+
float qcDist
140+
) {
141+
float ax = lowerInterval;
142+
// Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
143+
float lx = (upperInterval - ax) * BIT_SCALES[indexBits];
144+
float ay = queryLowerInterval;
145+
float ly = (queryUpperInterval - ay) * BIT_SCALES[queryBits];
146+
float y1 = queryComponentSum;
147+
float score = ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist;
148+
// For euclidean, we need to invert the score and apply the additional correction, which is
149+
// assumed to be the squared l2norm of the centroid centered vectors.
150+
if (similarityFunction == EUCLIDEAN) {
151+
score = queryAdditionalCorrection + additionalCorrection - 2 * score;
152+
return Math.max(1 / (1f + score), 0);
153+
} else {
154+
// For cosine and max inner product, we need to apply the additional correction, which is
155+
// assumed to be the non-centered dot-product between the vector and the centroid
156+
score += queryAdditionalCorrection + additionalCorrection - centroidDp;
157+
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
158+
return VectorUtil.scaleMaxInnerProductScore(score);
159+
}
160+
return Math.max((1f + score) / 2f, 0);
161+
}
162+
}
163+
164+
/**
165+
* compute the distance between the provided quantized query and the quantized vectors that are
166+
* read from the wrapped {@link IndexInput}.
167+
*
168+
* <p>The number of vectors to score is defined by {@link #BULK_SIZE}. The expected format of the
169+
* input is as follows: First the quantized vectors are read from the input,then all the lower
170+
* intervals as floats, then all the upper intervals as floats, then all the target component sums
171+
* as shorts, and finally all the additional corrections as floats.
172+
*
173+
* <p>The results are stored in the provided scores array.
174+
*/
175+
public float scoreBulk(
176+
byte[] q,
177+
float queryLowerInterval,
178+
float queryUpperInterval,
179+
int queryComponentSum,
180+
float queryAdditionalCorrection,
181+
VectorSimilarityFunction similarityFunction,
182+
float centroidDp,
183+
float[] scores
184+
) throws IOException {
185+
quantizeScoreBulk(q, BULK_SIZE, scores);
186+
in.readFloats(lowerIntervals, 0, BULK_SIZE);
187+
in.readFloats(upperIntervals, 0, BULK_SIZE);
188+
for (int i = 0; i < BULK_SIZE; i++) {
189+
targetComponentSums[i] = Short.toUnsignedInt(in.readShort());
190+
}
191+
in.readFloats(additionalCorrections, 0, BULK_SIZE);
192+
float maxScore = Float.NEGATIVE_INFINITY;
193+
for (int i = 0; i < BULK_SIZE; i++) {
194+
scores[i] = score(
195+
queryLowerInterval,
196+
queryUpperInterval,
197+
queryComponentSum,
198+
queryAdditionalCorrection,
199+
similarityFunction,
200+
centroidDp,
201+
lowerIntervals[i],
202+
upperIntervals[i],
203+
targetComponentSums[i],
204+
additionalCorrections[i],
205+
scores[i]
206+
);
207+
if (scores[i] > maxScore) {
208+
maxScore = scores[i];
209+
}
210+
}
211+
return maxScore;
212+
}
213+
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,16 @@ public static ES91OSQVectorsScorer getES91OSQVectorsScorer(IndexInput input, int
4848
return ESVectorizationProvider.getInstance().newES91OSQVectorsScorer(input, dimension);
4949
}
5050

51+
public static ESNextOSQVectorsScorer getESNextOSQVectorsScorer(
52+
IndexInput input,
53+
byte queryBits,
54+
byte indexBits,
55+
int dimension,
56+
int dataLength
57+
) throws IOException {
58+
return ESVectorizationProvider.getInstance().newESNextOSQVectorsScorer(input, queryBits, indexBits, dimension, dataLength);
59+
}
60+
5161
public static ES91Int4VectorsScorer getES91Int4VectorsScorer(IndexInput input, int dimension) throws IOException {
5262
return ESVectorizationProvider.getInstance().newES91Int4VectorsScorer(input, dimension);
5363
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
1414
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
1515
import org.elasticsearch.simdvec.ES92Int7VectorsScorer;
16+
import org.elasticsearch.simdvec.ESNextOSQVectorsScorer;
17+
18+
import java.io.IOException;
1619

1720
final class DefaultESVectorizationProvider extends ESVectorizationProvider {
1821
private final ESVectorUtilSupport vectorUtilSupport;
@@ -31,6 +34,12 @@ public ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimens
3134
return new ES91OSQVectorsScorer(input, dimension);
3235
}
3336

37+
@Override
38+
public ESNextOSQVectorsScorer newESNextOSQVectorsScorer(IndexInput input, byte queryBits, byte indexBits, int dimension, int dataLength)
39+
throws IOException {
40+
return new ESNextOSQVectorsScorer(input, queryBits, indexBits, dimension, dataLength);
41+
}
42+
3443
@Override
3544
public ES91Int4VectorsScorer newES91Int4VectorsScorer(IndexInput input, int dimension) {
3645
return new ES91Int4VectorsScorer(input, dimension);

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
1414
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
1515
import org.elasticsearch.simdvec.ES92Int7VectorsScorer;
16+
import org.elasticsearch.simdvec.ESNextOSQVectorsScorer;
1617

1718
import java.io.IOException;
1819
import java.util.Objects;
@@ -33,6 +34,14 @@ public static ESVectorizationProvider getInstance() {
3334
/** Create a new {@link ES91OSQVectorsScorer} for the given {@link IndexInput}. */
3435
public abstract ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException;
3536

37+
public abstract ESNextOSQVectorsScorer newESNextOSQVectorsScorer(
38+
IndexInput input,
39+
byte queryBits,
40+
byte indexBits,
41+
int dimension,
42+
int dataLength
43+
) throws IOException;
44+
3645
/** Create a new {@link ES91Int4VectorsScorer} for the given {@link IndexInput}. */
3746
public abstract ES91Int4VectorsScorer newES91Int4VectorsScorer(IndexInput input, int dimension) throws IOException;
3847

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
1717
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
1818
import org.elasticsearch.simdvec.ES92Int7VectorsScorer;
19+
import org.elasticsearch.simdvec.ESNextOSQVectorsScorer;
1920

2021
import java.io.IOException;
2122
import java.util.Locale;
@@ -40,6 +41,14 @@ public static ESVectorizationProvider getInstance() {
4041
/** Create a new {@link ES91OSQVectorsScorer} for the given {@link IndexInput}. */
4142
public abstract ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException;
4243

44+
public abstract ESNextOSQVectorsScorer newESNextOSQVectorsScorer(
45+
IndexInput input,
46+
byte queryBits,
47+
byte indexBits,
48+
int dimension,
49+
int dataLength
50+
) throws IOException;
51+
4352
/** Create a new {@link ES91Int4VectorsScorer} for the given {@link IndexInput}. */
4453
public abstract ES91Int4VectorsScorer newES91Int4VectorsScorer(IndexInput input, int dimension) throws IOException;
4554

0 commit comments

Comments
 (0)