diff --git a/docs/changelog/118774.yaml b/docs/changelog/118774.yaml new file mode 100644 index 0000000000000..cbd1ca82d1c59 --- /dev/null +++ b/docs/changelog/118774.yaml @@ -0,0 +1,5 @@ +pr: 118774 +summary: Apply default k for knn query eagerly +area: Vector Search +type: bug +issues: [] diff --git a/docs/reference/query-dsl/knn-query.asciidoc b/docs/reference/query-dsl/knn-query.asciidoc index daf9e9499a189..e42bd78d9f14a 100644 --- a/docs/reference/query-dsl/knn-query.asciidoc +++ b/docs/reference/query-dsl/knn-query.asciidoc @@ -100,7 +100,7 @@ include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=knn-query-vector-builde -- (Optional, integer) The number of nearest neighbors to return from each shard. {es} collects `k` results from each shard, then merges them to find the global top results. -This value must be less than or equal to `num_candidates`. Defaults to `num_candidates`. +This value must be less than or equal to `num_candidates`. Defaults to search request size. -- `num_candidates`:: diff --git a/rest-api-spec/build.gradle b/rest-api-spec/build.gradle index ed8a1f147b4fa..a104ec675adc2 100644 --- a/rest-api-spec/build.gradle +++ b/rest-api-spec/build.gradle @@ -60,5 +60,23 @@ tasks.named("yamlRestCompatTestTransform").configure ({ task -> task.skipTest("cat.aliases/10_basic/Deprecated local parameter", "CAT APIs not covered by compatibility policy") task.skipTest("cat.shards/10_basic/Help", "sync_id is removed in 9.0") task.skipTest("search/500_date_range/from, to, include_lower, include_upper deprecated", "deprecated parameters are removed in 9.0") + task.skipTest("search.vectors/41_knn_search_bbq_hnsw/Test knn search", "Scoring has changed in latest versions") + task.skipTest("search.vectors/42_knn_search_bbq_flat/Test knn search", "Scoring has changed in latest versions") + task.skipTest("search.vectors/180_update_dense_vector_type/Test create and update dense vector mapping with bulk indexing", "waiting for #118774 backport") + task.skipTest("search.vectors/160_knn_query_missing_params/kNN query in a bool clause - missing num_candidates", "waiting for #118774 backport") + task.skipTest("search.vectors/110_knn_query_with_filter/Simple knn query", "waiting for #118774 backport") + task.skipTest("search.vectors/160_knn_query_missing_params/kNN search used in nested field - missing num_candidates", "waiting for #118774 backport") + task.skipTest("search.vectors/180_update_dense_vector_type/Test create and update dense vector mapping to int4 with per-doc indexing and flush", "waiting for #118774 backport") + task.skipTest("search.vectors/110_knn_query_with_filter/PRE_FILTER: knn query with internal filter as pre-filter", "waiting for #118774 backport") + task.skipTest("search.vectors/180_update_dense_vector_type/Index, update and merge", "waiting for #118774 backport") + task.skipTest("search.vectors/160_knn_query_missing_params/kNN query with missing num_candidates param - size provided", "waiting for #118774 backport") + task.skipTest("search.vectors/110_knn_query_with_filter/POST_FILTER: knn query with filter from a parent bool query as post-filter", "waiting for #118774 backport") + task.skipTest("search.vectors/120_knn_query_multiple_shards/Aggregations with collected number of docs depends on num_candidates", "waiting for #118774 backport") + task.skipTest("search.vectors/180_update_dense_vector_type/Test create and update dense vector mapping with per-doc indexing and flush", "waiting for #118774 backport") + task.skipTest("search.vectors/110_knn_query_with_filter/PRE_FILTER: knn query with alias filter as pre-filter", "waiting for #118774 backport") + task.skipTest("search.vectors/140_knn_query_with_other_queries/Function score query with knn query", "waiting for #118774 backport") + task.skipTest("search.vectors/130_knn_query_nested_search/nested kNN search inner_hits size > 1", "waiting for #118774 backport") + task.skipTest("search.vectors/110_knn_query_with_filter/PRE_FILTER: pre-filter across multiple aliases", "waiting for #118774 backport") + task.skipTest("search.vectors/160_knn_query_missing_params/kNN search in a dis_max query - missing num_candidates", "waiting for #118774 backport") task.skipTest("search.highlight/30_max_analyzed_offset/Plain highlighter with max_analyzed_offset < 0 should FAIL", "semantics of test has changed") }) diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/110_knn_query_with_filter.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/110_knn_query_with_filter.yml index 618951711cffd..3d4841a16d82d 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/110_knn_query_with_filter.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/110_knn_query_with_filter.yml @@ -59,7 +59,9 @@ setup: --- "Simple knn query": - + - requires: + cluster_features: "search.vectors.k_param_supported" + reason: 'k param for knn as query is required' - do: search: index: my_index @@ -71,8 +73,9 @@ setup: field: my_vector query_vector: [1, 1, 1, 1] num_candidates: 5 + k: 5 - - match: { hits.total.value: 5 } # collector sees num_candidates docs + - match: { hits.total.value: 5 } - length: {hits.hits: 3} - match: { hits.hits.0._id: "1" } - match: { hits.hits.0.fields.my_name.0: v1 } @@ -93,8 +96,9 @@ setup: field: my_vector query_vector: [1, 1, 1, 1] num_candidates: 5 + k: 5 - - match: { hits.total.value: 5 } # collector sees num_candidates docs + - match: { hits.total.value: 5 } - length: {hits.hits: 3} - match: { hits.hits.0._id: "2" } - match: { hits.hits.0.fields.my_name.0: v2 } @@ -140,6 +144,7 @@ setup: field: my_vector query_vector: [1, 1, 1, 1] num_candidates: 5 + k: 5 - match: { hits.total.value: 5 } - length: { hits.hits: 3 } @@ -184,6 +189,7 @@ setup: field: my_vector query_vector: [1, 1, 1, 1] num_candidates: 100 + k: 100 - match: { hits.total.value: 10 } # 5 docs from each alias - length: {hits.hits: 6} @@ -213,6 +219,7 @@ setup: field: my_vector query_vector: [1, 1, 1, 1] num_candidates: 5 + k: 5 filter: term: my_name: v2 @@ -243,9 +250,10 @@ setup: field: my_vector query_vector: [1, 1, 1, 1] num_candidates: 5 + k: 5 - match: { hits.total.value: 2 } - - length: {hits.hits: 2} # knn query returns top 5 docs, but they are post-filtered to 2 docs + - length: {hits.hits: 2} # knn query returns top 3 docs, but they are post-filtered to 2 docs - match: { hits.hits.0._id: "2" } - match: { hits.hits.0.fields.my_name.0: v2 } - match: { hits.hits.1._id: "4" } @@ -271,4 +279,4 @@ setup: my_name: v1 - match: { hits.total.value: 0} - - length: { hits.hits: 0 } # knn query returns top 5 docs, but they are post-filtered to 0 docs + - length: { hits.hits: 0 } # knn query returns top 3 docs, but they are post-filtered to 0 docs diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/120_knn_query_multiple_shards.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/120_knn_query_multiple_shards.yml index c6f3e187f7953..c68565e6629f5 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/120_knn_query_multiple_shards.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/120_knn_query_multiple_shards.yml @@ -166,55 +166,3 @@ setup: - close_to: { hits.hits.2._score: { value: 120, error: 0.00001 } } - close_to: { hits.hits.2.matched_queries.bm25_query: { value: 100.0, error: 0.00001 } } - close_to: { hits.hits.2.matched_queries.knn_query: { value: 20.0, error: 0.00001 } } - ---- -"Aggregations with collected number of docs depends on num_candidates": - - do: - search: - index: my_index - body: - size: 2 - query: - knn: - field: my_vector - query_vector: [1, 1, 1, 1] - num_candidates: 100 # collect up to 100 candidates from each shard - aggs: - my_agg: - terms: - field: my_name - order: - _key: asc - - - length: {hits.hits: 2} - - match: {hits.total.value: 12} - - match: {aggregations.my_agg.buckets.0.key: 'v1'} - - match: {aggregations.my_agg.buckets.1.key: 'v2'} - - match: {aggregations.my_agg.buckets.0.doc_count: 6} - - match: {aggregations.my_agg.buckets.1.doc_count: 6} - - - do: - search: - index: my_index - body: - size: 2 - query: - knn: - field: my_vector - query_vector: [ 1, 1, 1, 1 ] - num_candidates: 3 # collect 3 candidates from each shard - aggs: - my_agg2: - terms: - field: my_name - order: - _key: asc - my_sum_buckets: - sum_bucket: - buckets_path: "my_agg2>_count" - - - length: { hits.hits: 2 } - - match: { hits.total.value: 6 } - - match: { aggregations.my_agg2.buckets.0.key: 'v1' } - - match: { aggregations.my_agg2.buckets.1.key: 'v2' } - - match: { aggregations.my_sum_buckets.value: 6.0 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/130_knn_query_nested_search.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/130_knn_query_nested_search.yml index 79ff3f61742f8..bf07144975650 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/130_knn_query_nested_search.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/130_knn_query_nested_search.yml @@ -273,6 +273,7 @@ setup: knn: field: nested.vector query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 5 num_candidates: 5 inner_hits: { size: 2, "fields": [ "nested.paragraph_id" ], _source: false } @@ -295,6 +296,7 @@ setup: knn: field: nested.vector query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 5 num_candidates: 5 inner_hits: { size: 2, "fields": [ "nested.paragraph_id" ], _source: false } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/140_knn_query_with_other_queries.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/140_knn_query_with_other_queries.yml index d52a5daf22344..1e54e497f286f 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/140_knn_query_with_other_queries.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/140_knn_query_with_other_queries.yml @@ -69,6 +69,7 @@ setup: field: my_vector query_vector: [ 1, 1, 1, 1 ] num_candidates: 5 + k: 5 functions: - filter: { match: { my_name: v1 } } weight: 10 diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/160_knn_query_missing_params.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/160_knn_query_missing_params.yml index 02962e049e267..26c52060dfb22 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/160_knn_query_missing_params.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/160_knn_query_missing_params.yml @@ -100,8 +100,9 @@ setup: knn: field: vector query_vector: [1, 1, 1] + k: 2 size: 1 - - match: { hits.total: 2 } # due to num_candidates defined as round(1.5 * size), so we only see 2 results + - match: { hits.total: 2 } # k defaults to size - length: { hits.hits: 1 } # one result is only returned though --- @@ -117,6 +118,7 @@ setup: field: vector query_vector: [-1, -1, -1] num_candidates: 1 + k: 1 size: 10 - match: { hits.total: 1 } @@ -137,9 +139,10 @@ setup: - knn: field: vector query_vector: [ 1, 1, 0] + k: 1 size: 1 - - match: { hits.total: 2 } # due to num_candidates defined as round(1.5 * size), so we only see 2 results from cat:A + - match: { hits.total: 1 } - length: { hits.hits: 1 } --- @@ -154,6 +157,7 @@ setup: - knn: field: vector query_vector: [1, 1, 0] + k: 2 - match: category: B tie_breaker: 0.8 @@ -175,6 +179,7 @@ setup: knn: field: nested.vector query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 2 inner_hits: { size: 1, "fields": [ "nested.paragraph_id" ], _source: false } size: 1 diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/180_update_dense_vector_type.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/180_update_dense_vector_type.yml index 855daeaa7f163..99943ef2671bb 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/180_update_dense_vector_type.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/180_update_dense_vector_type.yml @@ -109,6 +109,7 @@ setup: field: embedding query_vector: [1, 1, 1, 1] num_candidates: 10 + k: 10 - match: { hits.total.value: 10 } - length: {hits.hits: 3} @@ -215,6 +216,7 @@ setup: field: embedding query_vector: [ 1, 1, 1, 1 ] num_candidates: 20 + k: 20 - match: { hits.total.value: 20 } - length: { hits.hits: 3 } @@ -322,6 +324,7 @@ setup: field: embedding query_vector: [ 1, 1, 1, 1 ] num_candidates: 30 + k: 30 - match: { hits.total.value: 30 } - length: { hits.hits: 4 } @@ -430,6 +433,7 @@ setup: field: embedding query_vector: [ 1, 1, 1, 1 ] num_candidates: 40 + k: 40 - match: { hits.total.value: 40 } - length: { hits.hits: 5 } @@ -499,6 +503,7 @@ setup: field: embedding query_vector: [1, 1, 1, 1] num_candidates: 10 + k: 10 - match: { hits.total.value: 10 } - length: {hits.hits: 3} @@ -559,6 +564,7 @@ setup: field: embedding query_vector: [ 1, 1, 1, 1 ] num_candidates: 20 + k: 20 - match: { hits.total.value: 20 } - length: { hits.hits: 3 } @@ -620,6 +626,7 @@ setup: field: embedding query_vector: [ 1, 1, 1, 1 ] num_candidates: 30 + k: 30 - match: { hits.total.value: 30 } - length: { hits.hits: 4 } @@ -682,6 +689,7 @@ setup: field: embedding query_vector: [ 1, 1, 1, 1 ] num_candidates: 40 + k: 40 - match: { hits.total.value: 40 } - length: { hits.hits: 5 } @@ -751,6 +759,7 @@ setup: field: embedding query_vector: [ 1, 1, 1, 1 ] num_candidates: 10 + k: 10 - match: { hits.total.value: 10 } - length: { hits.hits: 3 } @@ -791,6 +800,7 @@ setup: field: embedding query_vector: [ 1, 1, 1, 1 ] num_candidates: 10 + k: 10 - match: { hits.total.value: 10 } - length: { hits.hits: 3 } @@ -833,6 +843,7 @@ setup: field: embedding query_vector: [ 1, 1, 1, 1 ] num_candidates: 20 + k: 20 - match: { hits.total.value: 20 } - length: { hits.hits: 3 } @@ -869,6 +880,7 @@ setup: field: embedding query_vector: [ 1, 1, 1, 1 ] num_candidates: 20 + k: 20 - match: { hits.total.value: 20 } - length: { hits.hits: 3 } @@ -911,6 +923,7 @@ setup: field: embedding query_vector: [ 1, 1, 1, 1 ] num_candidates: 30 + k: 30 - match: { hits.total.value: 30 } - length: { hits.hits: 4 } @@ -933,6 +946,7 @@ setup: knn: field: embedding query_vector: [ 1, 1, 1, 1 ] + k: 30 num_candidates: 30 - match: { hits.total.value: 30 } @@ -1769,6 +1783,7 @@ setup: field: embedding query_vector: [1, 1, 1, 1] num_candidates: 10 + k: 10 - match: { hits.total.value: 10 } - length: {hits.hits: 3} @@ -1875,6 +1890,7 @@ setup: field: embedding query_vector: [ 1, 1, 1, 1 ] num_candidates: 20 + k: 20 - match: { hits.total.value: 20 } - length: { hits.hits: 3 } @@ -1982,6 +1998,7 @@ setup: field: embedding query_vector: [ 1, 1, 1, 1 ] num_candidates: 30 + k: 30 - match: { hits.total.value: 30 } - length: { hits.hits: 4 } @@ -2090,6 +2107,7 @@ setup: field: embedding query_vector: [ 1, 1, 1, 1 ] num_candidates: 40 + k: 40 - match: { hits.total.value: 40 } - length: { hits.hits: 5 } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index a99f21803556f..b2b23baacc4db 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -2019,7 +2019,7 @@ && isNotUnitVector(squaredMagnitude)) { public Query createKnnQuery( VectorData queryVector, - Integer k, + int k, int numCands, Float numCandsFactor, Query filter, @@ -2052,7 +2052,7 @@ private boolean needsRescore(Float rescoreOversample) { private Query createKnnBitQuery( byte[] queryVector, - Integer k, + int k, int numCands, Query filter, Float similarityThreshold, @@ -2074,7 +2074,7 @@ private Query createKnnBitQuery( private Query createKnnByteQuery( byte[] queryVector, - Integer k, + int k, int numCands, Query filter, Float similarityThreshold, @@ -2101,7 +2101,7 @@ private Query createKnnByteQuery( private Query createKnnFloatQuery( float[] queryVector, - Integer k, + int k, int numCands, Float numCandsFactor, Query filter, diff --git a/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java b/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java index 1548a62d2b3ef..c343141490799 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java +++ b/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java @@ -37,6 +37,8 @@ private SearchCapabilities() {} private static final String NESTED_RETRIEVER_INNER_HITS_SUPPORT = "nested_retriever_inner_hits_support"; /** Fixed the math in {@code moving_fn}'s {@code linearWeightedAvg}. */ private static final String MOVING_FN_RIGHT_MATH = "moving_fn_right_math"; + /** knn query where k defaults to the request size. */ + private static final String K_DEFAULT_TO_SIZE = "k_default_to_size"; private static final String RANDOM_SAMPLER_WITH_SCORED_SUBAGGS = "random_sampler_with_scored_subaggs"; private static final String OPTIMIZED_SCALAR_QUANTIZATION_BBQ = "optimized_scalar_quantization_bbq"; @@ -57,6 +59,7 @@ private SearchCapabilities() {} capabilities.add(OPTIMIZED_SCALAR_QUANTIZATION_BBQ); capabilities.add(KNN_QUANTIZED_VECTOR_RESCORE); capabilities.add(MOVING_FN_RIGHT_MATH); + capabilities.add(K_DEFAULT_TO_SIZE); if (Build.current().isSnapshot()) { capabilities.add(KQL_QUERY_SUPPORTED); } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java index b18ce2dff65cb..9b9718efcf523 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java @@ -465,7 +465,7 @@ public KnnVectorQueryBuilder toQueryBuilder() { if (queryVectorBuilder != null) { throw new IllegalArgumentException("missing rewrite"); } - return new KnnVectorQueryBuilder(field, queryVector, null, numCands, rescoreVectorBuilder, similarity).boost(boost) + return new KnnVectorQueryBuilder(field, queryVector, numCands, numCands, rescoreVectorBuilder, similarity).boost(boost) .queryName(queryName) .addFilterQueries(filterQueries); } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchRequestParser.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchRequestParser.java index 81b00f1329591..12573d5ad496e 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchRequestParser.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchRequestParser.java @@ -256,7 +256,7 @@ public KnnVectorQueryBuilder toQueryBuilder() { if (numCands > NUM_CANDS_LIMIT) { throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]"); } - return new KnnVectorQueryBuilder(field, queryVector, null, numCands, null, null); + return new KnnVectorQueryBuilder(field, queryVector, numCands, numCands, null, null); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java index 88f6312fa7e6f..a65757cc25876 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java @@ -495,15 +495,16 @@ protected QueryBuilder doIndexMetadataRewrite(QueryRewriteContext context) throw @Override protected Query doToQuery(SearchExecutionContext context) throws IOException { MappedFieldType fieldType = context.getFieldType(fieldName); - int requestSize; - if (k != null) { - requestSize = k; + int k; + if (this.k != null) { + k = this.k; } else { - requestSize = context.requestSize() == null || context.requestSize() < 0 ? DEFAULT_SIZE : context.requestSize(); + k = context.requestSize() == null || context.requestSize() < 0 ? DEFAULT_SIZE : context.requestSize(); + if (numCands != null) { + k = Math.min(k, numCands); + } } - int adjustedNumCands = numCands == null - ? Math.round(Math.min(NUM_CANDS_MULTIPLICATIVE_FACTOR * requestSize, NUM_CANDS_LIMIT)) - : numCands; + int adjustedNumCands = numCands == null ? Math.round(Math.min(NUM_CANDS_MULTIPLICATIVE_FACTOR * k, NUM_CANDS_LIMIT)) : numCands; if (fieldType == null) { return new MatchNoDocsQuery(); } diff --git a/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java b/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java index 353188af8be3c..568186e0cae5b 100644 --- a/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java @@ -418,7 +418,7 @@ public void testKnnSearchAction() throws IOException { float[] queryVector = randomVector(); assertResponse( client().prepareSearch("index1", "index2") - .setQuery(new KnnVectorQueryBuilder("vector", queryVector, null, 5, null, null)) + .setQuery(new KnnVectorQueryBuilder("vector", queryVector, 5, 5, null, null)) .setSize(2), response -> { // The total hits is num_cands * num_shards, since the query gathers num_cands hits from each shard diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java index 375712ee60861..244d539403315 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java @@ -47,7 +47,6 @@ import java.util.stream.Stream; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.NUM_CANDS_OVERSAMPLE_LIMIT; -import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -82,7 +81,7 @@ private void checkIndexTypeAndDimensions() { abstract KnnVectorQueryBuilder createKnnVectorQueryBuilder( String fieldName, - Integer k, + int k, int numCands, RescoreVectorBuilder rescoreVectorBuilder, Float similarity @@ -138,8 +137,8 @@ protected void initializeAdditionalMappings(MapperService mapperService) throws @Override protected KnnVectorQueryBuilder doCreateTestQueryBuilder() { String fieldName = randomBoolean() ? VECTOR_FIELD : VECTOR_ALIAS_FIELD; - Integer k = randomBoolean() ? null : randomIntBetween(1, 100); - int numCands = randomIntBetween(k == null ? DEFAULT_SIZE : k + 20, 1000); + int k = randomIntBetween(1, 100); + int numCands = randomIntBetween(k + 20, 1000); KnnVectorQueryBuilder queryBuilder = createKnnVectorQueryBuilder( fieldName, k, diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java index f6c2e754cec63..26066389c63f1 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java @@ -20,7 +20,7 @@ DenseVectorFieldMapper.ElementType elementType() { @Override protected KnnVectorQueryBuilder createKnnVectorQueryBuilder( String fieldName, - Integer k, + int k, int numCands, RescoreVectorBuilder rescoreVectorBuilder, Float similarity diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java index 6f67e4be29a06..70d29ab525ef1 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java @@ -20,7 +20,7 @@ DenseVectorFieldMapper.ElementType elementType() { @Override KnnVectorQueryBuilder createKnnVectorQueryBuilder( String fieldName, - Integer k, + int k, int numCands, RescoreVectorBuilder rescoreVectorBuilder, Float similarity diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java index a39438af5b72a..108dc60e2ee3b 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java @@ -224,9 +224,9 @@ public void testToQueryBuilder() { builder.addFilterQuery(filter); } - QueryBuilder expected = new KnnVectorQueryBuilder(field, vector, null, numCands, rescoreVectorBuilder, similarity).addFilterQueries( - filterQueries - ).boost(boost); + QueryBuilder expected = new KnnVectorQueryBuilder(field, vector, numCands, numCands, rescoreVectorBuilder, similarity) + .addFilterQueries(filterQueries) + .boost(boost); assertEquals(expected, builder.toQueryBuilder()); }