Skip to content

Commit 044f34b

Browse files
authored
Refactor bulk quantization writing into a unified class (#130354)
this is a small refactor, laying ground work for more generalized bulk writing. I did some benchmarking and there was no significant performance difference (as expected).
1 parent c1a4f8a commit 044f34b

File tree

3 files changed

+112
-98
lines changed

3 files changed

+112
-98
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java

Lines changed: 7 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,6 @@
2828
import java.nio.ByteOrder;
2929
import java.util.Arrays;
3030

31-
import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.INDEX_BITS;
32-
import static org.elasticsearch.index.codec.vectors.BQVectorUtils.discretize;
33-
import static org.elasticsearch.index.codec.vectors.BQVectorUtils.packAsBinary;
34-
3531
/**
3632
* Default implementation of {@link IVFVectorsWriter}. It uses {@link HierarchicalKMeans} algorithm to
3733
* partition the vector space, and then stores the centroids and posting list in a sequential
@@ -58,12 +54,15 @@ long[] buildAndWritePostingsLists(
5854
// write the posting lists
5955
final long[] offsets = new long[centroidSupplier.size()];
6056
OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
61-
BinarizedFloatVectorValues binarizedByteVectorValues = new BinarizedFloatVectorValues(floatVectorValues, quantizer);
6257
DocIdsWriter docIdsWriter = new DocIdsWriter();
63-
58+
DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(
59+
ES91OSQVectorsScorer.BULK_SIZE,
60+
quantizer,
61+
floatVectorValues,
62+
postingsOutput
63+
);
6464
for (int c = 0; c < centroidSupplier.size(); c++) {
6565
float[] centroid = centroidSupplier.centroid(c);
66-
binarizedByteVectorValues.centroid = centroid;
6766
// TODO: add back in sorting vectors by distance to centroid
6867
int[] cluster = assignmentsByCluster[c];
6968
// TODO align???
@@ -75,7 +74,7 @@ long[] buildAndWritePostingsLists(
7574
// to aid with only having to fetch vectors from slower storage when they are required
7675
// keeping them in the same file indicates we pull the entire file into cache
7776
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput);
78-
writePostingList(cluster, postingsOutput, binarizedByteVectorValues);
77+
bulkWriter.writeOrds(j -> cluster[j], cluster.length, centroid);
7978
}
8079

8180
if (logger.isDebugEnabled()) {
@@ -115,54 +114,6 @@ private static void printClusterQualityStatistics(int[][] clusters) {
115114
);
116115
}
117116

118-
private void writePostingList(int[] cluster, IndexOutput postingsOutput, BinarizedFloatVectorValues binarizedByteVectorValues)
119-
throws IOException {
120-
int limit = cluster.length - ES91OSQVectorsScorer.BULK_SIZE + 1;
121-
int cidx = 0;
122-
OptimizedScalarQuantizer.QuantizationResult[] corrections =
123-
new OptimizedScalarQuantizer.QuantizationResult[ES91OSQVectorsScorer.BULK_SIZE];
124-
// Write vectors in bulks of ES91OSQVectorsScorer.BULK_SIZE.
125-
for (; cidx < limit; cidx += ES91OSQVectorsScorer.BULK_SIZE) {
126-
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
127-
int ord = cluster[cidx + j];
128-
byte[] binaryValue = binarizedByteVectorValues.vectorValue(ord);
129-
// write vector
130-
postingsOutput.writeBytes(binaryValue, 0, binaryValue.length);
131-
corrections[j] = binarizedByteVectorValues.getCorrectiveTerms(ord);
132-
}
133-
// write corrections
134-
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
135-
postingsOutput.writeInt(Float.floatToIntBits(corrections[j].lowerInterval()));
136-
}
137-
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
138-
postingsOutput.writeInt(Float.floatToIntBits(corrections[j].upperInterval()));
139-
}
140-
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
141-
int targetComponentSum = corrections[j].quantizedComponentSum();
142-
assert targetComponentSum >= 0 && targetComponentSum <= 0xffff;
143-
postingsOutput.writeShort((short) targetComponentSum);
144-
}
145-
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
146-
postingsOutput.writeInt(Float.floatToIntBits(corrections[j].additionalCorrection()));
147-
}
148-
}
149-
// write tail
150-
for (; cidx < cluster.length; cidx++) {
151-
int ord = cluster[cidx];
152-
// write vector
153-
byte[] binaryValue = binarizedByteVectorValues.vectorValue(ord);
154-
OptimizedScalarQuantizer.QuantizationResult correction = binarizedByteVectorValues.getCorrectiveTerms(ord);
155-
writeQuantizedValue(postingsOutput, binaryValue, correction);
156-
binarizedByteVectorValues.getCorrectiveTerms(ord);
157-
postingsOutput.writeBytes(binaryValue, 0, binaryValue.length);
158-
postingsOutput.writeInt(Float.floatToIntBits(correction.lowerInterval()));
159-
postingsOutput.writeInt(Float.floatToIntBits(correction.upperInterval()));
160-
postingsOutput.writeInt(Float.floatToIntBits(correction.additionalCorrection()));
161-
assert correction.quantizedComponentSum() >= 0 && correction.quantizedComponentSum() <= 0xffff;
162-
postingsOutput.writeShort((short) correction.quantizedComponentSum());
163-
}
164-
}
165-
166117
@Override
167118
CentroidSupplier createCentroidSupplier(IndexInput centroidsInput, int numCentroids, FieldInfo fieldInfo, float[] globalCentroid) {
168119
return new OffHeapCentroidSupplier(centroidsInput, numCentroids, fieldInfo);
@@ -295,47 +246,6 @@ CentroidAssignments calculateAndWriteCentroids(
295246
}
296247
}
297248

298-
// TODO unify with OSQ format
299-
static class BinarizedFloatVectorValues {
300-
private OptimizedScalarQuantizer.QuantizationResult corrections;
301-
private final byte[] binarized;
302-
private final byte[] initQuantized;
303-
private float[] centroid;
304-
private final FloatVectorValues values;
305-
private final OptimizedScalarQuantizer quantizer;
306-
307-
private int lastOrd = -1;
308-
309-
BinarizedFloatVectorValues(FloatVectorValues delegate, OptimizedScalarQuantizer quantizer) {
310-
this.values = delegate;
311-
this.quantizer = quantizer;
312-
this.binarized = new byte[discretize(delegate.dimension(), 64) / 8];
313-
this.initQuantized = new byte[delegate.dimension()];
314-
}
315-
316-
public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int ord) {
317-
if (ord != lastOrd) {
318-
throw new IllegalStateException(
319-
"attempt to retrieve corrective terms for different ord " + ord + " than the quantization was done for: " + lastOrd
320-
);
321-
}
322-
return corrections;
323-
}
324-
325-
public byte[] vectorValue(int ord) throws IOException {
326-
if (ord != lastOrd) {
327-
binarize(ord);
328-
lastOrd = ord;
329-
}
330-
return binarized;
331-
}
332-
333-
private void binarize(int ord) throws IOException {
334-
corrections = quantizer.scalarQuantize(values.vectorValue(ord), initQuantized, INDEX_BITS, centroid);
335-
packAsBinary(initQuantized, binarized);
336-
}
337-
}
338-
339249
static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections)
340250
throws IOException {
341251
indexOutput.writeBytes(binaryValue, binaryValue.length);
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.index.codec.vectors;
11+
12+
import org.apache.lucene.index.FloatVectorValues;
13+
import org.apache.lucene.store.IndexOutput;
14+
import org.apache.lucene.util.hnsw.IntToIntFunction;
15+
16+
import java.io.IOException;
17+
18+
import static org.elasticsearch.index.codec.vectors.BQVectorUtils.discretize;
19+
import static org.elasticsearch.index.codec.vectors.BQVectorUtils.packAsBinary;
20+
21+
/**
22+
* Base class for bulk writers that write vectors to disk using the BBQ encoding.
23+
* This class provides the structure for writing vectors in bulk, with specific
24+
* implementations for different bit sizes strategies.
25+
*/
26+
public abstract class DiskBBQBulkWriter {
27+
protected final int bulkSize;
28+
protected final OptimizedScalarQuantizer quantizer;
29+
protected final IndexOutput out;
30+
protected final FloatVectorValues fvv;
31+
32+
protected DiskBBQBulkWriter(int bulkSize, OptimizedScalarQuantizer quantizer, FloatVectorValues fvv, IndexOutput out) {
33+
this.bulkSize = bulkSize;
34+
this.quantizer = quantizer;
35+
this.out = out;
36+
this.fvv = fvv;
37+
}
38+
39+
public abstract void writeOrds(IntToIntFunction ords, int count, float[] centroid) throws IOException;
40+
41+
private static void writeCorrections(OptimizedScalarQuantizer.QuantizationResult[] corrections, IndexOutput out) throws IOException {
42+
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
43+
out.writeInt(Float.floatToIntBits(correction.lowerInterval()));
44+
}
45+
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
46+
out.writeInt(Float.floatToIntBits(correction.upperInterval()));
47+
}
48+
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
49+
int targetComponentSum = correction.quantizedComponentSum();
50+
assert targetComponentSum >= 0 && targetComponentSum <= 0xffff;
51+
out.writeShort((short) targetComponentSum);
52+
}
53+
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
54+
out.writeInt(Float.floatToIntBits(correction.additionalCorrection()));
55+
}
56+
}
57+
58+
private static void writeCorrection(OptimizedScalarQuantizer.QuantizationResult correction, IndexOutput out) throws IOException {
59+
out.writeInt(Float.floatToIntBits(correction.lowerInterval()));
60+
out.writeInt(Float.floatToIntBits(correction.upperInterval()));
61+
out.writeInt(Float.floatToIntBits(correction.additionalCorrection()));
62+
int targetComponentSum = correction.quantizedComponentSum();
63+
assert targetComponentSum >= 0 && targetComponentSum <= 0xffff;
64+
out.writeShort((short) targetComponentSum);
65+
}
66+
67+
public static class OneBitDiskBBQBulkWriter extends DiskBBQBulkWriter {
68+
private final byte[] binarized;
69+
private final byte[] initQuantized;
70+
private final OptimizedScalarQuantizer.QuantizationResult[] corrections;
71+
72+
public OneBitDiskBBQBulkWriter(int bulkSize, OptimizedScalarQuantizer quantizer, FloatVectorValues fvv, IndexOutput out) {
73+
super(bulkSize, quantizer, fvv, out);
74+
this.binarized = new byte[discretize(fvv.dimension(), 64) / 8];
75+
this.initQuantized = new byte[fvv.dimension()];
76+
this.corrections = new OptimizedScalarQuantizer.QuantizationResult[bulkSize];
77+
}
78+
79+
@Override
80+
public void writeOrds(IntToIntFunction ords, int count, float[] centroid) throws IOException {
81+
int limit = count - bulkSize + 1;
82+
int i = 0;
83+
for (; i < limit; i += bulkSize) {
84+
for (int j = 0; j < bulkSize; j++) {
85+
int ord = ords.apply(i + j);
86+
float[] fv = fvv.vectorValue(ord);
87+
corrections[j] = quantizer.scalarQuantize(fv, initQuantized, (byte) 1, centroid);
88+
packAsBinary(initQuantized, binarized);
89+
out.writeBytes(binarized, binarized.length);
90+
}
91+
writeCorrections(corrections, out);
92+
}
93+
// write tail
94+
for (; i < count; ++i) {
95+
int ord = ords.apply(i);
96+
float[] fv = fvv.vectorValue(ord);
97+
OptimizedScalarQuantizer.QuantizationResult correction = quantizer.scalarQuantize(fv, initQuantized, (byte) 1, centroid);
98+
packAsBinary(initQuantized, binarized);
99+
out.writeBytes(binarized, binarized.length);
100+
writeCorrection(correction, out);
101+
}
102+
}
103+
}
104+
}

server/src/test/java/org/elasticsearch/search/vectors/AbstractDiversifyingChildrenIVFKnnVectorQueryTestCase.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ public void testSkewedIndex() throws IOException {
300300
);
301301
assertEquals(8, results.scoreDocs.length);
302302
assertIdMatches(reader, "10", results.scoreDocs[0].doc);
303-
assertIdMatches(reader, "8", results.scoreDocs[7].doc);
303+
assertIdMatches(reader, "6", results.scoreDocs[7].doc);
304304
}
305305
}
306306
}

0 commit comments

Comments
 (0)