diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java index 74a7dbe168e6b..80cbec1e2f6c2 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java @@ -18,8 +18,6 @@ import org.apache.lucene.search.DoubleValues; import org.apache.lucene.search.DoubleValuesSource; import org.apache.lucene.search.IndexSearcher; -import org.elasticsearch.search.profile.query.QueryProfiler; -import org.elasticsearch.search.vectors.QueryProfilerProvider; import java.io.IOException; import java.util.Arrays; @@ -29,12 +27,11 @@ * DoubleValuesSource that is used to calculate scores according to a similarity function for a KnnFloatVectorField, using the * original vector values stored in the index */ -public class VectorSimilarityFloatValueSource extends DoubleValuesSource implements QueryProfilerProvider { +public class VectorSimilarityFloatValueSource extends DoubleValuesSource { private final String field; private final float[] target; private final VectorSimilarityFunction vectorSimilarityFunction; - private long vectorOpsCount; public VectorSimilarityFloatValueSource(String field, float[] target, VectorSimilarityFunction vectorSimilarityFunction) { this.field = field; @@ -52,7 +49,6 @@ public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws return new DoubleValues() { @Override public double doubleValue() throws IOException { - vectorOpsCount++; return vectorSimilarityFunction.compare(target, vectorValues.vectorValue(iterator.index())); } @@ -73,11 +69,6 @@ public DoubleValuesSource rewrite(IndexSearcher reader) throws IOException { return this; } - @Override - public void profile(QueryProfiler queryProfiler) { - queryProfiler.addVectorOpsCount(vectorOpsCount); - } - @Override public int hashCode() { return Objects.hash(field, Arrays.hashCode(target), vectorSimilarityFunction); diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java index a65757cc25876..4a6dbab399141 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java @@ -487,11 +487,6 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException { return this; } - @Override - protected QueryBuilder doIndexMetadataRewrite(QueryRewriteContext context) throws IOException { - return super.doIndexMetadataRewrite(context); - } - @Override protected Query doToQuery(SearchExecutionContext context) throws IOException { MappedFieldType fieldType = context.getFieldType(fieldName); @@ -529,8 +524,8 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { String parentPath = context.nestedLookup().getNestedParent(fieldName); Float numCandidatesFactor = rescoreVectorBuilder() == null ? null : rescoreVectorBuilder.numCandidatesFactor(); + BitSetProducer parentBitSet = null; if (parentPath != null) { - final BitSetProducer parentBitSet; final Query parentFilter; NestedObjectMapper originalObjectMapper = context.nestedScope().getObjectMapper(); if (originalObjectMapper != null) { @@ -559,17 +554,17 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { // Now join the filterQuery & parentFilter to provide the matching blocks of children filterQuery = new ToChildBlockJoinQuery(filterQuery, parentBitSet); } - return vectorFieldType.createKnnQuery( - queryVector, - k, - adjustedNumCands, - numCandidatesFactor, - filterQuery, - vectorSimilarity, - parentBitSet - ); } - return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, numCandidatesFactor, filterQuery, vectorSimilarity, null); + + return vectorFieldType.createKnnQuery( + queryVector, + k, + adjustedNumCands, + numCandidatesFactor, + filterQuery, + vectorSimilarity, + parentBitSet + ); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java index a9c606b1f8618..79ede6873ad1f 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java @@ -32,16 +32,15 @@ public class RescoreKnnVectorQuery extends Query implements QueryProfilerProvide private final String fieldName; private final float[] floatTarget; private final VectorSimilarityFunction vectorSimilarityFunction; - private final Integer k; + private final int k; private final Query innerQuery; - - private QueryProfilerProvider vectorProfiling; + private long vectorOperations = 0; public RescoreKnnVectorQuery( String fieldName, float[] floatTarget, VectorSimilarityFunction vectorSimilarityFunction, - Integer k, + int k, Query innerQuery ) { this.fieldName = fieldName; @@ -54,19 +53,12 @@ public RescoreKnnVectorQuery( @Override public Query rewrite(IndexSearcher searcher) throws IOException { DoubleValuesSource valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction); - // Vector similarity VectorSimilarityFloatValueSource keep track of the compared vectors - we need that in case we don't need - // to calculate top k and return directly the query to understand how many comparisons were done - vectorProfiling = (QueryProfilerProvider) valueSource; FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(innerQuery, valueSource); Query query = searcher.rewrite(functionScoreQuery); - if (k == null) { - // No need to calculate top k - let the request size limit the results. - return query; - } - // Retrieve top k documents from the rescored query TopDocs topDocs = searcher.search(query, k); + vectorOperations = topDocs.totalHits.value(); ScoreDoc[] scoreDocs = topDocs.scoreDocs; int[] docIds = new int[scoreDocs.length]; float[] scores = new float[scoreDocs.length]; @@ -82,7 +74,7 @@ public Query innerQuery() { return innerQuery; } - public Integer k() { + public int k() { return k; } @@ -92,10 +84,7 @@ public void profile(QueryProfiler queryProfiler) { queryProfilerProvider.profile(queryProfiler); } - if (vectorProfiling == null) { - throw new IllegalStateException("Query should have been rewritten"); - } - vectorProfiling.profile(queryProfiler); + queryProfiler.addVectorOpsCount(vectorOperations); } @Override diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index d37b4a4bacb4e..be4c677d20b03 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -456,18 +456,19 @@ public void testRescoreOversampleModifiesNumCandidates() { ); // Total results is k, internal k is multiplied by oversample - checkRescoreQueryParameters(fieldType, 10, 200, 2.5F, null, 500, 10); + checkRescoreQueryParameters(fieldType, 10, 200, randomInt(), 2.5F, null, 500, 10); // If numCands < k, update numCands to k - checkRescoreQueryParameters(fieldType, 10, 20, 2.5F, null, 50, 10); + checkRescoreQueryParameters(fieldType, 10, 20, randomInt(), 2.5F, null, 50, 10); // Oversampling limits for num candidates - checkRescoreQueryParameters(fieldType, 1000, 1000, 11.0F, null, 10000, 1000); - checkRescoreQueryParameters(fieldType, 5000, 7500, 2.5F, null, 10000, 5000); + checkRescoreQueryParameters(fieldType, 1000, 1000, randomInt(), 11.0F, null, 10000, 1000); + checkRescoreQueryParameters(fieldType, 5000, 7500, randomInt(), 2.5F, null, 10000, 5000); } private static void checkRescoreQueryParameters( DenseVectorFieldType fieldType, - Integer k, + int k, int candidates, + int requestSize, float numCandsFactor, Integer expectedK, int expectedCandidates, diff --git a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java index 7bbe7dcc155c5..861a8b11db567 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java @@ -9,8 +9,6 @@ package org.elasticsearch.search.vectors; -import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; - import org.apache.lucene.document.Document; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.DirectoryReader; @@ -33,11 +31,9 @@ import java.io.IOException; import java.io.UnsupportedEncodingException; -import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.HashSet; -import java.util.List; import java.util.Map; import java.util.PriorityQueue; import java.util.stream.Collectors; @@ -49,21 +45,11 @@ public class RescoreKnnVectorQueryTests extends ESTestCase { public static final String FIELD_NAME = "float_vector"; - private final int numDocs; - private final Integer k; - - public RescoreKnnVectorQueryTests(boolean useK) { - this.numDocs = randomIntBetween(10, 100); - this.k = useK ? randomIntBetween(1, numDocs - 1) : null; - } public void testRescoreDocs() throws Exception { + int numDocs = randomIntBetween(10, 100); int numDims = randomIntBetween(5, 100); - - Integer adjustedK = k; - if (k == null) { - adjustedK = numDocs; - } + int k = randomIntBetween(1, numDocs - 1); try (Directory d = newDirectory()) { addRandomDocuments(numDocs, d, numDims); @@ -77,7 +63,7 @@ public void testRescoreDocs() throws Exception { FIELD_NAME, queryVector, VectorSimilarityFunction.COSINE, - adjustedK, + k, new MatchAllDocsQuery() ); @@ -86,7 +72,7 @@ public void testRescoreDocs() throws Exception { Map rescoredDocs = Arrays.stream(docs.scoreDocs) .collect(Collectors.toMap(scoreDoc -> scoreDoc.doc, scoreDoc -> scoreDoc.score)); - assertThat(rescoredDocs.size(), equalTo(adjustedK)); + assertThat(rescoredDocs.size(), equalTo(k)); Collection rescoredScores = new HashSet<>(rescoredDocs.values()); @@ -113,7 +99,7 @@ public void testRescoreDocs() throws Exception { assertThat(rescoredDocs.size(), equalTo(0)); // Check top scoring docs are contained in rescored docs - for (int i = 0; i < adjustedK; i++) { + for (int i = 0; i < k; i++) { Float topScore = topK.poll(); if (rescoredScores.contains(topScore) == false) { fail("Top score " + topScore + " not contained in rescored doc scores " + rescoredScores); @@ -124,7 +110,9 @@ public void testRescoreDocs() throws Exception { } public void testProfiling() throws Exception { + int numDocs = randomIntBetween(10, 100); int numDims = randomIntBetween(5, 100); + int k = randomIntBetween(1, numDocs - 1); try (Directory d = newDirectory()) { addRandomDocuments(numDocs, d, numDims); @@ -132,13 +120,13 @@ public void testProfiling() throws Exception { try (IndexReader reader = DirectoryReader.open(d)) { float[] queryVector = randomVector(numDims); - checkProfiling(queryVector, reader, new MatchAllDocsQuery()); - checkProfiling(queryVector, reader, new MockQueryProfilerProvider(randomIntBetween(1, 100))); + checkProfiling(k, numDocs, queryVector, reader, new MatchAllDocsQuery()); + checkProfiling(k, numDocs, queryVector, reader, new MockQueryProfilerProvider(randomIntBetween(1, 100))); } } } - private void checkProfiling(float[] queryVector, IndexReader reader, Query innerQuery) throws IOException { + private void checkProfiling(int k, int numDocs, float[] queryVector, IndexReader reader, Query innerQuery) throws IOException { RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery( FIELD_NAME, queryVector, @@ -229,13 +217,4 @@ private static void addRandomDocuments(int numDocs, Directory d, int numDims) th w.forceMerge(1); } } - - @ParametersFactory - public static Iterable parameters() { - List params = new ArrayList<>(); - params.add(new Object[] { true }); - params.add(new Object[] { false }); - - return params; - } }