diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml index 34ff98c18372f..c99248332b8a9 100644 --- a/.idea/inspectionProfiles/Project_Default.xml +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -12,4 +12,4 @@ - + \ No newline at end of file diff --git a/docs/changelog/135380.yaml b/docs/changelog/135380.yaml new file mode 100644 index 0000000000000..46b1fa5db9d6d --- /dev/null +++ b/docs/changelog/135380.yaml @@ -0,0 +1,5 @@ +pr: 135380 +summary: Add DirectIO bulk rescoring +area: Vector Search +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/BulkScorableVectorValues.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/BulkScorableVectorValues.java index 800eefb3f6118..ce3601059df08 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/BulkScorableVectorValues.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/BulkScorableVectorValues.java @@ -30,7 +30,6 @@ interface BulkVectorScorer extends VectorScorer { interface BulkScorer { /** * Scores up to {@code nextCount} docs in the provided {@code buffer}. - * Returns the maxScore of docs scored. */ void nextDocsAndScores(int nextCount, Bits liveDocs, DocAndFloatFeatureBuffer buffer) throws IOException; } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java index a15cbba346353..8d9e8bf448adc 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java @@ -11,20 +11,36 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsReader; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene95.HasIndexSlice; import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsReader; import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsWriter; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.ConjunctionUtils; +import org.apache.lucene.search.DocAndFloatFeatureBuffer; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.FlushInfo; import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.MergeInfo; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.index.codec.vectors.BulkScorableFloatVectorValues; +import org.elasticsearch.index.codec.vectors.BulkScorableVectorValues; import org.elasticsearch.index.codec.vectors.DirectIOCapableFlatVectorsFormat; import org.elasticsearch.index.codec.vectors.MergeReaderWrapper; import org.elasticsearch.index.codec.vectors.es818.DirectIOHint; import org.elasticsearch.index.store.FsDirectoryFactory; import java.io.IOException; +import java.util.List; import java.util.Set; public class DirectIOCapableLucene99FlatVectorsFormat extends DirectIOCapableFlatVectorsFormat { @@ -71,7 +87,11 @@ public FlatVectorsReader fieldsReader(SegmentReadState state, boolean useDirectI ); // Use mmap for merges and direct I/O for searches. return new MergeReaderWrapper( - new Lucene99FlatVectorsReader(directIOState, vectorsScorer), + new Lucene99FlatBulkScoringVectorsReader( + directIOState, + new Lucene99FlatVectorsReader(directIOState, vectorsScorer), + vectorsScorer + ), new Lucene99FlatVectorsReader(state, vectorsScorer) ); } else { @@ -113,4 +133,203 @@ public IOContext withHints(FileOpenHint... hints) { return new DirectIOContext(Set.of(hints)); } } + + static class Lucene99FlatBulkScoringVectorsReader extends FlatVectorsReader { + private final Lucene99FlatVectorsReader inner; + private final SegmentReadState state; + + Lucene99FlatBulkScoringVectorsReader(SegmentReadState state, Lucene99FlatVectorsReader inner, FlatVectorsScorer scorer) { + super(scorer); + this.inner = inner; + this.state = state; + } + + @Override + public void close() throws IOException { + inner.close(); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException { + return inner.getRandomVectorScorer(field, target); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException { + return inner.getRandomVectorScorer(field, target); + } + + @Override + public void checkIntegrity() throws IOException { + inner.checkIntegrity(); + } + + @Override + public FloatVectorValues getFloatVectorValues(String field) throws IOException { + FloatVectorValues vectorValues = inner.getFloatVectorValues(field); + if (vectorValues == null || vectorValues.size() == 0) { + return null; + } + FieldInfo info = state.fieldInfos.fieldInfo(field); + return new RescorerOffHeapVectorValues(vectorValues, info.getVectorSimilarityFunction(), vectorScorer); + } + + @Override + public ByteVectorValues getByteVectorValues(String field) throws IOException { + return inner.getByteVectorValues(field); + } + + @Override + public long ramBytesUsed() { + return inner.ramBytesUsed(); + } + } + + static class RescorerOffHeapVectorValues extends FloatVectorValues implements BulkScorableFloatVectorValues { + private final VectorSimilarityFunction similarityFunction; + private final FloatVectorValues inner; + private final IndexInput inputSlice; + private final FlatVectorsScorer scorer; + + RescorerOffHeapVectorValues(FloatVectorValues inner, VectorSimilarityFunction similarityFunction, FlatVectorsScorer scorer) { + this.inner = inner; + if (inner instanceof HasIndexSlice slice) { + this.inputSlice = slice.getSlice(); + } else { + this.inputSlice = null; + } + this.similarityFunction = similarityFunction; + this.scorer = scorer; + } + + @Override + public float[] vectorValue(int ord) throws IOException { + return inner.vectorValue(ord); + } + + @Override + public int dimension() { + return inner.dimension(); + } + + @Override + public int size() { + return inner.size(); + } + + @Override + public RescorerOffHeapVectorValues copy() throws IOException { + return new RescorerOffHeapVectorValues(inner.copy(), similarityFunction, scorer); + } + + @Override + public BulkVectorScorer bulkRescorer(float[] target) throws IOException { + return bulkScorer(target); + } + + @Override + public BulkVectorScorer bulkScorer(float[] target) throws IOException { + DocIndexIterator indexIterator = inner.iterator(); + RandomVectorScorer randomScorer = scorer.getRandomVectorScorer(similarityFunction, inner, target); + return new PreFetchingFloatBulkScorer(randomScorer, indexIterator, inputSlice, dimension() * Float.BYTES); + } + + @Override + public VectorScorer scorer(float[] target) throws IOException { + return inner.scorer(target); + } + } + + private record PreFetchingFloatBulkScorer( + RandomVectorScorer inner, + KnnVectorValues.DocIndexIterator indexIterator, + IndexInput inputSlice, + int byteSize + ) implements BulkScorableVectorValues.BulkVectorScorer { + + @Override + public float score() throws IOException { + return inner.score(indexIterator.index()); + } + + @Override + public DocIdSetIterator iterator() { + return indexIterator; + } + + @Override + public BulkScorer bulkScore(DocIdSetIterator matchingDocs) throws IOException { + DocIdSetIterator conjunctionScorer = matchingDocs == null + ? indexIterator + : ConjunctionUtils.intersectIterators(List.of(matchingDocs, indexIterator)); + if (conjunctionScorer.docID() == -1) { + conjunctionScorer.nextDoc(); + } + return new FloatBulkScorer(inner, inputSlice, byteSize, 32, indexIterator, conjunctionScorer); + } + } + + private static class FloatBulkScorer implements BulkScorableVectorValues.BulkVectorScorer.BulkScorer { + private final KnnVectorValues.DocIndexIterator indexIterator; + private final DocIdSetIterator matchingDocs; + private final RandomVectorScorer inner; + private final int bulkSize; + private final IndexInput inputSlice; + private final int byteSize; + private final int[] docBuffer; + private final float[] scoreBuffer; + + FloatBulkScorer( + RandomVectorScorer fvv, + IndexInput inputSlice, + int byteSize, + int bulkSize, + KnnVectorValues.DocIndexIterator iterator, + DocIdSetIterator matchingDocs + ) { + this.indexIterator = iterator; + this.matchingDocs = matchingDocs; + this.inner = fvv; + this.bulkSize = bulkSize; + this.inputSlice = inputSlice; + this.docBuffer = new int[bulkSize]; + this.scoreBuffer = new float[bulkSize]; + this.byteSize = byteSize; + } + + @Override + public void nextDocsAndScores(int nextCount, Bits liveDocs, DocAndFloatFeatureBuffer buffer) throws IOException { + buffer.growNoCopy(nextCount); + int size = 0; + for (int doc = matchingDocs.docID(); doc != DocIdSetIterator.NO_MORE_DOCS && size < nextCount; doc = matchingDocs.nextDoc()) { + if (liveDocs == null || liveDocs.get(doc)) { + buffer.docs[size++] = indexIterator.index(); + } + } + int loopBound = size - (size % bulkSize); + int i = 0; + for (; i < loopBound; i += bulkSize) { + for (int j = 0; j < bulkSize; j++) { + long ord = buffer.docs[i + j]; + inputSlice.prefetch(ord * byteSize, byteSize); + } + System.arraycopy(buffer.docs, i, docBuffer, 0, bulkSize); + inner.bulkScore(docBuffer, scoreBuffer, bulkSize); + System.arraycopy(scoreBuffer, 0, buffer.features, i, bulkSize); + } + int countLeft = size - i; + for (int j = i; j < size; j++) { + long ord = buffer.docs[j]; + inputSlice.prefetch(ord * byteSize, byteSize); + } + System.arraycopy(buffer.docs, i, docBuffer, 0, countLeft); + inner.bulkScore(docBuffer, scoreBuffer, countLeft); + System.arraycopy(scoreBuffer, 0, buffer.features, i, countLeft); + buffer.size = size; + // fix the docIds in buffer + for (int j = 0; j < size; j++) { + buffer.docs[j] = inner.ordToDoc(buffer.docs[j]); + } + } + } } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java index 1b0258e6810f6..ebe62fa0cba98 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java @@ -13,9 +13,13 @@ import org.apache.lucene.document.Document; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.FilterDirectoryReader; +import org.apache.lucene.index.FilterLeafReader; +import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.queries.function.FunctionScoreQuery; import org.apache.lucene.search.BooleanClause; @@ -31,11 +35,14 @@ import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.VectorScorer; import org.apache.lucene.search.Weight; import org.apache.lucene.store.Directory; +import org.apache.lucene.util.Bits; import org.elasticsearch.index.codec.Elasticsearch92Lucene103Codec; import org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat; import org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat; @@ -48,6 +55,8 @@ import java.util.ArrayList; import java.util.List; +import static org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat.DEFAULT_CENTROIDS_PER_PARENT_CLUSTER; +import static org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat.DEFAULT_VECTORS_PER_CLUSTER; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; @@ -116,6 +125,64 @@ public void testRescoreDocs() throws Exception { } } + public void testRescoreSingleAndBulkEquality() throws Exception { + int numDocs = randomIntBetween(10, 100); + int numDims = randomIntBetween(5, 100); + int k = randomIntBetween(1, numDocs - 1); + + var queryVector = randomVector(numDims); + + List innerQueries = new ArrayList<>(); + innerQueries.add(new KnnFloatVectorQuery(FIELD_NAME, randomVector(numDims), (int) (k * randomFloatBetween(1.0f, 10.0f, true)))); + innerQueries.add( + new BooleanQuery.Builder().add(new DenseVectorQuery.Floats(queryVector, FIELD_NAME), BooleanClause.Occur.SHOULD) + .add(new FieldExistsQuery(FIELD_NAME), BooleanClause.Occur.FILTER) + .build() + ); + innerQueries.add(new MatchAllDocsQuery()); + + try (Directory d = newDirectory()) { + addRandomDocuments(numDocs, d, numDims); + try (DirectoryReader reader = DirectoryReader.open(d)) { + for (var innerQuery : innerQueries) { + RescoreKnnVectorQuery rescoreKnnVectorQuery = RescoreKnnVectorQuery.fromInnerQuery( + FIELD_NAME, + queryVector, + VectorSimilarityFunction.COSINE, + k, + k, + innerQuery + ); + + IndexSearcher searcher = newSearcher(reader, true, false); + TopDocs rescoredDocs = searcher.search(rescoreKnnVectorQuery, numDocs); + assertThat(rescoredDocs.scoreDocs.length, equalTo(k)); + + searcher = newSearcher(new SingleVectorQueryIndexReader(reader), true, false); + rescoreKnnVectorQuery = RescoreKnnVectorQuery.fromInnerQuery( + FIELD_NAME, + queryVector, + VectorSimilarityFunction.COSINE, + k, + k, + innerQuery + ); + TopDocs singleRescored = searcher.search(rescoreKnnVectorQuery, numDocs); + assertThat(singleRescored.scoreDocs.length, equalTo(k)); + + // Get real scores + ScoreDoc[] singleRescoreDocs = singleRescored.scoreDocs; + int i = 0; + for (ScoreDoc rescoreDoc : rescoredDocs.scoreDocs) { + assertThat(rescoreDoc.doc, equalTo(singleRescoreDocs[i].doc)); + assertThat(rescoreDoc.score, equalTo(singleRescoreDocs[i].score)); + i++; + } + } + } + } + } + public void testProfiling() throws Exception { int numDocs = randomIntBetween(10, 100); int numDims = randomIntBetween(5, 100); @@ -216,6 +283,7 @@ private static void addRandomDocuments(int numDocs, Directory d, int numDims) th IndexWriterConfig iwc = new IndexWriterConfig(); // Pick codec from quantized vector formats to ensure scores use real scores when using knn rescore KnnVectorsFormat format = randomFrom( + new ES920DiskBBQVectorsFormat(DEFAULT_VECTORS_PER_CLUSTER, DEFAULT_CENTROIDS_PER_PARENT_CLUSTER, randomBoolean()), new ES818BinaryQuantizedVectorsFormat(), new ES818HnswBinaryQuantizedVectorsFormat(), new ES93HnswBinaryQuantizedVectorsFormat(), @@ -243,4 +311,100 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { w.commit(); } } + + private static class SingleVectorQueryIndexReader extends FilterDirectoryReader { + + /** + * Create a new FilterDirectoryReader that filters a passed in DirectoryReader, using the supplied + * SubReaderWrapper to wrap its subreader. + * + * @param in the DirectoryReader to filter + */ + SingleVectorQueryIndexReader(DirectoryReader in) throws IOException { + super(in, new SubReaderWrapper() { + @Override + public LeafReader wrap(LeafReader reader) { + return new FilterLeafReader(reader) { + @Override + public CacheHelper getReaderCacheHelper() { + return null; + } + + @Override + public CacheHelper getCoreCacheHelper() { + return null; + } + + @Override + public FloatVectorValues getFloatVectorValues(String field) throws IOException { + FloatVectorValues values = super.getFloatVectorValues(field); + if (values == null) { + return null; + } + return new RegularFloatVectorValues(values); + } + }; + } + }); + } + + @Override + protected DirectoryReader doWrapDirectoryReader(DirectoryReader in) throws IOException { + return new SingleVectorQueryIndexReader(in); + } + + @Override + public CacheHelper getReaderCacheHelper() { + return null; + } + } + + private static final class RegularFloatVectorValues extends FloatVectorValues { + + private final FloatVectorValues in; + + RegularFloatVectorValues(FloatVectorValues in) { + this.in = in; + } + + @Override + public VectorScorer scorer(float[] target) throws IOException { + return in.scorer(target); + } + + @Override + public int ordToDoc(int ord) { + return in.ordToDoc(ord); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return in.getAcceptOrds(acceptDocs); + } + + @Override + public DocIndexIterator iterator() { + return in.iterator(); + } + + @Override + public float[] vectorValue(int ord) throws IOException { + return in.vectorValue(ord); + } + + @Override + public FloatVectorValues copy() throws IOException { + return new RegularFloatVectorValues(in.copy()); + } + + @Override + public int dimension() { + return in.dimension(); + } + + @Override + public int size() { + return in.size(); + } + } }