Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
package org.elasticsearch.index.codec.vectors.cluster;

import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.util.VectorUtil;

import java.io.IOException;

Expand All @@ -21,7 +20,7 @@ public class HierarchicalKMeans {

static final int MAXK = 128;
static final int MAX_ITERATIONS_DEFAULT = 6;
static final int SAMPLES_PER_CLUSTER_DEFAULT = 256;
static final int SAMPLES_PER_CLUSTER_DEFAULT = 64;
Copy link
Contributor

@john-wagster john-wagster Jun 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be worth including some of the runs you were doing in the PR comments just so we can look back at them if we need to to confirm recall wasn't hurt by doing this

I'll run a couple runs myself here real quick too to double check with a different model

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ran this whole PR and just the sampling change only on glove 200d 1m, 3m and 10m and for both saw no major drops in recall

static final float DEFAULT_SOAR_LAMBDA = 1.0f;

final int dimension;
Expand Down Expand Up @@ -67,8 +66,7 @@ public KMeansResult cluster(FloatVectorValues vectors, int targetSize) throws IO
// partition the space
KMeansIntermediate kMeansIntermediate = clusterAndSplit(vectors, targetSize);
if (kMeansIntermediate.centroids().length > 1 && kMeansIntermediate.centroids().length < vectors.size()) {
float f = Math.min((float) samplesPerCluster / targetSize, 1.0f);
int localSampleSize = (int) (f * vectors.size());
int localSampleSize = Math.min(kMeansIntermediate.centroids().length * samplesPerCluster, vectors.size());
KMeansLocal kMeansLocal = new KMeansLocal(localSampleSize, maxIterations, clustersPerNeighborhood, DEFAULT_SOAR_LAMBDA);
kMeansLocal.cluster(vectors, kMeansIntermediate, true);
}
Expand All @@ -86,42 +84,16 @@ KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int ta

// TODO: instead of creating a sub-cluster assignments reuse the parent array each time
int[] assignments = new int[vectors.size()];

KMeansLocal kmeans = new KMeansLocal(m, maxIterations);
float[][] centroids = KMeansLocal.pickInitialCentroids(vectors, k);
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids);
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, assignments, vectors::ordToDoc);
kmeans.cluster(vectors, kMeansIntermediate);

// TODO: consider adding cluster size counts to the kmeans algo
// handle assignment here so we can track distance and cluster size
int[] centroidVectorCount = new int[centroids.length];
float[][] nextCentroids = new float[centroids.length][dimension];
for (int i = 0; i < vectors.size(); i++) {
float smallest = Float.MAX_VALUE;
int centroidIdx = -1;
float[] vector = vectors.vectorValue(i);
for (int j = 0; j < centroids.length; j++) {
float[] centroid = centroids[j];
float d = VectorUtil.squareDistance(vector, centroid);
if (d < smallest) {
smallest = d;
centroidIdx = j;
}
}
centroidVectorCount[centroidIdx]++;
for (int j = 0; j < dimension; j++) {
nextCentroids[centroidIdx][j] += vector[j];
}
assignments[i] = centroidIdx;
}

// update centroids based on assignments of all vectors
for (int i = 0; i < centroids.length; i++) {
if (centroidVectorCount[i] > 0) {
for (int j = 0; j < dimension; j++) {
centroids[i][j] = nextCentroids[i][j] / centroidVectorCount[i];
}
}
for (int assigment : assignments) {
centroidVectorCount[assigment]++;
}

int effectiveK = 0;
Expand All @@ -131,8 +103,6 @@ KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int ta
}
}

kMeansIntermediate = new KMeansIntermediate(centroids, assignments, vectors::ordToDoc);

