Skip to content

Commit dd61ba5

Browse files
committed
split kmeansresult into two classes, updated centroid assignments interface to pass around a IntArrayList[] which reverts and simplifies some of the interfaces, added Ben's diff around simplifying ordinal resolution in hkmeans which greatly simplified a number of things with FFVSlices
1 parent 5112408 commit dd61ba5

File tree

11 files changed

+132
-127
lines changed

11 files changed

+132
-127
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/CentroidAssignments.java

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,26 @@
99

1010
package org.elasticsearch.index.codec.vectors;
1111

12+
import org.apache.lucene.internal.hppc.IntArrayList;
13+
1214
final class CentroidAssignments {
1315

1416
private final int numCentroids;
1517
private final float[][] cachedCentroids;
16-
private final int[] assignments;
17-
private final int[] soarAssignments;
18+
private final IntArrayList[] assignmentsByCluster;
1819

19-
private CentroidAssignments(int numCentroids, float[][] cachedCentroids, int[] assignments, int[] soarAssignments) {
20+
private CentroidAssignments(int numCentroids, float[][] cachedCentroids, IntArrayList[] assignmentsByCluster) {
2021
this.numCentroids = numCentroids;
2122
this.cachedCentroids = cachedCentroids;
22-
this.assignments = assignments;
23-
this.soarAssignments = soarAssignments;
23+
this.assignmentsByCluster = assignmentsByCluster;
2424
}
2525

26-
CentroidAssignments(float[][] centroids, int[] assignments, int[] soarAssignments) {
27-
this(centroids.length, centroids, assignments, soarAssignments);
26+
CentroidAssignments(float[][] centroids, IntArrayList[] assignmentsByCluster) {
27+
this(centroids.length, centroids, assignmentsByCluster);
2828
}
2929

30-
CentroidAssignments(int numCentroids, int[] assignments, int[] soarAssignments) {
31-
this(numCentroids, null, assignments, soarAssignments);
30+
CentroidAssignments(int numCentroids, IntArrayList[] assignmentsByCluster) {
31+
this(numCentroids, null, assignmentsByCluster);
3232
}
3333

3434
// Getters and setters
@@ -40,11 +40,7 @@ public float[][] cachedCentroids() {
4040
return cachedCentroids;
4141
}
4242

43-
public int[] assignments() {
44-
return assignments;
45-
}
46-
47-
public int[] soarAssignments() {
48-
return soarAssignments;
43+
public IntArrayList[] assignmentsByCluster() {
44+
return assignmentsByCluster;
4945
}
5046
}

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -56,41 +56,19 @@ long[] buildAndWritePostingsLists(
5656
FloatVectorValues floatVectorValues,
5757
IndexOutput postingsOutput,
5858
InfoStream infoStream,
59-
CentroidAssignments centroidAssignments
59+
IntArrayList[] assignmentsByCluster
6060
) throws IOException {
61-
6261
// write the posting lists
6362
final long[] offsets = new long[centroidSupplier.size()];
6463
OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
6564
BinarizedFloatVectorValues binarizedByteVectorValues = new BinarizedFloatVectorValues(floatVectorValues, quantizer);
6665
DocIdsWriter docIdsWriter = new DocIdsWriter();
6766

68-
int[] assignments = centroidAssignments.assignments();
69-
int[] soarAssignments = centroidAssignments.soarAssignments();
70-
71-
IntArrayList[] clustersForMetrics = null;
72-
if (infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
73-
clustersForMetrics = new IntArrayList[centroidSupplier.size()];
74-
}
75-
7667
for (int c = 0; c < centroidSupplier.size(); c++) {
7768
float[] centroid = centroidSupplier.centroid(c);
7869
binarizedByteVectorValues.centroid = centroid;
79-
8070
// TODO: add back in sorting vectors by distance to centroid
81-
IntArrayList cluster = new IntArrayList(vectorPerCluster);
82-
for (int j = 0; j < assignments.length; j++) {
83-
if (assignments[j] == c) {
84-
cluster.add(j);
85-
}
86-
}
87-
88-
for (int j = 0; j < soarAssignments.length; j++) {
89-
if (soarAssignments[j] == c) {
90-
cluster.add(j);
91-
}
92-
}
93-
71+
IntArrayList cluster = assignmentsByCluster[c];
9472
// TODO align???
9573
offsets[c] = postingsOutput.getFilePointer();
9674
int size = cluster.size();
@@ -99,16 +77,12 @@ long[] buildAndWritePostingsLists(
9977
// TODO we might want to consider putting the docIds in a separate file
10078
// to aid with only having to fetch vectors from slower storage when they are required
10179
// keeping them in the same file indicates we pull the entire file into cache
102-
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster.get(j)), cluster.size(), postingsOutput);
80+
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster.get(j)), size, postingsOutput);
10381
writePostingList(cluster, postingsOutput, binarizedByteVectorValues);
104-
105-
if (infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
106-
clustersForMetrics[c] = cluster;
107-
}
10882
}
10983

11084
if (infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
111-
printClusterQualityStatistics(clustersForMetrics, infoStream);
85+
printClusterQualityStatistics(assignmentsByCluster, infoStream);
11286
}
11387

11488
return offsets;
@@ -302,10 +276,28 @@ CentroidAssignments calculateAndWriteCentroids(
302276
infoStream.message(IVF_VECTOR_COMPONENT, "final centroid count: " + centroids.length);
303277
}
304278

279+
IntArrayList[] assignmentsByCluster = new IntArrayList[centroids.length];
280+
for (int c = 0; c < centroids.length; c++) {
281+
IntArrayList cluster = new IntArrayList(vectorPerCluster);
282+
for (int j = 0; j < assignments.length; j++) {
283+
if (assignments[j] == c) {
284+
cluster.add(j);
285+
}
286+
}
287+
288+
for (int j = 0; j < soarAssignments.length; j++) {
289+
if (soarAssignments[j] == c) {
290+
cluster.add(j);
291+
}
292+
}
293+
294+
assignmentsByCluster[c] = cluster;
295+
}
296+
305297
if (cacheCentroids) {
306-
return new CentroidAssignments(centroids, assignments, soarAssignments);
298+
return new CentroidAssignments(centroids, assignmentsByCluster);
307299
} else {
308-
return new CentroidAssignments(centroids.length, assignments, soarAssignments);
300+
return new CentroidAssignments(centroids.length, assignmentsByCluster);
309301
}
310302
}
311303

server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.apache.lucene.index.Sorter;
2626
import org.apache.lucene.index.VectorEncoding;
2727
import org.apache.lucene.index.VectorSimilarityFunction;
28+
import org.apache.lucene.internal.hppc.IntArrayList;
2829
import org.apache.lucene.search.DocIdSetIterator;
2930
import org.apache.lucene.store.IOContext;
3031
import org.apache.lucene.store.IndexInput;
@@ -144,7 +145,7 @@ abstract long[] buildAndWritePostingsLists(
144145
FloatVectorValues floatVectorValues,
145146
IndexOutput postingsOutput,
146147
InfoStream infoStream,
147-
CentroidAssignments centroidAssignments
148+
IntArrayList[] assignmentsByCluster
148149
) throws IOException;
149150

150151
abstract CentroidSupplier createCentroidSupplier(
@@ -181,7 +182,7 @@ public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
181182
floatVectorValues,
182183
ivfClusters,
183184
segmentWriteState.infoStream,
184-
centroidAssignments
185+
centroidAssignments.assignmentsByCluster()
185186
);
186187
// write posting lists
187188
writeMeta(fieldWriter.fieldInfo, centroidOffset, centroidLength, offsets, globalCentroid);
@@ -324,7 +325,7 @@ public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) thro
324325
floatVectorValues,
325326
ivfClusters,
326327
mergeState.infoStream,
327-
centroidAssignments
328+
centroidAssignments.assignmentsByCluster()
328329
);
329330
assert offsets.length == centroidSupplier.size();
330331
writeMeta(fieldInfo, centroidOffset, centroidLength, offsets, calculatedGlobalCentroid);

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/FloatVectorValuesSlice.java

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import org.apache.lucene.index.FloatVectorValues;
1313

