Skip to content

Commit ce74df5

Browse files
authored
Fix iterating for best centroid when algorithm is neighbour aware and decrease SAMPLES_PER_CLUSTER_DEFAULT (#130069)
* KMeansIntermediate shares assigments
1 parent 7c213ba commit ce74df5

File tree

3 files changed

+34
-62
lines changed

3 files changed

+34
-62
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java

Lines changed: 5 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
package org.elasticsearch.index.codec.vectors.cluster;
1111

1212
import org.apache.lucene.index.FloatVectorValues;
13-
import org.apache.lucene.util.VectorUtil;
1413

1514
import java.io.IOException;
1615

@@ -21,7 +20,7 @@ public class HierarchicalKMeans {
2120

2221
static final int MAXK = 128;
2322
static final int MAX_ITERATIONS_DEFAULT = 6;
24-
static final int SAMPLES_PER_CLUSTER_DEFAULT = 256;
23+
static final int SAMPLES_PER_CLUSTER_DEFAULT = 64;
2524
static final float DEFAULT_SOAR_LAMBDA = 1.0f;
2625

2726
final int dimension;
@@ -67,8 +66,7 @@ public KMeansResult cluster(FloatVectorValues vectors, int targetSize) throws IO
6766
// partition the space
6867
KMeansIntermediate kMeansIntermediate = clusterAndSplit(vectors, targetSize);
6968
if (kMeansIntermediate.centroids().length > 1 && kMeansIntermediate.centroids().length < vectors.size()) {
70-
float f = Math.min((float) samplesPerCluster / targetSize, 1.0f);
71-
int localSampleSize = (int) (f * vectors.size());
69+
int localSampleSize = Math.min(kMeansIntermediate.centroids().length * samplesPerCluster, vectors.size());
7270
KMeansLocal kMeansLocal = new KMeansLocal(localSampleSize, maxIterations, clustersPerNeighborhood, DEFAULT_SOAR_LAMBDA);
7371
kMeansLocal.cluster(vectors, kMeansIntermediate, true);
7472
}
@@ -86,42 +84,16 @@ KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int ta
8684

8785
// TODO: instead of creating a sub-cluster assignments reuse the parent array each time
8886
int[] assignments = new int[vectors.size()];
89-
9087
KMeansLocal kmeans = new KMeansLocal(m, maxIterations);
9188
float[][] centroids = KMeansLocal.pickInitialCentroids(vectors, k);
92-
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids);
89+
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, assignments, vectors::ordToDoc);
9390
kmeans.cluster(vectors, kMeansIntermediate);
9491

9592
// TODO: consider adding cluster size counts to the kmeans algo
9693
// handle assignment here so we can track distance and cluster size
9794
int[] centroidVectorCount = new int[centroids.length];
98-
float[][] nextCentroids = new float[centroids.length][dimension];
99-
for (int i = 0; i < vectors.size(); i++) {
100-
float smallest = Float.MAX_VALUE;
101-
int centroidIdx = -1;
102-
float[] vector = vectors.vectorValue(i);
103-
for (int j = 0; j < centroids.length; j++) {
104-
float[] centroid = centroids[j];
105-
float d = VectorUtil.squareDistance(vector, centroid);
106-
if (d < smallest) {
107-
smallest = d;
108-
centroidIdx = j;
109-
}
110-
}
111-
centroidVectorCount[centroidIdx]++;
112-
for (int j = 0; j < dimension; j++) {
113-
nextCentroids[centroidIdx][j] += vector[j];
114-
}
115-
assignments[i] = centroidIdx;
116-
}
117-
118-
// update centroids based on assignments of all vectors
119-
for (int i = 0; i < centroids.length; i++) {
120-
if (centroidVectorCount[i] > 0) {
121-
for (int j = 0; j < dimension; j++) {
122-
centroids[i][j] = nextCentroids[i][j] / centroidVectorCount[i];
123-
}
124-
}
95+
for (int assigment : assignments) {
96+
centroidVectorCount[assigment]++;
12597
}
12698

12799
int effectiveK = 0;
@@ -131,8 +103,6 @@ KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int ta
131103
}
132104
}
133105

134-
kMeansIntermediate = new KMeansIntermediate(centroids, assignments, vectors::ordToDoc);
135-
136106
if (effectiveK == 1) {
137107
return kMeansIntermediate;
138108
}

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansIntermediate.java

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,6 @@ private KMeansIntermediate(float[][] centroids, int[] assignments, IntToIntFunct
3131
this(new float[0][0], new int[0], i -> i, new int[0]);
3232
}
3333

