diff --git a/docs/changelog/135309.yaml b/docs/changelog/135309.yaml new file mode 100644 index 0000000000000..20c50553c2eb8 --- /dev/null +++ b/docs/changelog/135309.yaml @@ -0,0 +1,5 @@ +pr: 135309 +summary: Enable semantic search CCS when ccs_minimize_roundtrips=true +area: Vector Search +type: enhancement +issues: [] diff --git a/qa/multi-cluster-search/src/test/resources/rest-api-spec/test/multi_cluster/110_semantic_query.yml b/qa/multi-cluster-search/src/test/resources/rest-api-spec/test/multi_cluster/110_semantic_query.yml deleted file mode 100644 index 0155175f0e54a..0000000000000 --- a/qa/multi-cluster-search/src/test/resources/rest-api-spec/test/multi_cluster/110_semantic_query.yml +++ /dev/null @@ -1,37 +0,0 @@ ---- -setup: - - requires: - cluster_features: "gte_v8.15.0" - reason: semantic query introduced in 8.15.0 - - - do: - indices.create: - index: test-index - body: - settings: - index: - number_of_shards: 1 - number_of_replicas: 0 ---- -teardown: - - - do: - indices.delete: - index: test-index - ignore_unavailable: true - ---- -"Test that semantic query does not support cross-cluster search": - - do: - catch: bad_request - search: - index: "test-index,my_remote_cluster:test-index" - body: - query: - semantic: - field: "field" - query: "test query" - - - - match: { error.type: "illegal_argument_exception" } - - match: { error.reason: "semantic query does not support cross-cluster search" } diff --git a/server/src/main/resources/transport/definitions/referable/semantic_search_ccs_support.csv b/server/src/main/resources/transport/definitions/referable/semantic_search_ccs_support.csv new file mode 100644 index 0000000000000..35154103cd0da --- /dev/null +++ b/server/src/main/resources/transport/definitions/referable/semantic_search_ccs_support.csv @@ -0,0 +1 @@ +9174000 diff --git a/server/src/main/resources/transport/upper_bounds/9.2.csv b/server/src/main/resources/transport/upper_bounds/9.2.csv index e60434a3e2189..57900e0428e01 100644 --- a/server/src/main/resources/transport/upper_bounds/9.2.csv +++ b/server/src/main/resources/transport/upper_bounds/9.2.csv @@ -1 +1 @@ -sampling_configuration,9173000 +semantic_search_ccs_support,9174000 diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/LocalStateCompositeXPackPlugin.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/LocalStateCompositeXPackPlugin.java index 117141708ed43..366b10e125604 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/LocalStateCompositeXPackPlugin.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/LocalStateCompositeXPackPlugin.java @@ -79,6 +79,8 @@ import org.elasticsearch.plugins.ShutdownAwarePlugin; import org.elasticsearch.plugins.SystemIndexPlugin; import org.elasticsearch.plugins.interceptor.RestServerActionPlugin; +import org.elasticsearch.plugins.internal.InternalSearchPlugin; +import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor; import org.elasticsearch.repositories.RepositoriesMetrics; import org.elasticsearch.repositories.Repository; import org.elasticsearch.repositories.SnapshotMetrics; @@ -135,6 +137,7 @@ public class LocalStateCompositeXPackPlugin extends XPackPlugin IndexStorePlugin, SystemIndexPlugin, SearchPlugin, + InternalSearchPlugin, ShutdownAwarePlugin, RestServerActionPlugin { @@ -291,6 +294,15 @@ public List> getQueries() { return querySpecs; } + @Override + public List getQueryRewriteInterceptors() { + List interceptors = new ArrayList<>(); + filterPlugins(InternalSearchPlugin.class).stream() + .flatMap(p -> p.getQueryRewriteInterceptors().stream()) + .forEach(interceptors::add); + return interceptors; + } + @Override public List getNamedXContent() { List entries = new ArrayList<>(super.getNamedXContent()); diff --git a/x-pack/plugin/inference/build.gradle b/x-pack/plugin/inference/build.gradle index 9486d239e5de5..eb9372e675831 100644 --- a/x-pack/plugin/inference/build.gradle +++ b/x-pack/plugin/inference/build.gradle @@ -36,6 +36,7 @@ dependencies { testImplementation(project(':x-pack:plugin:inference:qa:test-service-plugin')) testImplementation project(':modules:reindex') testImplementation project(':modules:mapper-extras') + testImplementation project(':x-pack:plugin:ml') clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin') api "com.ibm.icu:icu4j:${versions.icu4j}" diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/AbstractSemanticCrossClusterSearchTestCase.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/AbstractSemanticCrossClusterSearchTestCase.java new file mode 100644 index 0000000000000..685453fa77c78 --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/AbstractSemanticCrossClusterSearchTestCase.java @@ -0,0 +1,337 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.search.ccs; + +import org.elasticsearch.action.DocWriteResponse; +import org.elasticsearch.action.search.OpenPointInTimeRequest; +import org.elasticsearch.action.search.OpenPointInTimeResponse; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.search.TransportOpenPointInTimeAction; +import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.action.support.broadcast.BroadcastResponse; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.inference.MinimalServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.license.LicenseSettings; +import org.elasticsearch.plugins.ActionPlugin; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.plugins.SearchPlugin; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.test.AbstractMultiClustersTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; +import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.vectors.TextEmbeddingQueryVectorBuilder; +import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; +import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; +import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension; +import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin; +import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; +import org.elasticsearch.xpack.ml.action.TransportCoordinatedInferenceAction; + +import java.io.IOException; +import java.util.Collection; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; + +public abstract class AbstractSemanticCrossClusterSearchTestCase extends AbstractMultiClustersTestCase { + protected static final String REMOTE_CLUSTER = "cluster_a"; + + @Override + protected List remoteClusterAlias() { + return List.of(REMOTE_CLUSTER); + } + + @Override + protected Map skipUnavailableForRemoteClusters() { + return Map.of(REMOTE_CLUSTER, DEFAULT_SKIP_UNAVAILABLE); + } + + @Override + protected boolean reuseClusters() { + return false; + } + + @Override + protected Settings nodeSettings() { + return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build(); + } + + @Override + protected Collection> nodePlugins(String clusterAlias) { + return List.of(LocalStateInferencePlugin.class, TestInferenceServicePlugin.class, FakeMlPlugin.class); + } + + protected void setupTwoClusters(TestIndexInfo localIndexInfo, TestIndexInfo remoteIndexInfo) throws IOException { + setupCluster(LOCAL_CLUSTER, localIndexInfo); + setupCluster(REMOTE_CLUSTER, remoteIndexInfo); + } + + protected void setupCluster(String clusterAlias, TestIndexInfo indexInfo) throws IOException { + final Client client = client(clusterAlias); + final String indexName = indexInfo.name(); + + for (var entry : indexInfo.inferenceEndpoints().entrySet()) { + String inferenceId = entry.getKey(); + MinimalServiceSettings minimalServiceSettings = entry.getValue(); + + Map serviceSettings = new HashMap<>(); + serviceSettings.put("model", randomAlphaOfLength(5)); + serviceSettings.put("api_key", randomAlphaOfLength(5)); + if (minimalServiceSettings.taskType() == TaskType.TEXT_EMBEDDING) { + serviceSettings.put("dimensions", minimalServiceSettings.dimensions()); + serviceSettings.put("similarity", minimalServiceSettings.similarity()); + serviceSettings.put("element_type", minimalServiceSettings.elementType()); + } + + createInferenceEndpoint(client, minimalServiceSettings.taskType(), inferenceId, serviceSettings); + } + + Settings indexSettings = indexSettings(randomIntBetween(2, 5), randomIntBetween(0, 1)).build(); + assertAcked(client.admin().indices().prepareCreate(indexName).setSettings(indexSettings).setMapping(indexInfo.mappings())); + assertFalse( + client.admin() + .cluster() + .prepareHealth(TEST_REQUEST_TIMEOUT, indexName) + .setWaitForYellowStatus() + .setTimeout(TimeValue.timeValueSeconds(10)) + .get() + .isTimedOut() + ); + + for (var entry : indexInfo.docs().entrySet()) { + String docId = entry.getKey(); + Map doc = entry.getValue(); + + DocWriteResponse response = client.prepareIndex(indexName).setId(docId).setSource(doc).execute().actionGet(); + assertThat(response.getResult(), equalTo(DocWriteResponse.Result.CREATED)); + } + BroadcastResponse refreshResponse = client.admin().indices().prepareRefresh(indexName).execute().actionGet(); + assertThat(refreshResponse.getStatus(), is(RestStatus.OK)); + } + + protected BytesReference openPointInTime(String[] indices, TimeValue keepAlive) { + OpenPointInTimeRequest request = new OpenPointInTimeRequest(indices).keepAlive(keepAlive); + final OpenPointInTimeResponse response = client().execute(TransportOpenPointInTimeAction.TYPE, request).actionGet(); + return response.getPointInTimeId(); + } + + protected static void createInferenceEndpoint(Client client, TaskType taskType, String inferenceId, Map serviceSettings) + throws IOException { + final String service = switch (taskType) { + case TEXT_EMBEDDING -> TestDenseInferenceServiceExtension.TestInferenceService.NAME; + case SPARSE_EMBEDDING -> TestSparseInferenceServiceExtension.TestInferenceService.NAME; + default -> throw new IllegalArgumentException("Unhandled task type [" + taskType + "]"); + }; + + final BytesReference content; + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + builder.startObject(); + builder.field("service", service); + builder.field("service_settings", serviceSettings); + builder.endObject(); + + content = BytesReference.bytes(builder); + } + + PutInferenceModelAction.Request request = new PutInferenceModelAction.Request( + taskType, + inferenceId, + content, + XContentType.JSON, + TEST_REQUEST_TIMEOUT + ); + var responseFuture = client.execute(PutInferenceModelAction.INSTANCE, request); + assertThat(responseFuture.actionGet(TEST_REQUEST_TIMEOUT).getModel().getInferenceEntityId(), equalTo(inferenceId)); + } + + protected void assertSearchResponse(QueryBuilder queryBuilder, List indices, List expectedSearchResults) + throws Exception { + assertSearchResponse(queryBuilder, indices, expectedSearchResults, null, null); + } + + protected void assertSearchResponse( + QueryBuilder queryBuilder, + List indices, + List expectedSearchResults, + ClusterFailure expectedRemoteFailure, + Consumer searchRequestModifier + ) throws Exception { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(queryBuilder).size(expectedSearchResults.size()); + indices.forEach(i -> searchSourceBuilder.indexBoost(i.index(), i.boost())); + + SearchRequest searchRequest = new SearchRequest(convertToArray(indices), searchSourceBuilder); + searchRequest.indicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN); + if (searchRequestModifier != null) { + searchRequestModifier.accept(searchRequest); + } + + assertResponse(client().search(searchRequest), response -> { + SearchHit[] hits = response.getHits().getHits(); + assertThat(hits.length, equalTo(expectedSearchResults.size())); + + Iterator searchResultIterator = expectedSearchResults.iterator(); + for (int i = 0; i < hits.length; i++) { + SearchResult expectedSearchResult = searchResultIterator.next(); + SearchHit actualSearchResult = hits[i]; + + assertThat(actualSearchResult.getClusterAlias(), equalTo(expectedSearchResult.clusterAlias())); + assertThat(actualSearchResult.getIndex(), equalTo(expectedSearchResult.index())); + assertThat(actualSearchResult.getId(), equalTo(expectedSearchResult.id())); + } + + SearchResponse.Clusters clusters = response.getClusters(); + assertThat(clusters.getCluster(LOCAL_CLUSTER).getStatus(), equalTo(SearchResponse.Cluster.Status.SUCCESSFUL)); + assertThat(clusters.getCluster(LOCAL_CLUSTER).getFailures().isEmpty(), is(true)); + + SearchResponse.Cluster remoteCluster = clusters.getCluster(REMOTE_CLUSTER); + if (expectedRemoteFailure != null) { + assertThat(remoteCluster.getStatus(), equalTo(expectedRemoteFailure.status())); + + Set expectedFailures = expectedRemoteFailure.failures(); + Set actualFailures = remoteCluster.getFailures() + .stream() + .map(f -> new FailureCause(f.getCause().getClass(), f.getCause().getMessage())) + .collect(Collectors.toSet()); + assertThat(actualFailures, equalTo(expectedFailures)); + } else { + assertThat(remoteCluster.getStatus(), equalTo(SearchResponse.Cluster.Status.SUCCESSFUL)); + assertThat(remoteCluster.getFailures().isEmpty(), is(true)); + } + }); + } + + protected static MinimalServiceSettings sparseEmbeddingServiceSettings() { + return new MinimalServiceSettings(null, TaskType.SPARSE_EMBEDDING, null, null, null); + } + + protected static MinimalServiceSettings textEmbeddingServiceSettings( + int dimensions, + SimilarityMeasure similarity, + DenseVectorFieldMapper.ElementType elementType + ) { + return new MinimalServiceSettings(null, TaskType.TEXT_EMBEDDING, dimensions, similarity, elementType); + } + + protected static Map semanticTextMapping(String inferenceId) { + return Map.of("type", SemanticTextFieldMapper.CONTENT_TYPE, "inference_id", inferenceId); + } + + protected static Map textMapping() { + return Map.of("type", "text"); + } + + protected static Map denseVectorMapping(int dimensions) { + return Map.of("type", DenseVectorFieldMapper.CONTENT_TYPE, "dims", dimensions); + } + + protected static Map sparseVectorMapping() { + return Map.of("type", SparseVectorFieldMapper.CONTENT_TYPE); + } + + protected static String fullyQualifiedIndexName(String clusterAlias, String indexName) { + return clusterAlias + ":" + indexName; + } + + protected static float[] generateDenseVectorFieldValue(int dimensions, DenseVectorFieldMapper.ElementType elementType, float value) { + if (elementType == DenseVectorFieldMapper.ElementType.BIT) { + assert dimensions % 8 == 0; + dimensions /= 8; + } + + float[] vector = new float[dimensions]; + for (int i = 0; i < dimensions; i++) { + // Use a constant value so that relevance is consistent + vector[i] = value; + } + + return vector; + } + + protected static Map generateSparseVectorFieldValue(float weight) { + // Generate values that have the same recall behavior as those produced by TestSparseInferenceServiceExtension. Use a constant token + // weight so that relevance is consistent. + return Map.of("feature_0", weight); + } + + protected static String[] convertToArray(List indices) { + return indices.stream().map(IndexWithBoost::index).toArray(String[]::new); + } + + public static class FakeMlPlugin extends Plugin implements ActionPlugin, SearchPlugin { + @Override + public List getNamedWriteables() { + return new MlInferenceNamedXContentProvider().getNamedWriteables(); + } + + @Override + public List> getQueryVectorBuilders() { + return List.of( + new QueryVectorBuilderSpec<>( + TextEmbeddingQueryVectorBuilder.NAME, + TextEmbeddingQueryVectorBuilder::new, + TextEmbeddingQueryVectorBuilder.PARSER + ) + ); + } + + @Override + public Collection getActions() { + return List.of(new ActionHandler(CoordinatedInferenceAction.INSTANCE, TransportCoordinatedInferenceAction.class)); + } + } + + protected record TestIndexInfo( + String name, + Map inferenceEndpoints, + Map mappings, + Map> docs + ) { + @Override + public Map mappings() { + return Map.of("properties", mappings); + } + } + + protected record SearchResult(@Nullable String clusterAlias, String index, String id) {} + + protected record FailureCause(Class causeClass, String message) {} + + protected record ClusterFailure(SearchResponse.Cluster.Status status, Set failures) {} + + protected record IndexWithBoost(String index, float boost) { + public IndexWithBoost(String index) { + this(index, 1.0f); + } + } +} diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/KnnVectorQueryBuilderCrossClusterSearchIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/KnnVectorQueryBuilderCrossClusterSearchIT.java new file mode 100644 index 0000000000000..b54d7afe08714 --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/KnnVectorQueryBuilderCrossClusterSearchIT.java @@ -0,0 +1,279 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.search.ccs; + +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.query.QueryShardException; +import org.elasticsearch.inference.MinimalServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; +import org.elasticsearch.search.vectors.VectorData; +import org.elasticsearch.xpack.core.ml.vectors.TextEmbeddingQueryVectorBuilder; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.BiConsumer; + +import static org.hamcrest.Matchers.equalTo; + +public class KnnVectorQueryBuilderCrossClusterSearchIT extends AbstractSemanticCrossClusterSearchTestCase { + private static final String LOCAL_INDEX_NAME = "local-index"; + private static final String REMOTE_INDEX_NAME = "remote-index"; + private static final List QUERY_INDICES = List.of( + new IndexWithBoost(LOCAL_INDEX_NAME), + new IndexWithBoost(fullyQualifiedIndexName(REMOTE_CLUSTER, REMOTE_INDEX_NAME)) + ); + + public void testKnnQuery() throws Exception { + final String commonInferenceId = "common-inference-id"; + final String localInferenceId = "local-inference-id"; + + final String commonInferenceIdField = "common-inference-id-field"; + final String mixedTypeField1 = "mixed-type-field-1"; + final String mixedTypeField2 = "mixed-type-field-2"; + + final TestIndexInfo localIndexInfo = new TestIndexInfo( + LOCAL_INDEX_NAME, + Map.of( + commonInferenceId, + textEmbeddingServiceSettings(256, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT), + localInferenceId, + textEmbeddingServiceSettings(384, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT) + ), + Map.of( + commonInferenceIdField, + semanticTextMapping(commonInferenceId), + mixedTypeField1, + denseVectorMapping(384), + mixedTypeField2, + semanticTextMapping(localInferenceId) + ), + Map.of( + "local_doc_1", + Map.of(commonInferenceIdField, "a"), + "local_doc_2", + Map.of(mixedTypeField1, generateDenseVectorFieldValue(384, DenseVectorFieldMapper.ElementType.FLOAT, -128.0f)), + "local_doc_3", + Map.of(mixedTypeField2, "c") + ) + ); + final TestIndexInfo remoteIndexInfo = new TestIndexInfo( + REMOTE_INDEX_NAME, + Map.of( + commonInferenceId, + textEmbeddingServiceSettings(384, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT) + ), + Map.of( + commonInferenceIdField, + semanticTextMapping(commonInferenceId), + mixedTypeField1, + semanticTextMapping(commonInferenceId), + mixedTypeField2, + denseVectorMapping(384) + ), + Map.of( + "remote_doc_1", + Map.of(commonInferenceIdField, "x"), + "remote_doc_2", + Map.of(mixedTypeField1, "y"), + "remote_doc_3", + Map.of(mixedTypeField2, generateDenseVectorFieldValue(384, DenseVectorFieldMapper.ElementType.FLOAT, -128.0f)) + ) + ); + setupTwoClusters(localIndexInfo, remoteIndexInfo); + + // Query a field has the same inference ID value across clusters, but with different backing inference services + assertSearchResponse( + new KnnVectorQueryBuilder(commonInferenceIdField, new TextEmbeddingQueryVectorBuilder(null, "a"), 10, 100, 10f, null), + QUERY_INDICES, + List.of( + new SearchResult(LOCAL_CLUSTER, LOCAL_INDEX_NAME, "local_doc_1"), + new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, "remote_doc_1") + ) + ); + + // Query a field that has mixed types across clusters + assertSearchResponse( + new KnnVectorQueryBuilder(mixedTypeField1, new TextEmbeddingQueryVectorBuilder(localInferenceId, "y"), 10, 100, 10f, null), + QUERY_INDICES, + List.of( + new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, "remote_doc_2"), + new SearchResult(LOCAL_CLUSTER, LOCAL_INDEX_NAME, "local_doc_2") + ) + ); + assertSearchResponse( + new KnnVectorQueryBuilder(mixedTypeField2, new TextEmbeddingQueryVectorBuilder(localInferenceId, "c"), 10, 100, 10f, null), + QUERY_INDICES, + List.of( + new SearchResult(LOCAL_CLUSTER, LOCAL_INDEX_NAME, "local_doc_3"), + new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, "remote_doc_3") + ) + ); + + // Query a field that has mixed types across clusters using a query vector + final VectorData queryVector = new VectorData( + generateDenseVectorFieldValue(384, DenseVectorFieldMapper.ElementType.FLOAT, -128.0f) + ); + assertSearchResponse( + new KnnVectorQueryBuilder(mixedTypeField1, queryVector, 10, 100, 10f, null, null), + QUERY_INDICES, + List.of( + new SearchResult(LOCAL_CLUSTER, LOCAL_INDEX_NAME, "local_doc_2"), + new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, "remote_doc_2") + ) + ); + assertSearchResponse( + new KnnVectorQueryBuilder(mixedTypeField2, queryVector, 10, 100, 10f, null, null), + QUERY_INDICES, + List.of( + new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, "remote_doc_3"), + new SearchResult(LOCAL_CLUSTER, LOCAL_INDEX_NAME, "local_doc_3") + ) + ); + + // Check that omitting the inference ID when querying a remote dense vector field leads to the expected partial failure + assertSearchResponse( + new KnnVectorQueryBuilder(mixedTypeField2, new TextEmbeddingQueryVectorBuilder(null, "c"), 10, 100, 10f, null), + QUERY_INDICES, + List.of(new SearchResult(LOCAL_CLUSTER, LOCAL_INDEX_NAME, "local_doc_3")), + new ClusterFailure( + SearchResponse.Cluster.Status.SKIPPED, + Set.of(new FailureCause(IllegalArgumentException.class, "[model_id] must not be null.")) + ), + null + ); + } + + public void testKnnQueryWithCcsMinimizeRoundTripsFalse() throws Exception { + final BiConsumer assertCcsMinimizeRoundTripsFalseFailure = (f, qvb) -> { + KnnVectorQueryBuilder queryBuilder = new KnnVectorQueryBuilder(f, qvb, 10, 100, 10f, null); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(queryBuilder); + SearchRequest searchRequest = new SearchRequest(convertToArray(QUERY_INDICES), searchSourceBuilder); + searchRequest.setCcsMinimizeRoundtrips(false); + + IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> client().search(searchRequest).actionGet(TEST_REQUEST_TIMEOUT) + ); + assertThat( + e.getMessage(), + equalTo( + "knn query does not support cross-cluster search when querying a [semantic_text] field when " + + "[ccs_minimize_roundtrips] is false" + ) + ); + }; + + final int dimensions = 256; + final String commonInferenceId = "common-inference-id"; + final MinimalServiceSettings commonInferenceIdServiceSettings = textEmbeddingServiceSettings( + dimensions, + SimilarityMeasure.COSINE, + DenseVectorFieldMapper.ElementType.FLOAT + ); + + final String commonInferenceIdField = "common-inference-id-field"; + final String mixedTypeField1 = "mixed-type-field-1"; + final String mixedTypeField2 = "mixed-type-field-2"; + final String denseVectorField = "dense-vector-field"; + + final TestIndexInfo localIndexInfo = new TestIndexInfo( + LOCAL_INDEX_NAME, + Map.of(commonInferenceId, commonInferenceIdServiceSettings), + Map.of( + commonInferenceIdField, + semanticTextMapping(commonInferenceId), + mixedTypeField1, + semanticTextMapping(commonInferenceId), + mixedTypeField2, + denseVectorMapping(dimensions), + denseVectorField, + denseVectorMapping(dimensions) + ), + Map.of( + mixedTypeField2 + "_doc", + Map.of(mixedTypeField2, generateDenseVectorFieldValue(dimensions, DenseVectorFieldMapper.ElementType.FLOAT, -128.0f)), + denseVectorField + "_doc", + Map.of(denseVectorField, generateDenseVectorFieldValue(dimensions, DenseVectorFieldMapper.ElementType.FLOAT, 1.0f)) + ) + ); + final TestIndexInfo remoteIndexInfo = new TestIndexInfo( + REMOTE_INDEX_NAME, + Map.of(commonInferenceId, commonInferenceIdServiceSettings), + Map.of( + commonInferenceIdField, + semanticTextMapping(commonInferenceId), + mixedTypeField1, + denseVectorMapping(dimensions), + mixedTypeField2, + semanticTextMapping(commonInferenceId), + denseVectorField, + denseVectorMapping(dimensions) + ), + Map.of( + mixedTypeField2 + "_doc", + Map.of(mixedTypeField2, "a"), + denseVectorField + "_doc", + Map.of(denseVectorField, generateDenseVectorFieldValue(dimensions, DenseVectorFieldMapper.ElementType.FLOAT, -128.0f)) + ) + ); + setupTwoClusters(localIndexInfo, remoteIndexInfo); + + // Validate that expected cases fail + assertCcsMinimizeRoundTripsFalseFailure.accept( + commonInferenceIdField, + new TextEmbeddingQueryVectorBuilder(null, randomAlphaOfLength(5)) + ); + assertCcsMinimizeRoundTripsFalseFailure.accept( + mixedTypeField1, + new TextEmbeddingQueryVectorBuilder(commonInferenceId, randomAlphaOfLength(5)) + ); + + // Validate the expected ccs_minimize_roundtrips=false detection gap and failure mode when querying non-inference fields locally + assertSearchResponse( + new KnnVectorQueryBuilder(mixedTypeField2, new TextEmbeddingQueryVectorBuilder(commonInferenceId, "foo"), 10, 100, 10f, null), + QUERY_INDICES, + List.of(new SearchResult(null, LOCAL_INDEX_NAME, mixedTypeField2 + "_doc")), + new ClusterFailure( + SearchResponse.Cluster.Status.SKIPPED, + Set.of( + new FailureCause( + QueryShardException.class, + "failed to create query: [knn] queries are only supported on [dense_vector] fields" + ) + ) + ), + s -> s.setCcsMinimizeRoundtrips(false) + ); + + // Validate that a CCS knn query functions when only dense vector fields are queried + assertSearchResponse( + new KnnVectorQueryBuilder( + denseVectorField, + generateDenseVectorFieldValue(dimensions, DenseVectorFieldMapper.ElementType.FLOAT, 1.0f), + 10, + 100, + 10f, + null, + null + ), + QUERY_INDICES, + List.of( + new SearchResult(null, LOCAL_INDEX_NAME, denseVectorField + "_doc"), + new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, denseVectorField + "_doc") + ), + null, + s -> s.setCcsMinimizeRoundtrips(false) + ); + } +} diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/MatchQueryBuilderCrossClusterSearchIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/MatchQueryBuilderCrossClusterSearchIT.java new file mode 100644 index 0000000000000..a83f7fa80e461 --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/MatchQueryBuilderCrossClusterSearchIT.java @@ -0,0 +1,217 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.search.ccs; + +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.query.MatchQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryShardException; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Consumer; + +import static org.hamcrest.Matchers.equalTo; + +public class MatchQueryBuilderCrossClusterSearchIT extends AbstractSemanticCrossClusterSearchTestCase { + private static final String LOCAL_INDEX_NAME = "local-index"; + private static final String REMOTE_INDEX_NAME = "remote-index"; + + // Boost the local index so that we can use the same doc values for local and remote indices and have consistent relevance + private static final List QUERY_INDICES = List.of( + new IndexWithBoost(LOCAL_INDEX_NAME, 10.0f), + new IndexWithBoost(fullyQualifiedIndexName(REMOTE_CLUSTER, REMOTE_INDEX_NAME)) + ); + + private static final String COMMON_INFERENCE_ID_FIELD = "common-inference-id-field"; + private static final String VARIABLE_INFERENCE_ID_FIELD = "variable-inference-id-field"; + private static final String MIXED_TYPE_FIELD_1 = "mixed-type-field-1"; + private static final String MIXED_TYPE_FIELD_2 = "mixed-type-field-2"; + private static final String TEXT_FIELD = "text-field"; + + boolean clustersConfigured = false; + + @Override + protected boolean reuseClusters() { + return true; + } + + @Before + @Override + public void setUp() throws Exception { + super.setUp(); + if (clustersConfigured == false) { + configureClusters(); + clustersConfigured = true; + } + } + + public void testMatchQuery() throws Exception { + // Query a field has the same inference ID value across clusters, but with different backing inference services + assertSearchResponse( + new MatchQueryBuilder(COMMON_INFERENCE_ID_FIELD, "a"), + QUERY_INDICES, + List.of( + new SearchResult(LOCAL_CLUSTER, LOCAL_INDEX_NAME, getDocId(COMMON_INFERENCE_ID_FIELD)), + new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, getDocId(COMMON_INFERENCE_ID_FIELD)) + ) + ); + + // Query a field that has different inference ID values across clusters + assertSearchResponse( + new MatchQueryBuilder(VARIABLE_INFERENCE_ID_FIELD, "b"), + QUERY_INDICES, + List.of( + new SearchResult(LOCAL_CLUSTER, LOCAL_INDEX_NAME, getDocId(VARIABLE_INFERENCE_ID_FIELD)), + new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, getDocId(VARIABLE_INFERENCE_ID_FIELD)) + ) + ); + + // Query a field that has mixed types across clusters + assertSearchResponse( + new MatchQueryBuilder(MIXED_TYPE_FIELD_1, "c"), + QUERY_INDICES, + List.of( + new SearchResult(LOCAL_CLUSTER, LOCAL_INDEX_NAME, getDocId(MIXED_TYPE_FIELD_1)), + new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, getDocId(MIXED_TYPE_FIELD_1)) + ) + ); + assertSearchResponse( + new MatchQueryBuilder(MIXED_TYPE_FIELD_2, "d"), + QUERY_INDICES, + List.of( + new SearchResult(LOCAL_CLUSTER, LOCAL_INDEX_NAME, getDocId(MIXED_TYPE_FIELD_2)), + new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, getDocId(MIXED_TYPE_FIELD_2)) + ) + ); + } + + public void testMatchQueryWithCcsMinimizeRoundTripsFalse() throws Exception { + final Consumer assertCcsMinimizeRoundTripsFalseFailure = q -> { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(q); + SearchRequest searchRequest = new SearchRequest(convertToArray(QUERY_INDICES), searchSourceBuilder); + searchRequest.setCcsMinimizeRoundtrips(false); + + IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> client().search(searchRequest).actionGet(TEST_REQUEST_TIMEOUT) + ); + assertThat( + e.getMessage(), + equalTo( + "match query does not support cross-cluster search when querying a [semantic_text] field when " + + "[ccs_minimize_roundtrips] is false" + ) + ); + }; + + // Validate that expected cases fail + assertCcsMinimizeRoundTripsFalseFailure.accept(new MatchQueryBuilder(COMMON_INFERENCE_ID_FIELD, randomAlphaOfLength(5))); + assertCcsMinimizeRoundTripsFalseFailure.accept(new MatchQueryBuilder(MIXED_TYPE_FIELD_1, randomAlphaOfLength(5))); + + // Validate the expected ccs_minimize_roundtrips=false detection gap and failure mode when querying non-inference fields locally + assertSearchResponse( + new MatchQueryBuilder(MIXED_TYPE_FIELD_2, "d"), + QUERY_INDICES, + List.of(new SearchResult(null, LOCAL_INDEX_NAME, getDocId(MIXED_TYPE_FIELD_2))), + new ClusterFailure( + SearchResponse.Cluster.Status.SKIPPED, + Set.of( + new FailureCause( + QueryShardException.class, + "failed to create query: Field [mixed-type-field-2] of type [semantic_text] does not support match queries" + ) + ) + ), + s -> s.setCcsMinimizeRoundtrips(false) + ); + + // Validate that a CCS match query functions when only text fields are queried + assertSearchResponse( + new MatchQueryBuilder(TEXT_FIELD, "e"), + QUERY_INDICES, + List.of( + new SearchResult(null, LOCAL_INDEX_NAME, getDocId(TEXT_FIELD)), + new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, getDocId(TEXT_FIELD)) + ), + null, + s -> s.setCcsMinimizeRoundtrips(false) + ); + } + + private void configureClusters() throws IOException { + final String commonInferenceId = "common-inference-id"; + final String localInferenceId = "local-inference-id"; + final String remoteInferenceId = "remote-inference-id"; + + final Map> docs = Map.of( + getDocId(COMMON_INFERENCE_ID_FIELD), + Map.of(COMMON_INFERENCE_ID_FIELD, "a"), + getDocId(VARIABLE_INFERENCE_ID_FIELD), + Map.of(VARIABLE_INFERENCE_ID_FIELD, "b"), + getDocId(MIXED_TYPE_FIELD_1), + Map.of(MIXED_TYPE_FIELD_1, "c"), + getDocId(MIXED_TYPE_FIELD_2), + Map.of(MIXED_TYPE_FIELD_2, "d"), + getDocId(TEXT_FIELD), + Map.of(TEXT_FIELD, "e") + ); + + final TestIndexInfo localIndexInfo = new TestIndexInfo( + LOCAL_INDEX_NAME, + Map.of(commonInferenceId, sparseEmbeddingServiceSettings(), localInferenceId, sparseEmbeddingServiceSettings()), + Map.of( + COMMON_INFERENCE_ID_FIELD, + semanticTextMapping(commonInferenceId), + VARIABLE_INFERENCE_ID_FIELD, + semanticTextMapping(localInferenceId), + MIXED_TYPE_FIELD_1, + semanticTextMapping(localInferenceId), + MIXED_TYPE_FIELD_2, + textMapping(), + TEXT_FIELD, + textMapping() + ), + docs + ); + final TestIndexInfo remoteIndexInfo = new TestIndexInfo( + REMOTE_INDEX_NAME, + Map.of( + commonInferenceId, + textEmbeddingServiceSettings(256, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT), + remoteInferenceId, + textEmbeddingServiceSettings(384, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT) + ), + Map.of( + COMMON_INFERENCE_ID_FIELD, + semanticTextMapping(commonInferenceId), + VARIABLE_INFERENCE_ID_FIELD, + semanticTextMapping(remoteInferenceId), + MIXED_TYPE_FIELD_1, + textMapping(), + MIXED_TYPE_FIELD_2, + semanticTextMapping(remoteInferenceId), + TEXT_FIELD, + textMapping() + ), + docs + ); + setupTwoClusters(localIndexInfo, remoteIndexInfo); + } + + private static String getDocId(String field) { + return field + "_doc"; + } +} diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticQueryBuilderCrossClusterSearchIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticQueryBuilderCrossClusterSearchIT.java new file mode 100644 index 0000000000000..4b3b616f93bb0 --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticQueryBuilderCrossClusterSearchIT.java @@ -0,0 +1,121 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.search.ccs; + +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.search.builder.PointInTimeBuilder; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; + +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; + +import static org.hamcrest.Matchers.equalTo; + +public class SemanticQueryBuilderCrossClusterSearchIT extends AbstractSemanticCrossClusterSearchTestCase { + private static final String LOCAL_INDEX_NAME = "local-index"; + private static final String REMOTE_INDEX_NAME = "remote-index"; + private static final List QUERY_INDICES = List.of( + new IndexWithBoost(LOCAL_INDEX_NAME), + new IndexWithBoost(fullyQualifiedIndexName(REMOTE_CLUSTER, REMOTE_INDEX_NAME)) + ); + + public void testSemanticQuery() throws Exception { + final String commonInferenceId = "common-inference-id"; + final String localInferenceId = "local-inference-id"; + final String remoteInferenceId = "remote-inference-id"; + + final String commonInferenceIdField = "common-inference-id-field"; + final String variableInferenceIdField = "variable-inference-id-field"; + + final TestIndexInfo localIndexInfo = new TestIndexInfo( + LOCAL_INDEX_NAME, + Map.of(commonInferenceId, sparseEmbeddingServiceSettings(), localInferenceId, sparseEmbeddingServiceSettings()), + Map.of( + commonInferenceIdField, + semanticTextMapping(commonInferenceId), + variableInferenceIdField, + semanticTextMapping(localInferenceId) + ), + Map.of("local_doc_1", Map.of(commonInferenceIdField, "a"), "local_doc_2", Map.of(variableInferenceIdField, "b")) + ); + final TestIndexInfo remoteIndexInfo = new TestIndexInfo( + REMOTE_INDEX_NAME, + Map.of( + commonInferenceId, + textEmbeddingServiceSettings(256, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT), + remoteInferenceId, + textEmbeddingServiceSettings(384, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT) + ), + Map.of( + commonInferenceIdField, + semanticTextMapping(commonInferenceId), + variableInferenceIdField, + semanticTextMapping(remoteInferenceId) + ), + Map.of("remote_doc_1", Map.of(commonInferenceIdField, "x"), "remote_doc_2", Map.of(variableInferenceIdField, "y")) + ); + setupTwoClusters(localIndexInfo, remoteIndexInfo); + + // Query a field has the same inference ID value across clusters, but with different backing inference services + assertSearchResponse( + new SemanticQueryBuilder(commonInferenceIdField, "a"), + QUERY_INDICES, + List.of( + new SearchResult(LOCAL_CLUSTER, LOCAL_INDEX_NAME, "local_doc_1"), + new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, "remote_doc_1") + ) + ); + + // Query a field that has different inference ID values across clusters + assertSearchResponse( + new SemanticQueryBuilder(variableInferenceIdField, "b"), + QUERY_INDICES, + List.of( + new SearchResult(LOCAL_CLUSTER, LOCAL_INDEX_NAME, "local_doc_2"), + new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, "remote_doc_2") + ) + ); + } + + public void testSemanticQueryWithCcMinimizeRoundTripsFalse() throws Exception { + final SemanticQueryBuilder queryBuilder = new SemanticQueryBuilder("foo", "bar"); + final Consumer assertCcsMinimizeRoundTripsFalseFailure = s -> { + IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> client().search(s).actionGet(TEST_REQUEST_TIMEOUT) + ); + assertThat( + e.getMessage(), + equalTo("semantic query does not support cross-cluster search when [ccs_minimize_roundtrips] is false") + ); + }; + + final TestIndexInfo localIndexInfo = new TestIndexInfo(LOCAL_INDEX_NAME, Map.of(), Map.of(), Map.of()); + final TestIndexInfo remoteIndexInfo = new TestIndexInfo(REMOTE_INDEX_NAME, Map.of(), Map.of(), Map.of()); + setupTwoClusters(localIndexInfo, remoteIndexInfo); + + // Explicitly set ccs_minimize_roundtrips=false in the search request + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(queryBuilder); + SearchRequest searchRequestWithCcMinimizeRoundTripsFalse = new SearchRequest(convertToArray(QUERY_INDICES), searchSourceBuilder); + searchRequestWithCcMinimizeRoundTripsFalse.setCcsMinimizeRoundtrips(false); + assertCcsMinimizeRoundTripsFalseFailure.accept(searchRequestWithCcMinimizeRoundTripsFalse); + + // Using a point in time implicitly sets ccs_minimize_roundtrips=false + BytesReference pitId = openPointInTime(convertToArray(QUERY_INDICES), TimeValue.timeValueMinutes(2)); + SearchSourceBuilder searchSourceBuilderWithPit = new SearchSourceBuilder().query(queryBuilder) + .pointInTimeBuilder(new PointInTimeBuilder(pitId)); + SearchRequest searchRequestWithPit = new SearchRequest().source(searchSourceBuilderWithPit); + assertCcsMinimizeRoundTripsFalseFailure.accept(searchRequestWithPit); + } +} diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SparseVectorQueryBuilderCrossClusterSearchIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SparseVectorQueryBuilderCrossClusterSearchIT.java new file mode 100644 index 0000000000000..be9183722a48c --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SparseVectorQueryBuilderCrossClusterSearchIT.java @@ -0,0 +1,248 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.search.ccs; + +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryShardException; +import org.elasticsearch.inference.WeightedToken; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Consumer; + +import static org.hamcrest.Matchers.equalTo; + +public class SparseVectorQueryBuilderCrossClusterSearchIT extends AbstractSemanticCrossClusterSearchTestCase { + private static final String LOCAL_INDEX_NAME = "local-index"; + private static final String REMOTE_INDEX_NAME = "remote-index"; + private static final List QUERY_INDICES = List.of( + new IndexWithBoost(LOCAL_INDEX_NAME), + new IndexWithBoost(fullyQualifiedIndexName(REMOTE_CLUSTER, REMOTE_INDEX_NAME)) + ); + + public void testSparseVectorQuery() throws Exception { + final String commonInferenceId = "common-inference-id"; + + final String commonInferenceIdField = "common-inference-id-field"; + final String mixedTypeField1 = "mixed-type-field-1"; + final String mixedTypeField2 = "mixed-type-field-2"; + + final TestIndexInfo localIndexInfo = new TestIndexInfo( + LOCAL_INDEX_NAME, + Map.of(commonInferenceId, sparseEmbeddingServiceSettings()), + Map.of( + commonInferenceIdField, + semanticTextMapping(commonInferenceId), + mixedTypeField1, + sparseVectorMapping(), + mixedTypeField2, + semanticTextMapping(commonInferenceId) + ), + Map.of( + "local_doc_1", + Map.of(commonInferenceIdField, "a"), + "local_doc_2", + Map.of(mixedTypeField1, generateSparseVectorFieldValue(1.0f)), + "local_doc_3", + Map.of(mixedTypeField2, "c") + ) + ); + final TestIndexInfo remoteIndexInfo = new TestIndexInfo( + REMOTE_INDEX_NAME, + Map.of(commonInferenceId, sparseEmbeddingServiceSettings()), + Map.of( + commonInferenceIdField, + semanticTextMapping(commonInferenceId), + mixedTypeField1, + semanticTextMapping(commonInferenceId), + mixedTypeField2, + sparseVectorMapping() + ), + Map.of( + "remote_doc_1", + Map.of(commonInferenceIdField, "x"), + "remote_doc_2", + Map.of(mixedTypeField1, "y"), + "remote_doc_3", + Map.of(mixedTypeField2, generateSparseVectorFieldValue(1.0f)) + ) + ); + setupTwoClusters(localIndexInfo, remoteIndexInfo); + + // Query a field has the same inference ID value across clusters, but with different backing inference services + assertSearchResponse( + new SparseVectorQueryBuilder(commonInferenceIdField, null, "a"), + QUERY_INDICES, + List.of( + new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, "remote_doc_1"), + new SearchResult(LOCAL_CLUSTER, LOCAL_INDEX_NAME, "local_doc_1") + ) + ); + + // Query a field that has mixed types across clusters + assertSearchResponse( + new SparseVectorQueryBuilder(mixedTypeField1, commonInferenceId, "b"), + QUERY_INDICES, + List.of( + new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, "remote_doc_2"), + new SearchResult(LOCAL_CLUSTER, LOCAL_INDEX_NAME, "local_doc_2") + ) + ); + assertSearchResponse( + new SparseVectorQueryBuilder(mixedTypeField2, commonInferenceId, "c"), + QUERY_INDICES, + List.of( + new SearchResult(LOCAL_CLUSTER, LOCAL_INDEX_NAME, "local_doc_3"), + new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, "remote_doc_3") + ) + ); + + // Query a field that has mixed types across clusters using a query vector + final List queryVector = generateSparseVectorFieldValue(1.0f).entrySet() + .stream() + .map(e -> new WeightedToken(e.getKey(), e.getValue())) + .toList(); + assertSearchResponse( + new SparseVectorQueryBuilder(mixedTypeField1, queryVector, null, null, null, null), + QUERY_INDICES, + List.of( + new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, "remote_doc_2"), + new SearchResult(LOCAL_CLUSTER, LOCAL_INDEX_NAME, "local_doc_2") + ) + ); + assertSearchResponse( + new SparseVectorQueryBuilder(mixedTypeField2, queryVector, null, null, null, null), + QUERY_INDICES, + List.of( + new SearchResult(LOCAL_CLUSTER, LOCAL_INDEX_NAME, "local_doc_3"), + new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, "remote_doc_3") + ) + ); + + // Check that omitting the inference ID when querying a remote sparse vector field leads to the expected partial failure + assertSearchResponse( + new SparseVectorQueryBuilder(mixedTypeField2, null, "c"), + QUERY_INDICES, + List.of(new SearchResult(LOCAL_CLUSTER, LOCAL_INDEX_NAME, "local_doc_3")), + new ClusterFailure( + SearchResponse.Cluster.Status.SKIPPED, + Set.of(new FailureCause(IllegalArgumentException.class, "inference_id required to perform vector search on query string")) + ), + null + ); + } + + public void testSparseVectorQueryWithCcsMinimizeRoundTripsFalse() throws Exception { + final Consumer assertCcsMinimizeRoundTripsFalseFailure = q -> { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(q); + SearchRequest searchRequest = new SearchRequest(convertToArray(QUERY_INDICES), searchSourceBuilder); + searchRequest.setCcsMinimizeRoundtrips(false); + + IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> client().search(searchRequest).actionGet(TEST_REQUEST_TIMEOUT) + ); + assertThat( + e.getMessage(), + equalTo( + "sparse_vector query does not support cross-cluster search when querying a [semantic_text] field when " + + "[ccs_minimize_roundtrips] is false" + ) + ); + }; + + final String commonInferenceId = "common-inference-id"; + + final String commonInferenceIdField = "common-inference-id-field"; + final String mixedTypeField1 = "mixed-type-field-1"; + final String mixedTypeField2 = "mixed-type-field-2"; + final String sparseVectorField = "sparse-vector-field"; + + final TestIndexInfo localIndexInfo = new TestIndexInfo( + LOCAL_INDEX_NAME, + Map.of(commonInferenceId, sparseEmbeddingServiceSettings()), + Map.of( + commonInferenceIdField, + semanticTextMapping(commonInferenceId), + mixedTypeField1, + semanticTextMapping(commonInferenceId), + mixedTypeField2, + sparseVectorMapping(), + sparseVectorField, + sparseVectorMapping() + ), + Map.of( + mixedTypeField2 + "_doc", + Map.of(mixedTypeField2, generateSparseVectorFieldValue(1.0f)), + sparseVectorField + "_doc", + Map.of(sparseVectorField, generateSparseVectorFieldValue(1.0f)) + ) + ); + final TestIndexInfo remoteIndexInfo = new TestIndexInfo( + REMOTE_INDEX_NAME, + Map.of(commonInferenceId, sparseEmbeddingServiceSettings()), + Map.of( + commonInferenceIdField, + semanticTextMapping(commonInferenceId), + mixedTypeField1, + sparseVectorMapping(), + mixedTypeField2, + semanticTextMapping(commonInferenceId), + sparseVectorField, + sparseVectorMapping() + ), + Map.of( + mixedTypeField2 + "_doc", + Map.of(mixedTypeField2, "a"), + sparseVectorField + "_doc", + Map.of(sparseVectorField, generateSparseVectorFieldValue(0.5f)) + ) + ); + setupTwoClusters(localIndexInfo, remoteIndexInfo); + + // Validate that expected cases fail + assertCcsMinimizeRoundTripsFalseFailure.accept(new SparseVectorQueryBuilder(commonInferenceIdField, null, randomAlphaOfLength(5))); + assertCcsMinimizeRoundTripsFalseFailure.accept( + new SparseVectorQueryBuilder(mixedTypeField1, commonInferenceId, randomAlphaOfLength(5)) + ); + + // Validate the expected ccs_minimize_roundtrips=false detection gap and failure mode when querying non-inference fields locally + assertSearchResponse( + new SparseVectorQueryBuilder(mixedTypeField2, commonInferenceId, "foo"), + QUERY_INDICES, + List.of(new SearchResult(null, LOCAL_INDEX_NAME, mixedTypeField2 + "_doc")), + new ClusterFailure( + SearchResponse.Cluster.Status.SKIPPED, + Set.of( + new FailureCause( + QueryShardException.class, + "failed to create query: field [mixed-type-field-2] must be type [sparse_vector] but is type [semantic_text]" + ) + ) + ), + s -> s.setCcsMinimizeRoundtrips(false) + ); + + // Validate that a CCS sparse vector query functions when only sparse vector fields are queried + assertSearchResponse( + new SparseVectorQueryBuilder(sparseVectorField, commonInferenceId, "foo"), + QUERY_INDICES, + List.of( + new SearchResult(null, LOCAL_INDEX_NAME, sparseVectorField + "_doc"), + new SearchResult(REMOTE_CLUSTER, REMOTE_INDEX_NAME, sparseVectorField + "_doc") + ), + null, + s -> s.setCcsMinimizeRoundtrips(false) + ); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/FullyQualifiedInferenceId.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/FullyQualifiedInferenceId.java index 5ee9d00da4abb..d1cc7aa3aa5c7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/FullyQualifiedInferenceId.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/FullyQualifiedInferenceId.java @@ -29,4 +29,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(clusterAlias); out.writeString(inferenceId); } + + @Override + public String toString() { + return "{clusterAlias=" + clusterAlias + ", inferenceId=" + inferenceId + "}"; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java index bf6b9d534d52e..0696185e11650 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java @@ -34,6 +34,8 @@ import java.util.Collection; import java.util.Map; +import static org.elasticsearch.transport.RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY; + public class InterceptedInferenceKnnVectorQueryBuilder extends InterceptedInferenceQueryBuilder { public static final String NAME = "intercepted_inference_knn"; @@ -44,15 +46,23 @@ public InterceptedInferenceKnnVectorQueryBuilder(KnnVectorQueryBuilder originalQ super(originalQuery); } + public InterceptedInferenceKnnVectorQueryBuilder( + KnnVectorQueryBuilder originalQuery, + Map inferenceResultsMap + ) { + super(originalQuery, inferenceResultsMap); + } + public InterceptedInferenceKnnVectorQueryBuilder(StreamInput in) throws IOException { super(in); } - InterceptedInferenceKnnVectorQueryBuilder( + private InterceptedInferenceKnnVectorQueryBuilder( InterceptedInferenceQueryBuilder other, - Map inferenceResultsMap + Map inferenceResultsMap, + boolean ccsRequest ) { - super(other, inferenceResultsMap); + super(other, inferenceResultsMap, ccsRequest); } @Override @@ -72,8 +82,9 @@ protected String getQuery() { } @Override - protected String getInferenceIdOverride() { - return getQueryVectorBuilderModelId(); + protected FullyQualifiedInferenceId getInferenceIdOverride() { + String modelId = getQueryVectorBuilderModelId(); + return modelId != null ? new FullyQualifiedInferenceId(LOCAL_CLUSTER_GROUP_KEY, modelId) : null; } @Override @@ -114,8 +125,8 @@ protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) { } @Override - protected QueryBuilder copy(Map inferenceResultsMap) { - return new InterceptedInferenceKnnVectorQueryBuilder(this, inferenceResultsMap); + protected QueryBuilder copy(Map inferenceResultsMap, boolean ccsRequest) { + return new InterceptedInferenceKnnVectorQueryBuilder(this, inferenceResultsMap, ccsRequest); } @Override @@ -131,7 +142,7 @@ protected QueryBuilder queryFields( } else if (fieldType instanceof SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType) { rewritten = querySemanticTextField(indexMetadataContext.getLocalClusterAlias(), semanticTextFieldType); } else { - rewritten = queryNonSemanticTextField(indexMetadataContext.getLocalClusterAlias()); + rewritten = queryNonSemanticTextField(); } return rewritten; @@ -177,12 +188,12 @@ private QueryBuilder querySemanticTextField(String clusterAlias, SemanticTextFie VectorData queryVector = originalQuery.queryVector(); if (queryVector == null) { - String inferenceId = getQueryVectorBuilderModelId(); - if (inferenceId == null) { - inferenceId = semanticTextFieldType.getSearchInferenceId(); + FullyQualifiedInferenceId fullyQualifiedInferenceId = getInferenceIdOverride(); + if (fullyQualifiedInferenceId == null) { + fullyQualifiedInferenceId = new FullyQualifiedInferenceId(clusterAlias, semanticTextFieldType.getSearchInferenceId()); } - MlTextEmbeddingResults textEmbeddingResults = getTextEmbeddingResults(clusterAlias, inferenceId); + MlTextEmbeddingResults textEmbeddingResults = getTextEmbeddingResults(fullyQualifiedInferenceId); queryVector = new VectorData(textEmbeddingResults.getInferenceAsFloat()); } @@ -202,18 +213,18 @@ private QueryBuilder querySemanticTextField(String clusterAlias, SemanticTextFie .queryName(originalQuery.queryName()); } - private QueryBuilder queryNonSemanticTextField(String clusterAlias) { + private QueryBuilder queryNonSemanticTextField() { VectorData queryVector = originalQuery.queryVector(); if (queryVector == null) { - String modelId = getQueryVectorBuilderModelId(); - if (modelId == null) { + FullyQualifiedInferenceId fullyQualifiedInferenceId = getInferenceIdOverride(); + if (fullyQualifiedInferenceId == null) { // This should never happen because we validate that either query vector or a valid query vector builder is specified in: // - The KnnVectorQueryBuilder constructor // - coordinatorNodeValidate throw new IllegalStateException("No query vector or query vector builder model ID specified"); } - MlTextEmbeddingResults textEmbeddingResults = getTextEmbeddingResults(clusterAlias, modelId); + MlTextEmbeddingResults textEmbeddingResults = getTextEmbeddingResults(fullyQualifiedInferenceId); queryVector = new VectorData(textEmbeddingResults.getInferenceAsFloat()); } @@ -231,10 +242,10 @@ private QueryBuilder queryNonSemanticTextField(String clusterAlias) { return knnQuery; } - private MlTextEmbeddingResults getTextEmbeddingResults(String clusterAlias, String inferenceId) { - InferenceResults inferenceResults = inferenceResultsMap.get(new FullyQualifiedInferenceId(clusterAlias, inferenceId)); + private MlTextEmbeddingResults getTextEmbeddingResults(FullyQualifiedInferenceId fullyQualifiedInferenceId) { + InferenceResults inferenceResults = inferenceResultsMap.get(fullyQualifiedInferenceId); if (inferenceResults == null) { - throw new IllegalStateException("Could not find inference results from inference endpoint [" + inferenceId + "]"); + throw new IllegalStateException("Could not find inference results from inference endpoint [" + fullyQualifiedInferenceId + "]"); } else if (inferenceResults instanceof MlTextEmbeddingResults == false) { throw new IllegalArgumentException( "Expected query inference results to be of type [" diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilder.java index eecc08acebb4d..69cbf665cc1f8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilder.java @@ -31,15 +31,23 @@ public InterceptedInferenceMatchQueryBuilder(MatchQueryBuilder originalQuery) { super(originalQuery); } + public InterceptedInferenceMatchQueryBuilder( + MatchQueryBuilder originalQuery, + Map inferenceResultsMap + ) { + super(originalQuery, inferenceResultsMap); + } + public InterceptedInferenceMatchQueryBuilder(StreamInput in) throws IOException { super(in); } - InterceptedInferenceMatchQueryBuilder( + private InterceptedInferenceMatchQueryBuilder( InterceptedInferenceQueryBuilder other, - Map inferenceResultsMap + Map inferenceResultsMap, + boolean ccsRequest ) { - super(other, inferenceResultsMap); + super(other, inferenceResultsMap, ccsRequest); } @Override @@ -63,8 +71,8 @@ protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) { } @Override - protected QueryBuilder copy(Map inferenceResultsMap) { - return new InterceptedInferenceMatchQueryBuilder(this, inferenceResultsMap); + protected QueryBuilder copy(Map inferenceResultsMap, boolean ccsRequest) { + return new InterceptedInferenceMatchQueryBuilder(this, inferenceResultsMap, ccsRequest); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java index a1b5f34aab848..8774d35f17ade 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java @@ -40,6 +40,7 @@ import static org.elasticsearch.index.IndexSettings.DEFAULT_FIELD_SETTING; import static org.elasticsearch.transport.RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY; +import static org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder.SEMANTIC_SEARCH_CCS_SUPPORT; import static org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder.convertFromBwcInferenceResultsMap; /** @@ -66,11 +67,17 @@ public abstract class InterceptedInferenceQueryBuilder inferenceResultsMap; + protected final boolean ccsRequest; protected InterceptedInferenceQueryBuilder(T originalQuery) { + this(originalQuery, null); + } + + protected InterceptedInferenceQueryBuilder(T originalQuery, Map inferenceResultsMap) { Objects.requireNonNull(originalQuery, "original query must not be null"); this.originalQuery = originalQuery; - this.inferenceResultsMap = null; + this.inferenceResultsMap = inferenceResultsMap != null ? Map.copyOf(inferenceResultsMap) : null; + this.ccsRequest = false; } @SuppressWarnings("unchecked") @@ -86,14 +93,21 @@ protected InterceptedInferenceQueryBuilder(StreamInput in) throws IOException { in.readOptional(i1 -> i1.readImmutableMap(i2 -> i2.readNamedWriteable(InferenceResults.class))) ); } + if (in.getTransportVersion().supports(SEMANTIC_SEARCH_CCS_SUPPORT)) { + this.ccsRequest = in.readBoolean(); + } else { + this.ccsRequest = false; + } } protected InterceptedInferenceQueryBuilder( InterceptedInferenceQueryBuilder other, - Map inferenceResultsMap + Map inferenceResultsMap, + boolean ccsRequest ) { this.originalQuery = other.originalQuery; this.inferenceResultsMap = inferenceResultsMap; + this.ccsRequest = ccsRequest; } /** @@ -133,9 +147,10 @@ protected InterceptedInferenceQueryBuilder( * Generate a copy of {@code this} using the provided inference results map. * * @param inferenceResultsMap The inference results map + * @param ccsRequest Flag indicating if this is a CCS request * @return A copy of {@code this} with the provided inference results map */ - protected abstract QueryBuilder copy(Map inferenceResultsMap); + protected abstract QueryBuilder copy(Map inferenceResultsMap, boolean ccsRequest); /** * Rewrite to a {@link QueryBuilder} appropriate for a specific index's mappings. The implementation can use @@ -167,7 +182,7 @@ protected abstract QueryBuilder queryFields( /** * Get the query-time inference ID override. If not applicable or available, {@code null} should be returned. */ - protected String getInferenceIdOverride() { + protected FullyQualifiedInferenceId getInferenceIdOverride() { return null; } @@ -194,6 +209,19 @@ protected void doWriteTo(StreamOutput out) throws IOException { o2.writeString(id.inferenceId()); }, StreamOutput::writeNamedWriteable), inferenceResultsMap); } + if (out.getTransportVersion().supports(SEMANTIC_SEARCH_CCS_SUPPORT)) { + out.writeBoolean(ccsRequest); + } else if (ccsRequest) { + throw new IllegalArgumentException( + "One or more nodes does not support " + + originalQuery.getName() + + " query cross-cluster search when querying a [" + + SemanticTextFieldMapper.CONTENT_TYPE + + "] field. Please update all nodes to at least Elasticsearch " + + SEMANTIC_SEARCH_CCS_SUPPORT.toReleaseVersion() + + "." + ); + } } @Override @@ -208,12 +236,14 @@ protected Query doToQuery(SearchExecutionContext context) { @Override protected boolean doEquals(InterceptedInferenceQueryBuilder other) { - return Objects.equals(originalQuery, other.originalQuery) && Objects.equals(inferenceResultsMap, other.inferenceResultsMap); + return Objects.equals(originalQuery, other.originalQuery) + && Objects.equals(inferenceResultsMap, other.inferenceResultsMap) + && Objects.equals(ccsRequest, other.ccsRequest); } @Override protected int doHashCode() { - return Objects.hash(originalQuery, inferenceResultsMap); + return Objects.hash(originalQuery, inferenceResultsMap, ccsRequest); } @Override @@ -261,14 +291,19 @@ private QueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewri // In this case, the remote data node will receive the original query, which will in turn result in an error about querying an // unsupported field type. ResolvedIndices resolvedIndices = queryRewriteContext.getResolvedIndices(); - Set inferenceIds = getInferenceIdsForFields( + Set inferenceIds = getInferenceIdsForFields( resolvedIndices.getConcreteLocalIndicesMetadata().values(), + queryRewriteContext.getLocalClusterAlias(), getFields(), resolveWildcards(), useDefaultFields() ); - if (inferenceIds.isEmpty()) { + // If we are handling a CCS request, always retain the intercepted query logic so that we can get inference results generated on + // the local cluster from the inference results map when rewriting on remote cluster data nodes. This can be necessary when: + // - A query specifies an inference ID override + // - Only non-inference fields are queried on the remote cluster + if (inferenceIds.isEmpty() && this.ccsRequest == false) { // Not querying a semantic text field return originalQuery; } @@ -276,17 +311,17 @@ private QueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewri // Validate early to prevent partial failures coordinatorNodeValidate(resolvedIndices); - // TODO: Check for supported CCS mode here (once we support CCS) - if (resolvedIndices.getRemoteClusterIndices().isEmpty() == false) { + boolean ccsRequest = this.ccsRequest || resolvedIndices.getRemoteClusterIndices().isEmpty() == false; + if (ccsRequest && queryRewriteContext.isCcsMinimizeRoundTrips() == false) { throw new IllegalArgumentException( originalQuery.getName() + " query does not support cross-cluster search when querying a [" + SemanticTextFieldMapper.CONTENT_TYPE - + "] field" + + "] field when [ccs_minimize_roundtrips] is false" ); } - String inferenceIdOverride = getInferenceIdOverride(); + FullyQualifiedInferenceId inferenceIdOverride = getInferenceIdOverride(); if (inferenceIdOverride != null) { inferenceIds = Set.of(inferenceIdOverride); } @@ -307,20 +342,21 @@ private QueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewri // The inference results map is fully populated, so we can perform error checking inferenceResultsErrorCheck(modifiedInferenceResultsMap); } else { - rewritten = copy(modifiedInferenceResultsMap); + rewritten = copy(modifiedInferenceResultsMap, ccsRequest); } } return rewritten; } - private static Set getInferenceIdsForFields( + private static Set getInferenceIdsForFields( Collection indexMetadataCollection, + String clusterAlias, Map fields, boolean resolveWildcards, boolean useDefaultFields ) { - Set inferenceIds = new HashSet<>(); + Set fullyQualifiedInferenceIds = new HashSet<>(); for (IndexMetadata indexMetadata : indexMetadataCollection) { final Map indexQueryFields = (useDefaultFields && fields.isEmpty()) ? getDefaultFields(indexMetadata.getSettings()) @@ -331,23 +367,34 @@ private static Set getInferenceIdsForFields( if (indexInferenceFields.containsKey(indexQueryField)) { // No wildcards in field name InferenceFieldMetadata inferenceFieldMetadata = indexInferenceFields.get(indexQueryField); - inferenceIds.add(inferenceFieldMetadata.getSearchInferenceId()); + fullyQualifiedInferenceIds.add( + new FullyQualifiedInferenceId(clusterAlias, inferenceFieldMetadata.getSearchInferenceId()) + ); continue; } if (resolveWildcards) { if (Regex.isMatchAllPattern(indexQueryField)) { - indexInferenceFields.values().forEach(ifm -> inferenceIds.add(ifm.getSearchInferenceId())); + indexInferenceFields.values() + .forEach( + ifm -> fullyQualifiedInferenceIds.add( + new FullyQualifiedInferenceId(clusterAlias, ifm.getSearchInferenceId()) + ) + ); } else if (Regex.isSimpleMatchPattern(indexQueryField)) { indexInferenceFields.values() .stream() .filter(ifm -> Regex.simpleMatch(indexQueryField, ifm.getName())) - .forEach(ifm -> inferenceIds.add(ifm.getSearchInferenceId())); + .forEach( + ifm -> fullyQualifiedInferenceIds.add( + new FullyQualifiedInferenceId(clusterAlias, ifm.getSearchInferenceId()) + ) + ); } } } } - return inferenceIds; + return fullyQualifiedInferenceIds; } private static Map getInferenceFieldsMap( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilder.java index 655fb3f790cd9..dab789c0223e7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilder.java @@ -33,6 +33,8 @@ import java.util.List; import java.util.Map; +import static org.elasticsearch.transport.RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY; + public class InterceptedInferenceSparseVectorQueryBuilder extends InterceptedInferenceQueryBuilder { public static final String NAME = "intercepted_inference_sparse_vector"; @@ -43,15 +45,23 @@ public InterceptedInferenceSparseVectorQueryBuilder(SparseVectorQueryBuilder ori super(originalQuery); } + public InterceptedInferenceSparseVectorQueryBuilder( + SparseVectorQueryBuilder originalQuery, + Map inferenceResultsMap + ) { + super(originalQuery, inferenceResultsMap); + } + public InterceptedInferenceSparseVectorQueryBuilder(StreamInput in) throws IOException { super(in); } - InterceptedInferenceSparseVectorQueryBuilder( + private InterceptedInferenceSparseVectorQueryBuilder( InterceptedInferenceQueryBuilder other, - Map inferenceResultsMap + Map inferenceResultsMap, + boolean ccsRequest ) { - super(other, inferenceResultsMap); + super(other, inferenceResultsMap, ccsRequest); } @Override @@ -65,8 +75,14 @@ protected String getQuery() { } @Override - protected String getInferenceIdOverride() { - return originalQuery.getInferenceId(); + protected FullyQualifiedInferenceId getInferenceIdOverride() { + FullyQualifiedInferenceId override = null; + String originalInferenceId = originalQuery.getInferenceId(); + if (originalInferenceId != null) { + override = new FullyQualifiedInferenceId(LOCAL_CLUSTER_GROUP_KEY, originalInferenceId); + } + + return override; } @Override @@ -96,8 +112,8 @@ protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) { } @Override - protected QueryBuilder copy(Map inferenceResultsMap) { - return new InterceptedInferenceSparseVectorQueryBuilder(this, inferenceResultsMap); + protected QueryBuilder copy(Map inferenceResultsMap, boolean ccsRequest) { + return new InterceptedInferenceSparseVectorQueryBuilder(this, inferenceResultsMap, ccsRequest); } @Override @@ -113,7 +129,7 @@ protected QueryBuilder queryFields( } else if (fieldType instanceof SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType) { rewritten = querySemanticTextField(indexMetadataContext.getLocalClusterAlias(), semanticTextFieldType); } else { - rewritten = queryNonSemanticTextField(indexMetadataContext.getLocalClusterAlias()); + rewritten = queryNonSemanticTextField(); } return rewritten; @@ -149,12 +165,12 @@ private QueryBuilder querySemanticTextField(String clusterAlias, SemanticTextFie List queryVector = originalQuery.getQueryVectors(); if (queryVector == null) { - String inferenceId = originalQuery.getInferenceId(); - if (inferenceId == null) { - inferenceId = semanticTextFieldType.getSearchInferenceId(); + FullyQualifiedInferenceId fullyQualifiedInferenceId = getInferenceIdOverride(); + if (fullyQualifiedInferenceId == null) { + fullyQualifiedInferenceId = new FullyQualifiedInferenceId(clusterAlias, semanticTextFieldType.getSearchInferenceId()); } - queryVector = getQueryVector(clusterAlias, inferenceId); + queryVector = getQueryVector(fullyQualifiedInferenceId); } SparseVectorQueryBuilder innerSparseVectorQuery = new SparseVectorQueryBuilder( @@ -171,15 +187,15 @@ private QueryBuilder querySemanticTextField(String clusterAlias, SemanticTextFie .queryName(originalQuery.queryName()); } - private QueryBuilder queryNonSemanticTextField(String clusterAlias) { + private QueryBuilder queryNonSemanticTextField() { List queryVector = originalQuery.getQueryVectors(); if (queryVector == null) { - String inferenceId = originalQuery.getInferenceId(); - if (inferenceId == null) { + FullyQualifiedInferenceId fullyQualifiedInferenceId = getInferenceIdOverride(); + if (fullyQualifiedInferenceId == null) { throw new IllegalArgumentException("Either query vector or inference ID must be specified"); } - queryVector = getQueryVector(clusterAlias, inferenceId); + queryVector = getQueryVector(fullyQualifiedInferenceId); } return new SparseVectorQueryBuilder( @@ -192,10 +208,10 @@ private QueryBuilder queryNonSemanticTextField(String clusterAlias) { ).boost(originalQuery.boost()).queryName(originalQuery.queryName()); } - private List getQueryVector(String clusterAlias, String inferenceId) { - InferenceResults inferenceResults = inferenceResultsMap.get(new FullyQualifiedInferenceId(clusterAlias, inferenceId)); + private List getQueryVector(FullyQualifiedInferenceId fullyQualifiedInferenceId) { + InferenceResults inferenceResults = inferenceResultsMap.get(fullyQualifiedInferenceId); if (inferenceResults == null) { - throw new IllegalStateException("Could not find inference results from inference endpoint [" + inferenceId + "]"); + throw new IllegalStateException("Could not find inference results from inference endpoint [" + fullyQualifiedInferenceId + "]"); } else if (inferenceResults instanceof TextExpansionResults == false) { throw new IllegalArgumentException( "Expected query inference results to be of type [" diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/LegacySemanticQueryRewriteInterceptor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/LegacySemanticQueryRewriteInterceptor.java index 670d846c8d4a9..4052e93559437 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/LegacySemanticQueryRewriteInterceptor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/LegacySemanticQueryRewriteInterceptor.java @@ -16,6 +16,7 @@ import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.TermsQueryBuilder; import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor; +import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import java.util.ArrayList; import java.util.Collection; @@ -24,6 +25,8 @@ import java.util.Map; import java.util.stream.Collectors; +import static org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder.SEMANTIC_SEARCH_CCS_SUPPORT; + /** * Intercepts and adapts a query to be rewritten to work seamlessly on a semantic_text field. */ @@ -46,15 +49,26 @@ public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilde if (indexInformation.getInferenceIndices().isEmpty()) { // No inference fields were identified, so return the original query. return queryBuilder; - } else if (indexInformation.nonInferenceIndices().isEmpty() == false) { - // Combined case where the field name requested by this query contains both - // semantic_text and non-inference fields, so we have to combine queries per index - // containing each field type. - return buildCombinedInferenceAndNonInferenceQuery(queryBuilder, indexInformation); + } else if (resolvedIndices.getRemoteClusterIndices().isEmpty()) { + if (indexInformation.nonInferenceIndices().isEmpty() == false) { + // Combined case where the field name requested by this query contains both + // semantic_text and non-inference fields, so we have to combine queries per index + // containing each field type. + return buildCombinedInferenceAndNonInferenceQuery(queryBuilder, indexInformation); + } else { + // The only fields we've identified are inference fields (e.g. semantic_text), + // so rewrite the entire query to work on a semantic_text field. + return buildInferenceQuery(queryBuilder, indexInformation); + } } else { - // The only fields we've identified are inference fields (e.g. semantic_text), - // so rewrite the entire query to work on a semantic_text field. - return buildInferenceQuery(queryBuilder, indexInformation); + throw new IllegalArgumentException( + getQueryName() + + " query does not support cross-cluster search when querying a [" + + SemanticTextFieldMapper.CONTENT_TYPE + + "] field in a mixed-version cluster. Please update all nodes to at least Elasticsearch " + + SEMANTIC_SEARCH_CCS_SUPPORT.toReleaseVersion() + + "." + ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index d00b3804a9920..97d0caef98d0e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -68,6 +68,7 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder inferenceResultsMap; private final Boolean lenient; + // ccsRequest is only used on the local cluster coordinator node to detect when: + // - The request references a remote index + // - The remote cluster is too old to support semantic search CCS + // It doesn't technically need to be serialized since it is only used for this purpose, but we do so to keep its behavior in line with + // standard query member variables. + private final boolean ccsRequest; + public SemanticQueryBuilder(String fieldName, String query) { this(fieldName, query, null); } public SemanticQueryBuilder(String fieldName, String query, Boolean lenient) { - this(fieldName, query, lenient, null); + this(fieldName, query, lenient, null, false); } protected SemanticQueryBuilder( @@ -107,6 +115,16 @@ protected SemanticQueryBuilder( String query, Boolean lenient, Map inferenceResultsMap + ) { + this(fieldName, query, lenient, inferenceResultsMap, false); + } + + protected SemanticQueryBuilder( + String fieldName, + String query, + Boolean lenient, + Map inferenceResultsMap, + boolean ccsRequest ) { if (fieldName == null) { throw new IllegalArgumentException("[" + NAME + "] requires a " + FIELD_FIELD.getPreferredName() + " value"); @@ -118,6 +136,7 @@ protected SemanticQueryBuilder( this.query = query; this.inferenceResultsMap = inferenceResultsMap != null ? Map.copyOf(inferenceResultsMap) : null; this.lenient = lenient; + this.ccsRequest = ccsRequest; } public SemanticQueryBuilder(StreamInput in) throws IOException { @@ -142,6 +161,11 @@ public SemanticQueryBuilder(StreamInput in) throws IOException { } else { this.lenient = null; } + if (in.getTransportVersion().supports(SEMANTIC_SEARCH_CCS_SUPPORT)) { + this.ccsRequest = in.readBoolean(); + } else { + this.ccsRequest = false; + } } @Override @@ -176,9 +200,24 @@ protected void doWriteTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_QUERY_LENIENT)) { out.writeOptionalBoolean(lenient); } + if (out.getTransportVersion().supports(SEMANTIC_SEARCH_CCS_SUPPORT)) { + out.writeBoolean(ccsRequest); + } else if (ccsRequest) { + throw new IllegalArgumentException( + "One or more nodes does not support " + + NAME + + " query cross-cluster search. Please update all nodes to at least Elasticsearch " + + SEMANTIC_SEARCH_CCS_SUPPORT.toReleaseVersion() + + "." + ); + } } - private SemanticQueryBuilder(SemanticQueryBuilder other, Map inferenceResultsMap) { + private SemanticQueryBuilder( + SemanticQueryBuilder other, + Map inferenceResultsMap, + boolean ccsRequest + ) { this.fieldName = other.fieldName; this.query = other.query; this.boost = other.boost; @@ -186,6 +225,7 @@ private SemanticQueryBuilder(SemanticQueryBuilder other, Map - * Get inference results for the provided query using the provided inference IDs. The inference IDs are fully qualified by the - * cluster alias in the provided {@link QueryRewriteContext}. + * Get inference results for the provided query using the provided fully qualified inference IDs. *

*

* This method will return an inference results map that will be asynchronously populated with inference results. If the provided @@ -222,14 +261,14 @@ public static SemanticQueryBuilder fromXContent(XContentParser parser) throws IO *

* * @param queryRewriteContext The query rewrite context - * @param inferenceIds The inference IDs to use to generate inference results + * @param fullyQualifiedInferenceIds The fully qualified inference IDs to use to generate inference results * @param inferenceResultsMap The initial inference results map * @param query The query to generate inference results for * @return An inference results map */ static Map getInferenceResults( QueryRewriteContext queryRewriteContext, - Set inferenceIds, + Set fullyQualifiedInferenceIds, @Nullable Map inferenceResultsMap, @Nullable String query ) { @@ -239,12 +278,19 @@ static Map getInferenceResults( : Map.of(); if (query != null) { - for (String inferenceId : inferenceIds) { - FullyQualifiedInferenceId fullyQualifiedInferenceId = new FullyQualifiedInferenceId( - queryRewriteContext.getLocalClusterAlias(), - inferenceId - ); + for (FullyQualifiedInferenceId fullyQualifiedInferenceId : fullyQualifiedInferenceIds) { if (currentInferenceResultsMap.containsKey(fullyQualifiedInferenceId) == false) { + if (fullyQualifiedInferenceId.clusterAlias().equals(queryRewriteContext.getLocalClusterAlias()) == false) { + // Catch if we are missing inference results that should have been generated on another cluster + throw new IllegalStateException( + "Cannot get inference results for inference endpoint [" + + fullyQualifiedInferenceId + + "] on cluster [" + + queryRewriteContext.getLocalClusterAlias() + + "]" + ); + } + if (modifiedInferenceResultsMap == false) { // Copy the inference results map to ensure it is mutable and thread safe currentInferenceResultsMap = new ConcurrentHashMap<>(currentInferenceResultsMap); @@ -255,7 +301,7 @@ static Map getInferenceResults( queryRewriteContext, ((ConcurrentHashMap) currentInferenceResultsMap), query, - inferenceId + fullyQualifiedInferenceId.inferenceId() ); } } @@ -406,16 +452,23 @@ private QueryBuilder doRewriteBuildSemanticQuery(SearchExecutionContext searchEx private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewriteContext) { ResolvedIndices resolvedIndices = queryRewriteContext.getResolvedIndices(); - if (resolvedIndices.getRemoteClusterIndices().isEmpty() == false) { - throw new IllegalArgumentException(NAME + " query does not support cross-cluster search"); + boolean ccsRequest = resolvedIndices.getRemoteClusterIndices().isEmpty() == false; + if (ccsRequest && queryRewriteContext.isCcsMinimizeRoundTrips() == false) { + throw new IllegalArgumentException( + NAME + " query does not support cross-cluster search when [ccs_minimize_roundtrips] is false" + ); } SemanticQueryBuilder rewritten = this; if (queryRewriteContext.hasAsyncActions() == false) { - Set inferenceIds = getInferenceIdsForForField(resolvedIndices.getConcreteLocalIndicesMetadata().values(), fieldName); + Set fullyQualifiedInferenceIds = getInferenceIdsForField( + resolvedIndices.getConcreteLocalIndicesMetadata().values(), + queryRewriteContext.getLocalClusterAlias(), + fieldName + ); Map modifiedInferenceResultsMap = getInferenceResults( queryRewriteContext, - inferenceIds, + fullyQualifiedInferenceIds, inferenceResultsMap, query ); @@ -424,7 +477,7 @@ private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext qu // The inference results map is fully populated, so we can perform error checking inferenceResultsErrorCheck(modifiedInferenceResultsMap); } else { - rewritten = new SemanticQueryBuilder(this, modifiedInferenceResultsMap); + rewritten = new SemanticQueryBuilder(this, modifiedInferenceResultsMap, ccsRequest); } } @@ -502,28 +555,33 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { throw new IllegalStateException(NAME + " should have been rewritten to another query type"); } - private static Set getInferenceIdsForForField(Collection indexMetadataCollection, String fieldName) { - Set inferenceIds = new HashSet<>(); + private static Set getInferenceIdsForField( + Collection indexMetadataCollection, + String clusterAlias, + String fieldName + ) { + Set fullyQualifiedInferenceIds = new HashSet<>(); for (IndexMetadata indexMetadata : indexMetadataCollection) { InferenceFieldMetadata inferenceFieldMetadata = indexMetadata.getInferenceFields().get(fieldName); String indexInferenceId = inferenceFieldMetadata != null ? inferenceFieldMetadata.getSearchInferenceId() : null; if (indexInferenceId != null) { - inferenceIds.add(indexInferenceId); + fullyQualifiedInferenceIds.add(new FullyQualifiedInferenceId(clusterAlias, indexInferenceId)); } } - return inferenceIds; + return fullyQualifiedInferenceIds; } @Override protected boolean doEquals(SemanticQueryBuilder other) { return Objects.equals(fieldName, other.fieldName) && Objects.equals(query, other.query) - && Objects.equals(inferenceResultsMap, other.inferenceResultsMap); + && Objects.equals(inferenceResultsMap, other.inferenceResultsMap) + && Objects.equals(ccsRequest, other.ccsRequest); } @Override protected int doHashCode() { - return Objects.hash(fieldName, query, inferenceResultsMap); + return Objects.hash(fieldName, query, inferenceResultsMap, ccsRequest); } } diff --git a/x-pack/plugin/inference/src/main/plugin-metadata/entitlement-policy.yaml b/x-pack/plugin/inference/src/main/plugin-metadata/entitlement-policy.yaml index 36ac851acf1ea..cd046c4a6cb23 100644 --- a/x-pack/plugin/inference/src/main/plugin-metadata/entitlement-policy.yaml +++ b/x-pack/plugin/inference/src/main/plugin-metadata/entitlement-policy.yaml @@ -12,6 +12,7 @@ software.amazon.awssdk.http.nio.netty: io.netty.common: - outbound_network - manage_threads + - inbound_network - files: - path: "/etc/os-release" mode: "read" @@ -22,3 +23,4 @@ io.netty.common: io.netty.transport: - manage_threads - outbound_network + - inbound_network diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/AbstractInterceptedInferenceQueryBuilderTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/AbstractInterceptedInferenceQueryBuilderTestCase.java index e83a88e714a98..f999c0f89ae90 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/AbstractInterceptedInferenceQueryBuilderTestCase.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/AbstractInterceptedInferenceQueryBuilderTestCase.java @@ -9,7 +9,6 @@ import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.TransportVersion; -import org.elasticsearch.TransportVersions; import org.elasticsearch.action.MockResolvedIndices; import org.elasticsearch.action.OriginalIndices; import org.elasticsearch.action.ResolvedIndices; @@ -23,6 +22,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.regex.Regex; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.CheckedRunnable; import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.IndexVersion; @@ -65,13 +65,15 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.BiConsumer; import java.util.function.Supplier; +import static org.elasticsearch.TransportVersions.NEW_SEMANTIC_QUERY_INTERCEPTORS; +import static org.elasticsearch.TransportVersions.V_8_15_0; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD; import static org.elasticsearch.xpack.inference.queries.InterceptedInferenceQueryBuilder.INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS; -import static org.hamcrest.Matchers.containsString; +import static org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder.SEMANTIC_SEARCH_CCS_SUPPORT; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.spy; @@ -169,34 +171,105 @@ public void testBwCSerialization() throws Exception { } } - public void testCcs() throws Exception { - final String field = "semantic_field"; - final QueryRewriteContext queryRewriteContext = createQueryRewriteContext( - Map.of("local-index", Map.of(field, SPARSE_INFERENCE_ID)), + public void testCcsSerialization() throws Exception { + final String inferenceField = "semantic_field"; + final T inferenceFieldQuery = createQueryBuilder(inferenceField); + final T nonInferenceFieldQuery = createQueryBuilder("non_inference_field"); + + // Test with the current transport version. This simulates sending the query to a remote cluster that supports semantic search CCS. + final QueryRewriteContext contextCurrent = createQueryRewriteContext( + Map.of("local-index", Map.of(inferenceField, SPARSE_INFERENCE_ID)), Map.of("remote-alias", "remote-index"), - TransportVersion.current() + TransportVersion.current(), + true ); - // Test querying a semantic text field - final T semanticFieldQuery = createQueryBuilder(field); - IllegalArgumentException e = assertThrows( - IllegalArgumentException.class, - () -> rewriteAndFetch(semanticFieldQuery, queryRewriteContext) + assertRewriteAndSerializeOnInferenceField(inferenceFieldQuery, contextCurrent, null, null); + assertRewriteAndSerializeOnNonInferenceField(nonInferenceFieldQuery, contextCurrent); + } + + public void testCcsSerializationWithMinimizeRoundTripsFalse() throws Exception { + final String inferenceField = "semantic_field"; + final T inferenceFieldQuery = createQueryBuilder(inferenceField); + final T nonInferenceFieldQuery = createQueryBuilder("non_inference_field"); + + final QueryRewriteContext minimizeRoundTripsFalseContext = createQueryRewriteContext( + Map.of("local-index", Map.of(inferenceField, SPARSE_INFERENCE_ID)), + Map.of("remote-alias", "remote-index"), + TransportVersion.current(), + false ); - assertThat( - e.getMessage(), - containsString( - semanticFieldQuery.getName() + " query does not support cross-cluster search when querying a [semantic_text] field" - ) + + assertRewriteAndSerializeOnInferenceField( + inferenceFieldQuery, + minimizeRoundTripsFalseContext, + new IllegalArgumentException( + inferenceFieldQuery.getName() + + " query does not support cross-cluster search when querying a [" + + SemanticTextFieldMapper.CONTENT_TYPE + + "] field when [ccs_minimize_roundtrips] is false" + ), + null ); + assertRewriteAndSerializeOnNonInferenceField(nonInferenceFieldQuery, minimizeRoundTripsFalseContext); + } - // Test querying a non-inference field + public void testCcsBwCSerialization() throws Exception { + final String inferenceField = "semantic_field"; + final T inferenceFieldQuery = createQueryBuilder(inferenceField); final T nonInferenceFieldQuery = createQueryBuilder("non_inference_field"); - QueryBuilder coordinatorRewritten = rewriteAndFetch(nonInferenceFieldQuery, queryRewriteContext); - // Use a serialization cycle to strip InterceptedQueryBuilderWrapper - coordinatorRewritten = copyNamedWriteable(coordinatorRewritten, writableRegistry(), QueryBuilder.class); - assertCoordinatorNodeRewriteOnNonInferenceField(nonInferenceFieldQuery, coordinatorRewritten); + for (int i = 0; i < 100; i++) { + TransportVersion transportVersion = TransportVersionUtils.randomVersionBetween( + random(), + V_8_15_0, + TransportVersionUtils.getPreviousVersion(TransportVersion.current()) + ); + + QueryRewriteContext queryRewriteContext = createQueryRewriteContext( + Map.of("local-index", Map.of(inferenceField, SPARSE_INFERENCE_ID)), + Map.of("remote-alias", "remote-index"), + transportVersion, + true + ); + + Exception expectedRewriteException = null; + Exception expectedSerializationException = null; + if (transportVersion.supports(SEMANTIC_SEARCH_CCS_SUPPORT) == false) { + if (transportVersion.supports(NEW_SEMANTIC_QUERY_INTERCEPTORS)) { + // Transport version is new enough to support the new interceptors, but not new enough to support CCS. This simulates if + // one of the local or remote cluster data nodes is out of date. + expectedSerializationException = new IllegalArgumentException( + "One or more nodes does not support " + + inferenceFieldQuery.getName() + + " query cross-cluster search when querying a [" + + SemanticTextFieldMapper.CONTENT_TYPE + + "] field. Please update all nodes to at least Elasticsearch " + + SEMANTIC_SEARCH_CCS_SUPPORT.toReleaseVersion() + + "." + ); + } else { + // Transport version indicates usage of the legacy interceptors. This simulates if one of the local cluster data nodes + // is out of date to the point that it can't use the new interceptors. + expectedRewriteException = new IllegalArgumentException( + inferenceFieldQuery.getName() + + " query does not support cross-cluster search when querying a [" + + SemanticTextFieldMapper.CONTENT_TYPE + + "] field in a mixed-version cluster. Please update all nodes to at least Elasticsearch " + + SEMANTIC_SEARCH_CCS_SUPPORT.toReleaseVersion() + + "." + ); + } + } + + assertRewriteAndSerializeOnInferenceField( + inferenceFieldQuery, + queryRewriteContext, + expectedRewriteException, + expectedSerializationException + ); + assertRewriteAndSerializeOnNonInferenceField(nonInferenceFieldQuery, queryRewriteContext); + } } public void testSerializationRemoteClusterInferenceResults() throws Exception { @@ -230,7 +303,7 @@ public void testSerializationRemoteClusterInferenceResults() throws Exception { // Test with a transport version prior to cluster alias support, which should fail TransportVersion transportVersion = TransportVersionUtils.randomVersionBetween( random(), - TransportVersions.NEW_SEMANTIC_QUERY_INTERCEPTORS, + NEW_SEMANTIC_QUERY_INTERCEPTORS, TransportVersionUtils.getPreviousVersion(INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS) ); IllegalArgumentException e = assertThrows( @@ -270,7 +343,7 @@ protected abstract void assertCoordinatorNodeRewriteOnInferenceField( QueryBuilder rewritten, TransportVersion transportVersion, QueryRewriteContext queryRewriteContext - ); + ) throws Exception; protected abstract void assertCoordinatorNodeRewriteOnNonInferenceField(QueryBuilder original, QueryBuilder rewritten); @@ -291,48 +364,28 @@ protected void serializationTestCase(TransportVersion transportVersion) throws E final QueryRewriteContext queryRewriteContext = createQueryRewriteContext( Map.of(testIndex1.name(), testIndex1.semanticTextFields(), testIndex2.name(), testIndex2.semanticTextFields()), Map.of(), - transportVersion + transportVersion, + null ); - // Disable query interception when checking the results of coordinator node rewrite so that the query rewrite context can be used - // to populate inference results without triggering another query interception. In production this is achieved by wrapping with - // InterceptedQueryBuilderWrapper, but we do not have access to that in this test. - final BiConsumer disableQueryInterception = (c, r) -> { - QueryRewriteInterceptor interceptor = c.getQueryRewriteInterceptor(); - c.setQueryRewriteInterceptor(null); - r.run(); - c.setQueryRewriteInterceptor(interceptor); - }; - // Query a semantic text field in both indices QueryBuilder originalSemantic = createQueryBuilder(semanticField); - QueryBuilder rewrittenSemantic = rewriteAndFetch(originalSemantic, queryRewriteContext); - QueryBuilder serializedSemantic = copyNamedWriteable(rewrittenSemantic, writableRegistry(), QueryBuilder.class); - disableQueryInterception.accept( - queryRewriteContext, - () -> assertCoordinatorNodeRewriteOnInferenceField(originalSemantic, serializedSemantic, transportVersion, queryRewriteContext) - ); + assertRewriteAndSerializeOnInferenceField(originalSemantic, queryRewriteContext, null, null); // Query a field that is a semantic text field in one index QueryBuilder originalMixed = createQueryBuilder(mixedField); - QueryBuilder rewrittenMixed = rewriteAndFetch(originalMixed, queryRewriteContext); - QueryBuilder serializedMixed = copyNamedWriteable(rewrittenMixed, writableRegistry(), QueryBuilder.class); - disableQueryInterception.accept( - queryRewriteContext, - () -> assertCoordinatorNodeRewriteOnInferenceField(originalMixed, serializedMixed, transportVersion, queryRewriteContext) - ); + assertRewriteAndSerializeOnInferenceField(originalMixed, queryRewriteContext, null, null); // Query a text field in both indices QueryBuilder originalText = createQueryBuilder(textField); - QueryBuilder rewrittenText = rewriteAndFetch(originalText, queryRewriteContext); - QueryBuilder serializedText = copyNamedWriteable(rewrittenText, writableRegistry(), QueryBuilder.class); - assertCoordinatorNodeRewriteOnNonInferenceField(originalText, serializedText); + assertRewriteAndSerializeOnNonInferenceField(originalText, queryRewriteContext); } protected QueryRewriteContext createQueryRewriteContext( Map> localIndexInferenceFields, Map remoteIndexNames, - TransportVersion minTransportVersion + TransportVersion minTransportVersion, + Boolean ccsMinimizeRoundTrips ) { Map indexMetadata = new HashMap<>(); for (var indexEntry : localIndexInferenceFields.entrySet()) { @@ -385,7 +438,7 @@ protected QueryRewriteContext createQueryRewriteContext( resolvedIndices, null, QueryRewriteInterceptor.multi(interceptorMap), - null + ccsMinimizeRoundTrips ); } @@ -465,12 +518,90 @@ protected QueryRewriteContext createIndexMetadataContext( } } + protected void assertRewriteAndSerializeOnInferenceField( + QueryBuilder originalQuery, + QueryRewriteContext queryRewriteContext, + Exception expectedRewriteException, + Exception expectedSerializationException + ) throws Exception { + if (expectedRewriteException != null) { + Exception actualException = assertThrows(Exception.class, () -> rewriteAndFetch(originalQuery, queryRewriteContext)); + assertThat(actualException, instanceOf(expectedRewriteException.getClass())); + assertThat(actualException.getMessage(), equalTo(expectedRewriteException.getMessage())); + return; + } + QueryBuilder rewrittenQuery = rewriteAndFetch(originalQuery, queryRewriteContext); + + TransportVersion serializationTransportVersion = queryRewriteContext.getMinTransportVersion(); + if (expectedSerializationException != null) { + Exception actualException = assertThrows( + Exception.class, + () -> copyNamedWriteable(rewrittenQuery, writableRegistry(), QueryBuilder.class, serializationTransportVersion) + ); + assertThat(actualException, instanceOf(expectedSerializationException.getClass())); + assertThat(actualException.getMessage(), equalTo(expectedSerializationException.getMessage())); + return; + } + QueryBuilder serializedQuery = copyNamedWriteable( + rewrittenQuery, + writableRegistry(), + QueryBuilder.class, + serializationTransportVersion + ); + + // Run the original query through a serialization cycle to account for any BwC logic applied through the transport version + QueryBuilder originalSerializedQuery = copyNamedWriteable( + originalQuery, + writableRegistry(), + QueryBuilder.class, + serializationTransportVersion + ); + + // Disable query interception when checking the results of coordinator node rewrite so that the query rewrite context can be used + // to populate inference results without triggering another query interception. In production this is achieved by wrapping with + // InterceptedQueryBuilderWrapper, but we do not have access to that in this test. + disableQueryInterception( + queryRewriteContext, + () -> assertCoordinatorNodeRewriteOnInferenceField( + originalSerializedQuery, + serializedQuery, + queryRewriteContext.getMinTransportVersion(), + queryRewriteContext + ) + ); + } + + protected void assertRewriteAndSerializeOnNonInferenceField(QueryBuilder originalQuery, QueryRewriteContext queryRewriteContext) + throws IOException { + TransportVersion serializationVersion = queryRewriteContext.getMinTransportVersion(); + + // Run the original query through a serialization cycle to account for any BwC logic applied through the transport version + QueryBuilder originalSerializedQuery = copyNamedWriteable( + originalQuery, + writableRegistry(), + QueryBuilder.class, + serializationVersion + ); + + QueryBuilder rewrittenQuery = rewriteAndFetch(originalQuery, queryRewriteContext); + QueryBuilder serializedQuery = copyNamedWriteable(rewrittenQuery, writableRegistry(), QueryBuilder.class, serializationVersion); + assertCoordinatorNodeRewriteOnNonInferenceField(originalSerializedQuery, serializedQuery); + } + protected static QueryBuilder rewriteAndFetch(QueryBuilder queryBuilder, QueryRewriteContext queryRewriteContext) { PlainActionFuture future = new PlainActionFuture<>(); Rewriteable.rewriteAndFetch(queryBuilder, queryRewriteContext, future); return future.actionGet(); } + protected static void disableQueryInterception(QueryRewriteContext queryRewriteContext, CheckedRunnable runnable) + throws Exception { + QueryRewriteInterceptor interceptor = queryRewriteContext.getQueryRewriteInterceptor(); + queryRewriteContext.setQueryRewriteInterceptor(null); + runnable.run(); + queryRewriteContext.setQueryRewriteInterceptor(interceptor); + } + private static ModelRegistry createModelRegistry(ThreadPool threadPool) { ClusterService clusterService = ClusterServiceUtils.createClusterService(threadPool); ModelRegistry modelRegistry = spy(new ModelRegistry(clusterService, new NoOpClient(threadPool))); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilderTests.java index dcbfff1ff99bf..329444595ef3e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilderTests.java @@ -61,10 +61,7 @@ protected InterceptedInferenceQueryBuilder createIntercep KnnVectorQueryBuilder originalQuery, Map inferenceResultsMap ) { - return new InterceptedInferenceKnnVectorQueryBuilder( - new InterceptedInferenceKnnVectorQueryBuilder(originalQuery), - inferenceResultsMap - ); + return new InterceptedInferenceKnnVectorQueryBuilder(originalQuery, inferenceResultsMap); } @Override @@ -160,7 +157,8 @@ public void testInterceptAndRewrite() throws Exception { final QueryRewriteContext queryRewriteContext = createQueryRewriteContext( Map.of(testIndex1.name(), testIndex1.semanticTextFields(), testIndex2.name(), testIndex2.semanticTextFields()), Map.of(), - TransportVersion.current() + TransportVersion.current(), + null ); QueryBuilder coordinatorRewritten = rewriteAndFetch(knnQuery, queryRewriteContext); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilderTests.java index abf12b2c876d4..c7a680f15d6d1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilderTests.java @@ -33,7 +33,7 @@ protected InterceptedInferenceQueryBuilder createInterceptedQ MatchQueryBuilder originalQuery, Map inferenceResultsMap ) { - return new InterceptedInferenceMatchQueryBuilder(new InterceptedInferenceMatchQueryBuilder(originalQuery), inferenceResultsMap); + return new InterceptedInferenceMatchQueryBuilder(originalQuery, inferenceResultsMap); } @Override @@ -52,7 +52,7 @@ protected void assertCoordinatorNodeRewriteOnInferenceField( QueryBuilder rewritten, TransportVersion transportVersion, QueryRewriteContext queryRewriteContext - ) { + ) throws Exception { assertThat(original, instanceOf(MatchQueryBuilder.class)); if (transportVersion.onOrAfter(TransportVersions.NEW_SEMANTIC_QUERY_INTERCEPTORS)) { assertThat(rewritten, instanceOf(InterceptedInferenceMatchQueryBuilder.class)); @@ -69,7 +69,16 @@ protected void assertCoordinatorNodeRewriteOnInferenceField( original ); QueryBuilder expectedLegacyRewritten = rewriteAndFetch(expectedLegacyIntercepted, queryRewriteContext); - assertThat(rewritten, equalTo(expectedLegacyRewritten)); + + // Run the expected query through a serialization cycle to align the inference results map representations + QueryBuilder expectedLegacySerialized = copyNamedWriteable( + expectedLegacyRewritten, + writableRegistry(), + QueryBuilder.class, + transportVersion + ); + + assertThat(rewritten, equalTo(expectedLegacySerialized)); } } @@ -98,7 +107,8 @@ public void testInterceptAndRewrite() throws Exception { testIndex3.semanticTextFields() ), Map.of(), - TransportVersion.current() + TransportVersion.current(), + null ); QueryBuilder coordinatorRewritten = rewriteAndFetch(matchQuery, queryRewriteContext); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilderTests.java index a0066f5da130d..9a44222b16cc3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilderTests.java @@ -61,10 +61,7 @@ protected InterceptedInferenceQueryBuilder createInter SparseVectorQueryBuilder originalQuery, Map inferenceResultsMap ) { - return new InterceptedInferenceSparseVectorQueryBuilder( - new InterceptedInferenceSparseVectorQueryBuilder(originalQuery), - inferenceResultsMap - ); + return new InterceptedInferenceSparseVectorQueryBuilder(originalQuery, inferenceResultsMap); } @Override @@ -133,7 +130,8 @@ public void testInterceptAndRewrite() throws Exception { final QueryRewriteContext queryRewriteContext = createQueryRewriteContext( Map.of(testIndex1.name(), testIndex1.semanticTextFields(), testIndex2.name(), testIndex2.semanticTextFields()), Map.of(), - TransportVersion.current() + TransportVersion.current(), + null ); QueryBuilder coordinatorRewritten = rewriteAndFetch(sparseVectorQuery, queryRewriteContext); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java index 05de3ba7d69ac..b2d7218720a57 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java @@ -95,6 +95,8 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD; import static org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder.INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS; import static org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder.SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS_TV; +import static org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder.SEMANTIC_SEARCH_CCS_SUPPORT; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.notNullValue; @@ -496,6 +498,40 @@ public void testSerializationBwc() throws IOException { } } + public void testSerializationCcs() throws Exception { + SemanticQueryBuilder originalQuery = new SemanticQueryBuilder(randomAlphaOfLength(5), randomAlphaOfLength(5), null, Map.of(), true); + QueryBuilder deserializedQuery = copyNamedWriteable(originalQuery, namedWriteableRegistry(), QueryBuilder.class); + assertThat(deserializedQuery, equalTo(originalQuery)); + } + + public void testSerializationCcsBwc() throws Exception { + SemanticQueryBuilder originalQuery = new SemanticQueryBuilder(randomAlphaOfLength(5), randomAlphaOfLength(5), null, Map.of(), true); + + for (int i = 0; i < 100; i++) { + TransportVersion transportVersion = TransportVersionUtils.randomVersionBetween( + random(), + originalQuery.getMinimalSupportedVersion(), + TransportVersionUtils.getPreviousVersion(TransportVersion.current()) + ); + + if (transportVersion.supports(SEMANTIC_SEARCH_CCS_SUPPORT)) { + QueryBuilder deserializedQuery = copyNamedWriteable( + originalQuery, + namedWriteableRegistry(), + QueryBuilder.class, + transportVersion + ); + assertThat(deserializedQuery, equalTo(originalQuery)); + } else { + IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> copyNamedWriteable(originalQuery, namedWriteableRegistry(), QueryBuilder.class, transportVersion) + ); + assertThat(e.getMessage(), containsString("One or more nodes does not support semantic query cross-cluster search")); + } + } + } + public void testToXContent() throws IOException { QueryBuilder queryBuilder = new SemanticQueryBuilder("foo", "bar"); checkGeneratedJson("""