Skip to content

Commit 192ff7e

Browse files
authored
[DiskBBQ] add method for calculate centroids during merge (elastic#137039)
We currently se the same API method to compute centroids during flush and during merge. In order to explore improvements on how we compute centroids during merge (e.g reusing some of the existing centroids), this commit adds a new API method for calculating centroids during merging. They implementation currently fowards the request to the calcualtion of centroids during flush. I moved the computation of the global centroid to CentroidAssignments as it feels it belongs there.
1 parent d7fda61 commit 192ff7e

File tree

4 files changed

+53
-54
lines changed

4 files changed

+53
-54
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/CentroidAssignments.java

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,33 @@
99

1010
package org.elasticsearch.index.codec.vectors.diskbbq;
1111

12-
public record CentroidAssignments(int numCentroids, float[][] centroids, int[] assignments, int[] overspillAssignments) {
12+
public record CentroidAssignments(
13+
int numCentroids,
14+
float[][] centroids,
15+
int[] assignments,
16+
int[] overspillAssignments,
17+
float[] globalCentroid
18+
) {
1319

14-
public CentroidAssignments(float[][] centroids, int[] assignments, int[] overspillAssignments) {
15-
this(centroids.length, centroids, assignments, overspillAssignments);
20+
public CentroidAssignments(int dims, float[][] centroids, int[] assignments, int[] overspillAssignments) {
21+
this(centroids.length, centroids, assignments, overspillAssignments, computeGlobalCentroid(dims, centroids));
1622
assert assignments.length == overspillAssignments.length || overspillAssignments.length == 0
1723
: "assignments and overspillAssignments must have the same length";
24+
25+
}
26+
27+
private static float[] computeGlobalCentroid(int dims, float[][] centroids) {
28+
final float[] globalCentroid = new float[dims];
29+
// TODO: push this logic into vector util?
30+
for (float[] centroid : centroids) {
31+
assert centroid.length == dims;
32+
for (int j = 0; j < centroid.length; j++) {
33+
globalCentroid[j] += centroid[j];
34+
}
35+
}
36+
for (int j = 0; j < globalCentroid.length; j++) {
37+
globalCentroid[j] /= centroids.length;
38+
}
39+
return globalCentroid;
1840
}
1941
}

server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsWriter.java

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -523,47 +523,34 @@ public int size() {
523523
return new CentroidGroups(kMeansResult.centroids(), vectorsPerCentroid, maxVectorsPerCentroidLength);
524524
}
525525

526+
@Override
527+
public CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, MergeState mergeState)
528+
throws IOException {
529+
return calculateCentroids(fieldInfo, floatVectorValues);
530+
}
531+
526532
/**
527533
* Calculate the centroids for the given field.
528534
* We use the {@link HierarchicalKMeans} algorithm to partition the space of all vectors across merging segments
529535
*
530536
* @param fieldInfo merging field info
531537
* @param floatVectorValues the float vector values to merge
532-
* @param globalCentroid the global centroid, calculated by this method and used to quantize the centroids
533538
* @return the vector assignments, soar assignments, and if asked the centroids themselves that were computed
534539
* @throws IOException if an I/O error occurs
535540
*/
536541
@Override
537-
public CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, float[] globalCentroid)
538-
throws IOException {
539-
542+
public CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues) throws IOException {
540543
// TODO: consider hinting / bootstrapping hierarchical kmeans with the prior segments centroids
541-
CentroidAssignments centroidAssignments = buildCentroidAssignments(floatVectorValues, vectorPerCluster);
542-
float[][] centroids = centroidAssignments.centroids();
543544
// TODO: for flush we are doing this over the vectors and here centroids which seems duplicative
544545
// preliminary tests suggest recall is good using only centroids but need to do further evaluation
545-
// TODO: push this logic into vector util?
546-
for (float[] centroid : centroids) {
547-
for (int j = 0; j < centroid.length; j++) {
548-
globalCentroid[j] += centroid[j];
549-
}
550-
}
551-
for (int j = 0; j < globalCentroid.length; j++) {
552-
globalCentroid[j] /= centroids.length;
553-
}
554-
546+
KMeansResult kMeansResult = new HierarchicalKMeans(floatVectorValues.dimension()).cluster(floatVectorValues, vectorPerCluster);
547+
float[][] centroids = kMeansResult.centroids();
555548
if (logger.isDebugEnabled()) {
556549
logger.debug("final centroid count: {}", centroids.length);
557550
}
558-
return centroidAssignments;
559-
}
560-
561-
static CentroidAssignments buildCentroidAssignments(FloatVectorValues floatVectorValues, int vectorPerCluster) throws IOException {
562-
KMeansResult kMeansResult = new HierarchicalKMeans(floatVectorValues.dimension()).cluster(floatVectorValues, vectorPerCluster);
563-
float[][] centroids = kMeansResult.centroids();
564551
int[] assignments = kMeansResult.assignments();
565552
int[] soarAssignments = kMeansResult.soarAssignments();
566-
return new CentroidAssignments(centroids, assignments, soarAssignments);
553+
return new CentroidAssignments(fieldInfo.getVectorDimension(), centroids, assignments, soarAssignments);
567554
}
568555

569556
static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections)

server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsWriter.java

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,9 @@ public final KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOExc
141141
return rawVectorDelegate;
142142
}
143143

144-
public abstract CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, float[] globalCentroid)
144+
public abstract CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues) throws IOException;
145+
146+
public abstract CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, MergeState mergeState)
145147
throws IOException;
146148

