Skip to content

Commit 8a29de1

Browse files
committed
Prefetch vectors during merge
1 parent 8ca0947 commit 8a29de1

File tree

8 files changed

+203
-80
lines changed

8 files changed

+203
-80
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.apache.lucene.util.packed.PackedLongValues;
2626
import org.elasticsearch.index.codec.vectors.cluster.HierarchicalKMeans;
2727
import org.elasticsearch.index.codec.vectors.cluster.KMeansResult;
28+
import org.elasticsearch.index.codec.vectors.cluster.PrefetchingFloatVectorValues;
2829
import org.elasticsearch.logging.LogManager;
2930
import org.elasticsearch.logging.Logger;
3031
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
@@ -63,7 +64,7 @@ public DefaultIVFVectorsWriter(
6364
LongValues buildAndWritePostingsLists(
6465
FieldInfo fieldInfo,
6566
CentroidSupplier centroidSupplier,
66-
FloatVectorValues floatVectorValues,
67+
PrefetchingFloatVectorValues floatVectorValues,
6768
IndexOutput postingsOutput,
6869
long fileOffset,
6970
int[] assignments,
@@ -155,7 +156,7 @@ LongValues buildAndWritePostingsLists(
155156
LongValues buildAndWritePostingsLists(
156157
FieldInfo fieldInfo,
157158
CentroidSupplier centroidSupplier,
158-
FloatVectorValues floatVectorValues,
159+
PrefetchingFloatVectorValues floatVectorValues,
159160
IndexOutput postingsOutput,
160161
long fileOffset,
161162
MergeState mergeState,
@@ -426,7 +427,7 @@ private void writeCentroidsWithoutParents(
426427
private record CentroidGroups(float[][] centroids, int[][] vectors, int maxVectorsPerCentroidLength) {}
427428

428429
private CentroidGroups buildCentroidGroups(FieldInfo fieldInfo, CentroidSupplier centroidSupplier) throws IOException {
429-
final FloatVectorValues floatVectorValues = FloatVectorValues.fromFloats(new AbstractList<>() {
430+
final PrefetchingFloatVectorValues floatVectorValues = PrefetchingFloatVectorValues.floats(new AbstractList<>() {
430431
@Override
431432
public float[] get(int index) {
432433
try {
@@ -479,7 +480,7 @@ public int size() {
479480
* @throws IOException if an I/O error occurs
480481
*/
481482
@Override
482-
CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, float[] globalCentroid)
483+
CentroidAssignments calculateCentroids(FieldInfo fieldInfo, PrefetchingFloatVectorValues floatVectorValues, float[] globalCentroid)
483484
throws IOException {
484485

485486
long nanoTime = System.nanoTime();
@@ -506,7 +507,8 @@ CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues fl
506507
return centroidAssignments;
507508
}
508509

509-
static CentroidAssignments buildCentroidAssignments(FloatVectorValues floatVectorValues, int vectorPerCluster) throws IOException {
510+
static CentroidAssignments buildCentroidAssignments(PrefetchingFloatVectorValues floatVectorValues, int vectorPerCluster)
511+
throws IOException {
510512
KMeansResult kMeansResult = new HierarchicalKMeans(floatVectorValues.dimension()).cluster(floatVectorValues, vectorPerCluster);
511513
float[][] centroids = kMeansResult.centroids();
512514
int[] assignments = kMeansResult.assignments();

server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java

Lines changed: 34 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.apache.lucene.util.VectorUtil;
3333
import org.elasticsearch.core.IOUtils;
3434
import org.elasticsearch.core.SuppressForbidden;
35+
import org.elasticsearch.index.codec.vectors.cluster.PrefetchingFloatVectorValues;
3536

3637
import java.io.IOException;
3738
import java.io.UncheckedIOException;
@@ -120,8 +121,11 @@ public final KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOExc
120121
return rawVectorDelegate;
121122
}
122123

123-
abstract CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, float[] globalCentroid)
124-
throws IOException;
124+
abstract CentroidAssignments calculateCentroids(
125+
FieldInfo fieldInfo,
126+
PrefetchingFloatVectorValues floatVectorValues,
127+
float[] globalCentroid
128+
) throws IOException;
125129

126130
abstract void writeCentroids(
127131
FieldInfo fieldInfo,
@@ -134,7 +138,7 @@ abstract void writeCentroids(
134138
abstract LongValues buildAndWritePostingsLists(
135139
FieldInfo fieldInfo,
136140
CentroidSupplier centroidSupplier,
137-
FloatVectorValues floatVectorValues,
141+
PrefetchingFloatVectorValues floatVectorValues,
138142
IndexOutput postingsOutput,
139143
long fileOffset,
140144
int[] assignments,
@@ -144,7 +148,7 @@ abstract LongValues buildAndWritePostingsLists(
144148
abstract LongValues buildAndWritePostingsLists(
145149
FieldInfo fieldInfo,
146150
CentroidSupplier centroidSupplier,
147-
FloatVectorValues floatVectorValues,
151+
PrefetchingFloatVectorValues floatVectorValues,
148152
IndexOutput postingsOutput,
149153
long fileOffset,
150154
MergeState mergeState,
@@ -165,7 +169,11 @@ public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
165169
for (FieldWriter fieldWriter : fieldWriters) {
166170
final float[] globalCentroid = new float[fieldWriter.fieldInfo.getVectorDimension()];
167171
// build a float vector values with random access
168-
final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldWriter.fieldInfo, fieldWriter.delegate, maxDoc);
172+
final PrefetchingFloatVectorValues floatVectorValues = getFloatVectorValues(
173+
fieldWriter.fieldInfo,
174+
fieldWriter.delegate,
175+
maxDoc
176+
);
169177
// build centroids
170178
final CentroidAssignments centroidAssignments = calculateCentroids(fieldWriter.fieldInfo, floatVectorValues, globalCentroid);
171179
// wrap centroids with a supplier
@@ -199,47 +207,22 @@ public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
199207
}
200208
}
201209

202-
private static FloatVectorValues getFloatVectorValues(
210+
private static PrefetchingFloatVectorValues getFloatVectorValues(
203211
FieldInfo fieldInfo,
204212
FlatFieldVectorsWriter<float[]> fieldVectorsWriter,
205213
int maxDoc
206214
) throws IOException {
207215
List<float[]> vectors = fieldVectorsWriter.getVectors();
208216
if (vectors.size() == maxDoc) {
209-
return FloatVectorValues.fromFloats(vectors, fieldInfo.getVectorDimension());
217+
return PrefetchingFloatVectorValues.floats(vectors, fieldInfo.getVectorDimension());
210218
}
211219
final DocIdSetIterator iterator = fieldVectorsWriter.getDocsWithFieldSet().iterator();
212220
final int[] docIds = new int[vectors.size()];
213221
for (int i = 0; i < docIds.length; i++) {
214222
docIds[i] = iterator.nextDoc();
215223
}
216224
assert iterator.nextDoc() == NO_MORE_DOCS;
217-
return new FloatVectorValues() {
218-
@Override
219-
public float[] vectorValue(int ord) {
220-
return vectors.get(ord);
221-
}
222-
223-
@Override
224-
public FloatVectorValues copy() {
225-
return this;
226-
}
227-
228-
@Override
229-
public int dimension() {
230-
return fieldInfo.getVectorDimension();
231-
}
232-
233-
@Override
234-
public int size() {
235-
return vectors.size();
236-
}
237-
238-
@Override
239-
public int ordToDoc(int ord) {
240-
return docIds[ord];
241-
}
242-
};
225+
return PrefetchingFloatVectorValues.floats(vectors, fieldInfo.getVectorDimension(), docIds);
243226
}
244227

245228
@Override
@@ -297,7 +280,7 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws
297280
IndexInput vectors = mergeState.segmentInfo.dir.openInput(tempRawVectorsFileName, IOContext.DEFAULT);
298281
IndexInput docs = docsFileName == null ? null : mergeState.segmentInfo.dir.openInput(docsFileName, IOContext.DEFAULT)
299282
) {
300-
final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo, docs, vectors, numVectors);
283+
final PrefetchingFloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo, docs, vectors, numVectors);
301284

302285
final long centroidOffset;
303286
final long centroidLength;
@@ -396,15 +379,26 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws
396379
}
397380
}
398381

399-
private static FloatVectorValues getFloatVectorValues(FieldInfo fieldInfo, IndexInput docs, IndexInput vectors, int numVectors)
400-
throws IOException {
382+
private static PrefetchingFloatVectorValues getFloatVectorValues(
383+
FieldInfo fieldInfo,
384+
IndexInput docs,
385+
IndexInput vectors,
386+
int numVectors
387+
) throws IOException {
401388
if (numVectors == 0) {
402-
return FloatVectorValues.fromFloats(List.of(), fieldInfo.getVectorDimension());
389+
return PrefetchingFloatVectorValues.floats(List.of(), fieldInfo.getVectorDimension());
403390
}
404391
final long vectorLength = (long) Float.BYTES * fieldInfo.getVectorDimension();
405392
final float[] vector = new float[fieldInfo.getVectorDimension()];
406393
final RandomAccessInput randomDocs = docs == null ? null : docs.randomAccessSlice(0, docs.length());
407-
return new FloatVectorValues() {
394+
return new PrefetchingFloatVectorValues() {
395+
@Override
396+
public void prefetch(int... ord) throws IOException {
397+
for (int o : ord) {
398+
vectors.prefetch(o * vectorLength, vectorLength);
399+
}
400+
}
401+
408402
@Override
409403
public float[] vectorValue(int ord) throws IOException {
410404
vectors.seek(ord * vectorLength);
@@ -413,7 +407,8 @@ public float[] vectorValue(int ord) throws IOException {
413407
}
414408

415409
@Override
416-
public FloatVectorValues copy() {
410+
public PrefetchingFloatVectorValues copy() {
411+
assert false;
417412
return this;
418413
}
419414

server/src/main/java/org/elasticsearch/index/codec/vectors/SampleReader.java

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,21 @@
2020
package org.elasticsearch.index.codec.vectors;
2121

2222
import org.apache.lucene.codecs.lucene95.HasIndexSlice;
23-
import org.apache.lucene.index.FloatVectorValues;
2423
import org.apache.lucene.store.IndexInput;
2524
import org.apache.lucene.util.Bits;
25+
import org.elasticsearch.index.codec.vectors.cluster.PrefetchingFloatVectorValues;
2626

2727
import java.io.IOException;
2828
import java.util.Arrays;
2929
import java.util.Random;
3030
import java.util.function.IntUnaryOperator;
3131

32-
public class SampleReader extends FloatVectorValues implements HasIndexSlice {
33-
private final FloatVectorValues origin;
32+
public class SampleReader extends PrefetchingFloatVectorValues implements HasIndexSlice {
33+
private final PrefetchingFloatVectorValues origin;
3434
private final int sampleSize;
3535
private final IntUnaryOperator sampleFunction;
3636

37-
SampleReader(FloatVectorValues origin, int sampleSize, IntUnaryOperator sampleFunction) {
37+
SampleReader(PrefetchingFloatVectorValues origin, int sampleSize, IntUnaryOperator sampleFunction) {
3838
this.origin = origin;
3939
this.sampleSize = sampleSize;
4040
this.sampleFunction = sampleFunction;
@@ -51,7 +51,14 @@ public int dimension() {
5151
}
5252

5353
@Override
54-
public FloatVectorValues copy() throws IOException {
54+
public void prefetch(int... ord) throws IOException {
55+
for (int o : ord) {
56+
origin.prefetch(sampleFunction.applyAsInt(o));
57+
}
58+
}
59+
60+
@Override
61+
public PrefetchingFloatVectorValues copy() throws IOException {
5562
throw new IllegalStateException("Not supported");
5663
}
5764

@@ -81,7 +88,7 @@ public Bits getAcceptOrds(Bits acceptDocs) {
8188
throw new IllegalStateException("Not supported");
8289
}
8390

84-
public static SampleReader createSampleReader(FloatVectorValues origin, int k, long seed) {
91+
public static SampleReader createSampleReader(PrefetchingFloatVectorValues origin, int k, long seed) {
8592
// TODO can we do something algorithmically that aligns an ordinal with a unique integer between 0 and numVectors?
8693
if (k >= origin.size()) {
8794
new SampleReader(origin, origin.size(), i -> i);

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/FloatVectorValuesSlice.java

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,14 @@
99

1010
package org.elasticsearch.index.codec.vectors.cluster;
1111

12-
import org.apache.lucene.index.FloatVectorValues;
13-
1412
import java.io.IOException;
1513

16-
class FloatVectorValuesSlice extends FloatVectorValues {
14+
class FloatVectorValuesSlice extends PrefetchingFloatVectorValues {
1715

18-
private final FloatVectorValues allValues;
16+
private final PrefetchingFloatVectorValues allValues;
1917
private final int[] slice;
2018

21-
FloatVectorValuesSlice(FloatVectorValues allValues, int[] slice) {
19+
FloatVectorValuesSlice(PrefetchingFloatVectorValues allValues, int[] slice) {
2220
assert slice != null;
2321
assert slice.length <= allValues.size();
2422
this.allValues = allValues;
@@ -46,7 +44,14 @@ public int ordToDoc(int ord) {
4644
}
4745

4846
@Override
49-
public FloatVectorValues copy() throws IOException {
47+
public PrefetchingFloatVectorValues copy() throws IOException {
5048
return new FloatVectorValuesSlice(this.allValues.copy(), this.slice);
5149
}
50+
51+
@Override
52+
public void prefetch(int... ord) throws IOException {
53+
for (int o : ord) {
54+
this.allValues.prefetch(slice[o]);
55+
}
56+
}
5257
}

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99

1010
package org.elasticsearch.index.codec.vectors.cluster;
1111

12-
import org.apache.lucene.index.FloatVectorValues;
13-
1412
import java.io.IOException;
1513
import java.util.Arrays;
1614
import java.util.Objects;
@@ -52,7 +50,7 @@ public HierarchicalKMeans(int dimension, int maxIterations, int samplesPerCluste
5250
* @return the centroids and the vectors assignments and SOAR (spilled from nearby neighborhoods) assignments
5351
* @throws IOException is thrown if vectors is inaccessible
5452
*/
55-
public KMeansResult cluster(FloatVectorValues vectors, int targetSize) throws IOException {
53+
public KMeansResult cluster(PrefetchingFloatVectorValues vectors, int targetSize) throws IOException {
5654

5755
if (vectors.size() == 0) {
5856
return new KMeansIntermediate();
@@ -86,7 +84,7 @@ public KMeansResult cluster(FloatVectorValues vectors, int targetSize) throws IO
8684
return kMeansIntermediate;
8785
}
8886

89-
KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int targetSize) throws IOException {
87+
KMeansIntermediate clusterAndSplit(final PrefetchingFloatVectorValues vectors, final int targetSize) throws IOException {
9088
if (vectors.size() <= targetSize) {
9189
return new KMeansIntermediate();
9290
}
@@ -132,7 +130,7 @@ KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int ta
132130
final int count = centroidVectorCount[c];
133131
final int adjustedCentroid = c - removedElements;
134132
if (100 * count > 134 * targetSize) {
135-
final FloatVectorValues sample = createClusterSlice(count, adjustedCentroid, vectors, assignments);
133+
final PrefetchingFloatVectorValues sample = createClusterSlice(count, adjustedCentroid, vectors, assignments);
136134
// TODO: consider iterative here instead of recursive
137135
// recursive call to build out the sub partitions around this centroid c
138136
// subsequently reconcile and flatten the space of all centroids and assignments into one structure we can return
@@ -163,7 +161,12 @@ KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int ta
163161
return kMeansIntermediate;
164162
}
165163

166-
static FloatVectorValues createClusterSlice(int clusterSize, int cluster, FloatVectorValues vectors, int[] assignments) {
164+
static PrefetchingFloatVectorValues createClusterSlice(
165+
int clusterSize,
166+
int cluster,
167+
PrefetchingFloatVectorValues vectors,
168+
int[] assignments
169+
) {
167170
int[] slice = new int[clusterSize];
168171
int idx = 0;
169172
for (int i = 0; i < assignments.length; i++) {

0 commit comments

Comments
 (0)