Skip to content

Commit ecd964a

Browse files
committed
[IVF] Simplify how we buildCentroidAssignments
1 parent 445c3eb commit ecd964a

File tree

3 files changed

+26
-62
lines changed

3 files changed

+26
-62
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/CentroidAssignments.java

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,36 +9,10 @@
99

1010
package org.elasticsearch.index.codec.vectors;
1111

12-
final class CentroidAssignments {
13-
14-
private final int numCentroids;
15-
private final float[][] cachedCentroids;
16-
private final int[][] assignmentsByCluster;
17-
18-
private CentroidAssignments(int numCentroids, float[][] cachedCentroids, int[][] assignmentsByCluster) {
19-
this.numCentroids = numCentroids;
20-
this.cachedCentroids = cachedCentroids;
21-
this.assignmentsByCluster = assignmentsByCluster;
22-
}
12+
record CentroidAssignments(int numCentroids, float[][] centroids, int[][] assignmentsByCluster) {
2313

2414
CentroidAssignments(float[][] centroids, int[][] assignmentsByCluster) {
2515
this(centroids.length, centroids, assignmentsByCluster);
26-
}
27-
28-
CentroidAssignments(int numCentroids, int[][] assignmentsByCluster) {
29-
this(numCentroids, null, assignmentsByCluster);
30-
}
31-
32-
// Getters and setters
33-
public int numCentroids() {
34-
return numCentroids;
35-
}
36-
37-
public float[][] cachedCentroids() {
38-
return cachedCentroids;
39-
}
40-
41-
public int[][] assignmentsByCluster() {
42-
return assignmentsByCluster;
16+
assert centroids.length == assignmentsByCluster.length;
4317
}
4418
}

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ static void writeCentroids(float[][] centroids, FieldInfo fieldInfo, float[] glo
148148
}
149149
}
150150

151+
@Override
151152
CentroidAssignments calculateAndWriteCentroids(
152153
FieldInfo fieldInfo,
153154
FloatVectorValues floatVectorValues,
@@ -156,16 +157,7 @@ CentroidAssignments calculateAndWriteCentroids(
156157
float[] globalCentroid
157158
) throws IOException {
158159
// TODO: take advantage of prior generated clusters from mergeState in the future
159-
return calculateAndWriteCentroids(fieldInfo, floatVectorValues, centroidOutput, globalCentroid, false);
160-
}
161-
162-
CentroidAssignments calculateAndWriteCentroids(
163-
FieldInfo fieldInfo,
164-
FloatVectorValues floatVectorValues,
165-
IndexOutput centroidOutput,
166-
float[] globalCentroid
167-
) throws IOException {
168-
return calculateAndWriteCentroids(fieldInfo, floatVectorValues, centroidOutput, globalCentroid, true);
160+
return calculateAndWriteCentroids(fieldInfo, floatVectorValues, centroidOutput, globalCentroid);
169161
}
170162

171163
/**
@@ -176,26 +168,22 @@ CentroidAssignments calculateAndWriteCentroids(
176168
* @param floatVectorValues the float vector values to merge
177169
* @param centroidOutput the centroid output
178170
* @param globalCentroid the global centroid, calculated by this method and used to quantize the centroids
179-
* @param cacheCentroids whether the centroids are kept or discarded once computed
180171
* @return the vector assignments, soar assignments, and if asked the centroids themselves that were computed
181172
* @throws IOException if an I/O error occurs
182173
*/
174+
@Override
183175
CentroidAssignments calculateAndWriteCentroids(
184176
FieldInfo fieldInfo,
185177
FloatVectorValues floatVectorValues,
186178
IndexOutput centroidOutput,
187-
float[] globalCentroid,
188-
boolean cacheCentroids
189-
) throws IOException {
179+
float[] globalCentroid
180+
) throws IOException {
190181

191182
long nanoTime = System.nanoTime();
192183

193184
// TODO: consider hinting / bootstrapping hierarchical kmeans with the prior segments centroids
194185
KMeansResult kMeansResult = new HierarchicalKMeans(floatVectorValues.dimension()).cluster(floatVectorValues, vectorPerCluster);
195186
float[][] centroids = kMeansResult.centroids();
196-
int[] assignments = kMeansResult.assignments();
197-
int[] soarAssignments = kMeansResult.soarAssignments();
198-
199187
// TODO: for flush we are doing this over the vectors and here centroids which seems duplicative
200188
// preliminary tests suggest recall is good using only centroids but need to do further evaluation
201189
// TODO: push this logic into vector util?
@@ -215,7 +203,13 @@ CentroidAssignments calculateAndWriteCentroids(
215203
logger.debug("calculate centroids and assign vectors time ms: {}", (System.nanoTime() - nanoTime) / 1000000.0);
216204
logger.debug("final centroid count: {}", centroids.length);
217205
}
206+
return buildCentroidAssignments(kMeansResult);
207+
}
218208

209+
static CentroidAssignments buildCentroidAssignments(KMeansResult kMeansResult) {
210+
float[][] centroids = kMeansResult.centroids();
211+
int[] assignments = kMeansResult.assignments();
212+
int[] soarAssignments = kMeansResult.soarAssignments();
219213
int[] centroidVectorCount = new int[centroids.length];
220214
for (int i = 0; i < assignments.length; i++) {
221215
centroidVectorCount[assignments[i]]++;
@@ -242,12 +236,7 @@ CentroidAssignments calculateAndWriteCentroids(
242236
}
243237
}
244238
}
245-
246-
if (cacheCentroids) {
247-
return new CentroidAssignments(centroids, assignmentsByCluster);
248-
} else {
249-
return new CentroidAssignments(centroids.length, assignmentsByCluster);
250-
}
239+
return new CentroidAssignments(centroids, assignmentsByCluster);
251240
}
252241

