Skip to content

Commit a4f2c28

Browse files
committed
fixing test failure, addressing pr comments
1 parent 817ee23 commit a4f2c28

File tree

2 files changed

+118
-94
lines changed

2 files changed

+118
-94
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java

Lines changed: 95 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -223,44 +223,17 @@ static void writeCentroids(float[][] centroids, FieldInfo fieldInfo, float[] glo
223223
}
224224
}
225225

226-
record SegmentCentroid(int segment, int centroid, int centroidSize) {}
227-
228-
@Override
229-
protected int calculateAndWriteCentroids(
226+
static float[][] gatherInitCentroids(
227+
List<FloatVectorValues> centroidList,
228+
List<SegmentCentroid> segmentCentroids,
229+
int desiredClusters,
230230
FieldInfo fieldInfo,
231-
FloatVectorValues floatVectorValues,
232-
IndexOutput temporaryCentroidOutput,
233-
MergeState mergeState,
234-
float[] globalCentroid
231+
MergeState mergeState
235232
) throws IOException {
236-
if (floatVectorValues.size() == 0) {
237-
return 0;
233+
if (centroidList.size() == 0) {
234+
return null;
238235
}
239-
int desiredClusters = ((floatVectorValues.size() - 1) / vectorPerCluster) + 1;
240-
// init centroids from merge state
241-
List<FloatVectorValues> centroidList = new ArrayList<>();
242-
List<SegmentCentroid> segmentCentroids = new ArrayList<>(desiredClusters);
243-
244-
int segmentIdx = 0;
245236
long startTime = System.nanoTime();
246-
for (var reader : mergeState.knnVectorsReaders) {
247-
IVFVectorsReader ivfVectorsReader = IVFVectorsFormat.getIVFReader(reader, fieldInfo.name);
248-
if (ivfVectorsReader == null) {
249-
continue;
250-
}
251-
252-
FloatVectorValues centroid = ivfVectorsReader.getCentroids(fieldInfo);
253-
if (centroid == null) {
254-
continue;
255-
}
256-
centroidList.add(centroid);
257-
for (int i = 0; i < centroid.size(); i++) {
258-
int size = ivfVectorsReader.centroidSize(fieldInfo.name, i);
259-
segmentCentroids.add(new SegmentCentroid(segmentIdx, i, size));
260-
}
261-
segmentIdx++;
262-
}
263-
264237
// sort centroid list by floatvector size
265238
FloatVectorValues baseSegment = centroidList.get(0);
266239
for (var l : centroidList) {
@@ -334,6 +307,9 @@ protected int calculateAndWriteCentroids(
334307
sum[label - 1] += segmentCentroid.centroidSize;
335308
}
336309
for (int i = 0; i < initCentroids.length; i++) {
310+
if (sum[i] == 0 || sum[i] == 1) {
311+
continue;
312+
}
337313
for (int j = 0; j < initCentroids[i].length; j++) {
338314
initCentroids[i][j] /= sum[i];
339315
}
@@ -348,6 +324,67 @@ protected int calculateAndWriteCentroids(
348324
"Gathered initCentroids:" + initCentroids.length + " for desired: " + desiredClusters
349325
);
350326
}
327+
return initCentroids;
328+
}
329+
330+
record SegmentCentroid(int segment, int centroid, int centroidSize) {}
331+
332+
/**
333+
* Calculate the centroids for the given field and write them to the given
334+
* temporary centroid output.
335+
* When merging, we first bootstrap the KMeans algorithm with the centroids contained in the merging segments.
336+
* To prevent centroids that are too similar from having an outsized impact, all centroids that are closer than
337+
* the largest segments intra-cluster distance are merged into a single centroid.
338+
* The resulting centroids are then used to initialize the KMeans algorithm.
339+
*
340+
* @param fieldInfo merging field info
341+
* @param floatVectorValues the float vector values to merge
342+
* @param temporaryCentroidOutput the temporary centroid output
343+
* @param mergeState the merge state
344+
* @param globalCentroid the global centroid, calculated by this method and used to quantize the centroids
345+
* @return the number of centroids written
346+
* @throws IOException if an I/O error occurs
347+
*/
348+
@Override
349+
protected int calculateAndWriteCentroids(
350+
FieldInfo fieldInfo,
351+
FloatVectorValues floatVectorValues,
352+
IndexOutput temporaryCentroidOutput,
353+
MergeState mergeState,
354+
float[] globalCentroid
355+
) throws IOException {
356+
if (floatVectorValues.size() == 0) {
357+
return 0;
358+
}
359+
int maxNumClusters = ((floatVectorValues.size() - 1) / vectorPerCluster) + 1;
360+
int desiredClusters = (int) Math.max(Math.sqrt(floatVectorValues.size()), maxNumClusters);
361+
// init centroids from merge state
362+
List<FloatVectorValues> centroidList = new ArrayList<>();
363+
List<SegmentCentroid> segmentCentroids = new ArrayList<>(desiredClusters);
364+
365+
int segmentIdx = 0;
366+
for (var reader : mergeState.knnVectorsReaders) {
367+
IVFVectorsReader ivfVectorsReader = IVFVectorsFormat.getIVFReader(reader, fieldInfo.name);
368+
if (ivfVectorsReader == null) {
369+
continue;
370+
}
371+
372+
FloatVectorValues centroid = ivfVectorsReader.getCentroids(fieldInfo);
373+
if (centroid == null) {
374+
continue;
375+
}
376+
centroidList.add(centroid);
377+
for (int i = 0; i < centroid.size(); i++) {
378+
int size = ivfVectorsReader.centroidSize(fieldInfo.name, i);
379+
if (size == 0) {
380+
continue;
381+
}
382+
segmentCentroids.add(new SegmentCentroid(segmentIdx, i, size));
383+
}
384+
segmentIdx++;
385+
}
386+
387+
float[][] initCentroids = gatherInitCentroids(centroidList, segmentCentroids, desiredClusters, fieldInfo, mergeState);
351388

352389
// FIXME: run a custom version of KMeans that is just better...
353390
long nanoTime = System.nanoTime();
@@ -369,6 +406,15 @@ protected int calculateAndWriteCentroids(
369406
float[][] centroids = kMeans.centroids();
370407

371408
// write them
409+
// calculate the global centroid from all the centroids:
410+
for (float[] centroid : centroids) {
411+
for (int j = 0; j < centroid.length; j++) {
412+
globalCentroid[j] += centroid[j];
413+
}
414+
}
415+
for (int j = 0; j < globalCentroid.length; j++) {
416+
globalCentroid[j] /= centroids.length;
417+
}
372418
writeCentroids(centroids, fieldInfo, globalCentroid, temporaryCentroidOutput);
373419
return centroids.length;
374420
}
@@ -477,14 +523,11 @@ static void assignCentroids(CentroidAssignmentScorer scorer, FloatVectorValues v
477523
// pop the best
478524
int sz = neighborsToCheck.size();
479525
int best = neighborsToCheck.consumeNodesAndScoresMin(ordScoreIterator.ords, ordScoreIterator.scores);
480-
// reset the ordScoreIterator as it has consumed the ords and scores
481-
ordScoreIterator.idx = sz;
526+
// Set the size to the number of neighbors we actually found
527+
ordScoreIterator.setSize(sz);
482528
bestScore = ordScoreIterator.getScore(best);
483529
bestCentroid = ordScoreIterator.getOrd(best);
484530
}
485-
if (clusters[bestCentroid] == null) {
486-
clusters[bestCentroid] = new IntArrayList(16);
487-
}
488531
clusters[bestCentroid].add(docID);
489532
if (soarClusterCheckCount > 0) {
490533
assignCentroidSOAR(
@@ -495,7 +538,7 @@ static void assignCentroids(CentroidAssignmentScorer scorer, FloatVectorValues v
495538
bestScore,
496539
scratch,
497540
scorer,
498-
vectors,
541+
vector,
499542
clusters
500543
);
501544
}
@@ -511,10 +554,9 @@ static void assignCentroidSOAR(
511554
float bestScore,
512555
float[] scratch,
513556
CentroidAssignmentScorer scorer,
514-
FloatVectorValues vectors,
557+
float[] vector,
515558
IntArrayList[] clusters
516559
) throws IOException {
517-
float[] vector = vectors.vectorValue(vecOrd);
518560
ESVectorUtil.subtract(vector, bestCentroid, scratch);
519561
int bestSecondaryCentroid = -1;
520562
float minDist = Float.MAX_VALUE;
@@ -546,6 +588,14 @@ static class OrdScoreIterator {
546588
this.scores = new float[size];
547589
}
548590

591+
int setSize(int size) {
592+
if (size > ords.length) {
593+
throw new IllegalArgumentException("size must be <= " + ords.length);
594+
}
595+
this.idx = size;
596+
return size;
597+
}
598+
549599
int getOrd(int idx) {
550600
return ords[idx];
551601
}
@@ -606,15 +656,15 @@ static class OffHeapCentroidAssignmentScorer implements CentroidAssignmentScorer
606656
private final int dimension;
607657
private final float[] scratch;
608658
private float[] q;
609-
private final long centroidByteSize;
659+
private final long rawCentroidOffset;
610660
private int currOrd = -1;
611661

612662
OffHeapCentroidAssignmentScorer(IndexInput centroidsInput, int numCentroids, FieldInfo info) {
613663
this.centroidsInput = centroidsInput;
614664
this.numCentroids = numCentroids;
615665
this.dimension = info.getVectorDimension();
616666
this.scratch = new float[dimension];
617-
this.centroidByteSize = dimension + 3 * Float.BYTES + Short.BYTES;
667+
this.rawCentroidOffset = (dimension + 3 * Float.BYTES + Short.BYTES) * numCentroids;
618668
}
619669

620670
@Override
@@ -627,7 +677,7 @@ public float[] centroid(int centroidOrdinal) throws IOException {
627677
if (centroidOrdinal == currOrd) {
628678
return scratch;
629679
}
630-
centroidsInput.seek(numCentroids * centroidByteSize + (long) centroidOrdinal * dimension * Float.BYTES);
680+
centroidsInput.seek(rawCentroidOffset + (long) centroidOrdinal * dimension * Float.BYTES);
631681
centroidsInput.readFloats(scratch, 0, dimension);
632682
this.currOrd = centroidOrdinal;
633683
return scratch;

server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java

Lines changed: 23 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,6 @@ public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
192192
floatVectorValues,
193193
ivfClusters
194194
);
195-
// write posting lists
196195
writeMeta(fieldWriter.fieldInfo, centroidOffset, centroidLength, offsets, globalCentroid);
197196
}
198197
}
@@ -256,54 +255,25 @@ public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) thro
256255
rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
257256
if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
258257
final int numVectors;
259-
String name = null;
258+
String tempRawVectorsFileName = null;
260259
boolean success = false;
261260
// build a float vector values with random access. In order to do that we dump the vectors to
262261
// a temporary file
263262
// and write the docID follow by the vector
264263
try (IndexOutput out = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "ivf_", IOContext.DEFAULT)) {
265-
name = out.getName();
264+
tempRawVectorsFileName = out.getName();
266265
// TODO do this better, we shouldn't have to write to a temp file, we should be able to
267-
// to just from the merged vector values.
266+
// to just from the merged vector values, the tricky part is the random access.
268267
numVectors = writeFloatVectorValues(fieldInfo, out, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState));
268+
CodecUtil.writeFooter(out);
269269
success = true;
270270
} finally {
271-
if (success == false && name != null) {
272-
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, name);
271+
if (success == false && tempRawVectorsFileName != null) {
272+
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, tempRawVectorsFileName);
273273
}
274274
}
275-
float[] globalCentroid = new float[fieldInfo.getVectorDimension()];
276-
int vectorCount = 0;
277-
for (int idx = 0; idx < mergeState.knnVectorsReaders.length; idx++) {
278-
if (mergeState.fieldInfos[idx] == null
279-
|| mergeState.fieldInfos[idx].hasVectorValues() == false
280-
|| mergeState.fieldInfos[idx].fieldInfo(fieldInfo.name) == null
281-
|| mergeState.fieldInfos[idx].fieldInfo(fieldInfo.name).hasVectorValues() == false) {
282-
continue;
283-
}
284-
KnnVectorsReader knnReaders = mergeState.knnVectorsReaders[idx];
285-
IVFVectorsReader ivfReader = getIVFReader(knnReaders, fieldInfo.name);
286-
if (ivfReader != null) {
287-
FloatVectorValues floatVectorValues = knnReaders.getFloatVectorValues(fieldInfo.name);
288-
if (floatVectorValues == null) {
289-
continue;
290-
}
291-
int numVecs = floatVectorValues.size();
292-
float[] readerGlobalCentroid = ivfReader.getGlobalCentroid(fieldInfo);
293-
if (readerGlobalCentroid != null) {
294-
vectorCount += numVecs;
295-
for (int i = 0; i < globalCentroid.length; i++) {
296-
globalCentroid[i] += readerGlobalCentroid[i] * numVecs;
297-
}
298-
}
299-
}
300-
}
301-
if (vectorCount > 0) {
302-
for (int i = 0; i < globalCentroid.length; i++) {
303-
globalCentroid[i] /= vectorCount;
304-
}
305-
}
306-
try (IndexInput in = mergeState.segmentInfo.dir.openInput(name, IOContext.DEFAULT)) {
275+
try (IndexInput in = mergeState.segmentInfo.dir.openInput(tempRawVectorsFileName, IOContext.DEFAULT)) {
276+
float[] calculatedGlobalCentroid = new float[fieldInfo.getVectorDimension()];
307277
final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo, in, numVectors);
308278
success = false;
309279
CentroidAssignmentScorer centroidAssignmentScorer;
@@ -315,7 +285,13 @@ public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) thro
315285
try {
316286
centroidTemp = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "civf_", IOContext.DEFAULT);
317287
centroidTempName = centroidTemp.getName();
318-
numCentroids = calculateAndWriteCentroids(fieldInfo, floatVectorValues, centroidTemp, mergeState, globalCentroid);
288+
numCentroids = calculateAndWriteCentroids(
289+
fieldInfo,
290+
floatVectorValues,
291+
centroidTemp,
292+
mergeState,
293+
calculatedGlobalCentroid
294+
);
319295
success = true;
320296
} finally {
321297
if (success == false && centroidTempName != null) {
@@ -337,7 +313,7 @@ public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) thro
337313
try (IndexInput centroidInput = mergeState.segmentInfo.dir.openInput(centroidTempName, IOContext.DEFAULT)) {
338314
ivfCentroids.copyBytes(centroidInput, centroidInput.length() - CodecUtil.footerLength());
339315
centroidLength = ivfCentroids.getFilePointer() - centroidOffset;
340-
centroidAssignmentScorer = createCentroidScorer(centroidInput, numCentroids, fieldInfo, globalCentroid);
316+
centroidAssignmentScorer = createCentroidScorer(centroidInput, numCentroids, fieldInfo, calculatedGlobalCentroid);
341317
assert centroidAssignmentScorer.size() == numCentroids;
342318
// build a float vector values with random access
343319
// build centroids
@@ -348,20 +324,18 @@ public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) thro
348324
ivfClusters,
349325
mergeState
350326
);
351-
// write posting lists
352-
353-
// TODO handle this correctly by creating new centroid
354-
if (vectorCount == 0 && offsets.length > 0) {
355-
throw new IllegalStateException("No global centroid found for field: " + fieldInfo.name);
356-
}
357327
assert offsets.length == centroidAssignmentScorer.size();
358-
writeMeta(fieldInfo, centroidOffset, centroidLength, offsets, globalCentroid);
328+
writeMeta(fieldInfo, centroidOffset, centroidLength, offsets, calculatedGlobalCentroid);
359329
}
360330
} finally {
361-
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, name, centroidTempName);
331+
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(
332+
mergeState.segmentInfo.dir,
333+
tempRawVectorsFileName,
334+
centroidTempName
335+
);
362336
}
363337
} finally {
364-
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, name);
338+
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, tempRawVectorsFileName);
365339
}
366340
}
367341
}

0 commit comments

Comments
 (0)