diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/CentroidAssignments.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/CentroidAssignments.java index e92ece41077a6..f7192fd83b08d 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/CentroidAssignments.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/CentroidAssignments.java @@ -9,11 +9,18 @@ package org.elasticsearch.index.codec.vectors; -record CentroidAssignments(int numCentroids, float[][] centroids, int[] assignments, int[] overspillAssignments) { +record CentroidAssignments( + int numCentroids, + float[][] centroids, + int[] assignments, + int[] overspillAssignments, + int[] centroidVectorCount +) { - CentroidAssignments(float[][] centroids, int[] assignments, int[] overspillAssignments) { - this(centroids.length, centroids, assignments, overspillAssignments); + CentroidAssignments(float[][] centroids, int[] assignments, int[] overspillAssignments, int[] centroidVectorCount) { + this(centroids.length, centroids, assignments, overspillAssignments, centroidVectorCount); assert assignments.length == overspillAssignments.length || overspillAssignments.length == 0 : "assignments and overspillAssignments must have the same length"; + assert centroids.length == centroidVectorCount.length : "centroids and counts must have the same length"; } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java index d16163d6934e8..b32113b7fb5b0 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java @@ -66,16 +66,9 @@ LongValues buildAndWritePostingsLists( IndexOutput postingsOutput, long fileOffset, int[] assignments, - int[] overspillAssignments + int[] overspillAssignments, + int[] centroidVectorCount ) throws IOException { - int[] centroidVectorCount = new int[centroidSupplier.size()]; - for (int i = 0; i < assignments.length; i++) { - centroidVectorCount[assignments[i]]++; - // if soar assignments are present, count them as well - if (overspillAssignments.length > i && overspillAssignments[i] != -1) { - centroidVectorCount[overspillAssignments[i]]++; - } - } int maxPostingListSize = 0; int[][] assignmentsByCluster = new int[centroidSupplier.size()][]; @@ -146,7 +139,8 @@ LongValues buildAndWritePostingsLists( long fileOffset, MergeState mergeState, int[] assignments, - int[] overspillAssignments + int[] overspillAssignments, + int[] centroidVectorCount ) throws IOException { // first, quantize all the vectors into a temporary file String quantizedVectorsTempName = null; @@ -194,14 +188,6 @@ LongValues buildAndWritePostingsLists( mergeState.segmentInfo.dir.deleteFile(quantizedVectorsTemp.getName()); } } - int[] centroidVectorCount = new int[centroidSupplier.size()]; - for (int i = 0; i < assignments.length; i++) { - centroidVectorCount[assignments[i]]++; - // if soar assignments are present, count them as well - if (overspillAssignments.length > i && overspillAssignments[i] != -1) { - centroidVectorCount[overspillAssignments[i]]++; - } - } int maxPostingListSize = 0; int[][] assignmentsByCluster = new int[centroidSupplier.size()][]; @@ -423,10 +409,7 @@ public int size() { HierarchicalKMeans.MAXK, -1 // disable SOAR assignments ).cluster(floatVectorValues, centroidsPerParentCluster); - final int[] centroidVectorCount = new int[kMeansResult.centroids().length]; - for (int i = 0; i < kMeansResult.assignments().length; i++) { - centroidVectorCount[kMeansResult.assignments()[i]]++; - } + final int[] centroidVectorCount = kMeansResult.centroidCounts(); final int[][] vectorsPerCentroid = new int[kMeansResult.centroids().length][]; int maxVectorsPerCentroidLength = 0; for (int i = 0; i < kMeansResult.centroids().length; i++) { @@ -481,10 +464,12 @@ CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues fl static CentroidAssignments buildCentroidAssignments(FloatVectorValues floatVectorValues, int vectorPerCluster) throws IOException { KMeansResult kMeansResult = new HierarchicalKMeans(floatVectorValues.dimension()).cluster(floatVectorValues, vectorPerCluster); - float[][] centroids = kMeansResult.centroids(); - int[] assignments = kMeansResult.assignments(); - int[] soarAssignments = kMeansResult.soarAssignments(); - return new CentroidAssignments(centroids, assignments, soarAssignments); + return new CentroidAssignments( + kMeansResult.centroids(), + kMeansResult.assignments(), + kMeansResult.soarAssignments(), + kMeansResult.centroidCounts() + ); } static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java index 308ee391b5f4a..9c949e1ec05e8 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java @@ -138,7 +138,8 @@ abstract LongValues buildAndWritePostingsLists( IndexOutput postingsOutput, long fileOffset, int[] assignments, - int[] overspillAssignments + int[] overspillAssignments, + int[] centroidVectorCount ) throws IOException; abstract LongValues buildAndWritePostingsLists( @@ -149,7 +150,8 @@ abstract LongValues buildAndWritePostingsLists( long fileOffset, MergeState mergeState, int[] assignments, - int[] overspillAssignments + int[] overspillAssignments, + int[] centroidVectorCount ) throws IOException; abstract CentroidSupplier createCentroidSupplier( @@ -179,7 +181,8 @@ public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { ivfClusters, postingListOffset, centroidAssignments.assignments(), - centroidAssignments.overspillAssignments() + centroidAssignments.overspillAssignments(), + centroidAssignments.centroidVectorCount() ); final long postingListLength = ivfClusters.getFilePointer() - postingListOffset; // write centroids @@ -306,6 +309,7 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws final int numCentroids; final int[] assignments; final int[] overspillAssignments; + final int[] centroidVectorCount; final float[] calculatedGlobalCentroid = new float[fieldInfo.getVectorDimension()]; String centroidTempName = null; IndexOutput centroidTemp = null; @@ -327,6 +331,7 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws numCentroids = centroidAssignments.numCentroids(); assignments = centroidAssignments.assignments(); overspillAssignments = centroidAssignments.overspillAssignments(); + centroidVectorCount = centroidAssignments.centroidVectorCount(); success = true; } finally { if (success == false && centroidTempName != null) { @@ -362,7 +367,8 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws postingListOffset, mergeState, assignments, - overspillAssignments + overspillAssignments, + centroidVectorCount ); postingListLength = ivfClusters.getFilePointer() - postingListOffset; // write centroids diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java index de654fb851554..93dc01a450178 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java @@ -10,6 +10,7 @@ package org.elasticsearch.index.codec.vectors.cluster; import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.util.ArrayUtil; import java.io.IOException; import java.util.Arrays; @@ -72,7 +73,7 @@ public KMeansResult cluster(FloatVectorValues vectors, int targetSize) throws IO for (int j = 0; j < dimension; j++) { centroid[j] /= vectors.size(); } - return new KMeansIntermediate(new float[][] { centroid }, new int[vectors.size()]); + return new KMeansIntermediate(new float[][] { centroid }, new int[vectors.size()], new int[] { vectors.size() }); } // partition the space @@ -100,27 +101,26 @@ KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int ta Arrays.fill(assignments, -1); KMeansLocal kmeans = new KMeansLocal(m, maxIterations); float[][] centroids = KMeansLocal.pickInitialCentroids(vectors, k); - KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, assignments, vectors::ordToDoc); + int[] centroidVectorCount = new int[centroids.length]; + KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, assignments, centroidVectorCount, vectors::ordToDoc); kmeans.cluster(vectors, kMeansIntermediate); - // TODO: consider adding cluster size counts to the kmeans algo - // handle assignment here so we can track distance and cluster size - int[] centroidVectorCount = new int[centroids.length]; - int effectiveCluster = -1; int effectiveK = 0; - for (int assigment : assignments) { - centroidVectorCount[assigment]++; - // this cluster has received an assignment, its now effective, but only count it once - if (centroidVectorCount[assigment] == 1) { + int effectiveCluster = -1; + for (int i = 0; i < centroidVectorCount.length; i++) { + if (centroidVectorCount[i] > 0) { effectiveK++; - effectiveCluster = assigment; + if (effectiveK > 1) { + break; + } + effectiveCluster = i; } } if (effectiveK == 1) { final float[][] singleClusterCentroid = new float[1][]; singleClusterCentroid[0] = centroids[effectiveCluster]; - kMeansIntermediate.setCentroids(singleClusterCentroid); + kMeansIntermediate.setCentroidsAndCounts(singleClusterCentroid, new int[] { vectors.size() }); Arrays.fill(kMeansIntermediate.assignments(), 0); return kMeansIntermediate; } @@ -149,13 +149,22 @@ KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int ta adjustedCentroid, newSize - adjustedCentroid ); + final int[] newCounts = new int[newSize]; + System.arraycopy(kMeansIntermediate.centroidCounts(), 0, newCounts, 0, adjustedCentroid); + System.arraycopy( + kMeansIntermediate.centroidCounts(), + adjustedCentroid + 1, + newCounts, + adjustedCentroid, + newSize - adjustedCentroid + ); // we need to update the assignments to reflect the new centroid ordinals for (int i = 0; i < kMeansIntermediate.assignments().length; i++) { if (kMeansIntermediate.assignments()[i] > adjustedCentroid) { kMeansIntermediate.assignments()[i]--; } } - kMeansIntermediate.setCentroids(newCentroids); + kMeansIntermediate.setCentroidsAndCounts(newCentroids, newCounts); removedElements++; } } @@ -184,18 +193,26 @@ void updateAssignmentsWithRecursiveSplit(KMeansIntermediate current, int cluster int newCentroidsSize = current.centroids().length + subPartitions.centroids().length - 1; // update based on the outcomes from the split clusters recursion - float[][] newCentroids = new float[newCentroidsSize][]; - System.arraycopy(current.centroids(), 0, newCentroids, 0, current.centroids().length); + final float[][] newCentroids = ArrayUtil.growExact(current.centroids(), newCentroidsSize); + final int[] newCounts = ArrayUtil.growExact(current.centroidCounts(), newCentroidsSize); // replace the original cluster int origCentroidOrd = 0; newCentroids[cluster] = subPartitions.centroids()[0]; + newCounts[cluster] = subPartitions.centroidCounts()[0]; // append the remainder System.arraycopy(subPartitions.centroids(), 1, newCentroids, current.centroids().length, subPartitions.centroids().length - 1); assert Arrays.stream(newCentroids).allMatch(Objects::nonNull); + System.arraycopy( + subPartitions.centroidCounts(), + 1, + newCounts, + current.centroidCounts().length, + subPartitions.centroidCounts().length - 1 + ); - current.setCentroids(newCentroids); + current.setCentroidsAndCounts(newCentroids, newCounts); for (int i = 0; i < subPartitions.assignments().length; i++) { // this is a new centroid that was added, and so we'll need to remap it @@ -205,5 +222,22 @@ void updateAssignmentsWithRecursiveSplit(KMeansIntermediate current, int cluster current.assignments()[parentOrd] = subPartitions.assignments()[i] + orgCentroidsSize - 1; } } + assert assertCounts(newCounts, current.assignments()); + } + + private static boolean assertCounts(int[] counts, int[] assignments) { + int[] newCounts = new int[counts.length]; + for (int assignment : assignments) { + if (assignment != -1) { + newCounts[assignment]++; + } + } + for (int i = 0; i < counts.length; i++) { + if (counts[i] != newCounts[i]) { + return false; + } + } + return true; } + } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansIntermediate.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansIntermediate.java index e44112610812a..76d3144c7a75e 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansIntermediate.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansIntermediate.java @@ -17,22 +17,28 @@ class KMeansIntermediate extends KMeansResult { private final IntToIntFunction assignmentOrds; - private KMeansIntermediate(float[][] centroids, int[] assignments, IntToIntFunction assignmentOrds, int[] soarAssignments) { - super(centroids, assignments, soarAssignments); + private KMeansIntermediate( + float[][] centroids, + int[] assignments, + int[] counts, + IntToIntFunction assignmentOrds, + int[] soarAssignments + ) { + super(centroids, assignments, soarAssignments, counts); assert assignmentOrds != null; this.assignmentOrds = assignmentOrds; } - KMeansIntermediate(float[][] centroids, int[] assignments, IntToIntFunction assignmentOrdinals) { - this(centroids, assignments, assignmentOrdinals, new int[0]); + KMeansIntermediate(float[][] centroids, int[] assignments, int[] counts, IntToIntFunction assignmentOrdinals) { + this(centroids, assignments, counts, assignmentOrdinals, new int[0]); } KMeansIntermediate() { - this(new float[0][0], new int[0], i -> i, new int[0]); + this(new float[0][0], new int[0], new int[0], i -> i, new int[0]); } - KMeansIntermediate(float[][] centroids, int[] assignments) { - this(centroids, assignments, i -> i, new int[0]); + KMeansIntermediate(float[][] centroids, int[] assignments, int[] counts) { + this(centroids, assignments, counts, i -> i, new int[0]); } public int ordToDoc(int ord) { diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java index 744fd248b2a49..50ad96655e398 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java @@ -77,7 +77,6 @@ private static boolean stepLloyd( NeighborHood[] neighborhoods ) throws IOException { boolean changed = false; - int dim = vectors.dimension(); centroidChanged.clear(); final float[] distances = new float[4]; for (int idx = 0; idx < vectors.size(); idx++) { @@ -92,42 +91,15 @@ private static boolean stepLloyd( } if (assignment != bestCentroidOffset) { if (assignment != -1) { + centroidCounts[assignment]--; centroidChanged.set(assignment); } + centroidCounts[bestCentroidOffset]++; centroidChanged.set(bestCentroidOffset); assignments[vectorOrd] = bestCentroidOffset; changed = true; } } - if (changed) { - Arrays.fill(centroidCounts, 0); - for (int idx = 0; idx < vectors.size(); idx++) { - final int assignment = assignments[translateOrd.apply(idx)]; - if (centroidChanged.get(assignment)) { - float[] centroid = centroids[assignment]; - if (centroidCounts[assignment]++ == 0) { - Arrays.fill(centroid, 0.0f); - } - float[] vector = vectors.vectorValue(idx); - for (int d = 0; d < dim; d++) { - centroid[d] += vector[d]; - } - } - } - - for (int clusterIdx = 0; clusterIdx < centroids.length; clusterIdx++) { - if (centroidChanged.get(clusterIdx)) { - float count = (float) centroidCounts[clusterIdx]; - if (count > 0) { - float[] centroid = centroids[clusterIdx]; - for (int d = 0; d < dim; d++) { - centroid[d] /= count; - } - } - } - } - } - return changed; } @@ -276,6 +248,7 @@ private void assignSpilled( assert spilledAssignments != null; assert spilledAssignments.length == vectors.size(); float[][] centroids = kmeansIntermediate.centroids(); + int[] counts = kmeansIntermediate.centroidCounts(); float[] diffs = new float[vectors.dimension()]; final float[] distances = new float[4]; @@ -359,6 +332,7 @@ private void assignSpilled( assert bestAssignment != -1 : "Failed to assign soar vector to centroid"; spilledAssignments[i] = bestAssignment; + counts[bestAssignment]++; } } @@ -424,9 +398,11 @@ private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansInterme int k = centroids.length; int n = vectors.size(); int[] assignments = kMeansIntermediate.assignments(); + int[] centroidCounts = kMeansIntermediate.centroidCounts(); if (k == 1) { Arrays.fill(assignments, 0); + centroidCounts[0] = assignments.length; return; } IntToIntFunction translateOrd = i -> i; @@ -438,33 +414,55 @@ private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansInterme assert assignments.length == n; FixedBitSet centroidChanged = new FixedBitSet(centroids.length); - int[] centroidCounts = new int[centroids.length]; for (int i = 0; i < maxIterations; i++) { // This is potentially sampled, so we need to translate ordinals if (stepLloyd(sampledVectors, translateOrd, centroids, centroidChanged, centroidCounts, assignments, neighborhoods) == false) { break; } + updateCentroids(vectors, centroids, assignments, centroidCounts, centroidChanged); } // If we were sampled, do a once over the full set of vectors to finalize the centroids if (sampleSize < n || maxIterations == 0) { // No ordinal translation needed here, we are using the full set of vectors - stepLloyd(vectors, i -> i, centroids, centroidChanged, centroidCounts, assignments, neighborhoods); + if (stepLloyd(vectors, i -> i, centroids, centroidChanged, centroidCounts, assignments, neighborhoods)) { + updateCentroids(vectors, centroids, assignments, centroidCounts, centroidChanged); + } } } - /** - * helper that calls {@link KMeansLocal#cluster(FloatVectorValues, KMeansIntermediate)} given a set of initialized centroids, - * this call is not neighbor aware - * - * @param vectors the vectors to cluster - * @param centroids the initialized centroids to be shifted using k-means - * @param sampleSize the subset of vectors to use when shifting centroids - * @param maxIterations the max iterations to shift centroids - */ - public static void cluster(FloatVectorValues vectors, float[][] centroids, int sampleSize, int maxIterations) throws IOException { - KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, new int[vectors.size()], vectors::ordToDoc); - KMeansLocal kMeans = new KMeansLocal(sampleSize, maxIterations); - kMeans.cluster(vectors, kMeansIntermediate); + private void updateCentroids( + FloatVectorValues vectors, + float[][] centroids, + int[] assignments, + int[] centroidCounts, + FixedBitSet centroidChanged + ) throws IOException { + int dim = vectors.dimension(); + for (int clusterIdx = 0; clusterIdx < centroids.length; clusterIdx++) { + if (centroidChanged.get(clusterIdx) && centroidCounts[clusterIdx] > 0) { + Arrays.fill(centroids[clusterIdx], 0.0f); + } + } + for (int idx = 0; idx < assignments.length; idx++) { + final int assignment = assignments[idx]; + if (assignment != -1 && centroidChanged.get(assignment)) { + float[] centroid = centroids[assignment]; + float[] vector = vectors.vectorValue(idx); + for (int d = 0; d < dim; d++) { + centroid[d] += vector[d]; + } + } + } + for (int clusterIdx = 0; clusterIdx < centroids.length; clusterIdx++) { + if (centroidChanged.get(clusterIdx)) { + float count = (float) centroidCounts[clusterIdx]; + if (count > 0) { + float[] centroid = centroids[clusterIdx]; + for (int d = 0; d < dim; d++) { + centroid[d] /= count; + } + } + } + } } - } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansResult.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansResult.java index 5c2f4afb03f1a..840f26bd1fb5d 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansResult.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansResult.java @@ -14,24 +14,31 @@ */ public class KMeansResult { private float[][] centroids; + private int[] centroidCounts; private final int[] assignments; private int[] soarAssignments; - KMeansResult(float[][] centroids, int[] assignments, int[] soarAssignments) { + KMeansResult(float[][] centroids, int[] assignments, int[] soarAssignments, int[] centroidCounts) { assert centroids != null; assert assignments != null; assert soarAssignments != null; this.centroids = centroids; this.assignments = assignments; this.soarAssignments = soarAssignments; + this.centroidCounts = centroidCounts; } public float[][] centroids() { return centroids; } - void setCentroids(float[][] centroids) { + public int[] centroidCounts() { + return centroidCounts; + } + + void setCentroidsAndCounts(float[][] centroids, int[] counts) { this.centroids = centroids; + this.centroidCounts = counts; } public int[] assignments() { diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeansTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeansTests.java index 0c2ce31f51339..f8ab65eef8ec0 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeansTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeansTests.java @@ -31,6 +31,7 @@ public void testHKmeans() throws IOException { HierarchicalKMeans hkmeans = new HierarchicalKMeans(dims, maxIterations, sampleSize, clustersPerNeighborhood, soarLambda); KMeansResult result = hkmeans.cluster(vectors, targetSize); + KMeansLocalTests.assertCounts(result); float[][] centroids = result.centroids(); int[] assignments = result.assignments(); diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocalTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocalTests.java index a2d34d28f3784..7512e5d2b1a33 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocalTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocalTests.java @@ -15,6 +15,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import static org.hamcrest.Matchers.containsString; @@ -23,7 +24,7 @@ public class KMeansLocalTests extends ESTestCase { public void testIllegalClustersPerNeighborhood() { KMeansLocal kMeansLocal = new KMeansLocal(randomInt(), randomInt()); - KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(new float[0][], new int[0], i -> i); + KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(new float[0][], new int[0], new int[0], i -> i); IllegalArgumentException ex = expectThrows( IllegalArgumentException.class, () -> kMeansLocal.cluster( @@ -47,10 +48,11 @@ public void testKMeansNeighbors() throws IOException { FloatVectorValues vectors = generateData(nVectors, dims, nClusters); float[][] centroids = KMeansLocal.pickInitialCentroids(vectors, nClusters); - KMeansLocal.cluster(vectors, centroids, sampleSize, maxIterations); + cluster(vectors, centroids, sampleSize, maxIterations); int[] assignments = new int[vectors.size()]; int[] assignmentOrdinals = new int[vectors.size()]; + int[] counts = new int[centroids.length]; for (int i = 0; i < vectors.size(); i++) { float minDist = Float.MAX_VALUE; int ord = -1; @@ -63,14 +65,16 @@ public void testKMeansNeighbors() throws IOException { } assignments[i] = ord; assignmentOrdinals[i] = i; + counts[ord]++; } - KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, assignments, i -> assignmentOrdinals[i]); + KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, assignments, counts, i -> assignmentOrdinals[i]); KMeansLocal kMeansLocal = new KMeansLocal(sampleSize, maxIterations); kMeansLocal.cluster(vectors, kMeansIntermediate, clustersPerNeighborhood, soarLambda); assertEquals(nClusters, centroids.length); assertNotNull(kMeansIntermediate.soarAssignments()); + assertCounts(kMeansIntermediate); } public void testKMeansNeighborsAllZero() throws IOException { @@ -88,10 +92,11 @@ public void testKMeansNeighborsAllZero() throws IOException { FloatVectorValues fvv = FloatVectorValues.fromFloats(vectors, 5); float[][] centroids = KMeansLocal.pickInitialCentroids(fvv, nClusters); - KMeansLocal.cluster(fvv, centroids, sampleSize, maxIterations); + cluster(fvv, centroids, sampleSize, maxIterations); int[] assignments = new int[vectors.size()]; int[] assignmentOrdinals = new int[vectors.size()]; + int[] counts = new int[centroids.length]; for (int i = 0; i < vectors.size(); i++) { float minDist = Float.MAX_VALUE; int ord = -1; @@ -104,9 +109,10 @@ public void testKMeansNeighborsAllZero() throws IOException { } assignments[i] = ord; assignmentOrdinals[i] = i; + counts[ord]++; } - KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, assignments, i -> assignmentOrdinals[i]); + KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, assignments, counts, i -> assignmentOrdinals[i]); KMeansLocal kMeansLocal = new KMeansLocal(sampleSize, maxIterations); kMeansLocal.cluster(fvv, kMeansIntermediate, clustersPerNeighborhood, soarLambda); @@ -119,6 +125,7 @@ public void testKMeansNeighborsAllZero() throws IOException { } } } + assertCounts(kMeansIntermediate); } private static FloatVectorValues generateData(int nSamples, int nDims, int nClusters) { @@ -141,4 +148,31 @@ private static FloatVectorValues generateData(int nSamples, int nDims, int nClus } return FloatVectorValues.fromFloats(vectors, nDims); } + + static void assertCounts(KMeansResult kMeansResult) { + int[] counts = new int[kMeansResult.centroidCounts().length]; + for (int i = 0; i < kMeansResult.assignments().length; i++) { + counts[kMeansResult.assignments()[i]]++; + if (kMeansResult.soarAssignments().length > i && kMeansResult.soarAssignments()[i] != -1) { + counts[kMeansResult.soarAssignments()[i]]++; + } + } + for (int i = 0; i < counts.length; i++) { + assertEquals(counts[i], kMeansResult.centroidCounts()[i]); + } + } + + private static void cluster(FloatVectorValues vectors, float[][] centroids, int sampleSize, int maxIterations) throws IOException { + int[] assignments = new int[vectors.size()]; + Arrays.fill(assignments, -1); + KMeansIntermediate kMeansIntermediate = new KMeansIntermediate( + centroids, + assignments, + new int[centroids.length], + vectors::ordToDoc + ); + KMeansLocal kMeans = new KMeansLocal(sampleSize, maxIterations); + kMeans.cluster(vectors, kMeansIntermediate); + assertCounts(kMeansIntermediate); + } }