@@ -70,17 +70,14 @@ private static boolean stepLloyd(
7070 FloatVectorValues vectors ,
7171 IntToIntFunction translateOrd ,
7272 float [][] centroids ,
73- float [][] nextCentroids ,
73+ int [] centroidCounts ,
7474 int [] assignments ,
7575 NeighborHood [] neighborhoods
7676 ) throws IOException {
7777 boolean changed = false ;
7878 int dim = vectors .dimension ();
79- int [] centroidCounts = new int [ centroids . length ] ;
79+ Arrays . fill ( centroidCounts , 0 ) ;
8080
81- for (float [] nextCentroid : nextCentroids ) {
82- Arrays .fill (nextCentroid , 0.0f );
83- }
8481 final float [] distances = new float [4 ];
8582 for (int idx = 0 ; idx < vectors .size (); idx ++) {
8683 float [] vector = vectors .vectorValue (idx );
@@ -97,16 +94,32 @@ private static boolean stepLloyd(
9794 changed = true ;
9895 }
9996 centroidCounts [bestCentroidOffset ]++;
100- for (int d = 0 ; d < dim ; d ++) {
101- nextCentroids [bestCentroidOffset ][d ] += vector [d ];
102- }
10397 }
10498
105- for (int clusterIdx = 0 ; clusterIdx < centroids .length ; clusterIdx ++) {
106- if (centroidCounts [clusterIdx ] > 0 ) {
107- float countF = (float ) centroidCounts [clusterIdx ];
108- for (int d = 0 ; d < dim ; d ++) {
109- centroids [clusterIdx ][d ] = nextCentroids [clusterIdx ][d ] / countF ;
99+ if (changed ) {
100+ for (int clusterIdx = 0 ; clusterIdx < centroids .length ; clusterIdx ++) {
101+ if (centroidCounts [clusterIdx ] > 0 ) {
102+ Arrays .fill (centroids [clusterIdx ], 0.0f );
103+ }
104+ }
105+
106+ for (int idx = 0 ; idx < vectors .size (); idx ++) {
107+ final int assignment = assignments [translateOrd .apply (idx )];
108+ if (centroidCounts [assignment ] > 0 ) {
109+ float [] vector = vectors .vectorValue (idx );
110+ float [] centroid = centroids [assignment ];
111+ for (int d = 0 ; d < dim ; d ++) {
112+ centroid [d ] += vector [d ];
113+ }
114+ }
115+ }
116+
117+ for (int clusterIdx = 0 ; clusterIdx < centroids .length ; clusterIdx ++) {
118+ if (centroidCounts [clusterIdx ] > 0 ) {
119+ float countF = (float ) centroidCounts [clusterIdx ];
120+ for (int d = 0 ; d < dim ; d ++) {
121+ centroids [clusterIdx ][d ] /= countF ;
122+ }
110123 }
111124 }
112125 }
@@ -420,17 +433,17 @@ private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansInterme
420433 }
421434
422435 assert assignments .length == n ;
423- float [][] nextCentroids = new float [centroids .length ][ vectors . dimension () ];
436+ int [] centroidCounts = new int [centroids .length ];
424437 for (int i = 0 ; i < maxIterations ; i ++) {
425438 // This is potentially sampled, so we need to translate ordinals
426- if (stepLloyd (sampledVectors , translateOrd , centroids , nextCentroids , assignments , neighborhoods ) == false ) {
439+ if (stepLloyd (sampledVectors , translateOrd , centroids , centroidCounts , assignments , neighborhoods ) == false ) {
427440 break ;
428441 }
429442 }
430443 // If we were sampled, do a once over the full set of vectors to finalize the centroids
431444 if (sampleSize < n || maxIterations == 0 ) {
432445 // No ordinal translation needed here, we are using the full set of vectors
433- stepLloyd (vectors , i -> i , centroids , nextCentroids , assignments , neighborhoods );
446+ stepLloyd (vectors , i -> i , centroids , centroidCounts , assignments , neighborhoods );
434447 }
435448 }
436449
0 commit comments