diff --git a/docs/changelog/111834.yaml b/docs/changelog/111834.yaml new file mode 100644 index 0000000000000..4548dee5f91e5 --- /dev/null +++ b/docs/changelog/111834.yaml @@ -0,0 +1,5 @@ +pr: 111834 +summary: Add inner hits support to semantic query +area: Search +type: enhancement +issues: [] diff --git a/docs/reference/query-dsl/semantic-query.asciidoc b/docs/reference/query-dsl/semantic-query.asciidoc index 22b5e6c5e6aad..f3f6aca3fd07a 100644 --- a/docs/reference/query-dsl/semantic-query.asciidoc +++ b/docs/reference/query-dsl/semantic-query.asciidoc @@ -25,7 +25,7 @@ GET my-index-000001/_search } } ------------------------------------------------------------ -// TEST[skip:TBD] +// TEST[skip: Requires inference endpoints] [discrete] @@ -40,9 +40,209 @@ The `semantic_text` field to perform the query on. (Required, string) The query text to be searched for on the field. +`inner_hits`:: +(Optional, object) +Retrieves the specific passages that match the query. +See <> for more information. ++ +.Properties of `inner_hits` +[%collapsible%open] +==== +`from`:: +(Optional, integer) +The offset from the first matching passage to fetch. +Used to paginate through the passages. +Defaults to `0`. + +`size`:: +(Optional, integer) +The maximum number of matching passages to return. +Defaults to `3`. +==== Refer to <> to learn more about semantic search using `semantic_text` and `semantic` query. +[discrete] +[[semantic-query-passage-ranking]] +==== Passage ranking with the `semantic` query +The `inner_hits` parameter can be used for _passage ranking_, which allows you to determine which passages in the document best match the query. +For example, if you have a document that covers varying topics: + +[source,console] +------------------------------------------------------------ +POST my-index/_doc/lake_tahoe +{ + "inference_field": [ + "Lake Tahoe is the largest alpine lake in North America", + "When hiking in the area, please be on alert for bears" + ] +} +------------------------------------------------------------ +// TEST[skip: Requires inference endpoints] + +You can use passage ranking to find the passage that best matches your query: + +[source,console] +------------------------------------------------------------ +GET my-index/_search +{ + "query": { + "semantic": { + "field": "inference_field", + "query": "mountain lake", + "inner_hits": { } + } + } +} +------------------------------------------------------------ +// TEST[skip: Requires inference endpoints] + +[source,console-result] +------------------------------------------------------------ +{ + "took": 67, + "timed_out": false, + "_shards": { + "total": 1, + "successful": 1, + "skipped": 0, + "failed": 0 + }, + "hits": { + "total": { + "value": 1, + "relation": "eq" + }, + "max_score": 10.844536, + "hits": [ + { + "_index": "my-index", + "_id": "lake_tahoe", + "_score": 10.844536, + "_source": { + ... + }, + "inner_hits": { <1> + "inference_field": { + "hits": { + "total": { + "value": 2, + "relation": "eq" + }, + "max_score": 10.844536, + "hits": [ + { + "_index": "my-index", + "_id": "lake_tahoe", + "_nested": { + "field": "inference_field.inference.chunks", + "offset": 0 + }, + "_score": 10.844536, + "_source": { + "text": "Lake Tahoe is the largest alpine lake in North America" + } + }, + { + "_index": "my-index", + "_id": "lake_tahoe", + "_nested": { + "field": "inference_field.inference.chunks", + "offset": 1 + }, + "_score": 3.2726858, + "_source": { + "text": "When hiking in the area, please be on alert for bears" + } + } + ] + } + } + } + } + ] + } +} +------------------------------------------------------------ +<1> Ranked passages will be returned using the <>, with `` set to the `semantic_text` field name. + +By default, the top three matching passages will be returned. +You can use the `size` parameter to control the number of passages returned and the `from` parameter to page through the matching passages: + +[source,console] +------------------------------------------------------------ +GET my-index/_search +{ + "query": { + "semantic": { + "field": "inference_field", + "query": "mountain lake", + "inner_hits": { + "from": 1, + "size": 1 + } + } + } +} +------------------------------------------------------------ +// TEST[skip: Requires inference endpoints] + +[source,console-result] +------------------------------------------------------------ +{ + "took": 42, + "timed_out": false, + "_shards": { + "total": 1, + "successful": 1, + "skipped": 0, + "failed": 0 + }, + "hits": { + "total": { + "value": 1, + "relation": "eq" + }, + "max_score": 10.844536, + "hits": [ + { + "_index": "my-index", + "_id": "lake_tahoe", + "_score": 10.844536, + "_source": { + ... + }, + "inner_hits": { + "inference_field": { + "hits": { + "total": { + "value": 2, + "relation": "eq" + }, + "max_score": 10.844536, + "hits": [ + { + "_index": "my-index", + "_id": "lake_tahoe", + "_nested": { + "field": "inference_field.inference.chunks", + "offset": 1 + }, + "_score": 3.2726858, + "_source": { + "text": "When hiking in the area, please be on alert for bears" + } + } + ] + } + } + } + } + ] + } +} +------------------------------------------------------------ + [discrete] [[hybrid-search-semantic]] ==== Hybrid search with the `semantic` query @@ -79,7 +279,7 @@ POST my-index/_search } } ------------------------------------------------------------ -// TEST[skip:TBD] +// TEST[skip: Requires inference endpoints] You can also use semantic_text as part of <> to make ranking relevant results easier: @@ -116,12 +316,12 @@ GET my-index/_search } } ------------------------------------------------------------ -// TEST[skip:TBD] +// TEST[skip: Requires inference endpoints] [discrete] [[advanced-search]] -=== Advanced search on `semantic_text` fields +==== Advanced search on `semantic_text` fields The `semantic` query uses default settings for searching on `semantic_text` fields for ease of use. If you want to fine-tune a search on a `semantic_text` field, you need to know the task type used by the `inference_id` configured in `semantic_text`. @@ -135,7 +335,7 @@ on a `semantic_text` field, it is not supported to use the `semantic_query` on a [discrete] [[search-sparse-inference]] -==== Search with `sparse_embedding` inference +===== Search with `sparse_embedding` inference When the {infer} endpoint uses a `sparse_embedding` model, you can use a <> on a <> field in the following way: @@ -157,14 +357,14 @@ GET test-index/_search } } ------------------------------------------------------------ -// TEST[skip:TBD] +// TEST[skip: Requires inference endpoints] You can customize the `sparse_vector` query to include specific settings, like <>. [discrete] [[search-text-inferece]] -==== Search with `text_embedding` inference +===== Search with `text_embedding` inference When the {infer} endpoint uses a `text_embedding` model, you can use a <> on a `semantic_text` field in the following way: @@ -190,6 +390,6 @@ GET test-index/_search } } ------------------------------------------------------------ -// TEST[skip:TBD] +// TEST[skip: Requires inference endpoints] You can customize the `knn` query to include specific settings, like `num_candidates` and `k`. diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 2d0f526f64a69..b519e263dd387 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -225,6 +225,7 @@ static TransportVersion def(int id) { public static final TransportVersion ILM_ADD_SEARCHABLE_SNAPSHOT_TOTAL_SHARDS_PER_NODE = def(8_749_00_0); public static final TransportVersion SEMANTIC_TEXT_SEARCH_INFERENCE_ID = def(8_750_00_0); public static final TransportVersion ML_INFERENCE_CHUNKING_SETTINGS = def(8_751_00_0); + public static final TransportVersion SEMANTIC_QUERY_INNER_HITS = def(8_752_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/index/query/InnerHitBuilder.java b/server/src/main/java/org/elasticsearch/index/query/InnerHitBuilder.java index 4c861c2320ea5..806f28d72647a 100644 --- a/server/src/main/java/org/elasticsearch/index/query/InnerHitBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/InnerHitBuilder.java @@ -50,9 +50,9 @@ public final class InnerHitBuilder implements Writeable, ToXContentObject { public static final ParseField COLLAPSE_FIELD = new ParseField("collapse"); public static final ParseField FIELD_FIELD = new ParseField("field"); + public static final int DEFAULT_FROM = 0; + public static final int DEFAULT_SIZE = 3; private static final boolean DEFAULT_IGNORE_UNAMPPED = false; - private static final int DEFAULT_FROM = 0; - private static final int DEFAULT_SIZE = 3; private static final boolean DEFAULT_VERSION = false; private static final boolean DEFAULT_SEQ_NO_AND_PRIMARY_TERM = false; private static final boolean DEFAULT_EXPLAIN = false; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java index fd330a8cf6cc6..30ccb48d5c709 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java @@ -10,6 +10,7 @@ import org.elasticsearch.features.FeatureSpecification; import org.elasticsearch.features.NodeFeature; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; +import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder; import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder; @@ -25,7 +26,8 @@ public Set getFeatures() { return Set.of( TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_RETRIEVER_SUPPORTED, RandomRankRetrieverBuilder.RANDOM_RERANKER_RETRIEVER_SUPPORTED, - SemanticTextFieldMapper.SEMANTIC_TEXT_SEARCH_INFERENCE_ID + SemanticTextFieldMapper.SEMANTIC_TEXT_SEARCH_INFERENCE_ID, + SemanticQueryBuilder.SEMANTIC_TEXT_INNER_HITS ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 0483296cd2c6a..e0ad044f597ab 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -40,6 +40,7 @@ import org.elasticsearch.index.mapper.ValueFetcher; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; +import org.elasticsearch.index.query.InnerHitBuilder; import org.elasticsearch.index.query.MatchNoneQueryBuilder; import org.elasticsearch.index.query.NestedQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; @@ -54,6 +55,7 @@ import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; +import org.elasticsearch.xpack.inference.queries.SemanticQueryInnerHitBuilder; import java.io.IOException; import java.util.ArrayList; @@ -468,7 +470,12 @@ public boolean fieldHasValue(FieldInfos fieldInfos) { return fieldInfos.fieldInfo(getEmbeddingsFieldName(name())) != null; } - public QueryBuilder semanticQuery(InferenceResults inferenceResults, float boost, String queryName) { + public QueryBuilder semanticQuery( + InferenceResults inferenceResults, + float boost, + String queryName, + SemanticQueryInnerHitBuilder semanticInnerHitBuilder + ) { String nestedFieldPath = getChunksFieldName(name()); String inferenceResultsFieldName = getEmbeddingsFieldName(name()); QueryBuilder childQueryBuilder; @@ -524,7 +531,10 @@ public QueryBuilder semanticQuery(InferenceResults inferenceResults, float boost }; } - return new NestedQueryBuilder(nestedFieldPath, childQueryBuilder, ScoreMode.Max).boost(boost).queryName(queryName); + InnerHitBuilder innerHitBuilder = semanticInnerHitBuilder != null ? semanticInnerHitBuilder.toInnerHitBuilder() : null; + return new NestedQueryBuilder(nestedFieldPath, childQueryBuilder, ScoreMode.Max).boost(boost) + .queryName(queryName) + .innerHit(innerHitBuilder); } private String generateQueryInferenceResultsTypeMismatchMessage(InferenceResults inferenceResults, String expectedResultsType) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index 9f7fcb1ef407c..901de30145f7d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -16,6 +16,8 @@ import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.query.AbstractQueryBuilder; import org.elasticsearch.index.query.MatchNoneQueryBuilder; @@ -44,35 +46,46 @@ import java.util.Map; import java.util.Objects; +import static org.elasticsearch.TransportVersions.SEMANTIC_QUERY_INNER_HITS; import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; public class SemanticQueryBuilder extends AbstractQueryBuilder { + public static final NodeFeature SEMANTIC_TEXT_INNER_HITS = new NodeFeature("semantic_text.inner_hits"); + public static final String NAME = "semantic"; private static final ParseField FIELD_FIELD = new ParseField("field"); private static final ParseField QUERY_FIELD = new ParseField("query"); + private static final ParseField INNER_HITS_FIELD = new ParseField("inner_hits"); private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( NAME, false, - args -> new SemanticQueryBuilder((String) args[0], (String) args[1]) + args -> new SemanticQueryBuilder((String) args[0], (String) args[1], (SemanticQueryInnerHitBuilder) args[2]) ); static { PARSER.declareString(constructorArg(), FIELD_FIELD); PARSER.declareString(constructorArg(), QUERY_FIELD); + PARSER.declareObject(optionalConstructorArg(), (p, c) -> SemanticQueryInnerHitBuilder.fromXContent(p), INNER_HITS_FIELD); declareStandardFields(PARSER); } private final String fieldName; private final String query; + private final SemanticQueryInnerHitBuilder innerHitBuilder; private final SetOnce inferenceResultsSupplier; private final InferenceResults inferenceResults; private final boolean noInferenceResults; public SemanticQueryBuilder(String fieldName, String query) { + this(fieldName, query, null); + } + + public SemanticQueryBuilder(String fieldName, String query, @Nullable SemanticQueryInnerHitBuilder innerHitBuilder) { if (fieldName == null) { throw new IllegalArgumentException("[" + NAME + "] requires a " + FIELD_FIELD.getPreferredName() + " value"); } @@ -81,15 +94,25 @@ public SemanticQueryBuilder(String fieldName, String query) { } this.fieldName = fieldName; this.query = query; + this.innerHitBuilder = innerHitBuilder; this.inferenceResults = null; this.inferenceResultsSupplier = null; this.noInferenceResults = false; + + if (this.innerHitBuilder != null) { + this.innerHitBuilder.setFieldName(fieldName); + } } public SemanticQueryBuilder(StreamInput in) throws IOException { super(in); this.fieldName = in.readString(); this.query = in.readString(); + if (in.getTransportVersion().onOrAfter(SEMANTIC_QUERY_INNER_HITS)) { + this.innerHitBuilder = in.readOptionalWriteable(SemanticQueryInnerHitBuilder::new); + } else { + this.innerHitBuilder = null; + } this.inferenceResults = in.readOptionalNamedWriteable(InferenceResults.class); this.noInferenceResults = in.readBoolean(); this.inferenceResultsSupplier = null; @@ -102,6 +125,21 @@ protected void doWriteTo(StreamOutput out) throws IOException { } out.writeString(fieldName); out.writeString(query); + if (out.getTransportVersion().onOrAfter(SEMANTIC_QUERY_INNER_HITS)) { + out.writeOptionalWriteable(innerHitBuilder); + } else if (innerHitBuilder != null) { + throw new IllegalStateException( + "Transport version must be at least [" + + SEMANTIC_QUERY_INNER_HITS.toReleaseVersion() + + "] to use [ " + + INNER_HITS_FIELD.getPreferredName() + + "] in [" + + NAME + + "], current transport version is [" + + out.getTransportVersion().toReleaseVersion() + + "]. Are you running a mixed-version cluster?" + ); + } out.writeOptionalNamedWriteable(inferenceResults); out.writeBoolean(noInferenceResults); } @@ -114,6 +152,7 @@ private SemanticQueryBuilder( ) { this.fieldName = other.fieldName; this.query = other.query; + this.innerHitBuilder = other.innerHitBuilder; this.boost = other.boost; this.queryName = other.queryName; this.inferenceResultsSupplier = inferenceResultsSupplier; @@ -121,6 +160,10 @@ private SemanticQueryBuilder( this.noInferenceResults = noInferenceResults; } + public SemanticQueryInnerHitBuilder innerHit() { + return innerHitBuilder; + } + @Override public String getWriteableName() { return NAME; @@ -140,6 +183,9 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep builder.startObject(NAME); builder.field(FIELD_FIELD.getPreferredName(), fieldName); builder.field(QUERY_FIELD.getPreferredName(), query); + if (innerHitBuilder != null) { + builder.field(INNER_HITS_FIELD.getPreferredName(), innerHitBuilder); + } boostAndQueryNameToXContent(builder); builder.endObject(); } @@ -166,7 +212,7 @@ private QueryBuilder doRewriteBuildSemanticQuery(SearchExecutionContext searchEx ); } - return semanticTextFieldType.semanticQuery(inferenceResults, boost(), queryName()); + return semanticTextFieldType.semanticQuery(inferenceResults, boost(), queryName(), innerHitBuilder); } else { throw new IllegalArgumentException( "Field [" + fieldName + "] of type [" + fieldType.typeName() + "] does not support " + NAME + " queries" @@ -301,11 +347,12 @@ private static String getInferenceIdForForField(Collection indexM protected boolean doEquals(SemanticQueryBuilder other) { return Objects.equals(fieldName, other.fieldName) && Objects.equals(query, other.query) + && Objects.equals(innerHitBuilder, other.innerHitBuilder) && Objects.equals(inferenceResults, other.inferenceResults); } @Override protected int doHashCode() { - return Objects.hash(fieldName, query, inferenceResults); + return Objects.hash(fieldName, query, innerHitBuilder, inferenceResults); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryInnerHitBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryInnerHitBuilder.java new file mode 100644 index 0000000000000..776ce990665ac --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryInnerHitBuilder.java @@ -0,0 +1,132 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.queries; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.index.query.InnerHitBuilder; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.fetch.subphase.FetchSourceContext; +import org.elasticsearch.xcontent.ObjectParser; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.inference.mapper.SemanticTextField; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.index.query.InnerHitBuilder.DEFAULT_FROM; +import static org.elasticsearch.index.query.InnerHitBuilder.DEFAULT_SIZE; + +public class SemanticQueryInnerHitBuilder implements Writeable, ToXContentObject { + private static final ObjectParser PARSER = new ObjectParser<>( + "semantic_query_inner_hits", + SemanticQueryInnerHitBuilder::new + ); + + static { + PARSER.declareInt(SemanticQueryInnerHitBuilder::setFrom, SearchSourceBuilder.FROM_FIELD); + PARSER.declareInt(SemanticQueryInnerHitBuilder::setSize, SearchSourceBuilder.SIZE_FIELD); + } + + private String fieldName; + private int from = DEFAULT_FROM; + private int size = DEFAULT_SIZE; + + public SemanticQueryInnerHitBuilder() { + this.fieldName = null; + } + + public SemanticQueryInnerHitBuilder(StreamInput in) throws IOException { + fieldName = in.readOptionalString(); + from = in.readVInt(); + size = in.readVInt(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(fieldName); + out.writeVInt(from); + out.writeVInt(size); + } + + public String getFieldName() { + return fieldName; + } + + public void setFieldName(String fieldName) { + this.fieldName = fieldName; + } + + public int getFrom() { + return from; + } + + public SemanticQueryInnerHitBuilder setFrom(int from) { + this.from = from; + return this; + } + + public int getSize() { + return size; + } + + public SemanticQueryInnerHitBuilder setSize(int size) { + this.size = size; + return this; + } + + public InnerHitBuilder toInnerHitBuilder() { + if (fieldName == null) { + throw new IllegalStateException("fieldName must have a value"); + } + + return new InnerHitBuilder(fieldName).setFrom(from) + .setSize(size) + .setFetchSourceContext(FetchSourceContext.of(true, null, new String[] { SemanticTextField.getEmbeddingsFieldName(fieldName) })); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + // Don't include name in XContent because it is hard-coded + builder.startObject(); + if (from != DEFAULT_FROM) { + builder.field(SearchSourceBuilder.FROM_FIELD.getPreferredName(), from); + } + if (size != DEFAULT_SIZE) { + builder.field(SearchSourceBuilder.SIZE_FIELD.getPreferredName(), size); + } + builder.endObject(); + return builder; + } + + public static SemanticQueryInnerHitBuilder fromXContent(XContentParser parser) throws IOException { + return PARSER.parse(parser, new SemanticQueryInnerHitBuilder(), null); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SemanticQueryInnerHitBuilder that = (SemanticQueryInnerHitBuilder) o; + return from == that.from && size == that.size && Objects.equals(fieldName, that.fieldName); + } + + @Override + public int hashCode() { + return Objects.hash(fieldName, from, size); + } + + @Override + public String toString() { + return Strings.toString(this, true, true); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java index f54ce89183079..47ac33a5cf9ab 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java @@ -31,7 +31,9 @@ import org.elasticsearch.index.mapper.ParsedDocument; import org.elasticsearch.index.mapper.SourceToParse; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.query.InnerHitContextBuilder; import org.elasticsearch.index.query.MatchNoneQueryBuilder; +import org.elasticsearch.index.query.NestedQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.SearchExecutionContext; @@ -62,7 +64,9 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.HashMap; import java.util.List; +import java.util.Map; import static org.apache.lucene.search.BooleanClause.Occur.FILTER; import static org.apache.lucene.search.BooleanClause.Occur.MUST; @@ -165,7 +169,14 @@ protected SemanticQueryBuilder doCreateTestQueryBuilder() { queryTokens.add(randomAlphaOfLength(QUERY_TOKEN_LENGTH)); } - SemanticQueryBuilder builder = new SemanticQueryBuilder(SEMANTIC_TEXT_FIELD, String.join(" ", queryTokens)); + SemanticQueryInnerHitBuilder innerHitBuilder = null; + if (randomBoolean()) { + innerHitBuilder = new SemanticQueryInnerHitBuilder(); + innerHitBuilder.setFrom(randomIntBetween(0, 100)); + innerHitBuilder.setSize(randomIntBetween(0, 100)); + } + + SemanticQueryBuilder builder = new SemanticQueryBuilder(SEMANTIC_TEXT_FIELD, String.join(" ", queryTokens), innerHitBuilder); if (randomBoolean()) { builder.boost((float) randomDoubleBetween(0.1, 10.0, true)); } @@ -190,6 +201,21 @@ protected void doAssertLuceneQuery(SemanticQueryBuilder queryBuilder, Query quer case SPARSE_EMBEDDING -> assertSparseEmbeddingLuceneQuery(nestedQuery.getChildQuery()); case TEXT_EMBEDDING -> assertTextEmbeddingLuceneQuery(nestedQuery.getChildQuery()); } + + if (queryBuilder.innerHit() != null) { + // Rewrite to a nested query + QueryBuilder rewrittenQueryBuilder = rewriteQuery(queryBuilder, createQueryRewriteContext(), createSearchExecutionContext()); + assertThat(rewrittenQueryBuilder, instanceOf(NestedQueryBuilder.class)); + + NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) rewrittenQueryBuilder; + Map innerHitInternals = new HashMap<>(); + InnerHitContextBuilder.extractInnerHits(nestedQueryBuilder, innerHitInternals); + assertThat(innerHitInternals.size(), equalTo(1)); + + InnerHitContextBuilder innerHits = innerHitInternals.get(queryBuilder.innerHit().getFieldName()); + assertNotNull(innerHits); + assertThat(innerHits.innerHitBuilder(), equalTo(queryBuilder.innerHit().toInnerHitBuilder())); + } } private void assertSparseEmbeddingLuceneQuery(Query query) { @@ -312,6 +338,20 @@ public void testToXContent() throws IOException { "query": "bar" } }""", queryBuilder); + + SemanticQueryInnerHitBuilder innerHitBuilder = new SemanticQueryInnerHitBuilder().setFrom(1).setSize(2); + queryBuilder = new SemanticQueryBuilder("foo", "bar", innerHitBuilder); + checkGeneratedJson(""" + { + "semantic": { + "field": "foo", + "query": "bar", + "inner_hits": { + "from": 1, + "size": 2 + } + } + }""", queryBuilder); } public void testSerializingQueryWhenNoInferenceId() throws IOException { diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/40_semantic_text_query.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/40_semantic_text_query.yml index 2070b3752791a..4d90d8faeb3f3 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/40_semantic_text_query.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/40_semantic_text_query.yml @@ -122,6 +122,147 @@ setup: - close_to: { hits.hits.0._score: { value: 3.7837332e17, error: 1e10 } } - length: { hits.hits.0._source.inference_field.inference.chunks: 2 } +--- +"Query using a sparse embedding model and inner hits": + - requires: + cluster_features: "semantic_text.inner_hits" + reason: semantic_text inner hits support added in 8.16.0 + + - skip: + features: [ "headers", "close_to" ] + + - do: + index: + index: test-sparse-index + id: doc_1 + body: + inference_field: ["inference test", "another inference test", "yet another inference test"] + non_inference_field: "non inference test" + refresh: true + + - do: + headers: + # Force JSON content type so that we use a parser that interprets the floating-point score as a double + Content-Type: application/json + search: + index: test-sparse-index + body: + query: + semantic: + field: "inference_field" + query: "inference test" + inner_hits: {} + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - close_to: { hits.hits.0._score: { value: 3.7837332e17, error: 1e10 } } + - length: { hits.hits.0._source.inference_field.inference.chunks: 3 } + - match: { hits.hits.0.inner_hits.inference_field.hits.total.value: 3 } + - length: { hits.hits.0.inner_hits.inference_field.hits.hits: 3 } + - match: { hits.hits.0.inner_hits.inference_field.hits.hits.0._source.text: "another inference test" } + - not_exists: hits.hits.0.inner_hits.inference_field.hits.hits.0._source.embeddings + - match: { hits.hits.0.inner_hits.inference_field.hits.hits.1._source.text: "yet another inference test" } + - not_exists: hits.hits.0.inner_hits.inference_field.hits.hits.1._source.embeddings + - match: { hits.hits.0.inner_hits.inference_field.hits.hits.2._source.text: "inference test" } + - not_exists: hits.hits.0.inner_hits.inference_field.hits.hits.2._source.embeddings + + - do: + headers: + # Force JSON content type so that we use a parser that interprets the floating-point score as a double + Content-Type: application/json + search: + index: test-sparse-index + body: + query: + semantic: + field: "inference_field" + query: "inference test" + inner_hits: { + "size": 1 + } + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - close_to: { hits.hits.0._score: { value: 3.7837332e17, error: 1e10 } } + - length: { hits.hits.0._source.inference_field.inference.chunks: 3 } + - match: { hits.hits.0.inner_hits.inference_field.hits.total.value: 3 } + - length: { hits.hits.0.inner_hits.inference_field.hits.hits: 1 } + - match: { hits.hits.0.inner_hits.inference_field.hits.hits.0._source.text: "another inference test" } + - not_exists: hits.hits.0.inner_hits.inference_field.hits.hits.0._source.embeddings + + - do: + headers: + # Force JSON content type so that we use a parser that interprets the floating-point score as a double + Content-Type: application/json + search: + index: test-sparse-index + body: + query: + semantic: + field: "inference_field" + query: "inference test" + inner_hits: { + "from": 1 + } + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - close_to: { hits.hits.0._score: { value: 3.7837332e17, error: 1e10 } } + - length: { hits.hits.0._source.inference_field.inference.chunks: 3 } + - match: { hits.hits.0.inner_hits.inference_field.hits.total.value: 3 } + - length: { hits.hits.0.inner_hits.inference_field.hits.hits: 2 } + - match: { hits.hits.0.inner_hits.inference_field.hits.hits.0._source.text: "yet another inference test" } + - not_exists: hits.hits.0.inner_hits.inference_field.hits.hits.0._source.embeddings + - match: { hits.hits.0.inner_hits.inference_field.hits.hits.1._source.text: "inference test" } + - not_exists: hits.hits.0.inner_hits.inference_field.hits.hits.1._source.embeddings + + - do: + headers: + # Force JSON content type so that we use a parser that interprets the floating-point score as a double + Content-Type: application/json + search: + index: test-sparse-index + body: + query: + semantic: + field: "inference_field" + query: "inference test" + inner_hits: { + "from": 1, + "size": 1 + } + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - close_to: { hits.hits.0._score: { value: 3.7837332e17, error: 1e10 } } + - length: { hits.hits.0._source.inference_field.inference.chunks: 3 } + - match: { hits.hits.0.inner_hits.inference_field.hits.total.value: 3 } + - length: { hits.hits.0.inner_hits.inference_field.hits.hits: 1 } + - match: { hits.hits.0.inner_hits.inference_field.hits.hits.0._source.text: "yet another inference test" } + - not_exists: hits.hits.0.inner_hits.inference_field.hits.hits.0._source.embeddings + + - do: + headers: + # Force JSON content type so that we use a parser that interprets the floating-point score as a double + Content-Type: application/json + search: + index: test-sparse-index + body: + query: + semantic: + field: "inference_field" + query: "inference test" + inner_hits: { + "from": 3 + } + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - close_to: { hits.hits.0._score: { value: 3.7837332e17, error: 1e10 } } + - length: { hits.hits.0._source.inference_field.inference.chunks: 3 } + - match: { hits.hits.0.inner_hits.inference_field.hits.total.value: 0 } # Hits total drops to zero when you page off the end + - length: { hits.hits.0.inner_hits.inference_field.hits.hits: 0 } + --- "Numeric query using a sparse embedding model": - skip: @@ -250,6 +391,147 @@ setup: - close_to: { hits.hits.0._score: { value: 1.0, error: 0.0001 } } - length: { hits.hits.0._source.inference_field.inference.chunks: 2 } +--- +"Query using a dense embedding model and inner hits": + - requires: + cluster_features: "semantic_text.inner_hits" + reason: semantic_text inner hits support added in 8.16.0 + + - skip: + features: [ "headers", "close_to" ] + + - do: + index: + index: test-dense-index + id: doc_1 + body: + inference_field: ["inference test", "another inference test", "yet another inference test"] + non_inference_field: "non inference test" + refresh: true + + - do: + headers: + # Force JSON content type so that we use a parser that interprets the floating-point score as a double + Content-Type: application/json + search: + index: test-dense-index + body: + query: + semantic: + field: "inference_field" + query: "inference test" + inner_hits: {} + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - close_to: { hits.hits.0._score: { value: 1.0, error: 0.0001 } } + - length: { hits.hits.0._source.inference_field.inference.chunks: 3 } + - match: { hits.hits.0.inner_hits.inference_field.hits.total.value: 3 } + - length: { hits.hits.0.inner_hits.inference_field.hits.hits: 3 } + - match: { hits.hits.0.inner_hits.inference_field.hits.hits.0._source.text: "inference test" } + - not_exists: hits.hits.0.inner_hits.inference_field.hits.hits.0._source.embeddings + - match: { hits.hits.0.inner_hits.inference_field.hits.hits.1._source.text: "yet another inference test" } + - not_exists: hits.hits.0.inner_hits.inference_field.hits.hits.1._source.embeddings + - match: { hits.hits.0.inner_hits.inference_field.hits.hits.2._source.text: "another inference test" } + - not_exists: hits.hits.0.inner_hits.inference_field.hits.hits.2._source.embeddings + + - do: + headers: + # Force JSON content type so that we use a parser that interprets the floating-point score as a double + Content-Type: application/json + search: + index: test-dense-index + body: + query: + semantic: + field: "inference_field" + query: "inference test" + inner_hits: { + "size": 1 + } + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - close_to: { hits.hits.0._score: { value: 1.0, error: 0.0001 } } + - length: { hits.hits.0._source.inference_field.inference.chunks: 3 } + - match: { hits.hits.0.inner_hits.inference_field.hits.total.value: 3 } + - length: { hits.hits.0.inner_hits.inference_field.hits.hits: 1 } + - match: { hits.hits.0.inner_hits.inference_field.hits.hits.0._source.text: "inference test" } + - not_exists: hits.hits.0.inner_hits.inference_field.hits.hits.0._source.embeddings + + - do: + headers: + # Force JSON content type so that we use a parser that interprets the floating-point score as a double + Content-Type: application/json + search: + index: test-dense-index + body: + query: + semantic: + field: "inference_field" + query: "inference test" + inner_hits: { + "from": 1 + } + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - close_to: { hits.hits.0._score: { value: 1.0, error: 0.0001 } } + - length: { hits.hits.0._source.inference_field.inference.chunks: 3 } + - match: { hits.hits.0.inner_hits.inference_field.hits.total.value: 3 } + - length: { hits.hits.0.inner_hits.inference_field.hits.hits: 2 } + - match: { hits.hits.0.inner_hits.inference_field.hits.hits.0._source.text: "yet another inference test" } + - not_exists: hits.hits.0.inner_hits.inference_field.hits.hits.0._source.embeddings + - match: { hits.hits.0.inner_hits.inference_field.hits.hits.1._source.text: "another inference test" } + - not_exists: hits.hits.0.inner_hits.inference_field.hits.hits.1._source.embeddings + + - do: + headers: + # Force JSON content type so that we use a parser that interprets the floating-point score as a double + Content-Type: application/json + search: + index: test-dense-index + body: + query: + semantic: + field: "inference_field" + query: "inference test" + inner_hits: { + "from": 1, + "size": 1 + } + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - close_to: { hits.hits.0._score: { value: 1.0, error: 0.0001 } } + - length: { hits.hits.0._source.inference_field.inference.chunks: 3 } + - match: { hits.hits.0.inner_hits.inference_field.hits.total.value: 3 } + - length: { hits.hits.0.inner_hits.inference_field.hits.hits: 1 } + - match: { hits.hits.0.inner_hits.inference_field.hits.hits.0._source.text: "yet another inference test" } + - not_exists: hits.hits.0.inner_hits.inference_field.hits.hits.0._source.embeddings + + - do: + headers: + # Force JSON content type so that we use a parser that interprets the floating-point score as a double + Content-Type: application/json + search: + index: test-dense-index + body: + query: + semantic: + field: "inference_field" + query: "inference test" + inner_hits: { + "from": 3 + } + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - close_to: { hits.hits.0._score: { value: 1.0, error: 0.0001 } } + - length: { hits.hits.0._source.inference_field.inference.chunks: 3 } + - match: { hits.hits.0.inner_hits.inference_field.hits.total.value: 0 } # Hits total drops to zero when you page off the end + - length: { hits.hits.0.inner_hits.inference_field.hits.hits: 0 } + --- "Numeric query using a dense embedding model": - skip: @@ -478,6 +760,101 @@ setup: - close_to: { hits.hits.0._score: { value: 3.7837332e17, error: 1e10 } } - length: { hits.hits.0._source.inference_field.inference.chunks: 2 } +--- +"Query multiple semantic text fields with inner hits": + - requires: + cluster_features: "semantic_text.inner_hits" + reason: semantic_text inner hits support added in 8.16.0 + + - do: + indices.create: + index: test-multi-semantic-text-field-index + body: + mappings: + properties: + inference_field_1: + type: semantic_text + inference_id: sparse-inference-id + inference_field_2: + type: semantic_text + inference_id: sparse-inference-id + + - do: + index: + index: test-multi-semantic-text-field-index + id: doc_1 + body: + inference_field_1: [ "inference test 1", "another inference test 1" ] + inference_field_2: [ "inference test 2", "another inference test 2", "yet another inference test 2" ] + refresh: true + + - do: + search: + index: test-multi-semantic-text-field-index + body: + query: + bool: + must: + - semantic: + field: "inference_field_1" + query: "inference test" + inner_hits: { } + - semantic: + field: "inference_field_2" + query: "inference test" + inner_hits: { } + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - length: { hits.hits.0._source.inference_field_1.inference.chunks: 2 } + - length: { hits.hits.0._source.inference_field_2.inference.chunks: 3 } + - match: { hits.hits.0.inner_hits.inference_field_1.hits.total.value: 2 } + - length: { hits.hits.0.inner_hits.inference_field_1.hits.hits: 2 } + - match: { hits.hits.0.inner_hits.inference_field_2.hits.total.value: 3 } + - length: { hits.hits.0.inner_hits.inference_field_2.hits.hits: 3 } + +--- +"Query semantic text field in object with inner hits": + - requires: + cluster_features: "semantic_text.inner_hits" + reason: semantic_text inner hits support added in 8.16.0 + + - do: + indices.create: + index: test-semantic-text-in-object-index + body: + mappings: + properties: + container: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + + - do: + index: + index: test-semantic-text-in-object-index + id: doc_1 + body: + container.inference_field: ["inference test", "another inference test", "yet another inference test"] + refresh: true + + - do: + search: + index: test-semantic-text-in-object-index + body: + query: + semantic: + field: "container.inference_field" + query: "inference test" + inner_hits: {} + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - exists: hits.hits.0.inner_hits.container\.inference_field + - match: { hits.hits.0.inner_hits.container\.inference_field.hits.total.value: 3 } + - length: { hits.hits.0.inner_hits.container\.inference_field.hits.hits: 3 } + --- "Query the wrong field type": - do: @@ -839,3 +1216,41 @@ setup: - match: { error.type: "resource_not_found_exception" } - match: { error.reason: "Inference endpoint not found [invalid-inference-id]" } + +--- +"Query using inner hits with invalid args": + - requires: + cluster_features: "semantic_text.inner_hits" + reason: semantic_text inner hits support added in 8.16.0 + + - do: + catch: bad_request + search: + index: test-sparse-index + body: + query: + semantic: + field: "inference_field" + query: "inference test" + inner_hits: { + "from": -1 + } + + - match: { error.root_cause.0.type: "illegal_argument_exception" } + - match: { error.root_cause.0.reason: "illegal from value, at least 0 or higher" } + + - do: + catch: bad_request + search: + index: test-sparse-index + body: + query: + semantic: + field: "inference_field" + query: "inference test" + inner_hits: { + "size": -1 + } + + - match: { error.root_cause.0.type: "illegal_argument_exception" } + - match: { error.root_cause.0.reason: "illegal size value, at least 0 or higher" }