Skip to content

Commit 33ede40

Browse files
committed
clustersPerNeighborhood must be > 2
1 parent 719ee96 commit 33ede40

File tree

3 files changed

+52
-28
lines changed

3 files changed

+52
-28
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ public KMeansResult cluster(FloatVectorValues vectors, int targetSize) throws IO
6868
KMeansIntermediate kMeansIntermediate = clusterAndSplit(vectors, targetSize);
6969
if (kMeansIntermediate.centroids().length > 1 && kMeansIntermediate.centroids().length < vectors.size()) {
7070
int localSampleSize = Math.min(kMeansIntermediate.centroids().length * samplesPerCluster / 2, vectors.size());
71-
KMeansLocal kMeansLocal = new KMeansLocal(localSampleSize, maxIterations, clustersPerNeighborhood, DEFAULT_SOAR_LAMBDA);
72-
kMeansLocal.cluster(vectors, kMeansIntermediate, true);
71+
KMeansLocal kMeansLocal = new KMeansLocal(localSampleSize, maxIterations);
72+
kMeansLocal.cluster(vectors, kMeansIntermediate, clustersPerNeighborhood, DEFAULT_SOAR_LAMBDA);
7373
}
7474

7575
return kMeansIntermediate;

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -33,20 +33,12 @@ class KMeansLocal {
3333
// second closest centroid.
3434
private static final float SOAR_MIN_DISTANCE = 1e-16f;
3535

36-
final int sampleSize;
37-
final int maxIterations;
38-
final int clustersPerNeighborhood;
39-
final float soarLambda;
36+
private final int sampleSize;
37+
private final int maxIterations;
4038

41-
KMeansLocal(int sampleSize, int maxIterations, int clustersPerNeighborhood, float soarLambda) {
39+
KMeansLocal(int sampleSize, int maxIterations) {
4240
this.sampleSize = sampleSize;
4341
this.maxIterations = maxIterations;
44-
this.clustersPerNeighborhood = clustersPerNeighborhood;
45-
this.soarLambda = soarLambda;
46-
}
47-
48-
KMeansLocal(int sampleSize, int maxIterations) {
49-
this(sampleSize, maxIterations, -1, -1f);
5042
}
5143

5244
/**
@@ -179,8 +171,13 @@ private void computeNeighborhoods(float[][] centers, List<int[]> neighborhoods,
179171
}
180172
}
181173

182-
private int[] assignSpilled(FloatVectorValues vectors, List<int[]> neighborhoods, float[][] centroids, int[] assignments)
183-
throws IOException {
174+
private int[] assignSpilled(
175+
FloatVectorValues vectors,
176+
List<int[]> neighborhoods,
177+
float[][] centroids,
178+
int[] assignments,
179+
float soarLambda
180+
) throws IOException {
184181
// SOAR uses an adjusted distance for assigning spilled documents which is
185182
// given by:
186183
//
@@ -238,7 +235,7 @@ private int[] assignSpilled(FloatVectorValues vectors, List<int[]> neighborhoods
238235
}
239236

240237
/**
241-
* cluster using a lloyd k-means algorithm that is not neighbor aware
238+
* cluster using a lloyd k-means algorithm that does not consider prior clustered neighborhoods when adjusting centroids
242239
*
243240
* @param vectors the vectors to cluster
244241
* @param kMeansIntermediate the output object to populate which minimally includes centroids,
@@ -247,7 +244,7 @@ private int[] assignSpilled(FloatVectorValues vectors, List<int[]> neighborhoods
247244
* @throws IOException is thrown if vectors is inaccessible
248245
*/
249246
void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate) throws IOException {
250-
cluster(vectors, kMeansIntermediate, false);
247+
doCluster(vectors, kMeansIntermediate, -1, -1);
251248
}
252249

253250
/**
@@ -259,13 +256,23 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate) t
259256
* the prior assignments of the given vectors; care should be taken in
260257
* passing in a valid output object with a centroids array that is the size of centroids expected
261258
* and assignments that are the same size as the vectors. The SOAR assignments are overwritten by this operation.
262-
* @param neighborAware whether nearby neighboring centroids and their vectors should be used to update the centroid positions,
263-
* implies SOAR assignments
264-
* @throws IOException is thrown if vectors is inaccessible
259+
* @param clustersPerNeighborhood number of nearby neighboring centroids to be used to update the centroid positions.
260+
* @param soarLambda lambda used for SOAR assignments
261+
*
262+
* @throws IOException is thrown if vectors is inaccessible or if the clustersPerNeighborhood is less than 2
265263
*/
266-
void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, boolean neighborAware) throws IOException {
264+
void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, int clustersPerNeighborhood, float soarLambda)
265+
throws IOException {
266+
if (clustersPerNeighborhood < 2) {
267+
throw new IllegalArgumentException("clustersPerNeighborhood must be at least 2, got [" + clustersPerNeighborhood + "]");
268+
}
269+
doCluster(vectors, kMeansIntermediate, clustersPerNeighborhood, soarLambda);
270+
}
271+
272+
private void doCluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, int clustersPerNeighborhood, float soarLambda)
273+
throws IOException {
267274
float[][] centroids = kMeansIntermediate.centroids();
268-
boolean computeNeighborhoods = neighborAware && clustersPerNeighborhood > 0;
275+
boolean computeNeighborhoods = clustersPerNeighborhood != -1;
269276

270277
List<int[]> neighborhoods = null;
271278
if (computeNeighborhoods) {
@@ -281,7 +288,7 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, b
281288
int[] assignments = kMeansIntermediate.assignments();
282289
assert assignments != null;
283290
assert assignments.length == vectors.size();
284-
kMeansIntermediate.setSoarAssignments(assignSpilled(vectors, neighborhoods, centroids, assignments));
291+
kMeansIntermediate.setSoarAssignments(assignSpilled(vectors, neighborhoods, centroids, assignments, soarLambda));
285292
}
286293
}
287294

