Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/135380.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 135380
summary: Add DirectIO bulk rescoring
area: Vector Search
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ interface BulkVectorScorer extends VectorScorer {
interface BulkScorer {
/**
* Scores up to {@code nextCount} docs in the provided {@code buffer}.
* Returns the maxScore of docs scored.
*/
void nextDocsAndScores(int nextCount, Bits liveDocs, DocAndFloatFeatureBuffer buffer) throws IOException;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,36 @@
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.codecs.lucene95.HasIndexSlice;
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsReader;
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsWriter;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.ConjunctionUtils;
import org.apache.lucene.search.DocAndFloatFeatureBuffer;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.FlushInfo;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.MergeInfo;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.index.codec.vectors.BulkScorableFloatVectorValues;
import org.elasticsearch.index.codec.vectors.BulkScorableVectorValues;
import org.elasticsearch.index.codec.vectors.DirectIOCapableFlatVectorsFormat;
import org.elasticsearch.index.codec.vectors.MergeReaderWrapper;
import org.elasticsearch.index.codec.vectors.es818.DirectIOHint;
import org.elasticsearch.index.store.FsDirectoryFactory;

import java.io.IOException;
import java.util.List;
import java.util.Set;

public class DirectIOCapableLucene99FlatVectorsFormat extends DirectIOCapableFlatVectorsFormat {
Expand Down Expand Up @@ -71,7 +87,11 @@ public FlatVectorsReader fieldsReader(SegmentReadState state, boolean useDirectI
);
// Use mmap for merges and direct I/O for searches.
return new MergeReaderWrapper(
new Lucene99FlatVectorsReader(directIOState, vectorsScorer),
new Lucene99FlatBulkScoringVectorsReader(
directIOState,
new Lucene99FlatVectorsReader(directIOState, vectorsScorer),
vectorsScorer
),
new Lucene99FlatVectorsReader(state, vectorsScorer)
);
} else {
Expand Down Expand Up @@ -113,4 +133,203 @@ public IOContext withHints(FileOpenHint... hints) {
return new DirectIOContext(Set.of(hints));
}
}

static class Lucene99FlatBulkScoringVectorsReader extends FlatVectorsReader {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So much ceremony required...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@thecoop yes, some of it should go away with Lucene 10.4 the nice thing is that the top level format name remains unchanged, so its a easy removal once the new lucene is released.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a PR for the corresponding changes on lucene_snapshot?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@thecoop no, not yet.

private final Lucene99FlatVectorsReader inner;
private final SegmentReadState state;

Lucene99FlatBulkScoringVectorsReader(SegmentReadState state, Lucene99FlatVectorsReader inner, FlatVectorsScorer scorer) {
super(scorer);
this.inner = inner;
this.state = state;
}

@Override
public void close() throws IOException {
inner.close();
}

@Override
public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException {
return inner.getRandomVectorScorer(field, target);
}

@Override
public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException {
return inner.getRandomVectorScorer(field, target);
}

@Override
public void checkIntegrity() throws IOException {
inner.checkIntegrity();
}

@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
FloatVectorValues vectorValues = inner.getFloatVectorValues(field);
if (vectorValues == null || vectorValues.size() == 0) {
return null;
}
FieldInfo info = state.fieldInfos.fieldInfo(field);
return new RescorerOffHeapVectorValues(vectorValues, info.getVectorSimilarityFunction(), vectorScorer);
}

@Override
public ByteVectorValues getByteVectorValues(String field) throws IOException {
return inner.getByteVectorValues(field);
}

@Override
public long ramBytesUsed() {
return inner.ramBytesUsed();
}
}

static class RescorerOffHeapVectorValues extends FloatVectorValues implements BulkScorableFloatVectorValues {
VectorSimilarityFunction similarityFunction;
FloatVectorValues inner;
IndexInput inputSlice;
FlatVectorsScorer scorer;

RescorerOffHeapVectorValues(FloatVectorValues inner, VectorSimilarityFunction similarityFunction, FlatVectorsScorer scorer) {
this.inner = inner;
if (inner instanceof HasIndexSlice slice) {
this.inputSlice = slice.getSlice();
} else {
this.inputSlice = null;
}
this.similarityFunction = similarityFunction;
this.scorer = scorer;
}

@Override
public float[] vectorValue(int ord) throws IOException {
return inner.vectorValue(ord);
}

@Override
public int dimension() {
return inner.dimension();
}

@Override
public int size() {
return inner.size();
}

@Override
public RescorerOffHeapVectorValues copy() throws IOException {
return new RescorerOffHeapVectorValues(inner.copy(), similarityFunction, scorer);
}

@Override
public BulkVectorScorer bulkRescorer(float[] target) throws IOException {
return bulkScorer(target);
}

@Override
public BulkVectorScorer bulkScorer(float[] target) throws IOException {
DocIndexIterator indexIterator = inner.iterator();
RandomVectorScorer randomScorer = scorer.getRandomVectorScorer(similarityFunction, inner, target);
return new PreFetchingFloatBulkScorer(randomScorer, indexIterator, inputSlice, dimension() * Float.BYTES);
}

@Override
public VectorScorer scorer(float[] target) throws IOException {
return inner.scorer(target);
}
}

private record PreFetchingFloatBulkScorer(
RandomVectorScorer inner,
KnnVectorValues.DocIndexIterator indexIterator,
IndexInput inputSlice,
int byteSize
) implements BulkScorableVectorValues.BulkVectorScorer {

@Override
public float score() throws IOException {
return inner.score(indexIterator.index());
}

@Override
public DocIdSetIterator iterator() {
return indexIterator;
}

@Override
public BulkScorer bulkScore(DocIdSetIterator matchingDocs) throws IOException {
DocIdSetIterator conjunctionScorer = matchingDocs == null
? indexIterator
: ConjunctionUtils.intersectIterators(List.of(matchingDocs, indexIterator));
if (conjunctionScorer.docID() == -1) {
conjunctionScorer.nextDoc();
}
return new FloatBulkScorer(inner, inputSlice, byteSize, 32, indexIterator, conjunctionScorer);
}
}

private static class FloatBulkScorer implements BulkScorableVectorValues.BulkVectorScorer.BulkScorer {
private final KnnVectorValues.DocIndexIterator indexIterator;
private final DocIdSetIterator matchingDocs;
private final RandomVectorScorer inner;
private final int bulkSize;
private final IndexInput inputSlice;
private final int byteSize;
private final int[] docBuffer;
private final float[] scoreBuffer;

FloatBulkScorer(
RandomVectorScorer fvv,
IndexInput inputSlice,
int byteSize,
int bulkSize,
KnnVectorValues.DocIndexIterator iterator,
DocIdSetIterator matchingDocs
) {
this.indexIterator = iterator;
this.matchingDocs = matchingDocs;
this.inner = fvv;
this.bulkSize = bulkSize;
this.inputSlice = inputSlice;
this.docBuffer = new int[bulkSize];
this.scoreBuffer = new float[bulkSize];
this.byteSize = byteSize;
}

@Override
public void nextDocsAndScores(int nextCount, Bits liveDocs, DocAndFloatFeatureBuffer buffer) throws IOException {
buffer.growNoCopy(nextCount);
int size = 0;
for (int doc = matchingDocs.docID(); doc != DocIdSetIterator.NO_MORE_DOCS && size < nextCount; doc = matchingDocs.nextDoc()) {
if (liveDocs == null || liveDocs.get(doc)) {
buffer.docs[size++] = indexIterator.index();
}
}
int loopBound = size - (size % bulkSize);
int i = 0;
for (; i < loopBound; i += bulkSize) {
for (int j = 0; j < bulkSize; j++) {
long ord = buffer.docs[i + j];
inputSlice.prefetch(ord * byteSize, byteSize);
}
System.arraycopy(buffer.docs, i, docBuffer, 0, bulkSize);
inner.bulkScore(docBuffer, scoreBuffer, bulkSize);
System.arraycopy(scoreBuffer, 0, buffer.features, i, bulkSize);
}
int countLeft = size - i;
for (int j = i; j < size; j++) {
long ord = buffer.docs[j];
inputSlice.prefetch(ord * byteSize, byteSize);
}
System.arraycopy(buffer.docs, i, docBuffer, 0, countLeft);
inner.bulkScore(docBuffer, scoreBuffer, countLeft);
System.arraycopy(scoreBuffer, 0, buffer.features, i, countLeft);
buffer.size = size;
// fix the docIds in buffer
for (int j = 0; j < size; j++) {
buffer.docs[j] = inner.ordToDoc(buffer.docs[j]);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,22 @@ public abstract class RescoreKnnVectorQuery extends Query implements QueryProfil
protected final int k;
protected final Query innerQuery;
protected long vectorOperations = 0;
protected final boolean enforceSingleRescore;

private RescoreKnnVectorQuery(
String fieldName,
float[] floatTarget,
VectorSimilarityFunction vectorSimilarityFunction,
int k,
Query innerQuery
Query innerQuery,
boolean enforceSingleRescore
) {
this.fieldName = fieldName;
this.floatTarget = floatTarget;
this.vectorSimilarityFunction = vectorSimilarityFunction;
this.k = k;
this.innerQuery = innerQuery;
this.enforceSingleRescore = enforceSingleRescore;
}

/**
Expand All @@ -95,9 +98,28 @@ public static RescoreKnnVectorQuery fromInnerQuery(
|| (innerQuery instanceof KnnByteVectorQuery bQuery && bQuery.getK() == rescoreK)
|| (innerQuery instanceof AbstractIVFKnnVectorQuery ivfQuery && ivfQuery.k == rescoreK)) {
// Queries that return only the top `k` results and do not require reduction before re-scoring.
return new InlineRescoreQuery(fieldName, floatTarget, vectorSimilarityFunction, k, innerQuery);
return new InlineRescoreQuery(fieldName, floatTarget, vectorSimilarityFunction, k, innerQuery, false);
}
return new LateRescoreQuery(fieldName, floatTarget, vectorSimilarityFunction, k, rescoreK, innerQuery, false);
}

// only used for testing purposes
static RescoreKnnVectorQuery fromInnerQuery(
String fieldName,
float[] floatTarget,
VectorSimilarityFunction vectorSimilarityFunction,
int k,
int rescoreK,
boolean enforceSingleRescore,
Query innerQuery
) {
if ((innerQuery instanceof KnnFloatVectorQuery fQuery && fQuery.getK() == rescoreK)
|| (innerQuery instanceof KnnByteVectorQuery bQuery && bQuery.getK() == rescoreK)
|| (innerQuery instanceof AbstractIVFKnnVectorQuery ivfQuery && ivfQuery.k == rescoreK)) {
// Queries that return only the top `k` results and do not require reduction before re-scoring.
return new InlineRescoreQuery(fieldName, floatTarget, vectorSimilarityFunction, k, innerQuery, enforceSingleRescore);
}
return new LateRescoreQuery(fieldName, floatTarget, vectorSimilarityFunction, k, rescoreK, innerQuery);
return new LateRescoreQuery(fieldName, floatTarget, vectorSimilarityFunction, k, rescoreK, innerQuery, enforceSingleRescore);
}

public Query innerQuery() {
Expand Down Expand Up @@ -164,14 +186,15 @@ private InlineRescoreQuery(
float[] floatTarget,
VectorSimilarityFunction vectorSimilarityFunction,
int k,
Query innerQuery
Query innerQuery,
boolean enforceSingleRescore
) {
super(fieldName, floatTarget, vectorSimilarityFunction, k, innerQuery);
super(fieldName, floatTarget, vectorSimilarityFunction, k, innerQuery, enforceSingleRescore);
}

@Override
public Query rewrite(IndexSearcher searcher) throws IOException {
var rescoreQuery = new DirectRescoreKnnVectorQuery(fieldName, floatTarget, innerQuery);
var rescoreQuery = new DirectRescoreKnnVectorQuery(fieldName, floatTarget, innerQuery, enforceSingleRescore);
var topDocs = searcher.search(rescoreQuery, k);
vectorOperations = topDocs.totalHits.value();
return new KnnScoreDocQuery(topDocs.scoreDocs, searcher.getIndexReader());
Expand Down Expand Up @@ -199,9 +222,10 @@ private LateRescoreQuery(
VectorSimilarityFunction vectorSimilarityFunction,
int k,
int rescoreK,
Query innerQuery
Query innerQuery,
boolean enforceSingleRescore
) {
super(fieldName, floatTarget, vectorSimilarityFunction, k, innerQuery);
super(fieldName, floatTarget, vectorSimilarityFunction, k, innerQuery, enforceSingleRescore);
this.rescoreK = rescoreK;
}

Expand All @@ -214,7 +238,7 @@ public Query rewrite(IndexSearcher searcher) throws IOException {

// Retrieve top `k` documents from the top `rescoreK` query
var topDocsQuery = new KnnScoreDocQuery(topDocs.scoreDocs, searcher.getIndexReader());
var rescoreQuery = new DirectRescoreKnnVectorQuery(fieldName, floatTarget, topDocsQuery);
var rescoreQuery = new DirectRescoreKnnVectorQuery(fieldName, floatTarget, topDocsQuery, enforceSingleRescore);
var rescoreTopDocs = searcher.search(rescoreQuery.rewrite(searcher), k);
return new KnnScoreDocQuery(rescoreTopDocs.scoreDocs, searcher.getIndexReader());
}
Expand All @@ -237,11 +261,13 @@ private static class DirectRescoreKnnVectorQuery extends Query {
private final float[] floatTarget;
private final String fieldName;
private final Query innerQuery;
private final boolean enforceSingleRescore;

DirectRescoreKnnVectorQuery(String fieldName, float[] floatTarget, Query innerQuery) {
DirectRescoreKnnVectorQuery(String fieldName, float[] floatTarget, Query innerQuery, boolean enforceSingleRescore) {
this.fieldName = fieldName;
this.floatTarget = floatTarget;
this.innerQuery = innerQuery;
this.enforceSingleRescore = enforceSingleRescore;
}

@Override
Expand Down Expand Up @@ -274,7 +300,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
continue;
}
var filterIterator = scorer.iterator();
if (knnVectorValues instanceof BulkScorableFloatVectorValues rescorableVectorValues) {
if (knnVectorValues instanceof BulkScorableFloatVectorValues rescorableVectorValues && enforceSingleRescore == false) {
rescoreBulk(leaf.docBase, rescorableVectorValues, results, filterIterator);
} else {
rescoreIndividually(
Expand Down
Loading