Skip to content

Commit 5aa2682

Browse files
committed
fixed 1 off error and other cleanup
1 parent 3b0764e commit 5aa2682

File tree

3 files changed

+70
-37
lines changed

3 files changed

+70
-37
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java

Lines changed: 62 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@
2929
import java.io.IOException;
3030
import java.nio.ByteBuffer;
3131
import java.nio.ByteOrder;
32-
import java.util.ArrayList;
3332
import java.util.Arrays;
34-
import java.util.List;
3533

3634
/**
3735
* Default implementation of {@link IVFVectorsWriter}. It uses {@link HierarchicalKMeans} algorithm to
@@ -131,7 +129,7 @@ CentroidSupplier createCentroidSupplier(
131129
}
132130

133131
static void writeCentroidsAndPartitions(
134-
List<CentroidPartition> centroidPartitions,
132+
CentroidPartition[] centroidPartitions,
135133
float[][] centroids,
136134
FieldInfo fieldInfo,
137135
float[] globalCentroid,
@@ -144,22 +142,24 @@ static void writeCentroidsAndPartitions(
144142
// TODO do we want to store these distances as well for future use?
145143
// TODO: sort centroids by global centroid (was doing so previously here)
146144

147-
// write the top level partition parent nodes and their pointers to the centroids within the partition
148-
// a size of 1 indicates a leaf node that did not have a parent node (orphans)
149-
for (CentroidPartition centroidPartition : centroidPartitions) {
150-
System.arraycopy(centroidPartition.centroid(), 0, centroidScratch, 0, centroidPartition.centroid().length);
151-
OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(
152-
centroidScratch,
153-
quantizedScratch,
154-
(byte) 4,
155-
globalCentroid
156-
);
157-
for (int i = 0; i < quantizedScratch.length; i++) {
158-
quantized[i] = (byte) quantizedScratch[i];
145+
if (centroidPartitions != null) {
146+
// write the top level partition parent nodes and their pointers to the centroids within the partition
147+
// a size of 1 indicates a leaf node that did not have a parent node (orphans)
148+
for (CentroidPartition centroidPartition : centroidPartitions) {
149+
System.arraycopy(centroidPartition.centroid(), 0, centroidScratch, 0, centroidPartition.centroid().length);
150+
OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(
151+
centroidScratch,
152+
quantizedScratch,
153+
(byte) 4,
154+
globalCentroid
155+
);
156+
for (int i = 0; i < quantizedScratch.length; i++) {
157+
quantized[i] = (byte) quantizedScratch[i];
158+
}
159+
writeQuantizedValue(centroidOutput, quantized, result);
160+
centroidOutput.writeInt(centroidPartition.childOrdinal());
161+
centroidOutput.writeInt(centroidPartition.size());
159162
}
160-
writeQuantizedValue(centroidOutput, quantized, result);
161-
centroidOutput.writeInt(centroidPartition.childOrdinal());
162-
centroidOutput.writeInt(centroidPartition.size());
163163
}
164164

165165
// write the quantized centroids which will be duplicate for orphans
@@ -242,43 +242,41 @@ CentroidAssignments calculateAndWriteCentroids(
242242
centroidOrds[i] = i;
243243
}
244244

245-
List<CentroidPartition> centroidPartitions = new ArrayList<>();
245+
CentroidPartition[] centroidPartitions = null;
246+
int partitionsCount = 0;
246247

247248
if (centroids.length > IVFVectorsFormat.DEFAULT_VECTORS_PER_CLUSTER) {
248-
List<float[]> centroidsList = Arrays.stream(centroids).toList();
249-
FloatVectorValues centroidsAsFVV = FloatVectorValues.fromFloats(centroidsList, fieldInfo.getVectorDimension());
250-
251-
HierarchicalKMeans hierarchicalKMeans = new HierarchicalKMeans(fieldInfo.getVectorDimension());
252-
KMeansResult result = hierarchicalKMeans.cluster(centroidsAsFVV, centroids.length / (int) Math.sqrt(centroids.length));
249+
KMeansResult result = clusterParentCentroids(fieldInfo, centroids);
253250
float[][] parentCentroids = result.centroids();
254251
int[] parentChildAssignments = result.assignments();
255-
// TODO: explore using soar assignments here as well
256-
// int[] parentChildSoarAssignments = result.soarAssignments();
252+
// TODO: explore soar assignments here as well
253+
254+
centroidPartitions = new CentroidPartition[parentCentroids.length];
257255

258256
AssignmentArraySorter sorter = new AssignmentArraySorter(centroids, centroidOrds, parentChildAssignments);
259257
sorter.sort(0, centroids.length);
260258

261-
for (int i = 0; i < parentChildAssignments.length; i++) {
259+
for (int i = 0; i < parentChildAssignments.length;) {
262260
int label = parentChildAssignments[i];
263261
int centroidCount = 0;
262+
int childOffset = i;
264263
int j = i;
265264
for (; j < parentChildAssignments.length; j++) {
266265
if (parentChildAssignments[j] != label) {
267266
break;
268267
}
269268
centroidCount++;
270269
}
271-
int childOrdinal = i;
272270
i = j;
273-
centroidPartitions.add(new CentroidPartition(parentCentroids[label], childOrdinal, centroidCount));
271+
centroidPartitions[partitionsCount++] = new CentroidPartition(parentCentroids[label], childOffset, centroidCount);
274272
}
275273
}
276274

277275
writeCentroidsAndPartitions(centroidPartitions, centroids, fieldInfo, globalCentroid, centroidOutput);
278276

279277
if (logger.isDebugEnabled()) {
280278
logger.debug("calculate centroids and assign vectors time ms: {}", (System.nanoTime() - nanoTime) / 1000000.0);
281-
logger.debug("final parent centroid count {}: ", centroidPartitions.size());
279+
logger.debug("final parent centroid count {}: ", partitionsCount);
282280
logger.debug("final centroid count: {}", centroids.length);
283281
}
284282

@@ -291,7 +289,40 @@ CentroidAssignments calculateAndWriteCentroids(
291289
int[] soarAssignments = kMeansResult.soarAssignments();
292290

293291
int[][] assignmentsByCluster = buildCentroidAssignments(centroids.length, assignments, soarAssignments, centroidOrdsToIdx);
294-
return new CentroidAssignments(centroidPartitions.size(), centroids, assignmentsByCluster);
292+
return new CentroidAssignments(partitionsCount, centroids, assignmentsByCluster);
293+
}
294+
295+
private KMeansResult clusterParentCentroids(FieldInfo fieldInfo, float[][] centroids) throws IOException {
296+
FloatVectorValues centroidsAsFVV = new FloatVectorValues() {
297+
@Override
298+
public int size() {
299+
return centroids.length;
300+
}
301+
302+
@Override
303+
public int dimension() {
304+
return fieldInfo.getVectorDimension();
305+
}
306+
307+
@Override
308+
public float[] vectorValue(int targetOrd) {
309+
return centroids[targetOrd];
310+
}
311+
312+
@Override
313+
public FloatVectorValues copy() {
314+
return this;
315+
}
316+
317+
@Override
318+
public DocIndexIterator iterator() {
319+
return createDenseIterator();
320+
}
321+
};
322+
323+
HierarchicalKMeans hierarchicalKMeans = new HierarchicalKMeans(fieldInfo.getVectorDimension());
324+
KMeansResult result = hierarchicalKMeans.cluster(centroidsAsFVV, centroids.length / (int) Math.sqrt(centroids.length));
325+
return result;
295326
}
296327

297328
static int[][] buildCentroidAssignments(int centroidCount, int[] assignments, int[] soarAssignments, IntIntMap centroidOrds) {

server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
*/
4747
public abstract class IVFVectorsReader extends KnnVectorsReader {
4848

49+
private static final float PARENT_SCORE_BUFFER = 0.05f;
50+
4951
private final IndexInput ivfCentroids, ivfClusters;
5052
private final SegmentReadState state;
5153
private final FieldInfos fieldInfos;
@@ -311,7 +313,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector
311313
// ... would be centroidScore < (nextParentScore + furthestCentroidScore) which is better than just a buffer
312314
// TODO: try a ParentNProbe here that's for instance the sqrt(nProbe) that forces a fixed
313315
// ... number of parents to be explored at each step
314-
while (parentCentroidQueue.size() > 0 && centroidScore < (nextParentScore + nextParentScore * 0.05)) {
316+
while (parentCentroidQueue.size() > 0 && centroidScore < (nextParentScore + nextParentScore * PARENT_SCORE_BUFFER)) {
315317
updateCentroidQueueWNextParent(parentCentroidQueryScorer, parentCentroidQueue, centroidQueryScorer, centroidQueue);
316318
if (parentCentroidQueue.size() > 0) {
317319
nextParentScore = parentCentroidQueue.topScore();
@@ -358,7 +360,7 @@ private static int updateCentroidQueueWNextParent(
358360
childCentroidOrdinal = parentCentroidQueryScorer.getChildCentroidStart(parentCentroidOrdinal);
359361
childCentroidCount = parentCentroidQueryScorer.getChildCount(parentCentroidOrdinal);
360362
}
361-
// TODO: add back scorePostingLists? seems like it's not doing anything at this point
363+
// TODO: add back scorePostingLists? or make this function abstract? or break it apart? not sure
362364
centroidQueryScorer.bulkScore(centroidQueue, childCentroidOrdinal, childCentroidOrdinal + childCentroidCount);
363365

364366
return childCentroidCount;

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ public KMeansResult cluster(FloatVectorValues vectors, int targetSize) throws IO
6666
}
6767

6868
// partition the space
69-
KMeansIntermediate kMeansIntermediate = clusterAndSplit(vectors, targetSize, 0);
69+
KMeansIntermediate kMeansIntermediate = clusterAndSplit(vectors, targetSize);
7070
if (kMeansIntermediate.centroids().length > 1 && kMeansIntermediate.centroids().length < vectors.size()) {
7171
int localSampleSize = Math.min(kMeansIntermediate.centroids().length * samplesPerCluster / 2, vectors.size());
7272
KMeansLocal kMeansLocal = new KMeansLocal(localSampleSize, maxIterations);
@@ -76,7 +76,7 @@ public KMeansResult cluster(FloatVectorValues vectors, int targetSize) throws IO
7676
return kMeansIntermediate;
7777
}
7878

79-
KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int targetSize, final int depth) throws IOException {
79+
KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int targetSize) throws IOException {
8080
if (vectors.size() <= targetSize) {
8181
return new KMeansIntermediate();
8282
}
@@ -118,7 +118,7 @@ KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int ta
118118
// TODO: consider iterative here instead of recursive
119119
// recursive call to build out the sub partitions around this centroid c
120120
// subsequently reconcile and flatten the space of all centroids and assignments into one structure we can return
121-
updateAssignmentsWithRecursiveSplit(kMeansIntermediate, c, clusterAndSplit(sample, targetSize, depth + 1), depth);
121+
updateAssignmentsWithRecursiveSplit(kMeansIntermediate, c, clusterAndSplit(sample, targetSize));
122122
}
123123
}
124124

@@ -138,7 +138,7 @@ static FloatVectorValues createClusterSlice(int clusterSize, int cluster, FloatV
138138
return new FloatVectorValuesSlice(vectors, slice);
139139
}
140140

141-
void updateAssignmentsWithRecursiveSplit(KMeansIntermediate current, int cluster, KMeansIntermediate subPartitions, int depth) {
141+
void updateAssignmentsWithRecursiveSplit(KMeansIntermediate current, int cluster, KMeansIntermediate subPartitions) {
142142
int orgCentroidsSize = current.centroids().length;
143143
int newCentroidsSize = current.centroids().length + subPartitions.centroids().length - 1;
144144

0 commit comments

Comments
 (0)