Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .idea/inspectionProfiles/Project_Default.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions docs/changelog/135380.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 135380
summary: Add DirectIO bulk rescoring
area: Vector Search
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ interface BulkVectorScorer extends VectorScorer {
interface BulkScorer {
/**
* Scores up to {@code nextCount} docs in the provided {@code buffer}.
* Returns the maxScore of docs scored.
*/
void nextDocsAndScores(int nextCount, Bits liveDocs, DocAndFloatFeatureBuffer buffer) throws IOException;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,36 @@
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.codecs.lucene95.HasIndexSlice;
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsReader;
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsWriter;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.ConjunctionUtils;
import org.apache.lucene.search.DocAndFloatFeatureBuffer;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.FlushInfo;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.MergeInfo;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.index.codec.vectors.BulkScorableFloatVectorValues;
import org.elasticsearch.index.codec.vectors.BulkScorableVectorValues;
import org.elasticsearch.index.codec.vectors.DirectIOCapableFlatVectorsFormat;
import org.elasticsearch.index.codec.vectors.MergeReaderWrapper;
import org.elasticsearch.index.codec.vectors.es818.DirectIOHint;
import org.elasticsearch.index.store.FsDirectoryFactory;

import java.io.IOException;
import java.util.List;
import java.util.Set;

public class DirectIOCapableLucene99FlatVectorsFormat extends DirectIOCapableFlatVectorsFormat {
Expand Down Expand Up @@ -71,7 +87,11 @@ public FlatVectorsReader fieldsReader(SegmentReadState state, boolean useDirectI
);
// Use mmap for merges and direct I/O for searches.
return new MergeReaderWrapper(
new Lucene99FlatVectorsReader(directIOState, vectorsScorer),
new Lucene99FlatBulkScoringVectorsReader(
directIOState,
new Lucene99FlatVectorsReader(directIOState, vectorsScorer),
vectorsScorer
),
new Lucene99FlatVectorsReader(state, vectorsScorer)
);
} else {
Expand Down Expand Up @@ -113,4 +133,203 @@ public IOContext withHints(FileOpenHint... hints) {
return new DirectIOContext(Set.of(hints));
}
}

static class Lucene99FlatBulkScoringVectorsReader extends FlatVectorsReader {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So much ceremony required...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@thecoop yes, some of it should go away with Lucene 10.4 the nice thing is that the top level format name remains unchanged, so its a easy removal once the new lucene is released.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a PR for the corresponding changes on lucene_snapshot?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@thecoop no, not yet.

private final Lucene99FlatVectorsReader inner;
private final SegmentReadState state;

Lucene99FlatBulkScoringVectorsReader(SegmentReadState state, Lucene99FlatVectorsReader inner, FlatVectorsScorer scorer) {
super(scorer);
this.inner = inner;
this.state = state;
}

@Override
public void close() throws IOException {
inner.close();
}

@Override
public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException {
return inner.getRandomVectorScorer(field, target);
}

@Override
public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException {
return inner.getRandomVectorScorer(field, target);
}

@Override
public void checkIntegrity() throws IOException {
inner.checkIntegrity();
}

@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
FloatVectorValues vectorValues = inner.getFloatVectorValues(field);
if (vectorValues == null || vectorValues.size() == 0) {
return null;
}
FieldInfo info = state.fieldInfos.fieldInfo(field);
return new RescorerOffHeapVectorValues(vectorValues, info.getVectorSimilarityFunction(), vectorScorer);
}

@Override
public ByteVectorValues getByteVectorValues(String field) throws IOException {
return inner.getByteVectorValues(field);
}

@Override
public long ramBytesUsed() {
return inner.ramBytesUsed();
}
}

static class RescorerOffHeapVectorValues extends FloatVectorValues implements BulkScorableFloatVectorValues {
private final VectorSimilarityFunction similarityFunction;
private final FloatVectorValues inner;
private final IndexInput inputSlice;
private final FlatVectorsScorer scorer;

RescorerOffHeapVectorValues(FloatVectorValues inner, VectorSimilarityFunction similarityFunction, FlatVectorsScorer scorer) {
this.inner = inner;
if (inner instanceof HasIndexSlice slice) {
this.inputSlice = slice.getSlice();
} else {
this.inputSlice = null;
}
this.similarityFunction = similarityFunction;
this.scorer = scorer;
}

@Override
public float[] vectorValue(int ord) throws IOException {
return inner.vectorValue(ord);
}

@Override
public int dimension() {
return inner.dimension();
}

@Override
public int size() {
return inner.size();
}

@Override
public RescorerOffHeapVectorValues copy() throws IOException {
return new RescorerOffHeapVectorValues(inner.copy(), similarityFunction, scorer);
}

@Override
public BulkVectorScorer bulkRescorer(float[] target) throws IOException {
return bulkScorer(target);
}

@Override
public BulkVectorScorer bulkScorer(float[] target) throws IOException {
DocIndexIterator indexIterator = inner.iterator();
RandomVectorScorer randomScorer = scorer.getRandomVectorScorer(similarityFunction, inner, target);
return new PreFetchingFloatBulkScorer(randomScorer, indexIterator, inputSlice, dimension() * Float.BYTES);
}

@Override
public VectorScorer scorer(float[] target) throws IOException {
return inner.scorer(target);
}
}

private record PreFetchingFloatBulkScorer(
RandomVectorScorer inner,
KnnVectorValues.DocIndexIterator indexIterator,
IndexInput inputSlice,
int byteSize
) implements BulkScorableVectorValues.BulkVectorScorer {

@Override
public float score() throws IOException {
return inner.score(indexIterator.index());
}

@Override
public DocIdSetIterator iterator() {
return indexIterator;
}

@Override
public BulkScorer bulkScore(DocIdSetIterator matchingDocs) throws IOException {
DocIdSetIterator conjunctionScorer = matchingDocs == null
? indexIterator
: ConjunctionUtils.intersectIterators(List.of(matchingDocs, indexIterator));
if (conjunctionScorer.docID() == -1) {
conjunctionScorer.nextDoc();
}
return new FloatBulkScorer(inner, inputSlice, byteSize, 32, indexIterator, conjunctionScorer);
}
}

private static class FloatBulkScorer implements BulkScorableVectorValues.BulkVectorScorer.BulkScorer {
private final KnnVectorValues.DocIndexIterator indexIterator;
private final DocIdSetIterator matchingDocs;
private final RandomVectorScorer inner;
private final int bulkSize;
private final IndexInput inputSlice;
private final int byteSize;
private final int[] docBuffer;
private final float[] scoreBuffer;

FloatBulkScorer(
RandomVectorScorer fvv,
IndexInput inputSlice,
int byteSize,
int bulkSize,
KnnVectorValues.DocIndexIterator iterator,
DocIdSetIterator matchingDocs
) {
this.indexIterator = iterator;
this.matchingDocs = matchingDocs;
this.inner = fvv;
this.bulkSize = bulkSize;
this.inputSlice = inputSlice;
this.docBuffer = new int[bulkSize];
this.scoreBuffer = new float[bulkSize];
this.byteSize = byteSize;
}

@Override
public void nextDocsAndScores(int nextCount, Bits liveDocs, DocAndFloatFeatureBuffer buffer) throws IOException {
buffer.growNoCopy(nextCount);
int size = 0;
for (int doc = matchingDocs.docID(); doc != DocIdSetIterator.NO_MORE_DOCS && size < nextCount; doc = matchingDocs.nextDoc()) {
if (liveDocs == null || liveDocs.get(doc)) {
buffer.docs[size++] = indexIterator.index();
}
}
int loopBound = size - (size % bulkSize);
int i = 0;
for (; i < loopBound; i += bulkSize) {
for (int j = 0; j < bulkSize; j++) {
long ord = buffer.docs[i + j];
inputSlice.prefetch(ord * byteSize, byteSize);
}
System.arraycopy(buffer.docs, i, docBuffer, 0, bulkSize);
inner.bulkScore(docBuffer, scoreBuffer, bulkSize);
System.arraycopy(scoreBuffer, 0, buffer.features, i, bulkSize);
}
int countLeft = size - i;
for (int j = i; j < size; j++) {
long ord = buffer.docs[j];
inputSlice.prefetch(ord * byteSize, byteSize);
}
System.arraycopy(buffer.docs, i, docBuffer, 0, countLeft);
inner.bulkScore(docBuffer, scoreBuffer, countLeft);
System.arraycopy(scoreBuffer, 0, buffer.features, i, countLeft);
buffer.size = size;
// fix the docIds in buffer
for (int j = 0; j < size; j++) {
buffer.docs[j] = inner.ordToDoc(buffer.docs[j]);
}
}
}
}
Loading