if (effectiveK == 1) {
return kMeansIntermediate;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ private KMeansIntermediate(float[][] centroids, int[] assignments, IntToIntFunct
this(new float[0][0], new int[0], i -> i, new int[0]);
}

KMeansIntermediate(float[][] centroids) {
this(centroids, new int[0], i -> i, new int[0]);
}

KMeansIntermediate(float[][] centroids, int[] assignments) {
this(centroids, assignments, i -> i, new int[0]);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,17 @@ private boolean stepLloyd(

for (int i = 0; i < sampleSize; i++) {
float[] vector = vectors.vectorValue(i);
int[] neighborOffsets = null;
int centroidIdx = -1;
final int assignment = assignments[i];
final int bestCentroidOffset;
if (neighborhoods != null) {
neighborOffsets = neighborhoods.get(assignments[i]);
centroidIdx = assignments[i];
bestCentroidOffset = getBestCentroidFromNeighbours(centroids, vector, assignment, neighborhoods.get(assignment));
} else {
bestCentroidOffset = getBestCentroid(centroids, vector);
}
int bestCentroidOffset = getBestCentroidOffset(centroids, vector, centroidIdx, neighborOffsets);
if (assignments[i] != bestCentroidOffset) {
if (assignment != bestCentroidOffset) {
assignments[i] = bestCentroidOffset;
changed = true;
}
assignments[i] = bestCentroidOffset;
centroidCounts[bestCentroidOffset]++;
for (int d = 0; d < dim; d++) {
nextCentroids[bestCentroidOffset][d] += vector[d];
Expand All @@ -116,23 +116,28 @@ private boolean stepLloyd(
return changed;
}

int getBestCentroidOffset(float[][] centroids, float[] vector, int centroidIdx, int[] centroidOffsets) {
int getBestCentroidFromNeighbours(float[][] centroids, float[] vector, int centroidIdx, int[] centroidOffsets) {
int bestCentroidOffset = centroidIdx;
float minDsq;
if (centroidIdx > 0 && centroidIdx < centroids.length) {
minDsq = VectorUtil.squareDistance(vector, centroids[centroidIdx]);
} else {
minDsq = Float.MAX_VALUE;
assert centroidIdx >= 0 && centroidIdx < centroids.length;
float minDsq = VectorUtil.squareDistance(vector, centroids[centroidIdx]);
for (int offset : centroidOffsets) {
float dsq = VectorUtil.squareDistance(vector, centroids[offset]);
if (dsq < minDsq) {
minDsq = dsq;
bestCentroidOffset = offset;
}
}
return bestCentroidOffset;
}

int k = 0;
for (int j = 0; j < centroids.length; j++) {
if (centroidOffsets == null || j == centroidOffsets[k]) {
float dsq = VectorUtil.squareDistance(vector, centroids[j]);
if (dsq < minDsq) {
minDsq = dsq;
bestCentroidOffset = j;
}
int getBestCentroid(float[][] centroids, float[] vector) {
int bestCentroidOffset = 0;
float minDsq = Float.MAX_VALUE;
for (int i = 0; i < centroids.length; i++) {
float dsq = VectorUtil.squareDistance(vector, centroids[i]);
if (dsq < minDsq) {
minDsq = dsq;
bestCentroidOffset = i;
}
}
return bestCentroidOffset;
Expand Down Expand Up @@ -271,7 +276,8 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, L
return;
}

int[] assignments = new int[n];
int[] assignments = kMeansIntermediate.assignments();
assert assignments.length == n;
float[][] nextCentroids = new float[centroids.length][vectors.dimension()];
for (int i = 0; i < maxIterations; i++) {
if (stepLloyd(vectors, centroids, nextCentroids, assignments, sampleSize, neighborhoods) == false) {
Expand All @@ -291,7 +297,7 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, L
* @param maxIterations the max iterations to shift centroids
*/
public static void cluster(FloatVectorValues vectors, float[][] centroids, int sampleSize, int maxIterations) throws IOException {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is only used for tests and kinda silly now you can just get rid of this or I can clean it up in a subsequent PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's clean up in a follow up PR

KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids);
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, new int[vectors.size()], vectors::ordToDoc);
KMeansLocal kMeans = new KMeansLocal(sampleSize, maxIterations);
kMeans.cluster(vectors, kMeansIntermediate);
}
Expand Down