From bfd6692ed31ad810c6851f2550ff577a6a050923 Mon Sep 17 00:00:00 2001 From: Ignacio Vera Date: Mon, 4 Aug 2025 12:46:28 +0100 Subject: [PATCH 1/3] Reduce heap usage in hierarchical k-means --- .../vectors/cluster/HierarchicalKMeans.java | 35 ++++++++------- .../codec/vectors/cluster/KMeansLocal.java | 45 ++++++++++++------- 2 files changed, 48 insertions(+), 32 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java index 0bf15943f0060..72f1c9e0cdf0e 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java @@ -13,6 +13,7 @@ import java.io.IOException; import java.util.Arrays; +import java.util.Objects; /** * An implementation of the hierarchical k-means algorithm that better partitions data than naive k-means @@ -148,30 +149,32 @@ static FloatVectorValues createClusterSlice(int clusterSize, int cluster, FloatV } void updateAssignmentsWithRecursiveSplit(KMeansIntermediate current, int cluster, KMeansIntermediate subPartitions) { + if (subPartitions.centroids().length == 0) { + return; // nothing to do, sub-partitions is empty + } int orgCentroidsSize = current.centroids().length; int newCentroidsSize = current.centroids().length + subPartitions.centroids().length - 1; // update based on the outcomes from the split clusters recursion - if (subPartitions.centroids().length > 1) { - float[][] newCentroids = new float[newCentroidsSize][dimension]; - System.arraycopy(current.centroids(), 0, newCentroids, 0, current.centroids().length); + float[][] newCentroids = new float[newCentroidsSize][]; + System.arraycopy(current.centroids(), 0, newCentroids, 0, current.centroids().length); - // replace the original cluster - int origCentroidOrd = 0; - newCentroids[cluster] = subPartitions.centroids()[0]; + // replace the original cluster + int origCentroidOrd = 0; + newCentroids[cluster] = subPartitions.centroids()[0]; - // append the remainder - System.arraycopy(subPartitions.centroids(), 1, newCentroids, current.centroids().length, subPartitions.centroids().length - 1); + // append the remainder + System.arraycopy(subPartitions.centroids(), 1, newCentroids, current.centroids().length, subPartitions.centroids().length - 1); + assert Arrays.stream(newCentroids).allMatch(Objects::nonNull); - current.setCentroids(newCentroids); + current.setCentroids(newCentroids); - for (int i = 0; i < subPartitions.assignments().length; i++) { - // this is a new centroid that was added, and so we'll need to remap it - if (subPartitions.assignments()[i] != origCentroidOrd) { - int parentOrd = subPartitions.ordToDoc(i); - assert current.assignments()[parentOrd] == cluster; - current.assignments()[parentOrd] = subPartitions.assignments()[i] + orgCentroidsSize - 1; - } + for (int i = 0; i < subPartitions.assignments().length; i++) { + // this is a new centroid that was added, and so we'll need to remap it + if (subPartitions.assignments()[i] != origCentroidOrd) { + int parentOrd = subPartitions.ordToDoc(i); + assert current.assignments()[parentOrd] == cluster; + current.assignments()[parentOrd] = subPartitions.assignments()[i] + orgCentroidsSize - 1; } } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java index 0aabdc9d74590..86609aa36e5b3 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java @@ -70,17 +70,14 @@ private static boolean stepLloyd( FloatVectorValues vectors, IntToIntFunction translateOrd, float[][] centroids, - float[][] nextCentroids, + int[] centroidCounts, int[] assignments, NeighborHood[] neighborhoods ) throws IOException { boolean changed = false; int dim = vectors.dimension(); - int[] centroidCounts = new int[centroids.length]; + Arrays.fill(centroidCounts, 0); - for (float[] nextCentroid : nextCentroids) { - Arrays.fill(nextCentroid, 0.0f); - } final float[] distances = new float[4]; for (int idx = 0; idx < vectors.size(); idx++) { float[] vector = vectors.vectorValue(idx); @@ -97,16 +94,32 @@ private static boolean stepLloyd( changed = true; } centroidCounts[bestCentroidOffset]++; - for (int d = 0; d < dim; d++) { - nextCentroids[bestCentroidOffset][d] += vector[d]; - } } - for (int clusterIdx = 0; clusterIdx < centroids.length; clusterIdx++) { - if (centroidCounts[clusterIdx] > 0) { - float countF = (float) centroidCounts[clusterIdx]; - for (int d = 0; d < dim; d++) { - centroids[clusterIdx][d] = nextCentroids[clusterIdx][d] / countF; + if (changed) { + for (int clusterIdx = 0; clusterIdx < centroids.length; clusterIdx++) { + if (centroidCounts[clusterIdx] > 0) { + Arrays.fill(centroids[clusterIdx], 0.0f); + } + } + + for (int idx = 0; idx < vectors.size(); idx++) { + final int assignment = assignments[translateOrd.apply(idx)]; + if (centroidCounts[assignment] > 0) { + float[] vector = vectors.vectorValue(idx); + float[] centroid = centroids[assignment]; + for (int d = 0; d < dim; d++) { + centroid[d] += vector[d]; + } + } + } + + for (int clusterIdx = 0; clusterIdx < centroids.length; clusterIdx++) { + if (centroidCounts[clusterIdx] > 0) { + float countF = (float) centroidCounts[clusterIdx]; + for (int d = 0; d < dim; d++) { + centroids[clusterIdx][d] /= countF; + } } } } @@ -420,17 +433,17 @@ private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansInterme } assert assignments.length == n; - float[][] nextCentroids = new float[centroids.length][vectors.dimension()]; + int[] centroidCounts = new int[centroids.length]; for (int i = 0; i < maxIterations; i++) { // This is potentially sampled, so we need to translate ordinals - if (stepLloyd(sampledVectors, translateOrd, centroids, nextCentroids, assignments, neighborhoods) == false) { + if (stepLloyd(sampledVectors, translateOrd, centroids, centroidCounts, assignments, neighborhoods) == false) { break; } } // If we were sampled, do a once over the full set of vectors to finalize the centroids if (sampleSize < n || maxIterations == 0) { // No ordinal translation needed here, we are using the full set of vectors - stepLloyd(vectors, i -> i, centroids, nextCentroids, assignments, neighborhoods); + stepLloyd(vectors, i -> i, centroids, centroidCounts, assignments, neighborhoods); } } From ee615538a4674e9a12e19446c6146aaf5178c34f Mon Sep 17 00:00:00 2001 From: Ignacio Vera Date: Tue, 5 Aug 2025 08:50:31 +0100 Subject: [PATCH 2/3] Ony recompute changed centroids --- .../codec/vectors/cluster/KMeansLocal.java | 25 +++++++++++-------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java index 86609aa36e5b3..bde10b618e50e 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java @@ -10,6 +10,7 @@ package org.elasticsearch.index.codec.vectors.cluster; import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.IntToIntFunction; import org.elasticsearch.index.codec.vectors.SampleReader; @@ -70,14 +71,14 @@ private static boolean stepLloyd( FloatVectorValues vectors, IntToIntFunction translateOrd, float[][] centroids, + FixedBitSet centroidChanged, int[] centroidCounts, int[] assignments, NeighborHood[] neighborhoods ) throws IOException { boolean changed = false; int dim = vectors.dimension(); - Arrays.fill(centroidCounts, 0); - + centroidChanged.clear(); final float[] distances = new float[4]; for (int idx = 0; idx < vectors.size(); idx++) { float[] vector = vectors.vectorValue(idx); @@ -90,22 +91,25 @@ private static boolean stepLloyd( bestCentroidOffset = getBestCentroid(centroids, vector, distances); } if (assignment != bestCentroidOffset) { + if (assignment != -1) { + centroidChanged.set(assignment); + } + centroidChanged.set(bestCentroidOffset); assignments[vectorOrd] = bestCentroidOffset; changed = true; } - centroidCounts[bestCentroidOffset]++; } - if (changed) { for (int clusterIdx = 0; clusterIdx < centroids.length; clusterIdx++) { - if (centroidCounts[clusterIdx] > 0) { + if (centroidChanged.get(clusterIdx)) { Arrays.fill(centroids[clusterIdx], 0.0f); } } - + Arrays.fill(centroidCounts, 0); for (int idx = 0; idx < vectors.size(); idx++) { final int assignment = assignments[translateOrd.apply(idx)]; - if (centroidCounts[assignment] > 0) { + if (centroidChanged.get(assignment)) { + centroidCounts[assignment]++; float[] vector = vectors.vectorValue(idx); float[] centroid = centroids[assignment]; for (int d = 0; d < dim; d++) { @@ -115,7 +119,7 @@ private static boolean stepLloyd( } for (int clusterIdx = 0; clusterIdx < centroids.length; clusterIdx++) { - if (centroidCounts[clusterIdx] > 0) { + if (centroidChanged.get(clusterIdx)) { float countF = (float) centroidCounts[clusterIdx]; for (int d = 0; d < dim; d++) { centroids[clusterIdx][d] /= countF; @@ -433,17 +437,18 @@ private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansInterme } assert assignments.length == n; + FixedBitSet centroidChanged = new FixedBitSet(centroids.length); int[] centroidCounts = new int[centroids.length]; for (int i = 0; i < maxIterations; i++) { // This is potentially sampled, so we need to translate ordinals - if (stepLloyd(sampledVectors, translateOrd, centroids, centroidCounts, assignments, neighborhoods) == false) { + if (stepLloyd(sampledVectors, translateOrd, centroids, centroidChanged, centroidCounts, assignments, neighborhoods) == false) { break; } } // If we were sampled, do a once over the full set of vectors to finalize the centroids if (sampleSize < n || maxIterations == 0) { // No ordinal translation needed here, we are using the full set of vectors - stepLloyd(vectors, i -> i, centroids, centroidCounts, assignments, neighborhoods); + stepLloyd(vectors, i -> i, centroids, centroidChanged, centroidCounts, assignments, neighborhoods); } } From 9a2c5b7040c4d985387f319bb10c5c8105727e5b Mon Sep 17 00:00:00 2001 From: Ignacio Vera Date: Tue, 5 Aug 2025 09:44:26 +0100 Subject: [PATCH 3/3] iter --- .../codec/vectors/cluster/KMeansLocal.java | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java index bde10b618e50e..744fd248b2a49 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java @@ -100,18 +100,15 @@ private static boolean stepLloyd( } } if (changed) { - for (int clusterIdx = 0; clusterIdx < centroids.length; clusterIdx++) { - if (centroidChanged.get(clusterIdx)) { - Arrays.fill(centroids[clusterIdx], 0.0f); - } - } Arrays.fill(centroidCounts, 0); for (int idx = 0; idx < vectors.size(); idx++) { final int assignment = assignments[translateOrd.apply(idx)]; if (centroidChanged.get(assignment)) { - centroidCounts[assignment]++; - float[] vector = vectors.vectorValue(idx); float[] centroid = centroids[assignment]; + if (centroidCounts[assignment]++ == 0) { + Arrays.fill(centroid, 0.0f); + } + float[] vector = vectors.vectorValue(idx); for (int d = 0; d < dim; d++) { centroid[d] += vector[d]; } @@ -120,9 +117,12 @@ private static boolean stepLloyd( for (int clusterIdx = 0; clusterIdx < centroids.length; clusterIdx++) { if (centroidChanged.get(clusterIdx)) { - float countF = (float) centroidCounts[clusterIdx]; - for (int d = 0; d < dim; d++) { - centroids[clusterIdx][d] /= countF; + float count = (float) centroidCounts[clusterIdx]; + if (count > 0) { + float[] centroid = centroids[clusterIdx]; + for (int d = 0; d < dim; d++) { + centroid[d] /= count; + } } } }