253242
static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections)

server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
166166
globalCentroid
167167
);
168168

169-
CentroidSupplier centroidSupplier = new OnHeapCentroidSupplier(centroidAssignments.cachedCentroids());
169+
CentroidSupplier centroidSupplier = new OnHeapCentroidSupplier(centroidAssignments.centroids());
170170

171171
long centroidLength = ivfCentroids.getFilePointer() - centroidOffset;
172172
final long[] offsets = buildAndWritePostingsLists(
@@ -280,26 +280,27 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws
280280
IndexInput docs = docsFileName == null ? null : mergeState.segmentInfo.dir.openInput(docsFileName, IOContext.DEFAULT)
281281
) {
282282
final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo, docs, vectors, numVectors);
283-
success = false;
284-
long centroidOffset;
285-
long centroidLength;
283+
284+
final long centroidOffset;
285+
final long centroidLength;
286+
final int numCentroids;
287+
final int[][] assignmentsByCluster;
288+
final float[] calculatedGlobalCentroid = new float[fieldInfo.getVectorDimension()];
286289
String centroidTempName = null;
287-
int numCentroids;
288290
IndexOutput centroidTemp = null;
289-
CentroidAssignments centroidAssignments;
290-
float[] calculatedGlobalCentroid = new float[fieldInfo.getVectorDimension()];
291+
success = false;
291292
try {
292293
centroidTemp = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "civf_", IOContext.DEFAULT);
293294
centroidTempName = centroidTemp.getName();
294-
centroidAssignments = calculateAndWriteCentroids(
295+
CentroidAssignments centroidAssignments = calculateAndWriteCentroids(
295296
fieldInfo,
296-
floatVectorValues,
297+
getFloatVectorValues(fieldInfo, docs, vectors, numVectors),
297298
centroidTemp,
298299
mergeState,
299300
calculatedGlobalCentroid
300301
);
301302
numCentroids = centroidAssignments.numCentroids();
302-
303+
assignmentsByCluster = centroidAssignments.assignmentsByCluster();
303304
success = true;
304305
} finally {
305306
if (success == false && centroidTempName != null) {
@@ -336,7 +337,7 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws
336337
centroidSupplier,
337338
floatVectorValues,
338339
ivfClusters,
339-
centroidAssignments.assignmentsByCluster()
340+
assignmentsByCluster
340341
);
341342
assert offsets.length == centroidSupplier.size();
342343
writeMeta(fieldInfo, centroidOffset, centroidLength, offsets, calculatedGlobalCentroid);

0 commit comments

Comments
 (0)