Skip to content

Commit 3a2ba61

Browse files
committed
using full hkmeans to gen parents
1 parent ffe2929 commit 3a2ba61

File tree

5 files changed

+81
-148
lines changed

5 files changed

+81
-148
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/AssignmentArraySorter.java

Lines changed: 9 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -9,34 +9,28 @@
99

1010
package org.elasticsearch.index.codec.vectors;
1111

12-
import org.apache.lucene.util.TimSorter;
13-
14-
class AssignmentArraySorter extends TimSorter {
12+
import org.apache.lucene.util.IntroSorter;
1513

14+
class AssignmentArraySorter extends IntroSorter {
15+
int pivot = -1;
1616
private final float[][] centroids;
17-
private final float[][] tmpC;
18-
1917
private final int[] centroidsOrds;
20-
private final int[] tmpA;
21-
2218
private final int[] sortOrdering;
23-
private final int[] tmpSort;
2419

2520
AssignmentArraySorter(float[][] centroids, int[] centroidsOrds, int[] sortOrdering) {
26-
super(centroids.length / 64);
2721
this.centroids = centroids;
2822
this.centroidsOrds = centroidsOrds;
2923
this.sortOrdering = sortOrdering;
24+
}
3025

31-
int maxTempSlots = centroids.length / 64;
32-
this.tmpC = new float[maxTempSlots][];
33-
this.tmpA = new int[maxTempSlots];
34-
this.tmpSort = new int[maxTempSlots];
26+
@Override
27+
protected void setPivot(int i) {
28+
pivot = sortOrdering[i];
3529
}
3630

3731
@Override
38-
protected int compare(int i, int j) {
39-
return Integer.compare(sortOrdering[i], sortOrdering[j]);
32+
protected int comparePivot(int j) {
33+
return Integer.compare(pivot, sortOrdering[j]);
4034
}
4135

4236
@Override
@@ -53,30 +47,4 @@ protected void swap(int i, int j) {
5347
sortOrdering[i] = sortOrdering[j];
5448
sortOrdering[j] = tmpSort;
5549
}
56-
57-
@Override
58-
protected void copy(int src, int dest) {
59-
centroids[dest] = centroids[src];
60-
centroidsOrds[dest] = centroidsOrds[src];
61-
sortOrdering[dest] = sortOrdering[src];
62-
}
63-
64-
@Override
65-
protected void save(int start, int len) {
66-
System.arraycopy(centroids, start, tmpC, 0, len);
67-
System.arraycopy(centroidsOrds, start, tmpA, 0, len);
68-
System.arraycopy(sortOrdering, start, tmpSort, 0, len);
69-
}
70-
71-
@Override
72-
protected void restore(int src, int dest) {
73-
centroids[dest] = tmpC[src];
74-
centroidsOrds[dest] = tmpA[src];
75-
sortOrdering[dest] = tmpSort[src];
76-
}
77-
78-
@Override
79-
protected int compareSaved(int i, int j) {
80-
return Integer.compare(tmpSort[i], sortOrdering[j]);
81-
}
8250
}

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -167,17 +167,19 @@ ParentCentroidQueryScorer getParentCentroidScorer(
167167
float[] targetQuery
168168
) throws IOException {
169169
FieldEntry fieldEntry = fields.get(fieldInfo.number);
170-
float[] globalCentroid = fieldEntry.globalCentroid();
171170
float globalCentroidDp = fieldEntry.globalCentroidDp();
172171
OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
173-
byte[] quantized = new byte[targetQuery.length];
174-
float[] targetScratch = ArrayUtil.copyArray(targetQuery);
175-
OptimizedScalarQuantizer.QuantizationResult queryParams = scalarQuantizer.scalarQuantize(
176-
targetScratch,
177-
quantized,
172+
final int[] scratch = new int[targetQuery.length];
173+
final OptimizedScalarQuantizer.QuantizationResult queryParams = scalarQuantizer.scalarQuantize(
174+
ArrayUtil.copyArray(targetQuery),
175+
scratch,
178176
(byte) 4,
179-
globalCentroid
177+
fieldEntry.globalCentroid()
180178
);
179+
final byte[] quantized = new byte[targetQuery.length];
180+
for (int i = 0; i < quantized.length; i++) {
181+
quantized[i] = (byte) scratch[i];
182+
}
181183
final ES91Int4VectorsScorer scorer = ESVectorUtil.getES91Int4VectorsScorer(centroids, fieldInfo.getVectorDimension());
182184
return new ParentCentroidQueryScorer() {
183185
int currentCentroid = -1;

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java

Lines changed: 41 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99

1010
package org.elasticsearch.index.codec.vectors;
1111

12+
import com.carrotsearch.hppc.IntIntHashMap;
13+
import com.carrotsearch.hppc.IntIntMap;
14+
1215
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
1316
import org.apache.lucene.index.FieldInfo;
1417
import org.apache.lucene.index.FloatVectorValues;
@@ -30,8 +33,6 @@
3033
import java.util.Arrays;
3134
import java.util.List;
3235

33-
import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.DEFAULT_VECTORS_PER_CLUSTER;
34-
3536
/**
3637
* Default implementation of {@link IVFVectorsWriter}. It uses {@link HierarchicalKMeans} algorithm to
3738
* partition the vector space, and then stores the centroids and posting list in a sequential
@@ -153,7 +154,10 @@ static void writeCentroidsAndPartitions(
153154
(byte) 4,
154155
globalCentroid
155156
);
156-
writeQuantizedValue(centroidOutput, quantizedScratch, result);
157+
for (int i = 0; i < quantizedScratch.length; i++) {
158+
quantized[i] = (byte) quantizedScratch[i];
159+
}
160+
writeQuantizedValue(centroidOutput, quantized, result);
157161
centroidOutput.writeInt(centroidPartition.childOrdinal());
158162
centroidOutput.writeInt(centroidPartition.size());
159163
}
@@ -252,40 +256,32 @@ CentroidAssignments calculateAndWriteCentroids(
252256

253257
List<CentroidPartition> centroidPartitions = new ArrayList<>();
254258

255-
// TODO: make this configurable
256-
if (centroids.length > DEFAULT_VECTORS_PER_CLUSTER) {
257-
// TODO: sort by global centroids as well
258-
// TODO: have this take a function instead of just an int[] for sorting
259-
AssignmentArraySorter sorter = new AssignmentArraySorter(centroids, centroidOrds, kMeansResult.parentLayer());
260-
sorter.sort(0, centroids.length);
261-
262-
for (int i = 0; i < kMeansResult.parentLayer().length;) {
263-
// for any layer that was not partitioned we treat it duplicatively as a parent and child
264-
if (kMeansResult.parentLayer()[i] == -1) {
265-
centroidPartitions.add(new CentroidPartition(centroids[i], i, 1));
266-
i++;
267-
} else {
268-
int label = kMeansResult.parentLayer()[i];
269-
int centroidCount = 0;
270-
float[] parentPartitionCentroid = new float[fieldInfo.getVectorDimension()];
271-
int j = i;
272-
for (; j < kMeansResult.parentLayer().length; j++) {
273-
if (kMeansResult.parentLayer()[j] != label) {
274-
break;
275-
}
276-
for (int k = 0; k < parentPartitionCentroid.length; k++) {
277-
parentPartitionCentroid[k] += centroids[i][k];
278-
}
279-
centroidCount++;
280-
}
281-
int childOrdinal = i;
282-
i = j;
283-
for (int d = 0; d < parentPartitionCentroid.length; d++) {
284-
parentPartitionCentroid[d] /= centroidCount;
285-
}
286-
centroidPartitions.add(new CentroidPartition(parentPartitionCentroid, childOrdinal, centroidCount));
259+
List<float[]> centroidsList = Arrays.stream(centroids).toList();
260+
FloatVectorValues centroidsAsFVV = FloatVectorValues.fromFloats(centroidsList, fieldInfo.getVectorDimension());
261+
262+
HierarchicalKMeans hierarchicalKMeans = new HierarchicalKMeans(fieldInfo.getVectorDimension());
263+
KMeansResult result = hierarchicalKMeans.cluster(centroidsAsFVV, centroids.length / (int) Math.sqrt(centroids.length));
264+
float[][] parentCentroids = result.centroids();
265+
int[] parentChildAssignments = result.assignments();
266+
// TODO: explore using soar assignments here as well
267+
//int[] parentChildSoarAssignments = result.soarAssignments();
268+
269+
AssignmentArraySorter sorter = new AssignmentArraySorter(centroids, centroidOrds, parentChildAssignments);
270+
sorter.sort(0, centroids.length);
271+
272+
for(int i = 0; i < parentChildAssignments.length; i++) {
273+
int label = parentChildAssignments[i];
274+
int centroidCount = 0;
275+
int j = i;
276+
for(; j < parentChildAssignments.length; j++) {
277+
if(parentChildAssignments[j] != label) {
278+
break;
287279
}
280+
centroidCount++;
288281
}
282+
int childOrdinal = i;
283+
i = j;
284+
centroidPartitions.add(new CentroidPartition(parentCentroids[label], childOrdinal, centroidCount));
289285
}
290286

291287
writeCentroidsAndPartitions(centroidPartitions, centroids, fieldInfo, globalCentroid, centroidOutput);
@@ -298,7 +294,11 @@ CentroidAssignments calculateAndWriteCentroids(
298294
logger.debug("final centroid count: {}", centroids.length);
299295
}
300296

301-
int[][] assignmentsByCluster = mapAssignmentsByCluster(centroids.length, assignments, soarAssignments, centroidOrds);
297+
IntIntMap centroidOrdsToIdx = new IntIntHashMap(centroidOrds.length);
298+
for(int i = 0; i < centroidOrds.length; i++) {
299+
centroidOrdsToIdx.put(centroidOrds[i], i);
300+
}
301+
int[][] assignmentsByCluster = mapAssignmentsByCluster(centroids.length, assignments, soarAssignments, centroidOrdsToIdx);
302302

303303
if (cacheCentroids) {
304304
return new CentroidAssignments(centroidPartitions.size(), centroids, assignmentsByCluster);
@@ -307,26 +307,14 @@ CentroidAssignments calculateAndWriteCentroids(
307307
}
308308
}
309309

310-
// FIXME: clean this up
311-
static int[][] mapAssignmentsByCluster(int centroidCount, int[] assignments, int[] soarAssignments, int[] centroidOrds) {
310+
static int[][] mapAssignmentsByCluster(int centroidCount, int[] assignments, int[] soarAssignments, IntIntMap centroidOrds) {
312311
int[] centroidVectorCount = new int[centroidCount];
313312
for (int i = 0; i < assignments.length; i++) {
314-
int c = -1;
315-
// FIXME: create a reverse mapping prior to this step? .. expensive
316-
for (int j = 0; j < centroidOrds.length; j++) {
317-
if (assignments[i] == centroidOrds[j]) {
318-
c = j;
319-
}
320-
}
313+
int c = centroidOrds.get(assignments[i]);
321314
centroidVectorCount[c]++;
322315
// if soar assignments are present, count them as well
323316
if (soarAssignments.length > i && soarAssignments[i] != -1) {
324-
int s = -1;
325-
for (int j = 0; j < centroidOrds.length; j++) {
326-
if (soarAssignments[i] == centroidOrds[j]) {
327-
s = j;
328-
}
329-
}
317+
int s = centroidOrds.get(soarAssignments[i]);
330318
centroidVectorCount[s]++;
331319
}
332320
}
@@ -338,21 +326,11 @@ static int[][] mapAssignmentsByCluster(int centroidCount, int[] assignments, int
338326
Arrays.fill(centroidVectorCount, 0);
339327

340328
for (int i = 0; i < assignments.length; i++) {
341-
int c = -1;
342-
for (int j = 0; j < centroidOrds.length; j++) {
343-
if (assignments[i] == centroidOrds[j]) {
344-
c = j;
345-
}
346-
}
329+
int c = centroidOrds.get(assignments[i]);
347330
assignmentsByCluster[c][centroidVectorCount[c]++] = i;
348331
// if soar assignments are present, add them to the cluster as well
349332
if (soarAssignments.length > i) {
350-
int s = -1;
351-
for (int j = 0; j < centroidOrds.length; j++) {
352-
if (soarAssignments[i] == centroidOrds[j]) {
353-
s = j;
354-
}
355-
}
333+
int s = centroidOrds.getOrDefault(soarAssignments[i], -1);
356334
if (s != -1) {
357335
assignmentsByCluster[s][centroidVectorCount[s]++] = i;
358336
}

server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -221,10 +221,6 @@ public final ByteVectorValues getByteVectorValues(String field) throws IOExcepti
221221
return rawVectorsReader.getByteVectorValues(field);
222222
}
223223

224-
// FIXME: remove the diagnostics
225-
int centroidsRead = 0;
226-
int ii = 0;
227-
228224
@Override
229225
public final void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
230226
final FieldInfo fieldInfo = state.fieldInfos.fieldInfo(field);
@@ -296,8 +292,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector
296292

297293
while (parentCentroidQueue.size() > 0 && (centroidsVisited < nProbe || knnCollectorImpl.numCollected() < knnCollector.k())) {
298294
NeighborQueue centroidQueue = new NeighborQueue(centroidQueryScorer.size(), true);
299-
centroidsRead++;
300-
centroidsRead += updateCentroidQueueWNextParent(
295+
updateCentroidQueueWNextParent(
301296
parentCentroidQueryScorer,
302297
parentCentroidQueue,
303298
centroidQueryScorer,
@@ -317,9 +312,12 @@ public final void search(String field, float[] target, KnnCollector knnCollector
317312
++centroidsVisited;
318313
float centroidScore = centroidQueue.topScore();
319314
// the next parent likely contains centroids we need to evaluate prior to evaluating this next centroid
320-
while (parentCentroidQueue.size() > 0 && centroidScore < nextParentScore) {
321-
centroidsRead++;
322-
centroidsRead += updateCentroidQueueWNextParent(
315+
// TODO: for each parent centroid I could store furtherest centroid distance from that parent and then the comparison here
316+
// ... would be centroidScore < (nextParentScore + furthestCentroidScore) which is better than just a buffer
317+
// TODO: try a ParentNProbe here that's for instance the sqrt(nProbe) that forces a fixed
318+
// ... number of parents to be explored at each step
319+
while (parentCentroidQueue.size() > 0 && centroidScore < (nextParentScore + nextParentScore * 0.05) ) {
320+
updateCentroidQueueWNextParent(
323321
parentCentroidQueryScorer,
324322
parentCentroidQueue,
325323
centroidQueryScorer,
@@ -350,11 +348,6 @@ public final void search(String field, float[] target, KnnCollector knnCollector
350348
}
351349
}
352350
}
353-
354-
if (ii == 1999) {
355-
System.out.println("total centroids (parent & child) read:" + (centroidsRead / (ii + 1)));
356-
}
357-
ii++;
358351
}
359352

360353
private static int updateCentroidQueueWNextParent(

server/src/test/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriterTests.java

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,47 +9,39 @@
99

1010
package org.elasticsearch.index.codec.vectors;
1111

12+
import com.carrotsearch.hppc.IntIntHashMap;
13+
import com.carrotsearch.hppc.IntIntMap;
1214
import org.elasticsearch.test.ESTestCase;
1315

14-
import java.util.Arrays;
15-
1616
import static org.elasticsearch.index.codec.vectors.DefaultIVFVectorsWriter.mapAssignmentsByCluster;
1717

1818
public class DefaultIVFVectorsWriterTests extends ESTestCase {
1919

20-
// FIXME: clean this up
2120
public void testAssignmentsByCluster() {
2221

2322
// the assignments represent where vectors (by index) got assigned
2423
int[] assignments = new int[] { 3, 3, 2, 1, 1, 1, 1, 0, 2, 3, 4, 4, 2, 2, 2 };
2524
int[] soarAssignments = new int[] { 0, 0, 0, 1, 2, 3, 2, 1, 1, 1, 4, 4, 1, 2, 3 };
2625

27-
assert assignments.length == soarAssignments.length;
26+
assertEquals(assignments.length, soarAssignments.length);
2827

2928
// 0, 1, 2, 3, 4
3029
// these subsequently get sorted in the order in which the centroids would have been sorted in logic
3130
int[] centroidOrds = new int[] { 3, 2, 4, 0, 1 };
3231

33-
int[][] assignmentsByCluster = mapAssignmentsByCluster(centroidOrds.length, assignments, soarAssignments, centroidOrds);
34-
35-
/*
36-
* correct answer
37-
* [0, 1, 9, 5, 14]
38-
* [2, 8, 12, 13, 14, 4, 6, 13]
39-
* [10, 11, 10, 11]
40-
* [7, 0, 1, 2]
41-
* [3, 4, 5, 6, 3, 7, 8, 9, 12]
42-
*/
43-
for (int i = 0; i < assignmentsByCluster.length; i++) {
44-
System.out.println(Arrays.toString(assignmentsByCluster[i]));
32+
IntIntMap centroidOrdsToIdx = new IntIntHashMap(centroidOrds.length);
33+
for(int i = 0; i < centroidOrds.length; i++) {
34+
// idx 0 1 2 3 4 5
35+
// ord 3 2 0 4 1 5
36+
centroidOrdsToIdx.put(centroidOrds[i], i);
4537
}
4638

47-
/*
48-
* [0, 1, 5, 9, 14]
49-
* [2, 4, 6, 8, 12, 13, 13, 14]
50-
* [10, 10, 11, 11]
51-
* [0, 1, 2, 7]
52-
* [3, 3, 4, 5, 6, 7, 8, 9, 12]
53-
*/
39+
int[][] assignmentsByCluster = mapAssignmentsByCluster(centroidOrds.length, assignments, soarAssignments, centroidOrdsToIdx);
40+
41+
assertArrayEquals(assignmentsByCluster[0], new int[] {0, 1, 5, 9, 14});
42+
assertArrayEquals(assignmentsByCluster[1], new int[] {2, 4, 6, 8, 12, 13, 13, 14});
43+
assertArrayEquals(assignmentsByCluster[2], new int[] {10, 10, 11, 11});
44+
assertArrayEquals(assignmentsByCluster[3], new int[] {0, 1, 2, 7});
45+
assertArrayEquals(assignmentsByCluster[4], new int[] {3, 3, 4, 5, 6, 7, 8, 9, 12});
5446
}
5547
}

0 commit comments

Comments
 (0)