@@ -100,18 +100,15 @@ private static boolean stepLloyd(
100100 }
101101 }
102102 if (changed ) {
103- for (int clusterIdx = 0 ; clusterIdx < centroids .length ; clusterIdx ++) {
104- if (centroidChanged .get (clusterIdx )) {
105- Arrays .fill (centroids [clusterIdx ], 0.0f );
106- }
107- }
108103 Arrays .fill (centroidCounts , 0 );
109104 for (int idx = 0 ; idx < vectors .size (); idx ++) {
110105 final int assignment = assignments [translateOrd .apply (idx )];
111106 if (centroidChanged .get (assignment )) {
112- centroidCounts [assignment ]++;
113- float [] vector = vectors .vectorValue (idx );
114107 float [] centroid = centroids [assignment ];
108+ if (centroidCounts [assignment ]++ == 0 ) {
109+ Arrays .fill (centroid , 0.0f );
110+ }
111+ float [] vector = vectors .vectorValue (idx );
115112 for (int d = 0 ; d < dim ; d ++) {
116113 centroid [d ] += vector [d ];
117114 }
@@ -120,9 +117,12 @@ private static boolean stepLloyd(
120117
121118 for (int clusterIdx = 0 ; clusterIdx < centroids .length ; clusterIdx ++) {
122119 if (centroidChanged .get (clusterIdx )) {
123- float countF = (float ) centroidCounts [clusterIdx ];
124- for (int d = 0 ; d < dim ; d ++) {
125- centroids [clusterIdx ][d ] /= countF ;
120+ float count = (float ) centroidCounts [clusterIdx ];
121+ if (count > 0 ) {
122+ float [] centroid = centroids [clusterIdx ];
123+ for (int d = 0 ; d < dim ; d ++) {
124+ centroid [d ] /= count ;
125+ }
126126 }
127127 }
128128 }
0 commit comments