server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocalTests.java

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,32 @@
1717
import java.util.ArrayList;
1818
import java.util.List;
1919

20+
import static org.hamcrest.Matchers.containsString;
21+
2022
public class KMeansLocalTests extends ESTestCase {
2123

24+
public void testIllegalClustersPerNeighborhood() {
25+
KMeansLocal kMeansLocal = new KMeansLocal(randomInt(), randomInt());
26+
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(new float[0][], new int[0], i -> i);
27+
IllegalArgumentException ex = expectThrows(
28+
IllegalArgumentException.class,
29+
() -> kMeansLocal.cluster(
30+
FloatVectorValues.fromFloats(List.of(), randomInt(1024)),
31+
kMeansIntermediate,
32+
randomIntBetween(Integer.MIN_VALUE, 1),
33+
randomFloat()
34+
)
35+
);
36+
assertThat(ex.getMessage(), containsString("clustersPerNeighborhood must be at least 2"));
37+
}
38+
2239
public void testKMeansNeighbors() throws IOException {
2340
int nClusters = random().nextInt(1, 10);
2441
int nVectors = random().nextInt(nClusters * 100, nClusters * 200);
2542
int dims = random().nextInt(2, 20);
2643
int sampleSize = random().nextInt(100, nVectors + 1);
2744
int maxIterations = random().nextInt(0, 100);
28-
int clustersPerNeighborhood = random().nextInt(0, 512);
45+
int clustersPerNeighborhood = random().nextInt(2, 512);
2946
float soarLambda = random().nextFloat(0.5f, 1.5f);
3047
FloatVectorValues vectors = generateData(nVectors, dims, nClusters);
3148

@@ -49,8 +66,8 @@ public void testKMeansNeighbors() throws IOException {
4966
}
5067

5168
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, assignments, i -> assignmentOrdinals[i]);
52-
KMeansLocal kMeansLocal = new KMeansLocal(sampleSize, maxIterations, clustersPerNeighborhood, soarLambda);
53-
kMeansLocal.cluster(vectors, kMeansIntermediate, true);
69+
KMeansLocal kMeansLocal = new KMeansLocal(sampleSize, maxIterations);
70+
kMeansLocal.cluster(vectors, kMeansIntermediate, clustersPerNeighborhood, soarLambda);
5471

5572
assertEquals(nClusters, centroids.length);
5673
assertNotNull(kMeansIntermediate.soarAssignments());
@@ -90,8 +107,8 @@ public void testKMeansNeighborsAllZero() throws IOException {
90107
}
91108

92109
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, assignments, i -> assignmentOrdinals[i]);
93-
KMeansLocal kMeansLocal = new KMeansLocal(sampleSize, maxIterations, clustersPerNeighborhood, soarLambda);
94-
kMeansLocal.cluster(fvv, kMeansIntermediate, true);
110+
KMeansLocal kMeansLocal = new KMeansLocal(sampleSize, maxIterations);
111+
kMeansLocal.cluster(fvv, kMeansIntermediate, clustersPerNeighborhood, soarLambda);
95112

96113
assertEquals(nClusters, centroids.length);
97114
assertNotNull(kMeansIntermediate.soarAssignments());

0 commit comments

Comments
 (0)