diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
index 34ff98c18372f..c99248332b8a9 100644
--- a/.idea/inspectionProfiles/Project_Default.xml
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -12,4 +12,4 @@
-
+
\ No newline at end of file
diff --git a/docs/changelog/135380.yaml b/docs/changelog/135380.yaml
new file mode 100644
index 0000000000000..46b1fa5db9d6d
--- /dev/null
+++ b/docs/changelog/135380.yaml
@@ -0,0 +1,5 @@
+pr: 135380
+summary: Add DirectIO bulk rescoring
+area: Vector Search
+type: enhancement
+issues: []
diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/BulkScorableVectorValues.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/BulkScorableVectorValues.java
index 800eefb3f6118..ce3601059df08 100644
--- a/server/src/main/java/org/elasticsearch/index/codec/vectors/BulkScorableVectorValues.java
+++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/BulkScorableVectorValues.java
@@ -30,7 +30,6 @@ interface BulkVectorScorer extends VectorScorer {
interface BulkScorer {
/**
* Scores up to {@code nextCount} docs in the provided {@code buffer}.
- * Returns the maxScore of docs scored.
*/
void nextDocsAndScores(int nextCount, Bits liveDocs, DocAndFloatFeatureBuffer buffer) throws IOException;
}
diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java
index a15cbba346353..8d9e8bf448adc 100644
--- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java
+++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java
@@ -11,20 +11,36 @@
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
+import org.apache.lucene.codecs.lucene95.HasIndexSlice;
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsReader;
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsWriter;
+import org.apache.lucene.index.ByteVectorValues;
+import org.apache.lucene.index.FieldInfo;
+import org.apache.lucene.index.FloatVectorValues;
+import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.ConjunctionUtils;
+import org.apache.lucene.search.DocAndFloatFeatureBuffer;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.FlushInfo;
import org.apache.lucene.store.IOContext;
+import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.MergeInfo;
+import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.elasticsearch.common.util.set.Sets;
+import org.elasticsearch.index.codec.vectors.BulkScorableFloatVectorValues;
+import org.elasticsearch.index.codec.vectors.BulkScorableVectorValues;
import org.elasticsearch.index.codec.vectors.DirectIOCapableFlatVectorsFormat;
import org.elasticsearch.index.codec.vectors.MergeReaderWrapper;
import org.elasticsearch.index.codec.vectors.es818.DirectIOHint;
import org.elasticsearch.index.store.FsDirectoryFactory;
import java.io.IOException;
+import java.util.List;
import java.util.Set;
public class DirectIOCapableLucene99FlatVectorsFormat extends DirectIOCapableFlatVectorsFormat {
@@ -71,7 +87,11 @@ public FlatVectorsReader fieldsReader(SegmentReadState state, boolean useDirectI
);
// Use mmap for merges and direct I/O for searches.
return new MergeReaderWrapper(
- new Lucene99FlatVectorsReader(directIOState, vectorsScorer),
+ new Lucene99FlatBulkScoringVectorsReader(
+ directIOState,
+ new Lucene99FlatVectorsReader(directIOState, vectorsScorer),
+ vectorsScorer
+ ),
new Lucene99FlatVectorsReader(state, vectorsScorer)
);
} else {
@@ -113,4 +133,203 @@ public IOContext withHints(FileOpenHint... hints) {
return new DirectIOContext(Set.of(hints));
}
}
+
+ static class Lucene99FlatBulkScoringVectorsReader extends FlatVectorsReader {
+ private final Lucene99FlatVectorsReader inner;
+ private final SegmentReadState state;
+
+ Lucene99FlatBulkScoringVectorsReader(SegmentReadState state, Lucene99FlatVectorsReader inner, FlatVectorsScorer scorer) {
+ super(scorer);
+ this.inner = inner;
+ this.state = state;
+ }
+
+ @Override
+ public void close() throws IOException {
+ inner.close();
+ }
+
+ @Override
+ public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException {
+ return inner.getRandomVectorScorer(field, target);
+ }
+
+ @Override
+ public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException {
+ return inner.getRandomVectorScorer(field, target);
+ }
+
+ @Override
+ public void checkIntegrity() throws IOException {
+ inner.checkIntegrity();
+ }
+
+ @Override
+ public FloatVectorValues getFloatVectorValues(String field) throws IOException {
+ FloatVectorValues vectorValues = inner.getFloatVectorValues(field);
+ if (vectorValues == null || vectorValues.size() == 0) {
+ return null;
+ }
+ FieldInfo info = state.fieldInfos.fieldInfo(field);
+ return new RescorerOffHeapVectorValues(vectorValues, info.getVectorSimilarityFunction(), vectorScorer);
+ }
+
+ @Override
+ public ByteVectorValues getByteVectorValues(String field) throws IOException {
+ return inner.getByteVectorValues(field);
+ }
+
+ @Override
+ public long ramBytesUsed() {
+ return inner.ramBytesUsed();
+ }
+ }
+
+ static class RescorerOffHeapVectorValues extends FloatVectorValues implements BulkScorableFloatVectorValues {
+ private final VectorSimilarityFunction similarityFunction;
+ private final FloatVectorValues inner;
+ private final IndexInput inputSlice;
+ private final FlatVectorsScorer scorer;
+
+ RescorerOffHeapVectorValues(FloatVectorValues inner, VectorSimilarityFunction similarityFunction, FlatVectorsScorer scorer) {
+ this.inner = inner;
+ if (inner instanceof HasIndexSlice slice) {
+ this.inputSlice = slice.getSlice();
+ } else {
+ this.inputSlice = null;
+ }
+ this.similarityFunction = similarityFunction;
+ this.scorer = scorer;
+ }
+
+ @Override
+ public float[] vectorValue(int ord) throws IOException {
+ return inner.vectorValue(ord);
+ }
+
+ @Override
+ public int dimension() {
+ return inner.dimension();
+ }
+
+ @Override
+ public int size() {
+ return inner.size();
+ }
+
+ @Override
+ public RescorerOffHeapVectorValues copy() throws IOException {
+ return new RescorerOffHeapVectorValues(inner.copy(), similarityFunction, scorer);
+ }
+
+ @Override
+ public BulkVectorScorer bulkRescorer(float[] target) throws IOException {
+ return bulkScorer(target);
+ }
+
+ @Override
+ public BulkVectorScorer bulkScorer(float[] target) throws IOException {
+ DocIndexIterator indexIterator = inner.iterator();
+ RandomVectorScorer randomScorer = scorer.getRandomVectorScorer(similarityFunction, inner, target);
+ return new PreFetchingFloatBulkScorer(randomScorer, indexIterator, inputSlice, dimension() * Float.BYTES);
+ }
+
+ @Override
+ public VectorScorer scorer(float[] target) throws IOException {
+ return inner.scorer(target);
+ }
+ }
+
+ private record PreFetchingFloatBulkScorer(
+ RandomVectorScorer inner,
+ KnnVectorValues.DocIndexIterator indexIterator,
+ IndexInput inputSlice,
+ int byteSize
+ ) implements BulkScorableVectorValues.BulkVectorScorer {
+
+ @Override
+ public float score() throws IOException {
+ return inner.score(indexIterator.index());
+ }
+
+ @Override
+ public DocIdSetIterator iterator() {
+ return indexIterator;
+ }
+
+ @Override
+ public BulkScorer bulkScore(DocIdSetIterator matchingDocs) throws IOException {
+ DocIdSetIterator conjunctionScorer = matchingDocs == null
+ ? indexIterator
+ : ConjunctionUtils.intersectIterators(List.of(matchingDocs, indexIterator));
+ if (conjunctionScorer.docID() == -1) {
+ conjunctionScorer.nextDoc();
+ }
+ return new FloatBulkScorer(inner, inputSlice, byteSize, 32, indexIterator, conjunctionScorer);
+ }
+ }
+
+ private static class FloatBulkScorer implements BulkScorableVectorValues.BulkVectorScorer.BulkScorer {
+ private final KnnVectorValues.DocIndexIterator indexIterator;
+ private final DocIdSetIterator matchingDocs;
+ private final RandomVectorScorer inner;
+ private final int bulkSize;
+ private final IndexInput inputSlice;
+ private final int byteSize;
+ private final int[] docBuffer;
+ private final float[] scoreBuffer;
+
+ FloatBulkScorer(
+ RandomVectorScorer fvv,
+ IndexInput inputSlice,
+ int byteSize,
+ int bulkSize,
+ KnnVectorValues.DocIndexIterator iterator,
+ DocIdSetIterator matchingDocs
+ ) {
+ this.indexIterator = iterator;
+ this.matchingDocs = matchingDocs;
+ this.inner = fvv;
+ this.bulkSize = bulkSize;
+ this.inputSlice = inputSlice;
+ this.docBuffer = new int[bulkSize];
+ this.scoreBuffer = new float[bulkSize];
+ this.byteSize = byteSize;
+ }
+
+ @Override
+ public void nextDocsAndScores(int nextCount, Bits liveDocs, DocAndFloatFeatureBuffer buffer) throws IOException {
+ buffer.growNoCopy(nextCount);
+ int size = 0;
+ for (int doc = matchingDocs.docID(); doc != DocIdSetIterator.NO_MORE_DOCS && size < nextCount; doc = matchingDocs.nextDoc()) {
+ if (liveDocs == null || liveDocs.get(doc)) {
+ buffer.docs[size++] = indexIterator.index();
+ }
+ }
+ int loopBound = size - (size % bulkSize);
+ int i = 0;
+ for (; i < loopBound; i += bulkSize) {
+ for (int j = 0; j < bulkSize; j++) {
+ long ord = buffer.docs[i + j];
+ inputSlice.prefetch(ord * byteSize, byteSize);
+ }
+ System.arraycopy(buffer.docs, i, docBuffer, 0, bulkSize);
+ inner.bulkScore(docBuffer, scoreBuffer, bulkSize);
+ System.arraycopy(scoreBuffer, 0, buffer.features, i, bulkSize);
+ }
+ int countLeft = size - i;
+ for (int j = i; j < size; j++) {
+ long ord = buffer.docs[j];
+ inputSlice.prefetch(ord * byteSize, byteSize);
+ }
+ System.arraycopy(buffer.docs, i, docBuffer, 0, countLeft);
+ inner.bulkScore(docBuffer, scoreBuffer, countLeft);
+ System.arraycopy(scoreBuffer, 0, buffer.features, i, countLeft);
+ buffer.size = size;
+ // fix the docIds in buffer
+ for (int j = 0; j < size; j++) {
+ buffer.docs[j] = inner.ordToDoc(buffer.docs[j]);
+ }
+ }
+ }
}
diff --git a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java
index 1b0258e6810f6..ebe62fa0cba98 100644
--- a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java
+++ b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java
@@ -13,9 +13,13 @@
import org.apache.lucene.document.Document;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.DirectoryReader;
+import org.apache.lucene.index.FilterDirectoryReader;
+import org.apache.lucene.index.FilterLeafReader;
+import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
+import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.queries.function.FunctionScoreQuery;
import org.apache.lucene.search.BooleanClause;
@@ -31,11 +35,14 @@
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.store.Directory;
+import org.apache.lucene.util.Bits;
import org.elasticsearch.index.codec.Elasticsearch92Lucene103Codec;
import org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat;
import org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat;
+import org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat;
import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat;
import org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat;
import org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat;
@@ -48,6 +55,8 @@
import java.util.ArrayList;
import java.util.List;
+import static org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat.DEFAULT_CENTROIDS_PER_PARENT_CLUSTER;
+import static org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat.DEFAULT_VECTORS_PER_CLUSTER;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
@@ -116,6 +125,64 @@ public void testRescoreDocs() throws Exception {
}
}
+ public void testRescoreSingleAndBulkEquality() throws Exception {
+ int numDocs = randomIntBetween(10, 100);
+ int numDims = randomIntBetween(5, 100);
+ int k = randomIntBetween(1, numDocs - 1);
+
+ var queryVector = randomVector(numDims);
+
+ List innerQueries = new ArrayList<>();
+ innerQueries.add(new KnnFloatVectorQuery(FIELD_NAME, randomVector(numDims), (int) (k * randomFloatBetween(1.0f, 10.0f, true))));
+ innerQueries.add(
+ new BooleanQuery.Builder().add(new DenseVectorQuery.Floats(queryVector, FIELD_NAME), BooleanClause.Occur.SHOULD)
+ .add(new FieldExistsQuery(FIELD_NAME), BooleanClause.Occur.FILTER)
+ .build()
+ );
+ innerQueries.add(new MatchAllDocsQuery());
+
+ try (Directory d = newDirectory()) {
+ addRandomDocuments(numDocs, d, numDims);
+ try (DirectoryReader reader = DirectoryReader.open(d)) {
+ for (var innerQuery : innerQueries) {
+ RescoreKnnVectorQuery rescoreKnnVectorQuery = RescoreKnnVectorQuery.fromInnerQuery(
+ FIELD_NAME,
+ queryVector,
+ VectorSimilarityFunction.COSINE,
+ k,
+ k,
+ innerQuery
+ );
+
+ IndexSearcher searcher = newSearcher(reader, true, false);
+ TopDocs rescoredDocs = searcher.search(rescoreKnnVectorQuery, numDocs);
+ assertThat(rescoredDocs.scoreDocs.length, equalTo(k));
+
+ searcher = newSearcher(new SingleVectorQueryIndexReader(reader), true, false);
+ rescoreKnnVectorQuery = RescoreKnnVectorQuery.fromInnerQuery(
+ FIELD_NAME,
+ queryVector,
+ VectorSimilarityFunction.COSINE,
+ k,
+ k,
+ innerQuery
+ );
+ TopDocs singleRescored = searcher.search(rescoreKnnVectorQuery, numDocs);
+ assertThat(singleRescored.scoreDocs.length, equalTo(k));
+
+ // Get real scores
+ ScoreDoc[] singleRescoreDocs = singleRescored.scoreDocs;
+ int i = 0;
+ for (ScoreDoc rescoreDoc : rescoredDocs.scoreDocs) {
+ assertThat(rescoreDoc.doc, equalTo(singleRescoreDocs[i].doc));
+ assertThat(rescoreDoc.score, equalTo(singleRescoreDocs[i].score));
+ i++;
+ }
+ }
+ }
+ }
+ }
+
public void testProfiling() throws Exception {
int numDocs = randomIntBetween(10, 100);
int numDims = randomIntBetween(5, 100);
@@ -216,6 +283,7 @@ private static void addRandomDocuments(int numDocs, Directory d, int numDims) th
IndexWriterConfig iwc = new IndexWriterConfig();
// Pick codec from quantized vector formats to ensure scores use real scores when using knn rescore
KnnVectorsFormat format = randomFrom(
+ new ES920DiskBBQVectorsFormat(DEFAULT_VECTORS_PER_CLUSTER, DEFAULT_CENTROIDS_PER_PARENT_CLUSTER, randomBoolean()),
new ES818BinaryQuantizedVectorsFormat(),
new ES818HnswBinaryQuantizedVectorsFormat(),
new ES93HnswBinaryQuantizedVectorsFormat(),
@@ -243,4 +311,100 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
w.commit();
}
}
+
+ private static class SingleVectorQueryIndexReader extends FilterDirectoryReader {
+
+ /**
+ * Create a new FilterDirectoryReader that filters a passed in DirectoryReader, using the supplied
+ * SubReaderWrapper to wrap its subreader.
+ *
+ * @param in the DirectoryReader to filter
+ */
+ SingleVectorQueryIndexReader(DirectoryReader in) throws IOException {
+ super(in, new SubReaderWrapper() {
+ @Override
+ public LeafReader wrap(LeafReader reader) {
+ return new FilterLeafReader(reader) {
+ @Override
+ public CacheHelper getReaderCacheHelper() {
+ return null;
+ }
+
+ @Override
+ public CacheHelper getCoreCacheHelper() {
+ return null;
+ }
+
+ @Override
+ public FloatVectorValues getFloatVectorValues(String field) throws IOException {
+ FloatVectorValues values = super.getFloatVectorValues(field);
+ if (values == null) {
+ return null;
+ }
+ return new RegularFloatVectorValues(values);
+ }
+ };
+ }
+ });
+ }
+
+ @Override
+ protected DirectoryReader doWrapDirectoryReader(DirectoryReader in) throws IOException {
+ return new SingleVectorQueryIndexReader(in);
+ }
+
+ @Override
+ public CacheHelper getReaderCacheHelper() {
+ return null;
+ }
+ }
+
+ private static final class RegularFloatVectorValues extends FloatVectorValues {
+
+ private final FloatVectorValues in;
+
+ RegularFloatVectorValues(FloatVectorValues in) {
+ this.in = in;
+ }
+
+ @Override
+ public VectorScorer scorer(float[] target) throws IOException {
+ return in.scorer(target);
+ }
+
+ @Override
+ public int ordToDoc(int ord) {
+ return in.ordToDoc(ord);
+ }
+
+ @Override
+ public Bits getAcceptOrds(Bits acceptDocs) {
+ return in.getAcceptOrds(acceptDocs);
+ }
+
+ @Override
+ public DocIndexIterator iterator() {
+ return in.iterator();
+ }
+
+ @Override
+ public float[] vectorValue(int ord) throws IOException {
+ return in.vectorValue(ord);
+ }
+
+ @Override
+ public FloatVectorValues copy() throws IOException {
+ return new RegularFloatVectorValues(in.copy());
+ }
+
+ @Override
+ public int dimension() {
+ return in.dimension();
+ }
+
+ @Override
+ public int size() {
+ return in.size();
+ }
+ }
}