diff --git a/server/src/main/java/org/elasticsearch/index/query/QueryRewriteContext.java b/server/src/main/java/org/elasticsearch/index/query/QueryRewriteContext.java index d15140be0c8a0..48cb9893c4f31 100644 --- a/server/src/main/java/org/elasticsearch/index/query/QueryRewriteContext.java +++ b/server/src/main/java/org/elasticsearch/index/query/QueryRewriteContext.java @@ -12,15 +12,18 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ResolvedIndices; import org.elasticsearch.client.internal.Client; +import org.elasticsearch.client.internal.RemoteClusterClient; import org.elasticsearch.cluster.metadata.DataStream; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.routing.allocation.DataTier; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.TriConsumer; import org.elasticsearch.common.collect.Iterators; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.regex.Regex; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.CountDown; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.Nullable; import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexSettings; @@ -36,11 +39,13 @@ import org.elasticsearch.search.aggregations.support.ValuesSourceRegistry; import org.elasticsearch.search.builder.PointInTimeBuilder; import org.elasticsearch.transport.RemoteClusterAware; +import org.elasticsearch.transport.RemoteClusterService; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -52,6 +57,8 @@ import java.util.function.Predicate; import java.util.stream.Collectors; +import static org.elasticsearch.threadpool.ThreadPool.Names.SEARCH_COORDINATION; + /** * Context object used to rewrite {@link QueryBuilder} instances into simplified version. */ @@ -72,6 +79,8 @@ public class QueryRewriteContext { protected final Client client; protected final LongSupplier nowInMillis; private final List>> asyncActions = new ArrayList<>(); + private final Map>>> remoteAsyncActions = + new HashMap<>(); protected boolean allowUnmappedFields; protected boolean mapUnmappedFieldAsString; protected Predicate allowedFields; @@ -357,11 +366,22 @@ public void registerAsyncAction(BiConsumer> asyncActio asyncActions.add(asyncAction); } + public void registerRemoteAsyncAction( + String clusterAlias, + TriConsumer> asyncAction + ) { + List>> asyncActions = remoteAsyncActions.computeIfAbsent( + clusterAlias, + k -> new ArrayList<>() + ); + asyncActions.add(asyncAction); + } + /** * Returns true if there are any registered async actions. */ public boolean hasAsyncActions() { - return asyncActions.isEmpty() == false; + return asyncActions.isEmpty() == false || remoteAsyncActions.isEmpty() == false; } /** @@ -369,10 +389,15 @@ public boolean hasAsyncActions() { * null. The list of registered actions is cleared once this method returns. */ public void executeAsyncActions(ActionListener listener) { - if (asyncActions.isEmpty()) { + if (asyncActions.isEmpty() && remoteAsyncActions.isEmpty()) { listener.onResponse(null); } else { - CountDown countDown = new CountDown(asyncActions.size()); + int actionCount = asyncActions.size(); + for (var actionList : remoteAsyncActions.values()) { + actionCount += actionList.size(); + } + + CountDown countDown = new CountDown(actionCount); ActionListener internalListener = new ActionListener<>() { @Override public void onResponse(Object o) { @@ -388,12 +413,30 @@ public void onFailure(Exception e) { } } }; + // make a copy to prevent concurrent modification exception List>> biConsumers = new ArrayList<>(asyncActions); asyncActions.clear(); for (BiConsumer> action : biConsumers) { action.accept(client, internalListener); } + + var remoteAsyncActionsCopy = new HashMap<>(remoteAsyncActions); + remoteAsyncActions.clear(); + for (var entry : remoteAsyncActionsCopy.entrySet()) { + String clusterAlias = entry.getKey(); + List>> remoteTriConsumers = entry.getValue(); + + RemoteClusterClient remoteClient = client.getRemoteClusterClient( + clusterAlias, + client.threadPool().executor(SEARCH_COORDINATION), + RemoteClusterService.DisconnectedStrategy.RECONNECT_UNLESS_SKIP_UNAVAILABLE + ); + ThreadContext threadContext = client.threadPool().getThreadContext(); + for (var action : remoteTriConsumers) { + action.apply(remoteClient, threadContext, internalListener); + } + } } } diff --git a/test/framework/src/main/java/org/elasticsearch/test/AbstractMultiClustersTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/AbstractMultiClustersTestCase.java index e918f05e50816..b9162487a528b 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/AbstractMultiClustersTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/AbstractMultiClustersTestCase.java @@ -15,6 +15,7 @@ import org.elasticsearch.action.admin.cluster.remote.TransportRemoteInfoAction; import org.elasticsearch.action.admin.cluster.settings.ClusterUpdateSettingsResponse; import org.elasticsearch.client.internal.Client; +import org.elasticsearch.client.internal.OriginSettingClient; import org.elasticsearch.common.network.NetworkModule; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.TransportAddress; @@ -94,6 +95,24 @@ protected boolean reuseClusters() { return true; } + protected NodeConfigurationSource nodeConfigurationSource() { + return null; + } + + protected String internalClientOrigin() { + return null; + } + + private Client internalClient() { + return internalClient(LOCAL_CLUSTER); + } + + private Client internalClient(String clusterAlias) { + String internalClientOrigin = internalClientOrigin(); + Client client = client(clusterAlias); + return internalClientOrigin != null ? new OriginSettingClient(client, internalClientOrigin) : client; + } + @Before public final void startClusters() throws Exception { if (clusterGroup != null && reuseClusters()) { @@ -129,7 +148,7 @@ public final void startClusters() throws Exception { mockPlugins, Function.identity(), TEST_ENTITLEMENTS::addEntitledNodePaths - ); + ).internalClientOrigin(internalClientOrigin()); try { cluster.beforeTest(random()); } catch (Exception e) { @@ -170,7 +189,11 @@ protected void removeRemoteCluster(String clusterAlias) throws Exception { settings.putNull("cluster.remote." + clusterAlias + ".seeds"); settings.putNull("cluster.remote." + clusterAlias + ".mode"); settings.putNull("cluster.remote." + clusterAlias + ".proxy_address"); - client().admin().cluster().prepareUpdateSettings(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT).setPersistentSettings(settings).get(); + internalClient().admin() + .cluster() + .prepareUpdateSettings(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT) + .setPersistentSettings(settings) + .get(); assertBusy(() -> { for (TransportService transportService : cluster(LOCAL_CLUSTER).getInstances(TransportService.class)) { assertThat(transportService.getRemoteClusterService().getRegisteredRemoteClusterNames(), not(contains(clusterAlias))); @@ -222,7 +245,7 @@ protected void configureRemoteClusterWithSeedAddresses(String clusterAlias, Coll } builder.build(); - ClusterUpdateSettingsResponse resp = client().admin() + ClusterUpdateSettingsResponse resp = internalClient().admin() .cluster() .prepareUpdateSettings(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT) .setPersistentSettings(settings) @@ -233,7 +256,10 @@ protected void configureRemoteClusterWithSeedAddresses(String clusterAlias, Coll } assertBusy(() -> { - List remoteConnectionInfos = client().execute(TransportRemoteInfoAction.TYPE, new RemoteInfoRequest()) + List remoteConnectionInfos = internalClient().execute( + TransportRemoteInfoAction.TYPE, + new RemoteInfoRequest() + ) .actionGet() .getInfos() .stream() @@ -265,27 +291,40 @@ public void close() throws IOException { } } - static NodeConfigurationSource nodeConfigurationSource(Settings nodeSettings, Collection> nodePlugins) { + private NodeConfigurationSource nodeConfigurationSource(Settings nodeSettings, Collection> nodePlugins) { final Settings.Builder builder = Settings.builder(); builder.putList(DISCOVERY_SEED_HOSTS_SETTING.getKey()); // empty list disables a port scan for other nodes builder.putList(DISCOVERY_SEED_PROVIDERS_SETTING.getKey(), "file"); builder.put(NetworkModule.TRANSPORT_TYPE_KEY, getTestTransportType()); - builder.put(nodeSettings); + NodeConfigurationSource nodeConfigurationSource = nodeConfigurationSource(); return new NodeConfigurationSource() { @Override public Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { + if (nodeConfigurationSource != null) { + builder.put(nodeConfigurationSource.nodeSettings(nodeOrdinal, otherSettings)); + } + builder.put(nodeSettings); + return builder.build(); } @Override public Path nodeConfigPath(int nodeOrdinal) { - return null; + return nodeConfigurationSource != null ? nodeConfigurationSource.nodeConfigPath(nodeOrdinal) : null; } @Override public Collection> nodePlugins() { - return nodePlugins; + Collection> plugins; + if (nodeConfigurationSource != null) { + plugins = new ArrayList<>(nodeConfigurationSource.nodePlugins()); + plugins.addAll(nodePlugins); + } else { + plugins = nodePlugins; + } + + return plugins; } }; } diff --git a/test/framework/src/main/java/org/elasticsearch/test/ExternalTestCluster.java b/test/framework/src/main/java/org/elasticsearch/test/ExternalTestCluster.java index d025b3bcc4a95..c62ac7e50c180 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/ExternalTestCluster.java +++ b/test/framework/src/main/java/org/elasticsearch/test/ExternalTestCluster.java @@ -161,6 +161,11 @@ public Client client() { return client; } + @Override + protected Client internalClient() { + return client; + } + @Override public int size() { return httpAddresses.length; @@ -189,7 +194,13 @@ public void close() throws IOException { @Override public void ensureEstimatedStats() { if (size() > 0) { - NodesStatsResponse nodeStats = client().admin().cluster().prepareNodesStats().clear().setBreaker(true).setIndices(true).get(); + NodesStatsResponse nodeStats = internalClient().admin() + .cluster() + .prepareNodesStats() + .clear() + .setBreaker(true) + .setIndices(true) + .get(); for (NodeStats stats : nodeStats.getNodes()) { assertThat( "Fielddata breaker not reset to 0 on node: " + stats.getNode(), diff --git a/test/framework/src/main/java/org/elasticsearch/test/InternalTestCluster.java b/test/framework/src/main/java/org/elasticsearch/test/InternalTestCluster.java index 59bf3fddf13ba..bf98180339ccf 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/InternalTestCluster.java +++ b/test/framework/src/main/java/org/elasticsearch/test/InternalTestCluster.java @@ -31,6 +31,7 @@ import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.action.support.replication.TransportReplicationAction; import org.elasticsearch.client.internal.Client; +import org.elasticsearch.client.internal.OriginSettingClient; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.NodeConnectionsService; @@ -278,6 +279,8 @@ public String toString() { private final int numDataPaths; + private String internalClientOrigin = null; + /** * All nodes started by the cluster will have their name set to nodePrefix followed by a positive number */ @@ -754,6 +757,11 @@ public synchronized void ensureAtMostNumDataNodes(int n) throws IOException { } } + public InternalTestCluster internalClientOrigin(String origin) { + this.internalClientOrigin = origin; + return this; + } + private Settings getNodeSettings(final int nodeId, final long seed, final Settings extraSettings) { final Settings settings = getSettings(nodeId, seed, extraSettings); @@ -874,6 +882,20 @@ public Client client() { return c.client(); } + @Override + protected Client internalClient() { + return internalClient(null); + } + + private Client internalClient(@Nullable String nodeName) { + Client client = nodeName != null ? client(nodeName) : client(); + return makeInternal(client); + } + + private Client makeInternal(Client client) { + return internalClientOrigin != null ? new OriginSettingClient(client, internalClientOrigin) : client; + } + /** * Returns a node client to a data node in the cluster. * Note: use this with care tests should not rely on a certain nodes client. @@ -1282,7 +1304,7 @@ public synchronized void validateClusterFormed() { try { assertBusy(() -> { try { - final boolean timeout = client().admin() + final boolean timeout = internalClient().admin() .cluster() .prepareHealth(TEST_REQUEST_TIMEOUT) .setWaitForEvents(Priority.LANGUID) @@ -1552,7 +1574,7 @@ public void assertSeqNos() throws Exception { */ public void assertSameDocIdsOnShards() throws Exception { assertBusy(() -> { - ClusterState state = client().admin().cluster().prepareState(TEST_REQUEST_TIMEOUT).get().getState(); + ClusterState state = internalClient().admin().cluster().prepareState(TEST_REQUEST_TIMEOUT).get().getState(); for (var indexRoutingTable : state.routingTable().indicesRouting().values()) { for (int i = 0; i < indexRoutingTable.size(); i++) { IndexShardRoutingTable indexShardRoutingTable = indexRoutingTable.shard(i); @@ -1999,7 +2021,7 @@ private Set excludeMasters(Collection nodeAndClients) { logger.info("adding voting config exclusions {} prior to restart/shutdown", excludedNodeNames); try { - client().execute( + internalClient().execute( TransportAddVotingConfigExclusionsAction.TYPE, new AddVotingConfigExclusionsRequest(TEST_REQUEST_TIMEOUT, excludedNodeNames.toArray(Strings.EMPTY_ARRAY)) ).get(); @@ -2016,7 +2038,7 @@ private void removeExclusions(Set excludedNodeIds) { if (autoManageVotingExclusions && excludedNodeIds.isEmpty() == false) { logger.info("removing voting config exclusions for {} after restart/shutdown", excludedNodeIds); try { - Client client = getRandomNodeAndClient(node -> excludedNodeIds.contains(node.name) == false).client(); + Client client = makeInternal(getRandomNodeAndClient(node -> excludedNodeIds.contains(node.name) == false).client()); client.execute( TransportClearVotingConfigExclusionsAction.TYPE, new ClearVotingConfigExclusionsRequest(TEST_REQUEST_TIMEOUT) @@ -2080,7 +2102,7 @@ public String getMasterName(@Nullable String viaNode) { } try { ClusterServiceUtils.awaitClusterState(state -> state.nodes().getMasterNode() != null, clusterService(viaNode)); - final ClusterState state = client(viaNode).admin() + final ClusterState state = internalClient(viaNode).admin() .cluster() .prepareState(TEST_REQUEST_TIMEOUT) .clear() diff --git a/test/framework/src/main/java/org/elasticsearch/test/TestCluster.java b/test/framework/src/main/java/org/elasticsearch/test/TestCluster.java index 4f4d933162733..190e8d90d4f9d 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/TestCluster.java +++ b/test/framework/src/main/java/org/elasticsearch/test/TestCluster.java @@ -85,7 +85,7 @@ public void wipe(Set excludeTemplates) { SubscribableListener .newForked( - l -> client().execute( + l -> internalClient().execute( DeleteDataStreamAction.INSTANCE, new DeleteDataStreamAction.Request(TEST_REQUEST_TIMEOUT, "*").indicesOptions( IndicesOptions.LENIENT_EXPAND_OPEN_CLOSED_HIDDEN @@ -116,7 +116,7 @@ public void wipe(Set excludeTemplates) { private void deleteTemplates(Set excludeTemplates, ActionListener listener) { final SubscribableListener getComposableTemplates = SubscribableListener.newForked( - l -> client().execute( + l -> internalClient().execute( GetComposableIndexTemplateAction.INSTANCE, new GetComposableIndexTemplateAction.Request(TEST_REQUEST_TIMEOUT, "*"), l @@ -124,7 +124,11 @@ private void deleteTemplates(Set excludeTemplates, ActionListener ); final SubscribableListener getComponentTemplates = SubscribableListener.newForked( - l -> client().execute(GetComponentTemplateAction.INSTANCE, new GetComponentTemplateAction.Request(TEST_REQUEST_TIMEOUT, "*"), l) + l -> internalClient().execute( + GetComponentTemplateAction.INSTANCE, + new GetComponentTemplateAction.Request(TEST_REQUEST_TIMEOUT, "*"), + l + ) ); SubscribableListener @@ -143,7 +147,7 @@ private void deleteTemplates(Set excludeTemplates, ActionListener l.onResponse(AcknowledgedResponse.TRUE); } else { var request = new TransportDeleteComposableIndexTemplateAction.Request(templates); - client().execute(TransportDeleteComposableIndexTemplateAction.TYPE, request, l); + internalClient().execute(TransportDeleteComposableIndexTemplateAction.TYPE, request, l); } }) .andThenAccept(ElasticsearchAssertions::assertAcked) @@ -159,7 +163,7 @@ private void deleteTemplates(Set excludeTemplates, ActionListener if (componentTemplates.length == 0) { l.onResponse(AcknowledgedResponse.TRUE); } else { - client().execute( + internalClient().execute( TransportDeleteComponentTemplateAction.TYPE, new TransportDeleteComponentTemplateAction.Request(componentTemplates), l @@ -194,6 +198,11 @@ public void assertAfterTest() throws Exception { */ public abstract Client client(); + /** + * Returns a client connected to any node in the cluster that is authorized to perform cluster management actions + */ + protected abstract Client internalClient(); + /** * Returns the number of nodes in the cluster. */ @@ -236,7 +245,7 @@ private void wipeIndicesAsync(String[] indices, ActionListener listener) { assert indices != null && indices.length > 0; logger.info("---- wiping indices [{}]", Strings.arrayToCommaDelimitedString(indices)); SubscribableListener.newForked( - l -> client().admin() + l -> internalClient().admin() .indices() .prepareDelete(indices) .setIndicesOptions( @@ -260,14 +269,14 @@ private void handleWipeIndicesFailure(Exception exception, boolean wipingAllIndi logger.info("---- retry wiping indices using their concrete names", exception); SubscribableListener - .newForked(l -> client().admin().cluster().prepareState(TEST_REQUEST_TIMEOUT).execute(l)) + .newForked(l -> internalClient().admin().cluster().prepareState(TEST_REQUEST_TIMEOUT).execute(l)) .andThen((l, clusterStateResponse) -> { ArrayList concreteIndices = new ArrayList<>(); for (IndexMetadata indexMetadata : clusterStateResponse.getState().metadata().getProject()) { concreteIndices.add(indexMetadata.getIndex().getName()); } if (concreteIndices.isEmpty() == false) { - client().admin().indices().prepareDelete(concreteIndices.toArray(Strings.EMPTY_ARRAY)).execute(l); + internalClient().admin().indices().prepareDelete(concreteIndices.toArray(Strings.EMPTY_ARRAY)).execute(l); } else { l.onResponse(AcknowledgedResponse.TRUE); } @@ -296,13 +305,15 @@ private void handleWipeIndicesFailure(Exception exception, boolean wipingAllIndi private void wipeAllTemplates(Set exclude, RefCountingListener listeners) { SubscribableListener - .newForked(l -> client().admin().indices().prepareGetTemplates(TEST_REQUEST_TIMEOUT).execute(l)) + .newForked( + l -> internalClient().admin().indices().prepareGetTemplates(TEST_REQUEST_TIMEOUT).execute(l) + ) .andThenAccept(response -> { for (IndexTemplateMetadata indexTemplate : response.getIndexTemplates()) { if (exclude.contains(indexTemplate.getName())) { continue; } - client().admin() + internalClient().admin() .indices() .prepareDeleteTemplate(indexTemplate.getName()) .execute(listeners.acquire(ElasticsearchAssertions::assertAcked).delegateResponse((l, e) -> { @@ -330,7 +341,7 @@ public void wipeTemplates(String... templates) { } for (String template : templates) { try { - client().admin().indices().prepareDeleteTemplate(template).get(); + internalClient().admin().indices().prepareDeleteTemplate(template).get(); } catch (IndexTemplateMissingException e) { // ignore } @@ -345,7 +356,7 @@ private void wipeRepositories(ActionListener listener) { SubscribableListener .newForked( - l -> client().admin() + l -> internalClient().admin() .cluster() .prepareDeleteRepository(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT, "*") .execute(l.delegateResponse((ll, e) -> { diff --git a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/QueryRewriteContextMultiClustersIT.java b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/QueryRewriteContextMultiClustersIT.java new file mode 100644 index 0000000000000..01c1d78425a95 --- /dev/null +++ b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/QueryRewriteContextMultiClustersIT.java @@ -0,0 +1,601 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.integration; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.Query; +import org.elasticsearch.ElasticsearchSecurityException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.ActionFuture; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.OriginalIndices; +import org.elasticsearch.action.RemoteClusterActionType; +import org.elasticsearch.action.ResolvedIndices; +import org.elasticsearch.action.admin.indices.delete.DeleteIndexRequest; +import org.elasticsearch.action.admin.indices.get.GetIndexRequest; +import org.elasticsearch.action.admin.indices.get.GetIndexResponse; +import org.elasticsearch.action.search.SearchRequestBuilder; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.GroupedActionListener; +import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.client.internal.OriginSettingClient; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.query.AbstractQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryRewriteContext; +import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.injection.guice.Inject; +import org.elasticsearch.plugins.ActionPlugin; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.plugins.SearchPlugin; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.test.AbstractMultiClustersTestCase; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.NodeConfigurationSource; +import org.elasticsearch.test.SecuritySettingsSource; +import org.elasticsearch.test.SecuritySettingsSourceField; +import org.elasticsearch.transport.NoSuchRemoteClusterException; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +import static org.elasticsearch.test.SecuritySettingsSource.TEST_USER_NAME; +import static org.elasticsearch.test.SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse; +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; +import static org.elasticsearch.xpack.core.ClientHelper.MONITORING_ORIGIN; +import static org.elasticsearch.xpack.core.ClientHelper.SECURITY_ORIGIN; +import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; +import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.BASIC_AUTH_HEADER; +import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue; +import static org.elasticsearch.xpack.security.support.SecuritySystemIndices.SECURITY_MAIN_ALIAS; +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; + +@ESTestCase.WithoutEntitlements +public class QueryRewriteContextMultiClustersIT extends AbstractMultiClustersTestCase { + private static final String REMOTE_CLUSTER_A = "cluster-a"; + private static final String REMOTE_CLUSTER_B = "cluster-b"; + + private static final String INDEX_1 = "index-1"; + private static final String INDEX_2 = "index-2"; + + private static final ConcurrentHashMap INSTRUMENTED_ACTION_CALL_MAP = new ConcurrentHashMap<>(); + + private final boolean securityEnabled; + + @ParametersFactory + public static Iterable parameters() { + return List.of(new Object[] { true }, new Object[] { false }); + } + + @Override + protected List remoteClusterAlias() { + return List.of(REMOTE_CLUSTER_A, REMOTE_CLUSTER_B); + } + + @Override + protected boolean reuseClusters() { + return false; + } + + @Override + protected Map skipUnavailableForRemoteClusters() { + return Map.of(REMOTE_CLUSTER_A, true, REMOTE_CLUSTER_B, false); + } + + @Override + protected Collection> nodePlugins(String clusterAlias) { + return List.of(TestPlugin.class); + } + + @Override + protected NodeConfigurationSource nodeConfigurationSource() { + return securityEnabled ? new CustomSecuritySettingsSource(false, createTempDir(), ESIntegTestCase.Scope.TEST) : null; + } + + @Override + protected String internalClientOrigin() { + return MONITORING_ORIGIN; + } + + @Before + @Override + public void setUp() throws Exception { + super.setUp(); + INSTRUMENTED_ACTION_CALL_MAP.clear(); + setupClusters(); + } + + @After + public void cleanupSecurityIndex() { + if (securityEnabled) { + deleteSecurityIndex(LOCAL_CLUSTER); + for (String clusterAlias : remoteClusterAlias()) { + deleteSecurityIndex(clusterAlias); + } + } + } + + public QueryRewriteContextMultiClustersIT(boolean securityEnabled) { + this.securityEnabled = securityEnabled; + } + + public void testCallRemoteAsyncActionWithOrigin() { + SearchRequestBuilder allClustersAllIndicesRequest = buildSearchRequest( + List.of(INDEX_1, INDEX_2), + List.of(REMOTE_CLUSTER_A, REMOTE_CLUSTER_B), + ML_ORIGIN + ); + assertSearchResponse(allClustersAllIndicesRequest); + assertInstrumentedActionCalls(2, 2); + + SearchRequestBuilder allClustersSingleIndexRequest = buildSearchRequest( + List.of(INDEX_1), + List.of(REMOTE_CLUSTER_A, REMOTE_CLUSTER_B), + ML_ORIGIN + ); + assertSearchResponse(allClustersSingleIndexRequest); + assertInstrumentedActionCalls(3, 3); + + SearchRequestBuilder singleClusterSingleIndexRequest = buildSearchRequest(List.of(INDEX_1), List.of(REMOTE_CLUSTER_A), ML_ORIGIN); + assertSearchResponse(singleClusterSingleIndexRequest); + assertInstrumentedActionCalls(4, 3); + } + + public void testCallRemoteAsyncActionWithoutOrigin() { + Consumer assertSecurityEnabled = r -> { + assertSearchFailure( + r, + ElasticsearchSecurityException.class, + "action [cluster:internal/test/instrumented] is unauthorized for user [test_user] with effective roles [user]" + ); + assertInstrumentedActionCalls(0, 0); + }; + + SearchRequestBuilder allClustersAllIndicesRequest = buildSearchRequest( + List.of(INDEX_1, INDEX_2), + List.of(REMOTE_CLUSTER_A, REMOTE_CLUSTER_B), + null + ); + if (securityEnabled) { + assertSecurityEnabled.accept(allClustersAllIndicesRequest); + } else { + assertSearchResponse(allClustersAllIndicesRequest); + assertInstrumentedActionCalls(2, 2); + } + + SearchRequestBuilder allClustersSingleIndexRequest = buildSearchRequest( + List.of(INDEX_1), + List.of(REMOTE_CLUSTER_A, REMOTE_CLUSTER_B), + null + ); + if (securityEnabled) { + assertSecurityEnabled.accept(allClustersSingleIndexRequest); + } else { + assertSearchResponse(allClustersSingleIndexRequest); + assertInstrumentedActionCalls(3, 3); + } + + SearchRequestBuilder singleClusterSingleIndexRequest = buildSearchRequest(List.of(INDEX_1), List.of(REMOTE_CLUSTER_A), null); + if (securityEnabled) { + assertSecurityEnabled.accept(singleClusterSingleIndexRequest); + } else { + assertSearchResponse(singleClusterSingleIndexRequest); + assertInstrumentedActionCalls(4, 3); + } + } + + public void testInvalidClusterAlias() { + // TODO: Enable this test for all cases when bug is fixed + assumeFalse("Test is currently broken when security is enabled due to unrelated bug", securityEnabled); + SearchRequestBuilder request = buildSearchRequest( + List.of(INDEX_1, INDEX_2), + List.of(REMOTE_CLUSTER_A, REMOTE_CLUSTER_B, "missing-cluster-alias"), + null + ); + assertSearchFailure(request, NoSuchRemoteClusterException.class, "no such remote cluster: [missing-cluster-alias]"); + assertInstrumentedActionCalls(0, 0); + } + + private void setupClusters() { + setupCluster(LOCAL_CLUSTER); + setupCluster(REMOTE_CLUSTER_A); + setupCluster(REMOTE_CLUSTER_B); + } + + private void setupCluster(String clusterAlias) { + final Client client = client(clusterAlias); + assertAcked(client.admin().indices().prepareCreate(INDEX_1)); + assertAcked(client.admin().indices().prepareCreate(INDEX_2)); + } + + private void deleteSecurityIndex(String clusterAlias) { + final Client client = new OriginSettingClient(client(clusterAlias), SECURITY_ORIGIN); + + GetIndexRequest getIndexRequest = new GetIndexRequest(TEST_REQUEST_TIMEOUT); + getIndexRequest.indices(SECURITY_MAIN_ALIAS); + getIndexRequest.indicesOptions(IndicesOptions.lenientExpandOpen()); + GetIndexResponse getIndexResponse = client.admin().indices().getIndex(getIndexRequest).actionGet(TEST_REQUEST_TIMEOUT); + + if (getIndexResponse.getIndices().length > 0) { + DeleteIndexRequest deleteIndexRequest = new DeleteIndexRequest(getIndexResponse.getIndices()); + assertAcked(client.admin().indices().delete(deleteIndexRequest).actionGet(TEST_REQUEST_TIMEOUT)); + } + } + + private SearchRequestBuilder buildSearchRequest(List indices, List clusterAliases, @Nullable String origin) { + Client client = client(); + if (securityEnabled) { + client = client.filterWithHeader(Map.of(BASIC_AUTH_HEADER, basicAuthHeaderValue(TEST_USER_NAME, TEST_PASSWORD_SECURE_STRING))); + } + + return client.prepareSearch(generateFullyQualifiedIndices(indices, clusterAliases)).setQuery(new TestQueryBuilder(origin)); + } + + private static String[] generateFullyQualifiedIndices(List indices, List clusterAliases) { + String[] fullyQualifiedIndices = new String[indices.size() * clusterAliases.size()]; + + int idx = 0; + for (String clusterAlias : clusterAliases) { + for (String index : indices) { + StringBuilder fullyQualifiedIndex = new StringBuilder(); + if (LOCAL_CLUSTER.equals(clusterAlias) == false) { + fullyQualifiedIndex.append(clusterAlias); + fullyQualifiedIndex.append(":"); + } + fullyQualifiedIndex.append(index); + + fullyQualifiedIndices[idx++] = fullyQualifiedIndex.toString(); + } + } + + return fullyQualifiedIndices; + } + + private static void assertSearchResponse(SearchRequestBuilder searchRequest) { + assertResponse(searchRequest, response -> { + assertThat(response.getHits().getTotalHits().value(), equalTo(0L)); + assertThat(response.getSuccessfulShards(), equalTo(response.getTotalShards())); + assertThat( + response.getClusters().getClusterStateCount(SearchResponse.Cluster.Status.SUCCESSFUL), + equalTo(response.getClusters().getTotal()) + ); + }); + } + + private static void assertSearchFailure( + SearchRequestBuilder searchRequest, + Class expectedExceptionClass, + String expectedMessage + ) { + T actualException = assertThrows(expectedExceptionClass, () -> assertResponse(searchRequest, response -> {})); + assertThat(actualException.getMessage(), containsString(expectedMessage)); + } + + private static void assertInstrumentedActionCalls(int expectedClusterACalls, int expectedClusterBCalls) { + Map expected = new HashMap<>(); + if (expectedClusterACalls > 0) { + expected.put(REMOTE_CLUSTER_A, expectedClusterACalls); + } + if (expectedClusterBCalls > 0) { + expected.put(REMOTE_CLUSTER_B, expectedClusterBCalls); + } + + Map actual = INSTRUMENTED_ACTION_CALL_MAP.entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().get())); + + assertThat(actual, equalTo(expected)); + } + + private static class TestQueryBuilder extends AbstractQueryBuilder { + private static final String NAME = "test"; + + private final String origin; + private final Boolean actionsAcknowledged; + private final ActionFuture actionsAcknowledgedSupplier; + + private static TestQueryBuilder fromXContent(XContentParser parser) { + return new TestQueryBuilder(); + } + + private TestQueryBuilder() { + this((String) null); + } + + private TestQueryBuilder(@Nullable String origin) { + this.origin = origin; + this.actionsAcknowledged = null; + this.actionsAcknowledgedSupplier = null; + } + + private TestQueryBuilder(StreamInput in) throws IOException { + super(in); + this.origin = in.readOptionalString(); + this.actionsAcknowledged = in.readOptionalBoolean(); + this.actionsAcknowledgedSupplier = null; + } + + private TestQueryBuilder(TestQueryBuilder other, Boolean actionsAcknowledged, ActionFuture actionsAcknowledgedSupplier) { + this.origin = other.origin; + this.actionsAcknowledged = actionsAcknowledged; + this.actionsAcknowledgedSupplier = actionsAcknowledgedSupplier; + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + if (actionsAcknowledgedSupplier != null) { + throw new IllegalStateException( + "actionsAcknowledgedSupplier must be null, can't serialize suppliers, missing a rewriteAndFetch?" + ); + } + + out.writeOptionalString(origin); + out.writeOptionalBoolean(this.actionsAcknowledged); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(NAME); + builder.endObject(); + } + + @Override + protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + ResolvedIndices resolvedIndices = queryRewriteContext.getResolvedIndices(); + if (resolvedIndices != null) { + TestQueryBuilder rewritten = this; + + if (actionsAcknowledgedSupplier != null) { + Boolean actionsAcknowledged = actionsAcknowledgedSupplier.isDone() ? actionsAcknowledgedSupplier.actionGet() : null; + if (actionsAcknowledged != null) { + rewritten = new TestQueryBuilder(this, actionsAcknowledged, null); + } + } else if (actionsAcknowledged == null) { + ActionFuture actionsAcknowledgedSupplier = registerActions(queryRewriteContext, origin); + rewritten = new TestQueryBuilder(this, null, actionsAcknowledgedSupplier); + } + + return rewritten; + } + + return this; + } + + @Override + protected Query doToQuery(SearchExecutionContext context) { + assertThat(actionsAcknowledged, is(true)); + return new MatchNoDocsQuery(); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersion.current(); + } + + @Override + protected boolean doEquals(TestQueryBuilder other) { + return Objects.equals(this.origin, other.origin) + && Objects.equals(this.actionsAcknowledged, other.actionsAcknowledged) + && Objects.equals(this.actionsAcknowledgedSupplier, other.actionsAcknowledgedSupplier); + } + + @Override + protected int doHashCode() { + return Objects.hash(origin, actionsAcknowledged, actionsAcknowledgedSupplier); + } + + private static ActionFuture registerActions(QueryRewriteContext queryRewriteContext, String origin) { + var remoteClusterIndices = queryRewriteContext.getResolvedIndices().getRemoteClusterIndices(); + + int requestCount = 0; + Map> clusterRequestMap = new HashMap<>(); + for (var entry : remoteClusterIndices.entrySet()) { + String clusterAlias = entry.getKey(); + OriginalIndices originalIndices = entry.getValue(); + + int indicesCount = originalIndices.indices().length; + List clusterRequestList = new ArrayList<>(indicesCount); + for (int i = 0; i < indicesCount; i++) { + clusterRequestList.add(new InstrumentedAction.Request()); + } + + requestCount += indicesCount; + clusterRequestMap.put(clusterAlias, clusterRequestList); + } + + PlainActionFuture actionsAcknowledgedSupplier = new PlainActionFuture<>(); + GroupedActionListener gal = new GroupedActionListener<>( + requestCount, + ActionListener.wrap(c -> actionsAcknowledgedSupplier.onResponse(true), actionsAcknowledgedSupplier::onFailure) + ); + + for (var entry : clusterRequestMap.entrySet()) { + String clusterAlias = entry.getKey(); + List clusterRequestList = clusterRequestMap.get(clusterAlias); + + for (InstrumentedAction.Request clusterRequest : clusterRequestList) { + queryRewriteContext.registerRemoteAsyncAction(clusterAlias, (client, threadContext, listener) -> { + ActionListener wrappedListener = listener.delegateFailureAndWrap((l, r) -> { + if (r.isAcknowledged()) { + gal.onResponse(null); + l.onResponse(null); + } else { + l.onFailure(new IllegalStateException("Unacknowledged response from cluster [" + clusterAlias + "]")); + } + }); + BiConsumer> requestConsumer = ( + r, + l) -> client.execute(InstrumentedAction.REMOTE_TYPE, r, l); + + if (origin != null) { + executeAsyncWithOrigin(threadContext, origin, clusterRequest, wrappedListener, requestConsumer); + } else { + requestConsumer.accept(clusterRequest, wrappedListener); + } + }); + } + } + + return actionsAcknowledgedSupplier; + } + } + + public static class InstrumentedAction extends ActionType { + private static final InstrumentedAction INSTANCE = new InstrumentedAction(); + private static final RemoteClusterActionType REMOTE_TYPE = new RemoteClusterActionType<>(INSTANCE.name(), Response::new); + + private static final String NAME = "cluster:internal/test/instrumented"; + + private InstrumentedAction() { + super(NAME); + } + + public static class Request extends ActionRequest { + public Request() {} + + public Request(StreamInput in) throws IOException { + super(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + return true; + } + + @Override + public int hashCode() { + return 0; + } + } + + public static class Response extends AcknowledgedResponse { + public Response() { + super(true); + } + + public Response(StreamInput in) throws IOException { + super(in); + } + } + } + + public static class TransportInstrumentedAction extends HandledTransportAction< + InstrumentedAction.Request, + InstrumentedAction.Response> { + private final ClusterService clusterService; + + @Inject + public TransportInstrumentedAction(TransportService transportService, ActionFilters actionFilters, ClusterService clusterService) { + super( + InstrumentedAction.NAME, + transportService, + actionFilters, + InstrumentedAction.Request::new, + EsExecutors.DIRECT_EXECUTOR_SERVICE + ); + this.clusterService = clusterService; + } + + @Override + protected void doExecute(Task task, InstrumentedAction.Request request, ActionListener listener) { + String clusterName = clusterService.getClusterName().value(); + AtomicInteger callCounter = INSTRUMENTED_ACTION_CALL_MAP.computeIfAbsent(clusterName, k -> new AtomicInteger()); + callCounter.incrementAndGet(); + + listener.onResponse(new InstrumentedAction.Response()); + } + } + + public static class TestPlugin extends Plugin implements ActionPlugin, SearchPlugin { + public TestPlugin() {} + + @Override + public Collection getActions() { + return List.of(new ActionHandler(InstrumentedAction.INSTANCE, TransportInstrumentedAction.class)); + } + + @Override + public List> getQueries() { + return List.of(new QuerySpec(TestQueryBuilder.NAME, TestQueryBuilder::new, TestQueryBuilder::fromXContent)); + } + } + + private static class CustomSecuritySettingsSource extends SecuritySettingsSource { + private static final String TEST_ROLE_YML = """ + user: + cluster: [ NONE ] + indices: + - names: 'index-*' + allow_restricted_indices: false + privileges: [ ALL ] + """; + + private static final String CONFIG_STANDARD_ROLES_YML = TEST_ROLE_YML + "\n" + SecuritySettingsSourceField.ES_TEST_ROOT_ROLE_YML; + + private CustomSecuritySettingsSource(boolean sslEnabled, Path parentFolder, ESIntegTestCase.Scope scope) { + super(sslEnabled, parentFolder, scope); + } + + @Override + protected String configRoles() { + return CONFIG_STANDARD_ROLES_YML; + } + } +}