Skip to content

Commit 1f29688

Browse files
authored
Add 4-bit quantization to DiskBBQ next (#137336)
1 parent c097a1b commit 1f29688

File tree

6 files changed

+94
-25
lines changed

6 files changed

+94
-25
lines changed

libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESNextOSQVectorsScorer.java

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ public class ESNextOSQVectorsScorer {
4848

4949
/** Sole constructor, called by sub-classes. */
5050
public ESNextOSQVectorsScorer(IndexInput in, byte queryBits, byte indexBits, int dimensions, int dataLength) {
51-
if (queryBits != 4 || (indexBits != 1 && indexBits != 2)) {
52-
throw new IllegalArgumentException("Only asymmetric 4-bit query and 1 or 2-bit index supported");
51+
if (queryBits != 4 || (indexBits != 1 && indexBits != 2 && indexBits != 4)) {
52+
throw new IllegalArgumentException("Only asymmetric 4-bit query and 1, 2 or 4-bit index supported");
5353
}
5454
this.in = in;
5555
this.queryBits = queryBits;
@@ -67,16 +67,31 @@ public long quantizeScore(byte[] q) throws IOException {
6767
if (queryBits == 4) {
6868
return quantized4BitScore(q, length);
6969
}
70-
throw new IllegalArgumentException("Only asymmetric 4-bit query supported");
70+
throw new IllegalArgumentException("Only asymmetric 4-bit query supported for 1-bit index");
7171
}
7272
if (indexBits == 2) {
7373
if (queryBits == 4) {
7474
return quantized4BitScore2BitIndex(q);
7575
}
7676
}
77+
if (indexBits == 4) {
78+
if (queryBits == 4) {
79+
return quantized4BitScoreSymmetric(q);
80+
}
81+
}
7782
throw new IllegalArgumentException("Only 1-bit index supported");
7883
}
7984

85+
private long quantized4BitScoreSymmetric(byte[] q) throws IOException {
86+
assert q.length == length : "length mismatch q " + q.length + " vs " + length;
87+
assert length % 4 == 0 : "length must be multiple of 4 for 4-bit index length: " + length + " dimensions: " + dimensions;
88+
int stripe0 = (int) quantized4BitScore(q, length / 4);
89+
int stripe1 = (int) quantized4BitScore(q, length / 4);
90+
int stripe2 = (int) quantized4BitScore(q, length / 4);
91+
int stripe3 = (int) quantized4BitScore(q, length / 4);
92+
return stripe0 + ((long) stripe1 << 1) + ((long) stripe2 << 2) + ((long) stripe3 << 3);
93+
}
94+
8095
private long quantized4BitScore2BitIndex(byte[] q) throws IOException {
8196
assert q.length == length * 2;
8297
assert length % 2 == 0 : "length must be even for 2-bit index length: " + length + " dimensions: " + dimensions;
@@ -130,7 +145,7 @@ public void quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOExce
130145
}
131146
return;
132147
}
133-
throw new IllegalArgumentException("Only asymmetric 4-bit query supported");
148+
throw new IllegalArgumentException("Only asymmetric 4-bit query supported for 1-bit index");
134149
}
135150
if (indexBits == 2) {
136151
if (queryBits == 4) {
@@ -139,7 +154,16 @@ public void quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOExce
139154
}
140155
return;
141156
}
142-
throw new IllegalArgumentException("Only asymmetric 4-bit query supported");
157+
throw new IllegalArgumentException("Only asymmetric 4-bit query supported for 2-bit index");
158+
}
159+
if (indexBits == 4) {
160+
if (queryBits == 4) {
161+
for (int i = 0; i < count; i++) {
162+
scores[i] = quantizeScore(q);
163+
}
164+
return;
165+
}
166+
throw new IllegalArgumentException("Only symmetric 4-bit query supported for 4-bit index");
143167
}
144168
}
145169

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ public ESVectorUtilSupport getVectorUtilSupport() {
3636
@Override
3737
public ESNextOSQVectorsScorer newESNextOSQVectorsScorer(IndexInput input, byte queryBits, byte indexBits, int dimension, int dataLength)
3838
throws IOException {
39-
// TODO: Extend to other bit configurations as needed
39+
// TODO: Extend to other bit configurations as needed (2 and 4 bit index vectors)
4040
if (PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS
4141
&& input instanceof MemorySegmentAccessInput msai
4242
&& queryBits == 4

qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -117,43 +117,41 @@ private static String formatIndexPath(CmdLineArgs args) {
117117

118118
static Codec createCodec(CmdLineArgs args) {
119119
final KnnVectorsFormat format;
120+
int quantizeBits = args.quantizeBits();
120121
if (args.indexType() == IndexType.IVF) {
121-
ESNextDiskBBQVectorsFormat.QuantEncoding encoding = args.quantizeBits() == 1
122-
? ESNextDiskBBQVectorsFormat.QuantEncoding.ONE_BIT_4BIT_QUERY
123-
: ESNextDiskBBQVectorsFormat.QuantEncoding.TWO_BIT_4BIT_QUERY;
122+
ESNextDiskBBQVectorsFormat.QuantEncoding encoding = switch (quantizeBits) {
123+
case (1) -> ESNextDiskBBQVectorsFormat.QuantEncoding.ONE_BIT_4BIT_QUERY;
124+
case (2) -> ESNextDiskBBQVectorsFormat.QuantEncoding.TWO_BIT_4BIT_QUERY;
125+
case (4) -> ESNextDiskBBQVectorsFormat.QuantEncoding.FOUR_BIT_SYMMETRIC;
126+
default -> throw new IllegalArgumentException(
127+
"IVF index type only supports 1, 2 or 4 bits quantization, but got: " + quantizeBits
128+
);
129+
};
124130
format = new ESNextDiskBBQVectorsFormat(
125131
encoding,
126132
args.ivfClusterSize(),
127133
ES920DiskBBQVectorsFormat.DEFAULT_CENTROIDS_PER_PARENT_CLUSTER
128134
);
129135
} else if (args.indexType() == IndexType.GPU_HNSW) {
130-
if (args.quantizeBits() == 32) {
136+
if (quantizeBits == 32) {
131137
format = new ES92GpuHnswVectorsFormat();
132-
} else if (args.quantizeBits() == 7) {
138+
} else if (quantizeBits == 7) {
133139
format = new ES92GpuHnswSQVectorsFormat();
134140
} else {
135-
throw new IllegalArgumentException(
136-
"GPU HNSW index type only supports 7 or 32 bits quantization, but got: " + args.quantizeBits()
137-
);
141+
throw new IllegalArgumentException("GPU HNSW index type only supports 7 or 32 bits quantization, but got: " + quantizeBits);
138142
}
139143
} else {
140-
if (args.quantizeBits() == 1) {
144+
if (quantizeBits == 1) {
141145
if (args.indexType() == IndexType.FLAT) {
142146
format = new ES818BinaryQuantizedVectorsFormat();
143147
} else {
144148
format = new ES818HnswBinaryQuantizedVectorsFormat(args.hnswM(), args.hnswEfConstruction(), 1, null);
145149
}
146-
} else if (args.quantizeBits() < 32) {
150+
} else if (quantizeBits < 32) {
147151
if (args.indexType() == IndexType.FLAT) {
148-
format = new ES813Int8FlatVectorFormat(null, args.quantizeBits(), true);
152+
format = new ES813Int8FlatVectorFormat(null, quantizeBits, true);
149153
} else {
150-
format = new ES814HnswScalarQuantizedVectorsFormat(
151-
args.hnswM(),
152-
args.hnswEfConstruction(),
153-
null,
154-
args.quantizeBits(),
155-
true
156-
);
154+
format = new ES814HnswScalarQuantizedVectorsFormat(args.hnswM(), args.hnswEfConstruction(), null, quantizeBits, true);
157155
}
158156
} else {
159157
format = new Lucene99HnswVectorsFormat(args.hnswM(), args.hnswEfConstruction(), 1, null);

server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/DiskBBQBulkWriter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ protected DiskBBQBulkWriter(int bulkSize, IndexOutput out) {
4040
*/
4141
public static DiskBBQBulkWriter fromBitSize(int bitSize, int bulkSize, IndexOutput out) {
4242
return switch (bitSize) {
43-
case 1, 2 -> new SmallBitDiskBBQBulkWriter(bulkSize, out);
43+
case 1, 2, 4 -> new SmallBitDiskBBQBulkWriter(bulkSize, out);
4444
case 7 -> new LargeBitDiskBBQBulkWriter(bulkSize, out);
4545
default -> throw new IllegalArgumentException("Unsupported bit size: " + bitSize);
4646
};

server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsFormat.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,34 @@ public int getDocPackedLength(int dimensions) {
108108
int discretized = discretizedDimensions(dimensions);
109109
return 2 * ((discretized + 7) / 8);
110110
}
111+
},
112+
FOUR_BIT_SYMMETRIC(2, (byte) 4, (byte) 4) {
113+
@Override
114+
public void packQuery(int[] quantized, byte[] destination) {
115+
ESVectorUtil.transposeHalfByte(quantized, destination);
116+
}
117+
118+
@Override
119+
public void pack(int[] quantized, byte[] destination) {
120+
ESVectorUtil.transposeHalfByte(quantized, destination);
121+
}
122+
123+
@Override
124+
public int getDocPackedLength(int dimensions) {
125+
int discretized = discretizedDimensions(dimensions);
126+
return 4 * ((discretized + 7) / 8);
127+
}
128+
129+
@Override
130+
public int getQueryPackedLength(int dimensions) {
131+
return getDocPackedLength(dimensions);
132+
}
133+
134+
@Override
135+
public int discretizedDimensions(int dimensions) {
136+
int totalBits = dimensions * 4;
137+
return (totalBits + 7) / 8 * 8 / 4;
138+
}
111139
};
112140

113141
private final int id;

server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/next/QuantEncodingTests.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,23 @@ public void testDibitAndNibblesPackSize() {
5252
assertEquals(8, encoding.getQueryPackedLength(15));
5353
assertEquals(8, encoding.getQueryPackedLength(16));
5454
}
55+
56+
public void testHalfByteAndNibbles() {
57+
ESNextDiskBBQVectorsFormat.QuantEncoding encoding = ESNextDiskBBQVectorsFormat.QuantEncoding.FOUR_BIT_SYMMETRIC;
58+
int discretized = encoding.discretizedDimensions(randomIntBetween(1, 1024));
59+
// should discretize to something that can be packed into bytes from four bits and nibbles
60+
assertEquals(0, discretized % 2);
61+
}
62+
63+
public void testHalfByteAndNibblesPackSize() {
64+
ESNextDiskBBQVectorsFormat.QuantEncoding encoding = ESNextDiskBBQVectorsFormat.QuantEncoding.FOUR_BIT_SYMMETRIC;
65+
assertEquals(4, encoding.getDocPackedLength(3));
66+
assertEquals(4, encoding.getQueryPackedLength(3));
67+
assertEquals(4, encoding.getDocPackedLength(8));
68+
assertEquals(4, encoding.getQueryPackedLength(8));
69+
assertEquals(8, encoding.getDocPackedLength(16));
70+
assertEquals(8, encoding.getDocPackedLength(16));
71+
assertEquals(8, encoding.getQueryPackedLength(16));
72+
assertEquals(8, encoding.getQueryPackedLength(16));
73+
}
5574
}

0 commit comments

Comments
 (0)