Skip to content

Commit c925e9b

Browse files
kaivalnpbenwtrent
authored andcommitted
Implement off-heap quantized scoring (#14863)
Off-heap scoring for quantized vectors! Related to #13515 This scorer is in-line with [`Lucene99MemorySegmentFlatVectorsScorer`](https://github.com/apache/lucene/blob/77f0d1f6d6762ca6ac9af5acc0c950365050d939/lucene/core/src/java24/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFlatVectorsScorer.java#L30), and will automatically be used with [`PanamaVectorizationProvider`](https://github.com/apache/lucene/blob/77f0d1f6d6762ca6ac9af5acc0c950365050d939/lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java#L30C13-L30C40) (i.e. on adding `jdk.incubator.vector`). Note that the computations are already vectorized, but we're avoiding the unnecessary copy to heap here.. I added off-heap Dot Product functions for two compressed 4-bit ints (i.e. no need to "decompress" them) -- I can try to come up with similar ones for Euclidean if this approach seems fine..
1 parent b437fff commit c925e9b

File tree

17 files changed

+1113
-217
lines changed

17 files changed

+1113
-217
lines changed

lucene/CHANGES.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ Optimizations
4545
* GITHUB#15160: Increased the size used for blocks of postings from 128 to 256.
4646
This gives a noticeable speedup to many queries. (Adrien Grand)
4747

48+
* GITHUB#14863: Perform scoring for 4, 7, 8 bit quantized vectors off-heap. (Kaival Parikh)
49+
4850
Bug Fixes
4951
---------------------
5052
* GITHUB#14161: PointInSetQuery's constructor now throws IllegalArgumentException

lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorUtilBenchmark.java

Lines changed: 112 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,13 @@ static void compressBytes(byte[] raw, byte[] compressed) {
4343
private byte[] bytesA;
4444
private byte[] bytesB;
4545
private byte[] halfBytesA;
46+
private byte[] halfBytesAPacked;
4647
private byte[] halfBytesB;
4748
private byte[] halfBytesBPacked;
4849
private float[] floatsA;
4950
private float[] floatsB;
50-
private int expectedhalfByteDotProduct;
51+
private int expectedHalfByteDotProduct;
52+
private int expectedHalfByteSquareDistance;
5153

5254
@Param({"1", "128", "207", "256", "300", "512", "702", "1024"})
5355
int size;
@@ -63,16 +65,23 @@ public void init() {
6365
random.nextBytes(bytesB);
6466
// random half byte arrays for binary methods
6567
// this means that all values must be between 0 and 15
66-
expectedhalfByteDotProduct = 0;
68+
expectedHalfByteDotProduct = 0;
69+
expectedHalfByteSquareDistance = 0;
6770
halfBytesA = new byte[size];
6871
halfBytesB = new byte[size];
6972
for (int i = 0; i < size; ++i) {
7073
halfBytesA[i] = (byte) random.nextInt(16);
7174
halfBytesB[i] = (byte) random.nextInt(16);
72-
expectedhalfByteDotProduct += halfBytesA[i] * halfBytesB[i];
75+
expectedHalfByteDotProduct += halfBytesA[i] * halfBytesB[i];
76+
77+
int diff = halfBytesA[i] - halfBytesB[i];
78+
expectedHalfByteSquareDistance += diff * diff;
7379
}
7480
// pack the half byte arrays
7581
if (size % 2 == 0) {
82+
halfBytesAPacked = new byte[(size + 1) >> 1];
83+
compressBytes(halfBytesA, halfBytesAPacked);
84+
7685
halfBytesBPacked = new byte[(size + 1) >> 1];
7786
compressBytes(halfBytesB, halfBytesBPacked);
7887
}
@@ -97,6 +106,74 @@ public float binaryCosineVector() {
97106
return VectorUtil.cosine(bytesA, bytesB);
98107
}
99108

109+
@Benchmark
110+
public int binarySquareScalar() {
111+
return VectorUtil.squareDistance(bytesA, bytesB);
112+
}
113+
114+
@Benchmark
115+
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
116+
public int binarySquareVector() {
117+
return VectorUtil.squareDistance(bytesA, bytesB);
118+
}
119+
120+
@Benchmark
121+
public int binaryHalfByteSquareScalar() {
122+
int v = VectorUtil.int4SquareDistance(halfBytesA, halfBytesB);
123+
if (v != expectedHalfByteSquareDistance) {
124+
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
125+
}
126+
return v;
127+
}
128+
129+
@Benchmark
130+
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
131+
public int binaryHalfByteSquareVector() {
132+
int v = VectorUtil.int4SquareDistance(halfBytesA, halfBytesB);
133+
if (v != expectedHalfByteSquareDistance) {
134+
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
135+
}
136+
return v;
137+
}
138+
139+
@Benchmark
140+
public int binaryHalfByteSquareSinglePackedScalar() {
141+
int v = VectorUtil.int4SquareDistanceSinglePacked(halfBytesA, halfBytesBPacked);
142+
if (v != expectedHalfByteSquareDistance) {
143+
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
144+
}
145+
return v;
146+
}
147+
148+
@Benchmark
149+
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
150+
public int binaryHalfByteSquareSinglePackedVector() {
151+
int v = VectorUtil.int4SquareDistanceSinglePacked(halfBytesA, halfBytesBPacked);
152+
if (v != expectedHalfByteSquareDistance) {
153+
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
154+
}
155+
return v;
156+
}
157+
158+
@Benchmark
159+
public int binaryHalfByteSquareBothPackedScalar() {
160+
int v = VectorUtil.int4SquareDistanceBothPacked(halfBytesAPacked, halfBytesBPacked);
161+
if (v != expectedHalfByteSquareDistance) {
162+
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
163+
}
164+
return v;
165+
}
166+
167+
@Benchmark
168+
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
169+
public int binaryHalfByteSquareBothPackedVector() {
170+
int v = VectorUtil.int4SquareDistanceBothPacked(halfBytesAPacked, halfBytesBPacked);
171+
if (v != expectedHalfByteSquareDistance) {
172+
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
173+
}
174+
return v;
175+
}
176+
100177
@Benchmark
101178
public int binaryDotProductScalar() {
102179
return VectorUtil.dotProduct(bytesA, bytesB);
@@ -120,14 +197,22 @@ public int binaryDotProductUint8Vector() {
120197
}
121198

122199
@Benchmark
123-
public int binarySquareScalar() {
124-
return VectorUtil.squareDistance(bytesA, bytesB);
200+
public int binaryHalfByteDotProductScalar() {
201+
int v = VectorUtil.int4DotProduct(halfBytesA, halfBytesB);
202+
if (v != expectedHalfByteDotProduct) {
203+
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
204+
}
205+
return v;
125206
}
126207

127208
@Benchmark
128209
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
129-
public int binarySquareVector() {
130-
return VectorUtil.squareDistance(bytesA, bytesB);
210+
public int binaryHalfByteDotProductVector() {
211+
int v = VectorUtil.int4DotProduct(halfBytesA, halfBytesB);
212+
if (v != expectedHalfByteDotProduct) {
213+
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
214+
}
215+
return v;
131216
}
132217

133218
@Benchmark
@@ -142,37 +227,39 @@ public int binarySquareUint8Vector() {
142227
}
143228

144229
@Benchmark
145-
public int binaryHalfByteScalar() {
146-
return VectorUtil.int4DotProduct(halfBytesA, halfBytesB);
230+
public int binaryHalfByteDotProductSinglePackedScalar() {
231+
int v = VectorUtil.int4DotProductSinglePacked(halfBytesA, halfBytesBPacked);
232+
if (v != expectedHalfByteDotProduct) {
233+
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
234+
}
235+
return v;
147236
}
148237

149238
@Benchmark
150239
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
151-
public int binaryHalfByteVector() {
152-
return VectorUtil.int4DotProduct(halfBytesA, halfBytesB);
240+
public int binaryHalfByteDotProductSinglePackedVector() {
241+
int v = VectorUtil.int4DotProductSinglePacked(halfBytesA, halfBytesBPacked);
242+
if (v != expectedHalfByteDotProduct) {
243+
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
244+
}
245+
return v;
153246
}
154247

155248
@Benchmark
156-
public int binaryHalfByteScalarPacked() {
157-
if (size % 2 != 0) {
158-
throw new RuntimeException("Size must be even for this benchmark");
159-
}
160-
int v = VectorUtil.int4DotProductPacked(halfBytesA, halfBytesBPacked);
161-
if (v != expectedhalfByteDotProduct) {
162-
throw new RuntimeException("Expected " + expectedhalfByteDotProduct + " but got " + v);
249+
public int binaryHalfByteDotProductBothPackedScalar() {
250+
int v = VectorUtil.int4DotProductBothPacked(halfBytesAPacked, halfBytesBPacked);
251+
if (v != expectedHalfByteDotProduct) {
252+
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
163253
}
164254
return v;
165255
}
166256

167257
@Benchmark
168258
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
169-
public int binaryHalfByteVectorPacked() {
170-
if (size % 2 != 0) {
171-
throw new RuntimeException("Size must be even for this benchmark");
172-
}
173-
int v = VectorUtil.int4DotProductPacked(halfBytesA, halfBytesBPacked);
174-
if (v != expectedhalfByteDotProduct) {
175-
throw new RuntimeException("Expected " + expectedhalfByteDotProduct + " but got " + v);
259+
public int binaryHalfByteDotProductBothPackedVector() {
260+
int v = VectorUtil.int4DotProductBothPacked(halfBytesAPacked, halfBytesBPacked);
261+
if (v != expectedHalfByteDotProduct) {
262+
throw new RuntimeException("Expected " + expectedHalfByteDotProduct + " but got " + v);
176263
}
177264
return v;
178265
}

lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,8 @@ private FlatVectorScorerUtil() {}
3737
public static FlatVectorsScorer getLucene99FlatVectorsScorer() {
3838
return IMPL.getLucene99FlatVectorsScorer();
3939
}
40+
41+
public static FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() {
42+
return IMPL.getLucene99ScalarQuantizedVectorsScorer();
43+
}
4044
}

lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
2424
import org.apache.lucene.index.KnnVectorValues;
2525
import org.apache.lucene.index.VectorSimilarityFunction;
26+
import org.apache.lucene.util.FloatToFloatFunction;
2627
import org.apache.lucene.util.VectorUtil;
2728
import org.apache.lucene.util.hnsw.RandomVectorScorer;
2829
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
@@ -237,7 +238,7 @@ public float score(int vectorOrdinal) throws IOException {
237238
values.getSlice().seek((long) vectorOrdinal * (values.getVectorByteLength() + Float.BYTES));
238239
values.getSlice().readBytes(compressedVector, 0, compressedVector.length);
239240
float vectorOffset = values.getScoreCorrectionConstant(vectorOrdinal);
240-
int dotProduct = VectorUtil.int4DotProductPacked(targetBytes, compressedVector);
241+
int dotProduct = VectorUtil.int4DotProductSinglePacked(targetBytes, compressedVector);
241242
// For the current implementation of scalar quantization, all dotproducts should
242243
// be >= 0;
243244
assert dotProduct >= 0;
@@ -293,11 +294,6 @@ public void setScoringOrdinal(int node) throws IOException {
293294
}
294295
}
295296

296-
@FunctionalInterface
297-
private interface FloatToFloatFunction {
298-
float apply(float f);
299-
}
300-
301297
private static final class ScalarQuantizedRandomVectorScorerSupplier
302298
implements RandomVectorScorerSupplier {
303299

lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
package org.apache.lucene.codecs.lucene99;
1919

2020
import java.io.IOException;
21-
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
2221
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
2322
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
2423
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
24+
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
2525
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
2626
import org.apache.lucene.index.SegmentReadState;
2727
import org.apache.lucene.index.SegmentWriteState;
@@ -68,7 +68,7 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
6868

6969
final byte bits;
7070
final boolean compress;
71-
final Lucene99ScalarQuantizedVectorScorer flatVectorScorer;
71+
final FlatVectorsScorer flatVectorScorer;
7272

7373
/** Constructs a format using default graph construction parameters */
7474
public Lucene99ScalarQuantizedVectorsFormat() {
@@ -115,8 +115,7 @@ public Lucene99ScalarQuantizedVectorsFormat(
115115
this.bits = (byte) bits;
116116
this.confidenceInterval = confidenceInterval;
117117
this.compress = compress;
118-
this.flatVectorScorer =
119-
new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE);
118+
this.flatVectorScorer = FlatVectorScorerUtil.getLucene99ScalarQuantizedVectorsScorer();
120119
}
121120

122121
public static float calculateDefaultConfidenceInterval(int vectorDimension) {

lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java

Lines changed: 62 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -162,24 +162,35 @@ public int uint8DotProduct(byte[] a, byte[] b) {
162162
}
163163

164164
@Override
165-
public int int4DotProduct(byte[] a, boolean apacked, byte[] b, boolean bpacked) {
166-
assert (apacked && bpacked) == false;
167-
if (apacked || bpacked) {
168-
byte[] packed = apacked ? a : b;
169-
byte[] unpacked = apacked ? b : a;
170-
int total = 0;
171-
for (int i = 0; i < packed.length; i++) {
172-
byte packedByte = packed[i];
173-
byte unpacked1 = unpacked[i];
174-
byte unpacked2 = unpacked[i + packed.length];
175-
total += (packedByte & 0x0F) * unpacked2;
176-
total += ((packedByte & 0xFF) >> 4) * unpacked1;
177-
}
178-
return total;
179-
}
165+
public int int4DotProduct(byte[] a, byte[] b) {
180166
return dotProduct(a, b);
181167
}
182168

169+
@Override
170+
public int int4DotProductSinglePacked(byte[] unpacked, byte[] packed) {
171+
int total = 0;
172+
for (int i = 0; i < packed.length; i++) {
173+
byte packedByte = packed[i];
174+
byte unpacked1 = unpacked[i];
175+
byte unpacked2 = unpacked[i + packed.length];
176+
total += (packedByte & 0x0F) * unpacked2;
177+
total += ((packedByte & 0xFF) >> 4) * unpacked1;
178+
}
179+
return total;
180+
}
181+
182+
@Override
183+
public int int4DotProductBothPacked(byte[] a, byte[] b) {
184+
int total = 0;
185+
for (int i = 0; i < a.length; i++) {
186+
byte aByte = a[i];
187+
byte bByte = b[i];
188+
total += (aByte & 0x0F) * (bByte & 0x0F);
189+
total += ((aByte & 0xFF) >> 4) * ((bByte & 0xFF) >> 4);
190+
}
191+
return total;
192+
}
193+
183194
@Override
184195
public float cosine(byte[] a, byte[] b) {
185196
// Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14.
@@ -208,6 +219,42 @@ public int squareDistance(byte[] a, byte[] b) {
208219
return squareSum;
209220
}
210221

222+
@Override
223+
public int int4SquareDistance(byte[] a, byte[] b) {
224+
return squareDistance(a, b);
225+
}
226+
227+
@Override
228+
public int int4SquareDistanceSinglePacked(byte[] unpacked, byte[] packed) {
229+
int total = 0;
230+
for (int i = 0; i < packed.length; i++) {
231+
byte packedByte = packed[i];
232+
byte unpacked1 = unpacked[i];
233+
byte unpacked2 = unpacked[i + packed.length];
234+
235+
int diff1 = (packedByte & 0x0F) - unpacked2;
236+
int diff2 = ((packedByte & 0xFF) >> 4) - unpacked1;
237+
238+
total += diff1 * diff1 + diff2 * diff2;
239+
}
240+
return total;
241+
}
242+
243+
@Override
244+
public int int4SquareDistanceBothPacked(byte[] a, byte[] b) {
245+
int total = 0;
246+
for (int i = 0; i < a.length; i++) {
247+
byte aByte = a[i];
248+
byte bByte = b[i];
249+
250+
int diff1 = (aByte & 0x0F) - (bByte & 0x0F);
251+
int diff2 = ((aByte & 0xFF) >> 4) - ((bByte & 0xFF) >> 4);
252+
253+
total += diff1 * diff1 + diff2 * diff2;
254+
}
255+
return total;
256+
}
257+
211258
@Override
212259
public int uint8SquareDistance(byte[] a, byte[] b) {
213260
// Note: this will not overflow if dim < 2^16, since max(ubyte * ubyte) = 2^16.

lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
2121
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
22+
import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorScorer;
2223
import org.apache.lucene.store.IndexInput;
2324

2425
/** Default provider returning scalar implementations. */
@@ -40,6 +41,11 @@ public FlatVectorsScorer getLucene99FlatVectorsScorer() {
4041
return DefaultFlatVectorScorer.INSTANCE;
4142
}
4243

44+
@Override
45+
public FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() {
46+
return new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE);
47+
}
48+
4349
@Override
4450
public PostingDecodingUtil newPostingDecodingUtil(IndexInput input) {
4551
return new PostingDecodingUtil(input);

0 commit comments

Comments
 (0)