diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenByteKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenByteKnnVectorQuery.java index 46b2f0a09cf7f..ddc427868c672 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenByteKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenByteKnnVectorQuery.java @@ -17,14 +17,14 @@ import org.elasticsearch.search.profile.query.QueryProfiler; public class ESDiversifyingChildrenByteKnnVectorQuery extends DiversifyingChildrenByteKnnVectorQuery implements QueryProfilerProvider { - private final Integer kParam; + private final int kParam; private long vectorOpsCount; public ESDiversifyingChildrenByteKnnVectorQuery( String field, byte[] query, Query childFilter, - Integer k, + int k, int numCands, BitSetProducer parentsFilter, KnnSearchStrategy strategy @@ -35,7 +35,7 @@ public ESDiversifyingChildrenByteKnnVectorQuery( @Override protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { - TopDocs topK = kParam == null ? super.mergeLeafResults(perLeafResults) : TopDocs.merge(kParam, perLeafResults); + TopDocs topK = TopDocs.merge(kParam, perLeafResults); vectorOpsCount = topK.totalHits.value(); return topK; } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenFloatKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenFloatKnnVectorQuery.java index 5635281ab0e8a..42e33cd948aef 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenFloatKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenFloatKnnVectorQuery.java @@ -17,14 +17,14 @@ import org.elasticsearch.search.profile.query.QueryProfiler; public class ESDiversifyingChildrenFloatKnnVectorQuery extends DiversifyingChildrenFloatKnnVectorQuery implements QueryProfilerProvider { - private final Integer kParam; + private final int kParam; private long vectorOpsCount; public ESDiversifyingChildrenFloatKnnVectorQuery( String field, float[] query, Query childFilter, - Integer k, + int k, int numCands, BitSetProducer parentsFilter, KnnSearchStrategy strategy @@ -35,7 +35,7 @@ public ESDiversifyingChildrenFloatKnnVectorQuery( @Override protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { - TopDocs topK = kParam == null ? super.mergeLeafResults(perLeafResults) : TopDocs.merge(kParam, perLeafResults); + TopDocs topK = TopDocs.merge(kParam, perLeafResults); vectorOpsCount = topK.totalHits.value(); return topK; } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java index 295efd8f9b05e..6e90da12bd7e7 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java @@ -16,10 +16,10 @@ import org.elasticsearch.search.profile.query.QueryProfiler; public class ESKnnByteVectorQuery extends KnnByteVectorQuery implements QueryProfilerProvider { - private final Integer kParam; + private final int kParam; private long vectorOpsCount; - public ESKnnByteVectorQuery(String field, byte[] target, Integer k, int numCands, Query filter, KnnSearchStrategy strategy) { + public ESKnnByteVectorQuery(String field, byte[] target, int k, int numCands, Query filter, KnnSearchStrategy strategy) { super(field, target, numCands, filter, strategy); this.kParam = k; } @@ -27,7 +27,7 @@ public ESKnnByteVectorQuery(String field, byte[] target, Integer k, int numCands @Override protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { // if k param is set, we get only top k results from each shard - TopDocs topK = kParam == null ? super.mergeLeafResults(perLeafResults) : TopDocs.merge(kParam, perLeafResults); + TopDocs topK = TopDocs.merge(kParam, perLeafResults); vectorOpsCount = topK.totalHits.value(); return topK; } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java index 8ef4aad147049..04f6104476c51 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java @@ -16,10 +16,10 @@ import org.elasticsearch.search.profile.query.QueryProfiler; public class ESKnnFloatVectorQuery extends KnnFloatVectorQuery implements QueryProfilerProvider { - private final Integer kParam; + private final int kParam; private long vectorOpsCount; - public ESKnnFloatVectorQuery(String field, float[] target, Integer k, int numCands, Query filter, KnnSearchStrategy strategy) { + public ESKnnFloatVectorQuery(String field, float[] target, int k, int numCands, Query filter, KnnSearchStrategy strategy) { super(field, target, numCands, filter, strategy); this.kParam = k; } @@ -27,7 +27,7 @@ public ESKnnFloatVectorQuery(String field, float[] target, Integer k, int numCan @Override protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { // if k param is set, we get only top k results from each shard - TopDocs topK = kParam == null ? super.mergeLeafResults(perLeafResults) : TopDocs.merge(kParam, perLeafResults); + TopDocs topK = TopDocs.merge(kParam, perLeafResults); vectorOpsCount = topK.totalHits.value(); return topK; } @@ -37,7 +37,7 @@ public void profile(QueryProfiler queryProfiler) { queryProfiler.addVectorOpsCount(vectorOpsCount); } - public Integer kParam() { + public int kParam() { return kParam; } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index c02cb5436e90d..15d36670929b1 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -712,7 +712,7 @@ private static void checkRescoreQueryParameters( int k, int candidates, float oversample, - Integer expectedK, + int expectedK, int expectedCandidates, int expectedResults ) { diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java index 27549b3c4030b..0e295fb02eaaa 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java @@ -187,14 +187,16 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que assertThat(((VectorSimilarityQuery) query).getSimilarity(), equalTo(queryBuilder.getVectorSimilarity())); query = ((VectorSimilarityQuery) query).getInnerKnnQuery(); } - Integer k = queryBuilder.k(); - if (k == null) { + int k; + if (queryBuilder.k() == null) { k = context.requestSize() == null || context.requestSize() < 0 ? DEFAULT_SIZE : context.requestSize(); + } else { + k = queryBuilder.k(); } if (queryBuilder.rescoreVectorBuilder() != null && isQuantizedElementType()) { if (queryBuilder.rescoreVectorBuilder().oversample() > 0) { RescoreKnnVectorQuery rescoreQuery = (RescoreKnnVectorQuery) query; - assertEquals(k.intValue(), (rescoreQuery.k())); + assertEquals(k, (rescoreQuery.k())); query = rescoreQuery.innerQuery(); } else { assertFalse(query instanceof RescoreKnnVectorQuery); @@ -213,7 +215,7 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que Query filterQuery = booleanQuery.clauses().isEmpty() ? null : booleanQuery; Integer numCands = queryBuilder.numCands(); if (queryBuilder.rescoreVectorBuilder() != null && isQuantizedElementType()) { - Float oversample = queryBuilder.rescoreVectorBuilder().oversample(); + float oversample = queryBuilder.rescoreVectorBuilder().oversample(); k = Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * oversample)); numCands = Math.max(numCands, k); } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java index 8cca3f9ed8a21..33ab8324ffb96 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java @@ -110,10 +110,10 @@ protected KnnSearchBuilder createTestInstance() { @Override protected KnnSearchBuilder mutateInstance(KnnSearchBuilder instance) { - switch (random().nextInt(8)) { - case 0: + return switch (random().nextInt(8)) { + case 0 -> { String newField = randomValueOtherThan(instance.field, () -> randomAlphaOfLength(5)); - return new KnnSearchBuilder( + yield new KnnSearchBuilder( newField, instance.queryVector, instance.k, @@ -121,9 +121,10 @@ protected KnnSearchBuilder mutateInstance(KnnSearchBuilder instance) { instance.getRescoreVectorBuilder(), instance.similarity ).boost(instance.boost); - case 1: + } + case 1 -> { float[] newVector = randomValueOtherThan(instance.queryVector.asFloatVector(), () -> randomVector(5)); - return new KnnSearchBuilder( + yield new KnnSearchBuilder( instance.field, newVector, instance.k, @@ -131,10 +132,11 @@ protected KnnSearchBuilder mutateInstance(KnnSearchBuilder instance) { instance.getRescoreVectorBuilder(), instance.similarity ).boost(instance.boost); - case 2: + } + case 2 -> { // given how the test instance is created, we have a 20-value gap between `k` and `numCands` so we SHOULD be safe Integer newK = randomValueOtherThan(instance.k, () -> instance.k + ESTestCase.randomInt(10)); - return new KnnSearchBuilder( + yield new KnnSearchBuilder( instance.field, instance.queryVector, newK, @@ -142,9 +144,10 @@ protected KnnSearchBuilder mutateInstance(KnnSearchBuilder instance) { instance.getRescoreVectorBuilder(), instance.similarity ).boost(instance.boost); - case 3: + } + case 3 -> { Integer newNumCands = randomValueOtherThan(instance.numCands, () -> instance.numCands + ESTestCase.randomInt(100)); - return new KnnSearchBuilder( + yield new KnnSearchBuilder( instance.field, instance.queryVector, instance.k, @@ -152,20 +155,20 @@ protected KnnSearchBuilder mutateInstance(KnnSearchBuilder instance) { instance.getRescoreVectorBuilder(), instance.similarity ).boost(instance.boost); - case 4: - return new KnnSearchBuilder( - instance.field, - instance.queryVector, - instance.k, - instance.numCands, - instance.getRescoreVectorBuilder(), - instance.similarity - ).addFilterQueries(instance.filterQueries) - .addFilterQuery(QueryBuilders.termQuery("new_field", "new-value")) - .boost(instance.boost); - case 5: + } + case 4 -> new KnnSearchBuilder( + instance.field, + instance.queryVector, + instance.k, + instance.numCands, + instance.getRescoreVectorBuilder(), + instance.similarity + ).addFilterQueries(instance.filterQueries) + .addFilterQuery(QueryBuilders.termQuery("new_field", "new-value")) + .boost(instance.boost); + case 5 -> { float newBoost = randomValueOtherThan(instance.boost, ESTestCase::randomFloat); - return new KnnSearchBuilder( + yield new KnnSearchBuilder( instance.field, instance.queryVector, instance.k, @@ -173,30 +176,28 @@ protected KnnSearchBuilder mutateInstance(KnnSearchBuilder instance) { instance.getRescoreVectorBuilder(), instance.similarity ).addFilterQueries(instance.filterQueries).boost(newBoost); - case 6: - return new KnnSearchBuilder( - instance.field, - instance.queryVector, - instance.k, - instance.numCands, + } + case 6 -> new KnnSearchBuilder( + instance.field, + instance.queryVector, + instance.k, + instance.numCands, + instance.getRescoreVectorBuilder(), + randomValueOtherThan(instance.similarity, ESTestCase::randomFloat) + ).addFilterQueries(instance.filterQueries).boost(instance.boost); + case 7 -> new KnnSearchBuilder( + instance.field, + instance.queryVector, + instance.k, + instance.numCands, + randomValueOtherThan( instance.getRescoreVectorBuilder(), - randomValueOtherThan(instance.similarity, ESTestCase::randomFloat) - ).addFilterQueries(instance.filterQueries).boost(instance.boost); - case 7: - return new KnnSearchBuilder( - instance.field, - instance.queryVector, - instance.k, - instance.numCands, - randomValueOtherThan( - instance.getRescoreVectorBuilder(), - () -> new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)) - ), - instance.similarity - ).addFilterQueries(instance.filterQueries).boost(instance.boost); - default: - throw new IllegalStateException(); - } + () -> new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)) + ), + instance.similarity + ).addFilterQueries(instance.filterQueries).boost(instance.boost); + default -> throw new IllegalStateException(); + }; } public void testToQueryBuilder() { diff --git a/server/src/test/java/org/elasticsearch/search/vectors/TestQueryVectorBuilderPlugin.java b/server/src/test/java/org/elasticsearch/search/vectors/TestQueryVectorBuilderPlugin.java index 5733a51bb7e9c..0b320c709e9c5 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/TestQueryVectorBuilderPlugin.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/TestQueryVectorBuilderPlugin.java @@ -45,7 +45,7 @@ public static class TestQueryVectorBuilder implements QueryVectorBuilder { PARSER.declareFloatArray(ConstructingObjectParser.constructorArg(), QUERY_VECTOR); } - private List vectorToBuild; + private final List vectorToBuild; public TestQueryVectorBuilder(List vectorToBuild) { this.vectorToBuild = vectorToBuild;