34-
KMeansIntermediate(float[][] centroids) {
35-
this(centroids, new int[0], i -> i, new int[0]);
36-
}
37-
3834
KMeansIntermediate(float[][] centroids, int[] assignments) {
3935
this(centroids, assignments, i -> i, new int[0]);
4036
}

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -87,17 +87,17 @@ private boolean stepLloyd(
8787

8888
for (int i = 0; i < sampleSize; i++) {
8989
float[] vector = vectors.vectorValue(i);
90-
int[] neighborOffsets = null;
91-
int centroidIdx = -1;
90+
final int assignment = assignments[i];
91+
final int bestCentroidOffset;
9292
if (neighborhoods != null) {
93-
neighborOffsets = neighborhoods.get(assignments[i]);
94-
centroidIdx = assignments[i];
93+
bestCentroidOffset = getBestCentroidFromNeighbours(centroids, vector, assignment, neighborhoods.get(assignment));
94+
} else {
95+
bestCentroidOffset = getBestCentroid(centroids, vector);
9596
}
96-
int bestCentroidOffset = getBestCentroidOffset(centroids, vector, centroidIdx, neighborOffsets);
97-
if (assignments[i] != bestCentroidOffset) {
97+
if (assignment != bestCentroidOffset) {
98+
assignments[i] = bestCentroidOffset;
9899
changed = true;
99100
}
100-
assignments[i] = bestCentroidOffset;
101101
centroidCounts[bestCentroidOffset]++;
102102
for (int d = 0; d < dim; d++) {
103103
nextCentroids[bestCentroidOffset][d] += vector[d];
@@ -116,23 +116,28 @@ private boolean stepLloyd(
116116
return changed;
117117
}
118118

119-
int getBestCentroidOffset(float[][] centroids, float[] vector, int centroidIdx, int[] centroidOffsets) {
119+
int getBestCentroidFromNeighbours(float[][] centroids, float[] vector, int centroidIdx, int[] centroidOffsets) {
120120
int bestCentroidOffset = centroidIdx;
121-
float minDsq;
122-
if (centroidIdx > 0 && centroidIdx < centroids.length) {
123-
minDsq = VectorUtil.squareDistance(vector, centroids[centroidIdx]);
124-
} else {
125-
minDsq = Float.MAX_VALUE;
121+
assert centroidIdx >= 0 && centroidIdx < centroids.length;
122+
float minDsq = VectorUtil.squareDistance(vector, centroids[centroidIdx]);
123+
for (int offset : centroidOffsets) {
124+
float dsq = VectorUtil.squareDistance(vector, centroids[offset]);
125+
if (dsq < minDsq) {
126+
minDsq = dsq;
127+
bestCentroidOffset = offset;
128+
}
126129
}
130+
return bestCentroidOffset;
131+
}
127132

128-
int k = 0;
129-
for (int j = 0; j < centroids.length; j++) {
130-
if (centroidOffsets == null || j == centroidOffsets[k]) {
131-
float dsq = VectorUtil.squareDistance(vector, centroids[j]);
132-
if (dsq < minDsq) {
133-
minDsq = dsq;
134-
bestCentroidOffset = j;
135-
}
133+
int getBestCentroid(float[][] centroids, float[] vector) {
134+
int bestCentroidOffset = 0;
135+
float minDsq = Float.MAX_VALUE;
136+
for (int i = 0; i < centroids.length; i++) {
137+
float dsq = VectorUtil.squareDistance(vector, centroids[i]);
138+
if (dsq < minDsq) {
139+
minDsq = dsq;
140+
bestCentroidOffset = i;
136141
}
137142
}
138143
return bestCentroidOffset;
@@ -271,7 +276,8 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, L
271276
return;
272277
}
273278

274-
int[] assignments = new int[n];
279+
int[] assignments = kMeansIntermediate.assignments();
280+
assert assignments.length == n;
275281
float[][] nextCentroids = new float[centroids.length][vectors.dimension()];
276282
for (int i = 0; i < maxIterations; i++) {
277283
if (stepLloyd(vectors, centroids, nextCentroids, assignments, sampleSize, neighborhoods) == false) {
@@ -291,7 +297,7 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, L
291297
* @param maxIterations the max iterations to shift centroids
292298
*/
293299
public static void cluster(FloatVectorValues vectors, float[][] centroids, int sampleSize, int maxIterations) throws IOException {
294-
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids);
300+
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, new int[vectors.size()], vectors::ordToDoc);
295301
KMeansLocal kMeans = new KMeansLocal(sampleSize, maxIterations);
296302
kMeans.cluster(vectors, kMeansIntermediate);
297303
}

0 commit comments

Comments
 (0)