diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/SampleReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/SampleReader.java index 89428bfb573a6..f2d7944f1088c 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/SampleReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/SampleReader.java @@ -25,10 +25,11 @@ import org.apache.lucene.util.Bits; import java.io.IOException; +import java.util.Arrays; import java.util.Random; import java.util.function.IntUnaryOperator; -class SampleReader extends FloatVectorValues implements HasIndexSlice { +public class SampleReader extends FloatVectorValues implements HasIndexSlice { private final FloatVectorValues origin; private final int sampleSize; private final IntUnaryOperator sampleFunction; @@ -71,7 +72,8 @@ public int getVectorByteLength() { @Override public int ordToDoc(int ord) { - throw new IllegalStateException("Not supported"); + // get the original ordinal from the sample ordinal + return sampleFunction.applyAsInt(ord); } @Override @@ -79,13 +81,15 @@ public Bits getAcceptOrds(Bits acceptDocs) { throw new IllegalStateException("Not supported"); } - static SampleReader createSampleReader(FloatVectorValues origin, int k, long seed) { + public static SampleReader createSampleReader(FloatVectorValues origin, int k, long seed) { // TODO can we do something algorithmically that aligns an ordinal with a unique integer between 0 and numVectors? if (k >= origin.size()) { new SampleReader(origin, origin.size(), i -> i); } // TODO maybe use bigArrays? int[] samples = reservoirSample(origin.size(), k, seed); + // sort to prevent random backwards access weirdness + Arrays.sort(samples); return new SampleReader(origin, samples.length, i -> samples[i]); } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java index 6f7705bfcc1ab..83462fb882f32 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java @@ -12,6 +12,7 @@ import org.apache.lucene.index.FloatVectorValues; import java.io.IOException; +import java.util.Arrays; /** * An implementation of the hierarchical k-means algorithm that better partitions data than naive k-means @@ -84,6 +85,8 @@ KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int ta // TODO: instead of creating a sub-cluster assignments reuse the parent array each time int[] assignments = new int[vectors.size()]; + // ensure we don't over assign to cluster 0 without adjusting it + Arrays.fill(assignments, -1); KMeansLocal kmeans = new KMeansLocal(m, maxIterations); float[][] centroids = KMeansLocal.pickInitialCentroids(vectors, k); KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, assignments, vectors::ordToDoc); diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java index 0bae5862fda8f..3d037ecf749db 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java @@ -11,6 +11,8 @@ import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.IntToIntFunction; +import org.elasticsearch.index.codec.vectors.SampleReader; import org.elasticsearch.simdvec.ESVectorUtil; import java.io.IOException; @@ -74,12 +76,12 @@ static float[][] pickInitialCentroids(FloatVectorValues vectors, int centroidCou return centroids; } - private boolean stepLloyd( + private static boolean stepLloyd( FloatVectorValues vectors, + IntToIntFunction translateOrd, float[][] centroids, float[][] nextCentroids, int[] assignments, - int sampleSize, List neighborhoods ) throws IOException { boolean changed = false; @@ -90,9 +92,10 @@ private boolean stepLloyd( Arrays.fill(nextCentroid, 0.0f); } - for (int i = 0; i < sampleSize; i++) { - float[] vector = vectors.vectorValue(i); - final int assignment = assignments[i]; + for (int idx = 0; idx < vectors.size(); idx++) { + float[] vector = vectors.vectorValue(idx); + int vectorOrd = translateOrd.apply(idx); + final int assignment = assignments[vectorOrd]; final int bestCentroidOffset; if (neighborhoods != null) { bestCentroidOffset = getBestCentroidFromNeighbours(centroids, vector, assignment, neighborhoods.get(assignment)); @@ -100,7 +103,7 @@ private boolean stepLloyd( bestCentroidOffset = getBestCentroid(centroids, vector); } if (assignment != bestCentroidOffset) { - assignments[i] = bestCentroidOffset; + assignments[vectorOrd] = bestCentroidOffset; changed = true; } centroidCounts[bestCentroidOffset]++; @@ -121,7 +124,7 @@ private boolean stepLloyd( return changed; } - int getBestCentroidFromNeighbours(float[][] centroids, float[] vector, int centroidIdx, int[] centroidOffsets) { + private static int getBestCentroidFromNeighbours(float[][] centroids, float[] vector, int centroidIdx, int[] centroidOffsets) { int bestCentroidOffset = centroidIdx; assert centroidIdx >= 0 && centroidIdx < centroids.length; float minDsq = VectorUtil.squareDistance(vector, centroids[centroidIdx]); @@ -135,7 +138,7 @@ int getBestCentroidFromNeighbours(float[][] centroids, float[] vector, int centr return bestCentroidOffset; } - int getBestCentroid(float[][] centroids, float[] vector) { + private static int getBestCentroid(float[][] centroids, float[] vector) { int bestCentroidOffset = 0; float minDsq = Float.MAX_VALUE; for (int i = 0; i < centroids.length; i++) { @@ -281,7 +284,7 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, b } } - void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, List neighborhoods) throws IOException { + private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, List neighborhoods) throws IOException { float[][] centroids = kMeansIntermediate.centroids(); int k = centroids.length; int n = vectors.size(); @@ -289,16 +292,26 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, L if (k == 1 || k >= n) { return; } - + IntToIntFunction translateOrd = i -> i; + FloatVectorValues sampledVectors = vectors; + if (sampleSize < n) { + sampledVectors = SampleReader.createSampleReader(vectors, sampleSize, 42L); + translateOrd = sampledVectors::ordToDoc; + } int[] assignments = kMeansIntermediate.assignments(); assert assignments.length == n; float[][] nextCentroids = new float[centroids.length][vectors.dimension()]; for (int i = 0; i < maxIterations; i++) { - if (stepLloyd(vectors, centroids, nextCentroids, assignments, sampleSize, neighborhoods) == false) { + // This is potentially sampled, so we need to translate ordinals + if (stepLloyd(sampledVectors, translateOrd, centroids, nextCentroids, assignments, neighborhoods) == false) { break; } } - stepLloyd(vectors, centroids, nextCentroids, assignments, vectors.size(), neighborhoods); + // If we were sampled, do a once over the full set of vectors to finalize the centroids + if (sampleSize < n) { + // No ordinal translation needed here, we are using the full set of vectors + stepLloyd(vectors, i -> i, centroids, nextCentroids, assignments, neighborhoods); + } } /**