Skip to content

Commit d90055e

Browse files
authored
Reduce heap usage in hierarchical k-means (#132391)
It reduces heap by avoiding multiple copies of centroids on heap.
1 parent 785803b commit d90055e

File tree

2 files changed

+55
-34
lines changed

2 files changed

+55
-34
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import java.io.IOException;
1515
import java.util.Arrays;
16+
import java.util.Objects;
1617

1718
/**
1819
* An implementation of the hierarchical k-means algorithm that better partitions data than naive k-means
@@ -148,30 +149,32 @@ static FloatVectorValues createClusterSlice(int clusterSize, int cluster, FloatV
148149
}
149150

150151
void updateAssignmentsWithRecursiveSplit(KMeansIntermediate current, int cluster, KMeansIntermediate subPartitions) {
152+
if (subPartitions.centroids().length == 0) {
153+
return; // nothing to do, sub-partitions is empty
154+
}
151155
int orgCentroidsSize = current.centroids().length;
152156
int newCentroidsSize = current.centroids().length + subPartitions.centroids().length - 1;
153157

154158
// update based on the outcomes from the split clusters recursion
155-
if (subPartitions.centroids().length > 1) {
156-
float[][] newCentroids = new float[newCentroidsSize][dimension];
157-
System.arraycopy(current.centroids(), 0, newCentroids, 0, current.centroids().length);
159+
float[][] newCentroids = new float[newCentroidsSize][];
160+
System.arraycopy(current.centroids(), 0, newCentroids, 0, current.centroids().length);
158161

159-
// replace the original cluster
160-
int origCentroidOrd = 0;
161-
newCentroids[cluster] = subPartitions.centroids()[0];
162+
// replace the original cluster
163+
int origCentroidOrd = 0;
164+
newCentroids[cluster] = subPartitions.centroids()[0];
162165

163-
// append the remainder
164-
System.arraycopy(subPartitions.centroids(), 1, newCentroids, current.centroids().length, subPartitions.centroids().length - 1);
166+
// append the remainder
167+
System.arraycopy(subPartitions.centroids(), 1, newCentroids, current.centroids().length, subPartitions.centroids().length - 1);
168+
assert Arrays.stream(newCentroids).allMatch(Objects::nonNull);
165169

166-
current.setCentroids(newCentroids);
170+
current.setCentroids(newCentroids);
167171

168-
for (int i = 0; i < subPartitions.assignments().length; i++) {
169-
// this is a new centroid that was added, and so we'll need to remap it
170-
if (subPartitions.assignments()[i] != origCentroidOrd) {
171-
int parentOrd = subPartitions.ordToDoc(i);
172-
assert current.assignments()[parentOrd] == cluster;
173-
current.assignments()[parentOrd] = subPartitions.assignments()[i] + orgCentroidsSize - 1;
174-
}
172+
for (int i = 0; i < subPartitions.assignments().length; i++) {
173+
// this is a new centroid that was added, and so we'll need to remap it
174+
if (subPartitions.assignments()[i] != origCentroidOrd) {
175+
int parentOrd = subPartitions.ordToDoc(i);
176+
assert current.assignments()[parentOrd] == cluster;
177+
current.assignments()[parentOrd] = subPartitions.assignments()[i] + orgCentroidsSize - 1;
175178
}
176179
}
177180
}

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
package org.elasticsearch.index.codec.vectors.cluster;
1111

1212
import org.apache.lucene.index.FloatVectorValues;
13+
import org.apache.lucene.util.FixedBitSet;
1314
import org.apache.lucene.util.VectorUtil;
1415
import org.apache.lucene.util.hnsw.IntToIntFunction;
1516
import org.elasticsearch.index.codec.vectors.SampleReader;
@@ -70,17 +71,14 @@ private static boolean stepLloyd(
7071
FloatVectorValues vectors,
7172
IntToIntFunction translateOrd,
7273
float[][] centroids,
73-
float[][] nextCentroids,
74+
FixedBitSet centroidChanged,
75+
int[] centroidCounts,
7476
int[] assignments,
7577
NeighborHood[] neighborhoods
7678
) throws IOException {
7779
boolean changed = false;
7880
int dim = vectors.dimension();
79-
int[] centroidCounts = new int[centroids.length];
80-
81-
for (float[] nextCentroid : nextCentroids) {
82-
Arrays.fill(nextCentroid, 0.0f);
83-
}
81+
centroidChanged.clear();
8482
final float[] distances = new float[4];
8583
for (int idx = 0; idx < vectors.size(); idx++) {
8684
float[] vector = vectors.vectorValue(idx);
@@ -93,20 +91,39 @@ private static boolean stepLloyd(
9391
bestCentroidOffset = getBestCentroid(centroids, vector, distances);
9492
}
9593
if (assignment != bestCentroidOffset) {
94+
if (assignment != -1) {
95+
centroidChanged.set(assignment);
96+
}
97+
centroidChanged.set(bestCentroidOffset);
9698
assignments[vectorOrd] = bestCentroidOffset;
9799
changed = true;
98100
}
99-
centroidCounts[bestCentroidOffset]++;
100-
for (int d = 0; d < dim; d++) {
101-
nextCentroids[bestCentroidOffset][d] += vector[d];
102-
}
103101
}
102+
if (changed) {
103+
Arrays.fill(centroidCounts, 0);
104+
for (int idx = 0; idx < vectors.size(); idx++) {
105+
final int assignment = assignments[translateOrd.apply(idx)];
106+
if (centroidChanged.get(assignment)) {
107+
float[] centroid = centroids[assignment];
108+
if (centroidCounts[assignment]++ == 0) {
109+
Arrays.fill(centroid, 0.0f);
110+
}
111+
float[] vector = vectors.vectorValue(idx);
112+
for (int d = 0; d < dim; d++) {
113+
centroid[d] += vector[d];
114+
}
115+
}
116+
}
104117

105-
for (int clusterIdx = 0; clusterIdx < centroids.length; clusterIdx++) {
106-
if (centroidCounts[clusterIdx] > 0) {
107-
float countF = (float) centroidCounts[clusterIdx];
108-
for (int d = 0; d < dim; d++) {
109-
centroids[clusterIdx][d] = nextCentroids[clusterIdx][d] / countF;
118+
for (int clusterIdx = 0; clusterIdx < centroids.length; clusterIdx++) {
119+
if (centroidChanged.get(clusterIdx)) {
120+
float count = (float) centroidCounts[clusterIdx];
121+
if (count > 0) {
122+
float[] centroid = centroids[clusterIdx];
123+
for (int d = 0; d < dim; d++) {
124+
centroid[d] /= count;
125+
}
126+
}
110127
}
111128
}
112129
}
@@ -420,17 +437,18 @@ private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansInterme
420437
}
421438

422439
assert assignments.length == n;
423-
float[][] nextCentroids = new float[centroids.length][vectors.dimension()];
440+
FixedBitSet centroidChanged = new FixedBitSet(centroids.length);
441+
int[] centroidCounts = new int[centroids.length];
424442
for (int i = 0; i < maxIterations; i++) {
425443
// This is potentially sampled, so we need to translate ordinals
426-
if (stepLloyd(sampledVectors, translateOrd, centroids, nextCentroids, assignments, neighborhoods) == false) {
444+
if (stepLloyd(sampledVectors, translateOrd, centroids, centroidChanged, centroidCounts, assignments, neighborhoods) == false) {
427445
break;
428446
}
429447
}
430448
// If we were sampled, do a once over the full set of vectors to finalize the centroids
431449
if (sampleSize < n || maxIterations == 0) {
432450
// No ordinal translation needed here, we are using the full set of vectors
433-
stepLloyd(vectors, i -> i, centroids, nextCentroids, assignments, neighborhoods);
451+
stepLloyd(vectors, i -> i, centroids, centroidChanged, centroidCounts, assignments, neighborhoods);
434452
}
435453
}
436454

0 commit comments

Comments
 (0)