Skip to content

Commit 5dcf12d

Browse files
committed
Make postings list building more IO friendly
1 parent 7fac8ff commit 5dcf12d

File tree

3 files changed

+185
-37
lines changed

3 files changed

+185
-37
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/CentroidAssignments.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99

1010
package org.elasticsearch.index.codec.vectors;
1111

12-
record CentroidAssignments(int numCentroids, float[][] centroids, int[][] assignmentsByCluster) {
12+
record CentroidAssignments(int numCentroids, float[][] centroids, int[] assignments, int[] overspillAssignments) {
1313

14-
CentroidAssignments(float[][] centroids, int[][] assignmentsByCluster) {
15-
this(centroids.length, centroids, assignmentsByCluster);
16-
assert centroids.length == assignmentsByCluster.length;
14+
CentroidAssignments(float[][] centroids, int[] assignments, int[] overspillAssignments) {
15+
this(centroids.length, centroids, assignments, overspillAssignments);
16+
assert assignments.length == overspillAssignments.length : "assignments and overspillAssignments must have the same length";
1717
}
1818
}

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java

Lines changed: 160 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.apache.lucene.index.FloatVectorValues;
1515
import org.apache.lucene.index.MergeState;
1616
import org.apache.lucene.index.SegmentWriteState;
17+
import org.apache.lucene.store.IOContext;
1718
import org.apache.lucene.store.IndexInput;
1819
import org.apache.lucene.store.IndexOutput;
1920
import org.apache.lucene.util.VectorUtil;
@@ -49,8 +50,35 @@ long[] buildAndWritePostingsLists(
4950
CentroidSupplier centroidSupplier,
5051
FloatVectorValues floatVectorValues,
5152
IndexOutput postingsOutput,
52-
int[][] assignmentsByCluster
53+
int[] assignments,
54+
int[] overspillAssignments
5355
) throws IOException {
56+
int[] centroidVectorCount = new int[centroidSupplier.size()];
57+
for (int i = 0; i < assignments.length; i++) {
58+
centroidVectorCount[assignments[i]]++;
59+
// if soar assignments are present, count them as well
60+
if (overspillAssignments.length > i && overspillAssignments[i] != -1) {
61+
centroidVectorCount[overspillAssignments[i]]++;
62+
}
63+
}
64+
65+
int[][] assignmentsByCluster = new int[centroidSupplier.size()][];
66+
for (int c = 0; c < centroidSupplier.size(); c++) {
67+
assignmentsByCluster[c] = new int[centroidVectorCount[c]];
68+
}
69+
Arrays.fill(centroidVectorCount, 0);
70+
71+
for (int i = 0; i < assignments.length; i++) {
72+
int c = assignments[i];
73+
assignmentsByCluster[c][centroidVectorCount[c]++] = i;
74+
// if soar assignments are present, add them to the cluster as well
75+
if (overspillAssignments.length > i) {
76+
int s = overspillAssignments[i];
77+
if (s != -1) {
78+
assignmentsByCluster[s][centroidVectorCount[s]++] = i;
79+
}
80+
}
81+
}
5482
// write the posting lists
5583
final long[] offsets = new long[centroidSupplier.size()];
5684
OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
@@ -84,6 +112,92 @@ long[] buildAndWritePostingsLists(
84112
return offsets;
85113
}
86114

115+
@Override
116+
long[] buildAndWritePostingsLists(
117+
FieldInfo fieldInfo,
118+
CentroidSupplier centroidSupplier,
119+
FloatVectorValues floatVectorValues,
120+
IndexOutput postingsOutput,
121+
MergeState mergeState,
122+
int[] assignments,
123+
int[] overspillAssignments
124+
) throws IOException {
125+
// first, quantize all the vectors into a temporary file
126+
String quantizedVectorsTempName = null;
127+
IndexOutput quantizedVectorsTemp = null;
128+
boolean success = false;
129+
try {
130+
quantizedVectorsTemp = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "qvec_", IOContext.DEFAULT);
131+
quantizedVectorsTempName = quantizedVectorsTemp.getName();
132+
OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
133+
int[] quantized = new int[fieldInfo.getVectorDimension()];
134+
byte[] binary = new byte[BQVectorUtils.discretize(fieldInfo.getVectorDimension(), 64) / 8];
135+
for (int i = 0; i < assignments.length; i++) {
136+
int c = assignments[i];
137+
float[] centroid = centroidSupplier.centroid(c);
138+
float[] vector = floatVectorValues.vectorValue(i);
139+
OptimizedScalarQuantizer.QuantizationResult result = quantizer.scalarQuantize(vector, quantized, (byte) 1, centroid);
140+
BQVectorUtils.packAsBinary(quantized, binary);
141+
writeQuantizedValue(quantizedVectorsTemp, binary, result);
142+
boolean overspill = overspillAssignments.length > i && overspillAssignments[i] != -1;
143+
if (overspill) {
144+
int s = overspillAssignments[i];
145+
// write the overspill vector as well
146+
result = quantizer.scalarQuantize(vector, quantized, (byte) 1, centroidSupplier.centroid(s));
147+
BQVectorUtils.packAsBinary(quantized, binary);
148+
writeQuantizedValue(quantizedVectorsTemp, binary, result);
149+
} else {
150+
// write a zero vector for the overspill
151+
Arrays.fill(binary, (byte) 0);
152+
OptimizedScalarQuantizer.QuantizationResult zeroResult = new OptimizedScalarQuantizer.QuantizationResult(0f, 0f, 0f, 0);
153+
writeQuantizedValue(quantizedVectorsTemp, binary, zeroResult);
154+
}
155+
}
156+
// close the temporary file so we can read it later
157+
quantizedVectorsTemp.close();
158+
success = true;
159+
} finally {
160+
if (success == false && quantizedVectorsTemp != null) {
161+
mergeState.segmentInfo.dir.deleteFile(quantizedVectorsTemp.getName());
162+
}
163+
}
164+
// now we can read the quantized vectors from the temporary file
165+
try (IndexInput quantizedVectorsInput = mergeState.segmentInfo.dir.openInput(quantizedVectorsTempName, IOContext.DEFAULT)) {
166+
final long[] offsets = new long[centroidSupplier.size()];
167+
OffHeapQuantizedVectors offHeapQuantizedVectors = new OffHeapQuantizedVectors(
168+
quantizedVectorsInput,
169+
fieldInfo.getVectorDimension()
170+
);
171+
DocIdsWriter docIdsWriter = new DocIdsWriter();
172+
DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(
173+
ES91OSQVectorsScorer.BULK_SIZE,
174+
quantizer,
175+
floatVectorValues,
176+
postingsOutput
177+
);
178+
for (int c = 0; c < centroidSupplier.size(); c++) {
179+
float[] centroid = centroidSupplier.centroid(c);
180+
// TODO: add back in sorting vectors by distance to centroid
181+
int[] cluster = assignmentsByCluster[c];
182+
// TODO align???
183+
offsets[c] = postingsOutput.getFilePointer();
184+
int size = cluster.length;
185+
postingsOutput.writeVInt(size);
186+
postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
187+
// TODO we might want to consider putting the docIds in a separate file
188+
// to aid with only having to fetch vectors from slower storage when they are required
189+
// keeping them in the same file indicates we pull the entire file into cache
190+
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput);
191+
bulkWriter.writeOrds(j -> cluster[j], cluster.length, centroid);
192+
}
193+
194+
if (logger.isDebugEnabled()) {
195+
printClusterQualityStatistics(assignmentsByCluster);
196+
}
197+
return offsets;
198+
}
199+
}
200+
87201
private static void printClusterQualityStatistics(int[][] clusters) {
88202
float min = Float.MAX_VALUE;
89203
float max = Float.MIN_VALUE;
@@ -210,33 +324,7 @@ static CentroidAssignments buildCentroidAssignments(KMeansResult kMeansResult) {
210324
float[][] centroids = kMeansResult.centroids();
211325
int[] assignments = kMeansResult.assignments();
212326
int[] soarAssignments = kMeansResult.soarAssignments();
213-
int[] centroidVectorCount = new int[centroids.length];
214-
for (int i = 0; i < assignments.length; i++) {
215-
centroidVectorCount[assignments[i]]++;
216-
// if soar assignments are present, count them as well
217-
if (soarAssignments.length > i && soarAssignments[i] != -1) {
218-
centroidVectorCount[soarAssignments[i]]++;
219-
}
220-
}
221-
222-
int[][] assignmentsByCluster = new int[centroids.length][];
223-
for (int c = 0; c < centroids.length; c++) {
224-
assignmentsByCluster[c] = new int[centroidVectorCount[c]];
225-
}
226-
Arrays.fill(centroidVectorCount, 0);
227-
228-
for (int i = 0; i < assignments.length; i++) {
229-
int c = assignments[i];
230-
assignmentsByCluster[c][centroidVectorCount[c]++] = i;
231-
// if soar assignments are present, add them to the cluster as well
232-
if (soarAssignments.length > i) {
233-
int s = soarAssignments[i];
234-
if (s != -1) {
235-
assignmentsByCluster[s][centroidVectorCount[s]++] = i;
236-
}
237-
}
238-
}
239-
return new CentroidAssignments(centroids, assignmentsByCluster);
327+
return new CentroidAssignments(centroids, assignments, soarAssignments);
240328
}
241329

242330
static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections)
@@ -281,4 +369,48 @@ public float[] centroid(int centroidOrdinal) throws IOException {
281369
return scratch;
282370
}
283371
}
372+
373+
static class OffHeapQuantizedVectors {
374+
private final IndexInput quantizedVectorsInput;
375+
private final byte[] binaryScratch;
376+
private final float[] corrections = new float[3];
377+
378+
private final int vectorByteSize;
379+
private short bitSum;
380+
private int currOrd = -1;
381+
private boolean isOverspill = false;
382+
383+
OffHeapQuantizedVectors(IndexInput quantizedVectorsInput, int dimension) {
384+
this.quantizedVectorsInput = quantizedVectorsInput;
385+
this.binaryScratch = new byte[BQVectorUtils.discretize(dimension, 64) / 8];
386+
this.vectorByteSize = (binaryScratch.length + 3 * Float.BYTES + Short.BYTES);
387+
}
388+
389+
byte[] getVector(int ord, boolean isOverspill) throws IOException {
390+
readQuantizedVector(ord, isOverspill);
391+
return binaryScratch;
392+
}
393+
394+
OptimizedScalarQuantizer.QuantizationResult getCorrections() throws IOException {
395+
if (currOrd == -1) {
396+
throw new IllegalStateException("No vector read yet, call readQuantizedVector first");
397+
}
398+
return new OptimizedScalarQuantizer.QuantizationResult(corrections[0], corrections[1], corrections[2], bitSum);
399+
}
400+
401+
public void readQuantizedVector(int ord, boolean isOverspill) throws IOException {
402+
if (ord == currOrd && isOverspill == this.isOverspill) {
403+
return; // no need to read again
404+
}
405+
long offset = (long) ord * (vectorByteSize * 2) + (isOverspill ? vectorByteSize : 0);
406+
quantizedVectorsInput.seek(offset);
407+
quantizedVectorsInput.readBytes(binaryScratch, 0, binaryScratch.length);
408+
quantizedVectorsInput.readFloats(corrections, 0, 3);
409+
bitSum = quantizedVectorsInput.readShort();
410+
if (ord != currOrd) {
411+
currOrd = ord;
412+
}
413+
this.isOverspill = isOverspill;
414+
}
415+
}
284416
}

