diff --git a/docs/changelog/113949.yaml b/docs/changelog/113949.yaml new file mode 100644 index 0000000000000..10ea17cd07ac4 --- /dev/null +++ b/docs/changelog/113949.yaml @@ -0,0 +1,7 @@ +pr: 113949 +summary: Support kNN filter on nested metadata +area: Vector Search +type: enhancement +issues: + - 128803 + - 106994 diff --git a/docs/reference/query-languages/query-dsl/query-dsl-knn-query.md b/docs/reference/query-languages/query-dsl/query-dsl-knn-query.md index f7a1b52ee67a1..4a73e93ccf365 100644 --- a/docs/reference/query-languages/query-dsl/query-dsl-knn-query.md +++ b/docs/reference/query-languages/query-dsl/query-dsl-knn-query.md @@ -203,10 +203,19 @@ POST my-image-index/_search `knn` query can be used inside a nested query. The behaviour here is similar to [top level nested kNN search](docs-content://solutions/search/vector/knn.md#nested-knn-search): * kNN search over nested dense_vectors diversifies the top results over the top-level document -* `filter` over the top-level document metadata is supported and acts as a pre-filter -* `filter` over `nested` field metadata is not supported +* `filter` both over the top-level document metadata and `nested` is supported and acts as a pre-filter + +::::{note} +To ensure correct results: each individual filter must be either over +the top-level metadata or `nested` metadata. However, a single knn query +supports multiple filters, where some filters can be over the top-level +metadata and some over nested. +:::: -A sample query can look like below: + +Below is a sample query with filter over nested metadata. +For scoring parents' documents, this query only considers vectors that +have "paragraph.language" set to "EN". ```json { @@ -215,12 +224,46 @@ A sample query can look like below: "path" : "paragraph", "query" : { "knn": { - "query_vector": [ - 0.45, - 45 - ], + "query_vector": [0.45, 0.50], "field": "paragraph.vector", - "num_candidates": 2 + "filter": { + "match": { + "paragraph.language": "EN" + } + } + } + } + } + } +} +``` + +Below is a sample query with two filters: one over nested metadata +and another over the top level metadata. For scoring parents' documents, +this query only considers vectors whose parent's title contain "essay" +word and have "paragraph.language" set to "EN". + +```json +{ + "query" : { + "nested" : { + "path" : "paragraph", + "query" : { + "knn": { + "query_vector": [0.45, 0.50], + "field": "paragraph.vector", + "filter": [ + { + "match": { + "paragraph.language": "EN" + } + }, + { + "match": { + "title": "essay" + } + } + ] } } } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/100_knn_nested_search.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/100_knn_nested_search.yml index d627be2fb15c3..df66831ba94ff 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/100_knn_nested_search.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/100_knn_nested_search.yml @@ -16,6 +16,8 @@ setup: nested: type: nested properties: + language: + type: keyword paragraph_id: type: keyword vector: @@ -27,6 +29,13 @@ setup: type: hnsw m: 16 ef_construction: 200 + nested2: + type: nested + properties: + key: + type: keyword + value: + type: keyword - do: index: @@ -37,8 +46,16 @@ setup: nested: - paragraph_id: 0 vector: [230.0, 300.33, -34.8988, 15.555, -200.0] + language: EN - paragraph_id: 1 vector: [240.0, 300, -3, 1, -20] + language: FR + nested2: + - key: "category" + value: "domestic" + - key: "level" + value: "beginner" + - do: index: @@ -49,10 +66,18 @@ setup: nested: - paragraph_id: 0 vector: [-0.5, 100.0, -13, 14.8, -156.0] + language: EN - paragraph_id: 2 vector: [0, 100.0, 0, 14.8, -156.0] + language: EN - paragraph_id: 3 vector: [0, 1.0, 0, 1.8, -15.0] + language: FR + nested2: + - key: "category" + value: "wild" + - key: "level" + value: "beginner" - do: index: @@ -63,6 +88,12 @@ setup: nested: - paragraph_id: 0 vector: [0.5, 111.3, -13.0, 14.8, -156.0] + language: FR + nested2: + - key: "category" + value: "domestic" + - key: "level" + value: "advanced" - do: indices.refresh: {} @@ -461,3 +492,125 @@ setup: - match: {hits.hits.0._id: "2"} - length: {hits.hits.0.inner_hits.nested.hits.hits: 1} - match: {hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0"} + + +--- +"Filter on nested fields": + - requires: + capabilities: + - method: POST + path: /_search + capabilities: [ knn_filter_on_nested_fields ] + test_runner_features: ["capabilities", "close_to"] + reason: "Capability for filtering on nested fields required" + + - do: + search: + index: test + body: + _source: false + knn: + boost: 2 + field: nested.vector + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 3 + filter: { match: { nested.language: "EN" } } + inner_hits: { size: 3, "fields": [ "nested.paragraph_id", "nested.language"], _source: false } + + - match: { hits.total.value: 2 } + - match: { hits.hits.0._id: "2" } + - match: { hits.hits.0.inner_hits.nested.hits.total.value: 2 } + - match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" } + - match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "EN" } + - match: { hits.hits.0.inner_hits.nested.hits.hits.1.fields.nested.0.paragraph_id.0: "2" } + - match: { hits.hits.0.inner_hits.nested.hits.hits.1.fields.nested.0.language.0: "EN" } + - close_to: { hits.hits.0._score: { value: 0.0182, error: 0.0001 } } + - close_to: { hits.hits.0.inner_hits.nested.hits.hits.0._score: { value: 0.0182, error: 0.0001 } } + - match: { hits.hits.1._id: "1" } + - match: { hits.hits.1.inner_hits.nested.hits.total.value: 1 } + - match: { hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" } + - match: { hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "EN" } + + + - do: + search: + index: test + body: + _source: false + knn: + boost: 2 + field: nested.vector + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 3 + filter: { match: { nested.language: "FR" } } + inner_hits: { size: 3, "fields": [ "nested.paragraph_id", "nested.language"], _source: false } + + - match: { hits.total.value: 3 } + - match: { hits.hits.0._id: "3" } + - match: { hits.hits.0.inner_hits.nested.hits.total.value: 1 } + - match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" } + - match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "FR" } + - close_to: { hits.hits.0._score: { value: 0.0043, error: 0.0001 } } + - close_to: { hits.hits.0.inner_hits.nested.hits.hits.0._score: { value: 0.0043, error: 0.0001 } } + - match: { hits.hits.1._id: "2" } + - match: { hits.hits.1.inner_hits.nested.hits.total.value: 1 } + - match: { hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "3" } + - match: { hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "FR" } + - match: { hits.hits.2._id: "1" } + - match: { hits.hits.2.inner_hits.nested.hits.total.value: 1 } + - match: { hits.hits.2.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "1" } + - match: { hits.hits.2.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "FR" } + + # filter on both nested and parent metadata with 2 different filters + - do: + search: + index: test + body: + _source: false + knn: + boost: 2 + field: nested.vector + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 3 + num_candidates: 10 + filter: [{ match: { nested.language: "FR" }}, {term: {name: "rabbit.jpg"}} ] + inner_hits: { size: 3, "fields": [ "nested.paragraph_id", "nested.language"], _source: false } + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "3" } + - match: { hits.hits.0.inner_hits.nested.hits.total.value: 1 } + - match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" } + - match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "FR" } + - close_to: { hits.hits.0._score: { value: 0.0043, error: 0.0001 } } + - close_to: { hits.hits.0.inner_hits.nested.hits.hits.0._score: { value: 0.0043, error: 0.0001 } } + + +--- +"Test filter on sibling nested fields works": + - requires: + capabilities: + - method: POST + path: /_search + capabilities: [ knn_filter_on_nested_fields ] + test_runner_features: ["capabilities", "close_to"] + reason: "Capability for filtering on nested fields required" + + - do: + search: + index: test + body: + _source: false + knn: + field: nested.vector + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + filter: + nested: + path: nested2 + query: + bool: + filter: + - match: + nested2.key: "category" + - match: + nested2.value: "domestic" + - match: { hits.total.value: 2} diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/130_knn_query_nested_search.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/130_knn_query_nested_search.yml index bf07144975650..2416689c285fd 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/130_knn_query_nested_search.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/130_knn_query_nested_search.yml @@ -16,6 +16,8 @@ setup: nested: type: nested properties: + language: + type: keyword paragraph_id: type: keyword vector: @@ -23,6 +25,17 @@ setup: dims: 5 index: true similarity: l2_norm + index_options: + type: hnsw + m: 16 + ef_construction: 200 + nested2: + type: nested + properties: + key: + type: keyword + value: + type: keyword aliases: my_alias: filter: @@ -38,8 +51,15 @@ setup: nested: - paragraph_id: 0 vector: [230.0, 300.33, -34.8988, 15.555, -200.0] + language: EN - paragraph_id: 1 vector: [240.0, 300, -3, 1, -20] + language: FR + nested2: + - key: "category" + value: "domestic" + - key: "level" + value: "beginner" - do: index: @@ -50,10 +70,19 @@ setup: nested: - paragraph_id: 0 vector: [-0.5, 100.0, -13, 14.8, -156.0] + language: EN - paragraph_id: 2 vector: [0, 100.0, 0, 14.8, -156.0] + language: EN - paragraph_id: 3 vector: [0, 1.0, 0, 1.8, -15.0] + language: FR + nested2: + - key: "category" + value: "wild" + - key: "level" + value: "beginner" + - do: index: @@ -64,6 +93,12 @@ setup: nested: - paragraph_id: 0 vector: [0.5, 111.3, -13.0, 14.8, -156.0] + language: FR + nested2: + - key: "category" + value: "domestic" + - key: "level" + value: "advanced" - do: indices.refresh: {} @@ -408,3 +443,147 @@ setup: - match: {hits.total.value: 1} - match: {hits.hits.0._id: "2"} + + +--- +"Filter on nested fields": + - requires: + capabilities: + - method: POST + path: /_search + capabilities: [ knn_filter_on_nested_fields ] + test_runner_features: ["capabilities", "close_to"] + reason: "Capability for filtering on nested fields required" + + - do: + search: + index: test + body: + _source: false + query: + nested: + path: nested + query: + knn: + boost: 2 + field: nested.vector + query_vector: [-0.5, 90.0, -10, 14.8, -156.0] + k: 3 + filter: + match: + nested.language: "EN" + inner_hits: { size: 3, "fields": [ "nested.paragraph_id", "nested.language"], _source: false } + + - match: {hits.total.value: 2} + - match: {hits.hits.0._id: "2"} + - match: { hits.hits.0.inner_hits.nested.hits.total.value: 2 } + - match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" } + - match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "EN" } + - match: { hits.hits.0.inner_hits.nested.hits.hits.1.fields.nested.0.paragraph_id.0: "2" } + - match: { hits.hits.0.inner_hits.nested.hits.hits.1.fields.nested.0.language.0: "EN" } + - close_to: { hits.hits.0._score: { value: 0.0182, error: 0.0001 } } + - close_to: { hits.hits.0.inner_hits.nested.hits.hits.0._score: { value: 0.0182, error: 0.0001 } } + - match: {hits.hits.1._id: "1"} + - match: { hits.hits.1.inner_hits.nested.hits.total.value: 1 } + - match: { hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" } + - match: { hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "EN" } + + - do: + search: + index: test + body: + _source: false + query: + nested: + path: nested + query: + knn: + boost: 2 + field: nested.vector + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 3 + filter: + match: + nested.language: "FR" + inner_hits: { size: 3, "fields": [ "nested.paragraph_id", "nested.language" ], _source: false } + + - match: { hits.total.value: 3 } + - match: { hits.hits.0._id: "3" } + - match: { hits.hits.0.inner_hits.nested.hits.total.value: 1 } + - match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" } + - match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "FR" } + - close_to: { hits.hits.0._score: { value: 0.0043, error: 0.0001 } } + - close_to: { hits.hits.0.inner_hits.nested.hits.hits.0._score: { value: 0.0043, error: 0.0001 } } + - match: { hits.hits.1._id: "2" } + - match: { hits.hits.1.inner_hits.nested.hits.total.value: 1 } + - match: { hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "3" } + - match: { hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "FR" } + - match: { hits.hits.2._id: "1" } + - match: { hits.hits.2.inner_hits.nested.hits.total.value: 1 } + - match: { hits.hits.2.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "1" } + - match: { hits.hits.2.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "FR" } + + # filter on both nested and parent metadata + - do: + search: + index: test + body: + _source: false + query: + nested: + path: nested + query: + knn: + boost: 2 + field: nested.vector + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 10 + filter: [{ match: { nested.language: "FR" }}, {term: {name: "rabbit.jpg"}} ] + inner_hits: { size: 3, "fields": [ "nested.paragraph_id", "nested.language" ], _source: false } + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "3" } + - match: { hits.hits.0.inner_hits.nested.hits.total.value: 1 } + - match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" } + - match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "FR" } + - close_to: { hits.hits.0._score: { value: 0.0043, error: 0.0001 } } + - close_to: { hits.hits.0.inner_hits.nested.hits.hits.0._score: { value: 0.0043, error: 0.0001 } } + + +--- +"Test filter on sibling nested fields doesn't work": + - requires: + capabilities: + - method: POST + path: /_search + capabilities: [ knn_filter_on_nested_fields ] + test_runner_features: ["capabilities", "close_to"] + reason: "Capability for filtering on nested fields required" + + - do: + search: + index: test + body: + _source: false + query: + nested: + path: nested + query: + knn: + field: nested.vector + query_vector: [-0.5, 90.0, -10, 14.8, -156.0] + k: 10 + filter: + nested: + path: nested2 + query: + bool: + filter: + - match: + nested2.key: "category" + - match: + nested2.value: "domestic" + inner_hits: { size: 3, "fields": [ "nested.paragraph_id", "nested.language"], _source: false } + + - match: { hits.total.value: 0 } + diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 569556a7aa2e6..e57cb485361b6 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -354,6 +354,7 @@ static TransportVersion def(int id) { public static final TransportVersion RERANK_SNIPPETS = def(9_130_0_00); public static final TransportVersion PIPELINE_TRACKING_INFO = def(9_131_0_00); public static final TransportVersion COMPONENT_TEMPLATE_TRACKING_INFO = def(9_132_0_00); + public static final TransportVersion TO_CHILD_BLOCK_JOIN_QUERY = def(9_133_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java b/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java index d67e656773495..71eb94459548c 100644 --- a/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java @@ -165,7 +165,8 @@ ShardSearchRequest rewriteShardSearchRequest(List knnResults, Sha scoreDocs.toArray(Lucene.EMPTY_SCORE_DOCS), source.knnSearch().get(i).getField(), source.knnSearch().get(i).getQueryVector(), - source.knnSearch().get(i).getSimilarity() + source.knnSearch().get(i).getSimilarity(), + source.knnSearch().get(i).getFilterQueries() ).boost(source.knnSearch().get(i).boost()).queryName(source.knnSearch().get(i).queryName()); if (nestedPath != null) { query = new NestedQueryBuilder(nestedPath, query, ScoreMode.Max).innerHit(source.knnSearch().get(i).innerHit()); diff --git a/server/src/main/java/org/elasticsearch/index/query/ToChildBlockJoinQueryBuilder.java b/server/src/main/java/org/elasticsearch/index/query/ToChildBlockJoinQueryBuilder.java new file mode 100644 index 0000000000000..1e6e6feee3f42 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/query/ToChildBlockJoinQueryBuilder.java @@ -0,0 +1,113 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.query; + +import org.apache.lucene.search.Query; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.join.ToChildBlockJoinQuery; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.lucene.search.Queries; +import org.elasticsearch.index.mapper.NestedObjectMapper; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +/** + * A query returns child documents whose parent matches the provided query. + * This query is used only for internal purposes and is not exposed to a user. + */ +public class ToChildBlockJoinQueryBuilder extends AbstractQueryBuilder { + public static final String NAME = "to_child_block_join"; + private final QueryBuilder parentQueryBuilder; + + public ToChildBlockJoinQueryBuilder(QueryBuilder parentQueryBuilder) { + this.parentQueryBuilder = parentQueryBuilder; + } + + public ToChildBlockJoinQueryBuilder(StreamInput in) throws IOException { + super(in); + parentQueryBuilder = in.readNamedWriteable(QueryBuilder.class); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeNamedWriteable(parentQueryBuilder); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(NAME); + builder.field("query"); + parentQueryBuilder.toXContent(builder, params); + boostAndQueryNameToXContent(builder); + builder.endObject(); + } + + @Override + protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + QueryBuilder rewritten = parentQueryBuilder.rewrite(queryRewriteContext); + if (rewritten instanceof MatchNoneQueryBuilder) { + return rewritten; + } + if (rewritten != parentQueryBuilder) { + return new ToChildBlockJoinQueryBuilder(rewritten); + } + return this; + } + + @Override + protected Query doToQuery(SearchExecutionContext context) throws IOException { + final Query parentFilter; + NestedObjectMapper originalObjectMapper = context.nestedScope().getObjectMapper(); + if (originalObjectMapper != null) { + try { + // we are in a nested context, to get the parent filter we need to go up one level + context.nestedScope().previousLevel(); + NestedObjectMapper objectMapper = context.nestedScope().getObjectMapper(); + parentFilter = objectMapper == null + ? Queries.newNonNestedFilter(context.indexVersionCreated()) + : objectMapper.nestedTypeFilter(); + } finally { + context.nestedScope().nextLevel(originalObjectMapper); + } + } else { + // we are NOT in a nested context, coming from the top level knn search + parentFilter = Queries.newNonNestedFilter(context.indexVersionCreated()); + } + final BitSetProducer parentBitSet = context.bitsetFilter(parentFilter); + Query parentQuery = parentQueryBuilder.toQuery(context); + // ensure that parentQuery only applies to parent docs by adding parentFilter + return new ToChildBlockJoinQuery(Queries.filtered(parentQuery, parentFilter), parentBitSet); + } + + @Override + protected boolean doEquals(ToChildBlockJoinQueryBuilder other) { + return Objects.equals(parentQueryBuilder, other.parentQueryBuilder); + } + + @Override + protected int doHashCode() { + return Objects.hash(parentQueryBuilder); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.TO_CHILD_BLOCK_JOIN_QUERY; + } +} diff --git a/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java b/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java index 66648b7126514..43c2332b8e4a5 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java +++ b/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java @@ -57,6 +57,7 @@ private SearchCapabilities() {} private static final String FIELD_EXISTS_QUERY_FOR_TEXT_FIELDS_NO_INDEX_OR_DV = "field_exists_query_for_text_fields_no_index_or_dv"; private static final String SYNTHETIC_VECTORS_SETTING = "synthetic_vectors_setting"; private static final String UPDATE_FIELD_TO_BBQ_DISK = "update_field_to_bbq_disk"; + private static final String KNN_FILTER_ON_NESTED_FIELDS_CAPABILITY = "knn_filter_on_nested_fields"; public static final Set CAPABILITIES; static { @@ -82,6 +83,7 @@ private SearchCapabilities() {} capabilities.add(DENSE_VECTOR_UPDATABLE_BBQ); capabilities.add(FIELD_EXISTS_QUERY_FOR_TEXT_FIELDS_NO_INDEX_OR_DV); capabilities.add(UPDATE_FIELD_TO_BBQ_DISK); + capabilities.add(KNN_FILTER_ON_NESTED_FIELDS_CAPABILITY); if (SYNTHETIC_VECTORS) { capabilities.add(SYNTHETIC_VECTORS_SETTING); } diff --git a/server/src/main/java/org/elasticsearch/search/SearchModule.java b/server/src/main/java/org/elasticsearch/search/SearchModule.java index 56b203700b362..7a5ae6f25a632 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchModule.java +++ b/server/src/main/java/org/elasticsearch/search/SearchModule.java @@ -66,6 +66,7 @@ import org.elasticsearch.index.query.TermQueryBuilder; import org.elasticsearch.index.query.TermsQueryBuilder; import org.elasticsearch.index.query.TermsSetQueryBuilder; +import org.elasticsearch.index.query.ToChildBlockJoinQueryBuilder; import org.elasticsearch.index.query.WildcardQueryBuilder; import org.elasticsearch.index.query.WrapperQueryBuilder; import org.elasticsearch.index.query.functionscore.ExponentialDecayFunctionBuilder; @@ -1187,6 +1188,9 @@ private void registerQueryParsers(List plugins) { registerQuery(new QuerySpec<>(ExactKnnQueryBuilder.NAME, ExactKnnQueryBuilder::new, parser -> { throw new IllegalArgumentException("[exact_knn] queries cannot be provided directly"); })); + registerQuery(new QuerySpec<>(ToChildBlockJoinQueryBuilder.NAME, ToChildBlockJoinQueryBuilder::new, parser -> { + throw new IllegalArgumentException("[to_child_block_join] queries cannot be provided directly"); + })); registerQuery( new QuerySpec<>(RandomSamplingQueryBuilder.NAME, RandomSamplingQueryBuilder::new, RandomSamplingQueryBuilder::fromXContent) ); diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java index 1a81f4b984e93..17403bdbb05c9 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java @@ -17,13 +17,16 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.lucene.Lucene; import org.elasticsearch.index.query.AbstractQueryBuilder; +import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.MatchNoneQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.index.query.ToChildBlockJoinQueryBuilder; import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; +import java.util.List; import java.util.Objects; /** @@ -37,6 +40,7 @@ public class KnnScoreDocQueryBuilder extends AbstractQueryBuilder filterQueries; /** * Creates a query builder. @@ -44,11 +48,18 @@ public class KnnScoreDocQueryBuilder extends AbstractQueryBuilder filterQueries + ) { this.scoreDocs = scoreDocs; this.fieldName = fieldName; this.queryVector = queryVector; this.vectorSimilarity = vectorSimilarity; + this.filterQueries = filterQueries; } public KnnScoreDocQueryBuilder(StreamInput in) throws IOException { @@ -74,6 +85,11 @@ public KnnScoreDocQueryBuilder(StreamInput in) throws IOException { } else { this.vectorSimilarity = null; } + if (in.getTransportVersion().onOrAfter(TransportVersions.TO_CHILD_BLOCK_JOIN_QUERY)) { + this.filterQueries = readQueries(in); + } else { + this.filterQueries = List.of(); + } } @Override @@ -116,6 +132,9 @@ protected void doWriteTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_15_0)) { out.writeOptionalFloat(vectorSimilarity); } + if (out.getTransportVersion().onOrAfter(TransportVersions.TO_CHILD_BLOCK_JOIN_QUERY)) { + writeQueries(out, filterQueries); + } } @Override @@ -135,6 +154,13 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep if (vectorSimilarity != null) { builder.field("similarity", vectorSimilarity); } + if (filterQueries.isEmpty() == false) { + builder.startArray("filter"); + for (QueryBuilder filterQuery : filterQueries) { + filterQuery.toXContent(builder, params); + } + builder.endArray(); + } boostAndQueryNameToXContent(builder); builder.endObject(); } @@ -150,7 +176,20 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws return new MatchNoneQueryBuilder("The \"" + getName() + "\" query was rewritten to a \"match_none\" query."); } if (queryRewriteContext.convertToInnerHitsRewriteContext() != null && queryVector != null && fieldName != null) { - return new ExactKnnQueryBuilder(queryVector, fieldName, vectorSimilarity); + QueryBuilder exactKnnQuery = new ExactKnnQueryBuilder(queryVector, fieldName, vectorSimilarity); + if (filterQueries.isEmpty()) { + return exactKnnQuery; + } else { + BoolQueryBuilder boolQuery = new BoolQueryBuilder(); + boolQuery.must(exactKnnQuery); + for (QueryBuilder filter : this.filterQueries) { + // filter can be both over parents or nested docs, so add them as should clauses to a filter + BoolQueryBuilder adjustedFilter = new BoolQueryBuilder().should(filter) + .should(new ToChildBlockJoinQueryBuilder(filter)); + boolQuery.filter(adjustedFilter); + } + return boolQuery; + } } return super.doRewrite(queryRewriteContext); } @@ -173,7 +212,8 @@ protected boolean doEquals(KnnScoreDocQueryBuilder other) { } return Objects.equals(fieldName, other.fieldName) && Objects.equals(queryVector, other.queryVector) - && Objects.equals(vectorSimilarity, other.vectorSimilarity); + && Objects.equals(vectorSimilarity, other.vectorSimilarity) + && Objects.equals(filterQueries, other.filterQueries); } @Override @@ -183,7 +223,7 @@ protected int doHashCode() { int hashCode = Objects.hash(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex); result = 31 * result + hashCode; } - return Objects.hash(result, fieldName, vectorSimilarity, Objects.hashCode(queryVector)); + return Objects.hash(result, fieldName, vectorSimilarity, Objects.hashCode(queryVector), filterQueries); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java index ea0c15642eb74..b76f56ceb2aa9 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java @@ -27,10 +27,12 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.DenseVectorFieldType; import org.elasticsearch.index.query.AbstractQueryBuilder; +import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.MatchNoneQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.index.query.ToChildBlockJoinQueryBuilder; import org.elasticsearch.index.search.NestedHelper; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ObjectParser; @@ -454,9 +456,6 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException { vectorSimilarity ).boost(boost).queryName(queryName).addFilterQueries(filterQueries); } - if (ctx.convertToInnerHitsRewriteContext() != null) { - return new ExactKnnQueryBuilder(queryVector, fieldName, vectorSimilarity).boost(boost).queryName(queryName); - } boolean changed = false; List rewrittenQueries = new ArrayList<>(filterQueries.size()); for (QueryBuilder query : filterQueries) { @@ -481,6 +480,22 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException { vectorSimilarity ).boost(boost).queryName(queryName).addFilterQueries(rewrittenQueries); } + if (ctx.convertToInnerHitsRewriteContext() != null) { + QueryBuilder exactKnnQuery = new ExactKnnQueryBuilder(queryVector, fieldName, vectorSimilarity); + if (filterQueries.isEmpty()) { + return exactKnnQuery; + } else { + BoolQueryBuilder boolQuery = new BoolQueryBuilder(); + boolQuery.must(exactKnnQuery); + for (QueryBuilder filter : this.filterQueries) { + // filter can be both over parents or nested docs, so add them as should clauses to a filter + BoolQueryBuilder adjustedFilter = new BoolQueryBuilder().should(filter) + .should(new ToChildBlockJoinQueryBuilder(filter)); + boolQuery.filter(adjustedFilter); + } + return boolQuery; + } + } return this; } @@ -500,29 +515,27 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { if (fieldType == null) { return new MatchNoDocsQuery(); } - if (fieldType instanceof DenseVectorFieldType == false) { throw new IllegalArgumentException( "[" + NAME + "] queries are only supported on [" + DenseVectorFieldMapper.CONTENT_TYPE + "] fields" ); } + DenseVectorFieldType vectorFieldType = (DenseVectorFieldType) fieldType; - BooleanQuery.Builder builder = new BooleanQuery.Builder(); + List filtersInitial = new ArrayList<>(filterQueries.size()); for (QueryBuilder query : this.filterQueries) { - builder.add(query.toQuery(context), BooleanClause.Occur.FILTER); + filtersInitial.add(query.toQuery(context)); } if (context.getAliasFilter() != null) { - builder.add(context.getAliasFilter().toQuery(context), BooleanClause.Occur.FILTER); + filtersInitial.add(context.getAliasFilter().toQuery(context)); } - BooleanQuery booleanQuery = builder.build(); - Query filterQuery = booleanQuery.clauses().isEmpty() ? null : booleanQuery; - DenseVectorFieldType vectorFieldType = (DenseVectorFieldType) fieldType; String parentPath = context.nestedLookup().getNestedParent(fieldName); - Float oversample = rescoreVectorBuilder() == null ? null : rescoreVectorBuilder.oversample(); - BitSetProducer parentBitSet = null; - if (parentPath != null) { + Query filterQuery; + if (parentPath == null) { + filterQuery = buildFilterQuery(filtersInitial); + } else { final Query parentFilter; NestedObjectMapper originalObjectMapper = context.nestedScope().getObjectMapper(); if (originalObjectMapper != null) { @@ -541,19 +554,23 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { parentFilter = Queries.newNonNestedFilter(context.indexVersionCreated()); } parentBitSet = context.bitsetFilter(parentFilter); - if (filterQuery != null) { - // We treat the provided filter as a filter over PARENT documents, so if it might match nested documents - // we need to adjust it. - if (NestedHelper.mightMatchNestedDocs(filterQuery, context)) { - // Ensure that the query only returns parent documents matching `filterQuery` - filterQuery = Queries.filtered(filterQuery, parentFilter); + List filterAdjusted = new ArrayList<>(filtersInitial.size()); + for (Query f : filtersInitial) { + // If filter matches non-nested docs, we assume this is a filter over parents docs, + // so we will modify it accordingly: matching parents docs with join to its child docs + if (NestedHelper.mightMatchNonNestedDocs(f, parentPath, context)) { + // Ensure that the query only returns parent documents matching filter + f = Queries.filtered(f, parentFilter); + f = new ToChildBlockJoinQuery(f, parentBitSet); } - // Now join the filterQuery & parentFilter to provide the matching blocks of children - filterQuery = new ToChildBlockJoinQuery(filterQuery, parentBitSet); + filterAdjusted.add(f); } + filterQuery = buildFilterQuery(filterAdjusted); } + DenseVectorFieldMapper.FilterHeuristic heuristic = context.getIndexSettings().getHnswFilterHeuristic(); boolean hnswEarlyTermination = context.getIndexSettings().getHnswEarlyTermination(); + Float oversample = rescoreVectorBuilder() == null ? null : rescoreVectorBuilder.oversample(); return vectorFieldType.createKnnQuery( queryVector, k, @@ -567,6 +584,16 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { ); } + private static Query buildFilterQuery(List filters) { + BooleanQuery.Builder builder = new BooleanQuery.Builder(); + for (Query f : filters) { + builder.add(f, BooleanClause.Occur.FILTER); + } + BooleanQuery booleanQuery = builder.build(); + Query filterQuery = booleanQuery.clauses().isEmpty() ? null : booleanQuery; + return filterQuery; + } + @Override protected int doHashCode() { return Objects.hash( diff --git a/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java index 43292c4f65245..a4f698d04b782 100644 --- a/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java @@ -367,13 +367,15 @@ public void testRewriteShardSearchRequestWithRank() { new ScoreDoc[] { new ScoreDoc(1, 3.0f, 1), new ScoreDoc(4, 1.5f, 1) }, "vector", VectorData.fromFloats(new float[] { 0.0f }), - null + null, + List.of() ); KnnScoreDocQueryBuilder ksdqb1 = new KnnScoreDocQueryBuilder( new ScoreDoc[] { new ScoreDoc(1, 2.0f, 1) }, "vector2", VectorData.fromFloats(new float[] { 0.0f }), - null + null, + List.of() ); assertEquals( List.of(bm25, ksdqb0, ksdqb1), diff --git a/server/src/test/java/org/elasticsearch/index/query/ToChildBlockJoinQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/index/query/ToChildBlockJoinQueryBuilderTests.java new file mode 100644 index 0000000000000..89d0474461ace --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/query/ToChildBlockJoinQueryBuilderTests.java @@ -0,0 +1,53 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.query; + +import org.apache.lucene.search.Query; +import org.apache.lucene.search.join.ToChildBlockJoinQuery; +import org.elasticsearch.test.AbstractQueryTestCase; + +import java.io.IOException; + +import static org.hamcrest.CoreMatchers.instanceOf; + +public class ToChildBlockJoinQueryBuilderTests extends AbstractQueryTestCase { + @Override + protected ToChildBlockJoinQueryBuilder doCreateTestQueryBuilder() { + String filterFieldName = randomBoolean() ? KEYWORD_FIELD_NAME : TEXT_FIELD_NAME; + return new ToChildBlockJoinQueryBuilder(QueryBuilders.termQuery(filterFieldName, randomAlphaOfLength(10))); + } + + @Override + protected void doAssertLuceneQuery(ToChildBlockJoinQueryBuilder queryBuilder, Query query, SearchExecutionContext context) + throws IOException { + assertThat(query, instanceOf(ToChildBlockJoinQuery.class)); + } + + @Override + public void testUnknownField() throws IOException { + // Test isn't relevant, since query is never parsed from xContent + } + + @Override + public void testUnknownObjectException() { + // Test isn't relevant, since query is never parsed from xContent + } + + @Override + public void testFromXContent() throws IOException { + // Test isn't relevant, since query is never parsed from xContent + } + + @Override + public void testValidOutput() { + // Test isn't relevant, since query is never parsed from xContent + } + +} diff --git a/server/src/test/java/org/elasticsearch/search/SearchModuleTests.java b/server/src/test/java/org/elasticsearch/search/SearchModuleTests.java index 1e638f8e7b30e..ef02a0405c88f 100644 --- a/server/src/test/java/org/elasticsearch/search/SearchModuleTests.java +++ b/server/src/test/java/org/elasticsearch/search/SearchModuleTests.java @@ -463,7 +463,8 @@ public CheckedBiConsumer getReque "terms_set", "wildcard", "wrapper", - "distance_feature" }; + "distance_feature", + "to_child_block_join" }; // add here deprecated queries to make sure we log a deprecation warnings when they are used private static final String[] DEPRECATED_QUERIES = new String[] { "field_masking_span", "geo_polygon" }; diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java index a15372bc1e8ef..a8d9b1259cb41 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java @@ -25,6 +25,7 @@ import org.elasticsearch.index.IndexVersions; import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.InnerHitsRewriteContext; import org.elasticsearch.index.query.MatchNoneQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; @@ -482,12 +483,19 @@ public void testRewriteForInnerHits() throws IOException { queryBuilder.boost(randomFloat()); queryBuilder.queryName(randomAlphaOfLength(10)); QueryBuilder rewritten = queryBuilder.rewrite(innerHitsRewriteContext); + float queryBoost = rewritten.boost(); + String queryName = rewritten.queryName(); + if (queryBuilder.filterQueries().isEmpty() == false) { + assertTrue(rewritten instanceof BoolQueryBuilder); + BoolQueryBuilder boolQueryBuilder = (BoolQueryBuilder) rewritten; + rewritten = boolQueryBuilder.must().get(0); + } assertTrue(rewritten instanceof ExactKnnQueryBuilder); ExactKnnQueryBuilder exactKnnQueryBuilder = (ExactKnnQueryBuilder) rewritten; assertEquals(queryBuilder.queryVector(), exactKnnQueryBuilder.getQuery()); assertEquals(queryBuilder.getFieldName(), exactKnnQueryBuilder.getField()); - assertEquals(queryBuilder.boost(), exactKnnQueryBuilder.boost(), 0.0001f); - assertEquals(queryBuilder.queryName(), exactKnnQueryBuilder.queryName()); + assertEquals(queryBuilder.boost(), queryBoost, 0.0001f); + assertEquals(queryBuilder.queryName(), queryName); assertEquals(queryBuilder.getVectorSimilarity(), exactKnnQueryBuilder.vectorSimilarity()); } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilderTests.java index bef0bbfd27ff6..cf94fd41c1171 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilderTests.java @@ -24,9 +24,11 @@ import org.apache.lucene.search.Weight; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; +import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.InnerHitsRewriteContext; import org.elasticsearch.index.query.MatchNoneQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.test.AbstractQueryTestCase; @@ -54,11 +56,20 @@ protected KnnScoreDocQueryBuilder doCreateTestQueryBuilder() { for (int doc = 0; doc < numDocs; doc++) { scoreDocs.add(new ScoreDoc(doc, randomFloat())); } + List filters = new ArrayList<>(); + if (randomBoolean()) { + int numFilters = randomIntBetween(1, 5); + for (int i = 0; i < numFilters; i++) { + String filterFieldName = randomBoolean() ? KEYWORD_FIELD_NAME : TEXT_FIELD_NAME; + filters.add(QueryBuilders.termQuery(filterFieldName, randomAlphaOfLength(10))); + } + } return new KnnScoreDocQueryBuilder( scoreDocs.toArray(new ScoreDoc[0]), randomBoolean() ? "field" : null, randomBoolean() ? VectorData.fromFloats(randomVector(10)) : null, - randomBoolean() ? randomFloat() : null + randomBoolean() ? randomFloat() : null, + filters ); } @@ -68,7 +79,8 @@ public void testValidOutput() { new ScoreDoc[] { new ScoreDoc(0, 4.25f), new ScoreDoc(5, 1.6f) }, "field", VectorData.fromFloats(new float[] { 1.0f, 2.0f }), - null + null, + List.of() ); String expected = """ { @@ -159,7 +171,8 @@ public void testRewriteToMatchNone() throws IOException { new ScoreDoc[0], randomBoolean() ? "field" : null, randomBoolean() ? VectorData.fromFloats(randomVector(10)) : null, - randomBoolean() ? randomFloat() : null + randomBoolean() ? randomFloat() : null, + List.of() ); QueryRewriteContext context = randomBoolean() ? new InnerHitsRewriteContext(createSearchExecutionContext().getParserConfig(), System::currentTimeMillis) @@ -170,21 +183,41 @@ public void testRewriteToMatchNone() throws IOException { public void testRewriteForInnerHits() throws IOException { SearchExecutionContext context = createSearchExecutionContext(); InnerHitsRewriteContext innerHitsRewriteContext = new InnerHitsRewriteContext(context.getParserConfig(), System::currentTimeMillis); + List filters = new ArrayList<>(); + boolean hasFilters = randomBoolean(); + if (hasFilters) { + int numFilters = randomIntBetween(1, 5); + for (int i = 0; i < numFilters; i++) { + String filterFieldName = randomBoolean() ? KEYWORD_FIELD_NAME : TEXT_FIELD_NAME; + filters.add(QueryBuilders.termQuery(filterFieldName, randomAlphaOfLength(10))); + } + } + KnnScoreDocQueryBuilder queryBuilder = new KnnScoreDocQueryBuilder( new ScoreDoc[] { new ScoreDoc(0, 4.25f), new ScoreDoc(5, 1.6f) }, randomAlphaOfLength(10), VectorData.fromFloats(randomVector(10)), - randomBoolean() ? randomFloat() : null + randomBoolean() ? randomFloat() : null, + filters ); queryBuilder.boost(randomFloat()); queryBuilder.queryName(randomAlphaOfLength(10)); QueryBuilder rewritten = queryBuilder.rewrite(innerHitsRewriteContext); + float queryBoost = rewritten.boost(); + String queryName = rewritten.queryName(); + + if (hasFilters) { + assertTrue(rewritten instanceof BoolQueryBuilder); + BoolQueryBuilder boolQueryBuilder = (BoolQueryBuilder) rewritten; + rewritten = boolQueryBuilder.must().get(0); + } + assertTrue(rewritten instanceof ExactKnnQueryBuilder); ExactKnnQueryBuilder exactKnnQueryBuilder = (ExactKnnQueryBuilder) rewritten; assertEquals(queryBuilder.queryVector(), exactKnnQueryBuilder.getQuery()); assertEquals(queryBuilder.fieldName(), exactKnnQueryBuilder.getField()); - assertEquals(queryBuilder.boost(), exactKnnQueryBuilder.boost(), 0.0001f); - assertEquals(queryBuilder.queryName(), exactKnnQueryBuilder.queryName()); + assertEquals(queryBuilder.boost(), queryBoost, 0.0001f); + assertEquals(queryBuilder.queryName(), queryName); assertEquals(queryBuilder.vectorSimilarity(), exactKnnQueryBuilder.vectorSimilarity()); } @@ -228,7 +261,8 @@ public void testScoreDocQueryWeightCount() throws IOException { scoreDocs, "field", VectorData.fromFloats(randomVector(10)), - null + null, + List.of() ); Query query = queryBuilder.doToQuery(context); final Weight w = query.createWeight(searcher, ScoreMode.TOP_SCORES, 1.0f); @@ -276,7 +310,8 @@ public void testScoreDocQuery() throws IOException { scoreDocs, "field", VectorData.fromFloats(randomVector(10)), - null + null, + List.of() ); final Query query = queryBuilder.doToQuery(context); final Weight w = query.createWeight(searcher, ScoreMode.TOP_SCORES, 1.0f);