1010package org .elasticsearch .index .codec .vectors .cluster ;
1111
1212import org .apache .lucene .index .FloatVectorValues ;
13+ import org .apache .lucene .util .FixedBitSet ;
1314import org .apache .lucene .util .VectorUtil ;
1415import org .apache .lucene .util .hnsw .IntToIntFunction ;
1516import org .elasticsearch .index .codec .vectors .SampleReader ;
@@ -70,17 +71,14 @@ private static boolean stepLloyd(
7071 FloatVectorValues vectors ,
7172 IntToIntFunction translateOrd ,
7273 float [][] centroids ,
73- float [][] nextCentroids ,
74+ FixedBitSet centroidChanged ,
75+ int [] centroidCounts ,
7476 int [] assignments ,
7577 NeighborHood [] neighborhoods
7678 ) throws IOException {
7779 boolean changed = false ;
7880 int dim = vectors .dimension ();
79- int [] centroidCounts = new int [centroids .length ];
80-
81- for (float [] nextCentroid : nextCentroids ) {
82- Arrays .fill (nextCentroid , 0.0f );
83- }
81+ centroidChanged .clear ();
8482 final float [] distances = new float [4 ];
8583 for (int idx = 0 ; idx < vectors .size (); idx ++) {
8684 float [] vector = vectors .vectorValue (idx );
@@ -93,20 +91,39 @@ private static boolean stepLloyd(
9391 bestCentroidOffset = getBestCentroid (centroids , vector , distances );
9492 }
9593 if (assignment != bestCentroidOffset ) {
94+ if (assignment != -1 ) {
95+ centroidChanged .set (assignment );
96+ }
97+ centroidChanged .set (bestCentroidOffset );
9698 assignments [vectorOrd ] = bestCentroidOffset ;
9799 changed = true ;
98100 }
99- centroidCounts [bestCentroidOffset ]++;
100- for (int d = 0 ; d < dim ; d ++) {
101- nextCentroids [bestCentroidOffset ][d ] += vector [d ];
102- }
103101 }
102+ if (changed ) {
103+ Arrays .fill (centroidCounts , 0 );
104+ for (int idx = 0 ; idx < vectors .size (); idx ++) {
105+ final int assignment = assignments [translateOrd .apply (idx )];
106+ if (centroidChanged .get (assignment )) {
107+ float [] centroid = centroids [assignment ];
108+ if (centroidCounts [assignment ]++ == 0 ) {
109+ Arrays .fill (centroid , 0.0f );
110+ }
111+ float [] vector = vectors .vectorValue (idx );
112+ for (int d = 0 ; d < dim ; d ++) {
113+ centroid [d ] += vector [d ];
114+ }
115+ }
116+ }
104117
105- for (int clusterIdx = 0 ; clusterIdx < centroids .length ; clusterIdx ++) {
106- if (centroidCounts [clusterIdx ] > 0 ) {
107- float countF = (float ) centroidCounts [clusterIdx ];
108- for (int d = 0 ; d < dim ; d ++) {
109- centroids [clusterIdx ][d ] = nextCentroids [clusterIdx ][d ] / countF ;
118+ for (int clusterIdx = 0 ; clusterIdx < centroids .length ; clusterIdx ++) {
119+ if (centroidChanged .get (clusterIdx )) {
120+ float count = (float ) centroidCounts [clusterIdx ];
121+ if (count > 0 ) {
122+ float [] centroid = centroids [clusterIdx ];
123+ for (int d = 0 ; d < dim ; d ++) {
124+ centroid [d ] /= count ;
125+ }
126+ }
110127 }
111128 }
112129 }
@@ -420,17 +437,18 @@ private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansInterme
420437 }
421438
422439 assert assignments .length == n ;
423- float [][] nextCentroids = new float [centroids .length ][vectors .dimension ()];
440+ FixedBitSet centroidChanged = new FixedBitSet (centroids .length );
441+ int [] centroidCounts = new int [centroids .length ];
424442 for (int i = 0 ; i < maxIterations ; i ++) {
425443 // This is potentially sampled, so we need to translate ordinals
426- if (stepLloyd (sampledVectors , translateOrd , centroids , nextCentroids , assignments , neighborhoods ) == false ) {
444+ if (stepLloyd (sampledVectors , translateOrd , centroids , centroidChanged , centroidCounts , assignments , neighborhoods ) == false ) {
427445 break ;
428446 }
429447 }
430448 // If we were sampled, do a once over the full set of vectors to finalize the centroids
431449 if (sampleSize < n || maxIterations == 0 ) {
432450 // No ordinal translation needed here, we are using the full set of vectors
433- stepLloyd (vectors , i -> i , centroids , nextCentroids , assignments , neighborhoods );
451+ stepLloyd (vectors , i -> i , centroids , centroidChanged , centroidCounts , assignments , neighborhoods );
434452 }
435453 }
436454
0 commit comments