diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 5a026b6e1660b..6e652102c5de9 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -365,6 +365,7 @@ static TransportVersion def(int id) { public static final TransportVersion ESQL_LOOKUP_JOIN_ON_MANY_FIELDS = def(9_139_0_00); public static final TransportVersion SIMULATE_INGEST_EFFECTIVE_MAPPING = def(9_140_0_00); public static final TransportVersion RESOLVE_INDEX_MODE_ADDED = def(9_141_0_00); + public static final TransportVersion SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS = def(9_142_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/validate/query/TransportValidateQueryAction.java b/server/src/main/java/org/elasticsearch/action/admin/indices/validate/query/TransportValidateQueryAction.java index 1d9ba297c35c6..e330bef2b756b 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/validate/query/TransportValidateQueryAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/validate/query/TransportValidateQueryAction.java @@ -130,7 +130,7 @@ protected void doExecute(Task task, ValidateQueryRequest request, ActionListener } else { Rewriteable.rewriteAndFetch( request.query(), - searchService.getRewriteContext(timeProvider, resolvedIndices, null), + searchService.getRewriteContext(timeProvider, resolvedIndices, null, null), rewriteListener ); } diff --git a/server/src/main/java/org/elasticsearch/action/explain/TransportExplainAction.java b/server/src/main/java/org/elasticsearch/action/explain/TransportExplainAction.java index 1a3d4ac70831f..009234fad56a2 100644 --- a/server/src/main/java/org/elasticsearch/action/explain/TransportExplainAction.java +++ b/server/src/main/java/org/elasticsearch/action/explain/TransportExplainAction.java @@ -103,7 +103,11 @@ protected void doExecute(Task task, ExplainRequest request, ActionListener request.nowInMillis; - Rewriteable.rewriteAndFetch(request.query(), searchService.getRewriteContext(timeProvider, resolvedIndices, null), rewriteListener); + Rewriteable.rewriteAndFetch( + request.query(), + searchService.getRewriteContext(timeProvider, resolvedIndices, null, null), + rewriteListener + ); } @Override diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java index bf85075781bc8..693fc340e6189 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java @@ -552,6 +552,7 @@ public void onFailure(Exception e) { timeProvider::absoluteStartMillis, resolvedIndices, original.pointInTimeBuilder(), + shouldMinimizeRoundtrips(original), isExplain ), rewriteListener @@ -787,6 +788,7 @@ public void onFailure(Exception e) { String clusterAlias = entry.getKey(); boolean skipUnavailable = remoteClusterService.isSkipUnavailable(clusterAlias); OriginalIndices indices = entry.getValue(); + SearchRequest ccsSearchRequest = SearchRequest.subSearchRequest( parentTaskId, searchRequest, diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchShardsAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchShardsAction.java index d12847ec8bf7f..eeecda05ca5d0 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchShardsAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchShardsAction.java @@ -127,7 +127,7 @@ public void searchShards(Task task, SearchShardsRequest searchShardsRequest, Act Rewriteable.rewriteAndFetch( original, - searchService.getRewriteContext(timeProvider::absoluteStartMillis, resolvedIndices, null), + searchService.getRewriteContext(timeProvider::absoluteStartMillis, resolvedIndices, null, original.isCcsMinimizeRoundtrips()), listener.delegateFailureAndWrap((delegate, searchRequest) -> { Index[] concreteIndices = resolvedIndices.getConcreteLocalIndices(); final Set indicesAndAliases = indexNameExpressionResolver.resolveExpressions( diff --git a/server/src/main/java/org/elasticsearch/index/IndexService.java b/server/src/main/java/org/elasticsearch/index/IndexService.java index 2ab766a1253a7..c74930fb6fbe6 100644 --- a/server/src/main/java/org/elasticsearch/index/IndexService.java +++ b/server/src/main/java/org/elasticsearch/index/IndexService.java @@ -836,6 +836,7 @@ public QueryRewriteContext newQueryRewriteContext( null, null, null, + null, false ); } diff --git a/server/src/main/java/org/elasticsearch/index/query/CoordinatorRewriteContext.java b/server/src/main/java/org/elasticsearch/index/query/CoordinatorRewriteContext.java index e0cfd86c10c6d..520cf6f6bf16f 100644 --- a/server/src/main/java/org/elasticsearch/index/query/CoordinatorRewriteContext.java +++ b/server/src/main/java/org/elasticsearch/index/query/CoordinatorRewriteContext.java @@ -120,6 +120,7 @@ public CoordinatorRewriteContext( null, null, null, + null, false ); this.dateFieldRangeInfo = dateFieldRangeInfo; diff --git a/server/src/main/java/org/elasticsearch/index/query/InterceptedQueryBuilderWrapper.java b/server/src/main/java/org/elasticsearch/index/query/InterceptedQueryBuilderWrapper.java index b1030e4a76d97..389c9bfa837af 100644 --- a/server/src/main/java/org/elasticsearch/index/query/InterceptedQueryBuilderWrapper.java +++ b/server/src/main/java/org/elasticsearch/index/query/InterceptedQueryBuilderWrapper.java @@ -98,8 +98,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws @Override public boolean equals(Object o) { if (this == o) return true; - if (o instanceof InterceptedQueryBuilderWrapper == false) return false; - return Objects.equals(queryBuilder, ((InterceptedQueryBuilderWrapper) o).queryBuilder); + if (o == null || getClass() != o.getClass()) return false; + InterceptedQueryBuilderWrapper that = (InterceptedQueryBuilderWrapper) o; + return Objects.equals(queryBuilder, that.queryBuilder); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/query/QueryRewriteContext.java b/server/src/main/java/org/elasticsearch/index/query/QueryRewriteContext.java index bc14a31978c18..d4d445f479070 100644 --- a/server/src/main/java/org/elasticsearch/index/query/QueryRewriteContext.java +++ b/server/src/main/java/org/elasticsearch/index/query/QueryRewriteContext.java @@ -72,6 +72,7 @@ public class QueryRewriteContext { private final ResolvedIndices resolvedIndices; private final PointInTimeBuilder pit; private QueryRewriteInterceptor queryRewriteInterceptor; + private final Boolean ccsMinimizeRoundtrips; private final boolean isExplain; public QueryRewriteContext( @@ -91,6 +92,7 @@ public QueryRewriteContext( final ResolvedIndices resolvedIndices, final PointInTimeBuilder pit, final QueryRewriteInterceptor queryRewriteInterceptor, + final Boolean ccsMinimizeRoundtrips, final boolean isExplain ) { @@ -111,6 +113,7 @@ public QueryRewriteContext( this.resolvedIndices = resolvedIndices; this.pit = pit; this.queryRewriteInterceptor = queryRewriteInterceptor; + this.ccsMinimizeRoundtrips = ccsMinimizeRoundtrips; this.isExplain = isExplain; } @@ -132,6 +135,7 @@ public QueryRewriteContext(final XContentParserConfiguration parserConfiguration null, null, null, + null, false ); } @@ -142,9 +146,10 @@ public QueryRewriteContext( final LongSupplier nowInMillis, final ResolvedIndices resolvedIndices, final PointInTimeBuilder pit, - final QueryRewriteInterceptor queryRewriteInterceptor + final QueryRewriteInterceptor queryRewriteInterceptor, + final Boolean ccsMinimizeRoundtrips ) { - this(parserConfiguration, client, nowInMillis, resolvedIndices, pit, queryRewriteInterceptor, false); + this(parserConfiguration, client, nowInMillis, resolvedIndices, pit, queryRewriteInterceptor, ccsMinimizeRoundtrips, false); } public QueryRewriteContext( @@ -154,6 +159,7 @@ public QueryRewriteContext( final ResolvedIndices resolvedIndices, final PointInTimeBuilder pit, final QueryRewriteInterceptor queryRewriteInterceptor, + final Boolean ccsMinimizeRoundtrips, final boolean isExplain ) { this( @@ -173,6 +179,7 @@ public QueryRewriteContext( resolvedIndices, pit, queryRewriteInterceptor, + ccsMinimizeRoundtrips, isExplain ); } @@ -279,6 +286,10 @@ public void setMapUnmappedFieldAsString(boolean mapUnmappedFieldAsString) { this.mapUnmappedFieldAsString = mapUnmappedFieldAsString; } + public Boolean isCcsMinimizeRoundtrips() { + return ccsMinimizeRoundtrips; + } + public boolean isExplain() { return this.isExplain; } diff --git a/server/src/main/java/org/elasticsearch/index/query/SearchExecutionContext.java b/server/src/main/java/org/elasticsearch/index/query/SearchExecutionContext.java index 56e136801e128..c2ba3025b5a5f 100644 --- a/server/src/main/java/org/elasticsearch/index/query/SearchExecutionContext.java +++ b/server/src/main/java/org/elasticsearch/index/query/SearchExecutionContext.java @@ -276,6 +276,7 @@ private SearchExecutionContext( null, null, null, + null, false ); this.shardId = shardId; diff --git a/server/src/main/java/org/elasticsearch/indices/IndicesService.java b/server/src/main/java/org/elasticsearch/indices/IndicesService.java index 528601f201fee..caa076eb057f1 100644 --- a/server/src/main/java/org/elasticsearch/indices/IndicesService.java +++ b/server/src/main/java/org/elasticsearch/indices/IndicesService.java @@ -1849,9 +1849,19 @@ public QueryRewriteContext getRewriteContext( LongSupplier nowInMillis, ResolvedIndices resolvedIndices, PointInTimeBuilder pit, + final Boolean ccsMinimizeRoundtrips, final boolean isExplain ) { - return new QueryRewriteContext(parserConfig, client, nowInMillis, resolvedIndices, pit, queryRewriteInterceptor, isExplain); + return new QueryRewriteContext( + parserConfig, + client, + nowInMillis, + resolvedIndices, + pit, + queryRewriteInterceptor, + ccsMinimizeRoundtrips, + isExplain + ); } public DataRewriteContext getDataRewriteContext(LongSupplier nowInMillis) { diff --git a/server/src/main/java/org/elasticsearch/search/SearchService.java b/server/src/main/java/org/elasticsearch/search/SearchService.java index 9cca55f2ec748..4e819ef739a2a 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchService.java +++ b/server/src/main/java/org/elasticsearch/search/SearchService.java @@ -2127,8 +2127,13 @@ private void rewriteAndFetchShardRequest(IndexShard shard, ShardSearchRequest re /** * Returns a new {@link QueryRewriteContext} with the given {@code now} provider */ - public QueryRewriteContext getRewriteContext(LongSupplier nowInMillis, ResolvedIndices resolvedIndices, PointInTimeBuilder pit) { - return getRewriteContext(nowInMillis, resolvedIndices, pit, false); + public QueryRewriteContext getRewriteContext( + LongSupplier nowInMillis, + ResolvedIndices resolvedIndices, + PointInTimeBuilder pit, + final Boolean ccsMinimizeRoundtrips + ) { + return getRewriteContext(nowInMillis, resolvedIndices, pit, ccsMinimizeRoundtrips, false); } /** @@ -2138,9 +2143,10 @@ public QueryRewriteContext getRewriteContext( LongSupplier nowInMillis, ResolvedIndices resolvedIndices, PointInTimeBuilder pit, + final Boolean ccsMinimizeRoundtrips, final boolean isExplain ) { - return indicesService.getRewriteContext(nowInMillis, resolvedIndices, pit, isExplain); + return indicesService.getRewriteContext(nowInMillis, resolvedIndices, pit, ccsMinimizeRoundtrips, isExplain); } public CoordinatorRewriteContextProvider getCoordinatorRewriteContextProvider(LongSupplier nowInMillis) { diff --git a/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java b/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java index 9f286efe28083..d20df0f2154ce 100644 --- a/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java @@ -1784,8 +1784,8 @@ protected void doWriteTo(StreamOutput out) throws IOException { NodeClient client = new NodeClient(settings, threadPool, TestProjectResolvers.alwaysThrow()); SearchService searchService = mock(SearchService.class); - when(searchService.getRewriteContext(any(), any(), any(), anyBoolean())).thenReturn( - new QueryRewriteContext(null, null, null, null, null, null) + when(searchService.getRewriteContext(any(), any(), any(), anyBoolean(), anyBoolean())).thenReturn( + new QueryRewriteContext(null, null, null, null, null, null, null) ); ClusterService clusterService = new ClusterService( settings, diff --git a/server/src/test/java/org/elasticsearch/index/query/QueryRewriteContextTests.java b/server/src/test/java/org/elasticsearch/index/query/QueryRewriteContextTests.java index b997ac4747a07..a0f88e61963c3 100644 --- a/server/src/test/java/org/elasticsearch/index/query/QueryRewriteContextTests.java +++ b/server/src/test/java/org/elasticsearch/index/query/QueryRewriteContextTests.java @@ -54,6 +54,7 @@ public void testGetTierPreference() { null, null, null, + null, false ); @@ -83,6 +84,7 @@ public void testGetTierPreference() { null, null, null, + null, false ); diff --git a/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java index be30dbe9823d4..5597d01ba0b42 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java @@ -637,6 +637,7 @@ QueryRewriteContext createQueryRewriteContext() { createMockResolvedIndices(), null, createMockQueryRewriteInterceptor(), + null, false ); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/QueryBuilderResolver.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/QueryBuilderResolver.java index ef3828a3f2fbb..1f08595c71cbb 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/QueryBuilderResolver.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/QueryBuilderResolver.java @@ -61,7 +61,7 @@ private static QueryRewriteContext queryRewriteContext(TransportActionServices s System.currentTimeMillis() ); - return services.searchService().getRewriteContext(System::currentTimeMillis, resolvedIndices, null); + return services.searchService().getRewriteContext(System::currentTimeMillis, resolvedIndices, null, null); } private static Set indexNames(LogicalPlan plan) { diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java new file mode 100644 index 0000000000000..a14048546c0fe --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java @@ -0,0 +1,258 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.search.ccs; + +import org.elasticsearch.action.search.OpenPointInTimeRequest; +import org.elasticsearch.action.search.OpenPointInTimeResponse; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.TransportOpenPointInTimeAction; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.settings.Setting; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.license.LicenseSettings; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.search.builder.PointInTimeBuilder; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.test.AbstractMultiClustersTestCase; +import org.elasticsearch.test.InternalTestCluster; +import org.elasticsearch.transport.RemoteClusterAware; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; +import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension; +import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin; +import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; +import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; + +import java.io.IOException; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse; +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.equalTo; + +public class SemanticCrossClusterSearchIT extends AbstractMultiClustersTestCase { + private static final String REMOTE_CLUSTER = "cluster_a"; + private static final String INFERENCE_FIELD = "inference_field"; + + private static final Map TEXT_EMBEDDING_SERVICE_SETTINGS_1 = Map.of( + "model", + "my_model", + "dimensions", + 256, + "similarity", + "cosine", + "api_key", + "my_api_key" + ); + + private static final Map TEXT_EMBEDDING_SERVICE_SETTINGS_2 = Map.of( + "model", + "my_model", + "dimensions", + 384, + "similarity", + "cosine", + "api_key", + "my_api_key" + ); + + @Override + protected List remoteClusterAlias() { + return List.of(REMOTE_CLUSTER); + } + + @Override + protected Map skipUnavailableForRemoteClusters() { + return Map.of(REMOTE_CLUSTER, randomBoolean()); + } + + @Override + protected boolean reuseClusters() { + return false; + } + + @Override + protected Settings nodeSettings() { + return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build(); + } + + @Override + protected Collection> nodePlugins(String clusterAlias) { + return List.of(LocalStateInferencePlugin.class, TestInferenceServicePlugin.class, FakeMlPlugin.class); + } + + public void testSemanticCrossClusterSearch() throws Exception { + Map testClusterInfo = setupTwoClusters(); + String localIndex = (String) testClusterInfo.get("local.index"); + String remoteIndex = (String) testClusterInfo.get("remote.index"); + + ModelRegistry modelRegistry = cluster(LOCAL_CLUSTER).getCurrentMasterNodeInstance(ModelRegistry.class); + SemanticQueryBuilder queryBuilder = new SemanticQueryBuilder(INFERENCE_FIELD, "foo"); + queryBuilder.setModelRegistrySupplier(() -> modelRegistry); + + SearchRequest searchRequest = new SearchRequest(localIndex, REMOTE_CLUSTER + ":" + remoteIndex); + searchRequest.source(new SearchSourceBuilder().query(queryBuilder).size(10)); + searchRequest.setCcsMinimizeRoundtrips(true); + + assertResponse(client(LOCAL_CLUSTER).search(searchRequest), response -> { + assertNotNull(response); + assertEquals(10, response.getHits().getHits().length); + }); + } + + public void testSemanticCrossClusterSearchWithPIT() throws Exception { + Map testClusterInfo = setupTwoClusters(); + String localIndex = (String) testClusterInfo.get("local.index"); + String remoteIndex = (String) testClusterInfo.get("remote.index"); + + BytesReference pitId = openPointInTime( + new String[] { localIndex, REMOTE_CLUSTER + ":" + remoteIndex }, + TimeValue.timeValueMinutes(2) + ); + + ModelRegistry modelRegistry = cluster(LOCAL_CLUSTER).getCurrentMasterNodeInstance(ModelRegistry.class); + SemanticQueryBuilder queryBuilder = new SemanticQueryBuilder(INFERENCE_FIELD, "foo"); + queryBuilder.setModelRegistrySupplier(() -> modelRegistry); + + SearchRequest searchRequest = new SearchRequest(); + searchRequest.source(new SearchSourceBuilder().query(queryBuilder).pointInTimeBuilder(new PointInTimeBuilder(pitId)).size(10)); + + IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> client(LOCAL_CLUSTER).search(searchRequest).actionGet(TEST_REQUEST_TIMEOUT) + ); + assertThat(e.getMessage(), containsString("semantic query supports CCS only when ccs_minimize_roundtrips=true")); + } + + private Map setupTwoClusters(String[] localIndices, String[] remoteIndices) throws IOException { + final String localInferenceId = "local_inference_id"; + final String remoteInferenceId = "remote_inference_id"; + createInferenceEndpoint(client(LOCAL_CLUSTER), TaskType.TEXT_EMBEDDING, localInferenceId, TEXT_EMBEDDING_SERVICE_SETTINGS_1); + createInferenceEndpoint(client(REMOTE_CLUSTER), TaskType.TEXT_EMBEDDING, remoteInferenceId, TEXT_EMBEDDING_SERVICE_SETTINGS_2); + + int numShardsLocal = randomIntBetween(2, 10); + Settings localSettings = indexSettings(numShardsLocal, randomIntBetween(0, 1)).build(); + for (String localIndex : localIndices) { + assertAcked( + client(LOCAL_CLUSTER).admin() + .indices() + .prepareCreate(localIndex) + .setSettings(localSettings) + .setMapping(INFERENCE_FIELD, "type=semantic_text,inference_id=" + localInferenceId) + ); + indexDocs(client(LOCAL_CLUSTER), localIndex); + } + + int numShardsRemote = randomIntBetween(2, 10); + final InternalTestCluster remoteCluster = cluster(REMOTE_CLUSTER); + remoteCluster.ensureAtLeastNumDataNodes(randomIntBetween(1, 3)); + for (String remoteIndex : remoteIndices) { + assertAcked( + client(REMOTE_CLUSTER).admin() + .indices() + .prepareCreate(remoteIndex) + .setSettings(indexSettings(numShardsRemote, randomIntBetween(0, 1))) + .setMapping(INFERENCE_FIELD, "type=semantic_text,inference_id=" + remoteInferenceId) + ); + assertFalse( + client(REMOTE_CLUSTER).admin() + .cluster() + .prepareHealth(TEST_REQUEST_TIMEOUT, remoteIndex) + .setWaitForYellowStatus() + .setTimeout(TimeValue.timeValueSeconds(10)) + .get() + .isTimedOut() + ); + indexDocs(client(REMOTE_CLUSTER), remoteIndex); + } + + String skipUnavailableKey = Strings.format("cluster.remote.%s.skip_unavailable", REMOTE_CLUSTER); + Setting skipUnavailableSetting = cluster(REMOTE_CLUSTER).clusterService().getClusterSettings().get(skipUnavailableKey); + boolean skipUnavailable = (boolean) cluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY).clusterService() + .getClusterSettings() + .get(skipUnavailableSetting); + + Map clusterInfo = new HashMap<>(); + clusterInfo.put("local.num_shards", numShardsLocal); + clusterInfo.put("remote.num_shards", numShardsRemote); + clusterInfo.put("remote.skip_unavailable", skipUnavailable); + return clusterInfo; + } + + private Map setupTwoClusters() throws IOException { + var clusterInfo = setupTwoClusters(new String[] { "demo" }, new String[] { "prod" }); + clusterInfo.put("local.index", "demo"); + clusterInfo.put("remote.index", "prod"); + return clusterInfo; + } + + private void createInferenceEndpoint(Client client, TaskType taskType, String inferenceId, Map serviceSettings) + throws IOException { + final String service = switch (taskType) { + case TEXT_EMBEDDING -> TestDenseInferenceServiceExtension.TestInferenceService.NAME; + case SPARSE_EMBEDDING -> TestSparseInferenceServiceExtension.TestInferenceService.NAME; + default -> throw new IllegalArgumentException("Unhandled task type [" + taskType + "]"); + }; + + final BytesReference content; + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + builder.startObject(); + builder.field("service", service); + builder.field("service_settings", serviceSettings); + builder.endObject(); + + content = BytesReference.bytes(builder); + } + + PutInferenceModelAction.Request request = new PutInferenceModelAction.Request( + taskType, + inferenceId, + content, + XContentType.JSON, + TEST_REQUEST_TIMEOUT + ); + var responseFuture = client.execute(PutInferenceModelAction.INSTANCE, request); + assertThat(responseFuture.actionGet(TEST_REQUEST_TIMEOUT).getModel().getInferenceEntityId(), equalTo(inferenceId)); + } + + private int indexDocs(Client client, String index) { + int numDocs = between(5, 10); + for (int i = 0; i < numDocs; i++) { + client.prepareIndex(index).setSource(INFERENCE_FIELD, randomAlphaOfLength(10)).get(); + } + client.admin().indices().prepareRefresh(index).get(); + return numDocs; + } + + private BytesReference openPointInTime(String[] indices, TimeValue keepAlive) { + OpenPointInTimeRequest request = new OpenPointInTimeRequest(indices).keepAlive(keepAlive); + final OpenPointInTimeResponse response = client().execute(TransportOpenPointInTimeAction.TYPE, request).actionGet(); + return response.getPointInTimeId(); + } + + public static class FakeMlPlugin extends Plugin { + @Override + public List getNamedWriteables() { + return new MlInferenceNamedXContentProvider().getNamedWriteables(); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index f8fb375022abb..f99f6e8734e8a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -13,6 +13,7 @@ import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.SecretSettings; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.TaskSettings; @@ -187,6 +188,9 @@ public static List getNamedWriteables() { namedWriteables.addAll(DeepSeekChatCompletionModel.namedWriteables()); namedWriteables.addAll(SageMakerModel.namedWriteables()); namedWriteables.addAll(SageMakerSchemas.namedWriteables()); + namedWriteables.add( + new NamedWriteableRegistry.Entry(MinimalServiceSettings.class, MinimalServiceSettings.NAME, MinimalServiceSettings::new) + ); return namedWriteables; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index c3ae4f0d9d6d6..0f88716bcc900 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -93,10 +93,13 @@ import org.elasticsearch.xpack.inference.mapper.OffsetSourceFieldMapper; import org.elasticsearch.xpack.inference.mapper.SemanticInferenceMetadataFieldsMapper; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; +import org.elasticsearch.xpack.inference.queries.EmbeddingsProvider; +import org.elasticsearch.xpack.inference.queries.MapEmbeddingsProvider; import org.elasticsearch.xpack.inference.queries.SemanticKnnVectorQueryRewriteInterceptor; import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor; import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; import org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor; +import org.elasticsearch.xpack.inference.queries.SingleEmbeddingsProvider; import org.elasticsearch.xpack.inference.rank.random.RandomRankBuilder; import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder; import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankBuilder; @@ -428,6 +431,10 @@ public List getNamedWriteables() { entries.add(new NamedWriteableRegistry.Entry(RankDoc.class, TextSimilarityRankDoc.NAME, TextSimilarityRankDoc::new)); entries.add(new NamedWriteableRegistry.Entry(Metadata.ProjectCustom.class, ModelRegistryMetadata.TYPE, ModelRegistryMetadata::new)); entries.add(new NamedWriteableRegistry.Entry(NamedDiff.class, ModelRegistryMetadata.TYPE, ModelRegistryMetadata::readDiffFrom)); + entries.add(new NamedWriteableRegistry.Entry(EmbeddingsProvider.class, MapEmbeddingsProvider.NAME, MapEmbeddingsProvider::new)); + entries.add( + new NamedWriteableRegistry.Entry(EmbeddingsProvider.class, SingleEmbeddingsProvider.NAME, SingleEmbeddingsProvider::new) + ); return entries; } @@ -573,7 +580,15 @@ public Collection getMappedActionFilters() { } public List> getQueries() { - return List.of(new QuerySpec<>(SemanticQueryBuilder.NAME, SemanticQueryBuilder::new, SemanticQueryBuilder::fromXContent)); + return List.of(new QuerySpec<>(SemanticQueryBuilder.NAME, i -> { + SemanticQueryBuilder queryBuilder = new SemanticQueryBuilder(i); + queryBuilder.setModelRegistrySupplier(getModelRegistry()); + return queryBuilder; + }, p -> { + SemanticQueryBuilder queryBuilder = SemanticQueryBuilder.fromXContent(p); + queryBuilder.setModelRegistrySupplier(getModelRegistry()); + return queryBuilder; + })); } @Override @@ -604,7 +619,6 @@ public Map getHighlighters() { @Override public void onNodeStarted() { var registry = inferenceServiceRegistry.get(); - if (registry != null) { registry.onNodeStarted(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/EmbeddingsProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/EmbeddingsProvider.java new file mode 100644 index 0000000000000..3538c75a49ee4 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/EmbeddingsProvider.java @@ -0,0 +1,15 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.queries; + +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.inference.InferenceResults; + +public interface EmbeddingsProvider extends NamedWriteable { + InferenceResults getEmbeddings(InferenceEndpointKey key); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InferenceEndpointKey.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InferenceEndpointKey.java new file mode 100644 index 0000000000000..8f80e81f0f506 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InferenceEndpointKey.java @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.queries; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.MinimalServiceSettings; + +import java.io.IOException; +import java.util.Objects; + +public class InferenceEndpointKey implements Writeable { + private final String inferenceId; + private final MinimalServiceSettings serviceSettings; + + public InferenceEndpointKey(String inferenceId, MinimalServiceSettings serviceSettings) { + this.inferenceId = inferenceId; + this.serviceSettings = serviceSettings; + } + + public InferenceEndpointKey(StreamInput in) throws IOException { + this.inferenceId = in.readString(); + this.serviceSettings = in.readNamedWriteable(MinimalServiceSettings.class); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(inferenceId); + out.writeNamedWriteable(serviceSettings); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + InferenceEndpointKey that = (InferenceEndpointKey) o; + return Objects.equals(inferenceId, that.inferenceId) && Objects.equals(serviceSettings, that.serviceSettings); + } + + @Override + public int hashCode() { + return Objects.hash(inferenceId, serviceSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/MapEmbeddingsProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/MapEmbeddingsProvider.java new file mode 100644 index 0000000000000..5e6be7ad37eeb --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/MapEmbeddingsProvider.java @@ -0,0 +1,63 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.queries; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.InferenceResults; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +public class MapEmbeddingsProvider implements EmbeddingsProvider { + public static final String NAME = "map_embeddings_provider"; + + private final Map embeddings; + + public MapEmbeddingsProvider() { + this.embeddings = new HashMap<>(); + } + + public MapEmbeddingsProvider(StreamInput in) throws IOException { + this.embeddings = in.readMap(InferenceEndpointKey::new, i -> i.readNamedWriteable(InferenceResults.class)); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeMap(embeddings, StreamOutput::writeWriteable, StreamOutput::writeNamedWriteable); + } + + @Override + public InferenceResults getEmbeddings(InferenceEndpointKey key) { + return embeddings.get(key); + } + + public void addEmbeddings(InferenceEndpointKey key, InferenceResults embeddings) { + this.embeddings.put(key, embeddings); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + MapEmbeddingsProvider that = (MapEmbeddingsProvider) o; + return Objects.equals(embeddings, that.embeddings); + } + + @Override + public int hashCode() { + return Objects.hashCode(embeddings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index f12e50674e5f0..c597650d3ce37 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -8,7 +8,6 @@ package org.elasticsearch.xpack.inference.queries; import org.apache.lucene.search.Query; -import org.apache.lucene.util.SetOnce; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ResolvedIndices; @@ -25,6 +24,7 @@ import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; @@ -36,18 +36,24 @@ import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.io.IOException; import java.util.Collection; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; +import java.util.function.Supplier; import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; +// TODO: Remove noInferenceResults + public class SemanticQueryBuilder extends AbstractQueryBuilder { public static final String NAME = "semantic"; @@ -70,11 +76,12 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder inferenceResultsSupplier; - private final InferenceResults inferenceResults; + private final EmbeddingsProvider embeddingsProvider; private final boolean noInferenceResults; private final Boolean lenient; + private Supplier modelRegistrySupplier = () -> null; + public SemanticQueryBuilder(String fieldName, String query) { this(fieldName, query, null); } @@ -88,8 +95,7 @@ public SemanticQueryBuilder(String fieldName, String query, Boolean lenient) { } this.fieldName = fieldName; this.query = query; - this.inferenceResults = null; - this.inferenceResultsSupplier = null; + this.embeddingsProvider = null; this.noInferenceResults = false; this.lenient = lenient; } @@ -98,9 +104,17 @@ public SemanticQueryBuilder(StreamInput in) throws IOException { super(in); this.fieldName = in.readString(); this.query = in.readString(); - this.inferenceResults = in.readOptionalNamedWriteable(InferenceResults.class); + if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS)) { + this.embeddingsProvider = in.readOptionalNamedWriteable(EmbeddingsProvider.class); + } else { + InferenceResults inferenceResults = in.readOptionalNamedWriteable(InferenceResults.class); + if (inferenceResults != null) { + this.embeddingsProvider = new SingleEmbeddingsProvider(inferenceResults); + } else { + this.embeddingsProvider = null; + } + } this.noInferenceResults = in.readBoolean(); - this.inferenceResultsSupplier = null; if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_QUERY_LENIENT)) { this.lenient = in.readOptionalBoolean(); } else { @@ -108,34 +122,35 @@ public SemanticQueryBuilder(StreamInput in) throws IOException { } } + public void setModelRegistrySupplier(Supplier supplier) { + modelRegistrySupplier = supplier; + } + @Override protected void doWriteTo(StreamOutput out) throws IOException { - if (inferenceResultsSupplier != null) { - throw new IllegalStateException("Inference results supplier is set. Missing a rewriteAndFetch?"); - } out.writeString(fieldName); out.writeString(query); - out.writeOptionalNamedWriteable(inferenceResults); + if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS)) { + out.writeOptionalNamedWriteable(embeddingsProvider); + } else { + // TODO: Handle multiple inference IDs in a mixed-version cluster + throw new UnsupportedOperationException("Handle old transport versions"); + } out.writeBoolean(noInferenceResults); if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_QUERY_LENIENT)) { out.writeOptionalBoolean(lenient); } } - private SemanticQueryBuilder( - SemanticQueryBuilder other, - SetOnce inferenceResultsSupplier, - InferenceResults inferenceResults, - boolean noInferenceResults - ) { + private SemanticQueryBuilder(SemanticQueryBuilder other, EmbeddingsProvider embeddingsProvider, boolean noInferenceResults) { this.fieldName = other.fieldName; this.query = other.query; this.boost = other.boost; this.queryName = other.queryName; - this.inferenceResultsSupplier = inferenceResultsSupplier; - this.inferenceResults = inferenceResults; + this.embeddingsProvider = embeddingsProvider; this.noInferenceResults = noInferenceResults; this.lenient = other.lenient; + this.modelRegistrySupplier = other.modelRegistrySupplier; } @Override @@ -187,13 +202,36 @@ private QueryBuilder doRewriteBuildSemanticQuery(SearchExecutionContext searchEx if (fieldType == null) { return new MatchNoneQueryBuilder(); } else if (fieldType instanceof SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType) { - if (inferenceResults == null) { + if (embeddingsProvider == null) { // This should never happen, but throw on it in case it ever does throw new IllegalStateException( "No inference results set for [" + semanticTextFieldType.typeName() + "] field [" + fieldName + "]" ); } + ModelRegistry modelRegistry = modelRegistrySupplier.get(); + if (modelRegistry == null) { + throw new IllegalStateException("Model registry has not been set"); + } + + String inferenceId = semanticTextFieldType.getSearchInferenceId(); + MinimalServiceSettings serviceSettings = modelRegistry.getMinimalServiceSettings(inferenceId); + InferenceEndpointKey inferenceEndpointKey = new InferenceEndpointKey(inferenceId, serviceSettings); + InferenceResults inferenceResults = embeddingsProvider.getEmbeddings(inferenceEndpointKey); + + // TODO: Handle ErrorInferenceResults and WarningInferenceResults + if (inferenceResults == null) { + throw new IllegalStateException( + "No inference results set for [" + + semanticTextFieldType.typeName() + + "] field [" + + fieldName + + "] with inference ID [" + + inferenceId + + "]" + ); + } + return semanticTextFieldType.semanticQuery(inferenceResults, searchExecutionContext.requestSize(), boost(), queryName()); } else if (lenient != null && lenient) { return new MatchNoneQueryBuilder(); @@ -205,95 +243,122 @@ private QueryBuilder doRewriteBuildSemanticQuery(SearchExecutionContext searchEx } private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewriteContext) { - if (inferenceResults != null || noInferenceResults) { + // Check that we are performing a coordinator node rewrite + // TODO: Clean up how we perform this check + if (queryRewriteContext.getClass() != QueryRewriteContext.class) { return this; } - if (inferenceResultsSupplier != null) { - InferenceResults inferenceResults = validateAndConvertInferenceResults(inferenceResultsSupplier, fieldName); - return inferenceResults != null ? new SemanticQueryBuilder(this, null, inferenceResults, noInferenceResults) : this; - } - ResolvedIndices resolvedIndices = queryRewriteContext.getResolvedIndices(); if (resolvedIndices == null) { throw new IllegalStateException( "Rewriting on the coordinator node requires a query rewrite context with non-null resolved indices" ); } else if (resolvedIndices.getRemoteClusterIndices().isEmpty() == false) { - throw new IllegalArgumentException(NAME + " query does not support cross-cluster search"); + if (queryRewriteContext.isCcsMinimizeRoundtrips() != true) { + throw new IllegalArgumentException(NAME + " query supports CCS only when ccs_minimize_roundtrips=true"); + } } - String inferenceId = getInferenceIdForForField(resolvedIndices.getConcreteLocalIndicesMetadata().values(), fieldName); - SetOnce inferenceResultsSupplier = new SetOnce<>(); - boolean noInferenceResults = false; - if (inferenceId != null) { - InferenceAction.Request inferenceRequest = new InferenceAction.Request( - TaskType.ANY, - inferenceId, - null, - null, - null, - List.of(query), - Map.of(), - InputType.INTERNAL_SEARCH, - null, - false - ); - - queryRewriteContext.registerAsyncAction( - (client, listener) -> executeAsyncWithOrigin( - client, - ML_ORIGIN, - InferenceAction.INSTANCE, - inferenceRequest, - listener.delegateFailureAndWrap((l, inferenceResponse) -> { - inferenceResultsSupplier.set(inferenceResponse.getResults()); - l.onResponse(null); - }) - ) - ); + MapEmbeddingsProvider currentEmbeddingsProvider; + if (embeddingsProvider != null) { + if (embeddingsProvider instanceof MapEmbeddingsProvider mapEmbeddingsProvider) { + currentEmbeddingsProvider = mapEmbeddingsProvider; + } else { + throw new IllegalStateException("Current embeddings provider should be a MapEmbeddingsProvider"); + } } else { - // The inference ID can be null if either the field name or index name(s) are invalid (or both). - // If this happens, we set the "no inference results" flag to true so the rewrite process can continue. - // Invalid index names will be handled in the transport layer, when the query is sent to the shard. - // Invalid field names will be handled when the query is re-written on the shard, where we have access to the index mappings. - noInferenceResults = true; + currentEmbeddingsProvider = new MapEmbeddingsProvider(); } - return new SemanticQueryBuilder(this, noInferenceResults ? null : inferenceResultsSupplier, null, noInferenceResults); + boolean modified = false; + if (queryRewriteContext.hasAsyncActions() == false) { + ModelRegistry modelRegistry = modelRegistrySupplier.get(); + if (modelRegistry == null) { + throw new IllegalStateException("Model registry has not been set"); + } + + Set inferenceIds = getInferenceIdsForForField(resolvedIndices.getConcreteLocalIndicesMetadata().values(), fieldName); + for (String inferenceId : inferenceIds) { + MinimalServiceSettings serviceSettings = modelRegistry.getMinimalServiceSettings(inferenceId); + InferenceEndpointKey inferenceEndpointKey = new InferenceEndpointKey(inferenceId, serviceSettings); + + if (currentEmbeddingsProvider.getEmbeddings(inferenceEndpointKey) == null) { + InferenceAction.Request inferenceRequest = new InferenceAction.Request( + TaskType.ANY, + inferenceId, + null, + null, + null, + List.of(query), + Map.of(), + InputType.INTERNAL_SEARCH, + null, + false + ); + + queryRewriteContext.registerAsyncAction( + (client, listener) -> executeAsyncWithOrigin( + client, + ML_ORIGIN, + InferenceAction.INSTANCE, + inferenceRequest, + listener.delegateFailureAndWrap((l, inferenceResponse) -> { + currentEmbeddingsProvider.addEmbeddings( + inferenceEndpointKey, + validateAndConvertInferenceResults(inferenceResponse.getResults(), fieldName, inferenceId) + ); + l.onResponse(null); + }) + ) + ); + + modified = true; + } + } + } + + return modified ? new SemanticQueryBuilder(this, currentEmbeddingsProvider, false) : this; } private static InferenceResults validateAndConvertInferenceResults( - SetOnce inferenceResultsSupplier, - String fieldName + InferenceServiceResults inferenceServiceResults, + String fieldName, + String inferenceId ) { - InferenceServiceResults inferenceServiceResults = inferenceResultsSupplier.get(); - if (inferenceServiceResults == null) { - return null; - } - List inferenceResultsList = inferenceServiceResults.transformToCoordinationFormat(); if (inferenceResultsList.isEmpty()) { - throw new IllegalArgumentException("No inference results retrieved for field [" + fieldName + "]"); + return new ErrorInferenceResults( + new IllegalArgumentException( + "No inference results retrieved for field [" + fieldName + "] with inference ID [" + inferenceId + "]" + ) + ); } else if (inferenceResultsList.size() > 1) { // The inference call should truncate if the query is too large. // Thus, if we receive more than one inference result, it is a server-side error. - throw new IllegalStateException(inferenceResultsList.size() + " inference results retrieved for field [" + fieldName + "]"); + return new ErrorInferenceResults( + new IllegalStateException( + inferenceResultsList.size() + + " inference results retrieved for field [" + + fieldName + + "] with inference ID [" + + inferenceId + + "]" + ) + ); } InferenceResults inferenceResults = inferenceResultsList.get(0); - if (inferenceResults instanceof ErrorInferenceResults errorInferenceResults) { - throw new IllegalStateException( - "Field [" + fieldName + "] query inference error: " + errorInferenceResults.getException().getMessage(), - errorInferenceResults.getException() - ); - } else if (inferenceResults instanceof WarningInferenceResults warningInferenceResults) { - throw new IllegalStateException("Field [" + fieldName + "] query inference warning: " + warningInferenceResults.getWarning()); - } else if (inferenceResults instanceof TextExpansionResults == false - && inferenceResults instanceof MlTextEmbeddingResults == false) { - throw new IllegalArgumentException( + if (inferenceResults instanceof TextExpansionResults == false + && inferenceResults instanceof MlTextEmbeddingResults == false + && inferenceResults instanceof ErrorInferenceResults == false + && inferenceResults instanceof WarningInferenceResults == false) { + return new ErrorInferenceResults( + new IllegalArgumentException( "Field [" + fieldName + + "] with inference ID [" + + inferenceId + "] expected query inference results to be of type [" + TextExpansionResults.NAME + "] or [" @@ -301,8 +366,9 @@ private static InferenceResults validateAndConvertInferenceResults( + "], got [" + inferenceResults.getWriteableName() + "]. Has the inference endpoint configuration changed?" - ); - } + ) + ); + } return inferenceResults; } @@ -312,33 +378,30 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { throw new IllegalStateException(NAME + " should have been rewritten to another query type"); } - private static String getInferenceIdForForField(Collection indexMetadataCollection, String fieldName) { - String inferenceId = null; + private static Set getInferenceIdsForForField(Collection indexMetadataCollection, String fieldName) { + Set inferenceIds = new HashSet<>(); for (IndexMetadata indexMetadata : indexMetadataCollection) { InferenceFieldMetadata inferenceFieldMetadata = indexMetadata.getInferenceFields().get(fieldName); String indexInferenceId = inferenceFieldMetadata != null ? inferenceFieldMetadata.getSearchInferenceId() : null; if (indexInferenceId != null) { - if (inferenceId != null && inferenceId.equals(indexInferenceId) == false) { - throw new IllegalArgumentException("Field [" + fieldName + "] has multiple inference IDs associated with it"); - } - - inferenceId = indexInferenceId; + inferenceIds.add(indexInferenceId); } } - return inferenceId; + return inferenceIds; } @Override protected boolean doEquals(SemanticQueryBuilder other) { return Objects.equals(fieldName, other.fieldName) && Objects.equals(query, other.query) - && Objects.equals(inferenceResults, other.inferenceResults) - && Objects.equals(inferenceResultsSupplier, other.inferenceResultsSupplier); + && Objects.equals(embeddingsProvider, other.embeddingsProvider) + && Objects.equals(noInferenceResults, other.noInferenceResults) + && Objects.equals(lenient, other.lenient); } @Override protected int doHashCode() { - return Objects.hash(fieldName, query, inferenceResults, inferenceResultsSupplier); + return Objects.hash(fieldName, query, embeddingsProvider, noInferenceResults, lenient); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SingleEmbeddingsProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SingleEmbeddingsProvider.java new file mode 100644 index 0000000000000..5b7f4535d23dc --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SingleEmbeddingsProvider.java @@ -0,0 +1,57 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.queries; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.InferenceResults; + +import java.io.IOException; +import java.util.Objects; + +public class SingleEmbeddingsProvider implements EmbeddingsProvider { + public static final String NAME = "single_embeddings_provider"; + + private final InferenceResults embeddings; + + public SingleEmbeddingsProvider(InferenceResults embeddings) { + this.embeddings = embeddings; + } + + public SingleEmbeddingsProvider(StreamInput in) throws IOException { + this.embeddings = in.readNamedWriteable(InferenceResults.class); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeNamedWriteable(embeddings); + } + + @Override + public InferenceResults getEmbeddings(InferenceEndpointKey key) { + return embeddings; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SingleEmbeddingsProvider that = (SingleEmbeddingsProvider) o; + return Objects.equals(embeddings, that.embeddings); + } + + @Override + public int hashCode() { + return Objects.hashCode(embeddings); + } +} diff --git a/x-pack/plugin/inference/src/main/plugin-metadata/entitlement-policy.yaml b/x-pack/plugin/inference/src/main/plugin-metadata/entitlement-policy.yaml index 36ac851acf1ea..cd046c4a6cb23 100644 --- a/x-pack/plugin/inference/src/main/plugin-metadata/entitlement-policy.yaml +++ b/x-pack/plugin/inference/src/main/plugin-metadata/entitlement-policy.yaml @@ -12,6 +12,7 @@ software.amazon.awssdk.http.nio.netty: io.netty.common: - outbound_network - manage_threads + - inbound_network - files: - path: "/etc/os-release" mode: "read" @@ -22,3 +23,4 @@ io.netty.common: io.netty.transport: - manage_threads - outbound_network + - inbound_network diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java index 1f0b56e3d6848..0d9669e9e2819 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java @@ -151,7 +151,7 @@ private QueryRewriteContext createQueryRewriteContext(Map