@@ -106,29 +106,57 @@ KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int ta
106106 // TODO: consider adding cluster size counts to the kmeans algo
107107 // handle assignment here so we can track distance and cluster size
108108 int [] centroidVectorCount = new int [centroids .length ];
109+ int effectiveCluster = -1 ;
109110 int effectiveK = 0 ;
110111 for (int assigment : assignments ) {
111112 centroidVectorCount [assigment ]++;
112113 // this cluster has received an assignment, its now effective, but only count it once
113114 if (centroidVectorCount [assigment ] == 1 ) {
114115 effectiveK ++;
116+ effectiveCluster = assigment ;
115117 }
116118 }
117119
118120 if (effectiveK == 1 ) {
121+ final float [][] singleClusterCentroid = new float [1 ][];
122+ singleClusterCentroid [0 ] = centroids [effectiveCluster ];
123+ kMeansIntermediate .setCentroids (singleClusterCentroid );
124+ Arrays .fill (kMeansIntermediate .assignments (), 0 );
119125 return kMeansIntermediate ;
120126 }
121127
128+ int removedElements = 0 ;
122129 for (int c = 0 ; c < centroidVectorCount .length ; c ++) {
123130 // Recurse for each cluster which is larger than targetSize
124131 // Give ourselves 30% margin for the target size
125- if (100 * centroidVectorCount [c ] > 134 * targetSize ) {
126- FloatVectorValues sample = createClusterSlice (centroidVectorCount [c ], c , vectors , assignments );
127-
132+ final int count = centroidVectorCount [c ];
133+ final int adjustedCentroid = c - removedElements ;
134+ if (100 * count > 134 * targetSize ) {
135+ final FloatVectorValues sample = createClusterSlice (count , adjustedCentroid , vectors , assignments );
128136 // TODO: consider iterative here instead of recursive
129137 // recursive call to build out the sub partitions around this centroid c
130138 // subsequently reconcile and flatten the space of all centroids and assignments into one structure we can return
131- updateAssignmentsWithRecursiveSplit (kMeansIntermediate , c , clusterAndSplit (sample , targetSize ));
139+ updateAssignmentsWithRecursiveSplit (kMeansIntermediate , adjustedCentroid , clusterAndSplit (sample , targetSize ));
140+ } else if (count == 0 ) {
141+ // remove empty clusters
142+ final int newSize = kMeansIntermediate .centroids ().length - 1 ;
143+ final float [][] newCentroids = new float [newSize ][];
144+ System .arraycopy (kMeansIntermediate .centroids (), 0 , newCentroids , 0 , adjustedCentroid );
145+ System .arraycopy (
146+ kMeansIntermediate .centroids (),
147+ adjustedCentroid + 1 ,
148+ newCentroids ,
149+ adjustedCentroid ,
150+ newSize - adjustedCentroid
151+ );
152+ // we need to update the assignments to reflect the new centroid ordinals
153+ for (int i = 0 ; i < kMeansIntermediate .assignments ().length ; i ++) {
154+ if (kMeansIntermediate .assignments ()[i ] > adjustedCentroid ) {
155+ kMeansIntermediate .assignments ()[i ]--;
156+ }
157+ }
158+ kMeansIntermediate .setCentroids (newCentroids );
159+ removedElements ++;
132160 }
133161 }
134162
0 commit comments