diff --git a/modules/percolator/src/internalClusterTest/java/org/elasticsearch/percolator/PercolatorQuerySearchIT.java b/modules/percolator/src/internalClusterTest/java/org/elasticsearch/percolator/PercolatorQuerySearchIT.java index 8a7f1405f8f4e..5e8ced116a1ff 100644 --- a/modules/percolator/src/internalClusterTest/java/org/elasticsearch/percolator/PercolatorQuerySearchIT.java +++ b/modules/percolator/src/internalClusterTest/java/org/elasticsearch/percolator/PercolatorQuerySearchIT.java @@ -1359,7 +1359,7 @@ public void testKnnQueryNotSupportedInPercolator() throws IOException { """); indicesAdmin().prepareCreate("index1").setMapping(mappings).get(); ensureGreen(); - QueryBuilder knnVectorQueryBuilder = new KnnVectorQueryBuilder("my_vector", new float[] { 1, 1, 1, 1, 1 }, 10, 10, null, null); + QueryBuilder knnVectorQueryBuilder = new KnnVectorQueryBuilder("my_vector", new float[] { 1, 1, 1, 1, 1 }, 10, 10, 10f, null, null); IndexRequestBuilder indexRequestBuilder = prepareIndex("index1").setId("knn_query1") .setSource(jsonBuilder().startObject().field("my_query", knnVectorQueryBuilder).endObject()); diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/20_knn_retriever.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/20_knn_retriever.yml index 1f07884c9fadf..98713cf570eea 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/20_knn_retriever.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/20_knn_retriever.yml @@ -53,6 +53,30 @@ setup: - match: {hits.hits.1._id: "3"} - match: {hits.hits.1.fields.name.0: "rabbit.jpg"} +--- +"kNN retrieve with visit_percentage": + - requires: + cluster_features: "mapper.bbq_disk_support" + reason: 'bbq disk support required' + - do: + search: + index: index1 + body: + fields: [ "name" ] + retriever: + knn: + field: vector + query_vector: [2, 2, 2, 2, 3] + k: 2 + num_candidates: 3 + visit_percentage: 1.0 + + - match: {hits.hits.0._id: "2"} + - match: {hits.hits.0.fields.name.0: "moose.jpg"} + + - match: {hits.hits.1._id: "3"} + - match: {hits.hits.1.fields.name.0: "rabbit.jpg"} + --- "kNN retriever with filter": - do: diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/135_knn_query_nested_search_ivf.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/135_knn_query_nested_search_ivf.yml index fe19a9b8578fb..4fc69240aed9b 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/135_knn_query_nested_search_ivf.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/135_knn_query_nested_search_ivf.yml @@ -191,3 +191,27 @@ setup: - match: {hits.hits.0._id: "3"} - match: {hits.hits.0.fields.name.0: "rabbit.jpg"} - match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" } + +--- +"nested kNN search works with visit_percentage": + - do: + search: + index: test + body: + fields: [ "name" ] + query: + nested: + path: nested + query: + knn: + field: nested.vector + query_vector: [-0.5, 90, -10, 14.8, -156] + num_candidates: 3 + visit_percentage: 1.0 + - match: {hits.total.value: 3} + + - match: {hits.hits.0._id: "2"} + - match: {hits.hits.0.fields.name.0: "moose.jpg"} + + - match: {hits.hits.1._id: "3"} + - match: {hits.hits.1.fields.name.0: "rabbit.jpg"} diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/46_knn_search_bbq_ivf.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/46_knn_search_bbq_ivf.yml index 3ce9232fc4ecd..3fbbbc3de1f48 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/46_knn_search_bbq_ivf.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/46_knn_search_bbq_ivf.yml @@ -106,6 +106,28 @@ setup: - match: { hits.hits.1._id: "3" } - match: { hits.hits.2._id: "2" } --- +"Test knn search with visit_percentage": + - do: + search: + index: bbq_disk + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + visit_percentage: 1.0 + + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.1._id: "3" } + - match: { hits.hits.2._id: "2" } +--- "Vector rescoring has same scoring as exact search for kNN section": - skip: features: "headers" diff --git a/server/src/internalClusterTest/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldIndexTypeUpdateIT.java b/server/src/internalClusterTest/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldIndexTypeUpdateIT.java index de174774f980a..4d698107dd0e6 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldIndexTypeUpdateIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldIndexTypeUpdateIT.java @@ -142,7 +142,7 @@ public void testDenseVectorMappingUpdate() throws Exception { for (int i = 0; i < queryVector.length; i++) { queryVector[i] = randomFloatBetween(-1, 1, true); } - KnnVectorQueryBuilder queryBuilder = new KnnVectorQueryBuilder(VECTOR_FIELD, queryVector, null, null, null, null); + KnnVectorQueryBuilder queryBuilder = new KnnVectorQueryBuilder(VECTOR_FIELD, queryVector, null, null, null, null, null); assertNoFailuresAndResponse( client().prepareSearch(INDEX_NAME).setQuery(queryBuilder).setTrackTotalHits(true).setSize(expectedDocs), response -> { diff --git a/server/src/internalClusterTest/java/org/elasticsearch/index/store/DirectIOIT.java b/server/src/internalClusterTest/java/org/elasticsearch/index/store/DirectIOIT.java index b14f067992ba0..e27ba0e141491 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/index/store/DirectIOIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/index/store/DirectIOIT.java @@ -124,7 +124,7 @@ public void testDirectIOUsed() { indexVectors(); // do a search - var knn = List.of(new KnnSearchBuilder("fooVector", new VectorData(null, new byte[64]), 10, 20, null, null)); + var knn = List.of(new KnnSearchBuilder("fooVector", new VectorData(null, new byte[64]), 10, 20, 10f, null, null)); assertHitCount(prepareSearch("foo-vectors").setKnnSearch(knn), 10); mockLog.assertAllExpectationsMatched(); } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/KnnSearchIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/KnnSearchIT.java index 91409e5e70183..17d6024145e22 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/KnnSearchIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/KnnSearchIT.java @@ -77,13 +77,13 @@ public void testKnnSearchWithScroll() throws Exception { // test top level knn search { SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); - sourceBuilder.knnSearch(List.of(new KnnSearchBuilder(VECTOR_FIELD, new float[] { 0, 0 }, k, 100, null, null))); + sourceBuilder.knnSearch(List.of(new KnnSearchBuilder(VECTOR_FIELD, new float[] { 0, 0 }, k, 100, 10f, null, null))); executeScrollSearch(client, sourceBuilder, k); } // test top level knn search + another query { SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); - sourceBuilder.knnSearch(List.of(new KnnSearchBuilder(VECTOR_FIELD, new float[] { 0, 0 }, k, 100, null, null))); + sourceBuilder.knnSearch(List.of(new KnnSearchBuilder(VECTOR_FIELD, new float[] { 0, 0 }, k, 100, 10f, null, null))); sourceBuilder.query(QueryBuilders.existsQuery("category").boost(10)); executeScrollSearch(client, sourceBuilder, k + 10); } @@ -91,7 +91,7 @@ public void testKnnSearchWithScroll() throws Exception { // test knn query { SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); - sourceBuilder.query(new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 0, 0 }, k, 100, null, null)); + sourceBuilder.query(new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 0, 0 }, k, 100, 10f, null, null)); executeScrollSearch(client, sourceBuilder, k * numShards); } // test knn query + another query @@ -99,7 +99,7 @@ public void testKnnSearchWithScroll() throws Exception { SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); sourceBuilder.query( QueryBuilders.boolQuery() - .should(new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 0, 0 }, k, 100, null, null)) + .should(new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 0, 0 }, k, 100, 10f, null, null)) .should(QueryBuilders.existsQuery("category").boost(10)) ); executeScrollSearch(client, sourceBuilder, k * numShards + 10); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/nested/VectorNestedIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/nested/VectorNestedIT.java index cbe7e7be51902..0ff2b7336e654 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/nested/VectorNestedIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/nested/VectorNestedIT.java @@ -72,7 +72,9 @@ public void testSimpleNested() throws Exception { assertResponse( prepareSearch("test").setKnnSearch( - List.of(new KnnSearchBuilder("nested.vector", new float[] { 1, 1, 1 }, 1, 1, null, null).innerHit(new InnerHitBuilder())) + List.of( + new KnnSearchBuilder("nested.vector", new float[] { 1, 1, 1 }, 1, 1, 10f, null, null).innerHit(new InnerHitBuilder()) + ) ).setAllowPartialSearchResults(false), response -> assertThat(response.getHits().getHits().length, greaterThan(0)) ); @@ -153,7 +155,7 @@ private void testNestedWithTwoSegments(boolean flush) { waitForRelocation(ClusterHealthStatus.GREEN); refresh(); - var knn = new KnnSearchBuilder("nested.vector", new float[] { -0.5f, 90.0f, -10f, 14.8f, -156.0f }, 2, 3, null, null); + var knn = new KnnSearchBuilder("nested.vector", new float[] { -0.5f, 90.0f, -10f, 14.8f, -156.0f }, 2, 3, 10f, null, null); var request = prepareSearch("test").addFetchField("name").setKnnSearch(List.of(knn)); assertNoFailuresAndResponse(request, response -> { assertHitCount(response, 2); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/profile/dfs/DfsProfilerIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/profile/dfs/DfsProfilerIT.java index 95d69a6ebaa86..c84e955ec8ce5 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/profile/dfs/DfsProfilerIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/profile/dfs/DfsProfilerIT.java @@ -72,6 +72,7 @@ public void testProfileDfs() throws Exception { new float[] { randomFloat(), randomFloat(), randomFloat() }, randomIntBetween(5, 10), 50, + 10f, randomBoolean() ? null : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)), randomBoolean() ? null : randomFloat() ); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/query/RescoreKnnVectorQueryIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/query/RescoreKnnVectorQueryIT.java index 453812c0566f7..efffcb6951ae2 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/query/RescoreKnnVectorQueryIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/query/RescoreKnnVectorQueryIT.java @@ -116,6 +116,7 @@ private record TestParams( float[] queryVector, int k, int numCands, + Float visitPercentage, RescoreVectorBuilder rescoreVectorBuilder ) { public static TestParams generate() { @@ -128,6 +129,7 @@ public static TestParams generate() { randomVector(numDims), k, (int) (k * randomFloatBetween(1.0f, 10.0f, true)), + randomBoolean() ? null : randomFloatBetween(0.0f, 100.0f, true), new RescoreVectorBuilder(randomFloatBetween(1.0f, 100f, true)) ); } @@ -140,6 +142,7 @@ public void testKnnSearchRescore() { testParams.queryVector, testParams.k, testParams.numCands, + testParams.visitPercentage, testParams.rescoreVectorBuilder, null ); @@ -155,6 +158,7 @@ public void testKnnQueryRescore() { testParams.queryVector, testParams.k, testParams.numCands, + testParams.visitPercentage, testParams.rescoreVectorBuilder, null ); @@ -171,6 +175,7 @@ public void testKnnRetriever() { null, testParams.k, testParams.numCands, + testParams.visitPercentage, testParams.rescoreVectorBuilder, null ); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/query/VectorIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/query/VectorIT.java index 496bea95b7d65..ffa7727c53d7e 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/query/VectorIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/query/VectorIT.java @@ -77,7 +77,7 @@ public void testFilteredQueryStrategy() { float[] vector = new float[16]; randomVector(vector, 25); int upperLimit = 35; - var query = new KnnSearchBuilder(VECTOR_FIELD, vector, 1, 1, null, null).addFilterQuery( + var query = new KnnSearchBuilder(VECTOR_FIELD, vector, 1, 1, 10f, null, null).addFilterQuery( QueryBuilders.rangeQuery(NUM_ID_FIELD).lte(35) ); assertResponse(client().prepareSearch(INDEX_NAME).setKnnSearch(List.of(query)).setSize(1).setProfile(true), acornResponse -> { @@ -133,7 +133,7 @@ public void testHnswEarlyTerminationQuery() { float[] vector = new float[16]; randomVector(vector, 25); int upperLimit = 35; - var query = new KnnSearchBuilder(VECTOR_FIELD, vector, 1, 1, null, null); + var query = new KnnSearchBuilder(VECTOR_FIELD, vector, 1, 1, 10f, null, null); assertResponse(client().prepareSearch(INDEX_NAME).setKnnSearch(List.of(query)).setSize(1).setProfile(true), response -> { assertNotEquals(0, response.getHits().getHits().length); var profileResults = response.getProfileResults(); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RetrieverTelemetryIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RetrieverTelemetryIT.java index 1762e4fe299c4..0996c5e3976c0 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RetrieverTelemetryIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RetrieverTelemetryIT.java @@ -84,7 +84,7 @@ public void testTelemetryForRetrievers() throws IOException { // search#1 - this will record 1 entry for "retriever" in `sections`, and 1 for "knn" under `retrievers` { performSearch( - new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null, null)) + new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, 10f, null, null)) ); } @@ -99,7 +99,7 @@ public void testTelemetryForRetrievers() throws IOException { { performSearch( new SearchSourceBuilder().retriever( - new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", new float[] { 1.0f }, 10, 15, null, null)) + new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", new float[] { 1.0f }, 10, 15, 10f, null, null)) ) ); } @@ -114,7 +114,7 @@ public void testTelemetryForRetrievers() throws IOException { // his will record 1 entry for "knn" in `sections` { performSearch( - new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, null, null))) + new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, 10f, null, null))) ); } diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index cc2edee43fcb5..b12190c3fe805 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -360,6 +360,7 @@ static TransportVersion def(int id) { public static final TransportVersion ESQL_LOOKUP_JOIN_PRE_JOIN_FILTER = def(9_151_0_00); public static final TransportVersion INFERENCE_API_DISABLE_EIS_RATE_LIMITING = def(9_152_0_00); public static final TransportVersion GEMINI_THINKING_BUDGET_ADDED = def(9_153_0_00); + public static final TransportVersion VISIT_PERCENTAGE = def(9_154_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 67cf720dc78b5..c1ff919c1f014 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -2407,6 +2407,7 @@ public Query createKnnQuery( VectorData queryVector, int k, int numCands, + Float visitPercentage, Float oversample, Query filter, Float similarityThreshold, @@ -2438,6 +2439,7 @@ public Query createKnnQuery( queryVector.asFloatVector(), k, numCands, + visitPercentage, oversample, filter, similarityThreshold, @@ -2570,6 +2572,7 @@ private Query createKnnFloatQuery( float[] queryVector, int k, int numCands, + Float visitPercentage, Float queryOversample, Query filter, Float similarityThreshold, @@ -2618,6 +2621,7 @@ private Query createKnnFloatQuery( .build(); } else if (indexOptions instanceof BBQIVFIndexOptions bbqIndexOptions) { float defaultVisitRatio = (float) (bbqIndexOptions.defaultVisitPercentage / 100d); + float visitRatio = visitPercentage == null ? defaultVisitRatio : (float) (visitPercentage / 100d); knnQuery = parentFilter != null ? new DiversifyingChildrenIVFKnnFloatVectorQuery( name(), @@ -2626,9 +2630,9 @@ private Query createKnnFloatQuery( numCands, filter, parentFilter, - defaultVisitRatio + visitRatio ) - : new IVFKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter, defaultVisitRatio); + : new IVFKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter, visitRatio); } else { knnQuery = parentFilter != null ? new ESDiversifyingChildrenFloatKnnVectorQuery( diff --git a/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java index 6db6b29515d21..862299c5cae1e 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java @@ -34,6 +34,7 @@ import java.util.function.Supplier; import static org.elasticsearch.common.Strings.format; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.IVF_FORMAT; import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; @@ -48,6 +49,7 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder { public static final ParseField FIELD_FIELD = new ParseField("field"); public static final ParseField K_FIELD = new ParseField("k"); public static final ParseField NUM_CANDS_FIELD = new ParseField("num_candidates"); + public static final ParseField VISIT_PERCENTAGE_FIELD = new ParseField("visit_percentage"); public static final ParseField QUERY_VECTOR_FIELD = new ParseField("query_vector"); public static final ParseField QUERY_VECTOR_BUILDER_FIELD = new ParseField("query_vector_builder"); public static final ParseField VECTOR_SIMILARITY = new ParseField("similarity"); @@ -67,15 +69,29 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder { } else { vectorArray = null; } - return new KnnRetrieverBuilder( - (String) args[0], - vectorArray, - (QueryVectorBuilder) args[2], - (int) args[3], - (int) args[4], - (RescoreVectorBuilder) args[6], - (Float) args[5] - ); + if (IVF_FORMAT.isEnabled()) { + return new KnnRetrieverBuilder( + (String) args[0], + vectorArray, + (QueryVectorBuilder) args[2], + (int) args[3], + (int) args[4], + (Float) args[5], + (RescoreVectorBuilder) args[7], + (Float) args[6] + ); + } else { + return new KnnRetrieverBuilder( + (String) args[0], + vectorArray, + (QueryVectorBuilder) args[2], + (int) args[3], + (int) args[4], + null, + (RescoreVectorBuilder) args[6], + (Float) args[5] + ); + } } ); @@ -89,6 +105,9 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder { ); PARSER.declareInt(constructorArg(), K_FIELD); PARSER.declareInt(constructorArg(), NUM_CANDS_FIELD); + if (IVF_FORMAT.isEnabled()) { + PARSER.declareFloat(optionalConstructorArg(), VISIT_PERCENTAGE_FIELD); + } PARSER.declareFloat(optionalConstructorArg(), VECTOR_SIMILARITY); PARSER.declareField( optionalConstructorArg(), @@ -108,6 +127,7 @@ public static KnnRetrieverBuilder fromXContent(XContentParser parser, RetrieverP private final QueryVectorBuilder queryVectorBuilder; private final int k; private final int numCands; + private final Float visitPercentage; private final RescoreVectorBuilder rescoreVectorBuilder; private final Float similarity; @@ -117,6 +137,7 @@ public KnnRetrieverBuilder( QueryVectorBuilder queryVectorBuilder, int k, int numCands, + Float visitPercentage, RescoreVectorBuilder rescoreVectorBuilder, Float similarity ) { @@ -142,6 +163,7 @@ public KnnRetrieverBuilder( this.queryVectorBuilder = queryVectorBuilder; this.k = k; this.numCands = numCands; + this.visitPercentage = visitPercentage; this.similarity = similarity; this.rescoreVectorBuilder = rescoreVectorBuilder; } @@ -152,6 +174,7 @@ private KnnRetrieverBuilder(KnnRetrieverBuilder clone, Supplier queryVe this.field = clone.field; this.k = clone.k; this.numCands = clone.numCands; + this.visitPercentage = clone.visitPercentage; this.similarity = clone.similarity; this.retrieverName = clone.retrieverName; this.preFilterQueryBuilders = clone.preFilterQueryBuilders; @@ -236,6 +259,7 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder null, k, numCands, + visitPercentage, rescoreVectorBuilder, similarity ); @@ -262,6 +286,10 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept builder.field(K_FIELD.getPreferredName(), k); builder.field(NUM_CANDS_FIELD.getPreferredName(), numCands); + if (visitPercentage != null) { + builder.field(VISIT_PERCENTAGE_FIELD.getPreferredName(), visitPercentage); + } + if (queryVector != null) { builder.field(QUERY_VECTOR_FIELD.getPreferredName(), queryVector.get()); } @@ -284,6 +312,7 @@ public boolean doEquals(Object o) { KnnRetrieverBuilder that = (KnnRetrieverBuilder) o; return k == that.k && numCands == that.numCands + && Objects.equals(visitPercentage, that.visitPercentage) && Objects.equals(field, that.field) && ((queryVector == null && that.queryVector == null) || (queryVector != null && that.queryVector != null && Arrays.equals(queryVector.get(), that.queryVector.get()))) @@ -294,7 +323,7 @@ public boolean doEquals(Object o) { @Override public int doHashCode() { - int result = Objects.hash(field, queryVectorBuilder, k, numCands, rescoreVectorBuilder, similarity); + int result = Objects.hash(field, queryVectorBuilder, k, numCands, visitPercentage, rescoreVectorBuilder, similarity); result = 31 * result + Arrays.hashCode(queryVector != null ? queryVector.get() : null); return result; } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java index 9b9718efcf523..f455797e68d12 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java @@ -34,6 +34,7 @@ import static org.elasticsearch.TransportVersions.V_8_11_X; import static org.elasticsearch.common.Strings.format; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.IVF_FORMAT; import static org.elasticsearch.index.query.AbstractQueryBuilder.DEFAULT_BOOST; import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; @@ -49,6 +50,7 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea public static final ParseField FIELD_FIELD = new ParseField("field"); public static final ParseField K_FIELD = new ParseField("k"); public static final ParseField NUM_CANDS_FIELD = new ParseField("num_candidates"); + public static final ParseField VISIT_PERCENTAGE_FIELD = new ParseField("visit_percentage"); public static final ParseField QUERY_VECTOR_FIELD = new ParseField("query_vector"); public static final ParseField QUERY_VECTOR_BUILDER_FIELD = new ParseField("query_vector_builder"); public static final ParseField VECTOR_SIMILARITY = new ParseField("similarity"); @@ -61,13 +63,24 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("knn", args -> { // TODO optimize parsing for when BYTE values are provided - return new Builder().field((String) args[0]) - .queryVector((VectorData) args[1]) - .queryVectorBuilder((QueryVectorBuilder) args[4]) - .k((Integer) args[2]) - .numCandidates((Integer) args[3]) - .similarity((Float) args[5]) - .rescoreVectorBuilder((RescoreVectorBuilder) args[6]); + if (IVF_FORMAT.isEnabled()) { + return new Builder().field((String) args[0]) + .queryVector((VectorData) args[1]) + .queryVectorBuilder((QueryVectorBuilder) args[5]) + .k((Integer) args[2]) + .numCandidates((Integer) args[3]) + .visitPercentage((Float) args[4]) + .similarity((Float) args[6]) + .rescoreVectorBuilder((RescoreVectorBuilder) args[7]); + } else { + return new Builder().field((String) args[0]) + .queryVector((VectorData) args[1]) + .queryVectorBuilder((QueryVectorBuilder) args[4]) + .k((Integer) args[2]) + .numCandidates((Integer) args[3]) + .similarity((Float) args[5]) + .rescoreVectorBuilder((RescoreVectorBuilder) args[6]); + } }); static { @@ -80,6 +93,9 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea ); PARSER.declareInt(optionalConstructorArg(), K_FIELD); PARSER.declareInt(optionalConstructorArg(), NUM_CANDS_FIELD); + if (IVF_FORMAT.isEnabled()) { + PARSER.declareFloat(optionalConstructorArg(), VISIT_PERCENTAGE_FIELD); + } PARSER.declareNamedObject( optionalConstructorArg(), (p, c, n) -> p.namedObject(QueryVectorBuilder.class, n, c), @@ -118,6 +134,7 @@ public static KnnSearchBuilder.Builder fromXContent(XContentParser parser) throw private final Supplier querySupplier; final int k; final int numCands; + final Float visitPercentage; final Float similarity; final List filterQueries; String queryName; @@ -132,6 +149,7 @@ public static KnnSearchBuilder.Builder fromXContent(XContentParser parser) throw * @param queryVector the query vector * @param k the final number of nearest neighbors to return as top hits * @param numCands the number of nearest neighbor candidates to consider per shard + * @param visitPercentage percentage of the total number of vectors to visit per shard * @param rescoreVectorBuilder rescore vector information */ public KnnSearchBuilder( @@ -139,6 +157,7 @@ public KnnSearchBuilder( float[] queryVector, int k, int numCands, + Float visitPercentage, RescoreVectorBuilder rescoreVectorBuilder, Float similarity ) { @@ -148,6 +167,7 @@ public KnnSearchBuilder( null, k, numCands, + visitPercentage, rescoreVectorBuilder, similarity ); @@ -160,16 +180,18 @@ public KnnSearchBuilder( * @param queryVector the query vector * @param k the final number of nearest neighbors to return as top hits * @param numCands the number of nearest neighbor candidates to consider per shard + * @param visitPercentage percentage of the total number of vectors to visit per shard */ public KnnSearchBuilder( String field, VectorData queryVector, int k, int numCands, + Float visitPercentage, RescoreVectorBuilder rescoreVectorBuilder, Float similarity ) { - this(field, queryVector, null, k, numCands, rescoreVectorBuilder, similarity); + this(field, queryVector, null, k, numCands, visitPercentage, rescoreVectorBuilder, similarity); } /** @@ -179,12 +201,14 @@ public KnnSearchBuilder( * @param queryVectorBuilder the query vector builder * @param k the final number of nearest neighbors to return as top hits * @param numCands the number of nearest neighbor candidates to consider per shard + * @param visitPercentage percentage of the total number of vectors to visit per shard */ public KnnSearchBuilder( String field, QueryVectorBuilder queryVectorBuilder, int k, int numCands, + Float visitPercentage, RescoreVectorBuilder rescoreVectorBuilder, Float similarity ) { @@ -194,6 +218,7 @@ public KnnSearchBuilder( Objects.requireNonNull(queryVectorBuilder, format("[%s] cannot be null", QUERY_VECTOR_BUILDER_FIELD.getPreferredName())), k, numCands, + visitPercentage, rescoreVectorBuilder, similarity ); @@ -205,6 +230,7 @@ public KnnSearchBuilder( QueryVectorBuilder queryVectorBuilder, int k, int numCands, + Float visitPercentage, RescoreVectorBuilder rescoreVectorBuilder, Float similarity ) { @@ -215,6 +241,7 @@ public KnnSearchBuilder( new ArrayList<>(), k, numCands, + visitPercentage, rescoreVectorBuilder, similarity, null, @@ -228,6 +255,7 @@ private KnnSearchBuilder( Supplier querySupplier, Integer k, Integer numCands, + Float visitPercentage, RescoreVectorBuilder rescoreVectorBuilder, List filterQueries, Float similarity @@ -237,6 +265,7 @@ private KnnSearchBuilder( this.queryVectorBuilder = null; this.k = k; this.numCands = numCands; + this.visitPercentage = visitPercentage; this.filterQueries = filterQueries; this.querySupplier = querySupplier; this.similarity = similarity; @@ -250,6 +279,7 @@ private KnnSearchBuilder( List filterQueries, int k, int numCandidates, + Float visitPercentage, RescoreVectorBuilder rescoreVectorBuilder, Float similarity, InnerHitBuilder innerHitBuilder, @@ -267,6 +297,9 @@ private KnnSearchBuilder( if (numCandidates > NUM_CANDS_LIMIT) { throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]"); } + if (visitPercentage != null && (visitPercentage < 0.0f || visitPercentage > 100.0f)) { + throw new IllegalArgumentException("[" + VISIT_PERCENTAGE_FIELD.getPreferredName() + "] must be between 0 and 100"); + } if (queryVector == null && queryVectorBuilder == null) { throw new IllegalArgumentException( format( @@ -290,6 +323,7 @@ private KnnSearchBuilder( this.queryVectorBuilder = queryVectorBuilder; this.k = k; this.numCands = numCandidates; + this.visitPercentage = visitPercentage; this.rescoreVectorBuilder = rescoreVectorBuilder; this.innerHitBuilder = innerHitBuilder; this.similarity = similarity; @@ -303,6 +337,11 @@ public KnnSearchBuilder(StreamInput in) throws IOException { this.field = in.readString(); this.k = in.readVInt(); this.numCands = in.readVInt(); + if (in.getTransportVersion().onOrAfter(TransportVersions.VISIT_PERCENTAGE)) { + this.visitPercentage = in.readOptionalFloat(); + } else { + this.visitPercentage = null; + } if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_14_0)) { this.queryVector = in.readOptionalWriteable(VectorData::new); } else { @@ -344,6 +383,10 @@ public int getNumCands() { return numCands; } + public Float getVisitPercentage() { + return visitPercentage; + } + public RescoreVectorBuilder getRescoreVectorBuilder() { return rescoreVectorBuilder; } @@ -416,10 +459,9 @@ public KnnSearchBuilder rewrite(QueryRewriteContext ctx) throws IOException { if (querySupplier.get() == null) { return this; } - return new KnnSearchBuilder(field, querySupplier.get(), k, numCands, rescoreVectorBuilder, similarity).boost(boost) - .queryName(queryName) - .addFilterQueries(filterQueries) - .innerHit(innerHitBuilder); + return new KnnSearchBuilder(field, querySupplier.get(), k, numCands, visitPercentage, rescoreVectorBuilder, similarity).boost( + boost + ).queryName(queryName).addFilterQueries(filterQueries).innerHit(innerHitBuilder); } if (queryVectorBuilder != null) { SetOnce toSet = new SetOnce<>(); @@ -439,7 +481,8 @@ public KnnSearchBuilder rewrite(QueryRewriteContext ctx) throws IOException { } ll.onResponse(null); }))); - return new KnnSearchBuilder(field, toSet::get, k, numCands, rescoreVectorBuilder, filterQueries, similarity).boost(boost) + return new KnnSearchBuilder(field, toSet::get, k, numCands, visitPercentage, rescoreVectorBuilder, filterQueries, similarity) + .boost(boost) .queryName(queryName) .innerHit(innerHitBuilder); } @@ -453,7 +496,7 @@ public KnnSearchBuilder rewrite(QueryRewriteContext ctx) throws IOException { rewrittenQueries.add(rewrittenQuery); } if (changed) { - return new KnnSearchBuilder(field, queryVector, k, numCands, rescoreVectorBuilder, similarity).boost(boost) + return new KnnSearchBuilder(field, queryVector, k, numCands, visitPercentage, rescoreVectorBuilder, similarity).boost(boost) .queryName(queryName) .addFilterQueries(rewrittenQueries) .innerHit(innerHitBuilder); @@ -465,9 +508,9 @@ public KnnVectorQueryBuilder toQueryBuilder() { if (queryVectorBuilder != null) { throw new IllegalArgumentException("missing rewrite"); } - return new KnnVectorQueryBuilder(field, queryVector, numCands, numCands, rescoreVectorBuilder, similarity).boost(boost) - .queryName(queryName) - .addFilterQueries(filterQueries); + return new KnnVectorQueryBuilder(field, queryVector, numCands, numCands, visitPercentage, rescoreVectorBuilder, similarity).boost( + boost + ).queryName(queryName).addFilterQueries(filterQueries); } public Float getSimilarity() { @@ -481,6 +524,7 @@ public boolean equals(Object o) { KnnSearchBuilder that = (KnnSearchBuilder) o; return k == that.k && numCands == that.numCands + && Objects.equals(visitPercentage, that.visitPercentage) && Objects.equals(rescoreVectorBuilder, that.rescoreVectorBuilder) && Objects.equals(field, that.field) && Objects.equals(queryVector, that.queryVector) @@ -499,6 +543,7 @@ public int hashCode() { field, k, numCands, + visitPercentage, querySupplier, queryVectorBuilder, rescoreVectorBuilder, @@ -517,6 +562,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(K_FIELD.getPreferredName(), k); builder.field(NUM_CANDS_FIELD.getPreferredName(), numCands); + if (visitPercentage != null) { + builder.field(VISIT_PERCENTAGE_FIELD.getPreferredName(), visitPercentage); + } + if (queryVectorBuilder != null) { builder.startObject(QUERY_VECTOR_BUILDER_FIELD.getPreferredName()); builder.field(queryVectorBuilder.getWriteableName(), queryVectorBuilder); @@ -561,6 +610,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(field); out.writeVInt(k); out.writeVInt(numCands); + if (out.getTransportVersion().onOrAfter(TransportVersions.VISIT_PERCENTAGE)) { + out.writeOptionalFloat(visitPercentage); + } if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_14_0)) { out.writeOptionalWriteable(queryVector); } else { @@ -601,6 +653,7 @@ public static class Builder { private QueryVectorBuilder queryVectorBuilder; private Integer k; private Integer numCandidates; + private Float visitPercentage; private Float similarity; private final List filterQueries = new ArrayList<>(); private String queryName; @@ -654,6 +707,11 @@ public Builder numCandidates(Integer numCands) { return this; } + public Builder visitPercentage(Float visitPercentage) { + this.visitPercentage = visitPercentage; + return this; + } + public Builder similarity(Float similarity) { this.similarity = similarity; return this; @@ -677,6 +735,7 @@ public KnnSearchBuilder build(int size) { filterQueries, adjustedK, adjustedNumCandidates, + visitPercentage, rescoreVectorBuilder, similarity, innerHitBuilder, diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchRequestParser.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchRequestParser.java index 12573d5ad496e..609b9df8c3fed 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchRequestParser.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchRequestParser.java @@ -30,7 +30,9 @@ import java.util.List; import java.util.Objects; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.IVF_FORMAT; import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; /** * A builder used in {@link RestKnnSearchAction} to convert the kNN REST request @@ -199,6 +201,7 @@ static class KnnSearch { static final ParseField FIELD_FIELD = new ParseField("field"); static final ParseField K_FIELD = new ParseField("k"); static final ParseField NUM_CANDS_FIELD = new ParseField("num_candidates"); + static final ParseField VISIT_PERCENTAGE_FIELD = new ParseField("visit_percentage"); static final ParseField QUERY_VECTOR_FIELD = new ParseField("query_vector"); private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("knn", args -> { @@ -208,7 +211,11 @@ static class KnnSearch { for (int i = 0; i < vector.size(); i++) { vectorArray[i] = vector.get(i); } - return new KnnSearch((String) args[0], vectorArray, (int) args[2], (int) args[3]); + if (IVF_FORMAT.isEnabled()) { + return new KnnSearch((String) args[0], vectorArray, (int) args[2], (int) args[3], (Float) args[4]); + } else { + return new KnnSearch((String) args[0], vectorArray, (int) args[2], (int) args[3], null); + } }); static { @@ -216,6 +223,9 @@ static class KnnSearch { PARSER.declareFloatArray(constructorArg(), QUERY_VECTOR_FIELD); PARSER.declareInt(constructorArg(), K_FIELD); PARSER.declareInt(constructorArg(), NUM_CANDS_FIELD); + if (IVF_FORMAT.isEnabled()) { + PARSER.declareFloat(optionalConstructorArg(), VISIT_PERCENTAGE_FIELD); + } } public static KnnSearch parse(XContentParser parser) throws IOException { @@ -226,6 +236,7 @@ public static KnnSearch parse(XContentParser parser) throws IOException { final float[] queryVector; final int k; final int numCands; + final Float visitPercentage; /** * Defines a kNN search. @@ -235,11 +246,12 @@ public static KnnSearch parse(XContentParser parser) throws IOException { * @param k the final number of nearest neighbors to return as top hits * @param numCands the number of nearest neighbor candidates to consider per shard */ - KnnSearch(String field, float[] queryVector, int k, int numCands) { + KnnSearch(String field, float[] queryVector, int k, int numCands, Float visitPercentage) { this.field = field; this.queryVector = queryVector; this.k = k; this.numCands = numCands; + this.visitPercentage = visitPercentage; } public KnnVectorQueryBuilder toQueryBuilder() { @@ -256,7 +268,10 @@ public KnnVectorQueryBuilder toQueryBuilder() { if (numCands > NUM_CANDS_LIMIT) { throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]"); } - return new KnnVectorQueryBuilder(field, queryVector, numCands, numCands, null, null); + if (visitPercentage != null && (visitPercentage < 0.0f || visitPercentage > 100.0f)) { + throw new IllegalArgumentException("[" + VISIT_PERCENTAGE_FIELD.getPreferredName() + "] must be between 0 and 100"); + } + return new KnnVectorQueryBuilder(field, queryVector, numCands, numCands, visitPercentage, null, null); } @Override @@ -266,13 +281,14 @@ public boolean equals(Object o) { KnnSearch that = (KnnSearch) o; return k == that.k && numCands == that.numCands + && Objects.equals(visitPercentage, that.visitPercentage) && Objects.equals(field, that.field) && Arrays.equals(queryVector, that.queryVector); } @Override public int hashCode() { - int result = Objects.hash(field, k, numCands); + int result = Objects.hash(field, k, numCands, visitPercentage); result = 31 * result + Arrays.hashCode(queryVector); return result; } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java index b76f56ceb2aa9..71c452858e1ae 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java @@ -48,6 +48,7 @@ import static org.elasticsearch.TransportVersions.KNN_QUERY_RESCORE_OVERSAMPLE; import static org.elasticsearch.common.Strings.format; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.IVF_FORMAT; import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; @@ -64,25 +65,40 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder PARSER = new ConstructingObjectParser<>( - "knn", - args -> new KnnVectorQueryBuilder( - (String) args[0], - (VectorData) args[1], - (QueryVectorBuilder) args[5], - null, - (Integer) args[2], - (Integer) args[3], - (RescoreVectorBuilder) args[6], - (Float) args[4] - ) - ); + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("knn", args -> { + if (IVF_FORMAT.isEnabled()) { + return new KnnVectorQueryBuilder( + (String) args[0], + (VectorData) args[1], + (QueryVectorBuilder) args[6], + null, + (Integer) args[2], + (Integer) args[3], + (Float) args[4], + (RescoreVectorBuilder) args[7], + (Float) args[5] + ); + } else { + return new KnnVectorQueryBuilder( + (String) args[0], + (VectorData) args[1], + (QueryVectorBuilder) args[5], + null, + (Integer) args[2], + (Integer) args[3], + null, + (RescoreVectorBuilder) args[6], + (Float) args[4] + ); + } + }); static { PARSER.declareString(constructorArg(), FIELD_FIELD); @@ -94,6 +110,9 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder filterQueries = new ArrayList<>(); private final Float vectorSimilarity; private final QueryVectorBuilder queryVectorBuilder; @@ -134,10 +154,21 @@ public KnnVectorQueryBuilder( float[] queryVector, Integer k, Integer numCands, + Float visitPercentage, RescoreVectorBuilder rescoreVectorBuilder, Float vectorSimilarity ) { - this(fieldName, VectorData.fromFloats(queryVector), null, null, k, numCands, rescoreVectorBuilder, vectorSimilarity); + this( + fieldName, + VectorData.fromFloats(queryVector), + null, + null, + k, + numCands, + visitPercentage, + rescoreVectorBuilder, + vectorSimilarity + ); } public KnnVectorQueryBuilder( @@ -145,9 +176,10 @@ public KnnVectorQueryBuilder( QueryVectorBuilder queryVectorBuilder, Integer k, Integer numCands, + Float visitPercentage, Float vectorSimilarity ) { - this(fieldName, null, queryVectorBuilder, null, k, numCands, null, vectorSimilarity); + this(fieldName, null, queryVectorBuilder, null, k, numCands, visitPercentage, null, vectorSimilarity); } public KnnVectorQueryBuilder( @@ -155,10 +187,21 @@ public KnnVectorQueryBuilder( byte[] queryVector, Integer k, Integer numCands, + Float visitPercentage, RescoreVectorBuilder rescoreVectorBuilder, Float vectorSimilarity ) { - this(fieldName, VectorData.fromBytes(queryVector), null, null, k, numCands, rescoreVectorBuilder, vectorSimilarity); + this( + fieldName, + VectorData.fromBytes(queryVector), + null, + null, + k, + numCands, + visitPercentage, + rescoreVectorBuilder, + vectorSimilarity + ); } public KnnVectorQueryBuilder( @@ -166,10 +209,11 @@ public KnnVectorQueryBuilder( VectorData queryVector, Integer k, Integer numCands, + Float visitPercentage, RescoreVectorBuilder rescoreVectorBuilder, Float vectorSimilarity ) { - this(fieldName, queryVector, null, null, k, numCands, rescoreVectorBuilder, vectorSimilarity); + this(fieldName, queryVector, null, null, k, numCands, visitPercentage, rescoreVectorBuilder, vectorSimilarity); } private KnnVectorQueryBuilder( @@ -179,6 +223,7 @@ private KnnVectorQueryBuilder( Supplier queryVectorSupplier, Integer k, Integer numCands, + Float visitPercentage, RescoreVectorBuilder rescoreVectorBuilder, Float vectorSimilarity ) { @@ -193,6 +238,9 @@ private KnnVectorQueryBuilder( "[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot be less than [" + K_FIELD.getPreferredName() + "]" ); } + if (visitPercentage != null && (visitPercentage < 0.0f || visitPercentage > 100.0f)) { + throw new IllegalArgumentException("[" + VISIT_PERCENTAGE_FIELD.getPreferredName() + "] must be between 0.0 and 100.0"); + } if (queryVector == null && queryVectorBuilder == null) { throw new IllegalArgumentException( format( @@ -214,6 +262,7 @@ private KnnVectorQueryBuilder( this.queryVector = queryVector; this.k = k; this.numCands = numCands; + this.visitPercentage = visitPercentage; this.vectorSimilarity = vectorSimilarity; this.queryVectorBuilder = queryVectorBuilder; this.queryVectorSupplier = queryVectorSupplier; @@ -233,6 +282,11 @@ public KnnVectorQueryBuilder(StreamInput in) throws IOException { } else { this.numCands = in.readVInt(); } + if (in.getTransportVersion().onOrAfter(TransportVersions.VISIT_PERCENTAGE)) { + this.visitPercentage = in.readOptionalFloat(); + } else { + this.visitPercentage = null; + } if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_14_0)) { this.queryVector = in.readOptionalWriteable(VectorData::new); } else { @@ -289,6 +343,10 @@ public Integer numCands() { return numCands; } + public Float visitPercentage() { + return visitPercentage; + } + public List filterQueries() { return filterQueries; } @@ -338,6 +396,9 @@ protected void doWriteTo(StreamOutput out) throws IOException { out.writeVInt(numCands); } } + if (out.getTransportVersion().onOrAfter(TransportVersions.VISIT_PERCENTAGE)) { + out.writeOptionalFloat(visitPercentage); + } if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_14_0)) { out.writeOptionalWriteable(queryVector); } else { @@ -389,6 +450,9 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep if (numCands != null) { builder.field(NUM_CANDS_FIELD.getPreferredName(), numCands); } + if (visitPercentage != null) { + builder.field(VISIT_PERCENTAGE_FIELD.getPreferredName(), visitPercentage); + } if (vectorSimilarity != null) { builder.field(VECTOR_SIMILARITY_FIELD.getPreferredName(), vectorSimilarity); } @@ -422,10 +486,15 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException { if (queryVectorSupplier.get() == null) { return this; } - return new KnnVectorQueryBuilder(fieldName, queryVectorSupplier.get(), k, numCands, rescoreVectorBuilder, vectorSimilarity) - .boost(boost) - .queryName(queryName) - .addFilterQueries(filterQueries); + return new KnnVectorQueryBuilder( + fieldName, + queryVectorSupplier.get(), + k, + numCands, + visitPercentage, + rescoreVectorBuilder, + vectorSimilarity + ).boost(boost).queryName(queryName).addFilterQueries(filterQueries); } if (queryVectorBuilder != null) { SetOnce toSet = new SetOnce<>(); @@ -452,6 +521,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException { toSet::get, k, numCands, + visitPercentage, rescoreVectorBuilder, vectorSimilarity ).boost(boost).queryName(queryName).addFilterQueries(filterQueries); @@ -476,6 +546,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException { queryVectorSupplier, k, numCands, + visitPercentage, rescoreVectorBuilder, vectorSimilarity ).boost(boost).queryName(queryName).addFilterQueries(rewrittenQueries); @@ -512,6 +583,7 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { } } int adjustedNumCands = numCands == null ? Math.round(Math.min(NUM_CANDS_MULTIPLICATIVE_FACTOR * k, NUM_CANDS_LIMIT)) : numCands; + if (fieldType == null) { return new MatchNoDocsQuery(); } @@ -575,6 +647,7 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { queryVector, k, adjustedNumCands, + visitPercentage, oversample, filterQuery, vectorSimilarity, @@ -601,6 +674,7 @@ protected int doHashCode() { Objects.hashCode(queryVector), k, numCands, + visitPercentage, filterQueries, vectorSimilarity, queryVectorBuilder, @@ -614,6 +688,7 @@ protected boolean doEquals(KnnVectorQueryBuilder other) { && Objects.equals(queryVector, other.queryVector) && Objects.equals(k, other.k) && Objects.equals(numCands, other.numCands) + && Objects.equals(visitPercentage, other.visitPercentage) && Objects.equals(filterQueries, other.filterQueries) && Objects.equals(vectorSimilarity, other.vectorSimilarity) && Objects.equals(queryVectorBuilder, other.queryVectorBuilder) diff --git a/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java index a4f698d04b782..f468c0c346aa5 100644 --- a/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java @@ -353,8 +353,8 @@ public void testRewriteShardSearchRequestWithRank() { SearchSourceBuilder ssb = new SearchSourceBuilder().query(bm25) .knnSearch( List.of( - new KnnSearchBuilder("vector", new float[] { 0.0f }, 10, 100, null, null), - new KnnSearchBuilder("vector2", new float[] { 0.0f }, 10, 100, null, null) + new KnnSearchBuilder("vector", new float[] { 0.0f }, 10, 100, 10f, null, null), + new KnnSearchBuilder("vector2", new float[] { 0.0f }, 10, 100, 10f, null, null) ) ) .rankBuilder(new TestRankBuilder(100)); diff --git a/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java b/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java index 7c5ca2d9007b6..0114d1994caf9 100644 --- a/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java @@ -63,7 +63,7 @@ public void testKnnSearchRemovedVector() throws IOException { client().prepareUpdate("index", "0").setDoc("vector", (Object) null).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE).get(); float[] queryVector = randomVector(); - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 20, 50, null, null).boost(5.0f); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 20, 50, 10f, null, null).boost(5.0f); assertResponse( client().prepareSearch("index") .setKnnSearch(List.of(knnSearch)) @@ -107,7 +107,7 @@ public void testKnnWithQuery() throws IOException { indicesAdmin().prepareRefresh("index").get(); float[] queryVector = randomVector(); - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null, null).boost(5.0f).queryName("knn"); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, 10f, null, null).boost(5.0f).queryName("knn"); assertResponse( client().prepareSearch("index") .setKnnSearch(List.of(knnSearch)) @@ -156,7 +156,7 @@ public void testKnnFilter() throws IOException { indicesAdmin().prepareRefresh("index").get(); float[] queryVector = randomVector(); - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null, null).addFilterQuery( + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, 10f, null, null).addFilterQuery( QueryBuilders.termsQuery("field", "second") ); assertResponse(client().prepareSearch("index").setKnnSearch(List.of(knnSearch)).addFetchField("*").setSize(10), response -> { @@ -199,7 +199,7 @@ public void testKnnFilterWithRewrite() throws IOException { indicesAdmin().prepareRefresh("index").get(); float[] queryVector = randomVector(); - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null, null).addFilterQuery( + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, 10f, null, null).addFilterQuery( QueryBuilders.termsLookupQuery("field", new TermsLookup("index", "lookup-doc", "other-field")) ); assertResponse(client().prepareSearch("index").setKnnSearch(List.of(knnSearch)).setSize(10), response -> { @@ -246,8 +246,8 @@ public void testMultiKnnClauses() throws IOException { indicesAdmin().prepareRefresh("index").get(); float[] queryVector = randomVector(20f, 21f); - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null, null).boost(5.0f); - KnnSearchBuilder knnSearch2 = new KnnSearchBuilder("vector_2", queryVector, 5, 50, null, null).boost(10.0f); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, 10f, null, null).boost(5.0f); + KnnSearchBuilder knnSearch2 = new KnnSearchBuilder("vector_2", queryVector, 5, 50, 10f, null, null).boost(10.0f); assertResponse( client().prepareSearch("index") .setKnnSearch(List.of(knnSearch, knnSearch2)) @@ -308,8 +308,8 @@ public void testMultiKnnClausesSameDoc() throws IOException { float[] queryVector = randomVector(); // Having the same query vector and same docs should mean our KNN scores are linearly combined if the same doc is matched - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null, null); - KnnSearchBuilder knnSearch2 = new KnnSearchBuilder("vector_2", queryVector, 5, 50, null, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, 10f, null, null); + KnnSearchBuilder knnSearch2 = new KnnSearchBuilder("vector_2", queryVector, 5, 50, 10f, null, null); assertResponse( client().prepareSearch("index") .setKnnSearch(List.of(knnSearch)) @@ -383,7 +383,7 @@ public void testKnnFilteredAlias() throws IOException { indicesAdmin().prepareRefresh("index").get(); float[] queryVector = randomVector(); - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 10, 50, null, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 10, 50, 10f, null, null); final int expectedHitCount = expectedHits; assertResponse(client().prepareSearch("test-alias").setKnnSearch(List.of(knnSearch)).setSize(10), response -> { assertHitCount(response, expectedHitCount); @@ -420,7 +420,7 @@ public void testKnnSearchAction() throws IOException { float[] queryVector = randomVector(); assertResponse( client().prepareSearch("index1", "index2") - .setQuery(new KnnVectorQueryBuilder("vector", queryVector, 5, 5, null, null)) + .setQuery(new KnnVectorQueryBuilder("vector", queryVector, 5, 5, 10f, null, null)) .setSize(2), response -> { // The total hits is num_cands * num_shards, since the query gathers num_cands hits from each shard @@ -454,7 +454,7 @@ public void testKnnVectorsWith4096Dims() throws IOException { indicesAdmin().prepareRefresh("index").get(); float[] queryVector = randomVector(4096); - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 3, 50, null, null).boost(5.0f); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 3, 50, 10f, null, null).boost(5.0f); assertResponse(client().prepareSearch("index").setKnnSearch(List.of(knnSearch)).addFetchField("*").setSize(10), response -> { assertHitCount(response, 3); assertEquals(3, response.getHits().getHits().length); diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchRequestTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchRequestTests.java index 9d9132ecdffe8..25c4c1672e852 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchRequestTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchRequestTests.java @@ -120,6 +120,7 @@ public void testSerializationMultiKNN() throws Exception { new float[] { 1, 2 }, 5, 10, + 10f, randomRescoreVectorBuilder(), randomBoolean() ? null : randomFloat() ), @@ -128,6 +129,7 @@ public void testSerializationMultiKNN() throws Exception { new float[] { 4, 12, 41 }, 3, 5, + 10f, randomRescoreVectorBuilder(), randomBoolean() ? null : randomFloat() ) @@ -151,6 +153,7 @@ public void testSerializationMultiKNN() throws Exception { new float[] { 1, 2 }, 5, 10, + 10f, randomRescoreVectorBuilder(), randomBoolean() ? null : randomFloat() ) @@ -474,7 +477,7 @@ public QueryBuilder topDocsQuery() { SearchRequest searchRequest = new SearchRequest().source( new SearchSourceBuilder().rankBuilder(new TestRankBuilder(100)) .query(QueryBuilders.termQuery("field", "term")) - .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null, null))) + .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, 10f, null, null))) .size(0) ); ActionRequestValidationException validationErrors = searchRequest.validate(); @@ -486,7 +489,7 @@ public QueryBuilder topDocsQuery() { SearchRequest searchRequest = new SearchRequest().source( new SearchSourceBuilder().rankBuilder(new TestRankBuilder(1)) .query(QueryBuilders.termQuery("field", "term")) - .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null, null))) + .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, 10f, null, null))) .size(2) ); ActionRequestValidationException validationErrors = searchRequest.validate(); @@ -513,7 +516,7 @@ public QueryBuilder topDocsQuery() { SearchRequest searchRequest = new SearchRequest().source( new SearchSourceBuilder().rankBuilder(new TestRankBuilder(100)) .query(QueryBuilders.termQuery("field", "term")) - .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null, null))) + .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, 10f, null, null))) ).scroll(new TimeValue(1000)); ActionRequestValidationException validationErrors = searchRequest.validate(); assertNotNull(validationErrors); @@ -524,7 +527,7 @@ public QueryBuilder topDocsQuery() { SearchRequest searchRequest = new SearchRequest().source( new SearchSourceBuilder().rankBuilder(new TestRankBuilder(9)) .query(QueryBuilders.termQuery("field", "term")) - .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null, null))) + .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, 10f, null, null))) ); ActionRequestValidationException validationErrors = searchRequest.validate(); assertNotNull(validationErrors); @@ -538,7 +541,7 @@ public QueryBuilder topDocsQuery() { SearchRequest searchRequest = new SearchRequest().source( new SearchSourceBuilder().rankBuilder(new TestRankBuilder(3)) .query(QueryBuilders.termQuery("field", "term")) - .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null, null))) + .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, 10f, null, null))) .size(3) .from(4) ); @@ -549,7 +552,7 @@ public QueryBuilder topDocsQuery() { SearchRequest searchRequest = new SearchRequest().source( new SearchSourceBuilder().rankBuilder(new TestRankBuilder(100)) .query(QueryBuilders.termQuery("field", "term")) - .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null, null))) + .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, 10f, null, null))) .addRescorer(new QueryRescorerBuilder(QueryBuilders.termQuery("rescore", "another term"))) ); ActionRequestValidationException validationErrors = searchRequest.validate(); @@ -561,7 +564,7 @@ public QueryBuilder topDocsQuery() { SearchRequest searchRequest = new SearchRequest().source( new SearchSourceBuilder().rankBuilder(new TestRankBuilder(100)) .query(QueryBuilders.termQuery("field", "term")) - .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null, null))) + .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, 10f, null, null))) .suggest(new SuggestBuilder().setGlobalText("test").addSuggestion("suggestion", new TermSuggestionBuilder("term"))) ); ActionRequestValidationException validationErrors = searchRequest.validate(); diff --git a/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java b/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java index fb79f9b462046..21d21acd7f9a2 100644 --- a/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java @@ -1395,7 +1395,7 @@ public void testShouldMinimizeRoundtrips() throws Exception { { SearchRequest searchRequest = new SearchRequest(); SearchSourceBuilder source = new SearchSourceBuilder(); - source.knnSearch(List.of(new KnnSearchBuilder("field", new float[] { 1, 2, 3 }, 10, 50, null, null))); + source.knnSearch(List.of(new KnnSearchBuilder("field", new float[] { 1, 2, 3 }, 10, 50, 10f, null, null))); searchRequest.source(source); searchRequest.setCcsMinimizeRoundtrips(true); @@ -1410,7 +1410,7 @@ public void testAdjustSearchType() { // If the search includes kNN, we should always use DFS_QUERY_THEN_FETCH SearchRequest searchRequest = new SearchRequest(); SearchSourceBuilder source = new SearchSourceBuilder(); - source.knnSearch(List.of(new KnnSearchBuilder("field", new float[] { 1, 2, 3 }, 10, 50, null, null))); + source.knnSearch(List.of(new KnnSearchBuilder("field", new float[] { 1, 2, 3 }, 10, 50, 10f, null, null))); searchRequest.source(source); TransportSearchAction.adjustSearchType(searchRequest, randomBoolean()); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java index 00c9cb4e68ae8..11ffe2f9dc789 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java @@ -2555,6 +2555,7 @@ public void testByteVectorQueryBoundaries() throws IOException { VectorData.fromFloats(new float[] { 128, 0, 0 }), 3, 3, + 10f, null, null, null, @@ -2574,6 +2575,7 @@ public void testByteVectorQueryBoundaries() throws IOException { VectorData.fromFloats(new float[] { 0.0f, 0f, -129.0f }), 3, 3, + 10f, null, null, null, @@ -2593,6 +2595,7 @@ public void testByteVectorQueryBoundaries() throws IOException { VectorData.fromFloats(new float[] { 0.0f, 0.5f, 0.0f }), 3, 3, + 10f, null, null, null, @@ -2612,6 +2615,7 @@ public void testByteVectorQueryBoundaries() throws IOException { VectorData.fromFloats(new float[] { 0, 0.0f, -0.25f }), 3, 3, + 10f, null, null, null, @@ -2631,6 +2635,7 @@ public void testByteVectorQueryBoundaries() throws IOException { VectorData.fromFloats(new float[] { Float.NaN, 0f, 0.0f }), 3, 3, + 10f, null, null, null, @@ -2647,6 +2652,7 @@ public void testByteVectorQueryBoundaries() throws IOException { VectorData.fromFloats(new float[] { Float.POSITIVE_INFINITY, 0f, 0.0f }), 3, 3, + 10f, null, null, null, @@ -2666,6 +2672,7 @@ public void testByteVectorQueryBoundaries() throws IOException { VectorData.fromFloats(new float[] { 0, Float.NEGATIVE_INFINITY, 0.0f }), 3, 3, + 10f, null, null, null, @@ -2702,6 +2709,7 @@ public void testFloatVectorQueryBoundaries() throws IOException { VectorData.fromFloats(new float[] { Float.NaN, 0f, 0.0f }), 3, 3, + 10f, null, null, null, @@ -2718,6 +2726,7 @@ public void testFloatVectorQueryBoundaries() throws IOException { VectorData.fromFloats(new float[] { Float.POSITIVE_INFINITY, 0f, 0.0f }), 3, 3, + 10f, null, null, null, @@ -2737,6 +2746,7 @@ public void testFloatVectorQueryBoundaries() throws IOException { VectorData.fromFloats(new float[] { 0, Float.NEGATIVE_INFINITY, 0.0f }), 3, 3, + 10f, null, null, null, diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index 2524422ed8f90..d6206b845ae25 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -25,22 +25,29 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.VectorSimilarity; import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.vectors.DenseVectorQuery; +import org.elasticsearch.search.vectors.DiversifyingChildrenIVFKnnFloatVectorQuery; import org.elasticsearch.search.vectors.DiversifyingParentBlockQuery; import org.elasticsearch.search.vectors.ESKnnByteVectorQuery; import org.elasticsearch.search.vectors.ESKnnFloatVectorQuery; +import org.elasticsearch.search.vectors.IVFKnnFloatVectorQuery; import org.elasticsearch.search.vectors.RescoreKnnVectorQuery; import org.elasticsearch.search.vectors.VectorData; import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Set; import java.util.function.Function; +import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.MAX_VECTORS_PER_CLUSTER; +import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.MIN_VECTORS_PER_CLUSTER; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.BBQ_MIN_DIMS; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.BIT; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.BYTE; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.FLOAT; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.IVF_FORMAT; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.OVERSAMPLE_LIMIT; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -66,36 +73,52 @@ private DenseVectorFieldMapper.DenseVectorIndexOptions randomIndexOptionsNonQuan } public static DenseVectorFieldMapper.DenseVectorIndexOptions randomIndexOptionsAll() { - return randomFrom( - new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000)), - new DenseVectorFieldMapper.Int8HnswIndexOptions( - randomIntBetween(1, 100), - randomIntBetween(1, 10_000), - randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), - randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) - ), - new DenseVectorFieldMapper.Int4HnswIndexOptions( - randomIntBetween(1, 100), - randomIntBetween(1, 10_000), - randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), - randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) - ), - new DenseVectorFieldMapper.FlatIndexOptions(), - new DenseVectorFieldMapper.Int8FlatIndexOptions( - randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), - randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) - ), - new DenseVectorFieldMapper.Int4FlatIndexOptions( - randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), - randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) - ), - new DenseVectorFieldMapper.BBQHnswIndexOptions( - randomIntBetween(1, 100), - randomIntBetween(1, 10_000), - randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) - ), - new DenseVectorFieldMapper.BBQFlatIndexOptions(randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector())) + List options = new ArrayList<>( + Arrays.asList( + new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000)), + new DenseVectorFieldMapper.Int8HnswIndexOptions( + randomIntBetween(1, 100), + randomIntBetween(1, 10_000), + randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), + randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) + ), + new DenseVectorFieldMapper.Int4HnswIndexOptions( + randomIntBetween(1, 100), + randomIntBetween(1, 10_000), + randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), + randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) + ), + new DenseVectorFieldMapper.FlatIndexOptions(), + new DenseVectorFieldMapper.Int8FlatIndexOptions( + randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), + randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) + ), + new DenseVectorFieldMapper.Int4FlatIndexOptions( + randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), + randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) + ), + new DenseVectorFieldMapper.BBQHnswIndexOptions( + randomIntBetween(1, 100), + randomIntBetween(1, 10_000), + randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) + ), + new DenseVectorFieldMapper.BBQFlatIndexOptions( + randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) + ) + ) ); + + if (IVF_FORMAT.isEnabled()) { + options.add( + new DenseVectorFieldMapper.BBQIVFIndexOptions( + randomIntBetween(MIN_VECTORS_PER_CLUSTER, MAX_VECTORS_PER_CLUSTER), + randomFloatBetween(0.0f, 100.0f, true), + randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) + ) + ); + } + + return randomFrom(options); } private DenseVectorFieldMapper.DenseVectorIndexOptions randomIndexOptionsHnswQuantized() { @@ -230,6 +253,7 @@ public void testCreateNestedKnnQuery() { VectorData.fromFloats(queryVector), 10, 10, + 10f, null, null, null, @@ -243,7 +267,11 @@ public void testCreateNestedKnnQuery() { if (field.getIndexOptions().isFlat()) { assertThat(query, instanceOf(DiversifyingParentBlockQuery.class)); } else { - assertTrue(query instanceof DiversifyingChildrenFloatKnnVectorQuery || query instanceof PatienceKnnVectorQuery); + assertTrue( + query instanceof DiversifyingChildrenFloatKnnVectorQuery + || query instanceof PatienceKnnVectorQuery + || query instanceof DiversifyingChildrenIVFKnnFloatVectorQuery + ); } } { @@ -269,6 +297,7 @@ public void testCreateNestedKnnQuery() { vectorData, 10, 10, + 10f, null, null, null, @@ -287,6 +316,7 @@ public void testCreateNestedKnnQuery() { vectorData, 10, 10, + 10f, null, null, null, @@ -365,6 +395,7 @@ public void testFloatCreateKnnQuery() { VectorData.fromFloats(new float[] { 0.3f, 0.1f, 1.0f, 0.0f }), 10, 10, + 10f, null, null, null, @@ -396,6 +427,7 @@ public void testFloatCreateKnnQuery() { VectorData.fromFloats(queryVector), 10, 10, + 10f, null, null, null, @@ -423,6 +455,7 @@ public void testFloatCreateKnnQuery() { VectorData.fromFloats(new float[BBQ_MIN_DIMS]), 10, 10, + 10f, null, null, null, @@ -455,6 +488,7 @@ public void testCreateKnnQueryMaxDims() { VectorData.fromFloats(queryVector), 10, 10, + 10f, null, null, null, @@ -468,7 +502,11 @@ public void testCreateKnnQueryMaxDims() { if (fieldWith4096dims.getIndexOptions().isFlat()) { assertThat(query, instanceOf(DenseVectorQuery.Floats.class)); } else { - assertTrue(query instanceof KnnFloatVectorQuery || query instanceof PatienceKnnVectorQuery); + assertTrue( + query instanceof KnnFloatVectorQuery + || query instanceof PatienceKnnVectorQuery + || query instanceof IVFKnnFloatVectorQuery + ); } } @@ -493,6 +531,7 @@ public void testCreateKnnQueryMaxDims() { vectorData, 10, 10, + 10f, null, null, null, @@ -526,6 +565,7 @@ public void testByteCreateKnnQuery() { VectorData.fromFloats(new float[] { 0.3f, 0.1f, 1.0f }), 10, 10, + 10f, null, null, null, @@ -553,6 +593,7 @@ public void testByteCreateKnnQuery() { VectorData.fromFloats(new float[] { 0.0f, 0.0f, 0.0f }), 10, 10, + 10f, null, null, null, @@ -569,6 +610,7 @@ public void testByteCreateKnnQuery() { new VectorData(null, new byte[] { 0, 0, 0 }), 10, 10, + 10f, null, null, null, @@ -598,6 +640,7 @@ public void testRescoreOversampleUsedWithoutQuantization() { new VectorData(null, new byte[] { 1, 4, 10 }), 10, 100, + 10f, randomFloatBetween(1.0F, 10.0F, false), null, null, @@ -647,11 +690,11 @@ public void testRescoreOversampleModifiesNumCandidates() { ); // Total results is k, internal k is multiplied by oversample - checkRescoreQueryParameters(fieldType, 10, 200, 2.5F, 25, 200, 10); + checkRescoreQueryParameters(fieldType, 10, 200, 10f, 2.5F, 25, 200, 10); // If numCands < k, update numCands to k - checkRescoreQueryParameters(fieldType, 10, 20, 2.5F, 25, 25, 10); + checkRescoreQueryParameters(fieldType, 10, 20, 10f, 2.5F, 25, 25, 10); // Oversampling limits for k - checkRescoreQueryParameters(fieldType, 1000, 1000, 11.0F, OVERSAMPLE_LIMIT, OVERSAMPLE_LIMIT, 1000); + checkRescoreQueryParameters(fieldType, 1000, 1000, 10f, 11.0F, OVERSAMPLE_LIMIT, OVERSAMPLE_LIMIT, 1000); } public void testRescoreOversampleQueryOverrides() { @@ -671,6 +714,7 @@ public void testRescoreOversampleQueryOverrides() { VectorData.fromFloats(new float[] { 1, 4, 10 }), 10, 100, + 10f, 0f, null, null, @@ -700,6 +744,7 @@ public void testRescoreOversampleQueryOverrides() { VectorData.fromFloats(new float[] { 1, 4, 10 }), 10, 100, + 10f, 2f, null, null, @@ -740,6 +785,7 @@ public void testFilterSearchThreshold() { VectorData.fromFloats(new float[] { 1, 4, 10 }), 10, 100, + 10f, 0f, null, null, @@ -756,6 +802,7 @@ public void testFilterSearchThreshold() { VectorData.fromFloats(new float[] { 1, 4, 10 }), 10, 100, + 10f, 0f, null, null, @@ -776,6 +823,7 @@ private static void checkRescoreQueryParameters( DenseVectorFieldType fieldType, int k, int candidates, + Float visitPercentage, float oversample, int expectedK, int expectedCandidates, @@ -785,6 +833,7 @@ private static void checkRescoreQueryParameters( VectorData.fromFloats(new float[] { 1, 4, 10 }), k, candidates, + visitPercentage, oversample, null, null, diff --git a/server/src/test/java/org/elasticsearch/index/query/NestedQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/index/query/NestedQueryBuilderTests.java index 6ab936bfab27c..39520db299f65 100644 --- a/server/src/test/java/org/elasticsearch/index/query/NestedQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/index/query/NestedQueryBuilderTests.java @@ -270,6 +270,7 @@ public void testKnnRewriteForInnerHits() throws IOException { new float[] { 1.0f, 2.0f, 3.0f }, null, 1, + 10f, null, null ); diff --git a/server/src/test/java/org/elasticsearch/rest/action/search/RestSearchActionTests.java b/server/src/test/java/org/elasticsearch/rest/action/search/RestSearchActionTests.java index 580dad8128494..ef620896e941d 100644 --- a/server/src/test/java/org/elasticsearch/rest/action/search/RestSearchActionTests.java +++ b/server/src/test/java/org/elasticsearch/rest/action/search/RestSearchActionTests.java @@ -83,7 +83,7 @@ public void testValidateSearchRequest() { .build(); SearchRequest searchRequest = new SearchRequest(); - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", new float[] { 1, 1, 1 }, 10, 100, null, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", new float[] { 1, 1, 1 }, 10, 100, 10f, null, null); searchRequest.source(new SearchSourceBuilder().knnSearch(List.of(knnSearch))); Exception ex = expectThrows( diff --git a/server/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java b/server/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java index 83f19924ec3d6..dc6328862254a 100644 --- a/server/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java @@ -826,7 +826,7 @@ public void testSearchSectionsUsageCollection() throws IOException { searchSourceBuilder.fetchField("field"); // these are not correct runtime mappings but they are counted compared to empty object searchSourceBuilder.runtimeMappings(Collections.singletonMap("field", "keyword")); - searchSourceBuilder.knnSearch(List.of(new KnnSearchBuilder("field", new float[] {}, 2, 5, null, null))); + searchSourceBuilder.knnSearch(List.of(new KnnSearchBuilder("field", new float[] {}, 2, 5, 10f, null, null))); searchSourceBuilder.pointInTimeBuilder(new PointInTimeBuilder(new BytesArray("pitid"))); searchSourceBuilder.docValueField("field"); searchSourceBuilder.storedField("field"); diff --git a/server/src/test/java/org/elasticsearch/search/dfs/DfsPhaseTests.java b/server/src/test/java/org/elasticsearch/search/dfs/DfsPhaseTests.java index dd2bcd7175976..ea5e7e6b7a488 100644 --- a/server/src/test/java/org/elasticsearch/search/dfs/DfsPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/search/dfs/DfsPhaseTests.java @@ -156,7 +156,7 @@ public DfsSearchResult dfsResult() { context.request() .source( new SearchSourceBuilder().knnSearch( - List.of(new KnnSearchBuilder("float_vector", new float[] { 0, 0, 0 }, numDocs, numDocs, null, null)) + List.of(new KnnSearchBuilder("float_vector", new float[] { 0, 0, 0 }, numDocs, numDocs, 100f, null, null)) ) ); context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); diff --git a/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java b/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java index 2724b86f9acd4..efe08cc62ab6e 100644 --- a/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java +++ b/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java @@ -52,6 +52,7 @@ public static KnnRetrieverBuilder createRandomKnnRetrieverBuilder() { float[] vector = randomVector(dim); int k = randomIntBetween(1, 100); int numCands = randomIntBetween(k + 20, 1000); + Float visitPercentage = randomBoolean() ? null : randomFloatBetween(0.0f, 100.0f, true); Float similarity = randomBoolean() ? null : randomFloat(); RescoreVectorBuilder rescoreVectorBuilder = randomBoolean() ? null @@ -63,6 +64,7 @@ public static KnnRetrieverBuilder createRandomKnnRetrieverBuilder() { null, k, numCands, + visitPercentage, rescoreVectorBuilder, similarity ); diff --git a/server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java b/server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java index 165ad9b2de183..772088296f76f 100644 --- a/server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java @@ -70,6 +70,7 @@ private List innerRetrievers(QueryRewriteContext queryRewriteC null, randomInt(10), randomIntBetween(10, 100), + randomBoolean() ? null : randomFloatBetween(0.0f, 100.0f, true), randomBoolean() ? null : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)), randomFloat() ); diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java index a8d9b1259cb41..c6725e9710f78 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java @@ -89,6 +89,7 @@ abstract KnnVectorQueryBuilder createKnnVectorQueryBuilder( String fieldName, int k, int numCands, + Float visitPercentage, RescoreVectorBuilder rescoreVectorBuilder, Float similarity ); @@ -145,10 +146,12 @@ protected KnnVectorQueryBuilder doCreateTestQueryBuilder() { String fieldName = randomBoolean() ? VECTOR_FIELD : VECTOR_ALIAS_FIELD; int k = randomIntBetween(1, 100); int numCands = randomIntBetween(k + 20, 1000); + Float visitPercentage = randomBoolean() ? null : randomFloatBetween(0.0f, 100.0f, true); KnnVectorQueryBuilder queryBuilder = createKnnVectorQueryBuilder( fieldName, k, numCands, + visitPercentage, isIndextypeBBQ() ? randomBBQRescoreVectorBuilder() : randomRescoreVectorBuilder(), randomFloat() ); @@ -284,7 +287,7 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que public void testWrongDimension() { SearchExecutionContext context = createSearchExecutionContext(); - KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f }, 5, 10, null, null); + KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f }, 5, 10, 10f, null, null); IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> query.doToQuery(context)); assertThat( e.getMessage(), @@ -294,7 +297,7 @@ public void testWrongDimension() { public void testNonexistentField() { SearchExecutionContext context = createSearchExecutionContext(); - KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("nonexistent", new float[] { 1.0f, 1.0f, 1.0f }, 5, 10, null, null); + KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("nonexistent", new float[] { 1.0f, 1.0f, 1.0f }, 5, 10, 10f, null, null); context.setAllowUnmappedFields(false); QueryShardException e = expectThrows(QueryShardException.class, () -> query.doToQuery(context)); assertThat(e.getMessage(), containsString("No field mapping can be found for the field with name [nonexistent]")); @@ -302,7 +305,7 @@ public void testNonexistentField() { public void testNonexistentFieldReturnEmpty() throws IOException { SearchExecutionContext context = createSearchExecutionContext(); - KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("nonexistent", new float[] { 1.0f, 1.0f, 1.0f }, 5, 10, null, null); + KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("nonexistent", new float[] { 1.0f, 1.0f, 1.0f }, 5, 10, 10f, null, null); Query queryNone = query.doToQuery(context); assertThat(queryNone, instanceOf(MatchNoDocsQuery.class)); } @@ -314,6 +317,7 @@ public void testWrongFieldType() { new float[] { 1.0f, 1.0f, 1.0f }, 5, 10, + 10f, null, null ); @@ -326,14 +330,14 @@ public void testNumCandsLessThanK() { int numCands = 3; IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 1.0f, 1.0f }, k, numCands, null, null) + () -> new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 1.0f, 1.0f }, k, numCands, 10f, null, null) ); assertThat(e.getMessage(), containsString("[num_candidates] cannot be less than [k]")); } @Override public void testValidOutput() { - KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, null, 10, null, null); + KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, null, 10, 10f, null, null); String expected = """ { "knn" : { @@ -343,12 +347,13 @@ public void testValidOutput() { 2.0, 3.0 ], - "num_candidates" : 10 + "num_candidates" : 10, + "visit_percentage" : 10.0 } }"""; assertEquals(expected, query.toString()); - KnnVectorQueryBuilder query2 = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, 5, 10, null, null); + KnnVectorQueryBuilder query2 = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, 5, 10, 10f, null, null); String expected2 = """ { "knn" : { @@ -359,10 +364,27 @@ public void testValidOutput() { 3.0 ], "k" : 5, - "num_candidates" : 10 + "num_candidates" : 10, + "visit_percentage" : 10.0 } }"""; assertEquals(expected2, query2.toString()); + + KnnVectorQueryBuilder query3 = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, 5, 10, null, null, null); + String expected3 = """ + { + "knn" : { + "field" : "vector", + "query_vector" : [ + 1.0, + 2.0, + 3.0 + ], + "k" : 5, + "num_candidates" : 10 + } + }"""; + assertEquals(expected3, query3.toString()); } @Override @@ -376,6 +398,7 @@ public void testMustRewrite() throws IOException { vectorDimensions, null, null, + null, null ); query.addFilterQuery(termQuery); @@ -396,6 +419,7 @@ public void testBWCVersionSerializationFilters() throws IOException { null, query.numCands(), null, + null, null ).queryName(query.queryName()).boost(query.boost()); TransportVersion beforeFilterVersion = TransportVersionUtils.randomVersionBetween( @@ -415,6 +439,7 @@ public void testBWCVersionSerializationSimilarity() throws IOException { null, query.numCands(), null, + null, null ).queryName(query.queryName()).boost(query.boost()).addFilterQueries(query.filterQueries()); assertBWCSerialization(query, queryNoSimilarity, TransportVersions.V_8_7_0); @@ -435,6 +460,7 @@ public void testBWCVersionSerializationQuery() throws IOException { null, query.numCands(), null, + null, similarity ).queryName(query.queryName()).boost(query.boost()).addFilterQueries(query.filterQueries()); assertBWCSerialization(query, queryOlderVersion, differentQueryVersion); @@ -457,6 +483,7 @@ public void testBWCVersionSerializationRescoreVector() throws IOException { k, query.numCands(), null, + null, query.getVectorSimilarity() ).queryName(query.queryName()).boost(query.boost()).addFilterQueries(query.filterQueries()); assertBWCSerialization(query, queryNoRescoreVector, version); @@ -510,6 +537,7 @@ public void testRewriteWithQueryVectorBuilder() throws Exception { new TestQueryVectorBuilderPlugin.TestQueryVectorBuilder(expectedArray), null, 5, + 10f, 1f ); knnVectorQueryBuilder.boost(randomFloat()); diff --git a/server/src/test/java/org/elasticsearch/search/vectors/DiversifyingParentBlockQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/DiversifyingParentBlockQueryTests.java index 8ff81cda6e8a0..4d9417c402700 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/DiversifyingParentBlockQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/DiversifyingParentBlockQueryTests.java @@ -116,6 +116,7 @@ public void testRandom() throws IOException { VectorData.fromFloats(queries[i]), 10, 10, + 10f, null, null, null, diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java index 26066389c63f1..b473a83ffd3eb 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java @@ -22,6 +22,7 @@ protected KnnVectorQueryBuilder createKnnVectorQueryBuilder( String fieldName, int k, int numCands, + Float visitPercentage, RescoreVectorBuilder rescoreVectorBuilder, Float similarity ) { @@ -29,7 +30,7 @@ protected KnnVectorQueryBuilder createKnnVectorQueryBuilder( for (int i = 0; i < vector.length; i++) { vector[i] = randomByte(); } - return new KnnVectorQueryBuilder(fieldName, vector, k, numCands, rescoreVectorBuilder, similarity); + return new KnnVectorQueryBuilder(fieldName, vector, k, numCands, visitPercentage, rescoreVectorBuilder, similarity); } @Override diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java index 70d29ab525ef1..4544cb260e38c 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java @@ -22,6 +22,7 @@ KnnVectorQueryBuilder createKnnVectorQueryBuilder( String fieldName, int k, int numCands, + Float visitPercentage, RescoreVectorBuilder rescoreVectorBuilder, Float similarity ) { @@ -29,7 +30,7 @@ KnnVectorQueryBuilder createKnnVectorQueryBuilder( for (int i = 0; i < vector.length; i++) { vector[i] = randomFloat(); } - return new KnnVectorQueryBuilder(fieldName, vector, k, numCands, rescoreVectorBuilder, similarity); + return new KnnVectorQueryBuilder(fieldName, vector, k, numCands, visitPercentage, rescoreVectorBuilder, similarity); } @Override diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java index 33ab8324ffb96..328b295f1f220 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java @@ -52,6 +52,7 @@ public static KnnSearchBuilder randomTestInstance() { float[] vector = randomVector(dim); int k = randomIntBetween(1, 100); int numCands = randomIntBetween(k + 20, 1000); + Float visitPercentage = randomBoolean() ? null : randomFloatBetween(0.0f, 100.0f, true); RescoreVectorBuilder rescoreVectorBuilder = randomBoolean() ? null : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)); @@ -61,6 +62,7 @@ public static KnnSearchBuilder randomTestInstance() { vector, k, numCands, + visitPercentage, rescoreVectorBuilder, randomBoolean() ? null : randomFloat() ); @@ -110,7 +112,7 @@ protected KnnSearchBuilder createTestInstance() { @Override protected KnnSearchBuilder mutateInstance(KnnSearchBuilder instance) { - return switch (random().nextInt(8)) { + return switch (random().nextInt(9)) { case 0 -> { String newField = randomValueOtherThan(instance.field, () -> randomAlphaOfLength(5)); yield new KnnSearchBuilder( @@ -118,6 +120,7 @@ yield new KnnSearchBuilder( instance.queryVector, instance.k, instance.numCands, + instance.visitPercentage, instance.getRescoreVectorBuilder(), instance.similarity ).boost(instance.boost); @@ -129,6 +132,7 @@ yield new KnnSearchBuilder( newVector, instance.k, instance.numCands, + instance.visitPercentage, instance.getRescoreVectorBuilder(), instance.similarity ).boost(instance.boost); @@ -141,6 +145,7 @@ yield new KnnSearchBuilder( instance.queryVector, newK, instance.numCands, + instance.visitPercentage, instance.getRescoreVectorBuilder(), instance.similarity ).boost(instance.boost); @@ -152,6 +157,7 @@ yield new KnnSearchBuilder( instance.queryVector, instance.k, newNumCands, + instance.visitPercentage, instance.getRescoreVectorBuilder(), instance.similarity ).boost(instance.boost); @@ -161,6 +167,7 @@ yield new KnnSearchBuilder( instance.queryVector, instance.k, instance.numCands, + instance.visitPercentage, instance.getRescoreVectorBuilder(), instance.similarity ).addFilterQueries(instance.filterQueries) @@ -173,6 +180,7 @@ yield new KnnSearchBuilder( instance.queryVector, instance.k, instance.numCands, + instance.visitPercentage, instance.getRescoreVectorBuilder(), instance.similarity ).addFilterQueries(instance.filterQueries).boost(newBoost); @@ -182,6 +190,7 @@ yield new KnnSearchBuilder( instance.queryVector, instance.k, instance.numCands, + instance.visitPercentage, instance.getRescoreVectorBuilder(), randomValueOtherThan(instance.similarity, ESTestCase::randomFloat) ).addFilterQueries(instance.filterQueries).boost(instance.boost); @@ -190,12 +199,28 @@ yield new KnnSearchBuilder( instance.queryVector, instance.k, instance.numCands, + instance.visitPercentage, randomValueOtherThan( instance.getRescoreVectorBuilder(), () -> new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)) ), instance.similarity ).addFilterQueries(instance.filterQueries).boost(instance.boost); + case 8 -> { + Float newVisitPercentage = randomValueOtherThan( + instance.visitPercentage, + () -> ESTestCase.randomFloatBetween(0f, 100f, true) + ); + yield new KnnSearchBuilder( + instance.field, + instance.queryVector, + instance.k, + instance.numCands, + newVisitPercentage, + instance.getRescoreVectorBuilder(), + instance.similarity + ).boost(instance.boost); + } default -> throw new IllegalStateException(); }; } @@ -205,11 +230,12 @@ public void testToQueryBuilder() { float[] vector = randomVector(randomIntBetween(2, 30)); int k = randomIntBetween(1, 100); int numCands = randomIntBetween(k, 1000); + Float visitPercentage = randomBoolean() ? null : randomFloatBetween(0.0f, 100.0f, true); Float similarity = randomBoolean() ? null : randomFloat(); RescoreVectorBuilder rescoreVectorBuilder = randomBoolean() ? null : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)); - KnnSearchBuilder builder = new KnnSearchBuilder(field, vector, k, numCands, rescoreVectorBuilder, similarity); + KnnSearchBuilder builder = new KnnSearchBuilder(field, vector, k, numCands, visitPercentage, rescoreVectorBuilder, similarity); float boost = AbstractQueryBuilder.DEFAULT_BOOST; if (randomBoolean()) { @@ -225,16 +251,22 @@ public void testToQueryBuilder() { builder.addFilterQuery(filter); } - QueryBuilder expected = new KnnVectorQueryBuilder(field, vector, numCands, numCands, rescoreVectorBuilder, similarity) - .addFilterQueries(filterQueries) - .boost(boost); + QueryBuilder expected = new KnnVectorQueryBuilder( + field, + vector, + numCands, + numCands, + visitPercentage, + rescoreVectorBuilder, + similarity + ).addFilterQueries(filterQueries).boost(boost); assertEquals(expected, builder.toQueryBuilder()); } public void testNumCandsLessThanK() { IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> new KnnSearchBuilder("field", randomVector(3), 50, 10, null, null) + () -> new KnnSearchBuilder("field", randomVector(3), 50, 10, 10f, null, null) ); assertThat(e.getMessage(), containsString("[num_candidates] cannot be less than [k]")); } @@ -242,15 +274,31 @@ public void testNumCandsLessThanK() { public void testNumCandsExceedsLimit() { IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> new KnnSearchBuilder("field", randomVector(3), 100, 10002, null, null) + () -> new KnnSearchBuilder("field", randomVector(3), 100, 10002, 10f, null, null) ); assertThat(e.getMessage(), containsString("[num_candidates] cannot exceed [10000]")); } + public void testVisitPercentageLessThan0() { + IllegalArgumentException e = expectThrows( + IllegalArgumentException.class, + () -> new KnnSearchBuilder("field", randomVector(3), 50, 100, -190f, null, null) + ); + assertThat(e.getMessage(), containsString("[visit_percentage] must be between 0 and 100")); + } + + public void testVisitPercentageGreaterThan100() { + IllegalArgumentException e = expectThrows( + IllegalArgumentException.class, + () -> new KnnSearchBuilder("field", randomVector(3), 100, 1000, 100000f, null, null) + ); + assertThat(e.getMessage(), containsString("[visit_percentage] must be between 0 and 100")); + } + public void testInvalidK() { IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> new KnnSearchBuilder("field", randomVector(3), 0, 100, null, null) + () -> new KnnSearchBuilder("field", randomVector(3), 0, 100, 10f, null, null) ); assertThat(e.getMessage(), containsString("[k] must be greater than 0")); } @@ -258,7 +306,7 @@ public void testInvalidK() { public void testInvalidRescoreVectorBuilder() { IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> new KnnSearchBuilder("field", randomVector(3), 10, 100, new RescoreVectorBuilder(0.99F), null) + () -> new KnnSearchBuilder("field", randomVector(3), 10, 100, 10f, new RescoreVectorBuilder(0.99F), null) ); assertThat(e.getMessage(), containsString("[oversample] must be >= 1.0")); } @@ -271,6 +319,7 @@ public void testRewrite() throws Exception { new TestQueryVectorBuilderPlugin.TestQueryVectorBuilder(expectedArray), 5, 10, + 10f, expectedRescore, 1f ); diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchRequestParserTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchRequestParserTests.java index 4e4d2158a9574..38b5bd8b4b475 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchRequestParserTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchRequestParserTests.java @@ -110,9 +110,12 @@ public void testParseSourceString() throws IOException { .startObject(KnnSearchRequestParser.KNN_SECTION_FIELD.getPreferredName()) .field(KnnSearch.FIELD_FIELD.getPreferredName(), knnSearch.field) .field(KnnSearch.K_FIELD.getPreferredName(), knnSearch.k) - .field(KnnSearch.NUM_CANDS_FIELD.getPreferredName(), knnSearch.numCands) - .field(KnnSearch.QUERY_VECTOR_FIELD.getPreferredName(), knnSearch.queryVector) - .endObject(); + .field(KnnSearch.NUM_CANDS_FIELD.getPreferredName(), knnSearch.numCands); + if (knnSearch.visitPercentage != null) { + builder.field(KnnSearch.VISIT_PERCENTAGE_FIELD.getPreferredName(), knnSearch.visitPercentage); + } + builder.field(KnnSearch.QUERY_VECTOR_FIELD.getPreferredName(), knnSearch.queryVector); + builder.endObject(); builder.field(SearchSourceBuilder._SOURCE_FIELD.getPreferredName(), "some-field"); builder.endObject(); @@ -136,9 +139,12 @@ public void testParseSourceArray() throws IOException { .startObject(KnnSearchRequestParser.KNN_SECTION_FIELD.getPreferredName()) .field(KnnSearch.FIELD_FIELD.getPreferredName(), knnSearch.field) .field(KnnSearch.K_FIELD.getPreferredName(), knnSearch.k) - .field(KnnSearch.NUM_CANDS_FIELD.getPreferredName(), knnSearch.numCands) - .field(KnnSearch.QUERY_VECTOR_FIELD.getPreferredName(), knnSearch.queryVector) - .endObject(); + .field(KnnSearch.NUM_CANDS_FIELD.getPreferredName(), knnSearch.numCands); + if (knnSearch.visitPercentage != null) { + builder.field(KnnSearch.VISIT_PERCENTAGE_FIELD.getPreferredName(), knnSearch.visitPercentage); + } + builder.field(KnnSearch.QUERY_VECTOR_FIELD.getPreferredName(), knnSearch.queryVector); + builder.endObject(); builder.array(SearchSourceBuilder._SOURCE_FIELD.getPreferredName(), "field1", "field2", "field3"); builder.endObject(); @@ -171,6 +177,7 @@ public void testNumCandsLessThanK() throws IOException { .field(KnnSearch.FIELD_FIELD.getPreferredName(), "field") .field(KnnSearch.K_FIELD.getPreferredName(), 100) .field(KnnSearch.NUM_CANDS_FIELD.getPreferredName(), 80) + .field(KnnSearch.VISIT_PERCENTAGE_FIELD.getPreferredName(), 100.0f) .field(KnnSearch.QUERY_VECTOR_FIELD.getPreferredName(), new float[] { 1.0f, 2.0f, 3.0f }) .endObject() .endObject(); @@ -187,6 +194,7 @@ public void testNumCandsExceedsLimit() throws IOException { .field(KnnSearch.FIELD_FIELD.getPreferredName(), "field") .field(KnnSearch.K_FIELD.getPreferredName(), 100) .field(KnnSearch.NUM_CANDS_FIELD.getPreferredName(), 10002) + .field(KnnSearch.VISIT_PERCENTAGE_FIELD.getPreferredName(), 100.0f) .field(KnnSearch.QUERY_VECTOR_FIELD.getPreferredName(), new float[] { 1.0f, 2.0f, 3.0f }) .endObject() .endObject(); @@ -195,6 +203,40 @@ public void testNumCandsExceedsLimit() throws IOException { assertThat(e.getMessage(), containsString("[num_candidates] cannot exceed [10000]")); } + public void testVisitPercnetageLessThan0() throws IOException { + XContentType xContentType = randomFrom(XContentType.values()); + XContentBuilder builder = XContentBuilder.builder(xContentType.xContent()) + .startObject() + .startObject(KnnSearchRequestParser.KNN_SECTION_FIELD.getPreferredName()) + .field(KnnSearch.FIELD_FIELD.getPreferredName(), "field") + .field(KnnSearch.K_FIELD.getPreferredName(), 100) + .field(KnnSearch.NUM_CANDS_FIELD.getPreferredName(), 1000) + .field(KnnSearch.VISIT_PERCENTAGE_FIELD.getPreferredName(), -100f) + .field(KnnSearch.QUERY_VECTOR_FIELD.getPreferredName(), new float[] { 1.0f, 2.0f, 3.0f }) + .endObject() + .endObject(); + + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> parseSearchRequest(builder)); + assertThat(e.getMessage(), containsString("[visit_percentage] must be between 0 and 100")); + } + + public void testVisitPercnetageGreaterThan100() throws IOException { + XContentType xContentType = randomFrom(XContentType.values()); + XContentBuilder builder = XContentBuilder.builder(xContentType.xContent()) + .startObject() + .startObject(KnnSearchRequestParser.KNN_SECTION_FIELD.getPreferredName()) + .field(KnnSearch.FIELD_FIELD.getPreferredName(), "field") + .field(KnnSearch.K_FIELD.getPreferredName(), 100) + .field(KnnSearch.NUM_CANDS_FIELD.getPreferredName(), 1000) + .field(KnnSearch.VISIT_PERCENTAGE_FIELD.getPreferredName(), 1000f) + .field(KnnSearch.QUERY_VECTOR_FIELD.getPreferredName(), new float[] { 1.0f, 2.0f, 3.0f }) + .endObject() + .endObject(); + + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> parseSearchRequest(builder)); + assertThat(e.getMessage(), containsString("[visit_percentage] must be between 0 and 100")); + } + public void testInvalidK() throws IOException { XContentType xContentType = randomFrom(XContentType.values()); XContentBuilder builder = XContentBuilder.builder(xContentType.xContent()) @@ -203,6 +245,7 @@ public void testInvalidK() throws IOException { .field(KnnSearch.FIELD_FIELD.getPreferredName(), "field") .field(KnnSearch.K_FIELD.getPreferredName(), 0) .field(KnnSearch.NUM_CANDS_FIELD.getPreferredName(), 10) + .field(KnnSearch.VISIT_PERCENTAGE_FIELD.getPreferredName(), 100.0f) .field(KnnSearch.QUERY_VECTOR_FIELD.getPreferredName(), new float[] { 1.0f, 2.0f, 3.0f }) .endObject() .endObject(); @@ -238,7 +281,8 @@ private KnnSearch randomKnnSearch() { int k = randomIntBetween(1, 100); int numCands = randomIntBetween(k, 1000); - return new KnnSearch(field, vector, k, numCands); + Float visitPercentage = randomBoolean() ? null : randomFloatBetween(0.0f, 100.0f, true); + return new KnnSearch(field, vector, k, numCands, visitPercentage); } private List randomFilterQueries() { @@ -260,9 +304,12 @@ private XContentBuilder createRequestBody(KnnSearch knnSearch, List planStr = new AtomicReference<>(); plan.forEachDown(EsQueryExec.class, result -> planStr.set(result.query().toString())); - var expectedQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 0.1f, 0.2f, 0.3f }, 10, null, null, null); + var expectedQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 0.1f, 0.2f, 0.3f }, 10, null, null, null, null); assertEquals(expectedQuery.toString(), planStr.get()); } @@ -1443,7 +1444,7 @@ public void testKnnKAndMinCandidatesLowerK() { AtomicReference planStr = new AtomicReference<>(); plan.forEachDown(EsQueryExec.class, result -> planStr.set(result.query().toString())); - var expectedQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 0.1f, 0.2f, 0.3f }, 50, 50, null, null); + var expectedQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 0.1f, 0.2f, 0.3f }, 50, 50, null, null, null); assertEquals(expectedQuery.toString(), planStr.get()); } @@ -1462,7 +1463,7 @@ public void testKnnKAndMinCandidatesHigherK() { AtomicReference planStr = new AtomicReference<>(); plan.forEachDown(EsQueryExec.class, result -> planStr.set(result.query().toString())); - var expectedQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 0.1f, 0.2f, 0.3f }, 50, 50, null, null); + var expectedQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 0.1f, 0.2f, 0.3f }, 50, 50, null, null, null); assertEquals(expectedQuery.toString(), planStr.get()); } @@ -1933,6 +1934,7 @@ public void testKnnPrefilters() { 1000, null, null, + null, null ).addFilterQuery(expectedFilterQueryBuilder); var expectedQuery = boolQuery().must(expectedKnnQueryBuilder).must(expectedFilterQueryBuilder); @@ -1969,6 +1971,7 @@ public void testKnnPrefiltersWithMultipleFilters() { 1000, null, null, + null, null ).addFilterQuery(expectedFilterQueryBuilder); var expectedQuery = boolQuery().must(expectedKnnQueryBuilder).must(integerFilter).must(keywordFilter); @@ -2004,6 +2007,7 @@ public void testPushDownConjunctionsToKnnPrefilter() { 1000, null, null, + null, null ).addFilterQuery(expectedFilterQueryBuilder); @@ -2041,6 +2045,7 @@ public void testPushDownNegatedConjunctionsToKnnPrefilter() { 1000, null, null, + null, null ).addFilterQuery(expectedFilterQueryBuilder); @@ -2065,7 +2070,15 @@ public void testNotPushDownDisjunctionsToKnnPrefilter() { var queryExec = as(field.child(), EsQueryExec.class); // The disjunction should not be pushed down to the KNN query - KnnVectorQueryBuilder knnQueryBuilder = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 1000, null, null, null); + KnnVectorQueryBuilder knnQueryBuilder = new KnnVectorQueryBuilder( + "dense_vector", + new float[] { 0, 1, 2 }, + 1000, + null, + null, + null, + null + ); QueryBuilder rangeQueryBuilder = wrapWithSingleQuery( query, unscore(rangeQuery("integer").gt(10)), @@ -2148,8 +2161,8 @@ and NOT ((keyword == "test") or knn(dense_vector, [4, 5, 6]))) new Source(2, 42, "NOT integer > 10") ); - KnnVectorQueryBuilder firstKnn = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 1000, null, null, null); - KnnVectorQueryBuilder secondKnn = new KnnVectorQueryBuilder("dense_vector", new float[] { 4, 5, 6 }, 1000, null, null, null); + KnnVectorQueryBuilder firstKnn = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 1000, null, null, null, null); + KnnVectorQueryBuilder secondKnn = new KnnVectorQueryBuilder("dense_vector", new float[] { 4, 5, 6 }, 1000, null, null, null, null); firstKnn.addFilterQuery(notKeywordFilter); secondKnn.addFilterQuery(notIntegerGt10); @@ -2177,7 +2190,15 @@ public void testMultipleKnnQueriesInPrefilters() { var field = as(project.child(), FieldExtractExec.class); var queryExec = as(field.child(), EsQueryExec.class); - KnnVectorQueryBuilder firstKnnQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 1000, null, null, null); + KnnVectorQueryBuilder firstKnnQuery = new KnnVectorQueryBuilder( + "dense_vector", + new float[] { 0, 1, 2 }, + 1000, + null, + null, + null, + null + ); // Integer range query (right side of first OR) QueryBuilder integerRangeQuery = wrapWithSingleQuery( query, @@ -2187,7 +2208,15 @@ public void testMultipleKnnQueriesInPrefilters() { ); // Second KNN query (right side of second OR) - KnnVectorQueryBuilder secondKnnQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 4, 5, 6 }, 1000, null, null, null); + KnnVectorQueryBuilder secondKnnQuery = new KnnVectorQueryBuilder( + "dense_vector", + new float[] { 4, 5, 6 }, + 1000, + null, + null, + null, + null + ); // Keyword term query (left side of second OR) QueryBuilder keywordQuery = wrapWithSingleQuery( @@ -2726,7 +2755,7 @@ private static Object randomVector() { @Override public QueryBuilder queryBuilder() { - return new KnnVectorQueryBuilder(fieldName(), (float[]) queryString(), k, null, null, null); + return new KnnVectorQueryBuilder(fieldName(), (float[]) queryString(), k, null, null, null, null); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 50ca0e61eaeb8..5b35c0384ce99 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -970,7 +970,7 @@ yield new SparseVectorQueryBuilder( k = Math.max(k, DEFAULT_SIZE); } - yield new KnnVectorQueryBuilder(inferenceResultsFieldName, inference, k, null, null, null); + yield new KnnVectorQueryBuilder(inferenceResultsFieldName, inference, k, null, null, null, null); } default -> throw new IllegalStateException( "Field [" diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticKnnVectorQueryRewriteInterceptor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticKnnVectorQueryRewriteInterceptor.java index b1f5c240371f8..afa2b0182a0a4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticKnnVectorQueryRewriteInterceptor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticKnnVectorQueryRewriteInterceptor.java @@ -141,6 +141,7 @@ private KnnVectorQueryBuilder addIndexFilterToKnnVectorQuery(Collection original.queryVectorBuilder(), original.k(), original.numCands(), + original.visitPercentage(), original.getVectorSimilarity() ); } else { @@ -149,6 +150,7 @@ private KnnVectorQueryBuilder addIndexFilterToKnnVectorQuery(Collection original.queryVector(), original.k(), original.numCands(), + original.visitPercentage(), original.rescoreVectorBuilder(), original.getVectorSimilarity() ); @@ -180,6 +182,7 @@ private KnnVectorQueryBuilder buildNewKnnVectorQuery( queryVectorBuilder, original.k(), original.numCands(), + original.visitPercentage(), original.getVectorSimilarity() ); } else { @@ -188,6 +191,7 @@ private KnnVectorQueryBuilder buildNewKnnVectorQuery( original.queryVector(), original.k(), original.numCands(), + original.visitPercentage(), original.rescoreVectorBuilder(), original.getVectorSimilarity() ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java index 1f0b56e3d6848..751f5280682bc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java @@ -60,7 +60,7 @@ public void testKnnQueryWithVectorBuilderIsInterceptedAndRewritten() throws IOEx ); QueryRewriteContext context = createQueryRewriteContext(inferenceFields); QueryVectorBuilder queryVectorBuilder = new TextEmbeddingQueryVectorBuilder(INFERENCE_ID, QUERY); - KnnVectorQueryBuilder original = new KnnVectorQueryBuilder(FIELD_NAME, queryVectorBuilder, 10, 100, null); + KnnVectorQueryBuilder original = new KnnVectorQueryBuilder(FIELD_NAME, queryVectorBuilder, 10, 100, 10f, null); if (randomBoolean()) { float boost = randomFloatBetween(1, 10, randomBoolean()); original.boost(boost); @@ -79,7 +79,7 @@ public void testKnnWithQueryBuilderWithoutInferenceIdIsInterceptedAndRewritten() ); QueryRewriteContext context = createQueryRewriteContext(inferenceFields); QueryVectorBuilder queryVectorBuilder = new TextEmbeddingQueryVectorBuilder(null, QUERY); - KnnVectorQueryBuilder original = new KnnVectorQueryBuilder(FIELD_NAME, queryVectorBuilder, 10, 100, null); + KnnVectorQueryBuilder original = new KnnVectorQueryBuilder(FIELD_NAME, queryVectorBuilder, 10, 100, 10f, null); if (randomBoolean()) { float boost = randomFloatBetween(1, 10, randomBoolean()); original.boost(boost); @@ -124,7 +124,7 @@ private void testRewrittenInferenceQuery(QueryRewriteContext context, KnnVectorQ public void testKnnVectorQueryOnNonInferenceFieldRemainsUnchanged() throws IOException { QueryRewriteContext context = createQueryRewriteContext(Map.of()); // No inference fields QueryVectorBuilder queryVectorBuilder = new TextEmbeddingQueryVectorBuilder(null, QUERY); - QueryBuilder original = new KnnVectorQueryBuilder(FIELD_NAME, queryVectorBuilder, 10, 100, null); + QueryBuilder original = new KnnVectorQueryBuilder(FIELD_NAME, queryVectorBuilder, 10, 100, 10f, null); QueryBuilder rewritten = original.rewrite(context); assertTrue( "Expected query to remain knn but was [" + rewritten.getClass().getName() + "]", diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighterTests.java index 67c6d8da52a88..8284907a1873c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighterTests.java @@ -99,7 +99,15 @@ public void testDenseVector() throws Exception { Map queryMap = (Map) queries.get("dense_vector_1"); float[] vector = readDenseVector(queryMap.get("embeddings")); var fieldType = (SemanticTextFieldMapper.SemanticTextFieldType) mapperService.mappingLookup().getFieldType(SEMANTIC_FIELD_E5); - KnnVectorQueryBuilder knnQuery = new KnnVectorQueryBuilder(fieldType.getEmbeddingsField().fullPath(), vector, 10, 10, null, null); + KnnVectorQueryBuilder knnQuery = new KnnVectorQueryBuilder( + fieldType.getEmbeddingsField().fullPath(), + vector, + 10, + 10, + 10f, + null, + null + ); NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder(fieldType.getChunksField().fullPath(), knnQuery, ScoreMode.Max); var shardRequest = createShardSearchRequest(nestedQueryBuilder); var sourceToParse = new SourceToParse("0", readSampleDoc(useLegacyFormat), XContentType.JSON); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java index 7f6bc6117561b..32593ccd622db 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java @@ -101,7 +101,7 @@ public void testTelemetryForRRFRetriever() throws IOException { // search#1 - this will record 1 entry for "retriever" in `sections`, and 1 for "knn" under `retrievers` { performSearch( - new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null, null)) + new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, 10f, null, null)) ); } @@ -116,7 +116,7 @@ public void testTelemetryForRRFRetriever() throws IOException { { performSearch( new SearchSourceBuilder().retriever( - new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", new float[] { 1.0f }, 10, 15, null, null)) + new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", new float[] { 1.0f }, 10, 15, 10f, null, null)) ) ); } @@ -149,7 +149,7 @@ public void testTelemetryForRRFRetriever() throws IOException { // search#6 - this will record 1 entry for "knn" in `sections` { performSearch( - new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, null, null))) + new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, 10f, null, null))) ); } diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java index b00af1713dcb6..86e835c21efe1 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java @@ -175,7 +175,16 @@ public void testLinearRetrieverWithAggs() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder( + VECTOR_FIELD, + new float[] { 2.0f }, + null, + 10, + 100, + null, + null, + null + ); // all requests would have an equal weight and use the identity normalizer source.retriever( @@ -233,7 +242,16 @@ public void testLinearWithCollapse() { standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 // with scores 1, 0.5, 0.05882353, 0.03846154 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder( + VECTOR_FIELD, + new float[] { 2.0f }, + null, + 10, + 100, + null, + null, + null + ); // final ranking with no-normalizer would be: doc 2, 6, 1, 4, 7, 3 // doc 1: 10 // doc 2: 9 + 20 + 1 = 30 @@ -302,7 +320,16 @@ public void testLinearRetrieverWithCollapseAndAggs() { standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 // with scores 1, 0.5, 0.05882353, 0.03846154 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder( + VECTOR_FIELD, + new float[] { 2.0f }, + null, + 10, + 100, + null, + null, + null + ); // final ranking with no-normalizer would be: doc 2, 6, 1, 4, 7, 3 // doc 1: 10 // doc 2: 9 + 20 + 1 = 30 @@ -393,7 +420,7 @@ public void testMultipleLinearRetrievers() { ), // this one bring just doc 7 which should be ranked first eventually with a score of 100 new CompoundRetrieverBuilder.RetrieverSource( - new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 7.0f }, null, 1, 100, null, null), + new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 7.0f }, null, 1, 100, null, null, null), null ) ), @@ -447,7 +474,16 @@ public void testLinearExplainWithNamedRetrievers() { standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 // with scores 1, 0.5, 0.05882353, 0.03846154 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder( + VECTOR_FIELD, + new float[] { 2.0f }, + null, + 10, + 100, + null, + null, + null + ); // final ranking with no-normalizer would be: doc 2, 6, 1, 4, 7, 3 // doc 1: 10 // doc 2: 9 + 20 + 1 = 30 @@ -537,7 +573,16 @@ public void testLinearExplainWithAnotherNestedLinear() { standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 // with scores 1, 0.5, 0.05882353, 0.03846154 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder( + VECTOR_FIELD, + new float[] { 2.0f }, + null, + 10, + 100, + null, + null, + null + ); // final ranking with no-normalizer would be: doc 2, 6, 1, 4, 7, 3 // doc 1: 10 // doc 2: 9 + 20 + 1 = 30 @@ -764,6 +809,7 @@ public void testLinearFiltersPropagatedToKnnQueryVectorBuilder() { 10, 10, null, + null, null ); source.retriever( @@ -816,8 +862,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws throw new IllegalStateException("Should not be called"); } }; - var knn = new KnnRetrieverBuilder("vector", null, vectorBuilder, 10, 10, null, null); - var standard = new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", vectorBuilder, 10, 10, null)); + var knn = new KnnRetrieverBuilder("vector", null, vectorBuilder, 10, 10, null, null, null); + var standard = new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", vectorBuilder, 10, 10, 10f, null)); var rrf = new LinearRetrieverBuilder( List.of(new CompoundRetrieverBuilder.RetrieverSource(knn, null), new CompoundRetrieverBuilder.RetrieverSource(standard, null)), 10 diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankMultiShardIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankMultiShardIT.java index 457c57410d168..016507ac9c369 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankMultiShardIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankMultiShardIT.java @@ -136,7 +136,7 @@ public void setupSuiteScopeCluster() throws Exception { public void testTotalDocsSmallerThanSize() { float[] queryVector = { 0.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 3, 3, null, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 3, 3, 10f, null, null); assertResponse( prepareSearch("tiny_index").setRankBuilder(new RRFRankBuilder(100, 1)) .setKnnSearch(List.of(knnSearch)) @@ -167,7 +167,7 @@ public void testTotalDocsSmallerThanSize() { public void testBM25AndKnn() { float[] queryVector = { 500.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, 10f, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(101, 1)) .setTrackTotalHits(false) @@ -208,8 +208,8 @@ public void testBM25AndKnn() { public void testMultipleOnlyKnn() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, 10f, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, 10f, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(51, 1)) .setTrackTotalHits(true) @@ -260,8 +260,8 @@ public void testMultipleOnlyKnn() { public void testBM25AndMultipleKnn() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, 10f, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, 10f, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(51, 1)) .setTrackTotalHits(false) @@ -332,7 +332,7 @@ public void testBM25AndMultipleKnn() { public void testBM25AndKnnWithBucketAggregation() { float[] queryVector = { 500.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, 10f, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(101, 1)) .setTrackTotalHits(true) @@ -389,8 +389,8 @@ public void testBM25AndKnnWithBucketAggregation() { public void testMultipleOnlyKnnWithAggregation() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, 10f, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, 10f, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(51, 1)) .setTrackTotalHits(false) @@ -457,8 +457,8 @@ public void testMultipleOnlyKnnWithAggregation() { public void testBM25AndMultipleKnnWithAggregation() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, 10f, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, 10f, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(51, 1)) .setTrackTotalHits(true) @@ -704,7 +704,7 @@ public void testMultiBM25WithAggregation() { public void testMultiBM25AndSingleKnn() { float[] queryVector = { 500.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, 10f, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(101, 1)) .setTrackTotalHits(false) @@ -762,7 +762,7 @@ public void testMultiBM25AndSingleKnn() { public void testMultiBM25AndSingleKnnWithAggregation() { float[] queryVector = { 500.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, 10f, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(101, 1)) .setTrackTotalHits(false) @@ -837,8 +837,8 @@ public void testMultiBM25AndSingleKnnWithAggregation() { public void testMultiBM25AndMultipleKnn() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 101, 1001, null, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 101, 1001, null, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 101, 1001, 10f, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 101, 1001, 10f, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(101, 1)) .setTrackTotalHits(false) @@ -899,8 +899,8 @@ public void testMultiBM25AndMultipleKnn() { public void testMultiBM25AndMultipleKnnWithAggregation() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 101, 1001, null, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 101, 1001, null, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 101, 1001, 10f, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 101, 1001, 10f, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(101, 1)) .setTrackTotalHits(false) @@ -979,7 +979,7 @@ public void testBasicRRFExplain() { // the first result should be the one present in both queries (i.e. doc with text0: 10 and vector: [10]) and the other ones // should only match the knn query float[] queryVector = { 9f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null).queryName("my_knn_search"); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, 10f, null, null).queryName("my_knn_search"); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(100, 1)) .setKnnSearch(List.of(knnSearch)) @@ -1045,7 +1045,7 @@ public void testRRFExplainUnknownField() { // in this test we try knn with a query on an unknown field that would be rewritten to MatchNoneQuery // so we expect results and explanations only for the first part float[] queryVector = { 9f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null).queryName("my_knn_search"); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, 10f, null, null).queryName("my_knn_search"); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(100, 1)) .setKnnSearch(List.of(knnSearch)) @@ -1112,7 +1112,7 @@ public void testRRFExplainOneUnknownFieldSubSearches() { // while the other one would produce a match. // So, we'd have a total of 3 queries, a (rewritten) MatchNoneQuery, a TermQuery, and a kNN query float[] queryVector = { 9f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null).queryName("my_knn_search"); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, 10f, null, null).queryName("my_knn_search"); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(100, 1)) .setKnnSearch(List.of(knnSearch)) diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankSingleShardIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankSingleShardIT.java index a4e7db3b3e3fe..364a241a643a4 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankSingleShardIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankSingleShardIT.java @@ -131,7 +131,7 @@ public void setupIndices() throws Exception { public void testTotalDocsSmallerThanSize() { float[] queryVector = { 0.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 3, 3, null, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 3, 3, 10f, null, null); assertResponse( client().prepareSearch("tiny_index") @@ -164,7 +164,7 @@ public void testTotalDocsSmallerThanSize() { public void testBM25AndKnn() { float[] queryVector = { 500.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, 10f, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(101, 1)) @@ -206,8 +206,8 @@ public void testBM25AndKnn() { public void testMultipleOnlyKnn() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, 10f, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, 10f, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(51, 1)) @@ -259,8 +259,8 @@ public void testMultipleOnlyKnn() { public void testBM25AndMultipleKnn() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, 10f, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, 10f, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(51, 1)) @@ -332,7 +332,7 @@ public void testBM25AndMultipleKnn() { public void testBM25AndKnnWithBucketAggregation() { float[] queryVector = { 500.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, 10f, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(101, 1)) @@ -390,8 +390,8 @@ public void testBM25AndKnnWithBucketAggregation() { public void testMultipleOnlyKnnWithAggregation() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, 10f, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, 10f, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(51, 1)) @@ -459,8 +459,8 @@ public void testMultipleOnlyKnnWithAggregation() { public void testBM25AndMultipleKnnWithAggregation() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, 10f, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, 10f, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(51, 1)) @@ -709,7 +709,7 @@ public void testMultiBM25WithAggregation() { public void testMultiBM25AndSingleKnn() { float[] queryVector = { 500.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, 10f, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(101, 1)) @@ -768,7 +768,7 @@ public void testMultiBM25AndSingleKnn() { public void testMultiBM25AndSingleKnnWithAggregation() { float[] queryVector = { 500.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, 10f, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(101, 1)) @@ -844,8 +844,8 @@ public void testMultiBM25AndSingleKnnWithAggregation() { public void testMultiBM25AndMultipleKnn() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 101, 1001, null, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 101, 1001, null, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 101, 1001, 10f, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 101, 1001, 10f, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(101, 1)) @@ -907,8 +907,8 @@ public void testMultiBM25AndMultipleKnn() { public void testMultiBM25AndMultipleKnnWithAggregation() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 101, 1001, null, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 101, 1001, null, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 101, 1001, 10f, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 101, 1001, 10f, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(101, 1)) diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java index 6854fc436038f..1a5cd6bca3059 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java @@ -190,6 +190,7 @@ public void testRRFPagination() { 10, 100, null, + null, null ); source.retriever( @@ -241,7 +242,16 @@ public void testRRFWithAggs() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder( + VECTOR_FIELD, + new float[] { 2.0f }, + null, + 10, + 100, + null, + null, + null + ); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -296,7 +306,16 @@ public void testRRFWithCollapse() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder( + VECTOR_FIELD, + new float[] { 2.0f }, + null, + 10, + 100, + null, + null, + null + ); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -353,7 +372,16 @@ public void testRRFRetrieverWithCollapseAndAggs() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder( + VECTOR_FIELD, + new float[] { 2.0f }, + null, + 10, + 100, + null, + null, + null + ); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -419,7 +447,16 @@ public void testMultipleRRFRetrievers() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder( + VECTOR_FIELD, + new float[] { 2.0f }, + null, + 10, + 100, + null, + null, + null + ); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -438,7 +475,7 @@ public void testMultipleRRFRetrievers() { ), // this one bring just doc 7 which should be ranked first eventually new CompoundRetrieverBuilder.RetrieverSource( - new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 7.0f }, null, 1, 100, null, null), + new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 7.0f }, null, 1, 100, null, null, null), null ) ), @@ -485,7 +522,16 @@ public void testRRFExplainWithNamedRetrievers() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder( + VECTOR_FIELD, + new float[] { 2.0f }, + null, + 10, + 100, + null, + null, + null + ); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -544,7 +590,16 @@ public void testRRFExplainWithAnotherNestedRRF() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder( + VECTOR_FIELD, + new float[] { 2.0f }, + null, + 10, + 100, + null, + null, + null + ); RRFRetrieverBuilder nestedRRF = new RRFRetrieverBuilder( Arrays.asList( @@ -765,6 +820,7 @@ public void testRRFFiltersPropagatedToKnnQueryVectorBuilder() { 10, 10, null, + null, null ); source.retriever( @@ -818,8 +874,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws throw new IllegalStateException("Should not be called"); } }; - var knn = new KnnRetrieverBuilder("vector", null, vectorBuilder, 10, 10, null, null); - var standard = new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", vectorBuilder, 10, 10, null)); + var knn = new KnnRetrieverBuilder("vector", null, vectorBuilder, 10, 10, null, null, null); + var standard = new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", vectorBuilder, 10, 10, 10f, null)); var rrf = new RRFRetrieverBuilder( List.of(new CompoundRetrieverBuilder.RetrieverSource(knn, null), new CompoundRetrieverBuilder.RetrieverSource(standard, null)), 10, diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderNestedDocsIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderNestedDocsIT.java index a00b940bbed62..21bbc6ab5180f 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderNestedDocsIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderNestedDocsIT.java @@ -149,7 +149,16 @@ public void testRRFRetrieverWithNestedQuery() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 6 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 6.0f }, null, 1, 100, null, null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder( + VECTOR_FIELD, + new float[] { 6.0f }, + null, + 1, + 100, + null, + null, + null + ); source.retriever( new RRFRetrieverBuilder( Arrays.asList( diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverTelemetryIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverTelemetryIT.java index 556f8b87923db..2712ab84cf972 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverTelemetryIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverTelemetryIT.java @@ -104,7 +104,7 @@ public void testTelemetryForRRFRetriever() throws IOException { // search#1 - this will record 1 entry for "retriever" in `sections`, and 1 for "knn" under `retrievers` { performSearch( - new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null, null)) + new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, 10f, null, null)) ); } @@ -119,7 +119,7 @@ public void testTelemetryForRRFRetriever() throws IOException { { performSearch( new SearchSourceBuilder().retriever( - new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", new float[] { 1.0f }, 10, 15, null, null)) + new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", new float[] { 1.0f }, 10, 15, 10f, null, null)) ) ); } @@ -138,7 +138,7 @@ public void testTelemetryForRRFRetriever() throws IOException { new RRFRetrieverBuilder( Arrays.asList( new CompoundRetrieverBuilder.RetrieverSource( - new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null, null), + new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, 10f, null, null), null ), new CompoundRetrieverBuilder.RetrieverSource( @@ -156,7 +156,7 @@ public void testTelemetryForRRFRetriever() throws IOException { // search#6 - this will record 1 entry for "knn" in `sections` { performSearch( - new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, null, null))) + new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, 10f, null, null))) ); } diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java index b6ffbf8f3301e..61ff39197a799 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java @@ -212,6 +212,7 @@ public RetrieverBuilder toRetriever(SearchSourceBuilder source, Predicate throw new UnsupportedOperationException("Unhandled task type [" + fieldModel.getTaskType() + "]"); };