Skip to content

Add the current count of vectors in a cluster in hierarchical k-means #132587

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,18 @@

package org.elasticsearch.index.codec.vectors;

record CentroidAssignments(int numCentroids, float[][] centroids, int[] assignments, int[] overspillAssignments) {
record CentroidAssignments(
int numCentroids,
float[][] centroids,
int[] assignments,
int[] overspillAssignments,
int[] centroidVectorCount
) {

CentroidAssignments(float[][] centroids, int[] assignments, int[] overspillAssignments) {
this(centroids.length, centroids, assignments, overspillAssignments);
CentroidAssignments(float[][] centroids, int[] assignments, int[] overspillAssignments, int[] centroidVectorCount) {
this(centroids.length, centroids, assignments, overspillAssignments, centroidVectorCount);
assert assignments.length == overspillAssignments.length || overspillAssignments.length == 0
: "assignments and overspillAssignments must have the same length";
assert centroids.length == centroidVectorCount.length : "centroids and counts must have the same length";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,9 @@ LongValues buildAndWritePostingsLists(
IndexOutput postingsOutput,
long fileOffset,
int[] assignments,
int[] overspillAssignments
int[] overspillAssignments,
int[] centroidVectorCount
) throws IOException {
int[] centroidVectorCount = new int[centroidSupplier.size()];
for (int i = 0; i < assignments.length; i++) {
centroidVectorCount[assignments[i]]++;
// if soar assignments are present, count them as well
if (overspillAssignments.length > i && overspillAssignments[i] != -1) {
centroidVectorCount[overspillAssignments[i]]++;
}
}

int maxPostingListSize = 0;
int[][] assignmentsByCluster = new int[centroidSupplier.size()][];
Expand Down Expand Up @@ -146,7 +139,8 @@ LongValues buildAndWritePostingsLists(
long fileOffset,
MergeState mergeState,
int[] assignments,
int[] overspillAssignments
int[] overspillAssignments,
int[] centroidVectorCount
) throws IOException {
// first, quantize all the vectors into a temporary file
String quantizedVectorsTempName = null;
Expand Down Expand Up @@ -194,14 +188,6 @@ LongValues buildAndWritePostingsLists(
mergeState.segmentInfo.dir.deleteFile(quantizedVectorsTemp.getName());
}
}
int[] centroidVectorCount = new int[centroidSupplier.size()];
for (int i = 0; i < assignments.length; i++) {
centroidVectorCount[assignments[i]]++;
// if soar assignments are present, count them as well
if (overspillAssignments.length > i && overspillAssignments[i] != -1) {
centroidVectorCount[overspillAssignments[i]]++;
}
}

int maxPostingListSize = 0;
int[][] assignmentsByCluster = new int[centroidSupplier.size()][];
Expand Down Expand Up @@ -423,10 +409,7 @@ public int size() {
HierarchicalKMeans.MAXK,
-1 // disable SOAR assignments
).cluster(floatVectorValues, centroidsPerParentCluster);
final int[] centroidVectorCount = new int[kMeansResult.centroids().length];
for (int i = 0; i < kMeansResult.assignments().length; i++) {
centroidVectorCount[kMeansResult.assignments()[i]]++;
}
final int[] centroidVectorCount = kMeansResult.centroidCounts();
final int[][] vectorsPerCentroid = new int[kMeansResult.centroids().length][];
int maxVectorsPerCentroidLength = 0;
for (int i = 0; i < kMeansResult.centroids().length; i++) {
Expand Down Expand Up @@ -481,10 +464,12 @@ CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues fl

static CentroidAssignments buildCentroidAssignments(FloatVectorValues floatVectorValues, int vectorPerCluster) throws IOException {
KMeansResult kMeansResult = new HierarchicalKMeans(floatVectorValues.dimension()).cluster(floatVectorValues, vectorPerCluster);
float[][] centroids = kMeansResult.centroids();
int[] assignments = kMeansResult.assignments();
int[] soarAssignments = kMeansResult.soarAssignments();
return new CentroidAssignments(centroids, assignments, soarAssignments);
return new CentroidAssignments(
kMeansResult.centroids(),
kMeansResult.assignments(),
kMeansResult.soarAssignments(),
kMeansResult.centroidCounts()
);
}

static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ abstract LongValues buildAndWritePostingsLists(
IndexOutput postingsOutput,
long fileOffset,
int[] assignments,
int[] overspillAssignments
int[] overspillAssignments,
int[] centroidVectorCount
) throws IOException;

abstract LongValues buildAndWritePostingsLists(
Expand All @@ -149,7 +150,8 @@ abstract LongValues buildAndWritePostingsLists(
long fileOffset,
MergeState mergeState,
int[] assignments,
int[] overspillAssignments
int[] overspillAssignments,
int[] centroidVectorCount
) throws IOException;

abstract CentroidSupplier createCentroidSupplier(
Expand Down Expand Up @@ -179,7 +181,8 @@ public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
ivfClusters,
postingListOffset,
centroidAssignments.assignments(),
centroidAssignments.overspillAssignments()
centroidAssignments.overspillAssignments(),
centroidAssignments.centroidVectorCount()
);
final long postingListLength = ivfClusters.getFilePointer() - postingListOffset;
// write centroids
Expand Down Expand Up @@ -306,6 +309,7 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws
final int numCentroids;
final int[] assignments;
final int[] overspillAssignments;
final int[] centroidVectorCount;
final float[] calculatedGlobalCentroid = new float[fieldInfo.getVectorDimension()];
String centroidTempName = null;
IndexOutput centroidTemp = null;
Expand All @@ -327,6 +331,7 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws
numCentroids = centroidAssignments.numCentroids();
assignments = centroidAssignments.assignments();
overspillAssignments = centroidAssignments.overspillAssignments();
centroidVectorCount = centroidAssignments.centroidVectorCount();
success = true;
} finally {
if (success == false && centroidTempName != null) {
Expand Down Expand Up @@ -362,7 +367,8 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws
postingListOffset,
mergeState,
assignments,
overspillAssignments
overspillAssignments,
centroidVectorCount
);
postingListLength = ivfClusters.getFilePointer() - postingListOffset;
// write centroids
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
package org.elasticsearch.index.codec.vectors.cluster;

import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.util.ArrayUtil;

import java.io.IOException;
import java.util.Arrays;
Expand Down Expand Up @@ -72,7 +73,7 @@ public KMeansResult cluster(FloatVectorValues vectors, int targetSize) throws IO
for (int j = 0; j < dimension; j++) {
centroid[j] /= vectors.size();
}
return new KMeansIntermediate(new float[][] { centroid }, new int[vectors.size()]);
return new KMeansIntermediate(new float[][] { centroid }, new int[vectors.size()], new int[] { vectors.size() });
}

// partition the space
Expand Down Expand Up @@ -100,27 +101,26 @@ KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int ta
Arrays.fill(assignments, -1);
KMeansLocal kmeans = new KMeansLocal(m, maxIterations);
float[][] centroids = KMeansLocal.pickInitialCentroids(vectors, k);
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, assignments, vectors::ordToDoc);
int[] centroidVectorCount = new int[centroids.length];
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, assignments, centroidVectorCount, vectors::ordToDoc);
kmeans.cluster(vectors, kMeansIntermediate);

// TODO: consider adding cluster size counts to the kmeans algo
// handle assignment here so we can track distance and cluster size
int[] centroidVectorCount = new int[centroids.length];
int effectiveCluster = -1;
int effectiveK = 0;
for (int assigment : assignments) {
centroidVectorCount[assigment]++;
// this cluster has received an assignment, its now effective, but only count it once
if (centroidVectorCount[assigment] == 1) {
int effectiveCluster = -1;
for (int i = 0; i < centroidVectorCount.length; i++) {
if (centroidVectorCount[i] > 0) {
effectiveK++;
effectiveCluster = assigment;
if (effectiveK > 1) {
break;
}
effectiveCluster = i;
}
}

if (effectiveK == 1) {
final float[][] singleClusterCentroid = new float[1][];
singleClusterCentroid[0] = centroids[effectiveCluster];
kMeansIntermediate.setCentroids(singleClusterCentroid);
kMeansIntermediate.setCentroidsAndCounts(singleClusterCentroid, new int[] { vectors.size() });
Arrays.fill(kMeansIntermediate.assignments(), 0);
return kMeansIntermediate;
}
Expand Down Expand Up @@ -149,13 +149,22 @@ KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int ta
adjustedCentroid,
newSize - adjustedCentroid
);
final int[] newCounts = new int[newSize];
System.arraycopy(kMeansIntermediate.centroidCounts(), 0, newCounts, 0, adjustedCentroid);
System.arraycopy(
kMeansIntermediate.centroidCounts(),
adjustedCentroid + 1,
newCounts,
adjustedCentroid,
newSize - adjustedCentroid
);
// we need to update the assignments to reflect the new centroid ordinals
for (int i = 0; i < kMeansIntermediate.assignments().length; i++) {
if (kMeansIntermediate.assignments()[i] > adjustedCentroid) {
kMeansIntermediate.assignments()[i]--;
}
}
kMeansIntermediate.setCentroids(newCentroids);
kMeansIntermediate.setCentroidsAndCounts(newCentroids, newCounts);
removedElements++;
}
}
Expand Down Expand Up @@ -184,18 +193,26 @@ void updateAssignmentsWithRecursiveSplit(KMeansIntermediate current, int cluster
int newCentroidsSize = current.centroids().length + subPartitions.centroids().length - 1;

// update based on the outcomes from the split clusters recursion
float[][] newCentroids = new float[newCentroidsSize][];
System.arraycopy(current.centroids(), 0, newCentroids, 0, current.centroids().length);
final float[][] newCentroids = ArrayUtil.growExact(current.centroids(), newCentroidsSize);
final int[] newCounts = ArrayUtil.growExact(current.centroidCounts(), newCentroidsSize);

// replace the original cluster
int origCentroidOrd = 0;
newCentroids[cluster] = subPartitions.centroids()[0];
newCounts[cluster] = subPartitions.centroidCounts()[0];

// append the remainder
System.arraycopy(subPartitions.centroids(), 1, newCentroids, current.centroids().length, subPartitions.centroids().length - 1);
assert Arrays.stream(newCentroids).allMatch(Objects::nonNull);
System.arraycopy(
subPartitions.centroidCounts(),
1,
newCounts,
current.centroidCounts().length,
subPartitions.centroidCounts().length - 1
);

current.setCentroids(newCentroids);
current.setCentroidsAndCounts(newCentroids, newCounts);

for (int i = 0; i < subPartitions.assignments().length; i++) {
// this is a new centroid that was added, and so we'll need to remap it
Expand All @@ -205,5 +222,22 @@ void updateAssignmentsWithRecursiveSplit(KMeansIntermediate current, int cluster
current.assignments()[parentOrd] = subPartitions.assignments()[i] + orgCentroidsSize - 1;
}
}
assert assertCounts(newCounts, current.assignments());
}

private static boolean assertCounts(int[] counts, int[] assignments) {
int[] newCounts = new int[counts.length];
for (int assignment : assignments) {
if (assignment != -1) {
newCounts[assignment]++;
}
}
for (int i = 0; i < counts.length; i++) {
if (counts[i] != newCounts[i]) {
return false;
}
}
return true;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,28 @@
class KMeansIntermediate extends KMeansResult {
private final IntToIntFunction assignmentOrds;

private KMeansIntermediate(float[][] centroids, int[] assignments, IntToIntFunction assignmentOrds, int[] soarAssignments) {
super(centroids, assignments, soarAssignments);
private KMeansIntermediate(
float[][] centroids,
int[] assignments,
int[] counts,
IntToIntFunction assignmentOrds,
int[] soarAssignments
) {
super(centroids, assignments, soarAssignments, counts);
assert assignmentOrds != null;
this.assignmentOrds = assignmentOrds;
}

KMeansIntermediate(float[][] centroids, int[] assignments, IntToIntFunction assignmentOrdinals) {
this(centroids, assignments, assignmentOrdinals, new int[0]);
KMeansIntermediate(float[][] centroids, int[] assignments, int[] counts, IntToIntFunction assignmentOrdinals) {
this(centroids, assignments, counts, assignmentOrdinals, new int[0]);
}

KMeansIntermediate() {
this(new float[0][0], new int[0], i -> i, new int[0]);
this(new float[0][0], new int[0], new int[0], i -> i, new int[0]);
}

KMeansIntermediate(float[][] centroids, int[] assignments) {
this(centroids, assignments, i -> i, new int[0]);
KMeansIntermediate(float[][] centroids, int[] assignments, int[] counts) {
this(centroids, assignments, counts, i -> i, new int[0]);
}

public int ordToDoc(int ord) {
Expand Down
Loading