Skip to content

Commit 046c6af

Browse files
committed
Add DirectIO bulk rescoring
1 parent 3e7e63e commit 046c6af

File tree

4 files changed

+226
-6
lines changed

4 files changed

+226
-6
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/BulkScorableFloatVectorValues.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ public interface BulkScorableFloatVectorValues extends BulkScorableVectorValues
1919
* Returns a {@link BulkVectorScorer} that can score against the provided {@code target} vector.
2020
* It will score to the fastest speed possible, potentially sacrificing some fidelity.
2121
*/
22-
BulkVectorScorer scorer(float[] target) throws IOException;
22+
BulkVectorScorer bulkScorer(float[] target) throws IOException;
2323

2424
/**
2525
* Returns a {@link BulkVectorScorer} that can rescore against the provided {@code target} vector.
2626
* It will score to the highest fidelity possible, potentially sacrificing some speed.
2727
*/
28-
BulkVectorScorer rescorer(float[] target) throws IOException;
28+
BulkVectorScorer bulkRescorer(float[] target) throws IOException;
2929
}

server/src/main/java/org/elasticsearch/index/codec/vectors/BulkScorableVectorValues.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,8 @@ interface BulkVectorScorer extends VectorScorer {
3030
interface Bulk {
3131
/**
3232
* Scores up to {@code nextCount} docs in the provided {@code buffer}.
33-
* Returns the maxScore of docs scored.
3433
*/
35-
float nextDocsAndScores(int nextCount, Bits liveDocs, DocAndFloatFeatureBuffer buffer) throws IOException;
34+
void nextDocsAndScores(int nextCount, Bits liveDocs, DocAndFloatFeatureBuffer buffer) throws IOException;
3635
}
3736
}
3837
}

server/src/main/java/org/elasticsearch/index/codec/vectors/es818/DirectIOLucene99FlatVectorsFormat.java

Lines changed: 222 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,35 @@
2222
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
2323
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
2424
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
25+
import org.apache.lucene.codecs.lucene95.HasIndexSlice;
2526
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsReader;
2627
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsWriter;
28+
import org.apache.lucene.index.ByteVectorValues;
29+
import org.apache.lucene.index.FieldInfo;
30+
import org.apache.lucene.index.FloatVectorValues;
31+
import org.apache.lucene.index.KnnVectorValues;
2732
import org.apache.lucene.index.SegmentReadState;
2833
import org.apache.lucene.index.SegmentWriteState;
34+
import org.apache.lucene.index.VectorSimilarityFunction;
35+
import org.apache.lucene.search.ConjunctionUtils;
36+
import org.apache.lucene.search.DocAndFloatFeatureBuffer;
37+
import org.apache.lucene.search.DocIdSetIterator;
38+
import org.apache.lucene.search.VectorScorer;
2939
import org.apache.lucene.store.FlushInfo;
3040
import org.apache.lucene.store.IOContext;
41+
import org.apache.lucene.store.IndexInput;
3142
import org.apache.lucene.store.MergeInfo;
43+
import org.apache.lucene.util.Bits;
44+
import org.apache.lucene.util.hnsw.RandomVectorScorer;
3245
import org.elasticsearch.common.util.set.Sets;
3346
import org.elasticsearch.index.codec.vectors.AbstractFlatVectorsFormat;
47+
import org.elasticsearch.index.codec.vectors.BulkScorableFloatVectorValues;
48+
import org.elasticsearch.index.codec.vectors.BulkScorableVectorValues;
3449
import org.elasticsearch.index.codec.vectors.MergeReaderWrapper;
3550
import org.elasticsearch.index.store.FsDirectoryFactory;
3651

3752
import java.io.IOException;
53+
import java.util.List;
3854
import java.util.Set;
3955

