Skip to content

Commit 216753a

Browse files
authored
Remove empty clusters in Hierarchical k-means (#132569)
This commit makes sure the result of the hierarchical k-means cannot produce empty clusters.
1 parent 397821e commit 216753a

File tree

2 files changed

+91
-4
lines changed

2 files changed

+91
-4
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,29 +106,57 @@ KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int ta
106106
// TODO: consider adding cluster size counts to the kmeans algo
107107
// handle assignment here so we can track distance and cluster size
108108
int[] centroidVectorCount = new int[centroids.length];
109+
int effectiveCluster = -1;
109110
int effectiveK = 0;
110111
for (int assigment : assignments) {
111112
centroidVectorCount[assigment]++;
112113
// this cluster has received an assignment, its now effective, but only count it once
113114
if (centroidVectorCount[assigment] == 1) {
114115
effectiveK++;
116+
effectiveCluster = assigment;
115117
}
116118
}
117119

118120
if (effectiveK == 1) {
121+
final float[][] singleClusterCentroid = new float[1][];
122+
singleClusterCentroid[0] = centroids[effectiveCluster];
123+
kMeansIntermediate.setCentroids(singleClusterCentroid);
124+
Arrays.fill(kMeansIntermediate.assignments(), 0);
119125
return kMeansIntermediate;
120126
}
121127

128+
int removedElements = 0;
122129
for (int c = 0; c < centroidVectorCount.length; c++) {
123130
// Recurse for each cluster which is larger than targetSize
124131
// Give ourselves 30% margin for the target size
125-
if (100 * centroidVectorCount[c] > 134 * targetSize) {
126-
FloatVectorValues sample = createClusterSlice(centroidVectorCount[c], c, vectors, assignments);
127-
132+
final int count = centroidVectorCount[c];
133+
final int adjustedCentroid = c - removedElements;
134+
if (100 * count > 134 * targetSize) {
135+
final FloatVectorValues sample = createClusterSlice(count, adjustedCentroid, vectors, assignments);
128136
// TODO: consider iterative here instead of recursive
129137
// recursive call to build out the sub partitions around this centroid c
130138
// subsequently reconcile and flatten the space of all centroids and assignments into one structure we can return
131-
updateAssignmentsWithRecursiveSplit(kMeansIntermediate, c, clusterAndSplit(sample, targetSize));
139+
updateAssignmentsWithRecursiveSplit(kMeansIntermediate, adjustedCentroid, clusterAndSplit(sample, targetSize));
140+
} else if (count == 0) {
141+
// remove empty clusters
142+
final int newSize = kMeansIntermediate.centroids().length - 1;
143+
final float[][] newCentroids = new float[newSize][];
144+
System.arraycopy(kMeansIntermediate.centroids(), 0, newCentroids, 0, adjustedCentroid);
145+
System.arraycopy(
146+
kMeansIntermediate.centroids(),
147+
adjustedCentroid + 1,
148+
newCentroids,
149+
adjustedCentroid,
150+
newSize - adjustedCentroid
151+
);
152+
// we need to update the assignments to reflect the new centroid ordinals
153+
for (int i = 0; i < kMeansIntermediate.assignments().length; i++) {
154+
if (kMeansIntermediate.assignments()[i] > adjustedCentroid) {
155+
kMeansIntermediate.assignments()[i]--;
156+
}
157+
}
158+
kMeansIntermediate.setCentroids(newCentroids);
159+
removedElements++;
132160
}
133161
}
134162

server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeansTests.java

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,4 +74,63 @@ private static FloatVectorValues generateData(int nSamples, int nDims, int nClus
7474
}
7575
return FloatVectorValues.fromFloats(vectors, nDims);
7676
}
77+
78+
public void testFewDifferentValues() throws IOException {
79+
int nVectors = random().nextInt(100, 1000);
80+
int targetSize = random().nextInt(4, 64);
81+
int dims = random().nextInt(2, 20);
82+
int diffValues = randomIntBetween(1, 5);
83+
float[][] values = new float[diffValues][dims];
84+
for (int i = 0; i < diffValues; i++) {
85+
for (int j = 0; j < dims; j++) {
86+
values[i][j] = random().nextFloat();
87+
}
88+
}
89+
List<float[]> vectorList = new ArrayList<>(nVectors);
90+
for (int i = 0; i < nVectors; i++) {
91+
vectorList.add(values[random().nextInt(diffValues)]);
92+
}
93+
FloatVectorValues vectors = FloatVectorValues.fromFloats(vectorList, dims);
94+
95+
HierarchicalKMeans hkmeans = new HierarchicalKMeans(
96+
dims,
97+
random().nextInt(1, 100),
98+
random().nextInt(Math.min(nVectors, 100), nVectors + 1),
99+
random().nextInt(2, 512),
100+
random().nextFloat(0.5f, 1.5f)
101+
);
102+
103+
KMeansResult result = hkmeans.cluster(vectors, targetSize);
104+
105+
float[][] centroids = result.centroids();
106+
int[] assignments = result.assignments();
107+
int[] soarAssignments = result.soarAssignments();
108+
109+
int[] counts = new int[centroids.length];
110+
for (int i = 0; i < assignments.length; i++) {
111+
counts[assignments[i]]++;
112+
}
113+
int totalCount = 0;
114+
for (int count : counts) {
115+
totalCount += count;
116+
assertTrue(count > 0);
117+
}
118+
assertEquals(nVectors, totalCount);
119+
120+
assertEquals(nVectors, assignments.length);
121+
122+
for (int assignment : assignments) {
123+
assertTrue(assignment >= 0 && assignment < centroids.length);
124+
}
125+
if (centroids.length > 1 && centroids.length < nVectors) {
126+
assertEquals(nVectors, soarAssignments.length);
127+
// verify no duplicates exist
128+
for (int i = 0; i < assignments.length; i++) {
129+
assertTrue(soarAssignments[i] >= 0 && soarAssignments[i] < centroids.length);
130+
assertNotEquals(assignments[i], soarAssignments[i]);
131+
}
132+
} else {
133+
assertEquals(0, soarAssignments.length);
134+
}
135+
}
77136
}

0 commit comments

Comments
 (0)