server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,18 @@ abstract long[] buildAndWritePostingsLists(
139139
CentroidSupplier centroidSupplier,
140140
FloatVectorValues floatVectorValues,
141141
IndexOutput postingsOutput,
142-
int[][] assignmentsByCluster
142+
int[] assignments,
143+
int[] overspillAssignments
144+
) throws IOException;
145+
146+
abstract long[] buildAndWritePostingsLists(
147+
FieldInfo fieldInfo,
148+
CentroidSupplier centroidSupplier,
149+
FloatVectorValues floatVectorValues,
150+
IndexOutput postingsOutput,
151+
MergeState mergeState,
152+
int[] assignments,
153+
int[] overspillAssignments
143154
) throws IOException;
144155

145156
abstract CentroidSupplier createCentroidSupplier(
@@ -174,7 +185,8 @@ public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
174185
centroidSupplier,
175186
floatVectorValues,
176187
ivfClusters,
177-
centroidAssignments.assignmentsByCluster()
188+
centroidAssignments.assignments(),
189+
centroidAssignments.overspillAssignments()
178190
);
179191
// write posting lists
180192
writeMeta(fieldWriter.fieldInfo, centroidOffset, centroidLength, offsets, globalCentroid);
@@ -284,7 +296,8 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws
284296
final long centroidOffset;
285297
final long centroidLength;
286298
final int numCentroids;
287-
final int[][] assignmentsByCluster;
299+
final int[] assignments;
300+
final int[] overspillAssignments;
288301
final float[] calculatedGlobalCentroid = new float[fieldInfo.getVectorDimension()];
289302
String centroidTempName = null;
290303
IndexOutput centroidTemp = null;
@@ -300,7 +313,8 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws
300313
calculatedGlobalCentroid
301314
);
302315
numCentroids = centroidAssignments.numCentroids();
303-
assignmentsByCluster = centroidAssignments.assignmentsByCluster();
316+
assignments = centroidAssignments.assignments();
317+
overspillAssignments = centroidAssignments.overspillAssignments();
304318
success = true;
305319
} finally {
306320
if (success == false && centroidTempName != null) {
@@ -337,7 +351,9 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws
337351
centroidSupplier,
338352
floatVectorValues,
339353
ivfClusters,
340-
assignmentsByCluster
354+
mergeState,
355+
assignments,
356+
overspillAssignments
341357
);
342358
assert offsets.length == centroidSupplier.size();
343359
writeMeta(fieldInfo, centroidOffset, centroidLength, offsets, calculatedGlobalCentroid);

0 commit comments

Comments
 (0)