From e03f24017b0d7f3ea7a0a1ddac536c9f9fc87193 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 18 Dec 2024 18:46:46 +0100 Subject: [PATCH 1/3] Use request size when k is null to calculate the number of results to retrieve from each shard --- .../vectors/DenseVectorFieldMapper.java | 5 ++- .../VectorSimilarityFloatValueSource.java | 11 +---- .../search/vectors/KnnVectorQueryBuilder.java | 23 +++++----- .../search/vectors/RescoreKnnVectorQuery.java | 23 +++------- .../vectors/DenseVectorFieldMapperTests.java | 31 ++++++++++++-- .../vectors/DenseVectorFieldTypeTests.java | 42 ++++++++++++------- .../vectors/RescoreKnnVectorQueryTests.java | 41 +++++------------- 7 files changed, 89 insertions(+), 87 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 8c6e874ff577f..3d394ed8a1549 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -2008,6 +2008,7 @@ public Query createKnnQuery( VectorData queryVector, Integer k, int numCands, + int requestSize, Float numCandsFactor, Query filter, Float similarityThreshold, @@ -2024,6 +2025,7 @@ public Query createKnnQuery( queryVector.asFloatVector(), k, numCands, + requestSize, numCandsFactor, filter, similarityThreshold, @@ -2090,6 +2092,7 @@ private Query createKnnFloatQuery( float[] queryVector, Integer k, int numCands, + int requestSize, Float numCandsFactor, Query filter, Float similarityThreshold, @@ -2127,7 +2130,7 @@ && isNotUnitVector(squaredMagnitude)) { name(), queryVector, similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.FLOAT), - k, + k == null ? requestSize : k, knnQuery ); } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java index 74a7dbe168e6b..80cbec1e2f6c2 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java @@ -18,8 +18,6 @@ import org.apache.lucene.search.DoubleValues; import org.apache.lucene.search.DoubleValuesSource; import org.apache.lucene.search.IndexSearcher; -import org.elasticsearch.search.profile.query.QueryProfiler; -import org.elasticsearch.search.vectors.QueryProfilerProvider; import java.io.IOException; import java.util.Arrays; @@ -29,12 +27,11 @@ * DoubleValuesSource that is used to calculate scores according to a similarity function for a KnnFloatVectorField, using the * original vector values stored in the index */ -public class VectorSimilarityFloatValueSource extends DoubleValuesSource implements QueryProfilerProvider { +public class VectorSimilarityFloatValueSource extends DoubleValuesSource { private final String field; private final float[] target; private final VectorSimilarityFunction vectorSimilarityFunction; - private long vectorOpsCount; public VectorSimilarityFloatValueSource(String field, float[] target, VectorSimilarityFunction vectorSimilarityFunction) { this.field = field; @@ -52,7 +49,6 @@ public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws return new DoubleValues() { @Override public double doubleValue() throws IOException { - vectorOpsCount++; return vectorSimilarityFunction.compare(target, vectorValues.vectorValue(iterator.index())); } @@ -73,11 +69,6 @@ public DoubleValuesSource rewrite(IndexSearcher reader) throws IOException { return this; } - @Override - public void profile(QueryProfiler queryProfiler) { - queryProfiler.addVectorOpsCount(vectorOpsCount); - } - @Override public int hashCode() { return Objects.hash(field, Arrays.hashCode(target), vectorSimilarityFunction); diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java index 88f6312fa7e6f..c4d6c3da7b5f0 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java @@ -528,8 +528,8 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { String parentPath = context.nestedLookup().getNestedParent(fieldName); Float numCandidatesFactor = rescoreVectorBuilder() == null ? null : rescoreVectorBuilder.numCandidatesFactor(); + BitSetProducer parentBitSet = null; if (parentPath != null) { - final BitSetProducer parentBitSet; final Query parentFilter; NestedObjectMapper originalObjectMapper = context.nestedScope().getObjectMapper(); if (originalObjectMapper != null) { @@ -558,17 +558,18 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { // Now join the filterQuery & parentFilter to provide the matching blocks of children filterQuery = new ToChildBlockJoinQuery(filterQuery, parentBitSet); } - return vectorFieldType.createKnnQuery( - queryVector, - k, - adjustedNumCands, - numCandidatesFactor, - filterQuery, - vectorSimilarity, - parentBitSet - ); } - return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, numCandidatesFactor, filterQuery, vectorSimilarity, null); + + return vectorFieldType.createKnnQuery( + queryVector, + k, + adjustedNumCands, + requestSize, + numCandidatesFactor, + filterQuery, + vectorSimilarity, + parentBitSet + ); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java index a9c606b1f8618..79ede6873ad1f 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java @@ -32,16 +32,15 @@ public class RescoreKnnVectorQuery extends Query implements QueryProfilerProvide private final String fieldName; private final float[] floatTarget; private final VectorSimilarityFunction vectorSimilarityFunction; - private final Integer k; + private final int k; private final Query innerQuery; - - private QueryProfilerProvider vectorProfiling; + private long vectorOperations = 0; public RescoreKnnVectorQuery( String fieldName, float[] floatTarget, VectorSimilarityFunction vectorSimilarityFunction, - Integer k, + int k, Query innerQuery ) { this.fieldName = fieldName; @@ -54,19 +53,12 @@ public RescoreKnnVectorQuery( @Override public Query rewrite(IndexSearcher searcher) throws IOException { DoubleValuesSource valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction); - // Vector similarity VectorSimilarityFloatValueSource keep track of the compared vectors - we need that in case we don't need - // to calculate top k and return directly the query to understand how many comparisons were done - vectorProfiling = (QueryProfilerProvider) valueSource; FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(innerQuery, valueSource); Query query = searcher.rewrite(functionScoreQuery); - if (k == null) { - // No need to calculate top k - let the request size limit the results. - return query; - } - // Retrieve top k documents from the rescored query TopDocs topDocs = searcher.search(query, k); + vectorOperations = topDocs.totalHits.value(); ScoreDoc[] scoreDocs = topDocs.scoreDocs; int[] docIds = new int[scoreDocs.length]; float[] scores = new float[scoreDocs.length]; @@ -82,7 +74,7 @@ public Query innerQuery() { return innerQuery; } - public Integer k() { + public int k() { return k; } @@ -92,10 +84,7 @@ public void profile(QueryProfiler queryProfiler) { queryProfilerProvider.profile(queryProfiler); } - if (vectorProfiling == null) { - throw new IllegalStateException("Query should have been rewritten"); - } - vectorProfiling.profile(queryProfiler); + queryProfiler.addVectorOpsCount(vectorOperations); } @Override diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java index 342d61b78defd..742797e000517 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java @@ -1674,7 +1674,7 @@ public void testByteVectorQueryBoundaries() throws IOException { Exception e = expectThrows( IllegalArgumentException.class, - () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 128, 0, 0 }), 3, 3, null, null, null, null) + () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 128, 0, 0 }), 3, 3, 3, null, null, null, null) ); assertThat( e.getMessage(), @@ -1687,6 +1687,7 @@ public void testByteVectorQueryBoundaries() throws IOException { VectorData.fromFloats(new float[] { 0.0f, 0f, -129.0f }), 3, 3, + 3, null, null, null, @@ -1700,7 +1701,16 @@ public void testByteVectorQueryBoundaries() throws IOException { e = expectThrows( IllegalArgumentException.class, - () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 0.0f, 0.5f, 0.0f }), 3, 3, null, null, null, null) + () -> denseVectorFieldType.createKnnQuery( + VectorData.fromFloats(new float[] { 0.0f, 0.5f, 0.0f }), + 3, + 3, + 3, + null, + null, + null, + null + ) ); assertThat( e.getMessage(), @@ -1709,7 +1719,16 @@ public void testByteVectorQueryBoundaries() throws IOException { e = expectThrows( IllegalArgumentException.class, - () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 0, 0.0f, -0.25f }), 3, 3, null, null, null, null) + () -> denseVectorFieldType.createKnnQuery( + VectorData.fromFloats(new float[] { 0, 0.0f, -0.25f }), + 3, + 3, + 3, + null, + null, + null, + null + ) ); assertThat( e.getMessage(), @@ -1722,6 +1741,7 @@ public void testByteVectorQueryBoundaries() throws IOException { VectorData.fromFloats(new float[] { Float.NaN, 0f, 0.0f }), 3, 3, + 3, null, null, null, @@ -1736,6 +1756,7 @@ public void testByteVectorQueryBoundaries() throws IOException { VectorData.fromFloats(new float[] { Float.POSITIVE_INFINITY, 0f, 0.0f }), 3, 3, + 3, null, null, null, @@ -1753,6 +1774,7 @@ public void testByteVectorQueryBoundaries() throws IOException { VectorData.fromFloats(new float[] { 0, Float.NEGATIVE_INFINITY, 0.0f }), 3, 3, + 3, null, null, null, @@ -1787,6 +1809,7 @@ public void testFloatVectorQueryBoundaries() throws IOException { VectorData.fromFloats(new float[] { Float.NaN, 0f, 0.0f }), 3, 3, + 3, null, null, null, @@ -1801,6 +1824,7 @@ public void testFloatVectorQueryBoundaries() throws IOException { VectorData.fromFloats(new float[] { Float.POSITIVE_INFINITY, 0f, 0.0f }), 3, 3, + 3, null, null, null, @@ -1818,6 +1842,7 @@ public void testFloatVectorQueryBoundaries() throws IOException { VectorData.fromFloats(new float[] { 0, Float.NEGATIVE_INFINITY, 0.0f }), 3, 3, + 3, null, null, null, diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index d37b4a4bacb4e..5f580c9ec5352 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -193,7 +193,7 @@ public void testCreateNestedKnnQuery() { for (int i = 0; i < dims; i++) { queryVector[i] = randomFloat(); } - Query query = field.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null, producer); + Query query = field.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, 10, null, null, null, producer); assertThat(query, instanceOf(DiversifyingChildrenFloatKnnVectorQuery.class)); } { @@ -214,11 +214,11 @@ public void testCreateNestedKnnQuery() { floatQueryVector[i] = queryVector[i]; } VectorData vectorData = new VectorData(null, queryVector); - Query query = field.createKnnQuery(vectorData, 10, 10, null, null, null, producer); + Query query = field.createKnnQuery(vectorData, 10, 10, 10, null, null, null, producer); assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class)); vectorData = new VectorData(floatQueryVector, null); - query = field.createKnnQuery(vectorData, 10, 10, null, null, null, producer); + query = field.createKnnQuery(vectorData, 10, 10, 10, null, null, null, producer); assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class)); } } @@ -283,6 +283,7 @@ public void testFloatCreateKnnQuery() { VectorData.fromFloats(new float[] { 0.3f, 0.1f, 1.0f, 0.0f }), 10, 10, + 10, null, null, null, @@ -307,7 +308,7 @@ public void testFloatCreateKnnQuery() { } e = expectThrows( IllegalArgumentException.class, - () -> dotProductField.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null, null) + () -> dotProductField.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, 10, null, null, null, null) ); assertThat(e.getMessage(), containsString("The [dot_product] similarity can only be used with unit-length vectors.")); @@ -323,7 +324,7 @@ public void testFloatCreateKnnQuery() { ); e = expectThrows( IllegalArgumentException.class, - () -> cosineField.createKnnQuery(VectorData.fromFloats(new float[BBQ_MIN_DIMS]), 10, 10, null, null, null, null) + () -> cosineField.createKnnQuery(VectorData.fromFloats(new float[BBQ_MIN_DIMS]), 10, 10, 10, null, null, null, null) ); assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude.")); } @@ -344,7 +345,7 @@ public void testCreateKnnQueryMaxDims() { for (int i = 0; i < 4096; i++) { queryVector[i] = randomFloat(); } - Query query = fieldWith4096dims.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null, null); + Query query = fieldWith4096dims.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, 10, null, null, null, null); assertThat(query, instanceOf(KnnFloatVectorQuery.class)); } @@ -364,7 +365,7 @@ public void testCreateKnnQueryMaxDims() { queryVector[i] = randomByte(); } VectorData vectorData = new VectorData(null, queryVector); - Query query = fieldWith4096dims.createKnnQuery(vectorData, 10, 10, null, null, null, null); + Query query = fieldWith4096dims.createKnnQuery(vectorData, 10, 10, 10, null, null, null, null); assertThat(query, instanceOf(KnnByteVectorQuery.class)); } } @@ -382,7 +383,7 @@ public void testByteCreateKnnQuery() { ); IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> unindexedField.createKnnQuery(VectorData.fromFloats(new float[] { 0.3f, 0.1f, 1.0f }), 10, 10, null, null, null, null) + () -> unindexedField.createKnnQuery(VectorData.fromFloats(new float[] { 0.3f, 0.1f, 1.0f }), 10, 10, 10, null, null, null, null) ); assertThat(e.getMessage(), containsString("to perform knn search on field [f], its mapping must have [index] set to [true]")); @@ -398,13 +399,13 @@ public void testByteCreateKnnQuery() { ); e = expectThrows( IllegalArgumentException.class, - () -> cosineField.createKnnQuery(VectorData.fromFloats(new float[] { 0.0f, 0.0f, 0.0f }), 10, 10, null, null, null, null) + () -> cosineField.createKnnQuery(VectorData.fromFloats(new float[] { 0.0f, 0.0f, 0.0f }), 10, 10, 10, null, null, null, null) ); assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude.")); e = expectThrows( IllegalArgumentException.class, - () -> cosineField.createKnnQuery(new VectorData(null, new byte[] { 0, 0, 0 }), 10, 10, null, null, null, null) + () -> cosineField.createKnnQuery(new VectorData(null, new byte[] { 0, 0, 0 }), 10, 10, 10, null, null, null, null) ); assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude.")); } @@ -426,6 +427,7 @@ public void testRescoreOversampleUsedWithoutQuantization() { new VectorData(null, new byte[] { 1, 4, 10 }), 10, 100, + 10, randomFloatBetween(1.0F, 10.0F, false), null, null, @@ -456,18 +458,29 @@ public void testRescoreOversampleModifiesNumCandidates() { ); // Total results is k, internal k is multiplied by oversample - checkRescoreQueryParameters(fieldType, 10, 200, 2.5F, null, 500, 10); + checkRescoreQueryParameters(fieldType, 10, 200, randomInt(), 2.5F, null, 500, 10); // If numCands < k, update numCands to k - checkRescoreQueryParameters(fieldType, 10, 20, 2.5F, null, 50, 10); + checkRescoreQueryParameters(fieldType, 10, 20, randomInt(), 2.5F, null, 50, 10); // Oversampling limits for num candidates - checkRescoreQueryParameters(fieldType, 1000, 1000, 11.0F, null, 10000, 1000); - checkRescoreQueryParameters(fieldType, 5000, 7500, 2.5F, null, 10000, 5000); + checkRescoreQueryParameters(fieldType, 1000, 1000, randomInt(), 11.0F, null, 10000, 1000); + checkRescoreQueryParameters(fieldType, 5000, 7500, randomInt(), 2.5F, null, 10000, 5000); + + // Check the same as the above, for null k - take request size + // Total results is k, internal k is multiplied by oversample + checkRescoreQueryParameters(fieldType, null, 200, 25, 2.5F, null, 500, 25); + // If numCands < k, update numCands to k + checkRescoreQueryParameters(fieldType, null, 20, 25, 2.5F, null, 50, 25); + // Oversampling limits for num candidates + checkRescoreQueryParameters(fieldType, null, 1000, 25, 11.0F, null, 10000, 25); + checkRescoreQueryParameters(fieldType, null, 7500, 25, 2.5F, null, 10000, 25); + } private static void checkRescoreQueryParameters( DenseVectorFieldType fieldType, Integer k, int candidates, + int requestSize, float numCandsFactor, Integer expectedK, int expectedCandidates, @@ -477,6 +490,7 @@ private static void checkRescoreQueryParameters( VectorData.fromFloats(new float[] { 1, 4, 10 }), k, candidates, + requestSize, numCandsFactor, null, null, diff --git a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java index 7bbe7dcc155c5..861a8b11db567 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java @@ -9,8 +9,6 @@ package org.elasticsearch.search.vectors; -import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; - import org.apache.lucene.document.Document; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.DirectoryReader; @@ -33,11 +31,9 @@ import java.io.IOException; import java.io.UnsupportedEncodingException; -import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.HashSet; -import java.util.List; import java.util.Map; import java.util.PriorityQueue; import java.util.stream.Collectors; @@ -49,21 +45,11 @@ public class RescoreKnnVectorQueryTests extends ESTestCase { public static final String FIELD_NAME = "float_vector"; - private final int numDocs; - private final Integer k; - - public RescoreKnnVectorQueryTests(boolean useK) { - this.numDocs = randomIntBetween(10, 100); - this.k = useK ? randomIntBetween(1, numDocs - 1) : null; - } public void testRescoreDocs() throws Exception { + int numDocs = randomIntBetween(10, 100); int numDims = randomIntBetween(5, 100); - - Integer adjustedK = k; - if (k == null) { - adjustedK = numDocs; - } + int k = randomIntBetween(1, numDocs - 1); try (Directory d = newDirectory()) { addRandomDocuments(numDocs, d, numDims); @@ -77,7 +63,7 @@ public void testRescoreDocs() throws Exception { FIELD_NAME, queryVector, VectorSimilarityFunction.COSINE, - adjustedK, + k, new MatchAllDocsQuery() ); @@ -86,7 +72,7 @@ public void testRescoreDocs() throws Exception { Map rescoredDocs = Arrays.stream(docs.scoreDocs) .collect(Collectors.toMap(scoreDoc -> scoreDoc.doc, scoreDoc -> scoreDoc.score)); - assertThat(rescoredDocs.size(), equalTo(adjustedK)); + assertThat(rescoredDocs.size(), equalTo(k)); Collection rescoredScores = new HashSet<>(rescoredDocs.values()); @@ -113,7 +99,7 @@ public void testRescoreDocs() throws Exception { assertThat(rescoredDocs.size(), equalTo(0)); // Check top scoring docs are contained in rescored docs - for (int i = 0; i < adjustedK; i++) { + for (int i = 0; i < k; i++) { Float topScore = topK.poll(); if (rescoredScores.contains(topScore) == false) { fail("Top score " + topScore + " not contained in rescored doc scores " + rescoredScores); @@ -124,7 +110,9 @@ public void testRescoreDocs() throws Exception { } public void testProfiling() throws Exception { + int numDocs = randomIntBetween(10, 100); int numDims = randomIntBetween(5, 100); + int k = randomIntBetween(1, numDocs - 1); try (Directory d = newDirectory()) { addRandomDocuments(numDocs, d, numDims); @@ -132,13 +120,13 @@ public void testProfiling() throws Exception { try (IndexReader reader = DirectoryReader.open(d)) { float[] queryVector = randomVector(numDims); - checkProfiling(queryVector, reader, new MatchAllDocsQuery()); - checkProfiling(queryVector, reader, new MockQueryProfilerProvider(randomIntBetween(1, 100))); + checkProfiling(k, numDocs, queryVector, reader, new MatchAllDocsQuery()); + checkProfiling(k, numDocs, queryVector, reader, new MockQueryProfilerProvider(randomIntBetween(1, 100))); } } } - private void checkProfiling(float[] queryVector, IndexReader reader, Query innerQuery) throws IOException { + private void checkProfiling(int k, int numDocs, float[] queryVector, IndexReader reader, Query innerQuery) throws IOException { RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery( FIELD_NAME, queryVector, @@ -229,13 +217,4 @@ private static void addRandomDocuments(int numDocs, Directory d, int numDims) th w.forceMerge(1); } } - - @ParametersFactory - public static Iterable parameters() { - List params = new ArrayList<>(); - params.add(new Object[] { true }); - params.add(new Object[] { false }); - - return params; - } } From 0a2f895775e6db90b6b8de3dcb74616e7da7bbd6 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 3 Jan 2025 17:17:49 +0100 Subject: [PATCH 2/3] Remove unnecessary override --- .../elasticsearch/search/vectors/KnnVectorQueryBuilder.java | 5 ----- 1 file changed, 5 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java index c4d6c3da7b5f0..4d422ddfad051 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java @@ -487,11 +487,6 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException { return this; } - @Override - protected QueryBuilder doIndexMetadataRewrite(QueryRewriteContext context) throws IOException { - return super.doIndexMetadataRewrite(context); - } - @Override protected Query doToQuery(SearchExecutionContext context) throws IOException { MappedFieldType fieldType = context.getFieldType(fieldName); From f1e2972f000284802518dbe305e912ad44e3d804 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 8 Jan 2025 08:12:36 +0100 Subject: [PATCH 3/3] Remove request size as it is already provided from the query using k --- .../vectors/DenseVectorFieldMapper.java | 5 +-- .../search/vectors/KnnVectorQueryBuilder.java | 1 - .../vectors/DenseVectorFieldMapperTests.java | 31 ++-------------- .../vectors/DenseVectorFieldTypeTests.java | 35 ++++++------------- 4 files changed, 15 insertions(+), 57 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index e224eaac71d36..b2b23baacc4db 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -2021,7 +2021,6 @@ public Query createKnnQuery( VectorData queryVector, int k, int numCands, - int requestSize, Float numCandsFactor, Query filter, Float similarityThreshold, @@ -2038,7 +2037,6 @@ public Query createKnnQuery( queryVector.asFloatVector(), k, numCands, - requestSize, numCandsFactor, filter, similarityThreshold, @@ -2105,7 +2103,6 @@ private Query createKnnFloatQuery( float[] queryVector, int k, int numCands, - int requestSize, Float numCandsFactor, Query filter, Float similarityThreshold, @@ -2143,7 +2140,7 @@ && isNotUnitVector(squaredMagnitude)) { name(), queryVector, similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.FLOAT), - k == null ? requestSize : k, + k, knnQuery ); } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java index 1f509f460d0fb..4a6dbab399141 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java @@ -560,7 +560,6 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { queryVector, k, adjustedNumCands, - requestSize, numCandidatesFactor, filterQuery, vectorSimilarity, diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java index f0e8d2943517c..3f574a29469c2 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java @@ -1674,7 +1674,7 @@ public void testByteVectorQueryBoundaries() throws IOException { Exception e = expectThrows( IllegalArgumentException.class, - () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 128, 0, 0 }), 3, 3, 3, null, null, null, null) + () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 128, 0, 0 }), 3, 3, null, null, null, null) ); assertThat( e.getMessage(), @@ -1687,7 +1687,6 @@ public void testByteVectorQueryBoundaries() throws IOException { VectorData.fromFloats(new float[] { 0.0f, 0f, -129.0f }), 3, 3, - 3, null, null, null, @@ -1701,16 +1700,7 @@ public void testByteVectorQueryBoundaries() throws IOException { e = expectThrows( IllegalArgumentException.class, - () -> denseVectorFieldType.createKnnQuery( - VectorData.fromFloats(new float[] { 0.0f, 0.5f, 0.0f }), - 3, - 3, - 3, - null, - null, - null, - null - ) + () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 0.0f, 0.5f, 0.0f }), 3, 3, null, null, null, null) ); assertThat( e.getMessage(), @@ -1719,16 +1709,7 @@ public void testByteVectorQueryBoundaries() throws IOException { e = expectThrows( IllegalArgumentException.class, - () -> denseVectorFieldType.createKnnQuery( - VectorData.fromFloats(new float[] { 0, 0.0f, -0.25f }), - 3, - 3, - 3, - null, - null, - null, - null - ) + () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 0, 0.0f, -0.25f }), 3, 3, null, null, null, null) ); assertThat( e.getMessage(), @@ -1741,7 +1722,6 @@ public void testByteVectorQueryBoundaries() throws IOException { VectorData.fromFloats(new float[] { Float.NaN, 0f, 0.0f }), 3, 3, - 3, null, null, null, @@ -1756,7 +1736,6 @@ public void testByteVectorQueryBoundaries() throws IOException { VectorData.fromFloats(new float[] { Float.POSITIVE_INFINITY, 0f, 0.0f }), 3, 3, - 3, null, null, null, @@ -1774,7 +1753,6 @@ public void testByteVectorQueryBoundaries() throws IOException { VectorData.fromFloats(new float[] { 0, Float.NEGATIVE_INFINITY, 0.0f }), 3, 3, - 3, null, null, null, @@ -1809,7 +1787,6 @@ public void testFloatVectorQueryBoundaries() throws IOException { VectorData.fromFloats(new float[] { Float.NaN, 0f, 0.0f }), 3, 3, - 3, null, null, null, @@ -1824,7 +1801,6 @@ public void testFloatVectorQueryBoundaries() throws IOException { VectorData.fromFloats(new float[] { Float.POSITIVE_INFINITY, 0f, 0.0f }), 3, 3, - 3, null, null, null, @@ -1842,7 +1818,6 @@ public void testFloatVectorQueryBoundaries() throws IOException { VectorData.fromFloats(new float[] { 0, Float.NEGATIVE_INFINITY, 0.0f }), 3, 3, - 3, null, null, null, diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index 5f580c9ec5352..be4c677d20b03 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -193,7 +193,7 @@ public void testCreateNestedKnnQuery() { for (int i = 0; i < dims; i++) { queryVector[i] = randomFloat(); } - Query query = field.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, 10, null, null, null, producer); + Query query = field.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null, producer); assertThat(query, instanceOf(DiversifyingChildrenFloatKnnVectorQuery.class)); } { @@ -214,11 +214,11 @@ public void testCreateNestedKnnQuery() { floatQueryVector[i] = queryVector[i]; } VectorData vectorData = new VectorData(null, queryVector); - Query query = field.createKnnQuery(vectorData, 10, 10, 10, null, null, null, producer); + Query query = field.createKnnQuery(vectorData, 10, 10, null, null, null, producer); assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class)); vectorData = new VectorData(floatQueryVector, null); - query = field.createKnnQuery(vectorData, 10, 10, 10, null, null, null, producer); + query = field.createKnnQuery(vectorData, 10, 10, null, null, null, producer); assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class)); } } @@ -283,7 +283,6 @@ public void testFloatCreateKnnQuery() { VectorData.fromFloats(new float[] { 0.3f, 0.1f, 1.0f, 0.0f }), 10, 10, - 10, null, null, null, @@ -308,7 +307,7 @@ public void testFloatCreateKnnQuery() { } e = expectThrows( IllegalArgumentException.class, - () -> dotProductField.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, 10, null, null, null, null) + () -> dotProductField.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null, null) ); assertThat(e.getMessage(), containsString("The [dot_product] similarity can only be used with unit-length vectors.")); @@ -324,7 +323,7 @@ public void testFloatCreateKnnQuery() { ); e = expectThrows( IllegalArgumentException.class, - () -> cosineField.createKnnQuery(VectorData.fromFloats(new float[BBQ_MIN_DIMS]), 10, 10, 10, null, null, null, null) + () -> cosineField.createKnnQuery(VectorData.fromFloats(new float[BBQ_MIN_DIMS]), 10, 10, null, null, null, null) ); assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude.")); } @@ -345,7 +344,7 @@ public void testCreateKnnQueryMaxDims() { for (int i = 0; i < 4096; i++) { queryVector[i] = randomFloat(); } - Query query = fieldWith4096dims.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, 10, null, null, null, null); + Query query = fieldWith4096dims.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null, null); assertThat(query, instanceOf(KnnFloatVectorQuery.class)); } @@ -365,7 +364,7 @@ public void testCreateKnnQueryMaxDims() { queryVector[i] = randomByte(); } VectorData vectorData = new VectorData(null, queryVector); - Query query = fieldWith4096dims.createKnnQuery(vectorData, 10, 10, 10, null, null, null, null); + Query query = fieldWith4096dims.createKnnQuery(vectorData, 10, 10, null, null, null, null); assertThat(query, instanceOf(KnnByteVectorQuery.class)); } } @@ -383,7 +382,7 @@ public void testByteCreateKnnQuery() { ); IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> unindexedField.createKnnQuery(VectorData.fromFloats(new float[] { 0.3f, 0.1f, 1.0f }), 10, 10, 10, null, null, null, null) + () -> unindexedField.createKnnQuery(VectorData.fromFloats(new float[] { 0.3f, 0.1f, 1.0f }), 10, 10, null, null, null, null) ); assertThat(e.getMessage(), containsString("to perform knn search on field [f], its mapping must have [index] set to [true]")); @@ -399,13 +398,13 @@ public void testByteCreateKnnQuery() { ); e = expectThrows( IllegalArgumentException.class, - () -> cosineField.createKnnQuery(VectorData.fromFloats(new float[] { 0.0f, 0.0f, 0.0f }), 10, 10, 10, null, null, null, null) + () -> cosineField.createKnnQuery(VectorData.fromFloats(new float[] { 0.0f, 0.0f, 0.0f }), 10, 10, null, null, null, null) ); assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude.")); e = expectThrows( IllegalArgumentException.class, - () -> cosineField.createKnnQuery(new VectorData(null, new byte[] { 0, 0, 0 }), 10, 10, 10, null, null, null, null) + () -> cosineField.createKnnQuery(new VectorData(null, new byte[] { 0, 0, 0 }), 10, 10, null, null, null, null) ); assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude.")); } @@ -427,7 +426,6 @@ public void testRescoreOversampleUsedWithoutQuantization() { new VectorData(null, new byte[] { 1, 4, 10 }), 10, 100, - 10, randomFloatBetween(1.0F, 10.0F, false), null, null, @@ -464,21 +462,11 @@ public void testRescoreOversampleModifiesNumCandidates() { // Oversampling limits for num candidates checkRescoreQueryParameters(fieldType, 1000, 1000, randomInt(), 11.0F, null, 10000, 1000); checkRescoreQueryParameters(fieldType, 5000, 7500, randomInt(), 2.5F, null, 10000, 5000); - - // Check the same as the above, for null k - take request size - // Total results is k, internal k is multiplied by oversample - checkRescoreQueryParameters(fieldType, null, 200, 25, 2.5F, null, 500, 25); - // If numCands < k, update numCands to k - checkRescoreQueryParameters(fieldType, null, 20, 25, 2.5F, null, 50, 25); - // Oversampling limits for num candidates - checkRescoreQueryParameters(fieldType, null, 1000, 25, 11.0F, null, 10000, 25); - checkRescoreQueryParameters(fieldType, null, 7500, 25, 2.5F, null, 10000, 25); - } private static void checkRescoreQueryParameters( DenseVectorFieldType fieldType, - Integer k, + int k, int candidates, int requestSize, float numCandsFactor, @@ -490,7 +478,6 @@ private static void checkRescoreQueryParameters( VectorData.fromFloats(new float[] { 1, 4, 10 }), k, candidates, - requestSize, numCandsFactor, null, null,