Skip to content

Commit da0075b

Browse files
committed
KMeansIntermediate shares assigments
1 parent ceed8b4 commit da0075b

File tree

3 files changed

+6
-38
lines changed

3 files changed

+6
-38
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java

Lines changed: 3 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
package org.elasticsearch.index.codec.vectors.cluster;
1111

1212
import org.apache.lucene.index.FloatVectorValues;
13-
import org.apache.lucene.util.VectorUtil;
1413

1514
import java.io.IOException;
1615

@@ -85,42 +84,16 @@ KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int ta
8584

8685
// TODO: instead of creating a sub-cluster assignments reuse the parent array each time
8786
int[] assignments = new int[vectors.size()];
88-
8987
KMeansLocal kmeans = new KMeansLocal(m, maxIterations);
9088
float[][] centroids = KMeansLocal.pickInitialCentroids(vectors, k);
91-
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids);
89+
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, assignments, vectors::ordToDoc);
9290
kmeans.cluster(vectors, kMeansIntermediate);
9391

9492
// TODO: consider adding cluster size counts to the kmeans algo
9593
// handle assignment here so we can track distance and cluster size
9694
int[] centroidVectorCount = new int[centroids.length];
97-
float[][] nextCentroids = new float[centroids.length][dimension];
98-
for (int i = 0; i < vectors.size(); i++) {
99-
float smallest = Float.MAX_VALUE;
100-
int centroidIdx = -1;
101-
float[] vector = vectors.vectorValue(i);
102-
for (int j = 0; j < centroids.length; j++) {
103-
float[] centroid = centroids[j];
104-
float d = VectorUtil.squareDistance(vector, centroid);
105-
if (d < smallest) {
106-
smallest = d;
107-
centroidIdx = j;
108-
}
109-
}
110-
centroidVectorCount[centroidIdx]++;
111-
for (int j = 0; j < dimension; j++) {
112-
nextCentroids[centroidIdx][j] += vector[j];
113-
}
114-
assignments[i] = centroidIdx;
115-
}
116-
117-
// update centroids based on assignments of all vectors
118-
for (int i = 0; i < centroids.length; i++) {
119-
if (centroidVectorCount[i] > 0) {
120-
for (int j = 0; j < dimension; j++) {
121-
centroids[i][j] = nextCentroids[i][j] / centroidVectorCount[i];
122-
}
123-
}
95+
for (int assigment : assignments) {
96+
centroidVectorCount[assigment]++;
12497
}
12598

12699
int effectiveK = 0;
@@ -130,8 +103,6 @@ KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int ta
130103
}
131104
}
132105

133-
kMeansIntermediate = new KMeansIntermediate(centroids, assignments, vectors::ordToDoc);
134-
135106
if (effectiveK == 1) {
136107
return kMeansIntermediate;
137108
}

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansIntermediate.java

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,6 @@ private KMeansIntermediate(float[][] centroids, int[] assignments, IntToIntFunct
3131
this(new float[0][0], new int[0], i -> i, new int[0]);
3232
}
3333

34-
KMeansIntermediate(float[][] centroids) {
35-
this(centroids, new int[0], i -> i, new int[0]);
36-
}
37-
3834
KMeansIntermediate(float[][] centroids, int[] assignments) {
3935
this(centroids, assignments, i -> i, new int[0]);
4036
}

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,8 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, L
276276
return;
277277
}
278278

279-
int[] assignments = new int[n];
279+
int[] assignments = kMeansIntermediate.assignments();
280+
assert assignments.length == n;
280281
float[][] nextCentroids = new float[centroids.length][vectors.dimension()];
281282
for (int i = 0; i < maxIterations; i++) {
282283
if (stepLloyd(vectors, centroids, nextCentroids, assignments, sampleSize, neighborhoods) == false) {
@@ -296,7 +297,7 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, L
296297
* @param maxIterations the max iterations to shift centroids
297298
*/
298299
public static void cluster(FloatVectorValues vectors, float[][] centroids, int sampleSize, int maxIterations) throws IOException {
299-
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids);
300+
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, new int[vectors.size()], vectors::ordToDoc);
300301
KMeansLocal kMeans = new KMeansLocal(sampleSize, maxIterations);
301302
kMeans.cluster(vectors, kMeansIntermediate);
302303
}

0 commit comments

Comments
 (0)