147149
public record CentroidOffsetAndLength(LongValues offsets, LongValues lengths) {}
@@ -191,11 +193,10 @@ public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
191193
writeMeta(fieldWriter.fieldInfo, 0, 0, 0, 0, 0, null);
192194
continue;
193195
}
194-
final float[] globalCentroid = new float[fieldWriter.fieldInfo.getVectorDimension()];
195196
// build a float vector values with random access
196197
final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldWriter.fieldInfo, fieldWriter.delegate, maxDoc);
197198
// build centroids
198-
final CentroidAssignments centroidAssignments = calculateCentroids(fieldWriter.fieldInfo, floatVectorValues, globalCentroid);
199+
final CentroidAssignments centroidAssignments = calculateCentroids(fieldWriter.fieldInfo, floatVectorValues);
199200
// wrap centroids with a supplier
200201
final CentroidSupplier centroidSupplier = CentroidSupplier.fromArray(centroidAssignments.centroids());
201202
// write posting lists
@@ -211,6 +212,7 @@ public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
211212
);
212213
final long postingListLength = ivfClusters.getFilePointer() - postingListOffset;
213214
// write centroids
215+
final float[] globalCentroid = centroidAssignments.globalCentroid();
214216
final long centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES);
215217
writeCentroids(fieldWriter.fieldInfo, centroidSupplier, globalCentroid, centroidOffsetAndLength, ivfCentroids);
216218
final long centroidLength = ivfCentroids.getFilePointer() - centroidOffset;
@@ -377,7 +379,7 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws
377379
final int numCentroids;
378380
final int[] assignments;
379381
final int[] overspillAssignments;
380-
final float[] calculatedGlobalCentroid = new float[fieldInfo.getVectorDimension()];
382+
final float[] calculatedGlobalCentroid;
381383
String centroidTempName = null;
382384
IndexOutput centroidTemp = null;
383385
success = false;
@@ -387,7 +389,7 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws
387389
CentroidAssignments centroidAssignments = calculateCentroids(
388390
fieldInfo,
389391
getFloatVectorValues(fieldInfo, docs, vectors, numVectors),
390-
calculatedGlobalCentroid
392+
mergeState
391393
);
392394
// write the centroids to a temporary file so we are not holding them on heap
393395
final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
@@ -397,6 +399,7 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws
397399
}
398400
numCentroids = centroidAssignments.numCentroids();
399401
assignments = centroidAssignments.assignments();
402+
calculatedGlobalCentroid = centroidAssignments.globalCentroid();
400403
overspillAssignments = centroidAssignments.overspillAssignments();
401404
success = true;
402405
} finally {

server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsWriter.java

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -511,47 +511,34 @@ public int size() {
511511
return new CentroidGroups(kMeansResult.centroids(), vectorsPerCentroid, maxVectorsPerCentroidLength);
512512
}
513513

514+
@Override
515+
public CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, MergeState mergeState)
516+
throws IOException {
517+
return calculateCentroids(fieldInfo, floatVectorValues);
518+
}
519+
514520
/**
515521
* Calculate the centroids for the given field.
516522
* We use the {@link HierarchicalKMeans} algorithm to partition the space of all vectors across merging segments
517523
*
518524
* @param fieldInfo merging field info
519525
* @param floatVectorValues the float vector values to merge
520-
* @param globalCentroid the global centroid, calculated by this method and used to quantize the centroids
521526
* @return the vector assignments, soar assignments, and if asked the centroids themselves that were computed
522527
* @throws IOException if an I/O error occurs
523528
*/
524529
@Override
525-
public CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, float[] globalCentroid)
526-
throws IOException {
527-
530+
public CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues) throws IOException {
528531
// TODO: consider hinting / bootstrapping hierarchical kmeans with the prior segments centroids
529-
CentroidAssignments centroidAssignments = buildCentroidAssignments(floatVectorValues, vectorPerCluster);
530-
float[][] centroids = centroidAssignments.centroids();
531532
// TODO: for flush we are doing this over the vectors and here centroids which seems duplicative
532533
// preliminary tests suggest recall is good using only centroids but need to do further evaluation
533-
// TODO: push this logic into vector util?
534-
for (float[] centroid : centroids) {
535-
for (int j = 0; j < centroid.length; j++) {
536-
globalCentroid[j] += centroid[j];
537-
}
538-
}
539-
for (int j = 0; j < globalCentroid.length; j++) {
540-
globalCentroid[j] /= centroids.length;
541-
}
542-
534+
KMeansResult kMeansResult = new HierarchicalKMeans(floatVectorValues.dimension()).cluster(floatVectorValues, vectorPerCluster);
535+
float[][] centroids = kMeansResult.centroids();
543536
if (logger.isDebugEnabled()) {
544537
logger.debug("final centroid count: {}", centroids.length);
545538
}
546-
return centroidAssignments;
547-
}
548-
549-
static CentroidAssignments buildCentroidAssignments(FloatVectorValues floatVectorValues, int vectorPerCluster) throws IOException {
550-
KMeansResult kMeansResult = new HierarchicalKMeans(floatVectorValues.dimension()).cluster(floatVectorValues, vectorPerCluster);
551-
float[][] centroids = kMeansResult.centroids();
552539
int[] assignments = kMeansResult.assignments();
553540
int[] soarAssignments = kMeansResult.soarAssignments();
554-
return new CentroidAssignments(centroids, assignments, soarAssignments);
541+
return new CentroidAssignments(fieldInfo.getVectorDimension(), centroids, assignments, soarAssignments);
555542
}
556543

557544
static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections)

0 commit comments

Comments
 (0)