Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
854bc26
Added inner hits builder to semantic query
Mikep86 Aug 12, 2024
733cae8
Pass inner hit builder to nested query builder
Mikep86 Aug 12, 2024
ada0b3a
Added InnerChunkBuilder
Mikep86 Aug 12, 2024
a8fac5f
Update InnerChunkBuilder to not inherit from InnerHitBuilder
Mikep86 Aug 12, 2024
557dc9a
Hard-code name in InnerChunkBuilder
Mikep86 Aug 12, 2024
8e94854
Updated semantic query builder tests
Mikep86 Aug 12, 2024
a860e69
Added YAML tests
Mikep86 Aug 13, 2024
dac2bd4
Resolved TODOs
Mikep86 Aug 13, 2024
dd67452
Update docs/changelog/111834.yaml
Mikep86 Aug 13, 2024
9d5fa1d
Fixed changelog
Mikep86 Aug 13, 2024
639adad
Set inner chunk builder name based on field name
Mikep86 Aug 13, 2024
4b8a62b
Add YAML test for querying multiple semantic text fields with inner c…
Mikep86 Aug 13, 2024
9311cc1
Fix YAML tests
Mikep86 Aug 13, 2024
df127b9
Rename inner_chunks to chunks
Mikep86 Aug 13, 2024
202314c
Fail the semantic query request if the transport version is not compa…
Mikep86 Aug 13, 2024
17e8edb
YAML test updates
Mikep86 Aug 13, 2024
91add83
Exclude embeddings from inner hit _source output
Mikep86 Aug 14, 2024
ae898fd
Updated YAML tests to check that embeddings are not in inner hits _so…
Mikep86 Aug 14, 2024
23d3344
Updated semantic query documentation
Mikep86 Aug 14, 2024
43b0a7f
Fix link
Mikep86 Aug 14, 2024
91f21f9
Merge branch 'main' into semantic-query_inner-hits
Mikep86 Aug 14, 2024
a5a03d9
Docs adjustments
Mikep86 Aug 14, 2024
1982eda
Fix headings
Mikep86 Aug 14, 2024
b0244f1
Merge branch 'main' into semantic-query_inner-hits
Mikep86 Aug 14, 2024
a5ee5d8
Merge branch 'main' into semantic-query_inner-hits
Mikep86 Sep 24, 2024
e28e72f
Added cluster feature for semantic text inner hits support
Mikep86 Sep 24, 2024
ee95981
Merge branch 'main' into semantic-query_inner-hits
Mikep86 Sep 24, 2024
8c73841
Rename chunks param to inner_hits
Mikep86 Sep 25, 2024
b779e29
Update documentation to address feedback and rename chunks to inner_hits
Mikep86 Sep 25, 2024
9f42742
Added reason for skipping doc tests
Mikep86 Sep 25, 2024
a19fd6e
Added "Query semantic text field in object with inner hits" YAML test
Mikep86 Sep 25, 2024
bb95eee
Merge branch 'main' into semantic-query_inner-hits
Mikep86 Sep 25, 2024
f62649d
Merge branch 'main' into semantic-query_inner-hits
Mikep86 Sep 25, 2024
3cbd7a5
PR feedback
Mikep86 Sep 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/111834.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 111834
summary: Add inner hits support to semantic query
area: Search
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_SINGLE_VALUE_QUERY_SOURCE = def(8_718_00_0);
public static final TransportVersion ESQL_ORIGINAL_INDICES = def(8_719_00_0);
public static final TransportVersion ML_INFERENCE_EIS_INTEGRATION_ADDED = def(8_720_00_0);
public static final TransportVersion SEMANTIC_QUERY_INNER_HITS = def(8_721_00_0);
/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ public final class InnerHitBuilder implements Writeable, ToXContentObject {
public static final ParseField COLLAPSE_FIELD = new ParseField("collapse");
public static final ParseField FIELD_FIELD = new ParseField("field");

public static final int DEFAULT_FROM = 0;
public static final int DEFAULT_SIZE = 3;
private static final boolean DEFAULT_IGNORE_UNAMPPED = false;
private static final int DEFAULT_FROM = 0;
private static final int DEFAULT_SIZE = 3;
private static final boolean DEFAULT_VERSION = false;
private static final boolean DEFAULT_SEQ_NO_AND_PRIMARY_TERM = false;
private static final boolean DEFAULT_EXPLAIN = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.elasticsearch.index.mapper.ValueFetcher;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper;
import org.elasticsearch.index.query.InnerHitBuilder;
import org.elasticsearch.index.query.MatchNoneQueryBuilder;
import org.elasticsearch.index.query.NestedQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
Expand All @@ -52,6 +53,7 @@
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import org.elasticsearch.xpack.inference.queries.InnerChunkBuilder;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -383,7 +385,12 @@ public IndexFieldData.Builder fielddataBuilder(FieldDataContext fieldDataContext
throw new IllegalArgumentException("[semantic_text] fields do not support sorting, scripting or aggregating");
}

public QueryBuilder semanticQuery(InferenceResults inferenceResults, float boost, String queryName) {
public QueryBuilder semanticQuery(
InferenceResults inferenceResults,
float boost,
String queryName,
InnerChunkBuilder innerChunkBuilder
) {
String nestedFieldPath = getChunksFieldName(name());
String inferenceResultsFieldName = getEmbeddingsFieldName(name());
QueryBuilder childQueryBuilder;
Expand Down Expand Up @@ -459,7 +466,10 @@ public QueryBuilder semanticQuery(InferenceResults inferenceResults, float boost
};
}

return new NestedQueryBuilder(nestedFieldPath, childQueryBuilder, ScoreMode.Max).boost(boost).queryName(queryName);
InnerHitBuilder innerHitBuilder = innerChunkBuilder != null ? innerChunkBuilder.toInnerHitBuilder() : null;
return new NestedQueryBuilder(nestedFieldPath, childQueryBuilder, ScoreMode.Max).boost(boost)
.queryName(queryName)
.innerHit(innerHitBuilder);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.queries;

import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.index.query.InnerHitBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;

import java.io.IOException;
import java.util.Objects;

import static org.elasticsearch.index.query.InnerHitBuilder.DEFAULT_FROM;
import static org.elasticsearch.index.query.InnerHitBuilder.DEFAULT_SIZE;

public class InnerChunkBuilder implements Writeable, ToXContentObject {
private static final ObjectParser<InnerChunkBuilder, Void> PARSER = new ObjectParser<>("inner_chunks", InnerChunkBuilder::new);

static {
PARSER.declareInt(InnerChunkBuilder::setFrom, SearchSourceBuilder.FROM_FIELD);
PARSER.declareInt(InnerChunkBuilder::setSize, SearchSourceBuilder.SIZE_FIELD);
}

private String name;
private int from = DEFAULT_FROM;
private int size = DEFAULT_SIZE;

public InnerChunkBuilder() {
this.name = null;
}

public InnerChunkBuilder(StreamInput in) throws IOException {
name = in.readString();
from = in.readVInt();
size = in.readVInt();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(name);
out.writeVInt(from);
out.writeVInt(size);
}

public String getName() {
return name;
}

public void setName(String name) {
this.name = name;
}

public int getFrom() {
return from;
}

public InnerChunkBuilder setFrom(int from) {
this.from = from;
return this;
}

public int getSize() {
return size;
}

public InnerChunkBuilder setSize(int size) {
this.size = size;
return this;
}

public InnerHitBuilder toInnerHitBuilder() {
if (name == null) {
throw new IllegalStateException("name must have a value");
}
return new InnerHitBuilder(name).setFrom(from).setSize(size);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
// Don't include name in XContent because it is hard-coded
builder.startObject();
if (from != DEFAULT_FROM) {
builder.field(SearchSourceBuilder.FROM_FIELD.getPreferredName(), from);
}
if (size != DEFAULT_SIZE) {
builder.field(SearchSourceBuilder.SIZE_FIELD.getPreferredName(), size);
}
builder.endObject();
return builder;
}

public static InnerChunkBuilder fromXContent(XContentParser parser) throws IOException {
return PARSER.parse(parser, new InnerChunkBuilder(), null);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
InnerChunkBuilder that = (InnerChunkBuilder) o;
return from == that.from && size == that.size && Objects.equals(name, that.name);
}

@Override
public int hashCode() {
return Objects.hash(name, from, size);
}

@Override
public String toString() {
return Strings.toString(this, true, true);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.MatchNoneQueryBuilder;
Expand Down Expand Up @@ -44,7 +45,9 @@
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.TransportVersions.SEMANTIC_QUERY_INNER_HITS;
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;

Expand All @@ -53,26 +56,33 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil

private static final ParseField FIELD_FIELD = new ParseField("field");
private static final ParseField QUERY_FIELD = new ParseField("query");
private static final ParseField INNER_CHUNKS_FIELD = new ParseField("inner_chunks");

private static final ConstructingObjectParser<SemanticQueryBuilder, Void> PARSER = new ConstructingObjectParser<>(
NAME,
false,
args -> new SemanticQueryBuilder((String) args[0], (String) args[1])
args -> new SemanticQueryBuilder((String) args[0], (String) args[1], (InnerChunkBuilder) args[2])
);

static {
PARSER.declareString(constructorArg(), FIELD_FIELD);
PARSER.declareString(constructorArg(), QUERY_FIELD);
PARSER.declareObject(optionalConstructorArg(), (p, c) -> InnerChunkBuilder.fromXContent(p), INNER_CHUNKS_FIELD);
declareStandardFields(PARSER);
}

private final String fieldName;
private final String query;
private final InnerChunkBuilder innerChunkBuilder;
private final SetOnce<InferenceServiceResults> inferenceResultsSupplier;
private final InferenceResults inferenceResults;
private final boolean noInferenceResults;

public SemanticQueryBuilder(String fieldName, String query) {
this(fieldName, query, null);
}

public SemanticQueryBuilder(String fieldName, String query, @Nullable InnerChunkBuilder innerChunkBuilder) {
if (fieldName == null) {
throw new IllegalArgumentException("[" + NAME + "] requires a " + FIELD_FIELD.getPreferredName() + " value");
}
Expand All @@ -81,15 +91,25 @@ public SemanticQueryBuilder(String fieldName, String query) {
}
this.fieldName = fieldName;
this.query = query;
this.innerChunkBuilder = innerChunkBuilder;
this.inferenceResults = null;
this.inferenceResultsSupplier = null;
this.noInferenceResults = false;

if (this.innerChunkBuilder != null) {
this.innerChunkBuilder.setName(fieldName);
}
}

public SemanticQueryBuilder(StreamInput in) throws IOException {
super(in);
this.fieldName = in.readString();
this.query = in.readString();
if (in.getTransportVersion().onOrAfter(SEMANTIC_QUERY_INNER_HITS)) {
this.innerChunkBuilder = in.readOptionalWriteable(InnerChunkBuilder::new);
} else {
this.innerChunkBuilder = null;
}
this.inferenceResults = in.readOptionalNamedWriteable(InferenceResults.class);
this.noInferenceResults = in.readBoolean();
this.inferenceResultsSupplier = null;
Expand All @@ -102,6 +122,9 @@ protected void doWriteTo(StreamOutput out) throws IOException {
}
out.writeString(fieldName);
out.writeString(query);
if (out.getTransportVersion().onOrAfter(SEMANTIC_QUERY_INNER_HITS)) {
out.writeOptionalWriteable(innerChunkBuilder);
}
out.writeOptionalNamedWriteable(inferenceResults);
out.writeBoolean(noInferenceResults);
}
Expand All @@ -114,13 +137,18 @@ private SemanticQueryBuilder(
) {
this.fieldName = other.fieldName;
this.query = other.query;
this.innerChunkBuilder = other.innerChunkBuilder;
this.boost = other.boost;
this.queryName = other.queryName;
this.inferenceResultsSupplier = inferenceResultsSupplier;
this.inferenceResults = inferenceResults;
this.noInferenceResults = noInferenceResults;
}

public InnerChunkBuilder innerChunk() {
return innerChunkBuilder;
}

@Override
public String getWriteableName() {
return NAME;
Expand All @@ -140,6 +168,9 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep
builder.startObject(NAME);
builder.field(FIELD_FIELD.getPreferredName(), fieldName);
builder.field(QUERY_FIELD.getPreferredName(), query);
if (innerChunkBuilder != null) {
builder.field(INNER_CHUNKS_FIELD.getPreferredName(), innerChunkBuilder);
}
boostAndQueryNameToXContent(builder);
builder.endObject();
}
Expand All @@ -166,7 +197,7 @@ private QueryBuilder doRewriteBuildSemanticQuery(SearchExecutionContext searchEx
);
}

return semanticTextFieldType.semanticQuery(inferenceResults, boost(), queryName());
return semanticTextFieldType.semanticQuery(inferenceResults, boost(), queryName(), innerChunkBuilder);
} else {
throw new IllegalArgumentException(
"Field [" + fieldName + "] of type [" + fieldType.typeName() + "] does not support " + NAME + " queries"
Expand Down Expand Up @@ -300,11 +331,12 @@ private static String getInferenceIdForForField(Collection<IndexMetadata> indexM
protected boolean doEquals(SemanticQueryBuilder other) {
return Objects.equals(fieldName, other.fieldName)
&& Objects.equals(query, other.query)
&& Objects.equals(innerChunkBuilder, other.innerChunkBuilder)
&& Objects.equals(inferenceResults, other.inferenceResults);
}

@Override
protected int doHashCode() {
return Objects.hash(fieldName, query, inferenceResults);
return Objects.hash(fieldName, query, innerChunkBuilder, inferenceResults);
}
}
Loading