From 7083868e690cf546ee61de2b08df5f35db868d72 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Mon, 17 Feb 2025 19:15:24 +0100 Subject: [PATCH 1/7] Add IT test for RescoreKnnVectorQuery --- .../search/query/RescoreKnnVectorQueryIT.java | 239 ++++++++++++++++++ 1 file changed, 239 insertions(+) create mode 100644 server/src/internalClusterTest/java/org/elasticsearch/search/query/RescoreKnnVectorQueryIT.java diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/query/RescoreKnnVectorQueryIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/query/RescoreKnnVectorQueryIT.java new file mode 100644 index 0000000000000..d53f8b1807b08 --- /dev/null +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/query/RescoreKnnVectorQueryIT.java @@ -0,0 +1,239 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.query; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.elasticsearch.action.index.IndexRequestBuilder; +import org.elasticsearch.action.search.SearchRequestBuilder; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.VectorIndexType; +import org.elasticsearch.index.mapper.vectors.DenseVectorScriptDocValues; +import org.elasticsearch.index.query.MatchAllQueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.index.query.functionscore.ScriptScoreQueryBuilder; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.script.MockScriptPlugin; +import org.elasticsearch.script.Script; +import org.elasticsearch.script.ScriptType; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.retriever.KnnRetrieverBuilder; +import org.elasticsearch.search.vectors.KnnSearchBuilder; +import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; +import org.elasticsearch.search.vectors.RescoreVectorBuilder; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.junit.Before; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailuresAndResponse; +import static org.hamcrest.Matchers.equalTo; + +public class RescoreKnnVectorQueryIT extends ESIntegTestCase { + + public static final String INDEX_NAME = "test"; + public static final String VECTOR_FIELD = "vector"; + public static final String VECTOR_SCORE_SCRIPT = "vector_scoring"; + public static final String QUERY_VECTOR_PARAM = "query_vector"; + + @Override + protected Collection> nodePlugins() { + return Collections.singleton(CustomScriptPlugin.class); + } + + public static class CustomScriptPlugin extends MockScriptPlugin { + + private static final VectorSimilarityFunction SIMILARITY_FUNCTION = DenseVectorFieldMapper.VectorSimilarity.L2_NORM + .vectorSimilarityFunction(IndexVersion.current(), DenseVectorFieldMapper.ElementType.FLOAT); + + @Override + protected Map, Object>> pluginScripts() { + return Map.of(VECTOR_SCORE_SCRIPT, vars -> { + Map doc = (Map) vars.get("doc"); + return SIMILARITY_FUNCTION.compare( + ((DenseVectorScriptDocValues) doc.get(VECTOR_FIELD)).getVectorValue(), + (float[]) vars.get(QUERY_VECTOR_PARAM) + ); + }); + } + } + + @Before + public void setup() throws IOException { + String type = randomFrom( + Arrays.stream(VectorIndexType.values()) + .filter(VectorIndexType::isQuantized) + .map(t -> t.name().toLowerCase(Locale.ROOT)) + .collect(Collectors.toCollection(ArrayList::new)) + ); + XContentBuilder mapping = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(VECTOR_FIELD) + .field("type", "dense_vector") + .field("similarity", "l2_norm") + .startObject("index_options") + .field("type", type) + .endObject() + .endObject() + .endObject() + .endObject(); + + Settings settings = Settings.builder() + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 5)) + .build(); + prepareCreate(INDEX_NAME).setMapping(mapping).setSettings(settings).get(); + ensureGreen(INDEX_NAME); + } + + private record TestParams( + int numDocs, + int numDims, + float[] queryVector, + int k, + int numCands, + RescoreVectorBuilder rescoreVectorBuilder + ) { + public static TestParams generate() { + int numDims = randomIntBetween(32, 512) * 2; // Ensure even dimensions + int numDocs = randomIntBetween(10, 100); + int k = randomIntBetween(1, numDocs - 5); + return new TestParams( + numDocs, + numDims, + randomVector(numDims), + k, + (int) (k * randomFloatBetween(1.0f, 10.0f, true)), + new RescoreVectorBuilder(randomFloatBetween(1.0f, 100f, true)) + ); + } + } + + public void testKnnSearchRescore() { + BiFunction knnSearchGenerator = (testParams, requestBuilder) -> { + KnnSearchBuilder knnSearch = new KnnSearchBuilder( + VECTOR_FIELD, + testParams.queryVector, + testParams.k, + testParams.numCands, + testParams.rescoreVectorBuilder, + null + ); + return requestBuilder.setKnnSearch(List.of(knnSearch)); + }; + testKnnRescore(knnSearchGenerator); + } + + public void testKnnQueryRescore() { + BiFunction knnQueryGenerator = (testParams, requestBuilder) -> { + KnnVectorQueryBuilder knnQuery = new KnnVectorQueryBuilder( + VECTOR_FIELD, + testParams.queryVector, + testParams.k, + testParams.numCands, + testParams.rescoreVectorBuilder, + null + ); + return requestBuilder.setQuery(knnQuery); + }; + testKnnRescore(knnQueryGenerator); + } + + public void testKnnRetriever() { + BiFunction knnQueryGenerator = (testParams, requestBuilder) -> { + KnnRetrieverBuilder knnRetriever = new KnnRetrieverBuilder( + VECTOR_FIELD, + testParams.queryVector, + null, + testParams.k, + testParams.numCands, + testParams.rescoreVectorBuilder, + null + ); + return requestBuilder.setSource(new SearchSourceBuilder().retriever(knnRetriever)); + }; + testKnnRescore(knnQueryGenerator); + } + + private void testKnnRescore(BiFunction searchRequestGenerator) { + TestParams testParams = TestParams.generate(); + + int numDocs = testParams.numDocs; + IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs]; + + for (int i = 0; i < numDocs; i++) { + docs[i] = prepareIndex(INDEX_NAME).setId("" + i).setSource(VECTOR_FIELD, randomVector(testParams.numDims)); + } + indexRandom(true, docs); + + float[] queryVector = testParams.queryVector; + float oversample = randomFloatBetween(1.0f, 100f, true); + RescoreVectorBuilder rescoreVectorBuilder = new RescoreVectorBuilder(oversample); + + SearchRequestBuilder requestBuilder = searchRequestGenerator.apply( + testParams, + prepareSearch(INDEX_NAME).setSize(numDocs).setTrackTotalHits(randomBoolean()) + ); + + assertNoFailuresAndResponse(requestBuilder, knnResponse -> { compareWithExactSearch(knnResponse, queryVector, numDocs); }); + } + + private static void compareWithExactSearch(SearchResponse knnResponse, float[] queryVector, int docCount) { + // Do an exact query and compare + Script script = new Script( + ScriptType.INLINE, + CustomScriptPlugin.NAME, + VECTOR_SCORE_SCRIPT, + Map.of(QUERY_VECTOR_PARAM, queryVector) + ); + ScriptScoreQueryBuilder scriptScoreQueryBuilder = QueryBuilders.scriptScoreQuery(new MatchAllQueryBuilder(), script); + assertNoFailuresAndResponse(prepareSearch(INDEX_NAME).setQuery(scriptScoreQueryBuilder).setSize(docCount), exactResponse -> { + assertHitCount(exactResponse, docCount); + + int i = 0; + SearchHit[] exactHits = exactResponse.getHits().getHits(); + for (SearchHit knnHit : knnResponse.getHits().getHits()) { + while (i < exactHits.length && exactHits[i].getId().equals(knnHit.getId()) == false) { + i++; + } + if (i >= exactHits.length) { + fail("Knn doc not found in exact search"); + } + assertThat("Real score is not the same as rescored score", knnHit.getScore(), equalTo(exactHits[i].getScore())); + } + }); + } + + private static float[] randomVector(int numDimensions) { + float[] vector = new float[numDimensions]; + for (int j = 0; j < numDimensions; j++) { + vector[j] = randomFloatBetween(0, 1, true); + } + return vector; + } +} From dddf1682c325b0feb885611b9404eec36ff6de60 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Mon, 17 Feb 2025 19:15:45 +0100 Subject: [PATCH 2/7] Add assertion to check docs are in order --- .../elasticsearch/search/vectors/KnnScoreDocQuery.java | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java index 3d13f3cd82b9c..7d54fea035371 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java @@ -21,6 +21,7 @@ import org.apache.lucene.search.Scorer; import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.Weight; +import org.elasticsearch.core.Assertions; import java.io.IOException; import java.util.Arrays; @@ -55,6 +56,13 @@ public class KnnScoreDocQuery extends Query { * @param reader IndexReader */ KnnScoreDocQuery(int[] docs, float[] scores, IndexReader reader) { + if (Assertions.ENABLED) { + assert docs.length == scores.length; + for (int i = 1; i < docs.length; i++) { + assert docs[i - 1] < docs[i] : "doc ids are not in order: " + Arrays.toString(docs); + } + } + this.docs = docs; this.scores = scores; this.segmentStarts = findSegmentStarts(reader, docs); From 46a983ab211cd395728262ea474c379cb27ca00b Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Mon, 17 Feb 2025 19:20:42 +0100 Subject: [PATCH 3/7] Refactor RescoreKnnVectorQueryTests to create multiple segments, add quantized codecs, check doc ordering --- .../vectors/DenseVectorFieldMapper.java | 2 +- .../vectors/RescoreKnnVectorQueryTests.java | 110 ++++++++++-------- 2 files changed, 63 insertions(+), 49 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index ce41c2164e205..193b2f8d90433 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -1225,7 +1225,7 @@ public final int hashCode() { } } - private enum VectorIndexType { + public enum VectorIndexType { HNSW("hnsw", false) { @Override public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { diff --git a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java index 861a8b11db567..05b7bc9ef4f82 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java @@ -9,36 +9,39 @@ package org.elasticsearch.search.vectors; +import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.document.Document; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.DirectoryReader; -import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; -import org.apache.lucene.index.KnnVectorValues; -import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.queries.function.FunctionScoreQuery; +import org.apache.lucene.search.DoubleValuesSource; import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.Weight; import org.apache.lucene.store.Directory; +import org.elasticsearch.index.codec.Elasticsearch900Lucene101Codec; +import org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat; +import org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.zstd.Zstd814StoredFieldsFormat; +import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource; import org.elasticsearch.search.profile.query.QueryProfiler; import org.elasticsearch.test.ESTestCase; import java.io.IOException; import java.io.UnsupportedEncodingException; -import java.util.Arrays; -import java.util.Collection; -import java.util.HashSet; -import java.util.Map; -import java.util.PriorityQueue; -import java.util.stream.Collectors; - -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; @@ -59,51 +62,45 @@ public void testRescoreDocs() throws Exception { // Use a RescoreKnnVectorQuery with a match all query, to ensure we get scoring of 1 from the inner query // and thus we're rescoring the top k docs. float[] queryVector = randomVector(numDims); + Query innerQuery; + if (randomBoolean()) { + innerQuery = new KnnFloatVectorQuery(FIELD_NAME, queryVector, (int) (k * randomFloatBetween(1.0f, 10.0f, true))); + } else { + innerQuery = new MatchAllDocsQuery(); + } RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery( FIELD_NAME, queryVector, VectorSimilarityFunction.COSINE, k, - new MatchAllDocsQuery() + innerQuery ); IndexSearcher searcher = newSearcher(reader, true, false); - TopDocs docs = searcher.search(rescoreKnnVectorQuery, numDocs); - Map rescoredDocs = Arrays.stream(docs.scoreDocs) - .collect(Collectors.toMap(scoreDoc -> scoreDoc.doc, scoreDoc -> scoreDoc.score)); - - assertThat(rescoredDocs.size(), equalTo(k)); - - Collection rescoredScores = new HashSet<>(rescoredDocs.values()); - - // Collect all docs sequentially, and score them using the similarity function to get the top K scores - PriorityQueue topK = new PriorityQueue<>((o1, o2) -> Float.compare(o2, o1)); - - for (LeafReaderContext leafReaderContext : reader.leaves()) { - FloatVectorValues vectorValues = leafReaderContext.reader().getFloatVectorValues(FIELD_NAME); - KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); - while (iterator.nextDoc() != NO_MORE_DOCS) { - float[] vectorData = vectorValues.vectorValue(iterator.docID()); - float score = VectorSimilarityFunction.COSINE.compare(queryVector, vectorData); - topK.add(score); - int docId = iterator.docID(); - // If the doc has been retrieved from the RescoreKnnVectorQuery, check the score is the same and remove it - // to ensure we found them all - if (rescoredDocs.containsKey(docId)) { - assertThat(rescoredDocs.get(docId), equalTo(score)); - rescoredDocs.remove(docId); - } - } - } - - assertThat(rescoredDocs.size(), equalTo(0)); + TopDocs rescoredDocs = searcher.search(rescoreKnnVectorQuery, numDocs); + assertThat(rescoredDocs.scoreDocs.length, equalTo(k)); - // Check top scoring docs are contained in rescored docs - for (int i = 0; i < k; i++) { - Float topScore = topK.poll(); - if (rescoredScores.contains(topScore) == false) { - fail("Top score " + topScore + " not contained in rescored doc scores " + rescoredScores); + // Get real scores + DoubleValuesSource valueSource = new VectorSimilarityFloatValueSource( + FIELD_NAME, + queryVector, + VectorSimilarityFunction.COSINE + ); + FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(new MatchAllDocsQuery(), valueSource); + TopDocs realScoreTopDocs = searcher.search(functionScoreQuery, numDocs); + + int i = 0; + ScoreDoc[] realScoreDocs = realScoreTopDocs.scoreDocs; + for (ScoreDoc rescoreDoc : rescoredDocs.scoreDocs) { + // There are docs that won't be found in the rescored search, but every doc found must be in the same order + // and have the same score + while (i < realScoreDocs.length && realScoreDocs[i].doc != rescoreDoc.doc) { + i++; + } + if (i >= realScoreDocs.length) { + fail("Rescored doc not found in real score docs"); } + assertThat("Real score is not the same as rescored score", rescoreDoc.score, equalTo(realScoreDocs[i].score)); } } } @@ -205,16 +202,33 @@ public void profile(QueryProfiler queryProfiler) { } private static void addRandomDocuments(int numDocs, Directory d, int numDims) throws IOException { + IndexWriterConfig iwc = new IndexWriterConfig(); + // Pick codec from quantized vector formats to ensure scores use real scores when using knn rescore + KnnVectorsFormat format = randomFrom( + new ES818BinaryQuantizedVectorsFormat(), + new ES818HnswBinaryQuantizedVectorsFormat(), + new ES813Int8FlatVectorFormat(), + new ES813Int8FlatVectorFormat(), + new ES814HnswScalarQuantizedVectorsFormat() + ); + iwc.setCodec(new Elasticsearch900Lucene101Codec(randomFrom(Zstd814StoredFieldsFormat.Mode.values())) { + @Override + public KnnVectorsFormat getKnnVectorsFormatForField(String field) { + return format; + } + }); try (IndexWriter w = new IndexWriter(d, newIndexWriterConfig())) { for (int i = 0; i < numDocs; i++) { Document document = new Document(); float[] vector = randomVector(numDims); - KnnFloatVectorField vectorField = new KnnFloatVectorField(FIELD_NAME, vector); + KnnFloatVectorField vectorField = new KnnFloatVectorField(FIELD_NAME, vector, VectorSimilarityFunction.COSINE); document.add(vectorField); w.addDocument(document); + if (randomBoolean() && (i % 10 == 0)) { + w.commit(); + } } w.commit(); - w.forceMerge(1); } } } From 3513d20780d5deab77696731625dbcf916c93e87 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Mon, 17 Feb 2025 19:22:30 +0100 Subject: [PATCH 4/7] Minor fixes --- .../elasticsearch/search/query/RescoreKnnVectorQueryIT.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/query/RescoreKnnVectorQueryIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/query/RescoreKnnVectorQueryIT.java index d53f8b1807b08..526b6f11fb978 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/query/RescoreKnnVectorQueryIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/query/RescoreKnnVectorQueryIT.java @@ -65,8 +65,7 @@ protected Collection> nodePlugins() { return Collections.singleton(CustomScriptPlugin.class); } - public static class CustomScriptPlugin extends MockScriptPlugin { - + private static class CustomScriptPlugin extends MockScriptPlugin { private static final VectorSimilarityFunction SIMILARITY_FUNCTION = DenseVectorFieldMapper.VectorSimilarity.L2_NORM .vectorSimilarityFunction(IndexVersion.current(), DenseVectorFieldMapper.ElementType.FLOAT); From 813f00117a38b92ec7bb383ef89b7a7800807564 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 18 Feb 2025 07:43:57 +0100 Subject: [PATCH 5/7] Changing plugin class visibility --- .../org/elasticsearch/search/query/RescoreKnnVectorQueryIT.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/query/RescoreKnnVectorQueryIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/query/RescoreKnnVectorQueryIT.java index 526b6f11fb978..c8812cfc109f2 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/query/RescoreKnnVectorQueryIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/query/RescoreKnnVectorQueryIT.java @@ -65,7 +65,7 @@ protected Collection> nodePlugins() { return Collections.singleton(CustomScriptPlugin.class); } - private static class CustomScriptPlugin extends MockScriptPlugin { + public static class CustomScriptPlugin extends MockScriptPlugin { private static final VectorSimilarityFunction SIMILARITY_FUNCTION = DenseVectorFieldMapper.VectorSimilarity.L2_NORM .vectorSimilarityFunction(IndexVersion.current(), DenseVectorFieldMapper.ElementType.FLOAT); From a073f43c76953cb40e38ec8cbed73529220efeb9 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 20 Feb 2025 13:06:23 +0100 Subject: [PATCH 6/7] Sort in query constructor, use ScoreDoc[] vs building individual arrays on the clients --- .../search/vectors/KnnScoreDocQuery.java | 22 +++++++++---------- .../vectors/KnnScoreDocQueryBuilder.java | 10 +-------- .../search/vectors/RescoreKnnVectorQuery.java | 13 +---------- 3 files changed, 12 insertions(+), 33 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java index 7d54fea035371..6542aed294f4d 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java @@ -17,14 +17,15 @@ import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.Weight; -import org.elasticsearch.core.Assertions; import java.io.IOException; import java.util.Arrays; +import java.util.Comparator; import java.util.Objects; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; @@ -51,20 +52,17 @@ public class KnnScoreDocQuery extends Query { /** * Creates a query. * - * @param docs the global doc IDs of documents that match, in ascending order - * @param scores the scores of the matching documents + * @param scoreDocs an array of ScoreDocs to use for the query * @param reader IndexReader */ - KnnScoreDocQuery(int[] docs, float[] scores, IndexReader reader) { - if (Assertions.ENABLED) { - assert docs.length == scores.length; - for (int i = 1; i < docs.length; i++) { - assert docs[i - 1] < docs[i] : "doc ids are not in order: " + Arrays.toString(docs); - } + KnnScoreDocQuery(ScoreDoc[] scoreDocs, IndexReader reader) { + Arrays.sort(scoreDocs, Comparator.comparingInt(scoreDoc -> scoreDoc.doc)); + this.docs = new int[scoreDocs.length]; + this.scores = new float[scoreDocs.length]; + for (int i = 0; i < scoreDocs.length; i++) { + docs[i] = scoreDocs[i].doc; + scores[i] = scoreDocs[i].score; } - - this.docs = docs; - this.scores = scores; this.segmentStarts = findSegmentStarts(reader, docs); this.contextIdentity = reader.getContext().id(); } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java index 6fa83ccfb6ac2..1a81f4b984e93 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java @@ -141,15 +141,7 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep @Override protected Query doToQuery(SearchExecutionContext context) throws IOException { - int numDocs = scoreDocs.length; - int[] docs = new int[numDocs]; - float[] scores = new float[numDocs]; - for (int i = 0; i < numDocs; i++) { - docs[i] = scoreDocs[i].doc; - scores[i] = scoreDocs[i].score; - } - - return new KnnScoreDocQuery(docs, scores, context.getIndexReader()); + return new KnnScoreDocQuery(scoreDocs, context.getIndexReader()); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java index 31d9767e9a857..99568a507ffb9 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java @@ -16,14 +16,12 @@ import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; -import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource; import org.elasticsearch.search.profile.query.QueryProfiler; import java.io.IOException; import java.util.Arrays; -import java.util.Comparator; import java.util.Objects; /** @@ -60,16 +58,7 @@ public Query rewrite(IndexSearcher searcher) throws IOException { // Retrieve top k documents from the rescored query TopDocs topDocs = searcher.search(query, k); vectorOperations = topDocs.totalHits.value(); - ScoreDoc[] scoreDocs = topDocs.scoreDocs; - Arrays.sort(scoreDocs, Comparator.comparingInt(scoreDoc -> scoreDoc.doc)); - int[] docIds = new int[scoreDocs.length]; - float[] scores = new float[scoreDocs.length]; - for (int i = 0; i < scoreDocs.length; i++) { - docIds[i] = scoreDocs[i].doc; - scores[i] = scoreDocs[i].score; - } - - return new KnnScoreDocQuery(docIds, scores, searcher.getIndexReader()); + return new KnnScoreDocQuery(topDocs.scoreDocs, searcher.getIndexReader()); } public Query innerQuery() { From ee464fe2b21ad76ee2a22474d72e27a40df50e6d Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 21 Feb 2025 19:41:30 +0100 Subject: [PATCH 7/7] Clarify javadoc --- .../org/elasticsearch/search/vectors/KnnScoreDocQuery.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java index 6542aed294f4d..35906940a6418 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java @@ -33,9 +33,8 @@ /** * A query that matches the provided docs with their scores. * - * Note: this query was adapted from Lucene's DocAndScoreQuery from the class + * Note: this query was originally adapted from Lucene's DocAndScoreQuery from the class * {@link org.apache.lucene.search.KnnFloatVectorQuery}, which is package-private. - * There are no changes to the behavior, just some renames. */ public class KnnScoreDocQuery extends Query { private final int[] docs; @@ -56,6 +55,7 @@ public class KnnScoreDocQuery extends Query { * @param reader IndexReader */ KnnScoreDocQuery(ScoreDoc[] scoreDocs, IndexReader reader) { + // Ensure that the docs are sorted by docId, as they are later searched using binary search Arrays.sort(scoreDocs, Comparator.comparingInt(scoreDoc -> scoreDoc.doc)); this.docs = new int[scoreDocs.length]; this.scores = new float[scoreDocs.length];