Skip to content

Commit 44497b7

Browse files
benwtrentelasticsearchmachine
andauthored
Adj ivf postings list building (#130843)
* Make postings list building more IO friendly * iter * iter * fixing assertion * [CI] Auto commit changes from spotless --------- Co-authored-by: elasticsearchmachine <[email protected]>
1 parent ce481e5 commit 44497b7

File tree

4 files changed

+324
-76
lines changed

4 files changed

+324
-76
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/CentroidAssignments.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99

1010
package org.elasticsearch.index.codec.vectors;
1111

12-
record CentroidAssignments(int numCentroids, float[][] centroids, int[][] assignmentsByCluster) {
12+
record CentroidAssignments(int numCentroids, float[][] centroids, int[] assignments, int[] overspillAssignments) {
1313

14-
CentroidAssignments(float[][] centroids, int[][] assignmentsByCluster) {
15-
this(centroids.length, centroids, assignmentsByCluster);
16-
assert centroids.length == assignmentsByCluster.length;
14+
CentroidAssignments(float[][] centroids, int[] assignments, int[] overspillAssignments) {
15+
this(centroids.length, centroids, assignments, overspillAssignments);
16+
assert assignments.length == overspillAssignments.length || overspillAssignments.length == 0
17+
: "assignments and overspillAssignments must have the same length";
1718
}
1819
}

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java

Lines changed: 283 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414
import org.apache.lucene.index.FloatVectorValues;
1515
import org.apache.lucene.index.MergeState;
1616
import org.apache.lucene.index.SegmentWriteState;
17+
import org.apache.lucene.store.IOContext;
1718
import org.apache.lucene.store.IndexInput;
1819
import org.apache.lucene.store.IndexOutput;
1920
import org.apache.lucene.util.VectorUtil;
21+
import org.apache.lucene.util.hnsw.IntToIntFunction;
2022
import org.elasticsearch.index.codec.vectors.cluster.HierarchicalKMeans;
2123
import org.elasticsearch.index.codec.vectors.cluster.KMeansResult;
2224
import org.elasticsearch.logging.LogManager;
@@ -49,32 +51,58 @@ long[] buildAndWritePostingsLists(
4951
CentroidSupplier centroidSupplier,
5052
FloatVectorValues floatVectorValues,
5153
IndexOutput postingsOutput,
52-
int[][] assignmentsByCluster
54+
int[] assignments,
55+
int[] overspillAssignments
5356
) throws IOException {
57+
int[] centroidVectorCount = new int[centroidSupplier.size()];
58+
for (int i = 0; i < assignments.length; i++) {
59+
centroidVectorCount[assignments[i]]++;
60+
// if soar assignments are present, count them as well
61+
if (overspillAssignments.length > i && overspillAssignments[i] != -1) {
62+
centroidVectorCount[overspillAssignments[i]]++;
63+
}
64+
}
65+
66+
int[][] assignmentsByCluster = new int[centroidSupplier.size()][];
67+
for (int c = 0; c < centroidSupplier.size(); c++) {
68+
assignmentsByCluster[c] = new int[centroidVectorCount[c]];
69+
}
70+
Arrays.fill(centroidVectorCount, 0);
71+
72+
for (int i = 0; i < assignments.length; i++) {
73+
int c = assignments[i];
74+
assignmentsByCluster[c][centroidVectorCount[c]++] = i;
75+
// if soar assignments are present, add them to the cluster as well
76+
if (overspillAssignments.length > i) {
77+
int s = overspillAssignments[i];
78+
if (s != -1) {
79+
assignmentsByCluster[s][centroidVectorCount[s]++] = i;
80+
}
81+
}
82+
}
5483
// write the posting lists
5584
final long[] offsets = new long[centroidSupplier.size()];
56-
OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
5785
DocIdsWriter docIdsWriter = new DocIdsWriter();
58-
DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(
59-
ES91OSQVectorsScorer.BULK_SIZE,
60-
quantizer,
86+
DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput);
87+
OnHeapQuantizedVectors onHeapQuantizedVectors = new OnHeapQuantizedVectors(
6188
floatVectorValues,
62-
postingsOutput
89+
fieldInfo.getVectorDimension(),
90+
new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction())
6391
);
6492
for (int c = 0; c < centroidSupplier.size(); c++) {
6593
float[] centroid = centroidSupplier.centroid(c);
66-
// TODO: add back in sorting vectors by distance to centroid
6794
int[] cluster = assignmentsByCluster[c];
6895
// TODO align???
6996
offsets[c] = postingsOutput.getFilePointer();
7097
int size = cluster.length;
7198
postingsOutput.writeVInt(size);
7299
postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
100+
onHeapQuantizedVectors.reset(centroid, size, ord -> cluster[ord]);
73101
// TODO we might want to consider putting the docIds in a separate file
74102
// to aid with only having to fetch vectors from slower storage when they are required
75103
// keeping them in the same file indicates we pull the entire file into cache
76104
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput);
77-
bulkWriter.writeOrds(j -> cluster[j], cluster.length, centroid);
105+
bulkWriter.writeVectors(onHeapQuantizedVectors);
78106
}
79107

80108
if (logger.isDebugEnabled()) {
@@ -84,6 +112,124 @@ long[] buildAndWritePostingsLists(
84112
return offsets;
85113
}
86114

115+
@Override
116+
long[] buildAndWritePostingsLists(
117+
FieldInfo fieldInfo,
118+
CentroidSupplier centroidSupplier,
119+
FloatVectorValues floatVectorValues,
120+
IndexOutput postingsOutput,
121+
MergeState mergeState,
122+
int[] assignments,
123+
int[] overspillAssignments
124+
) throws IOException {
125+
// first, quantize all the vectors into a temporary file
126+
String quantizedVectorsTempName = null;
127+
IndexOutput quantizedVectorsTemp = null;
128+
boolean success = false;
129+
try {
130+
quantizedVectorsTemp = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "qvec_", IOContext.DEFAULT);
131+
quantizedVectorsTempName = quantizedVectorsTemp.getName();
132+
OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
133+
int[] quantized = new int[fieldInfo.getVectorDimension()];
134+
byte[] binary = new byte[BQVectorUtils.discretize(fieldInfo.getVectorDimension(), 64) / 8];
135+
float[] overspillScratch = new float[fieldInfo.getVectorDimension()];
136+
for (int i = 0; i < assignments.length; i++) {
137+
int c = assignments[i];
138+
float[] centroid = centroidSupplier.centroid(c);
139+
float[] vector = floatVectorValues.vectorValue(i);
140+
boolean overspill = overspillAssignments.length > i && overspillAssignments[i] != -1;
141+
// if overspilling, this means we quantize twice, and quantization mutates the in-memory representation of the vector
142+
// so, make a copy of the vector to avoid mutating it
143+
if (overspill) {
144+
System.arraycopy(vector, 0, overspillScratch, 0, fieldInfo.getVectorDimension());
145+
}
146+
147+
OptimizedScalarQuantizer.QuantizationResult result = quantizer.scalarQuantize(vector, quantized, (byte) 1, centroid);
148+
BQVectorUtils.packAsBinary(quantized, binary);
149+
writeQuantizedValue(quantizedVectorsTemp, binary, result);
150+
if (overspill) {
151+
int s = overspillAssignments[i];
152+
// write the overspill vector as well
153+
result = quantizer.scalarQuantize(overspillScratch, quantized, (byte) 1, centroidSupplier.centroid(s));
154+
BQVectorUtils.packAsBinary(quantized, binary);
155+
writeQuantizedValue(quantizedVectorsTemp, binary, result);
156+
} else {
157+
// write a zero vector for the overspill
158+
Arrays.fill(binary, (byte) 0);
159+
OptimizedScalarQuantizer.QuantizationResult zeroResult = new OptimizedScalarQuantizer.QuantizationResult(0f, 0f, 0f, 0);
160+
writeQuantizedValue(quantizedVectorsTemp, binary, zeroResult);
161+
}
162+
}
163+
// close the temporary file so we can read it later
164+
quantizedVectorsTemp.close();
165+
success = true;
166+
} finally {
167+
if (success == false && quantizedVectorsTemp != null) {
168+
mergeState.segmentInfo.dir.deleteFile(quantizedVectorsTemp.getName());
169+
}
170+
}
171+
int[] centroidVectorCount = new int[centroidSupplier.size()];
172+
for (int i = 0; i < assignments.length; i++) {
173+
centroidVectorCount[assignments[i]]++;
174+
// if soar assignments are present, count them as well
175+
if (overspillAssignments.length > i && overspillAssignments[i] != -1) {
176+
centroidVectorCount[overspillAssignments[i]]++;
177+
}
178+
}
179+
180+
int[][] assignmentsByCluster = new int[centroidSupplier.size()][];
181+
boolean[][] isOverspillByCluster = new boolean[centroidSupplier.size()][];
182+
for (int c = 0; c < centroidSupplier.size(); c++) {
183+
assignmentsByCluster[c] = new int[centroidVectorCount[c]];
184+
isOverspillByCluster[c] = new boolean[centroidVectorCount[c]];
185+
}
186+
Arrays.fill(centroidVectorCount, 0);
187+
188+
for (int i = 0; i < assignments.length; i++) {
189+
int c = assignments[i];
190+
assignmentsByCluster[c][centroidVectorCount[c]++] = i;
191+
// if soar assignments are present, add them to the cluster as well
192+
if (overspillAssignments.length > i) {
193+
int s = overspillAssignments[i];
194+
if (s != -1) {
195+
assignmentsByCluster[s][centroidVectorCount[s]] = i;
196+
isOverspillByCluster[s][centroidVectorCount[s]++] = true;
197+
}
198+
}
199+
}
200+
// now we can read the quantized vectors from the temporary file
201+
try (IndexInput quantizedVectorsInput = mergeState.segmentInfo.dir.openInput(quantizedVectorsTempName, IOContext.DEFAULT)) {
202+
final long[] offsets = new long[centroidSupplier.size()];
203+
OffHeapQuantizedVectors offHeapQuantizedVectors = new OffHeapQuantizedVectors(
204+
quantizedVectorsInput,
205+
fieldInfo.getVectorDimension()
206+
);
207+
DocIdsWriter docIdsWriter = new DocIdsWriter();
208+
DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput);
209+
for (int c = 0; c < centroidSupplier.size(); c++) {
210+
float[] centroid = centroidSupplier.centroid(c);
211+
int[] cluster = assignmentsByCluster[c];
212+
boolean[] isOverspill = isOverspillByCluster[c];
213+
// TODO align???
214+
offsets[c] = postingsOutput.getFilePointer();
215+
int size = cluster.length;
216+
postingsOutput.writeVInt(size);
217+
postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
218+
offHeapQuantizedVectors.reset(size, ord -> isOverspill[ord], ord -> cluster[ord]);
219+
// TODO we might want to consider putting the docIds in a separate file
220+
// to aid with only having to fetch vectors from slower storage when they are required
221+
// keeping them in the same file indicates we pull the entire file into cache
222+
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput);
223+
bulkWriter.writeVectors(offHeapQuantizedVectors);
224+
}
225+
226+
if (logger.isDebugEnabled()) {
227+
printClusterQualityStatistics(assignmentsByCluster);
228+
}
229+
return offsets;
230+
}
231+
}
232+
87233
private static void printClusterQualityStatistics(int[][] clusters) {
88234
float min = Float.MAX_VALUE;
89235
float max = Float.MIN_VALUE;
@@ -210,33 +356,7 @@ static CentroidAssignments buildCentroidAssignments(KMeansResult kMeansResult) {
210356
float[][] centroids = kMeansResult.centroids();
211357
int[] assignments = kMeansResult.assignments();
212358
int[] soarAssignments = kMeansResult.soarAssignments();
213-
int[] centroidVectorCount = new int[centroids.length];
214-
for (int i = 0; i < assignments.length; i++) {
215-
centroidVectorCount[assignments[i]]++;
216-
// if soar assignments are present, count them as well
217-
if (soarAssignments.length > i && soarAssignments[i] != -1) {
218-
centroidVectorCount[soarAssignments[i]]++;
219-
}
220-
}
221-
222-
int[][] assignmentsByCluster = new int[centroids.length][];
223-
for (int c = 0; c < centroids.length; c++) {
224-
assignmentsByCluster[c] = new int[centroidVectorCount[c]];
225-
}
226-
Arrays.fill(centroidVectorCount, 0);
227-
228-
for (int i = 0; i < assignments.length; i++) {
229-
int c = assignments[i];
230-
assignmentsByCluster[c][centroidVectorCount[c]++] = i;
231-
// if soar assignments are present, add them to the cluster as well
232-
if (soarAssignments.length > i) {
233-
int s = soarAssignments[i];
234-
if (s != -1) {
235-
assignmentsByCluster[s][centroidVectorCount[s]++] = i;
236-
}
237-
}
238-
}
239-
return new CentroidAssignments(centroids, assignmentsByCluster);
359+
return new CentroidAssignments(centroids, assignments, soarAssignments);
240360
}
241361

242362
static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections)
@@ -281,4 +401,132 @@ public float[] centroid(int centroidOrdinal) throws IOException {
281401
return scratch;
282402
}
283403
}
404+
405+
interface QuantizedVectorValues {
406+
int count();
407+
408+
byte[] next() throws IOException;
409+
410+
OptimizedScalarQuantizer.QuantizationResult getCorrections() throws IOException;
411+
}
412+
413+
interface IntToBooleanFunction {
414+
boolean apply(int ord);
415+
}
416+
417+
static class OnHeapQuantizedVectors implements QuantizedVectorValues {
418+
private final FloatVectorValues vectorValues;
419+
private final OptimizedScalarQuantizer quantizer;
420+
private final byte[] quantizedVector;
421+
private final int[] quantizedVectorScratch;
422+
private OptimizedScalarQuantizer.QuantizationResult corrections;
423+
private float[] currentCentroid;
424+
private IntToIntFunction ordTransformer = null;
425+
private int currOrd = -1;
426+
private int count;
427+
428+
OnHeapQuantizedVectors(FloatVectorValues vectorValues, int dimension, OptimizedScalarQuantizer quantizer) {
429+
this.vectorValues = vectorValues;
430+
this.quantizer = quantizer;
431+
this.quantizedVector = new byte[BQVectorUtils.discretize(dimension, 64) / 8];
432+
this.quantizedVectorScratch = new int[dimension];
433+
this.corrections = null;
434+
}
435+
436+
private void reset(float[] centroid, int count, IntToIntFunction ordTransformer) {
437+
this.currentCentroid = centroid;
438+
this.ordTransformer = ordTransformer;
439+
this.currOrd = -1;
440+
this.count = count;
441+
}
442+
443+
@Override
444+
public int count() {
445+
return count;
446+
}
447+
448+
@Override
449+
public byte[] next() throws IOException {
450+
if (currOrd >= count() - 1) {
451+
throw new IllegalStateException("No more vectors to read, current ord: " + currOrd + ", count: " + count());
452+
}
453+
currOrd++;
454+
int ord = ordTransformer.apply(currOrd);
455+
float[] vector = vectorValues.vectorValue(ord);
456+
corrections = quantizer.scalarQuantize(vector, quantizedVectorScratch, (byte) 1, currentCentroid);
457+
BQVectorUtils.packAsBinary(quantizedVectorScratch, quantizedVector);
458+
return quantizedVector;
459+
}
460+
461+
@Override
462+
public OptimizedScalarQuantizer.QuantizationResult getCorrections() throws IOException {
463+
if (currOrd == -1) {
464+
throw new IllegalStateException("No vector read yet, call next first");
465+
}
466+
return corrections;
467+
}
468+
}
469+
470+
static class OffHeapQuantizedVectors implements QuantizedVectorValues {
471+
private final IndexInput quantizedVectorsInput;
472+
private final byte[] binaryScratch;
473+
private final float[] corrections = new float[3];
474+
475+
private final int vectorByteSize;
476+
private short bitSum;
477+
private int currOrd = -1;
478+
private int count;
479+
private IntToBooleanFunction isOverspill = null;
480+
private IntToIntFunction ordTransformer = null;
481+
482+
OffHeapQuantizedVectors(IndexInput quantizedVectorsInput, int dimension) {
483+
this.quantizedVectorsInput = quantizedVectorsInput;
484+
this.binaryScratch = new byte[BQVectorUtils.discretize(dimension, 64) / 8];
485+
this.vectorByteSize = (binaryScratch.length + 3 * Float.BYTES + Short.BYTES);
486+
}
487+
488+
private void reset(int count, IntToBooleanFunction isOverspill, IntToIntFunction ordTransformer) {
489+
this.count = count;
490+
this.isOverspill = isOverspill;
491+
this.ordTransformer = ordTransformer;
492+
this.currOrd = -1;
493+
}
494+
495+
@Override
496+
public int count() {
497+
return count;
498+
}
499+
500+
@Override
501+
public byte[] next() throws IOException {
502+
if (currOrd >= count - 1) {
503+
throw new IllegalStateException("No more vectors to read, current ord: " + currOrd + ", count: " + count);
504+
}
505+
currOrd++;
506+
int ord = ordTransformer.apply(currOrd);
507+
boolean isOverspill = this.isOverspill.apply(currOrd);
508+
return getVector(ord, isOverspill);
509+
}
510+
511+
@Override
512+
public OptimizedScalarQuantizer.QuantizationResult getCorrections() throws IOException {
513+
if (currOrd == -1) {
514+
throw new IllegalStateException("No vector read yet, call readQuantizedVector first");
515+
}
516+
return new OptimizedScalarQuantizer.QuantizationResult(corrections[0], corrections[1], corrections[2], bitSum);
517+
}
518+
519+
byte[] getVector(int ord, boolean isOverspill) throws IOException {
520+
readQuantizedVector(ord, isOverspill);
521+
return binaryScratch;
522+
}
523+
524+
public void readQuantizedVector(int ord, boolean isOverspill) throws IOException {
525+
long offset = (long) ord * (vectorByteSize * 2L) + (isOverspill ? vectorByteSize : 0);
526+
quantizedVectorsInput.seek(offset);
527+
quantizedVectorsInput.readBytes(binaryScratch, 0, binaryScratch.length);
528+
quantizedVectorsInput.readFloats(corrections, 0, 3);
529+
bitSum = quantizedVectorsInput.readShort();
530+
}
531+
}
284532
}

0 commit comments

Comments
 (0)