@@ -223,44 +223,17 @@ static void writeCentroids(float[][] centroids, FieldInfo fieldInfo, float[] glo
223223 }
224224 }
225225
226- record SegmentCentroid ( int segment , int centroid , int centroidSize ) {}
227-
228- @ Override
229- protected int calculateAndWriteCentroids (
226+ static float [][] gatherInitCentroids (
227+ List < FloatVectorValues > centroidList ,
228+ List < SegmentCentroid > segmentCentroids ,
229+ int desiredClusters ,
230230 FieldInfo fieldInfo ,
231- FloatVectorValues floatVectorValues ,
232- IndexOutput temporaryCentroidOutput ,
233- MergeState mergeState ,
234- float [] globalCentroid
231+ MergeState mergeState
235232 ) throws IOException {
236- if (floatVectorValues .size () == 0 ) {
237- return 0 ;
233+ if (centroidList .size () == 0 ) {
234+ return null ;
238235 }
239- int desiredClusters = ((floatVectorValues .size () - 1 ) / vectorPerCluster ) + 1 ;
240- // init centroids from merge state
241- List <FloatVectorValues > centroidList = new ArrayList <>();
242- List <SegmentCentroid > segmentCentroids = new ArrayList <>(desiredClusters );
243-
244- int segmentIdx = 0 ;
245236 long startTime = System .nanoTime ();
246- for (var reader : mergeState .knnVectorsReaders ) {
247- IVFVectorsReader ivfVectorsReader = IVFVectorsFormat .getIVFReader (reader , fieldInfo .name );
248- if (ivfVectorsReader == null ) {
249- continue ;
250- }
251-
252- FloatVectorValues centroid = ivfVectorsReader .getCentroids (fieldInfo );
253- if (centroid == null ) {
254- continue ;
255- }
256- centroidList .add (centroid );
257- for (int i = 0 ; i < centroid .size (); i ++) {
258- int size = ivfVectorsReader .centroidSize (fieldInfo .name , i );
259- segmentCentroids .add (new SegmentCentroid (segmentIdx , i , size ));
260- }
261- segmentIdx ++;
262- }
263-
264237 // sort centroid list by floatvector size
265238 FloatVectorValues baseSegment = centroidList .get (0 );
266239 for (var l : centroidList ) {
@@ -334,6 +307,9 @@ protected int calculateAndWriteCentroids(
334307 sum [label - 1 ] += segmentCentroid .centroidSize ;
335308 }
336309 for (int i = 0 ; i < initCentroids .length ; i ++) {
310+ if (sum [i ] == 0 || sum [i ] == 1 ) {
311+ continue ;
312+ }
337313 for (int j = 0 ; j < initCentroids [i ].length ; j ++) {
338314 initCentroids [i ][j ] /= sum [i ];
339315 }
@@ -348,6 +324,67 @@ protected int calculateAndWriteCentroids(
348324 "Gathered initCentroids:" + initCentroids .length + " for desired: " + desiredClusters
349325 );
350326 }
327+ return initCentroids ;
328+ }
329+
330+ record SegmentCentroid (int segment , int centroid , int centroidSize ) {}
331+
332+ /**
333+ * Calculate the centroids for the given field and write them to the given
334+ * temporary centroid output.
335+ * When merging, we first bootstrap the KMeans algorithm with the centroids contained in the merging segments.
336+ * To prevent centroids that are too similar from having an outsized impact, all centroids that are closer than
337+ * the largest segments intra-cluster distance are merged into a single centroid.
338+ * The resulting centroids are then used to initialize the KMeans algorithm.
339+ *
340+ * @param fieldInfo merging field info
341+ * @param floatVectorValues the float vector values to merge
342+ * @param temporaryCentroidOutput the temporary centroid output
343+ * @param mergeState the merge state
344+ * @param globalCentroid the global centroid, calculated by this method and used to quantize the centroids
345+ * @return the number of centroids written
346+ * @throws IOException if an I/O error occurs
347+ */
348+ @ Override
349+ protected int calculateAndWriteCentroids (
350+ FieldInfo fieldInfo ,
351+ FloatVectorValues floatVectorValues ,
352+ IndexOutput temporaryCentroidOutput ,
353+ MergeState mergeState ,
354+ float [] globalCentroid
355+ ) throws IOException {
356+ if (floatVectorValues .size () == 0 ) {
357+ return 0 ;
358+ }
359+ int maxNumClusters = ((floatVectorValues .size () - 1 ) / vectorPerCluster ) + 1 ;
360+ int desiredClusters = (int ) Math .max (Math .sqrt (floatVectorValues .size ()), maxNumClusters );
361+ // init centroids from merge state
362+ List <FloatVectorValues > centroidList = new ArrayList <>();
363+ List <SegmentCentroid > segmentCentroids = new ArrayList <>(desiredClusters );
364+
365+ int segmentIdx = 0 ;
366+ for (var reader : mergeState .knnVectorsReaders ) {
367+ IVFVectorsReader ivfVectorsReader = IVFVectorsFormat .getIVFReader (reader , fieldInfo .name );
368+ if (ivfVectorsReader == null ) {
369+ continue ;
370+ }
371+
372+ FloatVectorValues centroid = ivfVectorsReader .getCentroids (fieldInfo );
373+ if (centroid == null ) {
374+ continue ;
375+ }
376+ centroidList .add (centroid );
377+ for (int i = 0 ; i < centroid .size (); i ++) {
378+ int size = ivfVectorsReader .centroidSize (fieldInfo .name , i );
379+ if (size == 0 ) {
380+ continue ;
381+ }
382+ segmentCentroids .add (new SegmentCentroid (segmentIdx , i , size ));
383+ }
384+ segmentIdx ++;
385+ }
386+
387+ float [][] initCentroids = gatherInitCentroids (centroidList , segmentCentroids , desiredClusters , fieldInfo , mergeState );
351388
352389 // FIXME: run a custom version of KMeans that is just better...
353390 long nanoTime = System .nanoTime ();
@@ -369,6 +406,15 @@ protected int calculateAndWriteCentroids(
369406 float [][] centroids = kMeans .centroids ();
370407
371408 // write them
409+ // calculate the global centroid from all the centroids:
410+ for (float [] centroid : centroids ) {
411+ for (int j = 0 ; j < centroid .length ; j ++) {
412+ globalCentroid [j ] += centroid [j ];
413+ }
414+ }
415+ for (int j = 0 ; j < globalCentroid .length ; j ++) {
416+ globalCentroid [j ] /= centroids .length ;
417+ }
372418 writeCentroids (centroids , fieldInfo , globalCentroid , temporaryCentroidOutput );
373419 return centroids .length ;
374420 }
@@ -477,14 +523,11 @@ static void assignCentroids(CentroidAssignmentScorer scorer, FloatVectorValues v
477523 // pop the best
478524 int sz = neighborsToCheck .size ();
479525 int best = neighborsToCheck .consumeNodesAndScoresMin (ordScoreIterator .ords , ordScoreIterator .scores );
480- // reset the ordScoreIterator as it has consumed the ords and scores
481- ordScoreIterator .idx = sz ;
526+ // Set the size to the number of neighbors we actually found
527+ ordScoreIterator .setSize ( sz ) ;
482528 bestScore = ordScoreIterator .getScore (best );
483529 bestCentroid = ordScoreIterator .getOrd (best );
484530 }
485- if (clusters [bestCentroid ] == null ) {
486- clusters [bestCentroid ] = new IntArrayList (16 );
487- }
488531 clusters [bestCentroid ].add (docID );
489532 if (soarClusterCheckCount > 0 ) {
490533 assignCentroidSOAR (
@@ -495,7 +538,7 @@ static void assignCentroids(CentroidAssignmentScorer scorer, FloatVectorValues v
495538 bestScore ,
496539 scratch ,
497540 scorer ,
498- vectors ,
541+ vector ,
499542 clusters
500543 );
501544 }
@@ -511,10 +554,9 @@ static void assignCentroidSOAR(
511554 float bestScore ,
512555 float [] scratch ,
513556 CentroidAssignmentScorer scorer ,
514- FloatVectorValues vectors ,
557+ float [] vector ,
515558 IntArrayList [] clusters
516559 ) throws IOException {
517- float [] vector = vectors .vectorValue (vecOrd );
518560 ESVectorUtil .subtract (vector , bestCentroid , scratch );
519561 int bestSecondaryCentroid = -1 ;
520562 float minDist = Float .MAX_VALUE ;
@@ -546,6 +588,14 @@ static class OrdScoreIterator {
546588 this .scores = new float [size ];
547589 }
548590
591+ int setSize (int size ) {
592+ if (size > ords .length ) {
593+ throw new IllegalArgumentException ("size must be <= " + ords .length );
594+ }
595+ this .idx = size ;
596+ return size ;
597+ }
598+
549599 int getOrd (int idx ) {
550600 return ords [idx ];
551601 }
@@ -606,15 +656,15 @@ static class OffHeapCentroidAssignmentScorer implements CentroidAssignmentScorer
606656 private final int dimension ;
607657 private final float [] scratch ;
608658 private float [] q ;
609- private final long centroidByteSize ;
659+ private final long rawCentroidOffset ;
610660 private int currOrd = -1 ;
611661
612662 OffHeapCentroidAssignmentScorer (IndexInput centroidsInput , int numCentroids , FieldInfo info ) {
613663 this .centroidsInput = centroidsInput ;
614664 this .numCentroids = numCentroids ;
615665 this .dimension = info .getVectorDimension ();
616666 this .scratch = new float [dimension ];
617- this .centroidByteSize = dimension + 3 * Float .BYTES + Short .BYTES ;
667+ this .rawCentroidOffset = ( dimension + 3 * Float .BYTES + Short .BYTES ) * numCentroids ;
618668 }
619669
620670 @ Override
@@ -627,7 +677,7 @@ public float[] centroid(int centroidOrdinal) throws IOException {
627677 if (centroidOrdinal == currOrd ) {
628678 return scratch ;
629679 }
630- centroidsInput .seek (numCentroids * centroidByteSize + (long ) centroidOrdinal * dimension * Float .BYTES );
680+ centroidsInput .seek (rawCentroidOffset + (long ) centroidOrdinal * dimension * Float .BYTES );
631681 centroidsInput .readFloats (scratch , 0 , dimension );
632682 this .currOrd = centroidOrdinal ;
633683 return scratch ;
0 commit comments