Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;

/**
* An implementation of the hierarchical k-means algorithm that better partitions data than naive k-means
Expand Down Expand Up @@ -148,30 +149,32 @@ static FloatVectorValues createClusterSlice(int clusterSize, int cluster, FloatV
}

void updateAssignmentsWithRecursiveSplit(KMeansIntermediate current, int cluster, KMeansIntermediate subPartitions) {
if (subPartitions.centroids().length == 0) {
return; // nothing to do, sub-partitions is empty
}
int orgCentroidsSize = current.centroids().length;
int newCentroidsSize = current.centroids().length + subPartitions.centroids().length - 1;

// update based on the outcomes from the split clusters recursion
if (subPartitions.centroids().length > 1) {
float[][] newCentroids = new float[newCentroidsSize][dimension];
System.arraycopy(current.centroids(), 0, newCentroids, 0, current.centroids().length);
float[][] newCentroids = new float[newCentroidsSize][];
System.arraycopy(current.centroids(), 0, newCentroids, 0, current.centroids().length);

// replace the original cluster
int origCentroidOrd = 0;
newCentroids[cluster] = subPartitions.centroids()[0];
// replace the original cluster
int origCentroidOrd = 0;
newCentroids[cluster] = subPartitions.centroids()[0];

// append the remainder
System.arraycopy(subPartitions.centroids(), 1, newCentroids, current.centroids().length, subPartitions.centroids().length - 1);
// append the remainder
System.arraycopy(subPartitions.centroids(), 1, newCentroids, current.centroids().length, subPartitions.centroids().length - 1);
assert Arrays.stream(newCentroids).allMatch(Objects::nonNull);

current.setCentroids(newCentroids);
current.setCentroids(newCentroids);

for (int i = 0; i < subPartitions.assignments().length; i++) {
// this is a new centroid that was added, and so we'll need to remap it
if (subPartitions.assignments()[i] != origCentroidOrd) {
int parentOrd = subPartitions.ordToDoc(i);
assert current.assignments()[parentOrd] == cluster;
current.assignments()[parentOrd] = subPartitions.assignments()[i] + orgCentroidsSize - 1;
}
for (int i = 0; i < subPartitions.assignments().length; i++) {
// this is a new centroid that was added, and so we'll need to remap it
if (subPartitions.assignments()[i] != origCentroidOrd) {
int parentOrd = subPartitions.ordToDoc(i);
assert current.assignments()[parentOrd] == cluster;
current.assignments()[parentOrd] = subPartitions.assignments()[i] + orgCentroidsSize - 1;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,14 @@ private static boolean stepLloyd(
FloatVectorValues vectors,
IntToIntFunction translateOrd,
float[][] centroids,
float[][] nextCentroids,
int[] centroidCounts,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think of having boolean[] isChanged or FixedBitSet isChanged and we only adjust the centroids that are actually changed (basically, keeping track of changed centroids, so new FixedBitSet(centroids.length)).

It seems to me that as the steps increase, fewer centroids will actually get mutated. This will still significantly reduce heap.

Copy link
Contributor Author

@iverase iverase Aug 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if I follow you but we are only mutating changed centroids by checking if the counts are bigger than 0, e.g:

          for (int clusterIdx = 0; clusterIdx < centroids.length; clusterIdx++) {
                if (centroidCounts[clusterIdx] > 0) {
                    Arrays.fill(centroids[clusterIdx], 0.0f);
                }
            }

And we need the counts for computing the centroids:

             for (int clusterIdx = 0; clusterIdx < centroids.length; clusterIdx++) {
                if (centroidCounts[clusterIdx] > 0) {
                    float countF = (float) centroidCounts[clusterIdx];
                    for (int d = 0; d < dim; d++) {
                        centroids[clusterIdx][d] /= countF;
                    }
                }

In general I don't think this allocation is problematic as later on we will allocate an array for soar assignments which should be much bigger than this array. That's not the case for the nextCentroids array.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@iverase centroidCounts accounts for centroids that have assigned vectors but never changed?

I am saying it seems like if ANY centroid changes at all, we rebuild all of them. This seems wrong to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doh! sorry, I got you know, I will have a go tomorrow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, it makes sense to me, lmk what you think.

int[] assignments,
NeighborHood[] neighborhoods
) throws IOException {
boolean changed = false;
int dim = vectors.dimension();
int[] centroidCounts = new int[centroids.length];
Arrays.fill(centroidCounts, 0);

for (float[] nextCentroid : nextCentroids) {
Arrays.fill(nextCentroid, 0.0f);
}
final float[] distances = new float[4];
for (int idx = 0; idx < vectors.size(); idx++) {
float[] vector = vectors.vectorValue(idx);
Expand All @@ -97,16 +94,32 @@ private static boolean stepLloyd(
changed = true;
}
centroidCounts[bestCentroidOffset]++;
for (int d = 0; d < dim; d++) {
nextCentroids[bestCentroidOffset][d] += vector[d];
}
}

for (int clusterIdx = 0; clusterIdx < centroids.length; clusterIdx++) {
if (centroidCounts[clusterIdx] > 0) {
float countF = (float) centroidCounts[clusterIdx];
for (int d = 0; d < dim; d++) {
centroids[clusterIdx][d] = nextCentroids[clusterIdx][d] / countF;
if (changed) {
for (int clusterIdx = 0; clusterIdx < centroids.length; clusterIdx++) {
if (centroidCounts[clusterIdx] > 0) {
Arrays.fill(centroids[clusterIdx], 0.0f);
}
}

for (int idx = 0; idx < vectors.size(); idx++) {
final int assignment = assignments[translateOrd.apply(idx)];
if (centroidCounts[assignment] > 0) {
float[] vector = vectors.vectorValue(idx);
float[] centroid = centroids[assignment];
for (int d = 0; d < dim; d++) {
centroid[d] += vector[d];
}
}
}

for (int clusterIdx = 0; clusterIdx < centroids.length; clusterIdx++) {
if (centroidCounts[clusterIdx] > 0) {
float countF = (float) centroidCounts[clusterIdx];
for (int d = 0; d < dim; d++) {
centroids[clusterIdx][d] /= countF;
}
}
}
}
Expand Down Expand Up @@ -420,17 +433,17 @@ private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansInterme
}

assert assignments.length == n;
float[][] nextCentroids = new float[centroids.length][vectors.dimension()];
int[] centroidCounts = new int[centroids.length];
for (int i = 0; i < maxIterations; i++) {
// This is potentially sampled, so we need to translate ordinals
if (stepLloyd(sampledVectors, translateOrd, centroids, nextCentroids, assignments, neighborhoods) == false) {
if (stepLloyd(sampledVectors, translateOrd, centroids, centroidCounts, assignments, neighborhoods) == false) {
break;
}
}
// If we were sampled, do a once over the full set of vectors to finalize the centroids
if (sampleSize < n || maxIterations == 0) {
// No ordinal translation needed here, we are using the full set of vectors
stepLloyd(vectors, i -> i, centroids, nextCentroids, assignments, neighborhoods);
stepLloyd(vectors, i -> i, centroids, centroidCounts, assignments, neighborhoods);
}
}

Expand Down