1414
import java.io.IOException;
15-
import java.util.stream.IntStream;
1615

1716
class FloatVectorValuesSlice extends FloatVectorValues {
1817

@@ -29,10 +28,6 @@ class FloatVectorValuesSlice extends FloatVectorValues {
2928
}
3029
}
3130

32-
FloatVectorValuesSlice(FloatVectorValues allValues) {
33-
this(allValues, null);
34-
}
35-
3631
@Override
3732
public float[] vectorValue(int ord) throws IOException {
3833
if (this.slice == null) {
@@ -56,11 +51,12 @@ public int size() {
5651
}
5752
}
5853

59-
public int[] slice() {
54+
@Override
55+
public int ordToDoc(int ord) {
6056
if (this.slice == null) {
61-
return IntStream.range(0, allValues.size()).toArray();
57+
return ord;
6258
} else {
63-
return this.slice;
59+
return this.slice[ord];
6460
}
6561
}
6662

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -51,34 +51,34 @@ public HierarchicalKMeans(int dimension) {
5151
* @return the centroids and the vectors assignments and SOAR (spilled from nearby neighborhoods) assignments
5252
* @throws IOException is thrown if vectors is inaccessible
5353
*/
54-
public KMeansResult cluster(FloatVectorValues vectors, int targetSize) throws IOException {
54+
public KMeansIntermediate cluster(FloatVectorValues vectors, int targetSize) throws IOException {
5555

5656
if (vectors.size() == 0) {
57-
return new KMeansResult();
57+
return new KMeansIntermediate();
5858
}
5959

6060
// if we have a small number of vectors pick one and output that as the centroid
6161
if (vectors.size() < targetSize) {
6262
float[] centroid = new float[dimension];
6363
System.arraycopy(vectors.vectorValue(0), 0, centroid, 0, dimension);
64-
return new KMeansResult(new float[][] { centroid }, new int[vectors.size()]);
64+
return new KMeansIntermediate(new float[][] { centroid }, new int[vectors.size()]);
6565
}
6666

6767
// partition the space
68-
KMeansResult kMeansResult = cluster(new FloatVectorValuesSlice(vectors), targetSize);
69-
if (kMeansResult.centroids().length > 1 && kMeansResult.centroids().length < vectors.size()) {
68+
KMeansIntermediate kMeansIntermediate = clusterAndSplit(vectors, targetSize);
69+
if (kMeansIntermediate.centroids().length > 1 && kMeansIntermediate.centroids().length < vectors.size()) {
7070
float f = Math.min((float) samplesPerCluster / targetSize, 1.0f);
7171
int localSampleSize = (int) (f * vectors.size());
7272
KMeansLocal kMeansLocal = new KMeansLocal(localSampleSize, maxIterations, clustersPerNeighborhood, DEFAULT_SOAR_LAMBDA);
73-
kMeansLocal.cluster(vectors, kMeansResult);
73+
kMeansLocal.cluster(vectors, kMeansIntermediate);
7474
}
7575

76-
return kMeansResult;
76+
return kMeansIntermediate;
7777
}
7878

79-
KMeansResult cluster(final FloatVectorValuesSlice vectors, final int targetSize) throws IOException {
79+
KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int targetSize) throws IOException {
8080
if (vectors.size() <= targetSize) {
81-
return new KMeansResult();
81+
return new KMeansIntermediate();
8282
}
8383

8484
int k = Math.clamp((int) ((vectors.size() + targetSize / 2.0f) / (float) targetSize), 2, MAXK);
@@ -89,8 +89,8 @@ KMeansResult cluster(final FloatVectorValuesSlice vectors, final int targetSize)
8989

9090
KMeans kmeans = new KMeans(m, maxIterations);
9191
float[][] centroids = KMeans.pickInitialCentroids(vectors, k);
92-
KMeansResult kMeansResult = new KMeansResult(centroids);
93-
kmeans.cluster(vectors, kMeansResult);
92+
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids);
93+
kmeans.cluster(vectors, kMeansIntermediate);
9494

9595
int[] clusterSizes = new int[centroids.length];
9696

@@ -134,30 +134,29 @@ KMeansResult cluster(final FloatVectorValuesSlice vectors, final int targetSize)
134134
}
135135
}
136136

137-
int[] assignmentOrdinals = vectors.slice();
138-
kMeansResult = new KMeansResult(centroids, assignments, assignmentOrdinals);
137+
kMeansIntermediate = new KMeansIntermediate(centroids, assignments, vectors::ordToDoc);
139138

140139
if (effectiveK == 1) {
141-
return kMeansResult;
140+
return kMeansIntermediate;
142141
}
143142

144143
for (int c = 0; c < clusterSizes.length; c++) {
145144
// Recurse for each cluster which is larger than targetSize
146145
// Give ourselves 30% margin for the target size
147146
if (100 * clusterSizes[c] > 134 * targetSize) {
148-
FloatVectorValuesSlice sample = createClusterSlice(clusterSizes[c], c, vectors, assignments);
147+
FloatVectorValues sample = createClusterSlice(clusterSizes[c], c, vectors, assignments);
149148

150149
// TODO: consider iterative here instead of recursive
151150
// recursive call to build out the sub partitions around this centroid c
152151
// subsequently reconcile and flatten the space of all centroids and assignments into one structure we can return
153-
updateAssignmentsWithRecursiveSplit(kMeansResult, c, cluster(sample, targetSize));
152+
updateAssignmentsWithRecursiveSplit(kMeansIntermediate, c, clusterAndSplit(sample, targetSize));
154153
}
155154
}
156155

157-
return kMeansResult;
156+
return kMeansIntermediate;
158157
}
159158

160-
static FloatVectorValuesSlice createClusterSlice(int clusterSize, int cluster, FloatVectorValuesSlice vectors, int[] assignments) {
159+
static FloatVectorValues createClusterSlice(int clusterSize, int cluster, FloatVectorValues vectors, int[] assignments) {
161160
int[] slice = new int[clusterSize];
162161
int idx = 0;
163162
for (int i = 0; i < assignments.length; i++) {
@@ -170,7 +169,7 @@ static FloatVectorValuesSlice createClusterSlice(int clusterSize, int cluster, F
170169
return new FloatVectorValuesSlice(vectors, slice);
171170
}
172171

173-
void updateAssignmentsWithRecursiveSplit(KMeansResult current, int cluster, KMeansResult subPartitions) {
172+
void updateAssignmentsWithRecursiveSplit(KMeansIntermediate current, int cluster, KMeansIntermediate subPartitions) {
174173
int orgCentroidsSize = current.centroids().length;
175174
int newCentroidsSize = current.centroids().length + subPartitions.centroids().length - 1;
176175

@@ -191,7 +190,7 @@ void updateAssignmentsWithRecursiveSplit(KMeansResult current, int cluster, KMea
191190
for (int i = 0; i < subPartitions.assignments().length; i++) {
192191
// this is a new centroid that was added, and so we'll need to remap it
193192
if (subPartitions.assignments()[i] != origCentroidOrd) {
194-
int parentOrd = subPartitions.assignmentOrds()[i];
193+
int parentOrd = subPartitions.ordToDoc(i);
195194
assert current.assignments()[parentOrd] == cluster;
196195
current.assignments()[parentOrd] = subPartitions.assignments()[i] + orgCentroidsSize - 1;
197196
}

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeans.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -114,17 +114,17 @@ int getBestCentroidOffset(float[][] centroids, float[] vector, int vectorIdx, Cl
114114
* cluster using a lloyd k-means algorithm
115115
*
116116
* @param vectors the vectors to cluster
117-
* @param kMeansResult the output object to populate which minimally includes centroids,
117+
* @param kMeansIntermediate the output object to populate which minimally includes centroids,
118118
* but may include assignments and soar assignments as well; care should be taken in
119119
* passing in a valid output object with a centroids array that is the size of centroids expected
120120
* @throws IOException is thrown if vectors is inaccessible
121121
*/
122-
void cluster(FloatVectorValues vectors, KMeansResult kMeansResult) throws IOException {
123-
cluster(vectors, kMeansResult, new ClusteringAugment());
122+
void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate) throws IOException {
123+
cluster(vectors, kMeansIntermediate, new ClusteringAugment());
124124
}
125125

126-
void cluster(FloatVectorValues vectors, KMeansResult kMeansResult, ClusteringAugment augment) throws IOException {
127-
float[][] centroids = kMeansResult.centroids();
126+
void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, ClusteringAugment augment) throws IOException {
127+
float[][] centroids = kMeansIntermediate.centroids();
128128
int k = centroids.length;
129129
int n = vectors.size();
130130

@@ -143,17 +143,17 @@ void cluster(FloatVectorValues vectors, KMeansResult kMeansResult, ClusteringAug
143143
}
144144

145145
/**
146-
* helper that calls {@link KMeans#cluster(FloatVectorValues, KMeansResult)} given a set of initialized centroids
146+
* helper that calls {@link KMeans#cluster(FloatVectorValues, KMeansIntermediate)} given a set of initialized centroids
147147
*
148148
* @param vectors the vectors to cluster
149149
* @param centroids the initialized centroids to be shifted using k-means
150150
* @param sampleSize the subset of vectors to use when shifting centroids
151151
* @param maxIterations the max iterations to shift centroids
152152
*/
153153
public static void cluster(FloatVectorValues vectors, float[][] centroids, int sampleSize, int maxIterations) throws IOException {
154-
KMeansResult kMeansResult = new KMeansResult(centroids);
154+
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids);
155155
KMeans kMeans = new KMeans(sampleSize, maxIterations);
156-
kMeans.cluster(vectors, kMeansResult);
156+
kMeans.cluster(vectors, kMeansIntermediate);
157157
}
158158

159159
/**

0 commit comments

Comments
 (0)