Skip to content

Commit ef9128f

Browse files
authored
Add support for uint8 distance comparison (#15148)
Add distance functions that treat `byte[]` as unsigned and use them in `ScalarQuantizer` code paths. `ScalarQuantizer` assumes that all math will be unsigned but this can't be true when all 8 bits of a `byte` are used. In all cases where this is used values are widened to `short` or `int` before any operations occur so we can perform unsigned widening instead. Fixes are required in a few other places to support unsigned extension, in particular this is required when recalculating the vector offset as part of merging. This will be useful for supporting 8 bit OSQ in #15064
1 parent 2550eb2 commit ef9128f

File tree

14 files changed

+229
-48
lines changed

14 files changed

+229
-48
lines changed

lucene/CHANGES.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ New Features
123123

124124
Improvements
125125
---------------------
126-
(No changes)
126+
# GITHUB#15148: Add support uint8 distance and allow 8 bit scalar quantization (Trevor McCulloch)
127127

128128
Optimizations
129129
---------------------

lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorUtilBenchmark.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,17 @@ public int binaryDotProductVector() {
119119
return VectorUtil.dotProduct(bytesA, bytesB);
120120
}
121121

122+
@Benchmark
123+
public int binaryDotProductUint8Scalar() {
124+
return VectorUtil.uint8DotProduct(bytesA, bytesB);
125+
}
126+
127+
@Benchmark
128+
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
129+
public int binaryDotProductUint8Vector() {
130+
return VectorUtil.uint8DotProduct(bytesA, bytesB);
131+
}
132+
122133
@Benchmark
123134
public int binarySquareScalar() {
124135
return VectorUtil.squareDistance(bytesA, bytesB);
@@ -130,6 +141,17 @@ public int binarySquareVector() {
130141
return VectorUtil.squareDistance(bytesA, bytesB);
131142
}
132143

144+
@Benchmark
145+
public int binarySquareUint8Scalar() {
146+
return VectorUtil.uint8SquareDistance(bytesA, bytesB);
147+
}
148+
149+
@Benchmark
150+
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
151+
public int binarySquareUint8Vector() {
152+
return VectorUtil.uint8SquareDistance(bytesA, bytesB);
153+
}
154+
133155
@Benchmark
134156
public int binaryHalfByteScalar() {
135157
return VectorUtil.int4DotProduct(halfBytesA, halfBytesB);

lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ private Euclidean(QuantizedByteVectorValues values, float constMultiplier, byte[
156156
@Override
157157
public float score(int node) throws IOException {
158158
byte[] nodeVector = values.vectorValue(node);
159-
int squareDistance = VectorUtil.squareDistance(nodeVector, targetBytes);
159+
int squareDistance = VectorUtil.uint8SquareDistance(nodeVector, targetBytes);
160160
float adjustedDistance = squareDistance * constMultiplier;
161161
return 1 / (1f + adjustedDistance);
162162
}
@@ -194,8 +194,9 @@ public DotProduct(
194194
public float score(int vectorOrdinal) throws IOException {
195195
byte[] storedVector = values.vectorValue(vectorOrdinal);
196196
float vectorOffset = values.getScoreCorrectionConstant(vectorOrdinal);
197-
int dotProduct = VectorUtil.dotProduct(storedVector, targetBytes);
198-
// For the current implementation of scalar quantization, all dotproducts should be >= 0;
197+
int dotProduct = VectorUtil.uint8DotProduct(storedVector, targetBytes);
198+
// For the current implementation of scalar quantization, all dotproducts should
199+
// be >= 0;
199200
assert dotProduct >= 0;
200201
float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset;
201202
return scoreAdjustmentFunction.apply(adjustedDistance);
@@ -208,9 +209,10 @@ public void setScoringOrdinal(int node) throws IOException {
208209
}
209210
}
210211

211-
// TODO consider splitting this into two classes. right now the "query" vector is always
212+
// TODO consider splitting this into two classes. right now the "query" vector
213+
// is always
212214
// decompressed
213-
// it could stay compressed if we had a compressed version of the target vector
215+
// it could stay compressed if we had a compressed version of the target vector
214216
private static class CompressedInt4DotProduct
215217
extends UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer {
216218
private final float constMultiplier;
@@ -237,13 +239,15 @@ private CompressedInt4DotProduct(
237239

238240
@Override
239241
public float score(int vectorOrdinal) throws IOException {
240-
// get compressed vector, in Lucene99, vector values are stored and have a single value for
242+
// get compressed vector, in Lucene99, vector values are stored and have a
243+
// single value for
241244
// offset correction
242245
values.getSlice().seek((long) vectorOrdinal * (values.getVectorByteLength() + Float.BYTES));
243246
values.getSlice().readBytes(compressedVector, 0, compressedVector.length);
244247
float vectorOffset = values.getScoreCorrectionConstant(vectorOrdinal);
245248
int dotProduct = VectorUtil.int4DotProductPacked(targetBytes, compressedVector);
246-
// For the current implementation of scalar quantization, all dotproducts should be >= 0;
249+
// For the current implementation of scalar quantization, all dotproducts should
250+
// be >= 0;
247251
assert dotProduct >= 0;
248252
float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset;
249253
return scoreAdjustmentFunction.apply(adjustedDistance);
@@ -283,7 +287,8 @@ public float score(int vectorOrdinal) throws IOException {
283287
byte[] storedVector = values.vectorValue(vectorOrdinal);
284288
float vectorOffset = values.getScoreCorrectionConstant(vectorOrdinal);
285289
int dotProduct = VectorUtil.int4DotProduct(storedVector, targetBytes);
286-
// For the current implementation of scalar quantization, all dotproducts should be >= 0;
290+
// For the current implementation of scalar quantization, all dotproducts should
291+
// be >= 0;
287292
assert dotProduct >= 0;
288293
float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset;
289294
return scoreAdjustmentFunction.apply(adjustedDistance);

lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,8 @@
3434
public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
3535

3636
// The bits that are allowed for scalar quantization
37-
// We only allow signed byte (7), and half-byte (4)
38-
// NOTE: we used to allow 8 bits as well, but it was broken so we removed it
39-
// (https://github.com/apache/lucene/issues/13519)
40-
private static final int ALLOWED_BITS = (1 << 7) | (1 << 4);
37+
// We only allow unsigned byte (8), signed byte (7), and half-byte (4)
38+
private static final int ALLOWED_BITS = (1 << 8) | (1 << 7) | (1 << 4);
4139
public static final String QUANTIZED_VECTOR_COMPONENT = "QVEC";
4240

4341
public static final String NAME = "Lucene99ScalarQuantizedVectorsFormat";

lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,15 @@ public int dotProduct(byte[] a, byte[] b) {
154154
return total;
155155
}
156156

157+
@Override
158+
public int uint8DotProduct(byte[] a, byte[] b) {
159+
int total = 0;
160+
for (int i = 0; i < a.length; i++) {
161+
total += Byte.toUnsignedInt(a[i]) * Byte.toUnsignedInt(b[i]);
162+
}
163+
return total;
164+
}
165+
157166
@Override
158167
public int int4DotProduct(byte[] a, boolean apacked, byte[] b, boolean bpacked) {
159168
assert (apacked && bpacked) == false;
@@ -201,6 +210,17 @@ public int squareDistance(byte[] a, byte[] b) {
201210
return squareSum;
202211
}
203212

213+
@Override
214+
public int uint8SquareDistance(byte[] a, byte[] b) {
215+
// Note: this will not overflow if dim < 2^16, since max(ubyte * ubyte) = 2^16.
216+
int squareSum = 0;
217+
for (int i = 0; i < a.length; i++) {
218+
int diff = Byte.toUnsignedInt(a[i]) - Byte.toUnsignedInt(b[i]);
219+
squareSum += diff * diff;
220+
}
221+
return squareSum;
222+
}
223+
204224
@Override
205225
public int findNextGEQ(int[] buffer, int target, int from, int to) {
206226
for (int i = from; i < to; ++i) {
@@ -281,7 +301,7 @@ float recalculateOffset(byte[] vector, int start, float oldAlpha, float oldMinQu
281301
float correction = 0;
282302
for (int i = start; i < vector.length; i++) {
283303
// undo the old quantization
284-
float v = (oldAlpha * vector[i]) + oldMinQuantile;
304+
float v = (oldAlpha * Byte.toUnsignedInt(vector[i])) + oldMinQuantile;
285305
correction += quantizeFloat(v, null, 0);
286306
}
287307
return correction;

lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ public interface VectorUtilSupport {
3636
/** Returns the dot product computed over signed bytes. */
3737
int dotProduct(byte[] a, byte[] b);
3838

39+
/** Returns the dot product computed as though the bytes were unsigned. */
40+
int uint8DotProduct(byte[] a, byte[] b);
41+
3942
/** Returns the dot product over the computed bytes, assuming the values are int4 encoded. */
4043
int int4DotProduct(byte[] a, boolean apacked, byte[] b, boolean bpacked);
4144

@@ -45,6 +48,9 @@ public interface VectorUtilSupport {
4548
/** Returns the sum of squared differences of the two byte vectors. */
4649
int squareDistance(byte[] a, byte[] b);
4750

51+
/** Returns the sum of squared differences of the two unsigned byte vectors. */
52+
int uint8SquareDistance(byte[] a, byte[] b);
53+
4854
/**
4955
* Given an array {@code buffer} that is sorted between indexes {@code 0} inclusive and {@code to}
5056
* exclusive, find the first array index whose value is greater than or equal to {@code target}.

lucene/core/src/java/org/apache/lucene/util/VectorUtil.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,14 @@ public static int squareDistance(byte[] a, byte[] b) {
113113
return IMPL.squareDistance(a, b);
114114
}
115115

116+
/** Returns the sum of squared differences of the two vectors where each byte is unsigned */
117+
public static int uint8SquareDistance(byte[] a, byte[] b) {
118+
if (a.length != b.length) {
119+
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
120+
}
121+
return IMPL.uint8SquareDistance(a, b);
122+
}
123+
116124
/**
117125
* Modifies the argument to be unit length, dividing by its l2-norm. IllegalArgumentException is
118126
* thrown for zero vectors.
@@ -167,6 +175,20 @@ public static int dotProduct(byte[] a, byte[] b) {
167175
return IMPL.dotProduct(a, b);
168176
}
169177

178+
/**
179+
* Dot product over bytes assuming that the values are actually unsigned.
180+
*
181+
* @param a uint8 byte vector
182+
* @param b another uint8 byte vector of the same dimension
183+
* @return the value of the dot product of the two vectors
184+
*/
185+
public static int uint8DotProduct(byte[] a, byte[] b) {
186+
if (a.length != b.length) {
187+
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
188+
}
189+
return IMPL.uint8DotProduct(a, b);
190+
}
191+
170192
public static int int4DotProduct(byte[] a, byte[] b) {
171193
if (a.length != b.length) {
172194
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);

lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedVectorSimilarity.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,12 @@ static ScalarQuantizedVectorSimilarity fromVectorSimilarity(
4242
case EUCLIDEAN -> new Euclidean(constMultiplier);
4343
case COSINE, DOT_PRODUCT ->
4444
new DotProduct(
45-
constMultiplier, bits <= 4 ? VectorUtil::int4DotProduct : VectorUtil::dotProduct);
45+
constMultiplier,
46+
bits <= 4 ? VectorUtil::int4DotProduct : VectorUtil::uint8DotProduct);
4647
case MAXIMUM_INNER_PRODUCT ->
4748
new MaximumInnerProduct(
48-
constMultiplier, bits <= 4 ? VectorUtil::int4DotProduct : VectorUtil::dotProduct);
49+
constMultiplier,
50+
bits <= 4 ? VectorUtil::int4DotProduct : VectorUtil::uint8DotProduct);
4951
};
5052
}
5153

@@ -62,7 +64,7 @@ public Euclidean(float constMultiplier) {
6264
@Override
6365
public float score(
6466
byte[] queryVector, float queryVectorOffset, byte[] storedVector, float vectorOffset) {
65-
int squareDistance = VectorUtil.squareDistance(storedVector, queryVector);
67+
int squareDistance = VectorUtil.uint8SquareDistance(storedVector, queryVector);
6668
float adjustedDistance = squareDistance * constMultiplier;
6769
return 1 / (1f + adjustedDistance);
6870
}

lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ public float recalculateCorrectiveOffset(
165165
public void deQuantize(byte[] src, float[] dest) {
166166
assert src.length == dest.length;
167167
for (int i = 0; i < src.length; i++) {
168-
dest[i] = (alpha * src[i]) + minQuantile;
168+
dest[i] = (alpha * Byte.toUnsignedInt(src[i])) + minQuantile;
169169
}
170170
}
171171

0 commit comments

Comments
 (0)