4056
/**
@@ -86,8 +102,13 @@ public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException
86102
);
87103
// Use mmap for merges and direct I/O for searches.
88104
// TODO: Open the mmap file with sequential access instead of random (current behavior).
105+
// TODO: maybe we should force completely RANDOM access always for inner reader formats (outside of merges)?
89106
return new MergeReaderWrapper(
90-
new Lucene99FlatVectorsReader(directIOState, vectorsScorer),
107+
new Lucene99FlatBulkScoringVectorsReader(
108+
directIOState,
109+
new Lucene99FlatVectorsReader(directIOState, vectorsScorer),
110+
vectorsScorer
111+
),
91112
new Lucene99FlatVectorsReader(state, vectorsScorer)
92113
);
93114
} else {
@@ -129,4 +150,204 @@ public IOContext withHints(FileOpenHint... hints) {
129150
return new DirectIOContext(Set.of(hints));
130151
}
131152
}
153+
154+
static class Lucene99FlatBulkScoringVectorsReader extends FlatVectorsReader {
155+
private final Lucene99FlatVectorsReader inner;
156+
private final SegmentReadState state;
157+
158+
Lucene99FlatBulkScoringVectorsReader(SegmentReadState state, Lucene99FlatVectorsReader inner, FlatVectorsScorer scorer) {
159+
super(scorer);
160+
this.inner = inner;
161+
this.state = state;
162+
}
163+
164+
@Override
165+
public void close() throws IOException {
166+
inner.close();
167+
}
168+
169+
@Override
170+
public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException {
171+
return inner.getRandomVectorScorer(field, target);
172+
}
173+
174+
@Override
175+
public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException {
176+
return inner.getRandomVectorScorer(field, target);
177+
}
178+
179+
@Override
180+
public void checkIntegrity() throws IOException {
181+
inner.checkIntegrity();
182+
}
183+
184+
@Override
185+
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
186+
FloatVectorValues vectorValues = inner.getFloatVectorValues(field);
187+
if (vectorValues == null || vectorValues.size() == 0) {
188+
return null;
189+
}
190+
FieldInfo info = state.fieldInfos.fieldInfo(field);
191+
return new RescorerOffHeapVectorValues(vectorValues, info.getVectorSimilarityFunction(), vectorScorer);
192+
}
193+
194+
@Override
195+
public ByteVectorValues getByteVectorValues(String field) throws IOException {
196+
return inner.getByteVectorValues(field);
197+
}
198+
199+
@Override
200+
public long ramBytesUsed() {
201+
return inner.ramBytesUsed();
202+
}
203+
204+
static class RescorerOffHeapVectorValues extends FloatVectorValues implements BulkScorableFloatVectorValues {
205+
VectorSimilarityFunction similarityFunction;
206+
FloatVectorValues inner;
207+
IndexInput inputSlice;
208+
FlatVectorsScorer scorer;
209+
210+
RescorerOffHeapVectorValues(FloatVectorValues inner, VectorSimilarityFunction similarityFunction, FlatVectorsScorer scorer) {
211+
this.inner = inner;
212+
if (inner instanceof HasIndexSlice slice) {
213+
this.inputSlice = slice.getSlice();
214+
} else {
215+
this.inputSlice = null;
216+
}
217+
this.similarityFunction = similarityFunction;
218+
this.scorer = scorer;
219+
}
220+
221+
@Override
222+
public float[] vectorValue(int ord) throws IOException {
223+
return inner.vectorValue(ord);
224+
}
225+
226+
@Override
227+
public int dimension() {
228+
return inner.dimension();
229+
}
230+
231+
@Override
232+
public int size() {
233+
return inner.size();
234+
}
235+
236+
@Override
237+
public RescorerOffHeapVectorValues copy() throws IOException {
238+
return new RescorerOffHeapVectorValues(inner.copy(), similarityFunction, scorer);
239+
}
240+
241+
@Override
242+
public BulkVectorScorer bulkRescorer(float[] target) throws IOException {
243+
return bulkScorer(target);
244+
}
245+
246+
@Override
247+
public BulkVectorScorer bulkScorer(float[] target) throws IOException {
248+
DocIndexIterator indexIterator = inner.iterator();
249+
RandomVectorScorer randomScorer = scorer.getRandomVectorScorer(similarityFunction, inner, target);
250+
return new PreFetchingFloatBulkScorer(randomScorer, indexIterator, inputSlice, dimension() * Float.BYTES);
251+
}
252+
253+
@Override
254+
public VectorScorer scorer(float[] target) throws IOException {
255+
return inner.scorer(target);
256+
}
257+
}
258+
259+
private record PreFetchingFloatBulkScorer(
260+
RandomVectorScorer inner,
261+
KnnVectorValues.DocIndexIterator indexIterator,
262+
IndexInput inputSlice,
263+
int byteSize
264+
) implements BulkScorableVectorValues.BulkVectorScorer {
265+
266+
@Override
267+
public float score() throws IOException {
268+
return inner.score(indexIterator.index());
269+
}
270+
271+
@Override
272+
public DocIdSetIterator iterator() {
273+
return indexIterator;
274+
}
275+
276+
@Override
277+
public Bulk bulk(DocIdSetIterator matchingDocs) throws IOException {
278+
DocIdSetIterator conjunctionScorer = matchingDocs == null
279+
? indexIterator
280+
: ConjunctionUtils.intersectIterators(List.of(matchingDocs, indexIterator));
281+
if (conjunctionScorer.docID() == -1) {
282+
conjunctionScorer.nextDoc();
283+
}
284+
return new FloatBulkScorer(inner, inputSlice, byteSize, 32, indexIterator, conjunctionScorer);
285+
}
286+
}
287+
288+
private static class FloatBulkScorer implements BulkScorableVectorValues.BulkVectorScorer.Bulk {
289+
private final KnnVectorValues.DocIndexIterator indexIterator;
290+
private final DocIdSetIterator matchingDocs;
291+
private final RandomVectorScorer inner;
292+
private final int bulkSize;
293+
private final IndexInput inputSlice;
294+
private final int byteSize;
295+
private final int[] docBuffer;
296+
private final float[] scoreBuffer;
297+
298+
FloatBulkScorer(
299+
RandomVectorScorer fvv,
300+
IndexInput inputSlice,
301+
int byteSize,
302+
int bulkSize,
303+
KnnVectorValues.DocIndexIterator iterator,
304+
DocIdSetIterator matchingDocs
305+
) {
306+
this.indexIterator = iterator;
307+
this.matchingDocs = matchingDocs;
308+
this.inner = fvv;
309+
this.bulkSize = bulkSize;
310+
this.inputSlice = inputSlice;
311+
this.docBuffer = new int[bulkSize];
312+
this.scoreBuffer = new float[bulkSize];
313+
this.byteSize = byteSize;
314+
}
315+
316+
@Override
317+
public void nextDocsAndScores(int nextCount, Bits liveDocs, DocAndFloatFeatureBuffer buffer) throws IOException {
318+
buffer.growNoCopy(nextCount);
319+
int size = 0;
320+
for (int doc = matchingDocs.docID(); doc != DocIdSetIterator.NO_MORE_DOCS && size < nextCount; doc = matchingDocs
321+
.nextDoc()) {
322+
if (liveDocs == null || liveDocs.get(doc)) {
323+
buffer.docs[size++] = indexIterator.index();
324+
}
325+
}
326+
int loopBound = size - (size % bulkSize);
327+
int i = 0;
328+
for (; i < loopBound; i += bulkSize) {
329+
for (int j = 0; j < bulkSize; j++) {
330+
long ord = buffer.docs[i + j];
331+
inputSlice.prefetch(ord * byteSize, byteSize);
332+
}
333+
System.arraycopy(buffer.docs, i, docBuffer, 0, bulkSize);
334+
inner.bulkScore(docBuffer, scoreBuffer, bulkSize);
335+
System.arraycopy(scoreBuffer, 0, buffer.features, i, bulkSize);
336+
}
337+
int countLeft = size - i;
338+
for (int j = i; j < size; j++) {
339+
long ord = buffer.docs[j];
340+
inputSlice.prefetch(ord * byteSize, byteSize);
341+
}
342+
System.arraycopy(buffer.docs, i, docBuffer, 0, countLeft);
343+
inner.bulkScore(docBuffer, scoreBuffer, countLeft);
344+
System.arraycopy(scoreBuffer, 0, buffer.features, i, countLeft);
345+
buffer.size = size;
346+
// fix the docIds in buffer
347+
for (int j = 0; j < size; j++) {
348+
buffer.docs[j] = inner.ordToDoc(buffer.docs[j]);
349+
}
350+
}
351+
}
352+
}
132353
}

server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ private void rescoreBulk(
317317
List<ScoreDoc> queue,
318318
DocIdSetIterator filterIterator
319319
) throws IOException {
320-
BulkScorableVectorValues.BulkVectorScorer vectorReScorer = rescorableVectorValues.rescorer(floatTarget);
320+
BulkScorableVectorValues.BulkVectorScorer vectorReScorer = rescorableVectorValues.bulkRescorer(floatTarget);
321321
var iterator = vectorReScorer.iterator();
322322
BulkScorableVectorValues.BulkVectorScorer.Bulk bulkScorer = vectorReScorer.bulk(filterIterator);
323323
DocAndFloatFeatureBuffer buffer = new DocAndFloatFeatureBuffer();

0 commit comments

Comments
 (0)