From 482aad83a0c9830096dbbc224efa337c05c8e92f Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Fri, 18 Jul 2025 09:15:39 -0400 Subject: [PATCH 01/35] Added test and instrumentation --- .../search/ccs/CrossClusterSearchIT.java | 18 ++++++++++++++++++ .../action/search/TransportSearchAction.java | 4 ++-- .../index/query/MatchQueryBuilder.java | 11 +++++++++++ 3 files changed, 31 insertions(+), 2 deletions(-) diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java index d4f60a868dcd4..042d677837457 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java @@ -25,6 +25,7 @@ import org.elasticsearch.common.util.CollectionUtils; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.query.MatchAllQueryBuilder; +import org.elasticsearch.index.query.MatchQueryBuilder; import org.elasticsearch.index.query.RangeQueryBuilder; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.SearchPlugin; @@ -161,6 +162,23 @@ public void testClusterDetailsAfterSuccessfulCCS() throws Exception { }); } + public void testCCSQueryRewrite() throws Exception { + Map testClusterInfo = setupTwoClusters(); + String localIndex = (String) testClusterInfo.get("local.index"); + String remoteIndex = (String) testClusterInfo.get("remote.index"); + int localNumShards = (Integer) testClusterInfo.get("local.num_shards"); + int remoteNumShards = (Integer) testClusterInfo.get("remote.num_shards"); + + SearchRequest searchRequest = new SearchRequest(localIndex, REMOTE_CLUSTER + ":" + remoteIndex); + //searchRequest.setCcsMinimizeRoundtrips(false); + + searchRequest.source(new SearchSourceBuilder().query(new MatchQueryBuilder("foo", "bar")).size(10)); + + assertResponse(client(LOCAL_CLUSTER).search(searchRequest), response -> { + assertNotNull(response); + }); + } + // CCS with a search where the timestamp of the query cannot match so should be SUCCESSFUL with all shards skipped // during can-match public void testCCSClusterDetailsWhereAllShardsSkippedInCanMatch() throws Exception { diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java index 69260bcac105c..2c3c077b2a940 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java @@ -732,7 +732,7 @@ public void onFailure(Exception e) { OriginalIndices indices = entry.getValue(); SearchRequest ccsSearchRequest = SearchRequest.subSearchRequest( parentTaskId, - searchRequest, + searchRequest, // TODO: Need to prep the request here by stripping inference results? indices.indices(), clusterAlias, timeProvider.absoluteStartMillis(), @@ -859,7 +859,7 @@ Map createFinalResponse() { SearchShardsRequest searchShardsRequest = new SearchShardsRequest( indices, indicesOptions, - query, + query, // TODO: Need to prep the query here by stripping inference results? routing, preference, allowPartialResults, diff --git a/server/src/main/java/org/elasticsearch/index/query/MatchQueryBuilder.java b/server/src/main/java/org/elasticsearch/index/query/MatchQueryBuilder.java index fd704d39ca384..eb5f1bbbc3230 100644 --- a/server/src/main/java/org/elasticsearch/index/query/MatchQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/MatchQueryBuilder.java @@ -14,6 +14,7 @@ import org.apache.lucene.search.Query; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ResolvedIndices; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -364,6 +365,16 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio builder.endObject(); } + @Override + protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + ResolvedIndices resolvedIndices = queryRewriteContext.getResolvedIndices(); + if (resolvedIndices != null) { + return this; + } else { + return super.doRewrite(queryRewriteContext); + } + } + @Override protected QueryBuilder doIndexMetadataRewrite(QueryRewriteContext context) throws IOException { if (fuzziness != null || lenient) { From 43174f14d6b51ec665329fe87466732ca9727551 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Fri, 18 Jul 2025 15:01:55 -0400 Subject: [PATCH 02/35] Store original query in InterceptedQueryBuilderWrapper --- .../index/query/AbstractQueryBuilder.java | 2 +- .../index/query/InnerHitContextBuilder.java | 2 +- .../query/InterceptedQueryBuilderWrapper.java | 43 +++++++++++-------- .../InterceptedQueryBuilderWrapperTests.java | 8 ++-- ...KnnVectorQueryRewriteInterceptorTests.java | 4 +- ...nticMatchQueryRewriteInterceptorTests.java | 4 +- ...rseVectorQueryRewriteInterceptorTests.java | 8 ++-- 7 files changed, 39 insertions(+), 32 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/query/AbstractQueryBuilder.java b/server/src/main/java/org/elasticsearch/index/query/AbstractQueryBuilder.java index 05262798bac2a..ccb3e076c5798 100644 --- a/server/src/main/java/org/elasticsearch/index/query/AbstractQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/AbstractQueryBuilder.java @@ -283,7 +283,7 @@ public final QueryBuilder rewrite(QueryRewriteContext queryRewriteContext) throw if (queryRewriteInterceptor != null) { var rewritten = queryRewriteInterceptor.interceptAndRewrite(queryRewriteContext, this); if (rewritten != this) { - return new InterceptedQueryBuilderWrapper(rewritten); + return new InterceptedQueryBuilderWrapper(rewritten, this); } } diff --git a/server/src/main/java/org/elasticsearch/index/query/InnerHitContextBuilder.java b/server/src/main/java/org/elasticsearch/index/query/InnerHitContextBuilder.java index 31bc7dddacb7f..947c862d5b2b4 100644 --- a/server/src/main/java/org/elasticsearch/index/query/InnerHitContextBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/InnerHitContextBuilder.java @@ -68,7 +68,7 @@ public static void extractInnerHits(QueryBuilder query, Map) query).extractInnerHitBuilders(innerHitBuilders); } else if (query instanceof InterceptedQueryBuilderWrapper interceptedQuery) { // Unwrap an intercepted query here - extractInnerHits(interceptedQuery.queryBuilder, innerHitBuilders); + extractInnerHits(interceptedQuery.rewritten, innerHitBuilders); } else { throw new IllegalStateException( "provided query builder [" + query.getClass() + "] class should inherit from AbstractQueryBuilder, but it doesn't" diff --git a/server/src/main/java/org/elasticsearch/index/query/InterceptedQueryBuilderWrapper.java b/server/src/main/java/org/elasticsearch/index/query/InterceptedQueryBuilderWrapper.java index b1030e4a76d97..9c72675e476db 100644 --- a/server/src/main/java/org/elasticsearch/index/query/InterceptedQueryBuilderWrapper.java +++ b/server/src/main/java/org/elasticsearch/index/query/InterceptedQueryBuilderWrapper.java @@ -24,11 +24,17 @@ */ class InterceptedQueryBuilderWrapper implements QueryBuilder { - protected final QueryBuilder queryBuilder; + protected final QueryBuilder original; + protected final QueryBuilder rewritten; - InterceptedQueryBuilderWrapper(QueryBuilder queryBuilder) { + InterceptedQueryBuilderWrapper(QueryBuilder rewritten, QueryBuilder original) { super(); - this.queryBuilder = queryBuilder; + this.original = original; + this.rewritten = rewritten; + } + + public QueryBuilder getOriginal() { + return original; } @Override @@ -36,8 +42,8 @@ public QueryBuilder rewrite(QueryRewriteContext queryRewriteContext) throws IOEx QueryRewriteInterceptor queryRewriteInterceptor = queryRewriteContext.getQueryRewriteInterceptor(); try { queryRewriteContext.setQueryRewriteInterceptor(null); - QueryBuilder rewritten = queryBuilder.rewrite(queryRewriteContext); - return rewritten != queryBuilder ? new InterceptedQueryBuilderWrapper(rewritten) : this; + QueryBuilder rewritten = this.rewritten.rewrite(queryRewriteContext); + return rewritten != this.rewritten ? new InterceptedQueryBuilderWrapper(rewritten, original) : this; } finally { queryRewriteContext.setQueryRewriteInterceptor(queryRewriteInterceptor); } @@ -45,65 +51,66 @@ public QueryBuilder rewrite(QueryRewriteContext queryRewriteContext) throws IOEx @Override public String getWriteableName() { - return queryBuilder.getWriteableName(); + return rewritten.getWriteableName(); } @Override public TransportVersion getMinimalSupportedVersion() { - return queryBuilder.getMinimalSupportedVersion(); + return rewritten.getMinimalSupportedVersion(); } @Override public Query toQuery(SearchExecutionContext context) throws IOException { - return queryBuilder.toQuery(context); + return rewritten.toQuery(context); } @Override public QueryBuilder queryName(String queryName) { - queryBuilder.queryName(queryName); + rewritten.queryName(queryName); return this; } @Override public String queryName() { - return queryBuilder.queryName(); + return rewritten.queryName(); } @Override public float boost() { - return queryBuilder.boost(); + return rewritten.boost(); } @Override public QueryBuilder boost(float boost) { - queryBuilder.boost(boost); + rewritten.boost(boost); return this; } @Override public String getName() { - return queryBuilder.getName(); + return rewritten.getName(); } @Override public void writeTo(StreamOutput out) throws IOException { - queryBuilder.writeTo(out); + rewritten.writeTo(out); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - return queryBuilder.toXContent(builder, params); + return rewritten.toXContent(builder, params); } @Override public boolean equals(Object o) { if (this == o) return true; - if (o instanceof InterceptedQueryBuilderWrapper == false) return false; - return Objects.equals(queryBuilder, ((InterceptedQueryBuilderWrapper) o).queryBuilder); + if (o == null || getClass() != o.getClass()) return false; + InterceptedQueryBuilderWrapper that = (InterceptedQueryBuilderWrapper) o; + return Objects.equals(original, that.original) && Objects.equals(rewritten, that.rewritten); } @Override public int hashCode() { - return Objects.hashCode(queryBuilder); + return Objects.hash(original, rewritten); } } diff --git a/server/src/test/java/org/elasticsearch/index/query/InterceptedQueryBuilderWrapperTests.java b/server/src/test/java/org/elasticsearch/index/query/InterceptedQueryBuilderWrapperTests.java index 6c570e0e71725..7a8a580bbbca9 100644 --- a/server/src/test/java/org/elasticsearch/index/query/InterceptedQueryBuilderWrapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/query/InterceptedQueryBuilderWrapperTests.java @@ -36,7 +36,7 @@ public void cleanup() { public void testQueryNameReturnsWrappedQueryBuilder() { MatchAllQueryBuilder matchAllQueryBuilder = new MatchAllQueryBuilder(); - InterceptedQueryBuilderWrapper interceptedQueryBuilderWrapper = new InterceptedQueryBuilderWrapper(matchAllQueryBuilder); + InterceptedQueryBuilderWrapper interceptedQueryBuilderWrapper = new InterceptedQueryBuilderWrapper(matchAllQueryBuilder, matchAllQueryBuilder); String queryName = randomAlphaOfLengthBetween(5, 10); QueryBuilder namedQuery = interceptedQueryBuilderWrapper.queryName(queryName); assertTrue(namedQuery instanceof InterceptedQueryBuilderWrapper); @@ -45,7 +45,7 @@ public void testQueryNameReturnsWrappedQueryBuilder() { public void testQueryBoostReturnsWrappedQueryBuilder() { MatchAllQueryBuilder matchAllQueryBuilder = new MatchAllQueryBuilder(); - InterceptedQueryBuilderWrapper interceptedQueryBuilderWrapper = new InterceptedQueryBuilderWrapper(matchAllQueryBuilder); + InterceptedQueryBuilderWrapper interceptedQueryBuilderWrapper = new InterceptedQueryBuilderWrapper(matchAllQueryBuilder, matchAllQueryBuilder); float boost = randomFloat(); QueryBuilder boostedQuery = interceptedQueryBuilderWrapper.boost(boost); assertTrue(boostedQuery instanceof InterceptedQueryBuilderWrapper); @@ -65,8 +65,8 @@ public void testRewrite() throws IOException { MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder("field", "value"); rewritten = matchQueryBuilder.rewrite(context); assertTrue(rewritten instanceof InterceptedQueryBuilderWrapper); - assertTrue(((InterceptedQueryBuilderWrapper) rewritten).queryBuilder instanceof MatchQueryBuilder); - MatchQueryBuilder rewrittenMatchQueryBuilder = (MatchQueryBuilder) ((InterceptedQueryBuilderWrapper) rewritten).queryBuilder; + assertTrue(((InterceptedQueryBuilderWrapper) rewritten).rewritten instanceof MatchQueryBuilder); + MatchQueryBuilder rewrittenMatchQueryBuilder = (MatchQueryBuilder) ((InterceptedQueryBuilderWrapper) rewritten).rewritten; assertEquals("intercepted", rewrittenMatchQueryBuilder.value()); // An additional rewrite on an already intercepted query returns the same query diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java index 270cdba6d3469..0169b4ad5d87d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java @@ -82,8 +82,8 @@ private void testRewrittenInferenceQuery(QueryRewriteContext context, KnnVectorQ rewritten instanceof InterceptedQueryBuilderWrapper ); InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten; - assertTrue(intercepted.queryBuilder instanceof NestedQueryBuilder); - NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder; + assertTrue(intercepted.rewritten instanceof NestedQueryBuilder); + NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.rewritten; assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path()); QueryBuilder innerQuery = nestedQueryBuilder.query(); assertTrue(innerQuery instanceof KnnVectorQueryBuilder); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticMatchQueryRewriteInterceptorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticMatchQueryRewriteInterceptorTests.java index 6987ef33ed63d..108df54b627a2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticMatchQueryRewriteInterceptorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticMatchQueryRewriteInterceptorTests.java @@ -62,8 +62,8 @@ public void testMatchQueryOnInferenceFieldIsInterceptedAndRewrittenToSemanticQue rewritten instanceof InterceptedQueryBuilderWrapper ); InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten; - assertTrue(intercepted.queryBuilder instanceof SemanticQueryBuilder); - SemanticQueryBuilder semanticQueryBuilder = (SemanticQueryBuilder) intercepted.queryBuilder; + assertTrue(intercepted.rewritten instanceof SemanticQueryBuilder); + SemanticQueryBuilder semanticQueryBuilder = (SemanticQueryBuilder) intercepted.rewritten; assertEquals(FIELD_NAME, semanticQueryBuilder.getFieldName()); assertEquals(VALUE, semanticQueryBuilder.getQuery()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticSparseVectorQueryRewriteInterceptorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticSparseVectorQueryRewriteInterceptorTests.java index 075955766a0a9..8d5452b28cc8d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticSparseVectorQueryRewriteInterceptorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticSparseVectorQueryRewriteInterceptorTests.java @@ -64,8 +64,8 @@ public void testSparseVectorQueryOnInferenceFieldIsInterceptedAndRewritten() thr rewritten instanceof InterceptedQueryBuilderWrapper ); InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten; - assertTrue(intercepted.queryBuilder instanceof NestedQueryBuilder); - NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder; + assertTrue(intercepted.rewritten instanceof NestedQueryBuilder); + NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.rewritten; assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path()); QueryBuilder innerQuery = nestedQueryBuilder.query(); assertTrue(innerQuery instanceof SparseVectorQueryBuilder); @@ -88,8 +88,8 @@ public void testSparseVectorQueryOnInferenceFieldWithoutInferenceIdIsIntercepted rewritten instanceof InterceptedQueryBuilderWrapper ); InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten; - assertTrue(intercepted.queryBuilder instanceof NestedQueryBuilder); - NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder; + assertTrue(intercepted.rewritten instanceof NestedQueryBuilder); + NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.rewritten; assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path()); QueryBuilder innerQuery = nestedQueryBuilder.query(); assertTrue(innerQuery instanceof SparseVectorQueryBuilder); From a922ed500614add4214305c293d133b86e91c720 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Fri, 18 Jul 2025 15:02:16 -0400 Subject: [PATCH 03/35] Spotless --- .../elasticsearch/search/ccs/CrossClusterSearchIT.java | 6 ++---- .../query/InterceptedQueryBuilderWrapperTests.java | 10 ++++++++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java index 042d677837457..4c7f402c6636e 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java @@ -170,13 +170,11 @@ public void testCCSQueryRewrite() throws Exception { int remoteNumShards = (Integer) testClusterInfo.get("remote.num_shards"); SearchRequest searchRequest = new SearchRequest(localIndex, REMOTE_CLUSTER + ":" + remoteIndex); - //searchRequest.setCcsMinimizeRoundtrips(false); + // searchRequest.setCcsMinimizeRoundtrips(false); searchRequest.source(new SearchSourceBuilder().query(new MatchQueryBuilder("foo", "bar")).size(10)); - assertResponse(client(LOCAL_CLUSTER).search(searchRequest), response -> { - assertNotNull(response); - }); + assertResponse(client(LOCAL_CLUSTER).search(searchRequest), response -> { assertNotNull(response); }); } // CCS with a search where the timestamp of the query cannot match so should be SUCCESSFUL with all shards skipped diff --git a/server/src/test/java/org/elasticsearch/index/query/InterceptedQueryBuilderWrapperTests.java b/server/src/test/java/org/elasticsearch/index/query/InterceptedQueryBuilderWrapperTests.java index 7a8a580bbbca9..891a3b850bf56 100644 --- a/server/src/test/java/org/elasticsearch/index/query/InterceptedQueryBuilderWrapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/query/InterceptedQueryBuilderWrapperTests.java @@ -36,7 +36,10 @@ public void cleanup() { public void testQueryNameReturnsWrappedQueryBuilder() { MatchAllQueryBuilder matchAllQueryBuilder = new MatchAllQueryBuilder(); - InterceptedQueryBuilderWrapper interceptedQueryBuilderWrapper = new InterceptedQueryBuilderWrapper(matchAllQueryBuilder, matchAllQueryBuilder); + InterceptedQueryBuilderWrapper interceptedQueryBuilderWrapper = new InterceptedQueryBuilderWrapper( + matchAllQueryBuilder, + matchAllQueryBuilder + ); String queryName = randomAlphaOfLengthBetween(5, 10); QueryBuilder namedQuery = interceptedQueryBuilderWrapper.queryName(queryName); assertTrue(namedQuery instanceof InterceptedQueryBuilderWrapper); @@ -45,7 +48,10 @@ public void testQueryNameReturnsWrappedQueryBuilder() { public void testQueryBoostReturnsWrappedQueryBuilder() { MatchAllQueryBuilder matchAllQueryBuilder = new MatchAllQueryBuilder(); - InterceptedQueryBuilderWrapper interceptedQueryBuilderWrapper = new InterceptedQueryBuilderWrapper(matchAllQueryBuilder, matchAllQueryBuilder); + InterceptedQueryBuilderWrapper interceptedQueryBuilderWrapper = new InterceptedQueryBuilderWrapper( + matchAllQueryBuilder, + matchAllQueryBuilder + ); float boost = randomFloat(); QueryBuilder boostedQuery = interceptedQueryBuilderWrapper.boost(boost); assertTrue(boostedQuery instanceof InterceptedQueryBuilderWrapper); From d9f37b74e1eda9ca546622328564acb9630b5017 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Fri, 18 Jul 2025 15:37:25 -0400 Subject: [PATCH 04/35] Added SemanticCrossClusterSearchIT --- .../ccs/SemanticCrossClusterSearchIT.java | 178 ++++++++++++++++++ 1 file changed, 178 insertions(+) create mode 100644 x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java new file mode 100644 index 0000000000000..8a842ec3973fb --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java @@ -0,0 +1,178 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.search.ccs; + +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.Setting; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.CollectionUtils; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.license.LicenseSettings; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.test.AbstractMultiClustersTestCase; +import org.elasticsearch.test.InternalTestCluster; +import org.elasticsearch.transport.RemoteClusterAware; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; +import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; +import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension; +import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; + +import java.io.IOException; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.hamcrest.CoreMatchers.equalTo; + +public class SemanticCrossClusterSearchIT extends AbstractMultiClustersTestCase { + private static final String REMOTE_CLUSTER = "cluster_a"; + private static final String INFERENCE_FIELD = "inference_field"; + + private static final Map BBQ_COMPATIBLE_SERVICE_SETTINGS = Map.of( + "model", + "my_model", + "dimensions", + 256, + "similarity", + "cosine", + "api_key", + "my_api_key" + ); + + @Override + protected List remoteClusterAlias() { + return List.of(REMOTE_CLUSTER); + } + + @Override + protected Map skipUnavailableForRemoteClusters() { + return Map.of(REMOTE_CLUSTER, randomBoolean()); + } + + @Override + protected boolean reuseClusters() { + return false; + } + + @Override + protected Settings nodeSettings() { + return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build(); + } + + @Override + protected Collection> nodePlugins(String clusterAlias) { + return CollectionUtils.appendToCopy(super.nodePlugins(clusterAlias), LocalStateInferencePlugin.class); + } + + private Map setupTwoClusters(String[] localIndices, String[] remoteIndices) throws IOException { + final String inferenceId = "test_inference_id"; + createInferenceEndpoint(client(LOCAL_CLUSTER), TaskType.TEXT_EMBEDDING, inferenceId, BBQ_COMPATIBLE_SERVICE_SETTINGS); + createInferenceEndpoint(client(REMOTE_CLUSTER), TaskType.TEXT_EMBEDDING, inferenceId, BBQ_COMPATIBLE_SERVICE_SETTINGS); + + int numShardsLocal = randomIntBetween(2, 10); + Settings localSettings = indexSettings(numShardsLocal, randomIntBetween(0, 1)).build(); + for (String localIndex : localIndices) { + assertAcked( + client(LOCAL_CLUSTER).admin() + .indices() + .prepareCreate(localIndex) + .setSettings(localSettings) + .setMapping(INFERENCE_FIELD, "type=semantic_text", "f", "type=text,inference_id=" + inferenceId) + ); + indexDocs(client(LOCAL_CLUSTER), localIndex); + } + + int numShardsRemote = randomIntBetween(2, 10); + final InternalTestCluster remoteCluster = cluster(REMOTE_CLUSTER); + remoteCluster.ensureAtLeastNumDataNodes(randomIntBetween(1, 3)); + for (String remoteIndex : remoteIndices) { + assertAcked( + client(REMOTE_CLUSTER).admin() + .indices() + .prepareCreate(remoteIndex) + .setSettings(indexSettings(numShardsRemote, randomIntBetween(0, 1))) + .setMapping(INFERENCE_FIELD, "type=semantic_text", "f", "type=text,inference_id=" + inferenceId) + ); + assertFalse( + client(REMOTE_CLUSTER).admin() + .cluster() + .prepareHealth(TEST_REQUEST_TIMEOUT, remoteIndex) + .setWaitForYellowStatus() + .setTimeout(TimeValue.timeValueSeconds(10)) + .get() + .isTimedOut() + ); + indexDocs(client(REMOTE_CLUSTER), remoteIndex); + } + + String skipUnavailableKey = Strings.format("cluster.remote.%s.skip_unavailable", REMOTE_CLUSTER); + Setting skipUnavailableSetting = cluster(REMOTE_CLUSTER).clusterService().getClusterSettings().get(skipUnavailableKey); + boolean skipUnavailable = (boolean) cluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY).clusterService() + .getClusterSettings() + .get(skipUnavailableSetting); + + Map clusterInfo = new HashMap<>(); + clusterInfo.put("local.num_shards", numShardsLocal); + clusterInfo.put("remote.num_shards", numShardsRemote); + clusterInfo.put("remote.skip_unavailable", skipUnavailable); + return clusterInfo; + } + + private Map setupTwoClusters() throws IOException { + var clusterInfo = setupTwoClusters(new String[] { "demo" }, new String[] { "prod" }); + clusterInfo.put("local.index", "demo"); + clusterInfo.put("remote.index", "prod"); + return clusterInfo; + } + + private void createInferenceEndpoint(Client client, TaskType taskType, String inferenceId, Map serviceSettings) + throws IOException { + final String service = switch (taskType) { + case TEXT_EMBEDDING -> TestDenseInferenceServiceExtension.TestInferenceService.NAME; + case SPARSE_EMBEDDING -> TestSparseInferenceServiceExtension.TestInferenceService.NAME; + default -> throw new IllegalArgumentException("Unhandled task type [" + taskType + "]"); + }; + + final BytesReference content; + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + builder.startObject(); + builder.field("service", service); + builder.field("service_settings", serviceSettings); + builder.endObject(); + + content = BytesReference.bytes(builder); + } + + PutInferenceModelAction.Request request = new PutInferenceModelAction.Request( + taskType, + inferenceId, + content, + XContentType.JSON, + TEST_REQUEST_TIMEOUT + ); + var responseFuture = client.execute(PutInferenceModelAction.INSTANCE, request); + assertThat(responseFuture.actionGet(TEST_REQUEST_TIMEOUT).getModel().getInferenceEntityId(), equalTo(inferenceId)); + } + + private int indexDocs(Client client, String index) { + int numDocs = between(5, 10); + for (int i = 0; i < numDocs; i++) { + client.prepareIndex(index).setSource(INFERENCE_FIELD, randomAlphaOfLength(10)).get(); + } + client.admin().indices().prepareRefresh(index).get(); + return numDocs; + } +} From b94b41c135f2ab94472f82e0310a635e388d75b4 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Fri, 18 Jul 2025 15:47:59 -0400 Subject: [PATCH 05/35] test development --- .../ccs/SemanticCrossClusterSearchIT.java | 27 ++++++++++++++++--- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java index 8a842ec3973fb..c28fd05dd268e 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java @@ -7,16 +7,17 @@ package org.elasticsearch.search.ccs; +import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.CollectionUtils; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.TaskType; import org.elasticsearch.license.LicenseSettings; import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.test.AbstractMultiClustersTestCase; import org.elasticsearch.test.InternalTestCluster; import org.elasticsearch.transport.RemoteClusterAware; @@ -26,7 +27,9 @@ import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension; +import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin; import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; +import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; import java.io.IOException; import java.util.Collection; @@ -35,6 +38,7 @@ import java.util.Map; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse; import static org.hamcrest.CoreMatchers.equalTo; public class SemanticCrossClusterSearchIT extends AbstractMultiClustersTestCase { @@ -74,7 +78,22 @@ protected Settings nodeSettings() { @Override protected Collection> nodePlugins(String clusterAlias) { - return CollectionUtils.appendToCopy(super.nodePlugins(clusterAlias), LocalStateInferencePlugin.class); + return List.of(LocalStateInferencePlugin.class, TestInferenceServicePlugin.class); + } + + public void testSemanticCrossClusterSearch() throws Exception { + Map testClusterInfo = setupTwoClusters(); + String localIndex = (String) testClusterInfo.get("local.index"); + String remoteIndex = (String) testClusterInfo.get("remote.index"); + + SearchRequest searchRequest = new SearchRequest(localIndex, REMOTE_CLUSTER + ":" + remoteIndex); + searchRequest.source(new SearchSourceBuilder().query(new SemanticQueryBuilder(INFERENCE_FIELD, "foo")).size(10)); + // searchRequest.setCcsMinimizeRoundtrips(false); + + assertResponse(client(LOCAL_CLUSTER).search(searchRequest), response -> { + assertNotNull(response); + assertEquals(10, response.getHits().getHits().length); + }); } private Map setupTwoClusters(String[] localIndices, String[] remoteIndices) throws IOException { @@ -90,7 +109,7 @@ private Map setupTwoClusters(String[] localIndices, String[] rem .indices() .prepareCreate(localIndex) .setSettings(localSettings) - .setMapping(INFERENCE_FIELD, "type=semantic_text", "f", "type=text,inference_id=" + inferenceId) + .setMapping(INFERENCE_FIELD, "type=semantic_text,inference_id=" + inferenceId) ); indexDocs(client(LOCAL_CLUSTER), localIndex); } @@ -104,7 +123,7 @@ private Map setupTwoClusters(String[] localIndices, String[] rem .indices() .prepareCreate(remoteIndex) .setSettings(indexSettings(numShardsRemote, randomIntBetween(0, 1))) - .setMapping(INFERENCE_FIELD, "type=semantic_text", "f", "type=text,inference_id=" + inferenceId) + .setMapping(INFERENCE_FIELD, "type=semantic_text,inference_id=" + inferenceId) ); assertFalse( client(REMOTE_CLUSTER).admin() From ed58e786a03ed411fb79ae57ba9b71fe5a4ec9fe Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Fri, 18 Jul 2025 15:56:12 -0400 Subject: [PATCH 06/35] Allow cross-cluster search --- .../xpack/inference/queries/SemanticQueryBuilder.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index 182c083ef1c26..e1e5fed4474d4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -220,8 +220,6 @@ private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext qu throw new IllegalStateException( "Rewriting on the coordinator node requires a query rewrite context with non-null resolved indices" ); - } else if (resolvedIndices.getRemoteClusterIndices().isEmpty() == false) { - throw new IllegalArgumentException(NAME + " query does not support cross-cluster search"); } String inferenceId = getInferenceIdForForField(resolvedIndices.getConcreteLocalIndicesMetadata().values(), fieldName); From bc1ef8997a1a075c7c911c80398e039abefe86f0 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Fri, 18 Jul 2025 15:56:24 -0400 Subject: [PATCH 07/35] Fix test --- .../search/ccs/SemanticCrossClusterSearchIT.java | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java index c28fd05dd268e..2a1b2fab039df 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java @@ -11,6 +11,7 @@ import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; @@ -25,6 +26,7 @@ import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension; import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin; @@ -78,7 +80,7 @@ protected Settings nodeSettings() { @Override protected Collection> nodePlugins(String clusterAlias) { - return List.of(LocalStateInferencePlugin.class, TestInferenceServicePlugin.class); + return List.of(LocalStateInferencePlugin.class, TestInferenceServicePlugin.class, FakeMlPlugin.class); } public void testSemanticCrossClusterSearch() throws Exception { @@ -194,4 +196,11 @@ private int indexDocs(Client client, String index) { client.admin().indices().prepareRefresh(index).get(); return numDocs; } + + public static class FakeMlPlugin extends Plugin { + @Override + public List getNamedWriteables() { + return new MlInferenceNamedXContentProvider().getNamedWriteables(); + } + } } From 9a7d6f4a9218b3fa1488f65b2d191ae7ab9c9aef Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Fri, 18 Jul 2025 16:10:44 -0400 Subject: [PATCH 08/35] Send pre-intercepted request to remote cluster --- .../elasticsearch/action/search/TransportSearchAction.java | 7 +++++++ .../index/query/InterceptedQueryBuilderWrapper.java | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java index 2c3c077b2a940..2ed955ee6a541 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java @@ -64,6 +64,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexNotFoundException; +import org.elasticsearch.index.query.InterceptedQueryBuilderWrapper; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.Rewriteable; import org.elasticsearch.index.shard.ShardId; @@ -78,6 +79,7 @@ import org.elasticsearch.search.aggregations.AggregationReduceContext; import org.elasticsearch.search.builder.PointInTimeBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.builder.SubSearchSourceBuilder; import org.elasticsearch.search.internal.AliasFilter; import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.profile.SearchProfileResults; @@ -730,6 +732,11 @@ public void onFailure(Exception e) { String clusterAlias = entry.getKey(); boolean skipUnavailable = remoteClusterService.isSkipUnavailable(clusterAlias); OriginalIndices indices = entry.getValue(); + + if (searchRequest.source().query() instanceof InterceptedQueryBuilderWrapper interceptedQuery) { + searchRequest.source().subSearches(List.of(new SubSearchSourceBuilder(interceptedQuery.getOriginal()))); + } + SearchRequest ccsSearchRequest = SearchRequest.subSearchRequest( parentTaskId, searchRequest, // TODO: Need to prep the request here by stripping inference results? diff --git a/server/src/main/java/org/elasticsearch/index/query/InterceptedQueryBuilderWrapper.java b/server/src/main/java/org/elasticsearch/index/query/InterceptedQueryBuilderWrapper.java index 9c72675e476db..b3d4ae7226f27 100644 --- a/server/src/main/java/org/elasticsearch/index/query/InterceptedQueryBuilderWrapper.java +++ b/server/src/main/java/org/elasticsearch/index/query/InterceptedQueryBuilderWrapper.java @@ -22,7 +22,7 @@ * Wrapper for instances of {@link QueryBuilder} that have been intercepted using the {@link QueryRewriteInterceptor} to * break out of the rewrite phase. These instances are unwrapped on serialization. */ -class InterceptedQueryBuilderWrapper implements QueryBuilder { +public class InterceptedQueryBuilderWrapper implements QueryBuilder { protected final QueryBuilder original; protected final QueryBuilder rewritten; From 4740f0c3540a277fc5c9cfdff7769e49846a9bed Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Fri, 18 Jul 2025 16:17:41 -0400 Subject: [PATCH 09/35] Added match query test --- .../search/ccs/SemanticCrossClusterSearchIT.java | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java index 2a1b2fab039df..10e68e77dfdcb 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.query.MatchQueryBuilder; import org.elasticsearch.inference.TaskType; import org.elasticsearch.license.LicenseSettings; import org.elasticsearch.plugins.Plugin; @@ -98,6 +99,21 @@ public void testSemanticCrossClusterSearch() throws Exception { }); } + public void testMatchCrossClusterSearch() throws Exception { + Map testClusterInfo = setupTwoClusters(); + String localIndex = (String) testClusterInfo.get("local.index"); + String remoteIndex = (String) testClusterInfo.get("remote.index"); + + SearchRequest searchRequest = new SearchRequest(localIndex, REMOTE_CLUSTER + ":" + remoteIndex); + searchRequest.source(new SearchSourceBuilder().query(new MatchQueryBuilder(INFERENCE_FIELD, "foo")).size(10)); + // searchRequest.setCcsMinimizeRoundtrips(false); + + assertResponse(client(LOCAL_CLUSTER).search(searchRequest), response -> { + assertNotNull(response); + assertEquals(10, response.getHits().getHits().length); + }); + } + private Map setupTwoClusters(String[] localIndices, String[] remoteIndices) throws IOException { final String inferenceId = "test_inference_id"; createInferenceEndpoint(client(LOCAL_CLUSTER), TaskType.TEXT_EMBEDDING, inferenceId, BBQ_COMPATIBLE_SERVICE_SETTINGS); From 83923c4a83d746a7b14678151e2f63646a620f69 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Fri, 18 Jul 2025 18:05:11 -0400 Subject: [PATCH 10/35] Added stub classes for match query builder wrapper --- .../MatchQueryBuilderWithEmbeddings.java | 30 +++++++++++++++++++ .../PreComputedEmbeddingsProvider.java | 16 ++++++++++ 2 files changed, 46 insertions(+) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/MatchQueryBuilderWithEmbeddings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/PreComputedEmbeddingsProvider.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/MatchQueryBuilderWithEmbeddings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/MatchQueryBuilderWithEmbeddings.java new file mode 100644 index 0000000000000..a1a0d039ec3ac --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/MatchQueryBuilderWithEmbeddings.java @@ -0,0 +1,30 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.queries; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.index.query.MatchQueryBuilder; +import org.elasticsearch.inference.InferenceResults; + +import java.io.IOException; +import java.util.Map; + +public class MatchQueryBuilderWithEmbeddings extends MatchQueryBuilder implements PreComputedEmbeddingsProvider { + public MatchQueryBuilderWithEmbeddings(String fieldName, Object value) { + super(fieldName, value); + } + + public MatchQueryBuilderWithEmbeddings(StreamInput in) throws IOException { + super(in); + } + + @Override + public Map getEmbeddingsForField(String fieldName) { + return Map.of(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/PreComputedEmbeddingsProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/PreComputedEmbeddingsProvider.java new file mode 100644 index 0000000000000..994b6757b2c10 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/PreComputedEmbeddingsProvider.java @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.queries; + +import org.elasticsearch.inference.InferenceResults; + +import java.util.Map; + +public interface PreComputedEmbeddingsProvider { + Map getEmbeddingsForField(String fieldName); +} From f00a5207652d78678c733b523d26b1231b6697a2 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Thu, 31 Jul 2025 15:51:36 -0400 Subject: [PATCH 11/35] Fix build error --- .../query/SemanticMatchQueryRewriteInterceptorTests.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticMatchQueryRewriteInterceptorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticMatchQueryRewriteInterceptorTests.java index 6d27395e440c3..adfac6011886d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticMatchQueryRewriteInterceptorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticMatchQueryRewriteInterceptorTests.java @@ -98,8 +98,8 @@ public void testBoostAndQueryNameInMatchQueryRewrite() throws IOException { InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten; assertEquals(BOOST, intercepted.boost(), 0.0f); assertEquals(QUERY_NAME, intercepted.queryName()); - assertTrue(intercepted.queryBuilder instanceof SemanticQueryBuilder); - SemanticQueryBuilder semanticQueryBuilder = (SemanticQueryBuilder) intercepted.queryBuilder; + assertTrue(intercepted.rewritten instanceof SemanticQueryBuilder); + SemanticQueryBuilder semanticQueryBuilder = (SemanticQueryBuilder) intercepted.rewritten; assertEquals(FIELD_NAME, semanticQueryBuilder.getFieldName()); assertEquals(VALUE, semanticQueryBuilder.getQuery()); } From 87a1eaf818c80cc36b99cce67820d3a198f51d21 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Thu, 31 Jul 2025 16:00:25 -0400 Subject: [PATCH 12/35] Fixed entitlement policy --- .../inference/src/main/plugin-metadata/entitlement-policy.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/x-pack/plugin/inference/src/main/plugin-metadata/entitlement-policy.yaml b/x-pack/plugin/inference/src/main/plugin-metadata/entitlement-policy.yaml index 36ac851acf1ea..cd046c4a6cb23 100644 --- a/x-pack/plugin/inference/src/main/plugin-metadata/entitlement-policy.yaml +++ b/x-pack/plugin/inference/src/main/plugin-metadata/entitlement-policy.yaml @@ -12,6 +12,7 @@ software.amazon.awssdk.http.nio.netty: io.netty.common: - outbound_network - manage_threads + - inbound_network - files: - path: "/etc/os-release" mode: "read" @@ -22,3 +23,4 @@ io.netty.common: io.netty.transport: - manage_threads - outbound_network + - inbound_network From 38af314ce333d3e0aeab10bc0f48897cfda92b96 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Fri, 1 Aug 2025 10:17:26 -0400 Subject: [PATCH 13/35] Add PIT integration test --- .../ccs/SemanticCrossClusterSearchIT.java | 33 ++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java index 10e68e77dfdcb..fc9a8f47af46f 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java @@ -7,7 +7,10 @@ package org.elasticsearch.search.ccs; +import org.elasticsearch.action.search.OpenPointInTimeRequest; +import org.elasticsearch.action.search.OpenPointInTimeResponse; import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.TransportOpenPointInTimeAction; import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; @@ -19,6 +22,7 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.license.LicenseSettings; import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.search.builder.PointInTimeBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.test.AbstractMultiClustersTestCase; import org.elasticsearch.test.InternalTestCluster; @@ -91,7 +95,28 @@ public void testSemanticCrossClusterSearch() throws Exception { SearchRequest searchRequest = new SearchRequest(localIndex, REMOTE_CLUSTER + ":" + remoteIndex); searchRequest.source(new SearchSourceBuilder().query(new SemanticQueryBuilder(INFERENCE_FIELD, "foo")).size(10)); - // searchRequest.setCcsMinimizeRoundtrips(false); + searchRequest.setCcsMinimizeRoundtrips(false); + searchRequest.pointInTimeBuilder(); + + assertResponse(client(LOCAL_CLUSTER).search(searchRequest), response -> { + assertNotNull(response); + assertEquals(10, response.getHits().getHits().length); + }); + } + + public void testSemanticCrossClusterSearchWithPIT() throws Exception { + Map testClusterInfo = setupTwoClusters(); + String localIndex = (String) testClusterInfo.get("local.index"); + String remoteIndex = (String) testClusterInfo.get("remote.index"); + + BytesReference pitId = openPointInTime(new String[]{localIndex, REMOTE_CLUSTER + ":" + remoteIndex}, TimeValue.timeValueMinutes(2)); + + SearchRequest searchRequest = new SearchRequest(); + searchRequest.source(new SearchSourceBuilder() + .query(new SemanticQueryBuilder(INFERENCE_FIELD, "foo")) + .pointInTimeBuilder(new PointInTimeBuilder(pitId)) + .size(10) + ); assertResponse(client(LOCAL_CLUSTER).search(searchRequest), response -> { assertNotNull(response); @@ -213,6 +238,12 @@ private int indexDocs(Client client, String index) { return numDocs; } + private BytesReference openPointInTime(String[] indices, TimeValue keepAlive) { + OpenPointInTimeRequest request = new OpenPointInTimeRequest(indices).keepAlive(keepAlive); + final OpenPointInTimeResponse response = client().execute(TransportOpenPointInTimeAction.TYPE, request).actionGet(); + return response.getPointInTimeId(); + } + public static class FakeMlPlugin extends Plugin { @Override public List getNamedWriteables() { From 49355155ac501371280acb30c3c65d0d6eddae8d Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Fri, 1 Aug 2025 13:33:43 -0400 Subject: [PATCH 14/35] Added CCS minimize round-trips to query rewrite context --- .../query/TransportValidateQueryAction.java | 2 +- .../action/explain/TransportExplainAction.java | 2 +- .../action/search/TransportSearchAction.java | 1 + .../search/TransportSearchShardsAction.java | 2 +- .../index/query/QueryRewriteContext.java | 15 +++++++++++++-- .../org/elasticsearch/indices/IndicesService.java | 3 ++- .../org/elasticsearch/search/SearchService.java | 7 ++++--- .../action/search/TransportSearchActionTests.java | 4 ++-- .../function/fulltext/QueryBuilderResolver.java | 2 +- 9 files changed, 26 insertions(+), 12 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/validate/query/TransportValidateQueryAction.java b/server/src/main/java/org/elasticsearch/action/admin/indices/validate/query/TransportValidateQueryAction.java index 1d9ba297c35c6..c1c7384688df6 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/validate/query/TransportValidateQueryAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/validate/query/TransportValidateQueryAction.java @@ -130,7 +130,7 @@ protected void doExecute(Task task, ValidateQueryRequest request, ActionListener } else { Rewriteable.rewriteAndFetch( request.query(), - searchService.getRewriteContext(timeProvider, resolvedIndices, null), + searchService.getRewriteContext(timeProvider, resolvedIndices, null, true), rewriteListener ); } diff --git a/server/src/main/java/org/elasticsearch/action/explain/TransportExplainAction.java b/server/src/main/java/org/elasticsearch/action/explain/TransportExplainAction.java index 1a3d4ac70831f..c52e1a9121ec4 100644 --- a/server/src/main/java/org/elasticsearch/action/explain/TransportExplainAction.java +++ b/server/src/main/java/org/elasticsearch/action/explain/TransportExplainAction.java @@ -103,7 +103,7 @@ protected void doExecute(Task task, ExplainRequest request, ActionListener request.nowInMillis; - Rewriteable.rewriteAndFetch(request.query(), searchService.getRewriteContext(timeProvider, resolvedIndices, null), rewriteListener); + Rewriteable.rewriteAndFetch(request.query(), searchService.getRewriteContext(timeProvider, resolvedIndices, null, true), rewriteListener); } @Override diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java index 2ae7b7152a225..55b1df707f430 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java @@ -554,6 +554,7 @@ public void onFailure(Exception e) { timeProvider::absoluteStartMillis, resolvedIndices, original.pointInTimeBuilder(), + original.isCcsMinimizeRoundtrips(), isExplain ), rewriteListener diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchShardsAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchShardsAction.java index d12847ec8bf7f..eeecda05ca5d0 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchShardsAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchShardsAction.java @@ -127,7 +127,7 @@ public void searchShards(Task task, SearchShardsRequest searchShardsRequest, Act Rewriteable.rewriteAndFetch( original, - searchService.getRewriteContext(timeProvider::absoluteStartMillis, resolvedIndices, null), + searchService.getRewriteContext(timeProvider::absoluteStartMillis, resolvedIndices, null, original.isCcsMinimizeRoundtrips()), listener.delegateFailureAndWrap((delegate, searchRequest) -> { Index[] concreteIndices = resolvedIndices.getConcreteLocalIndices(); final Set indicesAndAliases = indexNameExpressionResolver.resolveExpressions( diff --git a/server/src/main/java/org/elasticsearch/index/query/QueryRewriteContext.java b/server/src/main/java/org/elasticsearch/index/query/QueryRewriteContext.java index bc14a31978c18..fde94d4077154 100644 --- a/server/src/main/java/org/elasticsearch/index/query/QueryRewriteContext.java +++ b/server/src/main/java/org/elasticsearch/index/query/QueryRewriteContext.java @@ -72,6 +72,7 @@ public class QueryRewriteContext { private final ResolvedIndices resolvedIndices; private final PointInTimeBuilder pit; private QueryRewriteInterceptor queryRewriteInterceptor; + private final boolean ccsMinimizeRoundtrips; private final boolean isExplain; public QueryRewriteContext( @@ -91,6 +92,7 @@ public QueryRewriteContext( final ResolvedIndices resolvedIndices, final PointInTimeBuilder pit, final QueryRewriteInterceptor queryRewriteInterceptor, + final boolean ccsMinimizeRoundtrips, final boolean isExplain ) { @@ -111,6 +113,7 @@ public QueryRewriteContext( this.resolvedIndices = resolvedIndices; this.pit = pit; this.queryRewriteInterceptor = queryRewriteInterceptor; + this.ccsMinimizeRoundtrips = ccsMinimizeRoundtrips; this.isExplain = isExplain; } @@ -132,6 +135,7 @@ public QueryRewriteContext(final XContentParserConfiguration parserConfiguration null, null, null, + true, false ); } @@ -142,9 +146,10 @@ public QueryRewriteContext( final LongSupplier nowInMillis, final ResolvedIndices resolvedIndices, final PointInTimeBuilder pit, - final QueryRewriteInterceptor queryRewriteInterceptor + final QueryRewriteInterceptor queryRewriteInterceptor, + final boolean ccsMinimizeRoundtrips ) { - this(parserConfiguration, client, nowInMillis, resolvedIndices, pit, queryRewriteInterceptor, false); + this(parserConfiguration, client, nowInMillis, resolvedIndices, pit, queryRewriteInterceptor, ccsMinimizeRoundtrips, false); } public QueryRewriteContext( @@ -154,6 +159,7 @@ public QueryRewriteContext( final ResolvedIndices resolvedIndices, final PointInTimeBuilder pit, final QueryRewriteInterceptor queryRewriteInterceptor, + final boolean ccsMinimizeRoundtrips, final boolean isExplain ) { this( @@ -173,6 +179,7 @@ public QueryRewriteContext( resolvedIndices, pit, queryRewriteInterceptor, + ccsMinimizeRoundtrips, isExplain ); } @@ -279,6 +286,10 @@ public void setMapUnmappedFieldAsString(boolean mapUnmappedFieldAsString) { this.mapUnmappedFieldAsString = mapUnmappedFieldAsString; } + public boolean isCcsMinimizeRoundtrips() { + return ccsMinimizeRoundtrips; + } + public boolean isExplain() { return this.isExplain; } diff --git a/server/src/main/java/org/elasticsearch/indices/IndicesService.java b/server/src/main/java/org/elasticsearch/indices/IndicesService.java index 528601f201fee..427f305718e57 100644 --- a/server/src/main/java/org/elasticsearch/indices/IndicesService.java +++ b/server/src/main/java/org/elasticsearch/indices/IndicesService.java @@ -1849,9 +1849,10 @@ public QueryRewriteContext getRewriteContext( LongSupplier nowInMillis, ResolvedIndices resolvedIndices, PointInTimeBuilder pit, + final boolean ccsMinimizeRoundtrips, final boolean isExplain ) { - return new QueryRewriteContext(parserConfig, client, nowInMillis, resolvedIndices, pit, queryRewriteInterceptor, isExplain); + return new QueryRewriteContext(parserConfig, client, nowInMillis, resolvedIndices, pit, queryRewriteInterceptor, ccsMinimizeRoundtrips, isExplain); } public DataRewriteContext getDataRewriteContext(LongSupplier nowInMillis) { diff --git a/server/src/main/java/org/elasticsearch/search/SearchService.java b/server/src/main/java/org/elasticsearch/search/SearchService.java index 9cca55f2ec748..16fe419e2401e 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchService.java +++ b/server/src/main/java/org/elasticsearch/search/SearchService.java @@ -2127,8 +2127,8 @@ private void rewriteAndFetchShardRequest(IndexShard shard, ShardSearchRequest re /** * Returns a new {@link QueryRewriteContext} with the given {@code now} provider */ - public QueryRewriteContext getRewriteContext(LongSupplier nowInMillis, ResolvedIndices resolvedIndices, PointInTimeBuilder pit) { - return getRewriteContext(nowInMillis, resolvedIndices, pit, false); + public QueryRewriteContext getRewriteContext(LongSupplier nowInMillis, ResolvedIndices resolvedIndices, PointInTimeBuilder pit, final boolean ccsMinimizeRoundtrips) { + return getRewriteContext(nowInMillis, resolvedIndices, pit, ccsMinimizeRoundtrips, false); } /** @@ -2138,9 +2138,10 @@ public QueryRewriteContext getRewriteContext( LongSupplier nowInMillis, ResolvedIndices resolvedIndices, PointInTimeBuilder pit, + final boolean ccsMinimizeRoundtrips, final boolean isExplain ) { - return indicesService.getRewriteContext(nowInMillis, resolvedIndices, pit, isExplain); + return indicesService.getRewriteContext(nowInMillis, resolvedIndices, pit, ccsMinimizeRoundtrips, isExplain); } public CoordinatorRewriteContextProvider getCoordinatorRewriteContextProvider(LongSupplier nowInMillis) { diff --git a/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java b/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java index 9f286efe28083..fd97b4b95fe35 100644 --- a/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java @@ -1784,8 +1784,8 @@ protected void doWriteTo(StreamOutput out) throws IOException { NodeClient client = new NodeClient(settings, threadPool, TestProjectResolvers.alwaysThrow()); SearchService searchService = mock(SearchService.class); - when(searchService.getRewriteContext(any(), any(), any(), anyBoolean())).thenReturn( - new QueryRewriteContext(null, null, null, null, null, null) + when(searchService.getRewriteContext(any(), any(), any(), anyBoolean(), anyBoolean())).thenReturn( + new QueryRewriteContext(null, null, null, null, null, null, true) ); ClusterService clusterService = new ClusterService( settings, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/QueryBuilderResolver.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/QueryBuilderResolver.java index ef3828a3f2fbb..2b6fe947994b3 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/QueryBuilderResolver.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/QueryBuilderResolver.java @@ -61,7 +61,7 @@ private static QueryRewriteContext queryRewriteContext(TransportActionServices s System.currentTimeMillis() ); - return services.searchService().getRewriteContext(System::currentTimeMillis, resolvedIndices, null); + return services.searchService().getRewriteContext(System::currentTimeMillis, resolvedIndices, null, true); } private static Set indexNames(LogicalPlan plan) { From 9c251dcd380c08595685a19031e7c744e5d51afb Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Fri, 1 Aug 2025 13:33:59 -0400 Subject: [PATCH 15/35] Code cleanup --- .../elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java | 1 - 1 file changed, 1 deletion(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java index fc9a8f47af46f..8b47049932d10 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java @@ -96,7 +96,6 @@ public void testSemanticCrossClusterSearch() throws Exception { SearchRequest searchRequest = new SearchRequest(localIndex, REMOTE_CLUSTER + ":" + remoteIndex); searchRequest.source(new SearchSourceBuilder().query(new SemanticQueryBuilder(INFERENCE_FIELD, "foo")).size(10)); searchRequest.setCcsMinimizeRoundtrips(false); - searchRequest.pointInTimeBuilder(); assertResponse(client(LOCAL_CLUSTER).search(searchRequest), response -> { assertNotNull(response); From 1e5cb6f1c94aa4337176f36569de3e69ef5f8906 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Fri, 1 Aug 2025 13:36:09 -0400 Subject: [PATCH 16/35] Spotless --- .../action/explain/TransportExplainAction.java | 6 +++++- .../org/elasticsearch/indices/IndicesService.java | 11 ++++++++++- .../org/elasticsearch/search/SearchService.java | 7 ++++++- .../search/ccs/SemanticCrossClusterSearchIT.java | 13 ++++++++----- 4 files changed, 29 insertions(+), 8 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/explain/TransportExplainAction.java b/server/src/main/java/org/elasticsearch/action/explain/TransportExplainAction.java index c52e1a9121ec4..23c4c136ca853 100644 --- a/server/src/main/java/org/elasticsearch/action/explain/TransportExplainAction.java +++ b/server/src/main/java/org/elasticsearch/action/explain/TransportExplainAction.java @@ -103,7 +103,11 @@ protected void doExecute(Task task, ExplainRequest request, ActionListener request.nowInMillis; - Rewriteable.rewriteAndFetch(request.query(), searchService.getRewriteContext(timeProvider, resolvedIndices, null, true), rewriteListener); + Rewriteable.rewriteAndFetch( + request.query(), + searchService.getRewriteContext(timeProvider, resolvedIndices, null, true), + rewriteListener + ); } @Override diff --git a/server/src/main/java/org/elasticsearch/indices/IndicesService.java b/server/src/main/java/org/elasticsearch/indices/IndicesService.java index 427f305718e57..081df95d94527 100644 --- a/server/src/main/java/org/elasticsearch/indices/IndicesService.java +++ b/server/src/main/java/org/elasticsearch/indices/IndicesService.java @@ -1852,7 +1852,16 @@ public QueryRewriteContext getRewriteContext( final boolean ccsMinimizeRoundtrips, final boolean isExplain ) { - return new QueryRewriteContext(parserConfig, client, nowInMillis, resolvedIndices, pit, queryRewriteInterceptor, ccsMinimizeRoundtrips, isExplain); + return new QueryRewriteContext( + parserConfig, + client, + nowInMillis, + resolvedIndices, + pit, + queryRewriteInterceptor, + ccsMinimizeRoundtrips, + isExplain + ); } public DataRewriteContext getDataRewriteContext(LongSupplier nowInMillis) { diff --git a/server/src/main/java/org/elasticsearch/search/SearchService.java b/server/src/main/java/org/elasticsearch/search/SearchService.java index 16fe419e2401e..49798f49de3b6 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchService.java +++ b/server/src/main/java/org/elasticsearch/search/SearchService.java @@ -2127,7 +2127,12 @@ private void rewriteAndFetchShardRequest(IndexShard shard, ShardSearchRequest re /** * Returns a new {@link QueryRewriteContext} with the given {@code now} provider */ - public QueryRewriteContext getRewriteContext(LongSupplier nowInMillis, ResolvedIndices resolvedIndices, PointInTimeBuilder pit, final boolean ccsMinimizeRoundtrips) { + public QueryRewriteContext getRewriteContext( + LongSupplier nowInMillis, + ResolvedIndices resolvedIndices, + PointInTimeBuilder pit, + final boolean ccsMinimizeRoundtrips + ) { return getRewriteContext(nowInMillis, resolvedIndices, pit, ccsMinimizeRoundtrips, false); } diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java index 8b47049932d10..11a68f7a8ce80 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java @@ -108,13 +108,16 @@ public void testSemanticCrossClusterSearchWithPIT() throws Exception { String localIndex = (String) testClusterInfo.get("local.index"); String remoteIndex = (String) testClusterInfo.get("remote.index"); - BytesReference pitId = openPointInTime(new String[]{localIndex, REMOTE_CLUSTER + ":" + remoteIndex}, TimeValue.timeValueMinutes(2)); + BytesReference pitId = openPointInTime( + new String[] { localIndex, REMOTE_CLUSTER + ":" + remoteIndex }, + TimeValue.timeValueMinutes(2) + ); SearchRequest searchRequest = new SearchRequest(); - searchRequest.source(new SearchSourceBuilder() - .query(new SemanticQueryBuilder(INFERENCE_FIELD, "foo")) - .pointInTimeBuilder(new PointInTimeBuilder(pitId)) - .size(10) + searchRequest.source( + new SearchSourceBuilder().query(new SemanticQueryBuilder(INFERENCE_FIELD, "foo")) + .pointInTimeBuilder(new PointInTimeBuilder(pitId)) + .size(10) ); assertResponse(client(LOCAL_CLUSTER).search(searchRequest), response -> { From 8db87aba08920b061f3eab2a8aeaa97264263e72 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Fri, 1 Aug 2025 13:58:47 -0400 Subject: [PATCH 17/35] Add model registry to semantic query builder --- .../elasticsearch/xpack/inference/InferencePlugin.java | 3 ++- .../xpack/inference/queries/SemanticQueryBuilder.java | 8 ++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 374770ad25eb1..80c0f5f846b68 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -592,10 +592,11 @@ public Map getHighlighters() { @Override public void onNodeStarted() { var registry = inferenceServiceRegistry.get(); - if (registry != null) { registry.onNodeStarted(); } + + SemanticQueryBuilder.setModelRegistrySupplier(getModelRegistry()); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index 6d097518c4466..8981f0b87ca2d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -36,12 +36,14 @@ import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.io.IOException; import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.function.Supplier; import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; @@ -68,6 +70,12 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder MODEL_REGISTRY_SUPPLIER = () -> null; + + public static void setModelRegistrySupplier(Supplier supplier) { + MODEL_REGISTRY_SUPPLIER = supplier; + } + private final String fieldName; private final String query; private final SetOnce inferenceResultsSupplier; From 0b8a1a9f4d39ecd2dfbe09a9aebe52dbf64b0e23 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Fri, 1 Aug 2025 14:01:40 -0400 Subject: [PATCH 18/35] Remove unused code --- .../MatchQueryBuilderWithEmbeddings.java | 30 ------------------- .../PreComputedEmbeddingsProvider.java | 16 ---------- 2 files changed, 46 deletions(-) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/MatchQueryBuilderWithEmbeddings.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/PreComputedEmbeddingsProvider.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/MatchQueryBuilderWithEmbeddings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/MatchQueryBuilderWithEmbeddings.java deleted file mode 100644 index a1a0d039ec3ac..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/MatchQueryBuilderWithEmbeddings.java +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.queries; - -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.index.query.MatchQueryBuilder; -import org.elasticsearch.inference.InferenceResults; - -import java.io.IOException; -import java.util.Map; - -public class MatchQueryBuilderWithEmbeddings extends MatchQueryBuilder implements PreComputedEmbeddingsProvider { - public MatchQueryBuilderWithEmbeddings(String fieldName, Object value) { - super(fieldName, value); - } - - public MatchQueryBuilderWithEmbeddings(StreamInput in) throws IOException { - super(in); - } - - @Override - public Map getEmbeddingsForField(String fieldName) { - return Map.of(); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/PreComputedEmbeddingsProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/PreComputedEmbeddingsProvider.java deleted file mode 100644 index 994b6757b2c10..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/PreComputedEmbeddingsProvider.java +++ /dev/null @@ -1,16 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.queries; - -import org.elasticsearch.inference.InferenceResults; - -import java.util.Map; - -public interface PreComputedEmbeddingsProvider { - Map getEmbeddingsForField(String fieldName); -} From c55628c6d0ed2c3284eb21e3f9dfcd028559eaca Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Fri, 1 Aug 2025 14:35:51 -0400 Subject: [PATCH 19/35] Added the map embeddings provider --- .../xpack/inference/InferencePlugin.java | 3 ++ .../inference/queries/EmbeddingsProvider.java | 15 ++++++ .../queries/InferenceEndpointKey.java | 50 +++++++++++++++++++ .../queries/MapEmbeddingsProvider.java | 49 ++++++++++++++++++ 4 files changed, 117 insertions(+) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/EmbeddingsProvider.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InferenceEndpointKey.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/MapEmbeddingsProvider.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 80c0f5f846b68..bff61341b7377 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -93,6 +93,8 @@ import org.elasticsearch.xpack.inference.mapper.OffsetSourceFieldMapper; import org.elasticsearch.xpack.inference.mapper.SemanticInferenceMetadataFieldsMapper; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; +import org.elasticsearch.xpack.inference.queries.EmbeddingsProvider; +import org.elasticsearch.xpack.inference.queries.MapEmbeddingsProvider; import org.elasticsearch.xpack.inference.queries.SemanticKnnVectorQueryRewriteInterceptor; import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor; import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; @@ -426,6 +428,7 @@ public List getNamedWriteables() { entries.add(new NamedWriteableRegistry.Entry(RankDoc.class, TextSimilarityRankDoc.NAME, TextSimilarityRankDoc::new)); entries.add(new NamedWriteableRegistry.Entry(Metadata.ProjectCustom.class, ModelRegistryMetadata.TYPE, ModelRegistryMetadata::new)); entries.add(new NamedWriteableRegistry.Entry(NamedDiff.class, ModelRegistryMetadata.TYPE, ModelRegistryMetadata::readDiffFrom)); + entries.add(new NamedWriteableRegistry.Entry(EmbeddingsProvider.class, MapEmbeddingsProvider.NAME, MapEmbeddingsProvider::new)); return entries; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/EmbeddingsProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/EmbeddingsProvider.java new file mode 100644 index 0000000000000..3538c75a49ee4 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/EmbeddingsProvider.java @@ -0,0 +1,15 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.queries; + +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.inference.InferenceResults; + +public interface EmbeddingsProvider extends NamedWriteable { + InferenceResults getEmbeddings(InferenceEndpointKey key); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InferenceEndpointKey.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InferenceEndpointKey.java new file mode 100644 index 0000000000000..8f80e81f0f506 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InferenceEndpointKey.java @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.queries; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.MinimalServiceSettings; + +import java.io.IOException; +import java.util.Objects; + +public class InferenceEndpointKey implements Writeable { + private final String inferenceId; + private final MinimalServiceSettings serviceSettings; + + public InferenceEndpointKey(String inferenceId, MinimalServiceSettings serviceSettings) { + this.inferenceId = inferenceId; + this.serviceSettings = serviceSettings; + } + + public InferenceEndpointKey(StreamInput in) throws IOException { + this.inferenceId = in.readString(); + this.serviceSettings = in.readNamedWriteable(MinimalServiceSettings.class); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(inferenceId); + out.writeNamedWriteable(serviceSettings); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + InferenceEndpointKey that = (InferenceEndpointKey) o; + return Objects.equals(inferenceId, that.inferenceId) && Objects.equals(serviceSettings, that.serviceSettings); + } + + @Override + public int hashCode() { + return Objects.hash(inferenceId, serviceSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/MapEmbeddingsProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/MapEmbeddingsProvider.java new file mode 100644 index 0000000000000..8fe1630dceaa0 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/MapEmbeddingsProvider.java @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.queries; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.InferenceResults; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +public class MapEmbeddingsProvider implements EmbeddingsProvider { + public static final String NAME = "map_embeddings_provider"; + + private final Map embeddings; + + public MapEmbeddingsProvider() { + this.embeddings = new HashMap<>(); + } + + public MapEmbeddingsProvider(StreamInput in) throws IOException { + embeddings = in.readMap(InferenceEndpointKey::new, i -> i.readNamedWriteable(InferenceResults.class)); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeMap(embeddings); + } + + @Override + public InferenceResults getEmbeddings(InferenceEndpointKey key) { + return embeddings.get(key); + } + + public void addEmbeddings(InferenceEndpointKey key, InferenceResults embeddings) { + this.embeddings.put(key, embeddings); + } +} From 65b20bd4bd99f9ab8c15d8bbe74e52e6c4f9e4df Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Fri, 1 Aug 2025 14:46:12 -0400 Subject: [PATCH 20/35] Added the single embeddings provider --- .../xpack/inference/InferencePlugin.java | 4 ++ .../queries/MapEmbeddingsProvider.java | 2 +- .../queries/SingleEmbeddingsProvider.java | 43 +++++++++++++++++++ 3 files changed, 48 insertions(+), 1 deletion(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SingleEmbeddingsProvider.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index bff61341b7377..e7730d79af492 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -99,6 +99,7 @@ import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor; import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; import org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor; +import org.elasticsearch.xpack.inference.queries.SingleEmbeddingsProvider; import org.elasticsearch.xpack.inference.rank.random.RandomRankBuilder; import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder; import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankBuilder; @@ -429,6 +430,9 @@ public List getNamedWriteables() { entries.add(new NamedWriteableRegistry.Entry(Metadata.ProjectCustom.class, ModelRegistryMetadata.TYPE, ModelRegistryMetadata::new)); entries.add(new NamedWriteableRegistry.Entry(NamedDiff.class, ModelRegistryMetadata.TYPE, ModelRegistryMetadata::readDiffFrom)); entries.add(new NamedWriteableRegistry.Entry(EmbeddingsProvider.class, MapEmbeddingsProvider.NAME, MapEmbeddingsProvider::new)); + entries.add( + new NamedWriteableRegistry.Entry(EmbeddingsProvider.class, SingleEmbeddingsProvider.NAME, SingleEmbeddingsProvider::new) + ); return entries; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/MapEmbeddingsProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/MapEmbeddingsProvider.java index 8fe1630dceaa0..5e2efa856d80c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/MapEmbeddingsProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/MapEmbeddingsProvider.java @@ -25,7 +25,7 @@ public MapEmbeddingsProvider() { } public MapEmbeddingsProvider(StreamInput in) throws IOException { - embeddings = in.readMap(InferenceEndpointKey::new, i -> i.readNamedWriteable(InferenceResults.class)); + this.embeddings = in.readMap(InferenceEndpointKey::new, i -> i.readNamedWriteable(InferenceResults.class)); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SingleEmbeddingsProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SingleEmbeddingsProvider.java new file mode 100644 index 0000000000000..78c66294cef6c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SingleEmbeddingsProvider.java @@ -0,0 +1,43 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.queries; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.InferenceResults; + +import java.io.IOException; + +public class SingleEmbeddingsProvider implements EmbeddingsProvider { + public static final String NAME = "single_embeddings_provider"; + + private final InferenceResults embeddings; + + public SingleEmbeddingsProvider(InferenceResults embeddings) { + this.embeddings = embeddings; + } + + public SingleEmbeddingsProvider(StreamInput in) throws IOException { + this.embeddings = in.readNamedWriteable(InferenceResults.class); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeNamedWriteable(embeddings); + } + + @Override + public InferenceResults getEmbeddings(InferenceEndpointKey key) { + return embeddings; + } +} From a335f34e30c11a88e7483e6090719b0378147a65 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Fri, 1 Aug 2025 15:44:24 -0400 Subject: [PATCH 21/35] Update semantic query builder to use embeddings providers --- .../org/elasticsearch/TransportVersions.java | 1 + .../queries/MapEmbeddingsProvider.java | 14 ++ .../queries/SemanticQueryBuilder.java | 169 ++++++++++-------- .../queries/SingleEmbeddingsProvider.java | 14 ++ 4 files changed, 128 insertions(+), 70 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index e57cb485361b6..8f4da31c47175 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -355,6 +355,7 @@ static TransportVersion def(int id) { public static final TransportVersion PIPELINE_TRACKING_INFO = def(9_131_0_00); public static final TransportVersion COMPONENT_TEMPLATE_TRACKING_INFO = def(9_132_0_00); public static final TransportVersion TO_CHILD_BLOCK_JOIN_QUERY = def(9_133_0_00); + public static final TransportVersion SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS = def(9_134_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/MapEmbeddingsProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/MapEmbeddingsProvider.java index 5e2efa856d80c..68edf16bcc80a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/MapEmbeddingsProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/MapEmbeddingsProvider.java @@ -14,6 +14,7 @@ import java.io.IOException; import java.util.HashMap; import java.util.Map; +import java.util.Objects; public class MapEmbeddingsProvider implements EmbeddingsProvider { public static final String NAME = "map_embeddings_provider"; @@ -46,4 +47,17 @@ public InferenceResults getEmbeddings(InferenceEndpointKey key) { public void addEmbeddings(InferenceEndpointKey key, InferenceResults embeddings) { this.embeddings.put(key, embeddings); } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + MapEmbeddingsProvider that = (MapEmbeddingsProvider) o; + return Objects.equals(embeddings, that.embeddings); + } + + @Override + public int hashCode() { + return Objects.hashCode(embeddings); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index 8981f0b87ca2d..dbddd7492b65d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -8,7 +8,6 @@ package org.elasticsearch.xpack.inference.queries; import org.apache.lucene.search.Query; -import org.apache.lucene.util.SetOnce; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ResolvedIndices; @@ -25,6 +24,7 @@ import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; @@ -40,9 +40,11 @@ import java.io.IOException; import java.util.Collection; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; import java.util.function.Supplier; import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; @@ -78,8 +80,7 @@ public static void setModelRegistrySupplier(Supplier supplier) { private final String fieldName; private final String query; - private final SetOnce inferenceResultsSupplier; - private final InferenceResults inferenceResults; + private final EmbeddingsProvider embeddingsProvider; private final boolean noInferenceResults; private final Boolean lenient; @@ -96,8 +97,7 @@ public SemanticQueryBuilder(String fieldName, String query, Boolean lenient) { } this.fieldName = fieldName; this.query = query; - this.inferenceResults = null; - this.inferenceResultsSupplier = null; + this.embeddingsProvider = null; this.noInferenceResults = false; this.lenient = lenient; } @@ -106,9 +106,17 @@ public SemanticQueryBuilder(StreamInput in) throws IOException { super(in); this.fieldName = in.readString(); this.query = in.readString(); - this.inferenceResults = in.readOptionalNamedWriteable(InferenceResults.class); + if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS)) { + this.embeddingsProvider = in.readOptionalNamedWriteable(EmbeddingsProvider.class); + } else { + InferenceResults inferenceResults = in.readOptionalNamedWriteable(InferenceResults.class); + if (inferenceResults != null) { + this.embeddingsProvider = new SingleEmbeddingsProvider(inferenceResults); + } else { + this.embeddingsProvider = null; + } + } this.noInferenceResults = in.readBoolean(); - this.inferenceResultsSupplier = null; if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_QUERY_LENIENT)) { this.lenient = in.readOptionalBoolean(); } else { @@ -118,30 +126,26 @@ public SemanticQueryBuilder(StreamInput in) throws IOException { @Override protected void doWriteTo(StreamOutput out) throws IOException { - if (inferenceResultsSupplier != null) { - throw new IllegalStateException("Inference results supplier is set. Missing a rewriteAndFetch?"); - } out.writeString(fieldName); out.writeString(query); - out.writeOptionalNamedWriteable(inferenceResults); + if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS)) { + out.writeOptionalNamedWriteable(embeddingsProvider); + } else { + // TODO: Handle multiple inference IDs in a mixed-version cluster + throw new UnsupportedOperationException("Handle old transport versions"); + } out.writeBoolean(noInferenceResults); if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_QUERY_LENIENT)) { out.writeOptionalBoolean(lenient); } } - private SemanticQueryBuilder( - SemanticQueryBuilder other, - SetOnce inferenceResultsSupplier, - InferenceResults inferenceResults, - boolean noInferenceResults - ) { + private SemanticQueryBuilder(SemanticQueryBuilder other, EmbeddingsProvider embeddingsProvider, boolean noInferenceResults) { this.fieldName = other.fieldName; this.query = other.query; this.boost = other.boost; this.queryName = other.queryName; - this.inferenceResultsSupplier = inferenceResultsSupplier; - this.inferenceResults = inferenceResults; + this.embeddingsProvider = embeddingsProvider; this.noInferenceResults = noInferenceResults; this.lenient = other.lenient; } @@ -195,13 +199,32 @@ private QueryBuilder doRewriteBuildSemanticQuery(SearchExecutionContext searchEx if (fieldType == null) { return new MatchNoneQueryBuilder(); } else if (fieldType instanceof SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType) { - if (inferenceResults == null) { + if (embeddingsProvider == null) { // This should never happen, but throw on it in case it ever does throw new IllegalStateException( "No inference results set for [" + semanticTextFieldType.typeName() + "] field [" + fieldName + "]" ); } + // TODO: Check that model registry supplier has been set + String inferenceId = semanticTextFieldType.getSearchInferenceId(); + MinimalServiceSettings serviceSettings = MODEL_REGISTRY_SUPPLIER.get().getMinimalServiceSettings(inferenceId); + InferenceEndpointKey inferenceEndpointKey = new InferenceEndpointKey(inferenceId, serviceSettings); + InferenceResults inferenceResults = embeddingsProvider.getEmbeddings(inferenceEndpointKey); + + // TODO: Handle ErrorInferenceResults and WarningInferenceResults + if (inferenceResults == null) { + throw new IllegalStateException( + "No inference results set for [" + + semanticTextFieldType.typeName() + + "] field [" + + fieldName + + "] with inference ID [" + + inferenceId + + "]" + ); + } + return semanticTextFieldType.semanticQuery(inferenceResults, searchExecutionContext.requestSize(), boost(), queryName()); } else if (lenient != null && lenient) { return new MatchNoneQueryBuilder(); @@ -213,15 +236,10 @@ private QueryBuilder doRewriteBuildSemanticQuery(SearchExecutionContext searchEx } private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewriteContext) { - if (inferenceResults != null || noInferenceResults) { + if (embeddingsProvider != null || noInferenceResults) { return this; } - if (inferenceResultsSupplier != null) { - InferenceResults inferenceResults = validateAndConvertInferenceResults(inferenceResultsSupplier, fieldName); - return inferenceResults != null ? new SemanticQueryBuilder(this, null, inferenceResults, noInferenceResults) : this; - } - ResolvedIndices resolvedIndices = queryRewriteContext.getResolvedIndices(); if (resolvedIndices == null) { throw new IllegalStateException( @@ -229,10 +247,16 @@ private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext qu ); } - String inferenceId = getInferenceIdForForField(resolvedIndices.getConcreteLocalIndicesMetadata().values(), fieldName); - SetOnce inferenceResultsSupplier = new SetOnce<>(); - boolean noInferenceResults = false; - if (inferenceId != null) { + Set inferenceIds = getInferenceIdsForForField(resolvedIndices.getConcreteLocalIndicesMetadata().values(), fieldName); + MapEmbeddingsProvider mapEmbeddingsProvider = new MapEmbeddingsProvider(); + + // The inference ID set can be empty if either the field name or index name(s) are invalid (or both). + // If this happens, we set the "no inference results" flag to true so the rewrite process can continue. + // Invalid index names will be handled in the transport layer, when the query is sent to the shard. + // Invalid field names will be handled when the query is re-written on the shard, where we have access to the index mappings. + boolean noInferenceResults = inferenceIds.isEmpty(); + + for (String inferenceId : inferenceIds) { InferenceAction.Request inferenceRequest = new InferenceAction.Request( TaskType.ANY, inferenceId, @@ -246,6 +270,9 @@ private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext qu false ); + // TODO: Check that model registry supplier has been set + MinimalServiceSettings serviceSettings = MODEL_REGISTRY_SUPPLIER.get().getMinimalServiceSettings(inferenceId); + InferenceEndpointKey inferenceEndpointKey = new InferenceEndpointKey(inferenceId, serviceSettings); queryRewriteContext.registerAsyncAction( (client, listener) -> executeAsyncWithOrigin( client, @@ -253,53 +280,57 @@ private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext qu InferenceAction.INSTANCE, inferenceRequest, listener.delegateFailureAndWrap((l, inferenceResponse) -> { - inferenceResultsSupplier.set(inferenceResponse.getResults()); + mapEmbeddingsProvider.addEmbeddings( + inferenceEndpointKey, + validateAndConvertInferenceResults(inferenceResponse.getResults(), fieldName, inferenceId) + ); l.onResponse(null); }) ) ); - } else { - // The inference ID can be null if either the field name or index name(s) are invalid (or both). - // If this happens, we set the "no inference results" flag to true so the rewrite process can continue. - // Invalid index names will be handled in the transport layer, when the query is sent to the shard. - // Invalid field names will be handled when the query is re-written on the shard, where we have access to the index mappings. - noInferenceResults = true; } - return new SemanticQueryBuilder(this, noInferenceResults ? null : inferenceResultsSupplier, null, noInferenceResults); + return new SemanticQueryBuilder(this, noInferenceResults ? null : mapEmbeddingsProvider, noInferenceResults); } private static InferenceResults validateAndConvertInferenceResults( - SetOnce inferenceResultsSupplier, - String fieldName + InferenceServiceResults inferenceServiceResults, + String fieldName, + String inferenceId ) { - InferenceServiceResults inferenceServiceResults = inferenceResultsSupplier.get(); - if (inferenceServiceResults == null) { - return null; - } - List inferenceResultsList = inferenceServiceResults.transformToCoordinationFormat(); if (inferenceResultsList.isEmpty()) { - throw new IllegalArgumentException("No inference results retrieved for field [" + fieldName + "]"); + return new ErrorInferenceResults( + new IllegalArgumentException( + "No inference results retrieved for field [" + fieldName + "] with inference ID [" + inferenceId + "]" + ) + ); } else if (inferenceResultsList.size() > 1) { // The inference call should truncate if the query is too large. // Thus, if we receive more than one inference result, it is a server-side error. - throw new IllegalStateException(inferenceResultsList.size() + " inference results retrieved for field [" + fieldName + "]"); + return new ErrorInferenceResults( + new IllegalStateException( + inferenceResultsList.size() + + " inference results retrieved for field [" + + fieldName + + "] with inference ID [" + + inferenceId + + "]" + ) + ); } InferenceResults inferenceResults = inferenceResultsList.get(0); - if (inferenceResults instanceof ErrorInferenceResults errorInferenceResults) { - throw new IllegalStateException( - "Field [" + fieldName + "] query inference error: " + errorInferenceResults.getException().getMessage(), - errorInferenceResults.getException() - ); - } else if (inferenceResults instanceof WarningInferenceResults warningInferenceResults) { - throw new IllegalStateException("Field [" + fieldName + "] query inference warning: " + warningInferenceResults.getWarning()); - } else if (inferenceResults instanceof TextExpansionResults == false - && inferenceResults instanceof MlTextEmbeddingResults == false) { - throw new IllegalArgumentException( + if (inferenceResults instanceof TextExpansionResults == false + && inferenceResults instanceof MlTextEmbeddingResults == false + && inferenceResults instanceof ErrorInferenceResults == false + && inferenceResults instanceof WarningInferenceResults == false) { + return new ErrorInferenceResults( + new IllegalArgumentException( "Field [" + fieldName + + "] with inference ID [" + + inferenceId + "] expected query inference results to be of type [" + TextExpansionResults.NAME + "] or [" @@ -307,8 +338,9 @@ private static InferenceResults validateAndConvertInferenceResults( + "], got [" + inferenceResults.getWriteableName() + "]. Has the inference endpoint configuration changed?" - ); - } + ) + ); + } return inferenceResults; } @@ -318,33 +350,30 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { throw new IllegalStateException(NAME + " should have been rewritten to another query type"); } - private static String getInferenceIdForForField(Collection indexMetadataCollection, String fieldName) { - String inferenceId = null; + private static Set getInferenceIdsForForField(Collection indexMetadataCollection, String fieldName) { + Set inferenceIds = new HashSet<>(); for (IndexMetadata indexMetadata : indexMetadataCollection) { InferenceFieldMetadata inferenceFieldMetadata = indexMetadata.getInferenceFields().get(fieldName); String indexInferenceId = inferenceFieldMetadata != null ? inferenceFieldMetadata.getSearchInferenceId() : null; if (indexInferenceId != null) { - if (inferenceId != null && inferenceId.equals(indexInferenceId) == false) { - throw new IllegalArgumentException("Field [" + fieldName + "] has multiple inference IDs associated with it"); - } - - inferenceId = indexInferenceId; + inferenceIds.add(indexInferenceId); } } - return inferenceId; + return inferenceIds; } @Override protected boolean doEquals(SemanticQueryBuilder other) { return Objects.equals(fieldName, other.fieldName) && Objects.equals(query, other.query) - && Objects.equals(inferenceResults, other.inferenceResults) - && Objects.equals(inferenceResultsSupplier, other.inferenceResultsSupplier); + && Objects.equals(embeddingsProvider, other.embeddingsProvider) + && Objects.equals(noInferenceResults, other.noInferenceResults) + && Objects.equals(lenient, other.lenient); } @Override protected int doHashCode() { - return Objects.hash(fieldName, query, inferenceResults, inferenceResultsSupplier); + return Objects.hash(fieldName, query, embeddingsProvider, noInferenceResults, lenient); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SingleEmbeddingsProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SingleEmbeddingsProvider.java index 78c66294cef6c..5b7f4535d23dc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SingleEmbeddingsProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SingleEmbeddingsProvider.java @@ -12,6 +12,7 @@ import org.elasticsearch.inference.InferenceResults; import java.io.IOException; +import java.util.Objects; public class SingleEmbeddingsProvider implements EmbeddingsProvider { public static final String NAME = "single_embeddings_provider"; @@ -40,4 +41,17 @@ public void writeTo(StreamOutput out) throws IOException { public InferenceResults getEmbeddings(InferenceEndpointKey key) { return embeddings; } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SingleEmbeddingsProvider that = (SingleEmbeddingsProvider) o; + return Objects.equals(embeddings, that.embeddings); + } + + @Override + public int hashCode() { + return Objects.hashCode(embeddings); + } } From 071f6c54a02f24416522846d3bd6e9e0147917db Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Fri, 1 Aug 2025 16:06:10 -0400 Subject: [PATCH 22/35] Added TODOs --- .../xpack/inference/queries/SemanticQueryBuilder.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index dbddd7492b65d..7be6fa746abc6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -52,6 +52,9 @@ import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; +// TODO: Only allow CCS when ccs_minimize_roundtrips=true +// TODO: Add flag to perform inference again during remote cluster coordinator rewrite + public class SemanticQueryBuilder extends AbstractQueryBuilder { public static final String NAME = "semantic"; From 0b945d4f82605fcd7ac82e619ddff3888ca6ed86 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Fri, 1 Aug 2025 16:24:46 -0400 Subject: [PATCH 23/35] Make ccsMinimizeRoundtrips nullable --- .../validate/query/TransportValidateQueryAction.java | 2 +- .../action/explain/TransportExplainAction.java | 2 +- .../index/query/QueryRewriteContext.java | 12 ++++++------ .../org/elasticsearch/indices/IndicesService.java | 2 +- .../java/org/elasticsearch/search/SearchService.java | 4 ++-- .../action/search/TransportSearchActionTests.java | 2 +- .../function/fulltext/QueryBuilderResolver.java | 2 +- 7 files changed, 13 insertions(+), 13 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/validate/query/TransportValidateQueryAction.java b/server/src/main/java/org/elasticsearch/action/admin/indices/validate/query/TransportValidateQueryAction.java index c1c7384688df6..e330bef2b756b 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/validate/query/TransportValidateQueryAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/validate/query/TransportValidateQueryAction.java @@ -130,7 +130,7 @@ protected void doExecute(Task task, ValidateQueryRequest request, ActionListener } else { Rewriteable.rewriteAndFetch( request.query(), - searchService.getRewriteContext(timeProvider, resolvedIndices, null, true), + searchService.getRewriteContext(timeProvider, resolvedIndices, null, null), rewriteListener ); } diff --git a/server/src/main/java/org/elasticsearch/action/explain/TransportExplainAction.java b/server/src/main/java/org/elasticsearch/action/explain/TransportExplainAction.java index 23c4c136ca853..009234fad56a2 100644 --- a/server/src/main/java/org/elasticsearch/action/explain/TransportExplainAction.java +++ b/server/src/main/java/org/elasticsearch/action/explain/TransportExplainAction.java @@ -105,7 +105,7 @@ protected void doExecute(Task task, ExplainRequest request, ActionListener request.nowInMillis; Rewriteable.rewriteAndFetch( request.query(), - searchService.getRewriteContext(timeProvider, resolvedIndices, null, true), + searchService.getRewriteContext(timeProvider, resolvedIndices, null, null), rewriteListener ); } diff --git a/server/src/main/java/org/elasticsearch/index/query/QueryRewriteContext.java b/server/src/main/java/org/elasticsearch/index/query/QueryRewriteContext.java index fde94d4077154..d4d445f479070 100644 --- a/server/src/main/java/org/elasticsearch/index/query/QueryRewriteContext.java +++ b/server/src/main/java/org/elasticsearch/index/query/QueryRewriteContext.java @@ -72,7 +72,7 @@ public class QueryRewriteContext { private final ResolvedIndices resolvedIndices; private final PointInTimeBuilder pit; private QueryRewriteInterceptor queryRewriteInterceptor; - private final boolean ccsMinimizeRoundtrips; + private final Boolean ccsMinimizeRoundtrips; private final boolean isExplain; public QueryRewriteContext( @@ -92,7 +92,7 @@ public QueryRewriteContext( final ResolvedIndices resolvedIndices, final PointInTimeBuilder pit, final QueryRewriteInterceptor queryRewriteInterceptor, - final boolean ccsMinimizeRoundtrips, + final Boolean ccsMinimizeRoundtrips, final boolean isExplain ) { @@ -135,7 +135,7 @@ public QueryRewriteContext(final XContentParserConfiguration parserConfiguration null, null, null, - true, + null, false ); } @@ -147,7 +147,7 @@ public QueryRewriteContext( final ResolvedIndices resolvedIndices, final PointInTimeBuilder pit, final QueryRewriteInterceptor queryRewriteInterceptor, - final boolean ccsMinimizeRoundtrips + final Boolean ccsMinimizeRoundtrips ) { this(parserConfiguration, client, nowInMillis, resolvedIndices, pit, queryRewriteInterceptor, ccsMinimizeRoundtrips, false); } @@ -159,7 +159,7 @@ public QueryRewriteContext( final ResolvedIndices resolvedIndices, final PointInTimeBuilder pit, final QueryRewriteInterceptor queryRewriteInterceptor, - final boolean ccsMinimizeRoundtrips, + final Boolean ccsMinimizeRoundtrips, final boolean isExplain ) { this( @@ -286,7 +286,7 @@ public void setMapUnmappedFieldAsString(boolean mapUnmappedFieldAsString) { this.mapUnmappedFieldAsString = mapUnmappedFieldAsString; } - public boolean isCcsMinimizeRoundtrips() { + public Boolean isCcsMinimizeRoundtrips() { return ccsMinimizeRoundtrips; } diff --git a/server/src/main/java/org/elasticsearch/indices/IndicesService.java b/server/src/main/java/org/elasticsearch/indices/IndicesService.java index 081df95d94527..caa076eb057f1 100644 --- a/server/src/main/java/org/elasticsearch/indices/IndicesService.java +++ b/server/src/main/java/org/elasticsearch/indices/IndicesService.java @@ -1849,7 +1849,7 @@ public QueryRewriteContext getRewriteContext( LongSupplier nowInMillis, ResolvedIndices resolvedIndices, PointInTimeBuilder pit, - final boolean ccsMinimizeRoundtrips, + final Boolean ccsMinimizeRoundtrips, final boolean isExplain ) { return new QueryRewriteContext( diff --git a/server/src/main/java/org/elasticsearch/search/SearchService.java b/server/src/main/java/org/elasticsearch/search/SearchService.java index 49798f49de3b6..4e819ef739a2a 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchService.java +++ b/server/src/main/java/org/elasticsearch/search/SearchService.java @@ -2131,7 +2131,7 @@ public QueryRewriteContext getRewriteContext( LongSupplier nowInMillis, ResolvedIndices resolvedIndices, PointInTimeBuilder pit, - final boolean ccsMinimizeRoundtrips + final Boolean ccsMinimizeRoundtrips ) { return getRewriteContext(nowInMillis, resolvedIndices, pit, ccsMinimizeRoundtrips, false); } @@ -2143,7 +2143,7 @@ public QueryRewriteContext getRewriteContext( LongSupplier nowInMillis, ResolvedIndices resolvedIndices, PointInTimeBuilder pit, - final boolean ccsMinimizeRoundtrips, + final Boolean ccsMinimizeRoundtrips, final boolean isExplain ) { return indicesService.getRewriteContext(nowInMillis, resolvedIndices, pit, ccsMinimizeRoundtrips, isExplain); diff --git a/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java b/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java index fd97b4b95fe35..d20df0f2154ce 100644 --- a/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java @@ -1785,7 +1785,7 @@ protected void doWriteTo(StreamOutput out) throws IOException { SearchService searchService = mock(SearchService.class); when(searchService.getRewriteContext(any(), any(), any(), anyBoolean(), anyBoolean())).thenReturn( - new QueryRewriteContext(null, null, null, null, null, null, true) + new QueryRewriteContext(null, null, null, null, null, null, null) ); ClusterService clusterService = new ClusterService( settings, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/QueryBuilderResolver.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/QueryBuilderResolver.java index 2b6fe947994b3..1f08595c71cbb 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/QueryBuilderResolver.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/QueryBuilderResolver.java @@ -61,7 +61,7 @@ private static QueryRewriteContext queryRewriteContext(TransportActionServices s System.currentTimeMillis() ); - return services.searchService().getRewriteContext(System::currentTimeMillis, resolvedIndices, null, true); + return services.searchService().getRewriteContext(System::currentTimeMillis, resolvedIndices, null, null); } private static Set indexNames(LogicalPlan plan) { From 33d1be14c5ab76aabb657fa0bb207271d7d2fe05 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Fri, 1 Aug 2025 16:37:09 -0400 Subject: [PATCH 24/35] Fix build errors --- server/src/main/java/org/elasticsearch/index/IndexService.java | 1 + .../elasticsearch/index/query/CoordinatorRewriteContext.java | 1 + .../org/elasticsearch/index/query/SearchExecutionContext.java | 1 + .../org/elasticsearch/index/query/QueryRewriteContextTests.java | 2 ++ .../java/org/elasticsearch/test/AbstractBuilderTestCase.java | 1 + .../query/SemanticKnnVectorQueryRewriteInterceptorTests.java | 2 +- .../index/query/SemanticMatchQueryRewriteInterceptorTests.java | 2 +- .../query/SemanticSparseVectorQueryRewriteInterceptorTests.java | 2 +- 8 files changed, 9 insertions(+), 3 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/IndexService.java b/server/src/main/java/org/elasticsearch/index/IndexService.java index 2ab766a1253a7..c74930fb6fbe6 100644 --- a/server/src/main/java/org/elasticsearch/index/IndexService.java +++ b/server/src/main/java/org/elasticsearch/index/IndexService.java @@ -836,6 +836,7 @@ public QueryRewriteContext newQueryRewriteContext( null, null, null, + null, false ); } diff --git a/server/src/main/java/org/elasticsearch/index/query/CoordinatorRewriteContext.java b/server/src/main/java/org/elasticsearch/index/query/CoordinatorRewriteContext.java index e0cfd86c10c6d..520cf6f6bf16f 100644 --- a/server/src/main/java/org/elasticsearch/index/query/CoordinatorRewriteContext.java +++ b/server/src/main/java/org/elasticsearch/index/query/CoordinatorRewriteContext.java @@ -120,6 +120,7 @@ public CoordinatorRewriteContext( null, null, null, + null, false ); this.dateFieldRangeInfo = dateFieldRangeInfo; diff --git a/server/src/main/java/org/elasticsearch/index/query/SearchExecutionContext.java b/server/src/main/java/org/elasticsearch/index/query/SearchExecutionContext.java index b2c0cdab8d16e..45ac597b084fb 100644 --- a/server/src/main/java/org/elasticsearch/index/query/SearchExecutionContext.java +++ b/server/src/main/java/org/elasticsearch/index/query/SearchExecutionContext.java @@ -275,6 +275,7 @@ private SearchExecutionContext( null, null, null, + null, false ); this.shardId = shardId; diff --git a/server/src/test/java/org/elasticsearch/index/query/QueryRewriteContextTests.java b/server/src/test/java/org/elasticsearch/index/query/QueryRewriteContextTests.java index b997ac4747a07..a0f88e61963c3 100644 --- a/server/src/test/java/org/elasticsearch/index/query/QueryRewriteContextTests.java +++ b/server/src/test/java/org/elasticsearch/index/query/QueryRewriteContextTests.java @@ -54,6 +54,7 @@ public void testGetTierPreference() { null, null, null, + null, false ); @@ -83,6 +84,7 @@ public void testGetTierPreference() { null, null, null, + null, false ); diff --git a/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java index be30dbe9823d4..5597d01ba0b42 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java @@ -637,6 +637,7 @@ QueryRewriteContext createQueryRewriteContext() { createMockResolvedIndices(), null, createMockQueryRewriteInterceptor(), + null, false ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java index 901a056ac5c47..a923309756509 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java @@ -151,7 +151,7 @@ private QueryRewriteContext createQueryRewriteContext(Map Date: Mon, 4 Aug 2025 09:27:27 -0400 Subject: [PATCH 25/35] Fix test errors --- .../search/ccs/SemanticCrossClusterSearchIT.java | 2 +- .../xpack/inference/InferenceNamedWriteablesProvider.java | 4 ++++ .../xpack/inference/queries/MapEmbeddingsProvider.java | 2 +- .../xpack/inference/LocalStateInferencePlugin.java | 5 +++++ 4 files changed, 11 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java index 11a68f7a8ce80..af3fe85fbae9d 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java @@ -95,7 +95,7 @@ public void testSemanticCrossClusterSearch() throws Exception { SearchRequest searchRequest = new SearchRequest(localIndex, REMOTE_CLUSTER + ":" + remoteIndex); searchRequest.source(new SearchSourceBuilder().query(new SemanticQueryBuilder(INFERENCE_FIELD, "foo")).size(10)); - searchRequest.setCcsMinimizeRoundtrips(false); + searchRequest.setCcsMinimizeRoundtrips(true); assertResponse(client(LOCAL_CLUSTER).search(searchRequest), response -> { assertNotNull(response); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 6fd07cd4c2831..51a7ef867daf6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -13,6 +13,7 @@ import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.SecretSettings; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.TaskSettings; @@ -185,6 +186,9 @@ public static List getNamedWriteables() { namedWriteables.addAll(DeepSeekChatCompletionModel.namedWriteables()); namedWriteables.addAll(SageMakerModel.namedWriteables()); namedWriteables.addAll(SageMakerSchemas.namedWriteables()); + namedWriteables.add( + new NamedWriteableRegistry.Entry(MinimalServiceSettings.class, MinimalServiceSettings.NAME, MinimalServiceSettings::new) + ); return namedWriteables; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/MapEmbeddingsProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/MapEmbeddingsProvider.java index 68edf16bcc80a..5e6be7ad37eeb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/MapEmbeddingsProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/MapEmbeddingsProvider.java @@ -36,7 +36,7 @@ public String getWriteableName() { @Override public void writeTo(StreamOutput out) throws IOException { - out.writeMap(embeddings); + out.writeMap(embeddings, StreamOutput::writeWriteable, StreamOutput::writeNamedWriteable); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/LocalStateInferencePlugin.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/LocalStateInferencePlugin.java index 5aa42520d74bd..553af8b5b16b4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/LocalStateInferencePlugin.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/LocalStateInferencePlugin.java @@ -73,4 +73,9 @@ public Map getHighlighters() { public Collection getMappedActionFilters() { return inferencePlugin.getMappedActionFilters(); } + + @Override + public void onNodeStarted() { + inferencePlugin.onNodeStarted(); + } } From 43dda15b89d02d7a2cb676578b0d06e7e60befe2 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Mon, 4 Aug 2025 09:42:04 -0400 Subject: [PATCH 26/35] Check that model registry is set --- .../inference/queries/SemanticQueryBuilder.java | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index 7be6fa746abc6..d1237864e92d9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -209,9 +209,13 @@ private QueryBuilder doRewriteBuildSemanticQuery(SearchExecutionContext searchEx ); } - // TODO: Check that model registry supplier has been set + ModelRegistry modelRegistry = MODEL_REGISTRY_SUPPLIER.get(); + if (modelRegistry == null) { + throw new IllegalStateException("Model registry has not been set"); + } + String inferenceId = semanticTextFieldType.getSearchInferenceId(); - MinimalServiceSettings serviceSettings = MODEL_REGISTRY_SUPPLIER.get().getMinimalServiceSettings(inferenceId); + MinimalServiceSettings serviceSettings = modelRegistry.getMinimalServiceSettings(inferenceId); InferenceEndpointKey inferenceEndpointKey = new InferenceEndpointKey(inferenceId, serviceSettings); InferenceResults inferenceResults = embeddingsProvider.getEmbeddings(inferenceEndpointKey); @@ -273,8 +277,12 @@ private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext qu false ); - // TODO: Check that model registry supplier has been set - MinimalServiceSettings serviceSettings = MODEL_REGISTRY_SUPPLIER.get().getMinimalServiceSettings(inferenceId); + ModelRegistry modelRegistry = MODEL_REGISTRY_SUPPLIER.get(); + if (modelRegistry == null) { + throw new IllegalStateException("Model registry has not been set"); + } + + MinimalServiceSettings serviceSettings = modelRegistry.getMinimalServiceSettings(inferenceId); InferenceEndpointKey inferenceEndpointKey = new InferenceEndpointKey(inferenceId, serviceSettings); queryRewriteContext.registerAsyncAction( (client, listener) -> executeAsyncWithOrigin( From cd878bd414442c052c1a5e2da27260ad41307daa Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Mon, 4 Aug 2025 09:47:37 -0400 Subject: [PATCH 27/35] Allow CCS only when ccs_minimize_roundtrips=true --- .../xpack/inference/queries/SemanticQueryBuilder.java | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index d1237864e92d9..0f65a79ae7188 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -52,7 +52,6 @@ import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; -// TODO: Only allow CCS when ccs_minimize_roundtrips=true // TODO: Add flag to perform inference again during remote cluster coordinator rewrite public class SemanticQueryBuilder extends AbstractQueryBuilder { @@ -252,6 +251,10 @@ private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext qu throw new IllegalStateException( "Rewriting on the coordinator node requires a query rewrite context with non-null resolved indices" ); + } else if (resolvedIndices.getRemoteClusterIndices().isEmpty() == false) { + if (queryRewriteContext.isCcsMinimizeRoundtrips() != true) { + throw new IllegalArgumentException(NAME + " query supports CCS only when ccs_minimize_roundtrips=true"); + } } Set inferenceIds = getInferenceIdsForForField(resolvedIndices.getConcreteLocalIndicesMetadata().values(), fieldName); From d3937c13689d5cb4cbdb649efbdf4038c74ef030 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Mon, 4 Aug 2025 10:44:16 -0400 Subject: [PATCH 28/35] Perform inference on remote cluster when necessary --- .../queries/SemanticQueryBuilder.java | 97 +++++++++++-------- 1 file changed, 55 insertions(+), 42 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index 0f65a79ae7188..e01a7aa6185c8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -52,7 +52,7 @@ import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; -// TODO: Add flag to perform inference again during remote cluster coordinator rewrite +// TODO: Remove noInferenceResults public class SemanticQueryBuilder extends AbstractQueryBuilder { public static final String NAME = "semantic"; @@ -242,7 +242,9 @@ private QueryBuilder doRewriteBuildSemanticQuery(SearchExecutionContext searchEx } private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewriteContext) { - if (embeddingsProvider != null || noInferenceResults) { + // Check that we are performing a coordinator node rewrite + // TODO: Clean up how we perform this check + if (queryRewriteContext.getClass() != QueryRewriteContext.class) { return this; } @@ -257,54 +259,65 @@ private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext qu } } - Set inferenceIds = getInferenceIdsForForField(resolvedIndices.getConcreteLocalIndicesMetadata().values(), fieldName); - MapEmbeddingsProvider mapEmbeddingsProvider = new MapEmbeddingsProvider(); - - // The inference ID set can be empty if either the field name or index name(s) are invalid (or both). - // If this happens, we set the "no inference results" flag to true so the rewrite process can continue. - // Invalid index names will be handled in the transport layer, when the query is sent to the shard. - // Invalid field names will be handled when the query is re-written on the shard, where we have access to the index mappings. - boolean noInferenceResults = inferenceIds.isEmpty(); - - for (String inferenceId : inferenceIds) { - InferenceAction.Request inferenceRequest = new InferenceAction.Request( - TaskType.ANY, - inferenceId, - null, - null, - null, - List.of(query), - Map.of(), - InputType.INTERNAL_SEARCH, - null, - false - ); + MapEmbeddingsProvider currentEmbeddingsProvider; + if (embeddingsProvider != null) { + if (embeddingsProvider instanceof MapEmbeddingsProvider mapEmbeddingsProvider) { + currentEmbeddingsProvider = mapEmbeddingsProvider; + } else { + throw new IllegalStateException("Current embeddings provider should be a MapEmbeddingsProvider"); + } + } else { + currentEmbeddingsProvider = new MapEmbeddingsProvider(); + } + boolean modified = false; + if (queryRewriteContext.hasAsyncActions() == false) { ModelRegistry modelRegistry = MODEL_REGISTRY_SUPPLIER.get(); if (modelRegistry == null) { throw new IllegalStateException("Model registry has not been set"); } - MinimalServiceSettings serviceSettings = modelRegistry.getMinimalServiceSettings(inferenceId); - InferenceEndpointKey inferenceEndpointKey = new InferenceEndpointKey(inferenceId, serviceSettings); - queryRewriteContext.registerAsyncAction( - (client, listener) -> executeAsyncWithOrigin( - client, - ML_ORIGIN, - InferenceAction.INSTANCE, - inferenceRequest, - listener.delegateFailureAndWrap((l, inferenceResponse) -> { - mapEmbeddingsProvider.addEmbeddings( - inferenceEndpointKey, - validateAndConvertInferenceResults(inferenceResponse.getResults(), fieldName, inferenceId) - ); - l.onResponse(null); - }) - ) - ); + Set inferenceIds = getInferenceIdsForForField(resolvedIndices.getConcreteLocalIndicesMetadata().values(), fieldName); + for (String inferenceId : inferenceIds) { + MinimalServiceSettings serviceSettings = modelRegistry.getMinimalServiceSettings(inferenceId); + InferenceEndpointKey inferenceEndpointKey = new InferenceEndpointKey(inferenceId, serviceSettings); + + if (currentEmbeddingsProvider.getEmbeddings(inferenceEndpointKey) == null) { + InferenceAction.Request inferenceRequest = new InferenceAction.Request( + TaskType.ANY, + inferenceId, + null, + null, + null, + List.of(query), + Map.of(), + InputType.INTERNAL_SEARCH, + null, + false + ); + + queryRewriteContext.registerAsyncAction( + (client, listener) -> executeAsyncWithOrigin( + client, + ML_ORIGIN, + InferenceAction.INSTANCE, + inferenceRequest, + listener.delegateFailureAndWrap((l, inferenceResponse) -> { + currentEmbeddingsProvider.addEmbeddings( + inferenceEndpointKey, + validateAndConvertInferenceResults(inferenceResponse.getResults(), fieldName, inferenceId) + ); + l.onResponse(null); + }) + ) + ); + + modified = true; + } + } } - return new SemanticQueryBuilder(this, noInferenceResults ? null : mapEmbeddingsProvider, noInferenceResults); + return modified ? new SemanticQueryBuilder(this, currentEmbeddingsProvider, false) : this; } private static InferenceResults validateAndConvertInferenceResults( From 659a79f44539098098d219e8412b7f69942f8d19 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Mon, 4 Aug 2025 11:49:41 -0400 Subject: [PATCH 29/35] Update integration test to use different inference IDs across clusters --- .../ccs/SemanticCrossClusterSearchIT.java | 29 +++++++++++++++---- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java index af3fe85fbae9d..c0f00835192e6 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java @@ -52,7 +52,7 @@ public class SemanticCrossClusterSearchIT extends AbstractMultiClustersTestCase private static final String REMOTE_CLUSTER = "cluster_a"; private static final String INFERENCE_FIELD = "inference_field"; - private static final Map BBQ_COMPATIBLE_SERVICE_SETTINGS = Map.of( + private static final Map TEXT_EMBEDDING_SERVICE_SETTINGS_1 = Map.of( "model", "my_model", "dimensions", @@ -63,6 +63,17 @@ public class SemanticCrossClusterSearchIT extends AbstractMultiClustersTestCase "my_api_key" ); + private static final Map TEXT_EMBEDDING_SERVICE_SETTINGS_2 = Map.of( + "model", + "my_model", + "dimensions", + 384, + "similarity", + "cosine", + "api_key", + "my_api_key" + ); + @Override protected List remoteClusterAlias() { return List.of(REMOTE_CLUSTER); @@ -142,9 +153,15 @@ public void testMatchCrossClusterSearch() throws Exception { } private Map setupTwoClusters(String[] localIndices, String[] remoteIndices) throws IOException { - final String inferenceId = "test_inference_id"; - createInferenceEndpoint(client(LOCAL_CLUSTER), TaskType.TEXT_EMBEDDING, inferenceId, BBQ_COMPATIBLE_SERVICE_SETTINGS); - createInferenceEndpoint(client(REMOTE_CLUSTER), TaskType.TEXT_EMBEDDING, inferenceId, BBQ_COMPATIBLE_SERVICE_SETTINGS); + final String localInferenceId = "local_inference_id"; + final String remoteInferenceId = "remote_inference_id"; + + // TODO: Resolve bug where remote model registry overwrites local model registry in SemanticQueryBuilder + createInferenceEndpoint(client(LOCAL_CLUSTER), TaskType.TEXT_EMBEDDING, localInferenceId, TEXT_EMBEDDING_SERVICE_SETTINGS_1); + createInferenceEndpoint(client(LOCAL_CLUSTER), TaskType.TEXT_EMBEDDING, remoteInferenceId, TEXT_EMBEDDING_SERVICE_SETTINGS_2); + + createInferenceEndpoint(client(REMOTE_CLUSTER), TaskType.TEXT_EMBEDDING, localInferenceId, TEXT_EMBEDDING_SERVICE_SETTINGS_1); + createInferenceEndpoint(client(REMOTE_CLUSTER), TaskType.TEXT_EMBEDDING, remoteInferenceId, TEXT_EMBEDDING_SERVICE_SETTINGS_2); int numShardsLocal = randomIntBetween(2, 10); Settings localSettings = indexSettings(numShardsLocal, randomIntBetween(0, 1)).build(); @@ -154,7 +171,7 @@ private Map setupTwoClusters(String[] localIndices, String[] rem .indices() .prepareCreate(localIndex) .setSettings(localSettings) - .setMapping(INFERENCE_FIELD, "type=semantic_text,inference_id=" + inferenceId) + .setMapping(INFERENCE_FIELD, "type=semantic_text,inference_id=" + localInferenceId) ); indexDocs(client(LOCAL_CLUSTER), localIndex); } @@ -168,7 +185,7 @@ private Map setupTwoClusters(String[] localIndices, String[] rem .indices() .prepareCreate(remoteIndex) .setSettings(indexSettings(numShardsRemote, randomIntBetween(0, 1))) - .setMapping(INFERENCE_FIELD, "type=semantic_text,inference_id=" + inferenceId) + .setMapping(INFERENCE_FIELD, "type=semantic_text,inference_id=" + remoteInferenceId) ); assertFalse( client(REMOTE_CLUSTER).admin() From b407a4f79971e7f8b3a3983b20ca16c69c69cbbf Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Mon, 4 Aug 2025 13:06:14 -0400 Subject: [PATCH 30/35] Set model registry for each semantic query builder instance --- .../ccs/SemanticCrossClusterSearchIT.java | 12 ++++++------ .../xpack/inference/InferencePlugin.java | 12 +++++++++--- .../inference/queries/SemanticQueryBuilder.java | 17 +++++++++-------- .../inference/LocalStateInferencePlugin.java | 5 ----- 4 files changed, 24 insertions(+), 22 deletions(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java index c0f00835192e6..ae05e0c2b9c64 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java @@ -37,6 +37,7 @@ import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin; import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.io.IOException; import java.util.Collection; @@ -104,8 +105,12 @@ public void testSemanticCrossClusterSearch() throws Exception { String localIndex = (String) testClusterInfo.get("local.index"); String remoteIndex = (String) testClusterInfo.get("remote.index"); + ModelRegistry modelRegistry = cluster(LOCAL_CLUSTER).getCurrentMasterNodeInstance(ModelRegistry.class); + SemanticQueryBuilder queryBuilder = new SemanticQueryBuilder(INFERENCE_FIELD, "foo"); + queryBuilder.setModelRegistrySupplier(() -> modelRegistry); + SearchRequest searchRequest = new SearchRequest(localIndex, REMOTE_CLUSTER + ":" + remoteIndex); - searchRequest.source(new SearchSourceBuilder().query(new SemanticQueryBuilder(INFERENCE_FIELD, "foo")).size(10)); + searchRequest.source(new SearchSourceBuilder().query(queryBuilder).size(10)); searchRequest.setCcsMinimizeRoundtrips(true); assertResponse(client(LOCAL_CLUSTER).search(searchRequest), response -> { @@ -155,12 +160,7 @@ public void testMatchCrossClusterSearch() throws Exception { private Map setupTwoClusters(String[] localIndices, String[] remoteIndices) throws IOException { final String localInferenceId = "local_inference_id"; final String remoteInferenceId = "remote_inference_id"; - - // TODO: Resolve bug where remote model registry overwrites local model registry in SemanticQueryBuilder createInferenceEndpoint(client(LOCAL_CLUSTER), TaskType.TEXT_EMBEDDING, localInferenceId, TEXT_EMBEDDING_SERVICE_SETTINGS_1); - createInferenceEndpoint(client(LOCAL_CLUSTER), TaskType.TEXT_EMBEDDING, remoteInferenceId, TEXT_EMBEDDING_SERVICE_SETTINGS_2); - - createInferenceEndpoint(client(REMOTE_CLUSTER), TaskType.TEXT_EMBEDDING, localInferenceId, TEXT_EMBEDDING_SERVICE_SETTINGS_1); createInferenceEndpoint(client(REMOTE_CLUSTER), TaskType.TEXT_EMBEDDING, remoteInferenceId, TEXT_EMBEDDING_SERVICE_SETTINGS_2); int numShardsLocal = randomIntBetween(2, 10); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index e7730d79af492..0ccea65d53e56 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -568,7 +568,15 @@ public Collection getMappedActionFilters() { } public List> getQueries() { - return List.of(new QuerySpec<>(SemanticQueryBuilder.NAME, SemanticQueryBuilder::new, SemanticQueryBuilder::fromXContent)); + return List.of(new QuerySpec<>(SemanticQueryBuilder.NAME, i -> { + SemanticQueryBuilder queryBuilder = new SemanticQueryBuilder(i); + queryBuilder.setModelRegistrySupplier(getModelRegistry()); + return queryBuilder; + }, p -> { + SemanticQueryBuilder queryBuilder = SemanticQueryBuilder.fromXContent(p); + queryBuilder.setModelRegistrySupplier(getModelRegistry()); + return queryBuilder; + })); } @Override @@ -602,8 +610,6 @@ public void onNodeStarted() { if (registry != null) { registry.onNodeStarted(); } - - SemanticQueryBuilder.setModelRegistrySupplier(getModelRegistry()); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index e01a7aa6185c8..c597650d3ce37 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -74,18 +74,14 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder MODEL_REGISTRY_SUPPLIER = () -> null; - - public static void setModelRegistrySupplier(Supplier supplier) { - MODEL_REGISTRY_SUPPLIER = supplier; - } - private final String fieldName; private final String query; private final EmbeddingsProvider embeddingsProvider; private final boolean noInferenceResults; private final Boolean lenient; + private Supplier modelRegistrySupplier = () -> null; + public SemanticQueryBuilder(String fieldName, String query) { this(fieldName, query, null); } @@ -126,6 +122,10 @@ public SemanticQueryBuilder(StreamInput in) throws IOException { } } + public void setModelRegistrySupplier(Supplier supplier) { + modelRegistrySupplier = supplier; + } + @Override protected void doWriteTo(StreamOutput out) throws IOException { out.writeString(fieldName); @@ -150,6 +150,7 @@ private SemanticQueryBuilder(SemanticQueryBuilder other, EmbeddingsProvider embe this.embeddingsProvider = embeddingsProvider; this.noInferenceResults = noInferenceResults; this.lenient = other.lenient; + this.modelRegistrySupplier = other.modelRegistrySupplier; } @Override @@ -208,7 +209,7 @@ private QueryBuilder doRewriteBuildSemanticQuery(SearchExecutionContext searchEx ); } - ModelRegistry modelRegistry = MODEL_REGISTRY_SUPPLIER.get(); + ModelRegistry modelRegistry = modelRegistrySupplier.get(); if (modelRegistry == null) { throw new IllegalStateException("Model registry has not been set"); } @@ -272,7 +273,7 @@ private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext qu boolean modified = false; if (queryRewriteContext.hasAsyncActions() == false) { - ModelRegistry modelRegistry = MODEL_REGISTRY_SUPPLIER.get(); + ModelRegistry modelRegistry = modelRegistrySupplier.get(); if (modelRegistry == null) { throw new IllegalStateException("Model registry has not been set"); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/LocalStateInferencePlugin.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/LocalStateInferencePlugin.java index 553af8b5b16b4..5aa42520d74bd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/LocalStateInferencePlugin.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/LocalStateInferencePlugin.java @@ -73,9 +73,4 @@ public Map getHighlighters() { public Collection getMappedActionFilters() { return inferencePlugin.getMappedActionFilters(); } - - @Override - public void onNodeStarted() { - inferencePlugin.onNodeStarted(); - } } From 10ec96dabcdde140e5420dc946dc4e96b8222e4f Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Mon, 4 Aug 2025 13:13:05 -0400 Subject: [PATCH 31/35] Revert CrossClusterSearchIT changes --- .../search/ccs/CrossClusterSearchIT.java | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java index 4c7f402c6636e..d4f60a868dcd4 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java @@ -25,7 +25,6 @@ import org.elasticsearch.common.util.CollectionUtils; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.query.MatchAllQueryBuilder; -import org.elasticsearch.index.query.MatchQueryBuilder; import org.elasticsearch.index.query.RangeQueryBuilder; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.SearchPlugin; @@ -162,21 +161,6 @@ public void testClusterDetailsAfterSuccessfulCCS() throws Exception { }); } - public void testCCSQueryRewrite() throws Exception { - Map testClusterInfo = setupTwoClusters(); - String localIndex = (String) testClusterInfo.get("local.index"); - String remoteIndex = (String) testClusterInfo.get("remote.index"); - int localNumShards = (Integer) testClusterInfo.get("local.num_shards"); - int remoteNumShards = (Integer) testClusterInfo.get("remote.num_shards"); - - SearchRequest searchRequest = new SearchRequest(localIndex, REMOTE_CLUSTER + ":" + remoteIndex); - // searchRequest.setCcsMinimizeRoundtrips(false); - - searchRequest.source(new SearchSourceBuilder().query(new MatchQueryBuilder("foo", "bar")).size(10)); - - assertResponse(client(LOCAL_CLUSTER).search(searchRequest), response -> { assertNotNull(response); }); - } - // CCS with a search where the timestamp of the query cannot match so should be SUCCESSFUL with all shards skipped // during can-match public void testCCSClusterDetailsWhereAllShardsSkippedInCanMatch() throws Exception { From 7e2d973e532931189ce451331284adc9ddc68407 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Mon, 4 Aug 2025 13:29:09 -0400 Subject: [PATCH 32/35] Revert InterceptedQueryBuilderWrapper changes --- .../action/search/TransportSearchAction.java | 10 +---- .../index/query/AbstractQueryBuilder.java | 2 +- .../index/query/InnerHitContextBuilder.java | 2 +- .../query/InterceptedQueryBuilderWrapper.java | 42 ++++++++----------- .../InterceptedQueryBuilderWrapperTests.java | 14 ++----- ...KnnVectorQueryRewriteInterceptorTests.java | 4 +- ...nticMatchQueryRewriteInterceptorTests.java | 8 ++-- ...rseVectorQueryRewriteInterceptorTests.java | 4 +- 8 files changed, 34 insertions(+), 52 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java index 55b1df707f430..e0d254a00c4f2 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java @@ -65,7 +65,6 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexNotFoundException; -import org.elasticsearch.index.query.InterceptedQueryBuilderWrapper; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.Rewriteable; import org.elasticsearch.index.shard.ShardId; @@ -80,7 +79,6 @@ import org.elasticsearch.search.aggregations.AggregationReduceContext; import org.elasticsearch.search.builder.PointInTimeBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; -import org.elasticsearch.search.builder.SubSearchSourceBuilder; import org.elasticsearch.search.internal.AliasFilter; import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.profile.SearchProfileResults; @@ -791,13 +789,9 @@ public void onFailure(Exception e) { boolean skipUnavailable = remoteClusterService.isSkipUnavailable(clusterAlias); OriginalIndices indices = entry.getValue(); - if (searchRequest.source().query() instanceof InterceptedQueryBuilderWrapper interceptedQuery) { - searchRequest.source().subSearches(List.of(new SubSearchSourceBuilder(interceptedQuery.getOriginal()))); - } - SearchRequest ccsSearchRequest = SearchRequest.subSearchRequest( parentTaskId, - searchRequest, // TODO: Need to prep the request here by stripping inference results? + searchRequest, indices.indices(), clusterAlias, timeProvider.absoluteStartMillis(), @@ -948,7 +942,7 @@ Map createFinalResponse() { SearchShardsRequest searchShardsRequest = new SearchShardsRequest( indices, indicesOptions, - query, // TODO: Need to prep the query here by stripping inference results? + query, routing, preference, allowPartialResults, diff --git a/server/src/main/java/org/elasticsearch/index/query/AbstractQueryBuilder.java b/server/src/main/java/org/elasticsearch/index/query/AbstractQueryBuilder.java index ccb3e076c5798..05262798bac2a 100644 --- a/server/src/main/java/org/elasticsearch/index/query/AbstractQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/AbstractQueryBuilder.java @@ -283,7 +283,7 @@ public final QueryBuilder rewrite(QueryRewriteContext queryRewriteContext) throw if (queryRewriteInterceptor != null) { var rewritten = queryRewriteInterceptor.interceptAndRewrite(queryRewriteContext, this); if (rewritten != this) { - return new InterceptedQueryBuilderWrapper(rewritten, this); + return new InterceptedQueryBuilderWrapper(rewritten); } } diff --git a/server/src/main/java/org/elasticsearch/index/query/InnerHitContextBuilder.java b/server/src/main/java/org/elasticsearch/index/query/InnerHitContextBuilder.java index 947c862d5b2b4..31bc7dddacb7f 100644 --- a/server/src/main/java/org/elasticsearch/index/query/InnerHitContextBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/InnerHitContextBuilder.java @@ -68,7 +68,7 @@ public static void extractInnerHits(QueryBuilder query, Map) query).extractInnerHitBuilders(innerHitBuilders); } else if (query instanceof InterceptedQueryBuilderWrapper interceptedQuery) { // Unwrap an intercepted query here - extractInnerHits(interceptedQuery.rewritten, innerHitBuilders); + extractInnerHits(interceptedQuery.queryBuilder, innerHitBuilders); } else { throw new IllegalStateException( "provided query builder [" + query.getClass() + "] class should inherit from AbstractQueryBuilder, but it doesn't" diff --git a/server/src/main/java/org/elasticsearch/index/query/InterceptedQueryBuilderWrapper.java b/server/src/main/java/org/elasticsearch/index/query/InterceptedQueryBuilderWrapper.java index b3d4ae7226f27..389c9bfa837af 100644 --- a/server/src/main/java/org/elasticsearch/index/query/InterceptedQueryBuilderWrapper.java +++ b/server/src/main/java/org/elasticsearch/index/query/InterceptedQueryBuilderWrapper.java @@ -22,19 +22,13 @@ * Wrapper for instances of {@link QueryBuilder} that have been intercepted using the {@link QueryRewriteInterceptor} to * break out of the rewrite phase. These instances are unwrapped on serialization. */ -public class InterceptedQueryBuilderWrapper implements QueryBuilder { +class InterceptedQueryBuilderWrapper implements QueryBuilder { - protected final QueryBuilder original; - protected final QueryBuilder rewritten; + protected final QueryBuilder queryBuilder; - InterceptedQueryBuilderWrapper(QueryBuilder rewritten, QueryBuilder original) { + InterceptedQueryBuilderWrapper(QueryBuilder queryBuilder) { super(); - this.original = original; - this.rewritten = rewritten; - } - - public QueryBuilder getOriginal() { - return original; + this.queryBuilder = queryBuilder; } @Override @@ -42,8 +36,8 @@ public QueryBuilder rewrite(QueryRewriteContext queryRewriteContext) throws IOEx QueryRewriteInterceptor queryRewriteInterceptor = queryRewriteContext.getQueryRewriteInterceptor(); try { queryRewriteContext.setQueryRewriteInterceptor(null); - QueryBuilder rewritten = this.rewritten.rewrite(queryRewriteContext); - return rewritten != this.rewritten ? new InterceptedQueryBuilderWrapper(rewritten, original) : this; + QueryBuilder rewritten = queryBuilder.rewrite(queryRewriteContext); + return rewritten != queryBuilder ? new InterceptedQueryBuilderWrapper(rewritten) : this; } finally { queryRewriteContext.setQueryRewriteInterceptor(queryRewriteInterceptor); } @@ -51,54 +45,54 @@ public QueryBuilder rewrite(QueryRewriteContext queryRewriteContext) throws IOEx @Override public String getWriteableName() { - return rewritten.getWriteableName(); + return queryBuilder.getWriteableName(); } @Override public TransportVersion getMinimalSupportedVersion() { - return rewritten.getMinimalSupportedVersion(); + return queryBuilder.getMinimalSupportedVersion(); } @Override public Query toQuery(SearchExecutionContext context) throws IOException { - return rewritten.toQuery(context); + return queryBuilder.toQuery(context); } @Override public QueryBuilder queryName(String queryName) { - rewritten.queryName(queryName); + queryBuilder.queryName(queryName); return this; } @Override public String queryName() { - return rewritten.queryName(); + return queryBuilder.queryName(); } @Override public float boost() { - return rewritten.boost(); + return queryBuilder.boost(); } @Override public QueryBuilder boost(float boost) { - rewritten.boost(boost); + queryBuilder.boost(boost); return this; } @Override public String getName() { - return rewritten.getName(); + return queryBuilder.getName(); } @Override public void writeTo(StreamOutput out) throws IOException { - rewritten.writeTo(out); + queryBuilder.writeTo(out); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - return rewritten.toXContent(builder, params); + return queryBuilder.toXContent(builder, params); } @Override @@ -106,11 +100,11 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; InterceptedQueryBuilderWrapper that = (InterceptedQueryBuilderWrapper) o; - return Objects.equals(original, that.original) && Objects.equals(rewritten, that.rewritten); + return Objects.equals(queryBuilder, that.queryBuilder); } @Override public int hashCode() { - return Objects.hash(original, rewritten); + return Objects.hashCode(queryBuilder); } } diff --git a/server/src/test/java/org/elasticsearch/index/query/InterceptedQueryBuilderWrapperTests.java b/server/src/test/java/org/elasticsearch/index/query/InterceptedQueryBuilderWrapperTests.java index 891a3b850bf56..6c570e0e71725 100644 --- a/server/src/test/java/org/elasticsearch/index/query/InterceptedQueryBuilderWrapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/query/InterceptedQueryBuilderWrapperTests.java @@ -36,10 +36,7 @@ public void cleanup() { public void testQueryNameReturnsWrappedQueryBuilder() { MatchAllQueryBuilder matchAllQueryBuilder = new MatchAllQueryBuilder(); - InterceptedQueryBuilderWrapper interceptedQueryBuilderWrapper = new InterceptedQueryBuilderWrapper( - matchAllQueryBuilder, - matchAllQueryBuilder - ); + InterceptedQueryBuilderWrapper interceptedQueryBuilderWrapper = new InterceptedQueryBuilderWrapper(matchAllQueryBuilder); String queryName = randomAlphaOfLengthBetween(5, 10); QueryBuilder namedQuery = interceptedQueryBuilderWrapper.queryName(queryName); assertTrue(namedQuery instanceof InterceptedQueryBuilderWrapper); @@ -48,10 +45,7 @@ public void testQueryNameReturnsWrappedQueryBuilder() { public void testQueryBoostReturnsWrappedQueryBuilder() { MatchAllQueryBuilder matchAllQueryBuilder = new MatchAllQueryBuilder(); - InterceptedQueryBuilderWrapper interceptedQueryBuilderWrapper = new InterceptedQueryBuilderWrapper( - matchAllQueryBuilder, - matchAllQueryBuilder - ); + InterceptedQueryBuilderWrapper interceptedQueryBuilderWrapper = new InterceptedQueryBuilderWrapper(matchAllQueryBuilder); float boost = randomFloat(); QueryBuilder boostedQuery = interceptedQueryBuilderWrapper.boost(boost); assertTrue(boostedQuery instanceof InterceptedQueryBuilderWrapper); @@ -71,8 +65,8 @@ public void testRewrite() throws IOException { MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder("field", "value"); rewritten = matchQueryBuilder.rewrite(context); assertTrue(rewritten instanceof InterceptedQueryBuilderWrapper); - assertTrue(((InterceptedQueryBuilderWrapper) rewritten).rewritten instanceof MatchQueryBuilder); - MatchQueryBuilder rewrittenMatchQueryBuilder = (MatchQueryBuilder) ((InterceptedQueryBuilderWrapper) rewritten).rewritten; + assertTrue(((InterceptedQueryBuilderWrapper) rewritten).queryBuilder instanceof MatchQueryBuilder); + MatchQueryBuilder rewrittenMatchQueryBuilder = (MatchQueryBuilder) ((InterceptedQueryBuilderWrapper) rewritten).queryBuilder; assertEquals("intercepted", rewrittenMatchQueryBuilder.value()); // An additional rewrite on an already intercepted query returns the same query diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java index a923309756509..0d9669e9e2819 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java @@ -100,9 +100,9 @@ private void testRewrittenInferenceQuery(QueryRewriteContext context, KnnVectorQ InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten; assertEquals(original.boost(), intercepted.boost(), 0.0f); assertEquals(original.queryName(), intercepted.queryName()); - assertTrue(intercepted.rewritten instanceof NestedQueryBuilder); + assertTrue(intercepted.queryBuilder instanceof NestedQueryBuilder); - NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.rewritten; + NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder; assertEquals(original.boost(), nestedQueryBuilder.boost(), 0.0f); assertEquals(original.queryName(), nestedQueryBuilder.queryName()); assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticMatchQueryRewriteInterceptorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticMatchQueryRewriteInterceptorTests.java index e8a3a45cff637..2226acefd2357 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticMatchQueryRewriteInterceptorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticMatchQueryRewriteInterceptorTests.java @@ -64,8 +64,8 @@ public void testMatchQueryOnInferenceFieldIsInterceptedAndRewrittenToSemanticQue rewritten instanceof InterceptedQueryBuilderWrapper ); InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten; - assertTrue(intercepted.rewritten instanceof SemanticQueryBuilder); - SemanticQueryBuilder semanticQueryBuilder = (SemanticQueryBuilder) intercepted.rewritten; + assertTrue(intercepted.queryBuilder instanceof SemanticQueryBuilder); + SemanticQueryBuilder semanticQueryBuilder = (SemanticQueryBuilder) intercepted.queryBuilder; assertEquals(FIELD_NAME, semanticQueryBuilder.getFieldName()); assertEquals(VALUE, semanticQueryBuilder.getQuery()); } @@ -98,8 +98,8 @@ public void testBoostAndQueryNameInMatchQueryRewrite() throws IOException { InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten; assertEquals(BOOST, intercepted.boost(), 0.0f); assertEquals(QUERY_NAME, intercepted.queryName()); - assertTrue(intercepted.rewritten instanceof SemanticQueryBuilder); - SemanticQueryBuilder semanticQueryBuilder = (SemanticQueryBuilder) intercepted.rewritten; + assertTrue(intercepted.queryBuilder instanceof SemanticQueryBuilder); + SemanticQueryBuilder semanticQueryBuilder = (SemanticQueryBuilder) intercepted.queryBuilder; assertEquals(FIELD_NAME, semanticQueryBuilder.getFieldName()); assertEquals(VALUE, semanticQueryBuilder.getQuery()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticSparseVectorQueryRewriteInterceptorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticSparseVectorQueryRewriteInterceptorTests.java index 420c83b6bce7c..7f3a1eb504039 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticSparseVectorQueryRewriteInterceptorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticSparseVectorQueryRewriteInterceptorTests.java @@ -108,8 +108,8 @@ private void testRewrittenInferenceQuery(QueryRewriteContext context, QueryBuild assertEquals(original.boost(), intercepted.boost(), 0.0f); assertEquals(original.queryName(), intercepted.queryName()); - assertTrue(intercepted.rewritten instanceof NestedQueryBuilder); - NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.rewritten; + assertTrue(intercepted.queryBuilder instanceof NestedQueryBuilder); + NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder; assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path()); assertEquals(original.boost(), nestedQueryBuilder.boost(), 0.0f); assertEquals(original.queryName(), nestedQueryBuilder.queryName()); From c39adbe688de72b0d1223b67f641291faa9aeafe Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Mon, 4 Aug 2025 14:58:45 -0400 Subject: [PATCH 33/35] Revert MatchQueryBuilder changes --- .../elasticsearch/index/query/MatchQueryBuilder.java | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/query/MatchQueryBuilder.java b/server/src/main/java/org/elasticsearch/index/query/MatchQueryBuilder.java index eb5f1bbbc3230..d68eaa925500d 100644 --- a/server/src/main/java/org/elasticsearch/index/query/MatchQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/MatchQueryBuilder.java @@ -365,16 +365,6 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio builder.endObject(); } - @Override - protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { - ResolvedIndices resolvedIndices = queryRewriteContext.getResolvedIndices(); - if (resolvedIndices != null) { - return this; - } else { - return super.doRewrite(queryRewriteContext); - } - } - @Override protected QueryBuilder doIndexMetadataRewrite(QueryRewriteContext context) throws IOException { if (fuzziness != null || lenient) { From 9c58c1128c48e8d7c647472d710233d0f8b3e88b Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Mon, 4 Aug 2025 15:00:49 -0400 Subject: [PATCH 34/35] Adjust PIT test --- .../action/search/TransportSearchAction.java | 2 +- .../index/query/MatchQueryBuilder.java | 1 - .../ccs/SemanticCrossClusterSearchIT.java | 20 ++++++++++--------- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java index e0d254a00c4f2..693fc340e6189 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java @@ -552,7 +552,7 @@ public void onFailure(Exception e) { timeProvider::absoluteStartMillis, resolvedIndices, original.pointInTimeBuilder(), - original.isCcsMinimizeRoundtrips(), + shouldMinimizeRoundtrips(original), isExplain ), rewriteListener diff --git a/server/src/main/java/org/elasticsearch/index/query/MatchQueryBuilder.java b/server/src/main/java/org/elasticsearch/index/query/MatchQueryBuilder.java index d68eaa925500d..fd704d39ca384 100644 --- a/server/src/main/java/org/elasticsearch/index/query/MatchQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/MatchQueryBuilder.java @@ -14,7 +14,6 @@ import org.apache.lucene.search.Query; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; -import org.elasticsearch.action.ResolvedIndices; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java index ae05e0c2b9c64..b6265be8937c6 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java @@ -47,6 +47,7 @@ import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse; +import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.equalTo; public class SemanticCrossClusterSearchIT extends AbstractMultiClustersTestCase { @@ -129,17 +130,18 @@ public void testSemanticCrossClusterSearchWithPIT() throws Exception { TimeValue.timeValueMinutes(2) ); + ModelRegistry modelRegistry = cluster(LOCAL_CLUSTER).getCurrentMasterNodeInstance(ModelRegistry.class); + SemanticQueryBuilder queryBuilder = new SemanticQueryBuilder(INFERENCE_FIELD, "foo"); + queryBuilder.setModelRegistrySupplier(() -> modelRegistry); + SearchRequest searchRequest = new SearchRequest(); - searchRequest.source( - new SearchSourceBuilder().query(new SemanticQueryBuilder(INFERENCE_FIELD, "foo")) - .pointInTimeBuilder(new PointInTimeBuilder(pitId)) - .size(10) - ); + searchRequest.source(new SearchSourceBuilder().query(queryBuilder).pointInTimeBuilder(new PointInTimeBuilder(pitId)).size(10)); - assertResponse(client(LOCAL_CLUSTER).search(searchRequest), response -> { - assertNotNull(response); - assertEquals(10, response.getHits().getHits().length); - }); + IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> client(LOCAL_CLUSTER).search(searchRequest).actionGet(TEST_REQUEST_TIMEOUT) + ); + assertThat(e.getMessage(), containsString("semantic query supports CCS only when ccs_minimize_roundtrips=true")); } public void testMatchCrossClusterSearch() throws Exception { From 4fb6f42ddfa397ba841d692dcd3da9400524d525 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Mon, 4 Aug 2025 15:01:40 -0400 Subject: [PATCH 35/35] Remove match query test --- .../search/ccs/SemanticCrossClusterSearchIT.java | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java index b6265be8937c6..a14048546c0fe 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java @@ -18,7 +18,6 @@ import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.index.query.MatchQueryBuilder; import org.elasticsearch.inference.TaskType; import org.elasticsearch.license.LicenseSettings; import org.elasticsearch.plugins.Plugin; @@ -144,21 +143,6 @@ public void testSemanticCrossClusterSearchWithPIT() throws Exception { assertThat(e.getMessage(), containsString("semantic query supports CCS only when ccs_minimize_roundtrips=true")); } - public void testMatchCrossClusterSearch() throws Exception { - Map testClusterInfo = setupTwoClusters(); - String localIndex = (String) testClusterInfo.get("local.index"); - String remoteIndex = (String) testClusterInfo.get("remote.index"); - - SearchRequest searchRequest = new SearchRequest(localIndex, REMOTE_CLUSTER + ":" + remoteIndex); - searchRequest.source(new SearchSourceBuilder().query(new MatchQueryBuilder(INFERENCE_FIELD, "foo")).size(10)); - // searchRequest.setCcsMinimizeRoundtrips(false); - - assertResponse(client(LOCAL_CLUSTER).search(searchRequest), response -> { - assertNotNull(response); - assertEquals(10, response.getHits().getHits().length); - }); - } - private Map setupTwoClusters(String[] localIndices, String[] remoteIndices) throws IOException { final String localInferenceId = "local_inference_id"; final String remoteInferenceId = "remote_inference_id";