diff --git a/docs/changelog/133675.yaml b/docs/changelog/133675.yaml new file mode 100644 index 0000000000000..03962ef4aa64e --- /dev/null +++ b/docs/changelog/133675.yaml @@ -0,0 +1,5 @@ +pr: 133675 +summary: Support using the semantic query across multiple inference IDs +area: Vector Search +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 7c5ae10bbedbc..641312f10ac4c 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -356,6 +356,7 @@ static TransportVersion def(int id) { public static final TransportVersion PROJECT_RESERVED_STATE_MOVE_TO_REGISTRY = def(9_147_0_00); public static final TransportVersion STREAMS_ENDPOINT_PARAM_RESTRICTIONS = def(9_148_0_00); public static final TransportVersion RESOLVE_INDEX_MODE_FILTER = def(9_149_0_00); + public static final TransportVersion SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS = def(9_150_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/SemanticMatchTestCase.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/SemanticMatchTestCase.java index 3a4c4bfd98f84..aada75f151d66 100644 --- a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/SemanticMatchTestCase.java +++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/SemanticMatchTestCase.java @@ -34,15 +34,28 @@ public abstract class SemanticMatchTestCase extends ESRestTestCase { public void testWithMultipleInferenceIds() throws IOException { assumeTrue("semantic text capability not available", EsqlCapabilities.Cap.SEMANTIC_TEXT_FIELD_CAPS.isEnabled()); + var request1 = new Request("POST", "/test-semantic1/_doc/id-1"); + request1.addParameter("refresh", "true"); + request1.setJsonEntity("{\"semantic_text_field\": \"inference test 1\"}"); + assertEquals(201, adminClient().performRequest(request1).getStatusLine().getStatusCode()); + + var request2 = new Request("POST", "/test-semantic2/_doc/id-2"); + request2.addParameter("refresh", "true"); + request2.setJsonEntity("{\"semantic_text_field\": \"inference test 2\"}"); + assertEquals(201, adminClient().performRequest(request2).getStatusLine().getStatusCode()); + String query = """ from test-semantic1,test-semantic2 | where match(semantic_text_field, "something") + | SORT semantic_text_field ASC """; - ResponseException re = expectThrows(ResponseException.class, () -> runEsqlQuery(query)); - - assertThat(re.getMessage(), containsString("Field [semantic_text_field] has multiple inference IDs associated with it")); + Map result = runEsqlQuery(query); - assertEquals(400, re.getResponse().getStatusLine().getStatusCode()); + assertResultMap( + result, + matchesList().item(matchesMap().entry("name", "semantic_text_field").entry("type", "text")), + matchesList(List.of(List.of("inference test 1"), List.of("inference test 2"))) + ); } public void testWithInferenceNotConfigured() { @@ -128,6 +141,28 @@ public void setUpIndices() throws IOException { createIndex(adminClient(), "test-semantic4", settings, mapping4); } + @Before + public void setUpSparseEmbeddingInferenceEndpoint() throws IOException { + Request request = new Request("PUT", "_inference/sparse_embedding/test_sparse_inference"); + request.setJsonEntity(""" + { + "service": "test_service", + "service_settings": { + "model": "my_model", + "api_key": "abc64" + }, + "task_settings": { + } + } + """); + try { + adminClient().performRequest(request); + } catch (ResponseException exc) { + // in case the removal failed + assertThat(exc.getResponse().getStatusLine().getStatusCode(), equalTo(400)); + } + } + @Before public void setUpTextEmbeddingInferenceEndpoint() throws IOException { Request request = new Request("PUT", "_inference/text_embedding/test_dense_inference"); @@ -155,6 +190,15 @@ public void setUpTextEmbeddingInferenceEndpoint() throws IOException { public void wipeData() throws IOException { adminClient().performRequest(new Request("DELETE", "*")); + try { + adminClient().performRequest(new Request("DELETE", "_inference/test_sparse_inference")); + } catch (ResponseException e) { + // 404 here means the endpoint was not created + if (e.getResponse().getStatusLine().getStatusCode() != 404) { + throw e; + } + } + try { adminClient().performRequest(new Request("DELETE", "_inference/test_dense_inference")); } catch (ResponseException e) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java index fd160ae10fa6f..d5f8cebba3167 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java @@ -11,6 +11,7 @@ import org.elasticsearch.features.NodeFeature; import org.elasticsearch.xpack.inference.mapper.SemanticInferenceMetadataFieldsMapper; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; +import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder; import java.util.HashSet; @@ -82,7 +83,8 @@ public Set getTestFeatures() { SEMANTIC_QUERY_REWRITE_INTERCEPTORS_PROPAGATE_BOOST_AND_QUERY_NAME_FIX, SEMANTIC_TEXT_HIGHLIGHTING_FLAT, SEMANTIC_TEXT_SPARSE_VECTOR_INDEX_OPTIONS, - SEMANTIC_TEXT_FIELDS_CHUNKS_FORMAT + SEMANTIC_TEXT_FIELDS_CHUNKS_FORMAT, + SemanticQueryBuilder.SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS ) ); if (RERANK_SNIPPETS.isEnabled()) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index f12e50674e5f0..7a4e7ef1306f2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -8,7 +8,6 @@ package org.elasticsearch.xpack.inference.queries; import org.apache.lucene.search.Query; -import org.apache.lucene.util.SetOnce; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ResolvedIndices; @@ -16,6 +15,7 @@ import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.query.AbstractQueryBuilder; import org.elasticsearch.index.query.MatchNoneQueryBuilder; @@ -35,13 +35,17 @@ import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; +import org.elasticsearch.xpack.inference.InferenceException; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import java.io.IOException; import java.util.Collection; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; @@ -51,6 +55,11 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder { public static final String NAME = "semantic"; + public static final NodeFeature SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS = new NodeFeature("semantic_query.multiple_inference_ids"); + + // Use a placeholder inference ID that will never overlap with a real inference endpoint (user-created or internal) + private static final String PLACEHOLDER_INFERENCE_ID = "$PLACEHOLDER"; + private static final ParseField FIELD_FIELD = new ParseField("field"); private static final ParseField QUERY_FIELD = new ParseField("query"); private static final ParseField LENIENT_FIELD = new ParseField("lenient"); @@ -70,9 +79,7 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder inferenceResultsSupplier; - private final InferenceResults inferenceResults; - private final boolean noInferenceResults; + private final Map inferenceResultsMap; private final Boolean lenient; public SemanticQueryBuilder(String fieldName, String query) { @@ -80,6 +87,10 @@ public SemanticQueryBuilder(String fieldName, String query) { } public SemanticQueryBuilder(String fieldName, String query, Boolean lenient) { + this(fieldName, query, lenient, null); + } + + protected SemanticQueryBuilder(String fieldName, String query, Boolean lenient, Map inferenceResultsMap) { if (fieldName == null) { throw new IllegalArgumentException("[" + NAME + "] requires a " + FIELD_FIELD.getPreferredName() + " value"); } @@ -88,9 +99,7 @@ public SemanticQueryBuilder(String fieldName, String query, Boolean lenient) { } this.fieldName = fieldName; this.query = query; - this.inferenceResults = null; - this.inferenceResultsSupplier = null; - this.noInferenceResults = false; + this.inferenceResultsMap = inferenceResultsMap != null ? Map.copyOf(inferenceResultsMap) : null; this.lenient = lenient; } @@ -98,9 +107,13 @@ public SemanticQueryBuilder(StreamInput in) throws IOException { super(in); this.fieldName = in.readString(); this.query = in.readString(); - this.inferenceResults = in.readOptionalNamedWriteable(InferenceResults.class); - this.noInferenceResults = in.readBoolean(); - this.inferenceResultsSupplier = null; + if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS)) { + this.inferenceResultsMap = in.readOptional(i1 -> i1.readImmutableMap(i2 -> i2.readNamedWriteable(InferenceResults.class))); + } else { + InferenceResults inferenceResults = in.readOptionalNamedWriteable(InferenceResults.class); + this.inferenceResultsMap = inferenceResults != null ? buildBwcInferenceResultsMap(inferenceResults) : null; + in.readBoolean(); // Discard noInferenceResults, it is no longer necessary + } if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_QUERY_LENIENT)) { this.lenient = in.readOptionalBoolean(); } else { @@ -110,31 +123,35 @@ public SemanticQueryBuilder(StreamInput in) throws IOException { @Override protected void doWriteTo(StreamOutput out) throws IOException { - if (inferenceResultsSupplier != null) { - throw new IllegalStateException("Inference results supplier is set. Missing a rewriteAndFetch?"); - } out.writeString(fieldName); out.writeString(query); - out.writeOptionalNamedWriteable(inferenceResults); - out.writeBoolean(noInferenceResults); + if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS)) { + out.writeOptional((o, v) -> o.writeMap(v, StreamOutput::writeNamedWriteable), inferenceResultsMap); + } else { + InferenceResults inferenceResults = null; + if (inferenceResultsMap != null) { + if (inferenceResultsMap.size() > 1) { + throw new IllegalArgumentException("Cannot query multiple inference IDs in a mixed-version cluster"); + } else if (inferenceResultsMap.size() == 1) { + inferenceResults = inferenceResultsMap.values().iterator().next(); + } + } + + out.writeOptionalNamedWriteable(inferenceResults); + out.writeBoolean(inferenceResults == null); + } if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_QUERY_LENIENT)) { out.writeOptionalBoolean(lenient); } } - private SemanticQueryBuilder( - SemanticQueryBuilder other, - SetOnce inferenceResultsSupplier, - InferenceResults inferenceResults, - boolean noInferenceResults - ) { + private SemanticQueryBuilder(SemanticQueryBuilder other, Map inferenceResultsMap) { this.fieldName = other.fieldName; this.query = other.query; this.boost = other.boost; this.queryName = other.queryName; - this.inferenceResultsSupplier = inferenceResultsSupplier; - this.inferenceResults = inferenceResults; - this.noInferenceResults = noInferenceResults; + // No need to copy the map here since this is only called internally. We can safely assume that the caller will not modify the map. + this.inferenceResultsMap = inferenceResultsMap; this.lenient = other.lenient; } @@ -160,6 +177,27 @@ public static SemanticQueryBuilder fromXContent(XContentParser parser) throws IO return PARSER.apply(parser, null); } + /** + * Build an inference results map to store a single inference result that is not associated with an inference ID. + * + * @param inferenceResults The inference result + * @return An inference results map + */ + protected static Map buildBwcInferenceResultsMap(InferenceResults inferenceResults) { + return Map.of(PLACEHOLDER_INFERENCE_ID, inferenceResults); + } + + /** + * Extract an inference result not associated with an inference ID from an inference results map. Returns null if no such inference + * result exists in the map. + * + * @param inferenceResultsMap The inference results map + * @return The inference result + */ + private static InferenceResults getBwcInferenceResults(Map inferenceResultsMap) { + return inferenceResultsMap.get(PLACEHOLDER_INFERENCE_ID); + } + @Override protected void doXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(NAME); @@ -187,13 +225,31 @@ private QueryBuilder doRewriteBuildSemanticQuery(SearchExecutionContext searchEx if (fieldType == null) { return new MatchNoneQueryBuilder(); } else if (fieldType instanceof SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType) { - if (inferenceResults == null) { + if (inferenceResultsMap == null) { // This should never happen, but throw on it in case it ever does throw new IllegalStateException( "No inference results set for [" + semanticTextFieldType.typeName() + "] field [" + fieldName + "]" ); } + String inferenceId = semanticTextFieldType.getSearchInferenceId(); + InferenceResults inferenceResults = getBwcInferenceResults(inferenceResultsMap); + if (inferenceResults == null) { + inferenceResults = inferenceResultsMap.get(inferenceId); + } + + if (inferenceResults == null) { + throw new IllegalStateException( + "No inference results set for [" + + semanticTextFieldType.typeName() + + "] field [" + + fieldName + + "] with inference ID [" + + inferenceId + + "]" + ); + } + return semanticTextFieldType.semanticQuery(inferenceResults, searchExecutionContext.requestSize(), boost(), queryName()); } else if (lenient != null && lenient) { return new MatchNoneQueryBuilder(); @@ -205,15 +261,11 @@ private QueryBuilder doRewriteBuildSemanticQuery(SearchExecutionContext searchEx } private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewriteContext) { - if (inferenceResults != null || noInferenceResults) { + if (inferenceResultsMap != null) { + inferenceResultsErrorCheck(); return this; } - if (inferenceResultsSupplier != null) { - InferenceResults inferenceResults = validateAndConvertInferenceResults(inferenceResultsSupplier, fieldName); - return inferenceResults != null ? new SemanticQueryBuilder(this, null, inferenceResults, noInferenceResults) : this; - } - ResolvedIndices resolvedIndices = queryRewriteContext.getResolvedIndices(); if (resolvedIndices == null) { throw new IllegalStateException( @@ -223,10 +275,9 @@ private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext qu throw new IllegalArgumentException(NAME + " query does not support cross-cluster search"); } - String inferenceId = getInferenceIdForForField(resolvedIndices.getConcreteLocalIndicesMetadata().values(), fieldName); - SetOnce inferenceResultsSupplier = new SetOnce<>(); - boolean noInferenceResults = false; - if (inferenceId != null) { + Map inferenceResultsMap = new ConcurrentHashMap<>(); + Set inferenceIds = getInferenceIdsForForField(resolvedIndices.getConcreteLocalIndicesMetadata().values(), fieldName); + for (String inferenceId : inferenceIds) { InferenceAction.Request inferenceRequest = new InferenceAction.Request( TaskType.ANY, inferenceId, @@ -247,53 +298,57 @@ private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext qu InferenceAction.INSTANCE, inferenceRequest, listener.delegateFailureAndWrap((l, inferenceResponse) -> { - inferenceResultsSupplier.set(inferenceResponse.getResults()); + inferenceResultsMap.put( + inferenceId, + validateAndConvertInferenceResults(inferenceResponse.getResults(), fieldName, inferenceId) + ); l.onResponse(null); }) ) ); - } else { - // The inference ID can be null if either the field name or index name(s) are invalid (or both). - // If this happens, we set the "no inference results" flag to true so the rewrite process can continue. - // Invalid index names will be handled in the transport layer, when the query is sent to the shard. - // Invalid field names will be handled when the query is re-written on the shard, where we have access to the index mappings. - noInferenceResults = true; } - return new SemanticQueryBuilder(this, noInferenceResults ? null : inferenceResultsSupplier, null, noInferenceResults); + return new SemanticQueryBuilder(this, inferenceResultsMap); } private static InferenceResults validateAndConvertInferenceResults( - SetOnce inferenceResultsSupplier, - String fieldName + InferenceServiceResults inferenceServiceResults, + String fieldName, + String inferenceId ) { - InferenceServiceResults inferenceServiceResults = inferenceResultsSupplier.get(); - if (inferenceServiceResults == null) { - return null; - } - List inferenceResultsList = inferenceServiceResults.transformToCoordinationFormat(); if (inferenceResultsList.isEmpty()) { - throw new IllegalArgumentException("No inference results retrieved for field [" + fieldName + "]"); + return new ErrorInferenceResults( + new IllegalArgumentException( + "No inference results retrieved for field [" + fieldName + "] with inference ID [" + inferenceId + "]" + ) + ); } else if (inferenceResultsList.size() > 1) { - // The inference call should truncate if the query is too large. + // We don't chunk queries, so there should always be one inference result. // Thus, if we receive more than one inference result, it is a server-side error. - throw new IllegalStateException(inferenceResultsList.size() + " inference results retrieved for field [" + fieldName + "]"); + return new ErrorInferenceResults( + new IllegalStateException( + inferenceResultsList.size() + + " inference results retrieved for field [" + + fieldName + + "] with inference ID [" + + inferenceId + + "]" + ) + ); } - InferenceResults inferenceResults = inferenceResultsList.get(0); - if (inferenceResults instanceof ErrorInferenceResults errorInferenceResults) { - throw new IllegalStateException( - "Field [" + fieldName + "] query inference error: " + errorInferenceResults.getException().getMessage(), - errorInferenceResults.getException() - ); - } else if (inferenceResults instanceof WarningInferenceResults warningInferenceResults) { - throw new IllegalStateException("Field [" + fieldName + "] query inference warning: " + warningInferenceResults.getWarning()); - } else if (inferenceResults instanceof TextExpansionResults == false - && inferenceResults instanceof MlTextEmbeddingResults == false) { - throw new IllegalArgumentException( + InferenceResults inferenceResults = inferenceResultsList.getFirst(); + if (inferenceResults instanceof TextExpansionResults == false + && inferenceResults instanceof MlTextEmbeddingResults == false + && inferenceResults instanceof ErrorInferenceResults == false + && inferenceResults instanceof WarningInferenceResults == false) { + return new ErrorInferenceResults( + new IllegalArgumentException( "Field [" + fieldName + + "] with inference ID [" + + inferenceId + "] expected query inference results to be of type [" + TextExpansionResults.NAME + "] or [" @@ -301,44 +356,64 @@ private static InferenceResults validateAndConvertInferenceResults( + "], got [" + inferenceResults.getWriteableName() + "]. Has the inference endpoint configuration changed?" - ); - } + ) + ); + } return inferenceResults; } + private void inferenceResultsErrorCheck() { + for (var entry : inferenceResultsMap.entrySet()) { + String inferenceId = entry.getKey(); + InferenceResults inferenceResults = entry.getValue(); + + if (inferenceResults instanceof ErrorInferenceResults errorInferenceResults) { + // Use InferenceException here so that the status code is set by the cause + throw new InferenceException( + "Field [" + fieldName + "] with inference ID [" + inferenceId + "] query inference error", + errorInferenceResults.getException() + ); + } else if (inferenceResults instanceof WarningInferenceResults warningInferenceResults) { + throw new IllegalStateException( + "Field [" + + fieldName + + "] with inference ID [" + + inferenceId + + "] query inference warning: " + + warningInferenceResults.getWarning() + ); + } + } + } + @Override protected Query doToQuery(SearchExecutionContext context) throws IOException { throw new IllegalStateException(NAME + " should have been rewritten to another query type"); } - private static String getInferenceIdForForField(Collection indexMetadataCollection, String fieldName) { - String inferenceId = null; + private static Set getInferenceIdsForForField(Collection indexMetadataCollection, String fieldName) { + Set inferenceIds = new HashSet<>(); for (IndexMetadata indexMetadata : indexMetadataCollection) { InferenceFieldMetadata inferenceFieldMetadata = indexMetadata.getInferenceFields().get(fieldName); String indexInferenceId = inferenceFieldMetadata != null ? inferenceFieldMetadata.getSearchInferenceId() : null; if (indexInferenceId != null) { - if (inferenceId != null && inferenceId.equals(indexInferenceId) == false) { - throw new IllegalArgumentException("Field [" + fieldName + "] has multiple inference IDs associated with it"); - } - - inferenceId = indexInferenceId; + inferenceIds.add(indexInferenceId); } } - return inferenceId; + return inferenceIds; } @Override protected boolean doEquals(SemanticQueryBuilder other) { return Objects.equals(fieldName, other.fieldName) && Objects.equals(query, other.query) - && Objects.equals(inferenceResults, other.inferenceResults) - && Objects.equals(inferenceResultsSupplier, other.inferenceResultsSupplier); + && Objects.equals(inferenceResultsMap, other.inferenceResultsMap); } @Override protected int doHashCode() { - return Objects.hash(fieldName, query, inferenceResults, inferenceResultsSupplier); + return Objects.hash(fieldName, query, inferenceResultsMap); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java index a7eda7112723e..eab1ba1b76767 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java @@ -20,6 +20,8 @@ import org.apache.lucene.search.Query; import org.apache.lucene.search.join.ScoreMode; import org.apache.lucene.tests.index.RandomIndexWriter; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionType; @@ -27,10 +29,14 @@ import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterChangedEvent; import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.common.CheckedBiConsumer; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.compress.CompressedXContent; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.Nullable; @@ -46,6 +52,7 @@ import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.index.search.ESToParentBlockJoinQuery; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.SimilarityMeasure; @@ -55,6 +62,7 @@ import org.elasticsearch.search.vectors.SparseVectorQueryWrapper; import org.elasticsearch.test.AbstractQueryTestCase; import org.elasticsearch.test.ClusterServiceUtils; +import org.elasticsearch.test.TransportVersionUtils; import org.elasticsearch.test.client.NoOpClient; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.xcontent.XContentBuilder; @@ -79,6 +87,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.function.Supplier; @@ -86,6 +95,7 @@ import static org.apache.lucene.search.BooleanClause.Occur.FILTER; import static org.apache.lucene.search.BooleanClause.Occur.MUST; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.notNullValue; @@ -361,6 +371,103 @@ public void testIllegalValues() { } } + public void testSerializationBwc() throws IOException { + InferenceResults inferenceResults1 = new TextExpansionResults( + DEFAULT_RESULTS_FIELD, + List.of(new WeightedToken("foo", 1.0f)), + false + ); + InferenceResults inferenceResults2 = new TextExpansionResults( + DEFAULT_RESULTS_FIELD, + List.of(new WeightedToken("bar", 2.0f)), + false + ); + + // Single inference result + CheckedBiConsumer assertSingleInferenceResult = (inferenceResults, version) -> { + String fieldName = randomAlphaOfLength(5); + String query = randomAlphaOfLength(5); + + SemanticQueryBuilder originalQuery = new SemanticQueryBuilder( + fieldName, + query, + null, + Map.of(randomAlphaOfLength(5), inferenceResults) + ); + SemanticQueryBuilder bwcQuery = new SemanticQueryBuilder( + fieldName, + query, + null, + SemanticQueryBuilder.buildBwcInferenceResultsMap(inferenceResults) + ); + + try (BytesStreamOutput output = new BytesStreamOutput()) { + output.setTransportVersion(version); + output.writeNamedWriteable(originalQuery); + try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), namedWriteableRegistry())) { + in.setTransportVersion(version); + QueryBuilder deserializedQuery = in.readNamedWriteable(QueryBuilder.class); + + SemanticQueryBuilder expectedQuery = version.onOrAfter(TransportVersions.SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS) + ? originalQuery + : bwcQuery; + assertThat(deserializedQuery, equalTo(expectedQuery)); + } + } + }; + + for (int i = 0; i < 100; i++) { + TransportVersion transportVersion = TransportVersionUtils.randomVersionBetween( + random(), + TransportVersions.V_8_15_0, + TransportVersion.current() + ); + assertSingleInferenceResult.accept(inferenceResults1, transportVersion); + } + + // Multiple inference results + CheckedBiConsumer, TransportVersion, IOException> assertMultipleInferenceResults = ( + inferenceResultsList, + version) -> { + Map inferenceResultsMap = new HashMap<>(inferenceResultsList.size()); + inferenceResultsList.forEach(result -> inferenceResultsMap.put(randomAlphaOfLength(5), result)); + SemanticQueryBuilder originalQuery = new SemanticQueryBuilder( + randomAlphaOfLength(5), + randomAlphaOfLength(5), + null, + inferenceResultsMap + ); + + try (BytesStreamOutput output = new BytesStreamOutput()) { + output.setTransportVersion(version); + + if (version.onOrAfter(TransportVersions.SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS)) { + output.writeNamedWriteable(originalQuery); + try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), namedWriteableRegistry())) { + in.setTransportVersion(version); + QueryBuilder deserializedQuery = in.readNamedWriteable(QueryBuilder.class); + assertThat(deserializedQuery, equalTo(originalQuery)); + } + } else { + IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> output.writeNamedWriteable(originalQuery) + ); + assertThat(e.getMessage(), containsString("Cannot query multiple inference IDs in a mixed-version cluster")); + } + } + }; + + for (int i = 0; i < 100; i++) { + TransportVersion transportVersion = TransportVersionUtils.randomVersionBetween( + random(), + TransportVersions.V_8_15_0, + TransportVersion.current() + ); + assertMultipleInferenceResults.accept(List.of(inferenceResults1, inferenceResults2), transportVersion); + } + } + public void testToXContent() throws IOException { QueryBuilder queryBuilder = new SemanticQueryBuilder("foo", "bar"); checkGeneratedJson(""" diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/40_semantic_text_query.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/40_semantic_text_query.yml index 8b1617a58fe09..0b1a611bcdf72 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/40_semantic_text_query.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/40_semantic_text_query.yml @@ -641,8 +641,29 @@ setup: --- "Query multiple indices": + - requires: + cluster_features: "semantic_query.multiple_inference_ids" + reason: Semantic query support for querying multiple inference IDs + + - do: + index: + index: test-sparse-index + id: doc_1 + body: + inference_field: [ "inference test", "another inference test" ] + non_inference_field: "non inference test" + refresh: true + + - do: + index: + index: test-dense-index + id: doc_2 + body: + inference_field: [ "inference test", "another inference test" ] + non_inference_field: "non inference test" + refresh: true + - do: - catch: bad_request search: index: - test-sparse-index @@ -653,12 +674,12 @@ setup: field: "inference_field" query: "inference test" - - match: { error.type: "illegal_argument_exception" } - - match: { error.reason: "Field [inference_field] has multiple inference IDs associated with it" } + - match: { hits.total.value: 2 } + - match: { hits.hits.0._id: "doc_1" } + - match: { hits.hits.1._id: "doc_2" } # Test wildcard resolution - do: - catch: bad_request search: index: test-* body: @@ -667,8 +688,9 @@ setup: field: "inference_field" query: "inference test" - - match: { error.type: "illegal_argument_exception" } - - match: { error.reason: "Field [inference_field] has multiple inference IDs associated with it" } + - match: { hits.total.value: 2 } + - match: { hits.hits.0._id: "doc_1" } + - match: { hits.hits.1._id: "doc_2" } # Test querying an index alias that resolves to multiple indices - do: @@ -679,7 +701,6 @@ setup: name: my-alias - do: - catch: bad_request search: index: my-alias body: @@ -688,10 +709,11 @@ setup: field: "inference_field" query: "inference test" - - match: { error.type: "illegal_argument_exception" } - - match: { error.reason: "Field [inference_field] has multiple inference IDs associated with it" } + - match: { hits.total.value: 2 } + - match: { hits.hits.0._id: "doc_1" } + - match: { hits.hits.1._id: "doc_2" } - # Test querying multiple indices that use the same inference ID - this should work + # Test querying multiple indices that use the same inference ID - do: indices.create: index: test-sparse-index-2 @@ -704,18 +726,10 @@ setup: non_inference_field: type: text - - do: - index: - index: test-sparse-index - id: doc_1 - body: - inference_field: "inference test" - refresh: true - - do: index: index: test-sparse-index-2 - id: doc_2 + id: doc_3 body: inference_field: "another inference test" refresh: true @@ -730,8 +744,8 @@ setup: query: "inference test" - match: { hits.total.value: 2 } - - match: { hits.hits.0._id: "doc_2" } - - match: { hits.hits.1._id: "doc_1" } + - match: { hits.hits.0._id: "doc_1" } + - match: { hits.hits.1._id: "doc_3" } --- "Query a field that has no indexed inference results":