@@ -106,29 +106,57 @@ KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int ta
106
106
// TODO: consider adding cluster size counts to the kmeans algo
107
107
// handle assignment here so we can track distance and cluster size
108
108
int [] centroidVectorCount = new int [centroids .length ];
109
+ int effectiveCluster = -1 ;
109
110
int effectiveK = 0 ;
110
111
for (int assigment : assignments ) {
111
112
centroidVectorCount [assigment ]++;
112
113
// this cluster has received an assignment, its now effective, but only count it once
113
114
if (centroidVectorCount [assigment ] == 1 ) {
114
115
effectiveK ++;
116
+ effectiveCluster = assigment ;
115
117
}
116
118
}
117
119
118
120
if (effectiveK == 1 ) {
121
+ final float [][] singleClusterCentroid = new float [1 ][];
122
+ singleClusterCentroid [0 ] = centroids [effectiveCluster ];
123
+ kMeansIntermediate .setCentroids (singleClusterCentroid );
124
+ Arrays .fill (kMeansIntermediate .assignments (), 0 );
119
125
return kMeansIntermediate ;
120
126
}
121
127
128
+ int removedElements = 0 ;
122
129
for (int c = 0 ; c < centroidVectorCount .length ; c ++) {
123
130
// Recurse for each cluster which is larger than targetSize
124
131
// Give ourselves 30% margin for the target size
125
- if (100 * centroidVectorCount [c ] > 134 * targetSize ) {
126
- FloatVectorValues sample = createClusterSlice (centroidVectorCount [c ], c , vectors , assignments );
127
-
132
+ final int count = centroidVectorCount [c ];
133
+ final int adjustedCentroid = c - removedElements ;
134
+ if (100 * count > 134 * targetSize ) {
135
+ final FloatVectorValues sample = createClusterSlice (count , adjustedCentroid , vectors , assignments );
128
136
// TODO: consider iterative here instead of recursive
129
137
// recursive call to build out the sub partitions around this centroid c
130
138
// subsequently reconcile and flatten the space of all centroids and assignments into one structure we can return
131
- updateAssignmentsWithRecursiveSplit (kMeansIntermediate , c , clusterAndSplit (sample , targetSize ));
139
+ updateAssignmentsWithRecursiveSplit (kMeansIntermediate , adjustedCentroid , clusterAndSplit (sample , targetSize ));
140
+ } else if (count == 0 ) {
141
+ // remove empty clusters
142
+ final int newSize = kMeansIntermediate .centroids ().length - 1 ;
143
+ final float [][] newCentroids = new float [newSize ][];
144
+ System .arraycopy (kMeansIntermediate .centroids (), 0 , newCentroids , 0 , adjustedCentroid );
145
+ System .arraycopy (
146
+ kMeansIntermediate .centroids (),
147
+ adjustedCentroid + 1 ,
148
+ newCentroids ,
149
+ adjustedCentroid ,
150
+ newSize - adjustedCentroid
151
+ );
152
+ // we need to update the assignments to reflect the new centroid ordinals
153
+ for (int i = 0 ; i < kMeansIntermediate .assignments ().length ; i ++) {
154
+ if (kMeansIntermediate .assignments ()[i ] > adjustedCentroid ) {
155
+ kMeansIntermediate .assignments ()[i ]--;
156
+ }
157
+ }
158
+ kMeansIntermediate .setCentroids (newCentroids );
159
+ removedElements ++;
132
160
}
133
161
}
134
162
0 commit comments