diff --git a/docs/changelog/135309.yaml b/docs/changelog/135309.yaml
new file mode 100644
index 0000000000000..20c50553c2eb8
--- /dev/null
+++ b/docs/changelog/135309.yaml
@@ -0,0 +1,5 @@
+pr: 135309
+summary: Enable semantic search CCS when ccs_minimize_roundtrips=true
+area: Vector Search
+type: enhancement
+issues: []
diff --git a/qa/multi-cluster-search/src/test/resources/rest-api-spec/test/multi_cluster/110_semantic_query.yml b/qa/multi-cluster-search/src/test/resources/rest-api-spec/test/multi_cluster/110_semantic_query.yml
deleted file mode 100644
index 0155175f0e54a..0000000000000
--- a/qa/multi-cluster-search/src/test/resources/rest-api-spec/test/multi_cluster/110_semantic_query.yml
+++ /dev/null
@@ -1,37 +0,0 @@
----
-setup:
- - requires:
- cluster_features: "gte_v8.15.0"
- reason: semantic query introduced in 8.15.0
-
- - do:
- indices.create:
- index: test-index
- body:
- settings:
- index:
- number_of_shards: 1
- number_of_replicas: 0
----
-teardown:
-
- - do:
- indices.delete:
- index: test-index
- ignore_unavailable: true
-
----
-"Test that semantic query does not support cross-cluster search":
- - do:
- catch: bad_request
- search:
- index: "test-index,my_remote_cluster:test-index"
- body:
- query:
- semantic:
- field: "field"
- query: "test query"
-
-
- - match: { error.type: "illegal_argument_exception" }
- - match: { error.reason: "semantic query does not support cross-cluster search" }
diff --git a/server/src/main/resources/transport/definitions/referable/semantic_search_ccs_support.csv b/server/src/main/resources/transport/definitions/referable/semantic_search_ccs_support.csv
new file mode 100644
index 0000000000000..35154103cd0da
--- /dev/null
+++ b/server/src/main/resources/transport/definitions/referable/semantic_search_ccs_support.csv
@@ -0,0 +1 @@
+9174000
diff --git a/server/src/main/resources/transport/upper_bounds/9.2.csv b/server/src/main/resources/transport/upper_bounds/9.2.csv
index e60434a3e2189..57900e0428e01 100644
--- a/server/src/main/resources/transport/upper_bounds/9.2.csv
+++ b/server/src/main/resources/transport/upper_bounds/9.2.csv
@@ -1 +1 @@
-sampling_configuration,9173000
+semantic_search_ccs_support,9174000
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/LocalStateCompositeXPackPlugin.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/LocalStateCompositeXPackPlugin.java
index 117141708ed43..366b10e125604 100644
--- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/LocalStateCompositeXPackPlugin.java
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/LocalStateCompositeXPackPlugin.java
@@ -79,6 +79,8 @@
import org.elasticsearch.plugins.ShutdownAwarePlugin;
import org.elasticsearch.plugins.SystemIndexPlugin;
import org.elasticsearch.plugins.interceptor.RestServerActionPlugin;
+import org.elasticsearch.plugins.internal.InternalSearchPlugin;
+import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
import org.elasticsearch.repositories.RepositoriesMetrics;
import org.elasticsearch.repositories.Repository;
import org.elasticsearch.repositories.SnapshotMetrics;
@@ -135,6 +137,7 @@ public class LocalStateCompositeXPackPlugin extends XPackPlugin
IndexStorePlugin,
SystemIndexPlugin,
SearchPlugin,
+ InternalSearchPlugin,
ShutdownAwarePlugin,
RestServerActionPlugin {
@@ -291,6 +294,15 @@ public List> getQueries() {
return querySpecs;
}
+ @Override
+ public List getQueryRewriteInterceptors() {
+ List interceptors = new ArrayList<>();
+ filterPlugins(InternalSearchPlugin.class).stream()
+ .flatMap(p -> p.getQueryRewriteInterceptors().stream())
+ .forEach(interceptors::add);
+ return interceptors;
+ }
+
@Override
public List getNamedXContent() {
List entries = new ArrayList<>(super.getNamedXContent());
diff --git a/x-pack/plugin/inference/build.gradle b/x-pack/plugin/inference/build.gradle
index 9486d239e5de5..eb9372e675831 100644
--- a/x-pack/plugin/inference/build.gradle
+++ b/x-pack/plugin/inference/build.gradle
@@ -36,6 +36,7 @@ dependencies {
testImplementation(project(':x-pack:plugin:inference:qa:test-service-plugin'))
testImplementation project(':modules:reindex')
testImplementation project(':modules:mapper-extras')
+ testImplementation project(':x-pack:plugin:ml')
clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin')
api "com.ibm.icu:icu4j:${versions.icu4j}"
diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/AbstractSemanticCrossClusterSearchTestCase.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/AbstractSemanticCrossClusterSearchTestCase.java
new file mode 100644
index 0000000000000..685453fa77c78
--- /dev/null
+++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/AbstractSemanticCrossClusterSearchTestCase.java
@@ -0,0 +1,337 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.search.ccs;
+
+import org.elasticsearch.action.DocWriteResponse;
+import org.elasticsearch.action.search.OpenPointInTimeRequest;
+import org.elasticsearch.action.search.OpenPointInTimeResponse;
+import org.elasticsearch.action.search.SearchRequest;
+import org.elasticsearch.action.search.SearchResponse;
+import org.elasticsearch.action.search.TransportOpenPointInTimeAction;
+import org.elasticsearch.action.support.IndicesOptions;
+import org.elasticsearch.action.support.broadcast.BroadcastResponse;
+import org.elasticsearch.client.internal.Client;
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
+import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper;
+import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.inference.MinimalServiceSettings;
+import org.elasticsearch.inference.SimilarityMeasure;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.license.LicenseSettings;
+import org.elasticsearch.plugins.ActionPlugin;
+import org.elasticsearch.plugins.Plugin;
+import org.elasticsearch.plugins.SearchPlugin;
+import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.search.SearchHit;
+import org.elasticsearch.search.builder.SearchSourceBuilder;
+import org.elasticsearch.test.AbstractMultiClustersTestCase;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
+import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction;
+import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
+import org.elasticsearch.xpack.core.ml.vectors.TextEmbeddingQueryVectorBuilder;
+import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
+import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
+import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension;
+import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin;
+import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;
+import org.elasticsearch.xpack.ml.action.TransportCoordinatedInferenceAction;
+
+import java.io.IOException;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.function.Consumer;
+import java.util.stream.Collectors;
+
+import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
+import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.is;
+
+public abstract class AbstractSemanticCrossClusterSearchTestCase extends AbstractMultiClustersTestCase {
+ protected static final String REMOTE_CLUSTER = "cluster_a";
+
+ @Override
+ protected List remoteClusterAlias() {
+ return List.of(REMOTE_CLUSTER);
+ }
+
+ @Override
+ protected Map skipUnavailableForRemoteClusters() {
+ return Map.of(REMOTE_CLUSTER, DEFAULT_SKIP_UNAVAILABLE);
+ }
+
+ @Override
+ protected boolean reuseClusters() {
+ return false;
+ }
+
+ @Override
+ protected Settings nodeSettings() {
+ return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build();
+ }
+
+ @Override
+ protected Collection> nodePlugins(String clusterAlias) {
+ return List.of(LocalStateInferencePlugin.class, TestInferenceServicePlugin.class, FakeMlPlugin.class);
+ }
+
+ protected void setupTwoClusters(TestIndexInfo localIndexInfo, TestIndexInfo remoteIndexInfo) throws IOException {
+ setupCluster(LOCAL_CLUSTER, localIndexInfo);
+ setupCluster(REMOTE_CLUSTER, remoteIndexInfo);
+ }
+
+ protected void setupCluster(String clusterAlias, TestIndexInfo indexInfo) throws IOException {
+ final Client client = client(clusterAlias);
+ final String indexName = indexInfo.name();
+
+ for (var entry : indexInfo.inferenceEndpoints().entrySet()) {
+ String inferenceId = entry.getKey();
+ MinimalServiceSettings minimalServiceSettings = entry.getValue();
+
+ Map serviceSettings = new HashMap<>();
+ serviceSettings.put("model", randomAlphaOfLength(5));
+ serviceSettings.put("api_key", randomAlphaOfLength(5));
+ if (minimalServiceSettings.taskType() == TaskType.TEXT_EMBEDDING) {
+ serviceSettings.put("dimensions", minimalServiceSettings.dimensions());
+ serviceSettings.put("similarity", minimalServiceSettings.similarity());
+ serviceSettings.put("element_type", minimalServiceSettings.elementType());
+ }
+
+ createInferenceEndpoint(client, minimalServiceSettings.taskType(), inferenceId, serviceSettings);
+ }
+
+ Settings indexSettings = indexSettings(randomIntBetween(2, 5), randomIntBetween(0, 1)).build();
+ assertAcked(client.admin().indices().prepareCreate(indexName).setSettings(indexSettings).setMapping(indexInfo.mappings()));
+ assertFalse(
+ client.admin()
+ .cluster()
+ .prepareHealth(TEST_REQUEST_TIMEOUT, indexName)
+ .setWaitForYellowStatus()
+ .setTimeout(TimeValue.timeValueSeconds(10))
+ .get()
+ .isTimedOut()
+ );
+
+ for (var entry : indexInfo.docs().entrySet()) {
+ String docId = entry.getKey();
+ Map doc = entry.getValue();
+
+ DocWriteResponse response = client.prepareIndex(indexName).setId(docId).setSource(doc).execute().actionGet();
+ assertThat(response.getResult(), equalTo(DocWriteResponse.Result.CREATED));
+ }
+ BroadcastResponse refreshResponse = client.admin().indices().prepareRefresh(indexName).execute().actionGet();
+ assertThat(refreshResponse.getStatus(), is(RestStatus.OK));
+ }
+
+ protected BytesReference openPointInTime(String[] indices, TimeValue keepAlive) {
+ OpenPointInTimeRequest request = new OpenPointInTimeRequest(indices).keepAlive(keepAlive);
+ final OpenPointInTimeResponse response = client().execute(TransportOpenPointInTimeAction.TYPE, request).actionGet();
+ return response.getPointInTimeId();
+ }
+
+ protected static void createInferenceEndpoint(Client client, TaskType taskType, String inferenceId, Map serviceSettings)
+ throws IOException {
+ final String service = switch (taskType) {
+ case TEXT_EMBEDDING -> TestDenseInferenceServiceExtension.TestInferenceService.NAME;
+ case SPARSE_EMBEDDING -> TestSparseInferenceServiceExtension.TestInferenceService.NAME;
+ default -> throw new IllegalArgumentException("Unhandled task type [" + taskType + "]");
+ };
+
+ final BytesReference content;
+ try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
+ builder.startObject();
+ builder.field("service", service);
+ builder.field("service_settings", serviceSettings);
+ builder.endObject();
+
+ content = BytesReference.bytes(builder);
+ }
+
+ PutInferenceModelAction.Request request = new PutInferenceModelAction.Request(
+ taskType,
+ inferenceId,
+ content,
+ XContentType.JSON,
+ TEST_REQUEST_TIMEOUT
+ );
+ var responseFuture = client.execute(PutInferenceModelAction.INSTANCE, request);
+ assertThat(responseFuture.actionGet(TEST_REQUEST_TIMEOUT).getModel().getInferenceEntityId(), equalTo(inferenceId));
+ }
+
+ protected void assertSearchResponse(QueryBuilder queryBuilder, List indices, List expectedSearchResults)
+ throws Exception {
+ assertSearchResponse(queryBuilder, indices, expectedSearchResults, null, null);
+ }
+
+ protected void assertSearchResponse(
+ QueryBuilder queryBuilder,
+ List indices,
+ List expectedSearchResults,
+ ClusterFailure expectedRemoteFailure,
+ Consumer searchRequestModifier
+ ) throws Exception {
+ SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(queryBuilder).size(expectedSearchResults.size());
+ indices.forEach(i -> searchSourceBuilder.indexBoost(i.index(), i.boost()));
+
+ SearchRequest searchRequest = new SearchRequest(convertToArray(indices), searchSourceBuilder);
+ searchRequest.indicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN);
+ if (searchRequestModifier != null) {
+ searchRequestModifier.accept(searchRequest);
+ }
+
+ assertResponse(client().search(searchRequest), response -> {
+ SearchHit[] hits = response.getHits().getHits();
+ assertThat(hits.length, equalTo(expectedSearchResults.size()));
+
+ Iterator searchResultIterator = expectedSearchResults.iterator();
+ for (int i = 0; i < hits.length; i++) {
+ SearchResult expectedSearchResult = searchResultIterator.next();
+ SearchHit actualSearchResult = hits[i];
+
+ assertThat(actualSearchResult.getClusterAlias(), equalTo(expectedSearchResult.clusterAlias()));
+ assertThat(actualSearchResult.getIndex(), equalTo(expectedSearchResult.index()));
+ assertThat(actualSearchResult.getId(), equalTo(expectedSearchResult.id()));
+ }
+
+ SearchResponse.Clusters clusters = response.getClusters();
+ assertThat(clusters.getCluster(LOCAL_CLUSTER).getStatus(), equalTo(SearchResponse.Cluster.Status.SUCCESSFUL));
+ assertThat(clusters.getCluster(LOCAL_CLUSTER).getFailures().isEmpty(), is(true));
+
+ SearchResponse.Cluster remoteCluster = clusters.getCluster(REMOTE_CLUSTER);
+ if (expectedRemoteFailure != null) {
+ assertThat(remoteCluster.getStatus(), equalTo(expectedRemoteFailure.status()));
+
+ Set expectedFailures = expectedRemoteFailure.failures();
+ Set actualFailures = remoteCluster.getFailures()
+ .stream()
+ .map(f -> new FailureCause(f.getCause().getClass(), f.getCause().getMessage()))
+ .collect(Collectors.toSet());
+ assertThat(actualFailures, equalTo(expectedFailures));
+ } else {
+ assertThat(remoteCluster.getStatus(), equalTo(SearchResponse.Cluster.Status.SUCCESSFUL));
+ assertThat(remoteCluster.getFailures().isEmpty(), is(true));
+ }
+ });
+ }
+
+ protected static MinimalServiceSettings sparseEmbeddingServiceSettings() {
+ return new MinimalServiceSettings(null, TaskType.SPARSE_EMBEDDING, null, null, null);
+ }
+
+ protected static MinimalServiceSettings textEmbeddingServiceSettings(
+ int dimensions,
+ SimilarityMeasure similarity,
+ DenseVectorFieldMapper.ElementType elementType
+ ) {
+ return new MinimalServiceSettings(null, TaskType.TEXT_EMBEDDING, dimensions, similarity, elementType);
+ }
+
+ protected static Map semanticTextMapping(String inferenceId) {
+ return Map.of("type", SemanticTextFieldMapper.CONTENT_TYPE, "inference_id", inferenceId);
+ }
+
+ protected static Map textMapping() {
+ return Map.of("type", "text");
+ }
+
+ protected static Map denseVectorMapping(int dimensions) {
+ return Map.of("type", DenseVectorFieldMapper.CONTENT_TYPE, "dims", dimensions);
+ }
+
+ protected static Map sparseVectorMapping() {
+ return Map.of("type", SparseVectorFieldMapper.CONTENT_TYPE);
+ }
+
+ protected static String fullyQualifiedIndexName(String clusterAlias, String indexName) {
+ return clusterAlias + ":" + indexName;
+ }
+
+ protected static float[] generateDenseVectorFieldValue(int dimensions, DenseVectorFieldMapper.ElementType elementType, float value) {
+ if (elementType == DenseVectorFieldMapper.ElementType.BIT) {
+ assert dimensions % 8 == 0;
+ dimensions /= 8;
+ }
+
+ float[] vector = new float[dimensions];
+ for (int i = 0; i < dimensions; i++) {
+ // Use a constant value so that relevance is consistent
+ vector[i] = value;
+ }
+
+ return vector;
+ }
+
+ protected static Map generateSparseVectorFieldValue(float weight) {
+ // Generate values that have the same recall behavior as those produced by TestSparseInferenceServiceExtension. Use a constant token
+ // weight so that relevance is consistent.
+ return Map.of("feature_0", weight);
+ }
+
+ protected static String[] convertToArray(List indices) {
+ return indices.stream().map(IndexWithBoost::index).toArray(String[]::new);
+ }
+
+ public static class FakeMlPlugin extends Plugin implements ActionPlugin, SearchPlugin {
+ @Override
+ public List getNamedWriteables() {
+ return new MlInferenceNamedXContentProvider().getNamedWriteables();
+ }
+
+ @Override
+ public List> getQueryVectorBuilders() {
+ return List.of(
+ new QueryVectorBuilderSpec<>(
+ TextEmbeddingQueryVectorBuilder.NAME,
+ TextEmbeddingQueryVectorBuilder::new,
+ TextEmbeddingQueryVectorBuilder.PARSER
+ )
+ );
+ }
+
+ @Override
+ public Collection getActions() {
+ return List.of(new ActionHandler(CoordinatedInferenceAction.INSTANCE, TransportCoordinatedInferenceAction.class));
+ }
+ }
+
+ protected record TestIndexInfo(
+ String name,
+ Map inferenceEndpoints,
+ Map mappings,
+ Map> docs
+ ) {
+ @Override
+ public Map mappings() {
+ return Map.of("properties", mappings);
+ }
+ }
+
+ protected record SearchResult(@Nullable String clusterAlias, String index, String id) {}
+
+ protected record FailureCause(Class extends Throwable> causeClass, String message) {}
+
+ protected record ClusterFailure(SearchResponse.Cluster.Status status, Set failures) {}
+
+ protected record IndexWithBoost(String index, float boost) {
+ public IndexWithBoost(String index) {
+ this(index, 1.0f);
+ }
+ }
+}
diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/KnnVectorQueryBuilderCrossClusterSearchIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/KnnVectorQueryBuilderCrossClusterSearchIT.java
new file mode 100644
index 0000000000000..b54d7afe08714
--- /dev/null
+++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/KnnVectorQueryBuilderCrossClusterSearchIT.java
@@ -0,0 +1,279 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.search.ccs;
+
+import org.elasticsearch.action.search.SearchRequest;
+import org.elasticsearch.action.search.SearchResponse;
+import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
+import org.elasticsearch.index.query.QueryShardException;
+import org.elasticsearch.inference.MinimalServiceSettings;
+import org.elasticsearch.inference.SimilarityMeasure;
+import org.elasticsearch.search.builder.SearchSourceBuilder;
+import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
+import org.elasticsearch.search.vectors.VectorData;
+import org.elasticsearch.xpack.core.ml.vectors.TextEmbeddingQueryVectorBuilder;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.function.BiConsumer;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public class KnnVectorQueryBuilderCrossClusterSearchIT extends AbstractSemanticCrossClusterSearchTestCase {
+ private static final String LOCAL_INDEX_NAME = "local-index";
+ private static final String REMOTE_INDEX_NAME = "remote-index";
+ private static final List QUERY_INDICES = List.of(
+ new IndexWithBoost(LOCAL_INDEX_NAME),
+ new IndexWithBoost(fullyQualifiedIndexName(REMOTE_CLUSTER, REMOTE_INDEX_NAME))
+ );
+
+ public void testKnnQuery() throws Exception {
+ final String commonInferenceId = "common-inference-id";
+ final String localInferenceId = "local-inference-id";
+
+ final String commonInferenceIdField = "common-inference-id-field";
+ final String mixedTypeField1 = "mixed-type-field-1";
+ final String mixedTypeField2 = "mixed-type-field-2";
+
+ final TestIndexInfo localIndexInfo = new TestIndexInfo(
+ LOCAL_INDEX_NAME,
+ Map.of(
+ commonInferenceId,
+ textEmbeddingServiceSettings(256, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT),
+ localInferenceId,
+ textEmbeddingServiceSettings(384, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT)
+ ),
+ Map.of(
+ commonInferenceIdField,
+ semanticTextMapping(commonInferenceId),
+ mixedTypeField1,
+ denseVectorMapping(384),
+ mixedTypeField2,
+ semanticTextMapping(localInferenceId)
+ ),
+ Map.of(
+ "local_doc_1",
+ Map.of(commonInferenceIdField, "a"),
+ "local_doc_2",
+ Map.of(mixedTypeField1, generateDenseVectorFieldValue(384, DenseVectorFieldMapper.ElementType.FLOAT, -128.0f)),
+ "local_doc_3",
+ Map.of(mixedTypeField2, "c")
+ )
+ );
+ final TestIndexInfo remoteIndexInfo = new TestIndexInfo(
+ REMOTE_INDEX_NAME,
+ Map.of(
+ commonInferenceId,
+ textEmbeddingServiceSettings(384, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT)
+ ),
+ Map.of(
+ commonInferenceIdField,
+ semanticTextMapping(commonInferenceId),
+ mixedTypeField1,
+ semanticTextMapping(commonInferenceId),
+ mixedTypeField2,
+ denseVectorMapping(384)
+ ),
+ Map.of(
+ "remote_doc_1",
+ Map.of(commonInferenceIdField, "x"),
+ "remote_doc_2",
+ Map.of(mixedTypeField1, "y"),
+ "remote_doc_3",
+ Map.of(mixedTypeField2, generateDenseVectorFieldValue(384, DenseVectorFieldMapper.ElementType.FLOAT, -128.0f))
+ )
+ );
+ setupTwoClusters(localIndexInfo, remoteIndexInfo);
+
+ // Query a field has the same inference ID value across clusters, but with different backing inference services
+ assertSearchResponse(
+ new KnnVectorQueryBuilder(commonInferenceIdField, new TextEmbeddingQueryVectorBuilder(null, "a"), 10, 100, 10f, null),
+ QUERY_INDICES,
+ List.of(
+ new SearchResult(LOCAL_CLUSTER, LOCAL_INDEX_NAME, "local_doc_1"),
+ new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, "remote_doc_1")
+ )
+ );
+
+ // Query a field that has mixed types across clusters
+ assertSearchResponse(
+ new KnnVectorQueryBuilder(mixedTypeField1, new TextEmbeddingQueryVectorBuilder(localInferenceId, "y"), 10, 100, 10f, null),
+ QUERY_INDICES,
+ List.of(
+ new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, "remote_doc_2"),
+ new SearchResult(LOCAL_CLUSTER, LOCAL_INDEX_NAME, "local_doc_2")
+ )
+ );
+ assertSearchResponse(
+ new KnnVectorQueryBuilder(mixedTypeField2, new TextEmbeddingQueryVectorBuilder(localInferenceId, "c"), 10, 100, 10f, null),
+ QUERY_INDICES,
+ List.of(
+ new SearchResult(LOCAL_CLUSTER, LOCAL_INDEX_NAME, "local_doc_3"),
+ new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, "remote_doc_3")
+ )
+ );
+
+ // Query a field that has mixed types across clusters using a query vector
+ final VectorData queryVector = new VectorData(
+ generateDenseVectorFieldValue(384, DenseVectorFieldMapper.ElementType.FLOAT, -128.0f)
+ );
+ assertSearchResponse(
+ new KnnVectorQueryBuilder(mixedTypeField1, queryVector, 10, 100, 10f, null, null),
+ QUERY_INDICES,
+ List.of(
+ new SearchResult(LOCAL_CLUSTER, LOCAL_INDEX_NAME, "local_doc_2"),
+ new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, "remote_doc_2")
+ )
+ );
+ assertSearchResponse(
+ new KnnVectorQueryBuilder(mixedTypeField2, queryVector, 10, 100, 10f, null, null),
+ QUERY_INDICES,
+ List.of(
+ new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, "remote_doc_3"),
+ new SearchResult(LOCAL_CLUSTER, LOCAL_INDEX_NAME, "local_doc_3")
+ )
+ );
+
+ // Check that omitting the inference ID when querying a remote dense vector field leads to the expected partial failure
+ assertSearchResponse(
+ new KnnVectorQueryBuilder(mixedTypeField2, new TextEmbeddingQueryVectorBuilder(null, "c"), 10, 100, 10f, null),
+ QUERY_INDICES,
+ List.of(new SearchResult(LOCAL_CLUSTER, LOCAL_INDEX_NAME, "local_doc_3")),
+ new ClusterFailure(
+ SearchResponse.Cluster.Status.SKIPPED,
+ Set.of(new FailureCause(IllegalArgumentException.class, "[model_id] must not be null."))
+ ),
+ null
+ );
+ }
+
+ public void testKnnQueryWithCcsMinimizeRoundTripsFalse() throws Exception {
+ final BiConsumer assertCcsMinimizeRoundTripsFalseFailure = (f, qvb) -> {
+ KnnVectorQueryBuilder queryBuilder = new KnnVectorQueryBuilder(f, qvb, 10, 100, 10f, null);
+
+ SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(queryBuilder);
+ SearchRequest searchRequest = new SearchRequest(convertToArray(QUERY_INDICES), searchSourceBuilder);
+ searchRequest.setCcsMinimizeRoundtrips(false);
+
+ IllegalArgumentException e = assertThrows(
+ IllegalArgumentException.class,
+ () -> client().search(searchRequest).actionGet(TEST_REQUEST_TIMEOUT)
+ );
+ assertThat(
+ e.getMessage(),
+ equalTo(
+ "knn query does not support cross-cluster search when querying a [semantic_text] field when "
+ + "[ccs_minimize_roundtrips] is false"
+ )
+ );
+ };
+
+ final int dimensions = 256;
+ final String commonInferenceId = "common-inference-id";
+ final MinimalServiceSettings commonInferenceIdServiceSettings = textEmbeddingServiceSettings(
+ dimensions,
+ SimilarityMeasure.COSINE,
+ DenseVectorFieldMapper.ElementType.FLOAT
+ );
+
+ final String commonInferenceIdField = "common-inference-id-field";
+ final String mixedTypeField1 = "mixed-type-field-1";
+ final String mixedTypeField2 = "mixed-type-field-2";
+ final String denseVectorField = "dense-vector-field";
+
+ final TestIndexInfo localIndexInfo = new TestIndexInfo(
+ LOCAL_INDEX_NAME,
+ Map.of(commonInferenceId, commonInferenceIdServiceSettings),
+ Map.of(
+ commonInferenceIdField,
+ semanticTextMapping(commonInferenceId),
+ mixedTypeField1,
+ semanticTextMapping(commonInferenceId),
+ mixedTypeField2,
+ denseVectorMapping(dimensions),
+ denseVectorField,
+ denseVectorMapping(dimensions)
+ ),
+ Map.of(
+ mixedTypeField2 + "_doc",
+ Map.of(mixedTypeField2, generateDenseVectorFieldValue(dimensions, DenseVectorFieldMapper.ElementType.FLOAT, -128.0f)),
+ denseVectorField + "_doc",
+ Map.of(denseVectorField, generateDenseVectorFieldValue(dimensions, DenseVectorFieldMapper.ElementType.FLOAT, 1.0f))
+ )
+ );
+ final TestIndexInfo remoteIndexInfo = new TestIndexInfo(
+ REMOTE_INDEX_NAME,
+ Map.of(commonInferenceId, commonInferenceIdServiceSettings),
+ Map.of(
+ commonInferenceIdField,
+ semanticTextMapping(commonInferenceId),
+ mixedTypeField1,
+ denseVectorMapping(dimensions),
+ mixedTypeField2,
+ semanticTextMapping(commonInferenceId),
+ denseVectorField,
+ denseVectorMapping(dimensions)
+ ),
+ Map.of(
+ mixedTypeField2 + "_doc",
+ Map.of(mixedTypeField2, "a"),
+ denseVectorField + "_doc",
+ Map.of(denseVectorField, generateDenseVectorFieldValue(dimensions, DenseVectorFieldMapper.ElementType.FLOAT, -128.0f))
+ )
+ );
+ setupTwoClusters(localIndexInfo, remoteIndexInfo);
+
+ // Validate that expected cases fail
+ assertCcsMinimizeRoundTripsFalseFailure.accept(
+ commonInferenceIdField,
+ new TextEmbeddingQueryVectorBuilder(null, randomAlphaOfLength(5))
+ );
+ assertCcsMinimizeRoundTripsFalseFailure.accept(
+ mixedTypeField1,
+ new TextEmbeddingQueryVectorBuilder(commonInferenceId, randomAlphaOfLength(5))
+ );
+
+ // Validate the expected ccs_minimize_roundtrips=false detection gap and failure mode when querying non-inference fields locally
+ assertSearchResponse(
+ new KnnVectorQueryBuilder(mixedTypeField2, new TextEmbeddingQueryVectorBuilder(commonInferenceId, "foo"), 10, 100, 10f, null),
+ QUERY_INDICES,
+ List.of(new SearchResult(null, LOCAL_INDEX_NAME, mixedTypeField2 + "_doc")),
+ new ClusterFailure(
+ SearchResponse.Cluster.Status.SKIPPED,
+ Set.of(
+ new FailureCause(
+ QueryShardException.class,
+ "failed to create query: [knn] queries are only supported on [dense_vector] fields"
+ )
+ )
+ ),
+ s -> s.setCcsMinimizeRoundtrips(false)
+ );
+
+ // Validate that a CCS knn query functions when only dense vector fields are queried
+ assertSearchResponse(
+ new KnnVectorQueryBuilder(
+ denseVectorField,
+ generateDenseVectorFieldValue(dimensions, DenseVectorFieldMapper.ElementType.FLOAT, 1.0f),
+ 10,
+ 100,
+ 10f,
+ null,
+ null
+ ),
+ QUERY_INDICES,
+ List.of(
+ new SearchResult(null, LOCAL_INDEX_NAME, denseVectorField + "_doc"),
+ new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, denseVectorField + "_doc")
+ ),
+ null,
+ s -> s.setCcsMinimizeRoundtrips(false)
+ );
+ }
+}
diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/MatchQueryBuilderCrossClusterSearchIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/MatchQueryBuilderCrossClusterSearchIT.java
new file mode 100644
index 0000000000000..a83f7fa80e461
--- /dev/null
+++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/MatchQueryBuilderCrossClusterSearchIT.java
@@ -0,0 +1,217 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.search.ccs;
+
+import org.elasticsearch.action.search.SearchRequest;
+import org.elasticsearch.action.search.SearchResponse;
+import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
+import org.elasticsearch.index.query.MatchQueryBuilder;
+import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.index.query.QueryShardException;
+import org.elasticsearch.inference.SimilarityMeasure;
+import org.elasticsearch.search.builder.SearchSourceBuilder;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.function.Consumer;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public class MatchQueryBuilderCrossClusterSearchIT extends AbstractSemanticCrossClusterSearchTestCase {
+ private static final String LOCAL_INDEX_NAME = "local-index";
+ private static final String REMOTE_INDEX_NAME = "remote-index";
+
+ // Boost the local index so that we can use the same doc values for local and remote indices and have consistent relevance
+ private static final List QUERY_INDICES = List.of(
+ new IndexWithBoost(LOCAL_INDEX_NAME, 10.0f),
+ new IndexWithBoost(fullyQualifiedIndexName(REMOTE_CLUSTER, REMOTE_INDEX_NAME))
+ );
+
+ private static final String COMMON_INFERENCE_ID_FIELD = "common-inference-id-field";
+ private static final String VARIABLE_INFERENCE_ID_FIELD = "variable-inference-id-field";
+ private static final String MIXED_TYPE_FIELD_1 = "mixed-type-field-1";
+ private static final String MIXED_TYPE_FIELD_2 = "mixed-type-field-2";
+ private static final String TEXT_FIELD = "text-field";
+
+ boolean clustersConfigured = false;
+
+ @Override
+ protected boolean reuseClusters() {
+ return true;
+ }
+
+ @Before
+ @Override
+ public void setUp() throws Exception {
+ super.setUp();
+ if (clustersConfigured == false) {
+ configureClusters();
+ clustersConfigured = true;
+ }
+ }
+
+ public void testMatchQuery() throws Exception {
+ // Query a field has the same inference ID value across clusters, but with different backing inference services
+ assertSearchResponse(
+ new MatchQueryBuilder(COMMON_INFERENCE_ID_FIELD, "a"),
+ QUERY_INDICES,
+ List.of(
+ new SearchResult(LOCAL_CLUSTER, LOCAL_INDEX_NAME, getDocId(COMMON_INFERENCE_ID_FIELD)),
+ new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, getDocId(COMMON_INFERENCE_ID_FIELD))
+ )
+ );
+
+ // Query a field that has different inference ID values across clusters
+ assertSearchResponse(
+ new MatchQueryBuilder(VARIABLE_INFERENCE_ID_FIELD, "b"),
+ QUERY_INDICES,
+ List.of(
+ new SearchResult(LOCAL_CLUSTER, LOCAL_INDEX_NAME, getDocId(VARIABLE_INFERENCE_ID_FIELD)),
+ new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, getDocId(VARIABLE_INFERENCE_ID_FIELD))
+ )
+ );
+
+ // Query a field that has mixed types across clusters
+ assertSearchResponse(
+ new MatchQueryBuilder(MIXED_TYPE_FIELD_1, "c"),
+ QUERY_INDICES,
+ List.of(
+ new SearchResult(LOCAL_CLUSTER, LOCAL_INDEX_NAME, getDocId(MIXED_TYPE_FIELD_1)),
+ new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, getDocId(MIXED_TYPE_FIELD_1))
+ )
+ );
+ assertSearchResponse(
+ new MatchQueryBuilder(MIXED_TYPE_FIELD_2, "d"),
+ QUERY_INDICES,
+ List.of(
+ new SearchResult(LOCAL_CLUSTER, LOCAL_INDEX_NAME, getDocId(MIXED_TYPE_FIELD_2)),
+ new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, getDocId(MIXED_TYPE_FIELD_2))
+ )
+ );
+ }
+
+ public void testMatchQueryWithCcsMinimizeRoundTripsFalse() throws Exception {
+ final Consumer assertCcsMinimizeRoundTripsFalseFailure = q -> {
+ SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(q);
+ SearchRequest searchRequest = new SearchRequest(convertToArray(QUERY_INDICES), searchSourceBuilder);
+ searchRequest.setCcsMinimizeRoundtrips(false);
+
+ IllegalArgumentException e = assertThrows(
+ IllegalArgumentException.class,
+ () -> client().search(searchRequest).actionGet(TEST_REQUEST_TIMEOUT)
+ );
+ assertThat(
+ e.getMessage(),
+ equalTo(
+ "match query does not support cross-cluster search when querying a [semantic_text] field when "
+ + "[ccs_minimize_roundtrips] is false"
+ )
+ );
+ };
+
+ // Validate that expected cases fail
+ assertCcsMinimizeRoundTripsFalseFailure.accept(new MatchQueryBuilder(COMMON_INFERENCE_ID_FIELD, randomAlphaOfLength(5)));
+ assertCcsMinimizeRoundTripsFalseFailure.accept(new MatchQueryBuilder(MIXED_TYPE_FIELD_1, randomAlphaOfLength(5)));
+
+ // Validate the expected ccs_minimize_roundtrips=false detection gap and failure mode when querying non-inference fields locally
+ assertSearchResponse(
+ new MatchQueryBuilder(MIXED_TYPE_FIELD_2, "d"),
+ QUERY_INDICES,
+ List.of(new SearchResult(null, LOCAL_INDEX_NAME, getDocId(MIXED_TYPE_FIELD_2))),
+ new ClusterFailure(
+ SearchResponse.Cluster.Status.SKIPPED,
+ Set.of(
+ new FailureCause(
+ QueryShardException.class,
+ "failed to create query: Field [mixed-type-field-2] of type [semantic_text] does not support match queries"
+ )
+ )
+ ),
+ s -> s.setCcsMinimizeRoundtrips(false)
+ );
+
+ // Validate that a CCS match query functions when only text fields are queried
+ assertSearchResponse(
+ new MatchQueryBuilder(TEXT_FIELD, "e"),
+ QUERY_INDICES,
+ List.of(
+ new SearchResult(null, LOCAL_INDEX_NAME, getDocId(TEXT_FIELD)),
+ new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, getDocId(TEXT_FIELD))
+ ),
+ null,
+ s -> s.setCcsMinimizeRoundtrips(false)
+ );
+ }
+
+ private void configureClusters() throws IOException {
+ final String commonInferenceId = "common-inference-id";
+ final String localInferenceId = "local-inference-id";
+ final String remoteInferenceId = "remote-inference-id";
+
+ final Map> docs = Map.of(
+ getDocId(COMMON_INFERENCE_ID_FIELD),
+ Map.of(COMMON_INFERENCE_ID_FIELD, "a"),
+ getDocId(VARIABLE_INFERENCE_ID_FIELD),
+ Map.of(VARIABLE_INFERENCE_ID_FIELD, "b"),
+ getDocId(MIXED_TYPE_FIELD_1),
+ Map.of(MIXED_TYPE_FIELD_1, "c"),
+ getDocId(MIXED_TYPE_FIELD_2),
+ Map.of(MIXED_TYPE_FIELD_2, "d"),
+ getDocId(TEXT_FIELD),
+ Map.of(TEXT_FIELD, "e")
+ );
+
+ final TestIndexInfo localIndexInfo = new TestIndexInfo(
+ LOCAL_INDEX_NAME,
+ Map.of(commonInferenceId, sparseEmbeddingServiceSettings(), localInferenceId, sparseEmbeddingServiceSettings()),
+ Map.of(
+ COMMON_INFERENCE_ID_FIELD,
+ semanticTextMapping(commonInferenceId),
+ VARIABLE_INFERENCE_ID_FIELD,
+ semanticTextMapping(localInferenceId),
+ MIXED_TYPE_FIELD_1,
+ semanticTextMapping(localInferenceId),
+ MIXED_TYPE_FIELD_2,
+ textMapping(),
+ TEXT_FIELD,
+ textMapping()
+ ),
+ docs
+ );
+ final TestIndexInfo remoteIndexInfo = new TestIndexInfo(
+ REMOTE_INDEX_NAME,
+ Map.of(
+ commonInferenceId,
+ textEmbeddingServiceSettings(256, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT),
+ remoteInferenceId,
+ textEmbeddingServiceSettings(384, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT)
+ ),
+ Map.of(
+ COMMON_INFERENCE_ID_FIELD,
+ semanticTextMapping(commonInferenceId),
+ VARIABLE_INFERENCE_ID_FIELD,
+ semanticTextMapping(remoteInferenceId),
+ MIXED_TYPE_FIELD_1,
+ textMapping(),
+ MIXED_TYPE_FIELD_2,
+ semanticTextMapping(remoteInferenceId),
+ TEXT_FIELD,
+ textMapping()
+ ),
+ docs
+ );
+ setupTwoClusters(localIndexInfo, remoteIndexInfo);
+ }
+
+ private static String getDocId(String field) {
+ return field + "_doc";
+ }
+}
diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticQueryBuilderCrossClusterSearchIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticQueryBuilderCrossClusterSearchIT.java
new file mode 100644
index 0000000000000..4b3b616f93bb0
--- /dev/null
+++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticQueryBuilderCrossClusterSearchIT.java
@@ -0,0 +1,121 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.search.ccs;
+
+import org.elasticsearch.action.search.SearchRequest;
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
+import org.elasticsearch.inference.SimilarityMeasure;
+import org.elasticsearch.search.builder.PointInTimeBuilder;
+import org.elasticsearch.search.builder.SearchSourceBuilder;
+import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
+
+import java.util.List;
+import java.util.Map;
+import java.util.function.Consumer;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public class SemanticQueryBuilderCrossClusterSearchIT extends AbstractSemanticCrossClusterSearchTestCase {
+ private static final String LOCAL_INDEX_NAME = "local-index";
+ private static final String REMOTE_INDEX_NAME = "remote-index";
+ private static final List QUERY_INDICES = List.of(
+ new IndexWithBoost(LOCAL_INDEX_NAME),
+ new IndexWithBoost(fullyQualifiedIndexName(REMOTE_CLUSTER, REMOTE_INDEX_NAME))
+ );
+
+ public void testSemanticQuery() throws Exception {
+ final String commonInferenceId = "common-inference-id";
+ final String localInferenceId = "local-inference-id";
+ final String remoteInferenceId = "remote-inference-id";
+
+ final String commonInferenceIdField = "common-inference-id-field";
+ final String variableInferenceIdField = "variable-inference-id-field";
+
+ final TestIndexInfo localIndexInfo = new TestIndexInfo(
+ LOCAL_INDEX_NAME,
+ Map.of(commonInferenceId, sparseEmbeddingServiceSettings(), localInferenceId, sparseEmbeddingServiceSettings()),
+ Map.of(
+ commonInferenceIdField,
+ semanticTextMapping(commonInferenceId),
+ variableInferenceIdField,
+ semanticTextMapping(localInferenceId)
+ ),
+ Map.of("local_doc_1", Map.of(commonInferenceIdField, "a"), "local_doc_2", Map.of(variableInferenceIdField, "b"))
+ );
+ final TestIndexInfo remoteIndexInfo = new TestIndexInfo(
+ REMOTE_INDEX_NAME,
+ Map.of(
+ commonInferenceId,
+ textEmbeddingServiceSettings(256, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT),
+ remoteInferenceId,
+ textEmbeddingServiceSettings(384, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT)
+ ),
+ Map.of(
+ commonInferenceIdField,
+ semanticTextMapping(commonInferenceId),
+ variableInferenceIdField,
+ semanticTextMapping(remoteInferenceId)
+ ),
+ Map.of("remote_doc_1", Map.of(commonInferenceIdField, "x"), "remote_doc_2", Map.of(variableInferenceIdField, "y"))
+ );
+ setupTwoClusters(localIndexInfo, remoteIndexInfo);
+
+ // Query a field has the same inference ID value across clusters, but with different backing inference services
+ assertSearchResponse(
+ new SemanticQueryBuilder(commonInferenceIdField, "a"),
+ QUERY_INDICES,
+ List.of(
+ new SearchResult(LOCAL_CLUSTER, LOCAL_INDEX_NAME, "local_doc_1"),
+ new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, "remote_doc_1")
+ )
+ );
+
+ // Query a field that has different inference ID values across clusters
+ assertSearchResponse(
+ new SemanticQueryBuilder(variableInferenceIdField, "b"),
+ QUERY_INDICES,
+ List.of(
+ new SearchResult(LOCAL_CLUSTER, LOCAL_INDEX_NAME, "local_doc_2"),
+ new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, "remote_doc_2")
+ )
+ );
+ }
+
+ public void testSemanticQueryWithCcMinimizeRoundTripsFalse() throws Exception {
+ final SemanticQueryBuilder queryBuilder = new SemanticQueryBuilder("foo", "bar");
+ final Consumer assertCcsMinimizeRoundTripsFalseFailure = s -> {
+ IllegalArgumentException e = assertThrows(
+ IllegalArgumentException.class,
+ () -> client().search(s).actionGet(TEST_REQUEST_TIMEOUT)
+ );
+ assertThat(
+ e.getMessage(),
+ equalTo("semantic query does not support cross-cluster search when [ccs_minimize_roundtrips] is false")
+ );
+ };
+
+ final TestIndexInfo localIndexInfo = new TestIndexInfo(LOCAL_INDEX_NAME, Map.of(), Map.of(), Map.of());
+ final TestIndexInfo remoteIndexInfo = new TestIndexInfo(REMOTE_INDEX_NAME, Map.of(), Map.of(), Map.of());
+ setupTwoClusters(localIndexInfo, remoteIndexInfo);
+
+ // Explicitly set ccs_minimize_roundtrips=false in the search request
+ SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(queryBuilder);
+ SearchRequest searchRequestWithCcMinimizeRoundTripsFalse = new SearchRequest(convertToArray(QUERY_INDICES), searchSourceBuilder);
+ searchRequestWithCcMinimizeRoundTripsFalse.setCcsMinimizeRoundtrips(false);
+ assertCcsMinimizeRoundTripsFalseFailure.accept(searchRequestWithCcMinimizeRoundTripsFalse);
+
+ // Using a point in time implicitly sets ccs_minimize_roundtrips=false
+ BytesReference pitId = openPointInTime(convertToArray(QUERY_INDICES), TimeValue.timeValueMinutes(2));
+ SearchSourceBuilder searchSourceBuilderWithPit = new SearchSourceBuilder().query(queryBuilder)
+ .pointInTimeBuilder(new PointInTimeBuilder(pitId));
+ SearchRequest searchRequestWithPit = new SearchRequest().source(searchSourceBuilderWithPit);
+ assertCcsMinimizeRoundTripsFalseFailure.accept(searchRequestWithPit);
+ }
+}
diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SparseVectorQueryBuilderCrossClusterSearchIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SparseVectorQueryBuilderCrossClusterSearchIT.java
new file mode 100644
index 0000000000000..be9183722a48c
--- /dev/null
+++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SparseVectorQueryBuilderCrossClusterSearchIT.java
@@ -0,0 +1,248 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.search.ccs;
+
+import org.elasticsearch.action.search.SearchRequest;
+import org.elasticsearch.action.search.SearchResponse;
+import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.index.query.QueryShardException;
+import org.elasticsearch.inference.WeightedToken;
+import org.elasticsearch.search.builder.SearchSourceBuilder;
+import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.function.Consumer;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public class SparseVectorQueryBuilderCrossClusterSearchIT extends AbstractSemanticCrossClusterSearchTestCase {
+ private static final String LOCAL_INDEX_NAME = "local-index";
+ private static final String REMOTE_INDEX_NAME = "remote-index";
+ private static final List QUERY_INDICES = List.of(
+ new IndexWithBoost(LOCAL_INDEX_NAME),
+ new IndexWithBoost(fullyQualifiedIndexName(REMOTE_CLUSTER, REMOTE_INDEX_NAME))
+ );
+
+ public void testSparseVectorQuery() throws Exception {
+ final String commonInferenceId = "common-inference-id";
+
+ final String commonInferenceIdField = "common-inference-id-field";
+ final String mixedTypeField1 = "mixed-type-field-1";
+ final String mixedTypeField2 = "mixed-type-field-2";
+
+ final TestIndexInfo localIndexInfo = new TestIndexInfo(
+ LOCAL_INDEX_NAME,
+ Map.of(commonInferenceId, sparseEmbeddingServiceSettings()),
+ Map.of(
+ commonInferenceIdField,
+ semanticTextMapping(commonInferenceId),
+ mixedTypeField1,
+ sparseVectorMapping(),
+ mixedTypeField2,
+ semanticTextMapping(commonInferenceId)
+ ),
+ Map.of(
+ "local_doc_1",
+ Map.of(commonInferenceIdField, "a"),
+ "local_doc_2",
+ Map.of(mixedTypeField1, generateSparseVectorFieldValue(1.0f)),
+ "local_doc_3",
+ Map.of(mixedTypeField2, "c")
+ )
+ );
+ final TestIndexInfo remoteIndexInfo = new TestIndexInfo(
+ REMOTE_INDEX_NAME,
+ Map.of(commonInferenceId, sparseEmbeddingServiceSettings()),
+ Map.of(
+ commonInferenceIdField,
+ semanticTextMapping(commonInferenceId),
+ mixedTypeField1,
+ semanticTextMapping(commonInferenceId),
+ mixedTypeField2,
+ sparseVectorMapping()
+ ),
+ Map.of(
+ "remote_doc_1",
+ Map.of(commonInferenceIdField, "x"),
+ "remote_doc_2",
+ Map.of(mixedTypeField1, "y"),
+ "remote_doc_3",
+ Map.of(mixedTypeField2, generateSparseVectorFieldValue(1.0f))
+ )
+ );
+ setupTwoClusters(localIndexInfo, remoteIndexInfo);
+
+ // Query a field has the same inference ID value across clusters, but with different backing inference services
+ assertSearchResponse(
+ new SparseVectorQueryBuilder(commonInferenceIdField, null, "a"),
+ QUERY_INDICES,
+ List.of(
+ new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, "remote_doc_1"),
+ new SearchResult(LOCAL_CLUSTER, LOCAL_INDEX_NAME, "local_doc_1")
+ )
+ );
+
+ // Query a field that has mixed types across clusters
+ assertSearchResponse(
+ new SparseVectorQueryBuilder(mixedTypeField1, commonInferenceId, "b"),
+ QUERY_INDICES,
+ List.of(
+ new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, "remote_doc_2"),
+ new SearchResult(LOCAL_CLUSTER, LOCAL_INDEX_NAME, "local_doc_2")
+ )
+ );
+ assertSearchResponse(
+ new SparseVectorQueryBuilder(mixedTypeField2, commonInferenceId, "c"),
+ QUERY_INDICES,
+ List.of(
+ new SearchResult(LOCAL_CLUSTER, LOCAL_INDEX_NAME, "local_doc_3"),
+ new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, "remote_doc_3")
+ )
+ );
+
+ // Query a field that has mixed types across clusters using a query vector
+ final List queryVector = generateSparseVectorFieldValue(1.0f).entrySet()
+ .stream()
+ .map(e -> new WeightedToken(e.getKey(), e.getValue()))
+ .toList();
+ assertSearchResponse(
+ new SparseVectorQueryBuilder(mixedTypeField1, queryVector, null, null, null, null),
+ QUERY_INDICES,
+ List.of(
+ new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, "remote_doc_2"),
+ new SearchResult(LOCAL_CLUSTER, LOCAL_INDEX_NAME, "local_doc_2")
+ )
+ );
+ assertSearchResponse(
+ new SparseVectorQueryBuilder(mixedTypeField2, queryVector, null, null, null, null),
+ QUERY_INDICES,
+ List.of(
+ new SearchResult(LOCAL_CLUSTER, LOCAL_INDEX_NAME, "local_doc_3"),
+ new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, "remote_doc_3")
+ )
+ );
+
+ // Check that omitting the inference ID when querying a remote sparse vector field leads to the expected partial failure
+ assertSearchResponse(
+ new SparseVectorQueryBuilder(mixedTypeField2, null, "c"),
+ QUERY_INDICES,
+ List.of(new SearchResult(LOCAL_CLUSTER, LOCAL_INDEX_NAME, "local_doc_3")),
+ new ClusterFailure(
+ SearchResponse.Cluster.Status.SKIPPED,
+ Set.of(new FailureCause(IllegalArgumentException.class, "inference_id required to perform vector search on query string"))
+ ),
+ null
+ );
+ }
+
+ public void testSparseVectorQueryWithCcsMinimizeRoundTripsFalse() throws Exception {
+ final Consumer assertCcsMinimizeRoundTripsFalseFailure = q -> {
+ SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(q);
+ SearchRequest searchRequest = new SearchRequest(convertToArray(QUERY_INDICES), searchSourceBuilder);
+ searchRequest.setCcsMinimizeRoundtrips(false);
+
+ IllegalArgumentException e = assertThrows(
+ IllegalArgumentException.class,
+ () -> client().search(searchRequest).actionGet(TEST_REQUEST_TIMEOUT)
+ );
+ assertThat(
+ e.getMessage(),
+ equalTo(
+ "sparse_vector query does not support cross-cluster search when querying a [semantic_text] field when "
+ + "[ccs_minimize_roundtrips] is false"
+ )
+ );
+ };
+
+ final String commonInferenceId = "common-inference-id";
+
+ final String commonInferenceIdField = "common-inference-id-field";
+ final String mixedTypeField1 = "mixed-type-field-1";
+ final String mixedTypeField2 = "mixed-type-field-2";
+ final String sparseVectorField = "sparse-vector-field";
+
+ final TestIndexInfo localIndexInfo = new TestIndexInfo(
+ LOCAL_INDEX_NAME,
+ Map.of(commonInferenceId, sparseEmbeddingServiceSettings()),
+ Map.of(
+ commonInferenceIdField,
+ semanticTextMapping(commonInferenceId),
+ mixedTypeField1,
+ semanticTextMapping(commonInferenceId),
+ mixedTypeField2,
+ sparseVectorMapping(),
+ sparseVectorField,
+ sparseVectorMapping()
+ ),
+ Map.of(
+ mixedTypeField2 + "_doc",
+ Map.of(mixedTypeField2, generateSparseVectorFieldValue(1.0f)),
+ sparseVectorField + "_doc",
+ Map.of(sparseVectorField, generateSparseVectorFieldValue(1.0f))
+ )
+ );
+ final TestIndexInfo remoteIndexInfo = new TestIndexInfo(
+ REMOTE_INDEX_NAME,
+ Map.of(commonInferenceId, sparseEmbeddingServiceSettings()),
+ Map.of(
+ commonInferenceIdField,
+ semanticTextMapping(commonInferenceId),
+ mixedTypeField1,
+ sparseVectorMapping(),
+ mixedTypeField2,
+ semanticTextMapping(commonInferenceId),
+ sparseVectorField,
+ sparseVectorMapping()
+ ),
+ Map.of(
+ mixedTypeField2 + "_doc",
+ Map.of(mixedTypeField2, "a"),
+ sparseVectorField + "_doc",
+ Map.of(sparseVectorField, generateSparseVectorFieldValue(0.5f))
+ )
+ );
+ setupTwoClusters(localIndexInfo, remoteIndexInfo);
+
+ // Validate that expected cases fail
+ assertCcsMinimizeRoundTripsFalseFailure.accept(new SparseVectorQueryBuilder(commonInferenceIdField, null, randomAlphaOfLength(5)));
+ assertCcsMinimizeRoundTripsFalseFailure.accept(
+ new SparseVectorQueryBuilder(mixedTypeField1, commonInferenceId, randomAlphaOfLength(5))
+ );
+
+ // Validate the expected ccs_minimize_roundtrips=false detection gap and failure mode when querying non-inference fields locally
+ assertSearchResponse(
+ new SparseVectorQueryBuilder(mixedTypeField2, commonInferenceId, "foo"),
+ QUERY_INDICES,
+ List.of(new SearchResult(null, LOCAL_INDEX_NAME, mixedTypeField2 + "_doc")),
+ new ClusterFailure(
+ SearchResponse.Cluster.Status.SKIPPED,
+ Set.of(
+ new FailureCause(
+ QueryShardException.class,
+ "failed to create query: field [mixed-type-field-2] must be type [sparse_vector] but is type [semantic_text]"
+ )
+ )
+ ),
+ s -> s.setCcsMinimizeRoundtrips(false)
+ );
+
+ // Validate that a CCS sparse vector query functions when only sparse vector fields are queried
+ assertSearchResponse(
+ new SparseVectorQueryBuilder(sparseVectorField, commonInferenceId, "foo"),
+ QUERY_INDICES,
+ List.of(
+ new SearchResult(null, LOCAL_INDEX_NAME, sparseVectorField + "_doc"),
+ new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, sparseVectorField + "_doc")
+ ),
+ null,
+ s -> s.setCcsMinimizeRoundtrips(false)
+ );
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/FullyQualifiedInferenceId.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/FullyQualifiedInferenceId.java
index 5ee9d00da4abb..d1cc7aa3aa5c7 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/FullyQualifiedInferenceId.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/FullyQualifiedInferenceId.java
@@ -29,4 +29,9 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(clusterAlias);
out.writeString(inferenceId);
}
+
+ @Override
+ public String toString() {
+ return "{clusterAlias=" + clusterAlias + ", inferenceId=" + inferenceId + "}";
+ }
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java
index bf6b9d534d52e..0696185e11650 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java
@@ -34,6 +34,8 @@
import java.util.Collection;
import java.util.Map;
+import static org.elasticsearch.transport.RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY;
+
public class InterceptedInferenceKnnVectorQueryBuilder extends InterceptedInferenceQueryBuilder {
public static final String NAME = "intercepted_inference_knn";
@@ -44,15 +46,23 @@ public InterceptedInferenceKnnVectorQueryBuilder(KnnVectorQueryBuilder originalQ
super(originalQuery);
}
+ public InterceptedInferenceKnnVectorQueryBuilder(
+ KnnVectorQueryBuilder originalQuery,
+ Map inferenceResultsMap
+ ) {
+ super(originalQuery, inferenceResultsMap);
+ }
+
public InterceptedInferenceKnnVectorQueryBuilder(StreamInput in) throws IOException {
super(in);
}
- InterceptedInferenceKnnVectorQueryBuilder(
+ private InterceptedInferenceKnnVectorQueryBuilder(
InterceptedInferenceQueryBuilder other,
- Map inferenceResultsMap
+ Map inferenceResultsMap,
+ boolean ccsRequest
) {
- super(other, inferenceResultsMap);
+ super(other, inferenceResultsMap, ccsRequest);
}
@Override
@@ -72,8 +82,9 @@ protected String getQuery() {
}
@Override
- protected String getInferenceIdOverride() {
- return getQueryVectorBuilderModelId();
+ protected FullyQualifiedInferenceId getInferenceIdOverride() {
+ String modelId = getQueryVectorBuilderModelId();
+ return modelId != null ? new FullyQualifiedInferenceId(LOCAL_CLUSTER_GROUP_KEY, modelId) : null;
}
@Override
@@ -114,8 +125,8 @@ protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) {
}
@Override
- protected QueryBuilder copy(Map inferenceResultsMap) {
- return new InterceptedInferenceKnnVectorQueryBuilder(this, inferenceResultsMap);
+ protected QueryBuilder copy(Map inferenceResultsMap, boolean ccsRequest) {
+ return new InterceptedInferenceKnnVectorQueryBuilder(this, inferenceResultsMap, ccsRequest);
}
@Override
@@ -131,7 +142,7 @@ protected QueryBuilder queryFields(
} else if (fieldType instanceof SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType) {
rewritten = querySemanticTextField(indexMetadataContext.getLocalClusterAlias(), semanticTextFieldType);
} else {
- rewritten = queryNonSemanticTextField(indexMetadataContext.getLocalClusterAlias());
+ rewritten = queryNonSemanticTextField();
}
return rewritten;
@@ -177,12 +188,12 @@ private QueryBuilder querySemanticTextField(String clusterAlias, SemanticTextFie
VectorData queryVector = originalQuery.queryVector();
if (queryVector == null) {
- String inferenceId = getQueryVectorBuilderModelId();
- if (inferenceId == null) {
- inferenceId = semanticTextFieldType.getSearchInferenceId();
+ FullyQualifiedInferenceId fullyQualifiedInferenceId = getInferenceIdOverride();
+ if (fullyQualifiedInferenceId == null) {
+ fullyQualifiedInferenceId = new FullyQualifiedInferenceId(clusterAlias, semanticTextFieldType.getSearchInferenceId());
}
- MlTextEmbeddingResults textEmbeddingResults = getTextEmbeddingResults(clusterAlias, inferenceId);
+ MlTextEmbeddingResults textEmbeddingResults = getTextEmbeddingResults(fullyQualifiedInferenceId);
queryVector = new VectorData(textEmbeddingResults.getInferenceAsFloat());
}
@@ -202,18 +213,18 @@ private QueryBuilder querySemanticTextField(String clusterAlias, SemanticTextFie
.queryName(originalQuery.queryName());
}
- private QueryBuilder queryNonSemanticTextField(String clusterAlias) {
+ private QueryBuilder queryNonSemanticTextField() {
VectorData queryVector = originalQuery.queryVector();
if (queryVector == null) {
- String modelId = getQueryVectorBuilderModelId();
- if (modelId == null) {
+ FullyQualifiedInferenceId fullyQualifiedInferenceId = getInferenceIdOverride();
+ if (fullyQualifiedInferenceId == null) {
// This should never happen because we validate that either query vector or a valid query vector builder is specified in:
// - The KnnVectorQueryBuilder constructor
// - coordinatorNodeValidate
throw new IllegalStateException("No query vector or query vector builder model ID specified");
}
- MlTextEmbeddingResults textEmbeddingResults = getTextEmbeddingResults(clusterAlias, modelId);
+ MlTextEmbeddingResults textEmbeddingResults = getTextEmbeddingResults(fullyQualifiedInferenceId);
queryVector = new VectorData(textEmbeddingResults.getInferenceAsFloat());
}
@@ -231,10 +242,10 @@ private QueryBuilder queryNonSemanticTextField(String clusterAlias) {
return knnQuery;
}
- private MlTextEmbeddingResults getTextEmbeddingResults(String clusterAlias, String inferenceId) {
- InferenceResults inferenceResults = inferenceResultsMap.get(new FullyQualifiedInferenceId(clusterAlias, inferenceId));
+ private MlTextEmbeddingResults getTextEmbeddingResults(FullyQualifiedInferenceId fullyQualifiedInferenceId) {
+ InferenceResults inferenceResults = inferenceResultsMap.get(fullyQualifiedInferenceId);
if (inferenceResults == null) {
- throw new IllegalStateException("Could not find inference results from inference endpoint [" + inferenceId + "]");
+ throw new IllegalStateException("Could not find inference results from inference endpoint [" + fullyQualifiedInferenceId + "]");
} else if (inferenceResults instanceof MlTextEmbeddingResults == false) {
throw new IllegalArgumentException(
"Expected query inference results to be of type ["
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilder.java
index eecc08acebb4d..69cbf665cc1f8 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilder.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilder.java
@@ -31,15 +31,23 @@ public InterceptedInferenceMatchQueryBuilder(MatchQueryBuilder originalQuery) {
super(originalQuery);
}
+ public InterceptedInferenceMatchQueryBuilder(
+ MatchQueryBuilder originalQuery,
+ Map inferenceResultsMap
+ ) {
+ super(originalQuery, inferenceResultsMap);
+ }
+
public InterceptedInferenceMatchQueryBuilder(StreamInput in) throws IOException {
super(in);
}
- InterceptedInferenceMatchQueryBuilder(
+ private InterceptedInferenceMatchQueryBuilder(
InterceptedInferenceQueryBuilder other,
- Map inferenceResultsMap
+ Map inferenceResultsMap,
+ boolean ccsRequest
) {
- super(other, inferenceResultsMap);
+ super(other, inferenceResultsMap, ccsRequest);
}
@Override
@@ -63,8 +71,8 @@ protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) {
}
@Override
- protected QueryBuilder copy(Map inferenceResultsMap) {
- return new InterceptedInferenceMatchQueryBuilder(this, inferenceResultsMap);
+ protected QueryBuilder copy(Map inferenceResultsMap, boolean ccsRequest) {
+ return new InterceptedInferenceMatchQueryBuilder(this, inferenceResultsMap, ccsRequest);
}
@Override
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java
index a1b5f34aab848..8774d35f17ade 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java
@@ -40,6 +40,7 @@
import static org.elasticsearch.index.IndexSettings.DEFAULT_FIELD_SETTING;
import static org.elasticsearch.transport.RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY;
+import static org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder.SEMANTIC_SEARCH_CCS_SUPPORT;
import static org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder.convertFromBwcInferenceResultsMap;
/**
@@ -66,11 +67,17 @@ public abstract class InterceptedInferenceQueryBuilder inferenceResultsMap;
+ protected final boolean ccsRequest;
protected InterceptedInferenceQueryBuilder(T originalQuery) {
+ this(originalQuery, null);
+ }
+
+ protected InterceptedInferenceQueryBuilder(T originalQuery, Map inferenceResultsMap) {
Objects.requireNonNull(originalQuery, "original query must not be null");
this.originalQuery = originalQuery;
- this.inferenceResultsMap = null;
+ this.inferenceResultsMap = inferenceResultsMap != null ? Map.copyOf(inferenceResultsMap) : null;
+ this.ccsRequest = false;
}
@SuppressWarnings("unchecked")
@@ -86,14 +93,21 @@ protected InterceptedInferenceQueryBuilder(StreamInput in) throws IOException {
in.readOptional(i1 -> i1.readImmutableMap(i2 -> i2.readNamedWriteable(InferenceResults.class)))
);
}
+ if (in.getTransportVersion().supports(SEMANTIC_SEARCH_CCS_SUPPORT)) {
+ this.ccsRequest = in.readBoolean();
+ } else {
+ this.ccsRequest = false;
+ }
}
protected InterceptedInferenceQueryBuilder(
InterceptedInferenceQueryBuilder other,
- Map inferenceResultsMap
+ Map inferenceResultsMap,
+ boolean ccsRequest
) {
this.originalQuery = other.originalQuery;
this.inferenceResultsMap = inferenceResultsMap;
+ this.ccsRequest = ccsRequest;
}
/**
@@ -133,9 +147,10 @@ protected InterceptedInferenceQueryBuilder(
* Generate a copy of {@code this} using the provided inference results map.
*
* @param inferenceResultsMap The inference results map
+ * @param ccsRequest Flag indicating if this is a CCS request
* @return A copy of {@code this} with the provided inference results map
*/
- protected abstract QueryBuilder copy(Map inferenceResultsMap);
+ protected abstract QueryBuilder copy(Map inferenceResultsMap, boolean ccsRequest);
/**
* Rewrite to a {@link QueryBuilder} appropriate for a specific index's mappings. The implementation can use
@@ -167,7 +182,7 @@ protected abstract QueryBuilder queryFields(
/**
* Get the query-time inference ID override. If not applicable or available, {@code null} should be returned.
*/
- protected String getInferenceIdOverride() {
+ protected FullyQualifiedInferenceId getInferenceIdOverride() {
return null;
}
@@ -194,6 +209,19 @@ protected void doWriteTo(StreamOutput out) throws IOException {
o2.writeString(id.inferenceId());
}, StreamOutput::writeNamedWriteable), inferenceResultsMap);
}
+ if (out.getTransportVersion().supports(SEMANTIC_SEARCH_CCS_SUPPORT)) {
+ out.writeBoolean(ccsRequest);
+ } else if (ccsRequest) {
+ throw new IllegalArgumentException(
+ "One or more nodes does not support "
+ + originalQuery.getName()
+ + " query cross-cluster search when querying a ["
+ + SemanticTextFieldMapper.CONTENT_TYPE
+ + "] field. Please update all nodes to at least Elasticsearch "
+ + SEMANTIC_SEARCH_CCS_SUPPORT.toReleaseVersion()
+ + "."
+ );
+ }
}
@Override
@@ -208,12 +236,14 @@ protected Query doToQuery(SearchExecutionContext context) {
@Override
protected boolean doEquals(InterceptedInferenceQueryBuilder other) {
- return Objects.equals(originalQuery, other.originalQuery) && Objects.equals(inferenceResultsMap, other.inferenceResultsMap);
+ return Objects.equals(originalQuery, other.originalQuery)
+ && Objects.equals(inferenceResultsMap, other.inferenceResultsMap)
+ && Objects.equals(ccsRequest, other.ccsRequest);
}
@Override
protected int doHashCode() {
- return Objects.hash(originalQuery, inferenceResultsMap);
+ return Objects.hash(originalQuery, inferenceResultsMap, ccsRequest);
}
@Override
@@ -261,14 +291,19 @@ private QueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewri
// In this case, the remote data node will receive the original query, which will in turn result in an error about querying an
// unsupported field type.
ResolvedIndices resolvedIndices = queryRewriteContext.getResolvedIndices();
- Set inferenceIds = getInferenceIdsForFields(
+ Set inferenceIds = getInferenceIdsForFields(
resolvedIndices.getConcreteLocalIndicesMetadata().values(),
+ queryRewriteContext.getLocalClusterAlias(),
getFields(),
resolveWildcards(),
useDefaultFields()
);
- if (inferenceIds.isEmpty()) {
+ // If we are handling a CCS request, always retain the intercepted query logic so that we can get inference results generated on
+ // the local cluster from the inference results map when rewriting on remote cluster data nodes. This can be necessary when:
+ // - A query specifies an inference ID override
+ // - Only non-inference fields are queried on the remote cluster
+ if (inferenceIds.isEmpty() && this.ccsRequest == false) {
// Not querying a semantic text field
return originalQuery;
}
@@ -276,17 +311,17 @@ private QueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewri
// Validate early to prevent partial failures
coordinatorNodeValidate(resolvedIndices);
- // TODO: Check for supported CCS mode here (once we support CCS)
- if (resolvedIndices.getRemoteClusterIndices().isEmpty() == false) {
+ boolean ccsRequest = this.ccsRequest || resolvedIndices.getRemoteClusterIndices().isEmpty() == false;
+ if (ccsRequest && queryRewriteContext.isCcsMinimizeRoundTrips() == false) {
throw new IllegalArgumentException(
originalQuery.getName()
+ " query does not support cross-cluster search when querying a ["
+ SemanticTextFieldMapper.CONTENT_TYPE
- + "] field"
+ + "] field when [ccs_minimize_roundtrips] is false"
);
}
- String inferenceIdOverride = getInferenceIdOverride();
+ FullyQualifiedInferenceId inferenceIdOverride = getInferenceIdOverride();
if (inferenceIdOverride != null) {
inferenceIds = Set.of(inferenceIdOverride);
}
@@ -307,20 +342,21 @@ private QueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewri
// The inference results map is fully populated, so we can perform error checking
inferenceResultsErrorCheck(modifiedInferenceResultsMap);
} else {
- rewritten = copy(modifiedInferenceResultsMap);
+ rewritten = copy(modifiedInferenceResultsMap, ccsRequest);
}
}
return rewritten;
}
- private static Set getInferenceIdsForFields(
+ private static Set getInferenceIdsForFields(
Collection indexMetadataCollection,
+ String clusterAlias,
Map fields,
boolean resolveWildcards,
boolean useDefaultFields
) {
- Set inferenceIds = new HashSet<>();
+ Set fullyQualifiedInferenceIds = new HashSet<>();
for (IndexMetadata indexMetadata : indexMetadataCollection) {
final Map indexQueryFields = (useDefaultFields && fields.isEmpty())
? getDefaultFields(indexMetadata.getSettings())
@@ -331,23 +367,34 @@ private static Set getInferenceIdsForFields(
if (indexInferenceFields.containsKey(indexQueryField)) {
// No wildcards in field name
InferenceFieldMetadata inferenceFieldMetadata = indexInferenceFields.get(indexQueryField);
- inferenceIds.add(inferenceFieldMetadata.getSearchInferenceId());
+ fullyQualifiedInferenceIds.add(
+ new FullyQualifiedInferenceId(clusterAlias, inferenceFieldMetadata.getSearchInferenceId())
+ );
continue;
}
if (resolveWildcards) {
if (Regex.isMatchAllPattern(indexQueryField)) {
- indexInferenceFields.values().forEach(ifm -> inferenceIds.add(ifm.getSearchInferenceId()));
+ indexInferenceFields.values()
+ .forEach(
+ ifm -> fullyQualifiedInferenceIds.add(
+ new FullyQualifiedInferenceId(clusterAlias, ifm.getSearchInferenceId())
+ )
+ );
} else if (Regex.isSimpleMatchPattern(indexQueryField)) {
indexInferenceFields.values()
.stream()
.filter(ifm -> Regex.simpleMatch(indexQueryField, ifm.getName()))
- .forEach(ifm -> inferenceIds.add(ifm.getSearchInferenceId()));
+ .forEach(
+ ifm -> fullyQualifiedInferenceIds.add(
+ new FullyQualifiedInferenceId(clusterAlias, ifm.getSearchInferenceId())
+ )
+ );
}
}
}
}
- return inferenceIds;
+ return fullyQualifiedInferenceIds;
}
private static Map getInferenceFieldsMap(
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilder.java
index 655fb3f790cd9..dab789c0223e7 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilder.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilder.java
@@ -33,6 +33,8 @@
import java.util.List;
import java.util.Map;
+import static org.elasticsearch.transport.RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY;
+
public class InterceptedInferenceSparseVectorQueryBuilder extends InterceptedInferenceQueryBuilder {
public static final String NAME = "intercepted_inference_sparse_vector";
@@ -43,15 +45,23 @@ public InterceptedInferenceSparseVectorQueryBuilder(SparseVectorQueryBuilder ori
super(originalQuery);
}
+ public InterceptedInferenceSparseVectorQueryBuilder(
+ SparseVectorQueryBuilder originalQuery,
+ Map inferenceResultsMap
+ ) {
+ super(originalQuery, inferenceResultsMap);
+ }
+
public InterceptedInferenceSparseVectorQueryBuilder(StreamInput in) throws IOException {
super(in);
}
- InterceptedInferenceSparseVectorQueryBuilder(
+ private InterceptedInferenceSparseVectorQueryBuilder(
InterceptedInferenceQueryBuilder other,
- Map inferenceResultsMap
+ Map inferenceResultsMap,
+ boolean ccsRequest
) {
- super(other, inferenceResultsMap);
+ super(other, inferenceResultsMap, ccsRequest);
}
@Override
@@ -65,8 +75,14 @@ protected String getQuery() {
}
@Override
- protected String getInferenceIdOverride() {
- return originalQuery.getInferenceId();
+ protected FullyQualifiedInferenceId getInferenceIdOverride() {
+ FullyQualifiedInferenceId override = null;
+ String originalInferenceId = originalQuery.getInferenceId();
+ if (originalInferenceId != null) {
+ override = new FullyQualifiedInferenceId(LOCAL_CLUSTER_GROUP_KEY, originalInferenceId);
+ }
+
+ return override;
}
@Override
@@ -96,8 +112,8 @@ protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) {
}
@Override
- protected QueryBuilder copy(Map inferenceResultsMap) {
- return new InterceptedInferenceSparseVectorQueryBuilder(this, inferenceResultsMap);
+ protected QueryBuilder copy(Map inferenceResultsMap, boolean ccsRequest) {
+ return new InterceptedInferenceSparseVectorQueryBuilder(this, inferenceResultsMap, ccsRequest);
}
@Override
@@ -113,7 +129,7 @@ protected QueryBuilder queryFields(
} else if (fieldType instanceof SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType) {
rewritten = querySemanticTextField(indexMetadataContext.getLocalClusterAlias(), semanticTextFieldType);
} else {
- rewritten = queryNonSemanticTextField(indexMetadataContext.getLocalClusterAlias());
+ rewritten = queryNonSemanticTextField();
}
return rewritten;
@@ -149,12 +165,12 @@ private QueryBuilder querySemanticTextField(String clusterAlias, SemanticTextFie
List queryVector = originalQuery.getQueryVectors();
if (queryVector == null) {
- String inferenceId = originalQuery.getInferenceId();
- if (inferenceId == null) {
- inferenceId = semanticTextFieldType.getSearchInferenceId();
+ FullyQualifiedInferenceId fullyQualifiedInferenceId = getInferenceIdOverride();
+ if (fullyQualifiedInferenceId == null) {
+ fullyQualifiedInferenceId = new FullyQualifiedInferenceId(clusterAlias, semanticTextFieldType.getSearchInferenceId());
}
- queryVector = getQueryVector(clusterAlias, inferenceId);
+ queryVector = getQueryVector(fullyQualifiedInferenceId);
}
SparseVectorQueryBuilder innerSparseVectorQuery = new SparseVectorQueryBuilder(
@@ -171,15 +187,15 @@ private QueryBuilder querySemanticTextField(String clusterAlias, SemanticTextFie
.queryName(originalQuery.queryName());
}
- private QueryBuilder queryNonSemanticTextField(String clusterAlias) {
+ private QueryBuilder queryNonSemanticTextField() {
List queryVector = originalQuery.getQueryVectors();
if (queryVector == null) {
- String inferenceId = originalQuery.getInferenceId();
- if (inferenceId == null) {
+ FullyQualifiedInferenceId fullyQualifiedInferenceId = getInferenceIdOverride();
+ if (fullyQualifiedInferenceId == null) {
throw new IllegalArgumentException("Either query vector or inference ID must be specified");
}
- queryVector = getQueryVector(clusterAlias, inferenceId);
+ queryVector = getQueryVector(fullyQualifiedInferenceId);
}
return new SparseVectorQueryBuilder(
@@ -192,10 +208,10 @@ private QueryBuilder queryNonSemanticTextField(String clusterAlias) {
).boost(originalQuery.boost()).queryName(originalQuery.queryName());
}
- private List getQueryVector(String clusterAlias, String inferenceId) {
- InferenceResults inferenceResults = inferenceResultsMap.get(new FullyQualifiedInferenceId(clusterAlias, inferenceId));
+ private List getQueryVector(FullyQualifiedInferenceId fullyQualifiedInferenceId) {
+ InferenceResults inferenceResults = inferenceResultsMap.get(fullyQualifiedInferenceId);
if (inferenceResults == null) {
- throw new IllegalStateException("Could not find inference results from inference endpoint [" + inferenceId + "]");
+ throw new IllegalStateException("Could not find inference results from inference endpoint [" + fullyQualifiedInferenceId + "]");
} else if (inferenceResults instanceof TextExpansionResults == false) {
throw new IllegalArgumentException(
"Expected query inference results to be of type ["
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/LegacySemanticQueryRewriteInterceptor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/LegacySemanticQueryRewriteInterceptor.java
index 670d846c8d4a9..4052e93559437 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/LegacySemanticQueryRewriteInterceptor.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/LegacySemanticQueryRewriteInterceptor.java
@@ -16,6 +16,7 @@
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.TermsQueryBuilder;
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
+import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
import java.util.ArrayList;
import java.util.Collection;
@@ -24,6 +25,8 @@
import java.util.Map;
import java.util.stream.Collectors;
+import static org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder.SEMANTIC_SEARCH_CCS_SUPPORT;
+
/**
* Intercepts and adapts a query to be rewritten to work seamlessly on a semantic_text field.
*/
@@ -46,15 +49,26 @@ public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilde
if (indexInformation.getInferenceIndices().isEmpty()) {
// No inference fields were identified, so return the original query.
return queryBuilder;
- } else if (indexInformation.nonInferenceIndices().isEmpty() == false) {
- // Combined case where the field name requested by this query contains both
- // semantic_text and non-inference fields, so we have to combine queries per index
- // containing each field type.
- return buildCombinedInferenceAndNonInferenceQuery(queryBuilder, indexInformation);
+ } else if (resolvedIndices.getRemoteClusterIndices().isEmpty()) {
+ if (indexInformation.nonInferenceIndices().isEmpty() == false) {
+ // Combined case where the field name requested by this query contains both
+ // semantic_text and non-inference fields, so we have to combine queries per index
+ // containing each field type.
+ return buildCombinedInferenceAndNonInferenceQuery(queryBuilder, indexInformation);
+ } else {
+ // The only fields we've identified are inference fields (e.g. semantic_text),
+ // so rewrite the entire query to work on a semantic_text field.
+ return buildInferenceQuery(queryBuilder, indexInformation);
+ }
} else {
- // The only fields we've identified are inference fields (e.g. semantic_text),
- // so rewrite the entire query to work on a semantic_text field.
- return buildInferenceQuery(queryBuilder, indexInformation);
+ throw new IllegalArgumentException(
+ getQueryName()
+ + " query does not support cross-cluster search when querying a ["
+ + SemanticTextFieldMapper.CONTENT_TYPE
+ + "] field in a mixed-version cluster. Please update all nodes to at least Elasticsearch "
+ + SEMANTIC_SEARCH_CCS_SUPPORT.toReleaseVersion()
+ + "."
+ );
}
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java
index d00b3804a9920..97d0caef98d0e 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java
@@ -68,6 +68,7 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder inferenceResultsMap;
private final Boolean lenient;
+ // ccsRequest is only used on the local cluster coordinator node to detect when:
+ // - The request references a remote index
+ // - The remote cluster is too old to support semantic search CCS
+ // It doesn't technically need to be serialized since it is only used for this purpose, but we do so to keep its behavior in line with
+ // standard query member variables.
+ private final boolean ccsRequest;
+
public SemanticQueryBuilder(String fieldName, String query) {
this(fieldName, query, null);
}
public SemanticQueryBuilder(String fieldName, String query, Boolean lenient) {
- this(fieldName, query, lenient, null);
+ this(fieldName, query, lenient, null, false);
}
protected SemanticQueryBuilder(
@@ -107,6 +115,16 @@ protected SemanticQueryBuilder(
String query,
Boolean lenient,
Map inferenceResultsMap
+ ) {
+ this(fieldName, query, lenient, inferenceResultsMap, false);
+ }
+
+ protected SemanticQueryBuilder(
+ String fieldName,
+ String query,
+ Boolean lenient,
+ Map inferenceResultsMap,
+ boolean ccsRequest
) {
if (fieldName == null) {
throw new IllegalArgumentException("[" + NAME + "] requires a " + FIELD_FIELD.getPreferredName() + " value");
@@ -118,6 +136,7 @@ protected SemanticQueryBuilder(
this.query = query;
this.inferenceResultsMap = inferenceResultsMap != null ? Map.copyOf(inferenceResultsMap) : null;
this.lenient = lenient;
+ this.ccsRequest = ccsRequest;
}
public SemanticQueryBuilder(StreamInput in) throws IOException {
@@ -142,6 +161,11 @@ public SemanticQueryBuilder(StreamInput in) throws IOException {
} else {
this.lenient = null;
}
+ if (in.getTransportVersion().supports(SEMANTIC_SEARCH_CCS_SUPPORT)) {
+ this.ccsRequest = in.readBoolean();
+ } else {
+ this.ccsRequest = false;
+ }
}
@Override
@@ -176,9 +200,24 @@ protected void doWriteTo(StreamOutput out) throws IOException {
if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_QUERY_LENIENT)) {
out.writeOptionalBoolean(lenient);
}
+ if (out.getTransportVersion().supports(SEMANTIC_SEARCH_CCS_SUPPORT)) {
+ out.writeBoolean(ccsRequest);
+ } else if (ccsRequest) {
+ throw new IllegalArgumentException(
+ "One or more nodes does not support "
+ + NAME
+ + " query cross-cluster search. Please update all nodes to at least Elasticsearch "
+ + SEMANTIC_SEARCH_CCS_SUPPORT.toReleaseVersion()
+ + "."
+ );
+ }
}
- private SemanticQueryBuilder(SemanticQueryBuilder other, Map inferenceResultsMap) {
+ private SemanticQueryBuilder(
+ SemanticQueryBuilder other,
+ Map inferenceResultsMap,
+ boolean ccsRequest
+ ) {
this.fieldName = other.fieldName;
this.query = other.query;
this.boost = other.boost;
@@ -186,6 +225,7 @@ private SemanticQueryBuilder(SemanticQueryBuilder other, Map
- * Get inference results for the provided query using the provided inference IDs. The inference IDs are fully qualified by the
- * cluster alias in the provided {@link QueryRewriteContext}.
+ * Get inference results for the provided query using the provided fully qualified inference IDs.
*
*
* This method will return an inference results map that will be asynchronously populated with inference results. If the provided
@@ -222,14 +261,14 @@ public static SemanticQueryBuilder fromXContent(XContentParser parser) throws IO
*
*
* @param queryRewriteContext The query rewrite context
- * @param inferenceIds The inference IDs to use to generate inference results
+ * @param fullyQualifiedInferenceIds The fully qualified inference IDs to use to generate inference results
* @param inferenceResultsMap The initial inference results map
* @param query The query to generate inference results for
* @return An inference results map
*/
static Map getInferenceResults(
QueryRewriteContext queryRewriteContext,
- Set inferenceIds,
+ Set fullyQualifiedInferenceIds,
@Nullable Map inferenceResultsMap,
@Nullable String query
) {
@@ -239,12 +278,19 @@ static Map getInferenceResults(
: Map.of();
if (query != null) {
- for (String inferenceId : inferenceIds) {
- FullyQualifiedInferenceId fullyQualifiedInferenceId = new FullyQualifiedInferenceId(
- queryRewriteContext.getLocalClusterAlias(),
- inferenceId
- );
+ for (FullyQualifiedInferenceId fullyQualifiedInferenceId : fullyQualifiedInferenceIds) {
if (currentInferenceResultsMap.containsKey(fullyQualifiedInferenceId) == false) {
+ if (fullyQualifiedInferenceId.clusterAlias().equals(queryRewriteContext.getLocalClusterAlias()) == false) {
+ // Catch if we are missing inference results that should have been generated on another cluster
+ throw new IllegalStateException(
+ "Cannot get inference results for inference endpoint ["
+ + fullyQualifiedInferenceId
+ + "] on cluster ["
+ + queryRewriteContext.getLocalClusterAlias()
+ + "]"
+ );
+ }
+
if (modifiedInferenceResultsMap == false) {
// Copy the inference results map to ensure it is mutable and thread safe
currentInferenceResultsMap = new ConcurrentHashMap<>(currentInferenceResultsMap);
@@ -255,7 +301,7 @@ static Map getInferenceResults(
queryRewriteContext,
((ConcurrentHashMap) currentInferenceResultsMap),
query,
- inferenceId
+ fullyQualifiedInferenceId.inferenceId()
);
}
}
@@ -406,16 +452,23 @@ private QueryBuilder doRewriteBuildSemanticQuery(SearchExecutionContext searchEx
private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewriteContext) {
ResolvedIndices resolvedIndices = queryRewriteContext.getResolvedIndices();
- if (resolvedIndices.getRemoteClusterIndices().isEmpty() == false) {
- throw new IllegalArgumentException(NAME + " query does not support cross-cluster search");
+ boolean ccsRequest = resolvedIndices.getRemoteClusterIndices().isEmpty() == false;
+ if (ccsRequest && queryRewriteContext.isCcsMinimizeRoundTrips() == false) {
+ throw new IllegalArgumentException(
+ NAME + " query does not support cross-cluster search when [ccs_minimize_roundtrips] is false"
+ );
}
SemanticQueryBuilder rewritten = this;
if (queryRewriteContext.hasAsyncActions() == false) {
- Set inferenceIds = getInferenceIdsForForField(resolvedIndices.getConcreteLocalIndicesMetadata().values(), fieldName);
+ Set fullyQualifiedInferenceIds = getInferenceIdsForField(
+ resolvedIndices.getConcreteLocalIndicesMetadata().values(),
+ queryRewriteContext.getLocalClusterAlias(),
+ fieldName
+ );
Map modifiedInferenceResultsMap = getInferenceResults(
queryRewriteContext,
- inferenceIds,
+ fullyQualifiedInferenceIds,
inferenceResultsMap,
query
);
@@ -424,7 +477,7 @@ private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext qu
// The inference results map is fully populated, so we can perform error checking
inferenceResultsErrorCheck(modifiedInferenceResultsMap);
} else {
- rewritten = new SemanticQueryBuilder(this, modifiedInferenceResultsMap);
+ rewritten = new SemanticQueryBuilder(this, modifiedInferenceResultsMap, ccsRequest);
}
}
@@ -502,28 +555,33 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
throw new IllegalStateException(NAME + " should have been rewritten to another query type");
}
- private static Set getInferenceIdsForForField(Collection indexMetadataCollection, String fieldName) {
- Set inferenceIds = new HashSet<>();
+ private static Set getInferenceIdsForField(
+ Collection indexMetadataCollection,
+ String clusterAlias,
+ String fieldName
+ ) {
+ Set fullyQualifiedInferenceIds = new HashSet<>();
for (IndexMetadata indexMetadata : indexMetadataCollection) {
InferenceFieldMetadata inferenceFieldMetadata = indexMetadata.getInferenceFields().get(fieldName);
String indexInferenceId = inferenceFieldMetadata != null ? inferenceFieldMetadata.getSearchInferenceId() : null;
if (indexInferenceId != null) {
- inferenceIds.add(indexInferenceId);
+ fullyQualifiedInferenceIds.add(new FullyQualifiedInferenceId(clusterAlias, indexInferenceId));
}
}
- return inferenceIds;
+ return fullyQualifiedInferenceIds;
}
@Override
protected boolean doEquals(SemanticQueryBuilder other) {
return Objects.equals(fieldName, other.fieldName)
&& Objects.equals(query, other.query)
- && Objects.equals(inferenceResultsMap, other.inferenceResultsMap);
+ && Objects.equals(inferenceResultsMap, other.inferenceResultsMap)
+ && Objects.equals(ccsRequest, other.ccsRequest);
}
@Override
protected int doHashCode() {
- return Objects.hash(fieldName, query, inferenceResultsMap);
+ return Objects.hash(fieldName, query, inferenceResultsMap, ccsRequest);
}
}
diff --git a/x-pack/plugin/inference/src/main/plugin-metadata/entitlement-policy.yaml b/x-pack/plugin/inference/src/main/plugin-metadata/entitlement-policy.yaml
index 36ac851acf1ea..cd046c4a6cb23 100644
--- a/x-pack/plugin/inference/src/main/plugin-metadata/entitlement-policy.yaml
+++ b/x-pack/plugin/inference/src/main/plugin-metadata/entitlement-policy.yaml
@@ -12,6 +12,7 @@ software.amazon.awssdk.http.nio.netty:
io.netty.common:
- outbound_network
- manage_threads
+ - inbound_network
- files:
- path: "/etc/os-release"
mode: "read"
@@ -22,3 +23,4 @@ io.netty.common:
io.netty.transport:
- manage_threads
- outbound_network
+ - inbound_network
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/AbstractInterceptedInferenceQueryBuilderTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/AbstractInterceptedInferenceQueryBuilderTestCase.java
index e83a88e714a98..f999c0f89ae90 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/AbstractInterceptedInferenceQueryBuilderTestCase.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/AbstractInterceptedInferenceQueryBuilderTestCase.java
@@ -9,7 +9,6 @@
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.TransportVersion;
-import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.MockResolvedIndices;
import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.action.ResolvedIndices;
@@ -23,6 +22,7 @@
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.regex.Regex;
import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.core.CheckedRunnable;
import org.elasticsearch.index.Index;
import org.elasticsearch.index.IndexSettings;
import org.elasticsearch.index.IndexVersion;
@@ -65,13 +65,15 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
-import java.util.function.BiConsumer;
import java.util.function.Supplier;
+import static org.elasticsearch.TransportVersions.NEW_SEMANTIC_QUERY_INTERCEPTORS;
+import static org.elasticsearch.TransportVersions.V_8_15_0;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;
import static org.elasticsearch.xpack.inference.queries.InterceptedInferenceQueryBuilder.INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS;
-import static org.hamcrest.Matchers.containsString;
+import static org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder.SEMANTIC_SEARCH_CCS_SUPPORT;
import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.instanceOf;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.spy;
@@ -169,34 +171,105 @@ public void testBwCSerialization() throws Exception {
}
}
- public void testCcs() throws Exception {
- final String field = "semantic_field";
- final QueryRewriteContext queryRewriteContext = createQueryRewriteContext(
- Map.of("local-index", Map.of(field, SPARSE_INFERENCE_ID)),
+ public void testCcsSerialization() throws Exception {
+ final String inferenceField = "semantic_field";
+ final T inferenceFieldQuery = createQueryBuilder(inferenceField);
+ final T nonInferenceFieldQuery = createQueryBuilder("non_inference_field");
+
+ // Test with the current transport version. This simulates sending the query to a remote cluster that supports semantic search CCS.
+ final QueryRewriteContext contextCurrent = createQueryRewriteContext(
+ Map.of("local-index", Map.of(inferenceField, SPARSE_INFERENCE_ID)),
Map.of("remote-alias", "remote-index"),
- TransportVersion.current()
+ TransportVersion.current(),
+ true
);
- // Test querying a semantic text field
- final T semanticFieldQuery = createQueryBuilder(field);
- IllegalArgumentException e = assertThrows(
- IllegalArgumentException.class,
- () -> rewriteAndFetch(semanticFieldQuery, queryRewriteContext)
+ assertRewriteAndSerializeOnInferenceField(inferenceFieldQuery, contextCurrent, null, null);
+ assertRewriteAndSerializeOnNonInferenceField(nonInferenceFieldQuery, contextCurrent);
+ }
+
+ public void testCcsSerializationWithMinimizeRoundTripsFalse() throws Exception {
+ final String inferenceField = "semantic_field";
+ final T inferenceFieldQuery = createQueryBuilder(inferenceField);
+ final T nonInferenceFieldQuery = createQueryBuilder("non_inference_field");
+
+ final QueryRewriteContext minimizeRoundTripsFalseContext = createQueryRewriteContext(
+ Map.of("local-index", Map.of(inferenceField, SPARSE_INFERENCE_ID)),
+ Map.of("remote-alias", "remote-index"),
+ TransportVersion.current(),
+ false
);
- assertThat(
- e.getMessage(),
- containsString(
- semanticFieldQuery.getName() + " query does not support cross-cluster search when querying a [semantic_text] field"
- )
+
+ assertRewriteAndSerializeOnInferenceField(
+ inferenceFieldQuery,
+ minimizeRoundTripsFalseContext,
+ new IllegalArgumentException(
+ inferenceFieldQuery.getName()
+ + " query does not support cross-cluster search when querying a ["
+ + SemanticTextFieldMapper.CONTENT_TYPE
+ + "] field when [ccs_minimize_roundtrips] is false"
+ ),
+ null
);
+ assertRewriteAndSerializeOnNonInferenceField(nonInferenceFieldQuery, minimizeRoundTripsFalseContext);
+ }
- // Test querying a non-inference field
+ public void testCcsBwCSerialization() throws Exception {
+ final String inferenceField = "semantic_field";
+ final T inferenceFieldQuery = createQueryBuilder(inferenceField);
final T nonInferenceFieldQuery = createQueryBuilder("non_inference_field");
- QueryBuilder coordinatorRewritten = rewriteAndFetch(nonInferenceFieldQuery, queryRewriteContext);
- // Use a serialization cycle to strip InterceptedQueryBuilderWrapper
- coordinatorRewritten = copyNamedWriteable(coordinatorRewritten, writableRegistry(), QueryBuilder.class);
- assertCoordinatorNodeRewriteOnNonInferenceField(nonInferenceFieldQuery, coordinatorRewritten);
+ for (int i = 0; i < 100; i++) {
+ TransportVersion transportVersion = TransportVersionUtils.randomVersionBetween(
+ random(),
+ V_8_15_0,
+ TransportVersionUtils.getPreviousVersion(TransportVersion.current())
+ );
+
+ QueryRewriteContext queryRewriteContext = createQueryRewriteContext(
+ Map.of("local-index", Map.of(inferenceField, SPARSE_INFERENCE_ID)),
+ Map.of("remote-alias", "remote-index"),
+ transportVersion,
+ true
+ );
+
+ Exception expectedRewriteException = null;
+ Exception expectedSerializationException = null;
+ if (transportVersion.supports(SEMANTIC_SEARCH_CCS_SUPPORT) == false) {
+ if (transportVersion.supports(NEW_SEMANTIC_QUERY_INTERCEPTORS)) {
+ // Transport version is new enough to support the new interceptors, but not new enough to support CCS. This simulates if
+ // one of the local or remote cluster data nodes is out of date.
+ expectedSerializationException = new IllegalArgumentException(
+ "One or more nodes does not support "
+ + inferenceFieldQuery.getName()
+ + " query cross-cluster search when querying a ["
+ + SemanticTextFieldMapper.CONTENT_TYPE
+ + "] field. Please update all nodes to at least Elasticsearch "
+ + SEMANTIC_SEARCH_CCS_SUPPORT.toReleaseVersion()
+ + "."
+ );
+ } else {
+ // Transport version indicates usage of the legacy interceptors. This simulates if one of the local cluster data nodes
+ // is out of date to the point that it can't use the new interceptors.
+ expectedRewriteException = new IllegalArgumentException(
+ inferenceFieldQuery.getName()
+ + " query does not support cross-cluster search when querying a ["
+ + SemanticTextFieldMapper.CONTENT_TYPE
+ + "] field in a mixed-version cluster. Please update all nodes to at least Elasticsearch "
+ + SEMANTIC_SEARCH_CCS_SUPPORT.toReleaseVersion()
+ + "."
+ );
+ }
+ }
+
+ assertRewriteAndSerializeOnInferenceField(
+ inferenceFieldQuery,
+ queryRewriteContext,
+ expectedRewriteException,
+ expectedSerializationException
+ );
+ assertRewriteAndSerializeOnNonInferenceField(nonInferenceFieldQuery, queryRewriteContext);
+ }
}
public void testSerializationRemoteClusterInferenceResults() throws Exception {
@@ -230,7 +303,7 @@ public void testSerializationRemoteClusterInferenceResults() throws Exception {
// Test with a transport version prior to cluster alias support, which should fail
TransportVersion transportVersion = TransportVersionUtils.randomVersionBetween(
random(),
- TransportVersions.NEW_SEMANTIC_QUERY_INTERCEPTORS,
+ NEW_SEMANTIC_QUERY_INTERCEPTORS,
TransportVersionUtils.getPreviousVersion(INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS)
);
IllegalArgumentException e = assertThrows(
@@ -270,7 +343,7 @@ protected abstract void assertCoordinatorNodeRewriteOnInferenceField(
QueryBuilder rewritten,
TransportVersion transportVersion,
QueryRewriteContext queryRewriteContext
- );
+ ) throws Exception;
protected abstract void assertCoordinatorNodeRewriteOnNonInferenceField(QueryBuilder original, QueryBuilder rewritten);
@@ -291,48 +364,28 @@ protected void serializationTestCase(TransportVersion transportVersion) throws E
final QueryRewriteContext queryRewriteContext = createQueryRewriteContext(
Map.of(testIndex1.name(), testIndex1.semanticTextFields(), testIndex2.name(), testIndex2.semanticTextFields()),
Map.of(),
- transportVersion
+ transportVersion,
+ null
);
- // Disable query interception when checking the results of coordinator node rewrite so that the query rewrite context can be used
- // to populate inference results without triggering another query interception. In production this is achieved by wrapping with
- // InterceptedQueryBuilderWrapper, but we do not have access to that in this test.
- final BiConsumer disableQueryInterception = (c, r) -> {
- QueryRewriteInterceptor interceptor = c.getQueryRewriteInterceptor();
- c.setQueryRewriteInterceptor(null);
- r.run();
- c.setQueryRewriteInterceptor(interceptor);
- };
-
// Query a semantic text field in both indices
QueryBuilder originalSemantic = createQueryBuilder(semanticField);
- QueryBuilder rewrittenSemantic = rewriteAndFetch(originalSemantic, queryRewriteContext);
- QueryBuilder serializedSemantic = copyNamedWriteable(rewrittenSemantic, writableRegistry(), QueryBuilder.class);
- disableQueryInterception.accept(
- queryRewriteContext,
- () -> assertCoordinatorNodeRewriteOnInferenceField(originalSemantic, serializedSemantic, transportVersion, queryRewriteContext)
- );
+ assertRewriteAndSerializeOnInferenceField(originalSemantic, queryRewriteContext, null, null);
// Query a field that is a semantic text field in one index
QueryBuilder originalMixed = createQueryBuilder(mixedField);
- QueryBuilder rewrittenMixed = rewriteAndFetch(originalMixed, queryRewriteContext);
- QueryBuilder serializedMixed = copyNamedWriteable(rewrittenMixed, writableRegistry(), QueryBuilder.class);
- disableQueryInterception.accept(
- queryRewriteContext,
- () -> assertCoordinatorNodeRewriteOnInferenceField(originalMixed, serializedMixed, transportVersion, queryRewriteContext)
- );
+ assertRewriteAndSerializeOnInferenceField(originalMixed, queryRewriteContext, null, null);
// Query a text field in both indices
QueryBuilder originalText = createQueryBuilder(textField);
- QueryBuilder rewrittenText = rewriteAndFetch(originalText, queryRewriteContext);
- QueryBuilder serializedText = copyNamedWriteable(rewrittenText, writableRegistry(), QueryBuilder.class);
- assertCoordinatorNodeRewriteOnNonInferenceField(originalText, serializedText);
+ assertRewriteAndSerializeOnNonInferenceField(originalText, queryRewriteContext);
}
protected QueryRewriteContext createQueryRewriteContext(
Map> localIndexInferenceFields,
Map remoteIndexNames,
- TransportVersion minTransportVersion
+ TransportVersion minTransportVersion,
+ Boolean ccsMinimizeRoundTrips
) {
Map indexMetadata = new HashMap<>();
for (var indexEntry : localIndexInferenceFields.entrySet()) {
@@ -385,7 +438,7 @@ protected QueryRewriteContext createQueryRewriteContext(
resolvedIndices,
null,
QueryRewriteInterceptor.multi(interceptorMap),
- null
+ ccsMinimizeRoundTrips
);
}
@@ -465,12 +518,90 @@ protected QueryRewriteContext createIndexMetadataContext(
}
}
+ protected void assertRewriteAndSerializeOnInferenceField(
+ QueryBuilder originalQuery,
+ QueryRewriteContext queryRewriteContext,
+ Exception expectedRewriteException,
+ Exception expectedSerializationException
+ ) throws Exception {
+ if (expectedRewriteException != null) {
+ Exception actualException = assertThrows(Exception.class, () -> rewriteAndFetch(originalQuery, queryRewriteContext));
+ assertThat(actualException, instanceOf(expectedRewriteException.getClass()));
+ assertThat(actualException.getMessage(), equalTo(expectedRewriteException.getMessage()));
+ return;
+ }
+ QueryBuilder rewrittenQuery = rewriteAndFetch(originalQuery, queryRewriteContext);
+
+ TransportVersion serializationTransportVersion = queryRewriteContext.getMinTransportVersion();
+ if (expectedSerializationException != null) {
+ Exception actualException = assertThrows(
+ Exception.class,
+ () -> copyNamedWriteable(rewrittenQuery, writableRegistry(), QueryBuilder.class, serializationTransportVersion)
+ );
+ assertThat(actualException, instanceOf(expectedSerializationException.getClass()));
+ assertThat(actualException.getMessage(), equalTo(expectedSerializationException.getMessage()));
+ return;
+ }
+ QueryBuilder serializedQuery = copyNamedWriteable(
+ rewrittenQuery,
+ writableRegistry(),
+ QueryBuilder.class,
+ serializationTransportVersion
+ );
+
+ // Run the original query through a serialization cycle to account for any BwC logic applied through the transport version
+ QueryBuilder originalSerializedQuery = copyNamedWriteable(
+ originalQuery,
+ writableRegistry(),
+ QueryBuilder.class,
+ serializationTransportVersion
+ );
+
+ // Disable query interception when checking the results of coordinator node rewrite so that the query rewrite context can be used
+ // to populate inference results without triggering another query interception. In production this is achieved by wrapping with
+ // InterceptedQueryBuilderWrapper, but we do not have access to that in this test.
+ disableQueryInterception(
+ queryRewriteContext,
+ () -> assertCoordinatorNodeRewriteOnInferenceField(
+ originalSerializedQuery,
+ serializedQuery,
+ queryRewriteContext.getMinTransportVersion(),
+ queryRewriteContext
+ )
+ );
+ }
+
+ protected void assertRewriteAndSerializeOnNonInferenceField(QueryBuilder originalQuery, QueryRewriteContext queryRewriteContext)
+ throws IOException {
+ TransportVersion serializationVersion = queryRewriteContext.getMinTransportVersion();
+
+ // Run the original query through a serialization cycle to account for any BwC logic applied through the transport version
+ QueryBuilder originalSerializedQuery = copyNamedWriteable(
+ originalQuery,
+ writableRegistry(),
+ QueryBuilder.class,
+ serializationVersion
+ );
+
+ QueryBuilder rewrittenQuery = rewriteAndFetch(originalQuery, queryRewriteContext);
+ QueryBuilder serializedQuery = copyNamedWriteable(rewrittenQuery, writableRegistry(), QueryBuilder.class, serializationVersion);
+ assertCoordinatorNodeRewriteOnNonInferenceField(originalSerializedQuery, serializedQuery);
+ }
+
protected static QueryBuilder rewriteAndFetch(QueryBuilder queryBuilder, QueryRewriteContext queryRewriteContext) {
PlainActionFuture future = new PlainActionFuture<>();
Rewriteable.rewriteAndFetch(queryBuilder, queryRewriteContext, future);
return future.actionGet();
}
+ protected static void disableQueryInterception(QueryRewriteContext queryRewriteContext, CheckedRunnable runnable)
+ throws Exception {
+ QueryRewriteInterceptor interceptor = queryRewriteContext.getQueryRewriteInterceptor();
+ queryRewriteContext.setQueryRewriteInterceptor(null);
+ runnable.run();
+ queryRewriteContext.setQueryRewriteInterceptor(interceptor);
+ }
+
private static ModelRegistry createModelRegistry(ThreadPool threadPool) {
ClusterService clusterService = ClusterServiceUtils.createClusterService(threadPool);
ModelRegistry modelRegistry = spy(new ModelRegistry(clusterService, new NoOpClient(threadPool)));
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilderTests.java
index dcbfff1ff99bf..329444595ef3e 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilderTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilderTests.java
@@ -61,10 +61,7 @@ protected InterceptedInferenceQueryBuilder createIntercep
KnnVectorQueryBuilder originalQuery,
Map inferenceResultsMap
) {
- return new InterceptedInferenceKnnVectorQueryBuilder(
- new InterceptedInferenceKnnVectorQueryBuilder(originalQuery),
- inferenceResultsMap
- );
+ return new InterceptedInferenceKnnVectorQueryBuilder(originalQuery, inferenceResultsMap);
}
@Override
@@ -160,7 +157,8 @@ public void testInterceptAndRewrite() throws Exception {
final QueryRewriteContext queryRewriteContext = createQueryRewriteContext(
Map.of(testIndex1.name(), testIndex1.semanticTextFields(), testIndex2.name(), testIndex2.semanticTextFields()),
Map.of(),
- TransportVersion.current()
+ TransportVersion.current(),
+ null
);
QueryBuilder coordinatorRewritten = rewriteAndFetch(knnQuery, queryRewriteContext);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilderTests.java
index abf12b2c876d4..c7a680f15d6d1 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilderTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilderTests.java
@@ -33,7 +33,7 @@ protected InterceptedInferenceQueryBuilder createInterceptedQ
MatchQueryBuilder originalQuery,
Map inferenceResultsMap
) {
- return new InterceptedInferenceMatchQueryBuilder(new InterceptedInferenceMatchQueryBuilder(originalQuery), inferenceResultsMap);
+ return new InterceptedInferenceMatchQueryBuilder(originalQuery, inferenceResultsMap);
}
@Override
@@ -52,7 +52,7 @@ protected void assertCoordinatorNodeRewriteOnInferenceField(
QueryBuilder rewritten,
TransportVersion transportVersion,
QueryRewriteContext queryRewriteContext
- ) {
+ ) throws Exception {
assertThat(original, instanceOf(MatchQueryBuilder.class));
if (transportVersion.onOrAfter(TransportVersions.NEW_SEMANTIC_QUERY_INTERCEPTORS)) {
assertThat(rewritten, instanceOf(InterceptedInferenceMatchQueryBuilder.class));
@@ -69,7 +69,16 @@ protected void assertCoordinatorNodeRewriteOnInferenceField(
original
);
QueryBuilder expectedLegacyRewritten = rewriteAndFetch(expectedLegacyIntercepted, queryRewriteContext);
- assertThat(rewritten, equalTo(expectedLegacyRewritten));
+
+ // Run the expected query through a serialization cycle to align the inference results map representations
+ QueryBuilder expectedLegacySerialized = copyNamedWriteable(
+ expectedLegacyRewritten,
+ writableRegistry(),
+ QueryBuilder.class,
+ transportVersion
+ );
+
+ assertThat(rewritten, equalTo(expectedLegacySerialized));
}
}
@@ -98,7 +107,8 @@ public void testInterceptAndRewrite() throws Exception {
testIndex3.semanticTextFields()
),
Map.of(),
- TransportVersion.current()
+ TransportVersion.current(),
+ null
);
QueryBuilder coordinatorRewritten = rewriteAndFetch(matchQuery, queryRewriteContext);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilderTests.java
index a0066f5da130d..9a44222b16cc3 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilderTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilderTests.java
@@ -61,10 +61,7 @@ protected InterceptedInferenceQueryBuilder createInter
SparseVectorQueryBuilder originalQuery,
Map inferenceResultsMap
) {
- return new InterceptedInferenceSparseVectorQueryBuilder(
- new InterceptedInferenceSparseVectorQueryBuilder(originalQuery),
- inferenceResultsMap
- );
+ return new InterceptedInferenceSparseVectorQueryBuilder(originalQuery, inferenceResultsMap);
}
@Override
@@ -133,7 +130,8 @@ public void testInterceptAndRewrite() throws Exception {
final QueryRewriteContext queryRewriteContext = createQueryRewriteContext(
Map.of(testIndex1.name(), testIndex1.semanticTextFields(), testIndex2.name(), testIndex2.semanticTextFields()),
Map.of(),
- TransportVersion.current()
+ TransportVersion.current(),
+ null
);
QueryBuilder coordinatorRewritten = rewriteAndFetch(sparseVectorQuery, queryRewriteContext);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java
index 05de3ba7d69ac..b2d7218720a57 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java
@@ -95,6 +95,8 @@
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;
import static org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder.INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS;
import static org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder.SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS_TV;
+import static org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder.SEMANTIC_SEARCH_CCS_SUPPORT;
+import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.notNullValue;
@@ -496,6 +498,40 @@ public void testSerializationBwc() throws IOException {
}
}
+ public void testSerializationCcs() throws Exception {
+ SemanticQueryBuilder originalQuery = new SemanticQueryBuilder(randomAlphaOfLength(5), randomAlphaOfLength(5), null, Map.of(), true);
+ QueryBuilder deserializedQuery = copyNamedWriteable(originalQuery, namedWriteableRegistry(), QueryBuilder.class);
+ assertThat(deserializedQuery, equalTo(originalQuery));
+ }
+
+ public void testSerializationCcsBwc() throws Exception {
+ SemanticQueryBuilder originalQuery = new SemanticQueryBuilder(randomAlphaOfLength(5), randomAlphaOfLength(5), null, Map.of(), true);
+
+ for (int i = 0; i < 100; i++) {
+ TransportVersion transportVersion = TransportVersionUtils.randomVersionBetween(
+ random(),
+ originalQuery.getMinimalSupportedVersion(),
+ TransportVersionUtils.getPreviousVersion(TransportVersion.current())
+ );
+
+ if (transportVersion.supports(SEMANTIC_SEARCH_CCS_SUPPORT)) {
+ QueryBuilder deserializedQuery = copyNamedWriteable(
+ originalQuery,
+ namedWriteableRegistry(),
+ QueryBuilder.class,
+ transportVersion
+ );
+ assertThat(deserializedQuery, equalTo(originalQuery));
+ } else {
+ IllegalArgumentException e = assertThrows(
+ IllegalArgumentException.class,
+ () -> copyNamedWriteable(originalQuery, namedWriteableRegistry(), QueryBuilder.class, transportVersion)
+ );
+ assertThat(e.getMessage(), containsString("One or more nodes does not support semantic query cross-cluster search"));
+ }
+ }
+ }
+
public void testToXContent() throws IOException {
QueryBuilder queryBuilder = new SemanticQueryBuilder("foo", "bar");
checkGeneratedJson("""