Skip to content

Commit 80d91ce

Browse files
authored
Initial 2 bit support for DiskBBQ next (#136989)
This adds the initial 2 bit support for diskbbq next. This provides 5-10% or so recall improvement depending on the data, at least that is what I found in some correctness testing. I didn't do all the SIMD stuff, as this just gets out of hand. However, I expect it to be right around 2x as slow per vector op as binary (it is effectively two binary ops). There might be a faster way to encode and then do ops with 2 bits. Marking as non-issue as this is an unreleased format change.
1 parent ced3ca8 commit 80d91ce

File tree

15 files changed

+265
-35
lines changed

15 files changed

+265
-35
lines changed

libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESNextOSQVectorsScorer.java

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ public class ESNextOSQVectorsScorer {
4848

4949
/** Sole constructor, called by sub-classes. */
5050
public ESNextOSQVectorsScorer(IndexInput in, byte queryBits, byte indexBits, int dimensions, int dataLength) {
51-
if (queryBits != 4 || indexBits != 1) {
52-
throw new IllegalArgumentException("Only asymmetric 4-bit query and 1-bit index supported");
51+
if (queryBits != 4 || (indexBits != 1 && indexBits != 2)) {
52+
throw new IllegalArgumentException("Only asymmetric 4-bit query and 1 or 2-bit index supported");
5353
}
5454
this.in = in;
5555
this.queryBits = queryBits;
@@ -65,15 +65,27 @@ public ESNextOSQVectorsScorer(IndexInput in, byte queryBits, byte indexBits, int
6565
public long quantizeScore(byte[] q) throws IOException {
6666
if (indexBits == 1) {
6767
if (queryBits == 4) {
68-
return quantized4BitScore(q);
68+
return quantized4BitScore(q, length);
6969
}
7070
throw new IllegalArgumentException("Only asymmetric 4-bit query supported");
7171
}
72+
if (indexBits == 2) {
73+
if (queryBits == 4) {
74+
return quantized4BitScore2BitIndex(q);
75+
}
76+
}
7277
throw new IllegalArgumentException("Only 1-bit index supported");
78+
}
7379

80+
private long quantized4BitScore2BitIndex(byte[] q) throws IOException {
81+
assert q.length == length * 2;
82+
assert length % 2 == 0 : "length must be even for 2-bit index length: " + length + " dimensions: " + dimensions;
83+
int lower = (int) quantized4BitScore(q, length / 2);
84+
int upper = (int) quantized4BitScore(q, length / 2);
85+
return lower + ((long) upper << 1);
7486
}
7587

76-
private long quantized4BitScore(byte[] q) throws IOException {
88+
private long quantized4BitScore(byte[] q, int length) throws IOException {
7789
assert q.length == length * 4;
7890
final int size = length;
7991
long subRet0 = 0;
@@ -120,6 +132,15 @@ public void quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOExce
120132
}
121133
throw new IllegalArgumentException("Only asymmetric 4-bit query supported");
122134
}
135+
if (indexBits == 2) {
136+
if (queryBits == 4) {
137+
for (int i = 0; i < count; i++) {
138+
scores[i] = quantizeScore(q);
139+
}
140+
return;
141+
}
142+
throw new IllegalArgumentException("Only asymmetric 4-bit query supported");
143+
}
123144
}
124145

125146
/**
@@ -140,9 +161,9 @@ public float score(
140161
) {
141162
float ax = lowerInterval;
142163
// Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
143-
float lx = (upperInterval - ax) * BIT_SCALES[indexBits];
164+
float lx = (upperInterval - ax) * BIT_SCALES[indexBits - 1];
144165
float ay = queryLowerInterval;
145-
float ly = (queryUpperInterval - ay) * BIT_SCALES[queryBits];
166+
float ly = (queryUpperInterval - ay) * BIT_SCALES[queryBits - 1];
146167
float y1 = queryComponentSum;
147168
float score = ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist;
148169
// For euclidean, we need to invert the score and apply the additional correction, which is

libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,13 @@ public static void packAsBinary(int[] vector, byte[] packed) {
393393
IMPL.packAsBinary(vector, packed);
394394
}
395395

396+
public static void packDibit(int[] vector, byte[] packed) {
397+
if (packed.length * Byte.SIZE / 2 < vector.length) {
398+
throw new IllegalArgumentException("packed array is too small: " + packed.length * Byte.SIZE / 2 + " < " + vector.length);
399+
}
400+
IMPL.packDibit(vector, packed);
401+
}
402+
396403
/**
397404
* The idea here is to organize the query vector bits such that the first bit
398405
* of every dimension is in the first set dimensions bits, or (dimensions/8) bytes. The second,

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,11 +321,58 @@ public void soarDistanceBulk(
321321
distances[3] = soarDistance(v1, c3, originalResidual, soarLambda, rnorm);
322322
}
323323

324+
@Override
325+
public void packDibit(int[] vector, byte[] packed) {
326+
packDibitImpl(vector, packed);
327+
}
328+
324329
@Override
325330
public void packAsBinary(int[] vector, byte[] packed) {
326331
packAsBinaryImpl(vector, packed);
327332
}
328333

334+
/**
335+
* Packs two bit vector (values 0-3) into a byte array with lower bits first.
336+
* The striding is similar to transposeHalfByte
337+
*
338+
* @param vector the input vector with values 0-3
339+
* @param packed the output packed byte array
340+
*/
341+
public static void packDibitImpl(int[] vector, byte[] packed) {
342+
int limit = vector.length - 7;
343+
int i = 0;
344+
int index = 0;
345+
for (; i < limit; i += 8, index++) {
346+
assert vector[i] >= 0 && vector[i] <= 3;
347+
assert vector[i + 1] >= 0 && vector[i + 1] <= 3;
348+
assert vector[i + 2] >= 0 && vector[i + 2] <= 3;
349+
assert vector[i + 3] >= 0 && vector[i + 3] <= 3;
350+
assert vector[i + 4] >= 0 && vector[i + 4] <= 3;
351+
assert vector[i + 5] >= 0 && vector[i + 5] <= 3;
352+
assert vector[i + 6] >= 0 && vector[i + 6] <= 3;
353+
assert vector[i + 7] >= 0 && vector[i + 7] <= 3;
354+
int lowerByte = (vector[i] & 1) << 7 | (vector[i + 1] & 1) << 6 | (vector[i + 2] & 1) << 5 | (vector[i + 3] & 1) << 4
355+
| (vector[i + 4] & 1) << 3 | (vector[i + 5] & 1) << 2 | (vector[i + 6] & 1) << 1 | (vector[i + 7] & 1);
356+
int upperByte = ((vector[i] >> 1) & 1) << 7 | ((vector[i + 1] >> 1) & 1) << 6 | ((vector[i + 2] >> 1) & 1) << 5 | ((vector[i
357+
+ 3] >> 1) & 1) << 4 | ((vector[i + 4] >> 1) & 1) << 3 | ((vector[i + 5] >> 1) & 1) << 2 | ((vector[i + 6] >> 1) & 1) << 1
358+
| ((vector[i + 7] >> 1) & 1);
359+
packed[index] = (byte) lowerByte;
360+
packed[index + packed.length / 2] = (byte) upperByte;
361+
}
362+
if (i == vector.length) {
363+
return;
364+
}
365+
int lowerByte = 0;
366+
int upperByte = 0;
367+
for (int j = 7; i < vector.length; j--, i++) {
368+
assert vector[i] >= 0 && vector[i] <= 3;
369+
lowerByte |= (vector[i] & 1) << j;
370+
upperByte |= ((vector[i] >> 1) & 1) << j;
371+
}
372+
packed[index] = (byte) lowerByte;
373+
packed[index + packed.length / 2] = (byte) upperByte;
374+
}
375+
329376
public static void packAsBinaryImpl(int[] vector, byte[] packed) {
330377
int limit = vector.length - 7;
331378
int i = 0;

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ void soarDistanceBulk(
6666

6767
void packAsBinary(int[] vector, byte[] packed);
6868

69+
void packDibit(int[] vector, byte[] packed);
70+
6971
void transposeHalfByte(int[] q, byte[] quantQueryByte);
7072

7173
int indexOf(byte[] bytes, int offset, int length, byte marker);

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,6 +1023,12 @@ private void packAsBinary128(int[] vector, byte[] packed) {
10231023
packed[index] = result;
10241024
}
10251025

1026+
@Override
1027+
public void packDibit(int[] vector, byte[] packed) {
1028+
// TODO
1029+
DefaultESVectorUtilSupport.packDibitImpl(vector, packed);
1030+
}
1031+
10261032
@Override
10271033
public void transposeHalfByte(int[] q, byte[] quantQueryByte) {
10281034
// 128 / 32 == 4

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,11 @@ public ESVectorUtilSupport getVectorUtilSupport() {
3636
@Override
3737
public ESNextOSQVectorsScorer newESNextOSQVectorsScorer(IndexInput input, byte queryBits, byte indexBits, int dimension, int dataLength)
3838
throws IOException {
39-
if (PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS && input instanceof MemorySegmentAccessInput msai) {
39+
// TODO: Extend to other bit configurations as needed
40+
if (PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS
41+
&& input instanceof MemorySegmentAccessInput msai
42+
&& queryBits == 4
43+
&& indexBits == 1) {
4044
MemorySegment ms = msai.segmentSliceOrNull(0, input.length());
4145
if (ms != null) {
4246
return new MemorySegmentESNextOSQVectorsScorer(input, queryBits, indexBits, dimension, dataLength, ms);

libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import org.elasticsearch.index.codec.vectors.BQVectorUtils;
1313
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
14+
import org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat;
1415
import org.elasticsearch.simdvec.internal.vectorization.BaseVectorizationTests;
1516
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
1617

@@ -438,6 +439,41 @@ public void testTransposeHalfByte() {
438439
assertArrayEquals(packedLegacy, packed);
439440
}
440441

442+
public void testPackAsDibit() {
443+
int dims = randomIntBetween(16, 2048);
444+
int[] toPack = new int[dims];
445+
for (int i = 0; i < dims; i++) {
446+
toPack[i] = randomInt(3);
447+
}
448+
int length = ESNextDiskBBQVectorsFormat.QuantEncoding.TWO_BIT_4BIT_QUERY.getDocPackedLength(dims);
449+
;
450+
byte[] packed = new byte[length];
451+
byte[] packedLegacy = new byte[length];
452+
defaultedProvider.getVectorUtilSupport().packDibit(toPack, packedLegacy);
453+
defOrPanamaProvider.getVectorUtilSupport().packDibit(toPack, packed);
454+
assertArrayEquals(packedLegacy, packed);
455+
}
456+
457+
public void testPackDibitCorrectness() {
458+
// 5 bits
459+
// binary lower bits 1 1 0 0 1
460+
// binary upper bits 0 1 1 0 0
461+
// resulting dibit 1 3 2 0 1
462+
int[] toPack = new int[] { 1, 3, 2, 0, 1 };
463+
byte[] packed = new byte[2];
464+
ESVectorUtil.packDibit(toPack, packed);
465+
assertArrayEquals(new byte[] { (byte) 0b11001000, (byte) 0b01100000 }, packed);
466+
467+
// 8 bits
468+
// binary lower bits 1 1 0 0 1 0 1 0
469+
// binary upper bits 0 1 1 0 0 1 0 1
470+
// resulting dibit 1 3 2 0 1 2 1 2
471+
toPack = new int[] { 1, 3, 2, 0, 1, 2, 1, 2 };
472+
packed = new byte[2];
473+
ESVectorUtil.packDibit(toPack, packed);
474+
assertArrayEquals(new byte[] { (byte) 0b11001010, (byte) 0b01100101 }, packed);
475+
}
476+
441477
private float[] generateRandomVector(int size) {
442478
float[] vector = new float[size];
443479
for (int i = 0; i < size; ++i) {

qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat;
3636
import org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat;
3737
import org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat;
38+
import org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat;
3839
import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat;
3940
import org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat;
4041
import org.elasticsearch.logging.Level;
@@ -117,7 +118,14 @@ private static String formatIndexPath(CmdLineArgs args) {
117118
static Codec createCodec(CmdLineArgs args) {
118119
final KnnVectorsFormat format;
119120
if (args.indexType() == IndexType.IVF) {
120-
format = new ES920DiskBBQVectorsFormat(args.ivfClusterSize(), ES920DiskBBQVectorsFormat.DEFAULT_CENTROIDS_PER_PARENT_CLUSTER);
121+
ESNextDiskBBQVectorsFormat.QuantEncoding encoding = args.quantizeBits() == 1
122+
? ESNextDiskBBQVectorsFormat.QuantEncoding.ONE_BIT_4BIT_QUERY
123+
: ESNextDiskBBQVectorsFormat.QuantEncoding.TWO_BIT_4BIT_QUERY;
124+
format = new ESNextDiskBBQVectorsFormat(
125+
encoding,
126+
args.ivfClusterSize(),
127+
ES920DiskBBQVectorsFormat.DEFAULT_CENTROIDS_PER_PARENT_CLUSTER
128+
);
121129
} else if (args.indexType() == IndexType.GPU_HNSW) {
122130
if (args.quantizeBits() == 32) {
123131
format = new ES92GpuHnswVectorsFormat();

server/src/main/java/module-info.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,7 @@
496496
exports org.elasticsearch.index.codec.vectors.es818 to org.elasticsearch.test.knn;
497497
exports org.elasticsearch.inference.telemetry;
498498
exports org.elasticsearch.index.codec.vectors.diskbbq to org.elasticsearch.test.knn;
499+
exports org.elasticsearch.index.codec.vectors.diskbbq.next to org.elasticsearch.test.knn;
499500
exports org.elasticsearch.index.codec.vectors.cluster to org.elasticsearch.test.knn;
500501
exports org.elasticsearch.index.codec.vectors.es93 to org.elasticsearch.test.knn;
501502
exports org.elasticsearch.search.crossproject;

server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/DiskBBQBulkWriter.java

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
* This class provides the structure for writing vectors in bulk, with specific
2121
* implementations for different bit sizes strategies.
2222
*/
23-
public abstract class DiskBBQBulkWriter {
23+
public abstract sealed class DiskBBQBulkWriter {
2424
protected final int bulkSize;
2525
protected final IndexOutput out;
2626

@@ -31,10 +31,25 @@ protected DiskBBQBulkWriter(int bulkSize, IndexOutput out) {
3131

3232
public abstract void writeVectors(QuantizedVectorValues qvv, CheckedIntConsumer<IOException> docsWriter) throws IOException;
3333

34-
public static class OneBitDiskBBQBulkWriter extends DiskBBQBulkWriter {
34+
/**
35+
* Factory method to create a DiskBBQBulkWriter based on the bit size.
36+
* @param bitSize the bit size of the quantized vectors
37+
* @param bulkSize the number of vectors to write in bulk
38+
* @param out the IndexOutput to write to
39+
* @return a DiskBBQBulkWriter instance
40+
*/
41+
public static DiskBBQBulkWriter fromBitSize(int bitSize, int bulkSize, IndexOutput out) {
42+
return switch (bitSize) {
43+
case 1, 2 -> new SmallBitDiskBBQBulkWriter(bulkSize, out);
44+
case 7 -> new LargeBitDiskBBQBulkWriter(bulkSize, out);
45+
default -> throw new IllegalArgumentException("Unsupported bit size: " + bitSize);
46+
};
47+
}
48+
49+
private static final class SmallBitDiskBBQBulkWriter extends DiskBBQBulkWriter {
3550
private final OptimizedScalarQuantizer.QuantizationResult[] corrections;
3651

37-
public OneBitDiskBBQBulkWriter(int bulkSize, IndexOutput out) {
52+
private SmallBitDiskBBQBulkWriter(int bulkSize, IndexOutput out) {
3853
super(bulkSize, out);
3954
this.corrections = new OptimizedScalarQuantizer.QuantizationResult[bulkSize];
4055
}
@@ -93,10 +108,10 @@ private void writeCorrection(OptimizedScalarQuantizer.QuantizationResult correct
93108
}
94109
}
95110

96-
public static class SevenBitDiskBBQBulkWriter extends DiskBBQBulkWriter {
111+
private static final class LargeBitDiskBBQBulkWriter extends DiskBBQBulkWriter {
97112
private final OptimizedScalarQuantizer.QuantizationResult[] corrections;
98113

99-
public SevenBitDiskBBQBulkWriter(int bulkSize, IndexOutput out) {
114+
private LargeBitDiskBBQBulkWriter(int bulkSize, IndexOutput out) {
100115
super(bulkSize, out);
101116
this.corrections = new OptimizedScalarQuantizer.QuantizationResult[bulkSize];
102117
}

0 commit comments

Comments
 (0)