Skip to content

Commit ef9f544

Browse files
authored
Handle more edge cases in HierarchicalKMeans (#130622)
This commit handle the following cases: 1) When the number of centroids is one (and better checks when number of centroids < number of vectors). 2) when number of maxIterations is equal to 0 (no sampling) fixes #130497
1 parent f047f78 commit ef9f544

File tree

3 files changed

+12
-11
lines changed

3 files changed

+12
-11
lines changed

muted-tests.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -578,9 +578,6 @@ tests:
578578
- class: org.elasticsearch.xpack.esql.qa.single_node.GenerativeIT
579579
method: test
580580
issue: https://github.com/elastic/elasticsearch/issues/130067
581-
- class: org.elasticsearch.index.codec.vectors.cluster.HierarchicalKMeansTests
582-
method: testHKmeans
583-
issue: https://github.com/elastic/elasticsearch/issues/130497
584581
- class: org.elasticsearch.backwards.MixedClusterClientYamlTestSuiteIT
585582
method: test {p0=search.vectors/40_knn_search/Dimensions are dynamically set}
586583
issue: https://github.com/elastic/elasticsearch/issues/130626

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,8 +329,10 @@ private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansInterme
329329
float[][] centroids = kMeansIntermediate.centroids();
330330
int k = centroids.length;
331331
int n = vectors.size();
332+
int[] assignments = kMeansIntermediate.assignments();
332333

333-
if (k == 1 || k >= n) {
334+
if (k == 1) {
335+
Arrays.fill(assignments, 0);
334336
return;
335337
}
336338
IntToIntFunction translateOrd = i -> i;
@@ -339,7 +341,7 @@ private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansInterme
339341
sampledVectors = SampleReader.createSampleReader(vectors, sampleSize, 42L);
340342
translateOrd = sampledVectors::ordToDoc;
341343
}
342-
int[] assignments = kMeansIntermediate.assignments();
344+
343345
assert assignments.length == n;
344346
float[][] nextCentroids = new float[centroids.length][vectors.dimension()];
345347
for (int i = 0; i < maxIterations; i++) {
@@ -349,7 +351,7 @@ private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansInterme
349351
}
350352
}
351353
// If we were sampled, do a once over the full set of vectors to finalize the centroids
352-
if (sampleSize < n) {
354+
if (sampleSize < n || maxIterations == 0) {
353355
// No ordinal translation needed here, we are using the full set of vectors
354356
stepLloyd(vectors, i -> i, centroids, nextCentroids, assignments, neighborhoods);
355357
}

server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeansTests.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ public class HierarchicalKMeansTests extends ESTestCase {
1919

2020
public void testHKmeans() throws IOException {
2121
int nClusters = random().nextInt(1, 10);
22-
int nVectors = random().nextInt(nClusters * 100, nClusters * 200);
22+
int nVectors = random().nextInt(1, nClusters * 200);
2323
int dims = random().nextInt(2, 20);
24-
int sampleSize = random().nextInt(100, nVectors + 1);
25-
int maxIterations = random().nextInt(0, 100);
24+
int sampleSize = random().nextInt(Math.min(nVectors, 100), nVectors + 1);
25+
int maxIterations = random().nextInt(1, 100);
2626
int clustersPerNeighborhood = random().nextInt(2, 512);
2727
float soarLambda = random().nextFloat(0.5f, 1.5f);
2828
FloatVectorValues vectors = generateData(nVectors, dims, nClusters);
@@ -36,14 +36,16 @@ public void testHKmeans() throws IOException {
3636
int[] assignments = result.assignments();
3737
int[] soarAssignments = result.soarAssignments();
3838

39-
assertEquals(nClusters, centroids.length, 6);
39+
assertEquals(Math.min(nClusters, nVectors), centroids.length, 8);
4040
assertEquals(nVectors, assignments.length);
41-
if (centroids.length > 1 && clustersPerNeighborhood > 0) {
41+
if (centroids.length > 1 && centroids.length < nVectors) {
4242
assertEquals(nVectors, soarAssignments.length);
4343
// verify no duplicates exist
4444
for (int i = 0; i < assignments.length; i++) {
4545
assert assignments[i] != soarAssignments[i];
4646
}
47+
} else {
48+
assertEquals(0, soarAssignments.length);
4749
}
4850
}
4951

0 commit comments

Comments
 (0)