Skip to content

Commit bfd6692

Browse files
committed
Reduce heap usage in hierarchical k-means
1 parent 6578b9e commit bfd6692

File tree

2 files changed

+48
-32
lines changed

2 files changed

+48
-32
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import java.io.IOException;
1515
import java.util.Arrays;
16+
import java.util.Objects;
1617

1718
/**
1819
* An implementation of the hierarchical k-means algorithm that better partitions data than naive k-means
@@ -148,30 +149,32 @@ static FloatVectorValues createClusterSlice(int clusterSize, int cluster, FloatV
148149
}
149150

150151
void updateAssignmentsWithRecursiveSplit(KMeansIntermediate current, int cluster, KMeansIntermediate subPartitions) {
152+
if (subPartitions.centroids().length == 0) {
153+
return; // nothing to do, sub-partitions is empty
154+
}
151155
int orgCentroidsSize = current.centroids().length;
152156
int newCentroidsSize = current.centroids().length + subPartitions.centroids().length - 1;
153157

154158
// update based on the outcomes from the split clusters recursion
155-
if (subPartitions.centroids().length > 1) {
156-
float[][] newCentroids = new float[newCentroidsSize][dimension];
157-
System.arraycopy(current.centroids(), 0, newCentroids, 0, current.centroids().length);
159+
float[][] newCentroids = new float[newCentroidsSize][];
160+
System.arraycopy(current.centroids(), 0, newCentroids, 0, current.centroids().length);
158161

159-
// replace the original cluster
160-
int origCentroidOrd = 0;
161-
newCentroids[cluster] = subPartitions.centroids()[0];
162+
// replace the original cluster
163+
int origCentroidOrd = 0;
164+
newCentroids[cluster] = subPartitions.centroids()[0];
162165

163-
// append the remainder
164-
System.arraycopy(subPartitions.centroids(), 1, newCentroids, current.centroids().length, subPartitions.centroids().length - 1);
166+
// append the remainder
167+
System.arraycopy(subPartitions.centroids(), 1, newCentroids, current.centroids().length, subPartitions.centroids().length - 1);
168+
assert Arrays.stream(newCentroids).allMatch(Objects::nonNull);
165169

166-
current.setCentroids(newCentroids);
170+
current.setCentroids(newCentroids);
167171

168-
for (int i = 0; i < subPartitions.assignments().length; i++) {
169-
// this is a new centroid that was added, and so we'll need to remap it
170-
if (subPartitions.assignments()[i] != origCentroidOrd) {
171-
int parentOrd = subPartitions.ordToDoc(i);
172-
assert current.assignments()[parentOrd] == cluster;
173-
current.assignments()[parentOrd] = subPartitions.assignments()[i] + orgCentroidsSize - 1;
174-
}
172+
for (int i = 0; i < subPartitions.assignments().length; i++) {
173+
// this is a new centroid that was added, and so we'll need to remap it
174+
if (subPartitions.assignments()[i] != origCentroidOrd) {
175+
int parentOrd = subPartitions.ordToDoc(i);
176+
assert current.assignments()[parentOrd] == cluster;
177+
current.assignments()[parentOrd] = subPartitions.assignments()[i] + orgCentroidsSize - 1;
175178
}
176179
}
177180
}

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,14 @@ private static boolean stepLloyd(
7070
FloatVectorValues vectors,
7171
IntToIntFunction translateOrd,
7272
float[][] centroids,
73-
float[][] nextCentroids,
73+
int[] centroidCounts,
7474
int[] assignments,
7575
NeighborHood[] neighborhoods
7676
) throws IOException {
7777
boolean changed = false;
7878
int dim = vectors.dimension();
79-
int[] centroidCounts = new int[centroids.length];
79+
Arrays.fill(centroidCounts, 0);
8080

81-
for (float[] nextCentroid : nextCentroids) {
82-
Arrays.fill(nextCentroid, 0.0f);
83-
}
8481
final float[] distances = new float[4];
8582
for (int idx = 0; idx < vectors.size(); idx++) {
8683
float[] vector = vectors.vectorValue(idx);
@@ -97,16 +94,32 @@ private static boolean stepLloyd(
9794
changed = true;
9895
}
9996
centroidCounts[bestCentroidOffset]++;
100-
for (int d = 0; d < dim; d++) {
101-
nextCentroids[bestCentroidOffset][d] += vector[d];
102-
}
10397
}
10498

105-
for (int clusterIdx = 0; clusterIdx < centroids.length; clusterIdx++) {
106-
if (centroidCounts[clusterIdx] > 0) {
107-
float countF = (float) centroidCounts[clusterIdx];
108-
for (int d = 0; d < dim; d++) {
109-
centroids[clusterIdx][d] = nextCentroids[clusterIdx][d] / countF;
99+
if (changed) {
100+
for (int clusterIdx = 0; clusterIdx < centroids.length; clusterIdx++) {
101+
if (centroidCounts[clusterIdx] > 0) {
102+
Arrays.fill(centroids[clusterIdx], 0.0f);
103+
}
104+
}
105+
106+
for (int idx = 0; idx < vectors.size(); idx++) {
107+
final int assignment = assignments[translateOrd.apply(idx)];
108+
if (centroidCounts[assignment] > 0) {
109+
float[] vector = vectors.vectorValue(idx);
110+
float[] centroid = centroids[assignment];
111+
for (int d = 0; d < dim; d++) {
112+
centroid[d] += vector[d];
113+
}
114+
}
115+
}
116+
117+
for (int clusterIdx = 0; clusterIdx < centroids.length; clusterIdx++) {
118+
if (centroidCounts[clusterIdx] > 0) {
119+
float countF = (float) centroidCounts[clusterIdx];
120+
for (int d = 0; d < dim; d++) {
121+
centroids[clusterIdx][d] /= countF;
122+
}
110123
}
111124
}
112125
}
@@ -420,17 +433,17 @@ private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansInterme
420433
}
421434

422435
assert assignments.length == n;
423-
float[][] nextCentroids = new float[centroids.length][vectors.dimension()];
436+
int[] centroidCounts = new int[centroids.length];
424437
for (int i = 0; i < maxIterations; i++) {
425438
// This is potentially sampled, so we need to translate ordinals
426-
if (stepLloyd(sampledVectors, translateOrd, centroids, nextCentroids, assignments, neighborhoods) == false) {
439+
if (stepLloyd(sampledVectors, translateOrd, centroids, centroidCounts, assignments, neighborhoods) == false) {
427440
break;
428441
}
429442
}
430443
// If we were sampled, do a once over the full set of vectors to finalize the centroids
431444
if (sampleSize < n || maxIterations == 0) {
432445
// No ordinal translation needed here, we are using the full set of vectors
433-
stepLloyd(vectors, i -> i, centroids, nextCentroids, assignments, neighborhoods);
446+
stepLloyd(vectors, i -> i, centroids, centroidCounts, assignments, neighborhoods);
434447
}
435448
}
436449

0 commit comments

Comments
 (0)