diff --git a/x-pack/plugin/security/qa/rcs-extension/src/main/java/org/elasticsearch/xpack/security/rcs/extension/TestRemoteClusterSecurityExtension.java b/x-pack/plugin/security/qa/rcs-extension/src/main/java/org/elasticsearch/xpack/security/rcs/extension/TestRemoteClusterSecurityExtension.java index 3d27dcfb2d862..ddfcbd8c5242e 100644 --- a/x-pack/plugin/security/qa/rcs-extension/src/main/java/org/elasticsearch/xpack/security/rcs/extension/TestRemoteClusterSecurityExtension.java +++ b/x-pack/plugin/security/qa/rcs-extension/src/main/java/org/elasticsearch/xpack/security/rcs/extension/TestRemoteClusterSecurityExtension.java @@ -10,25 +10,20 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.DestructiveOperations; import org.elasticsearch.common.settings.Setting; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.ssl.SslConfiguration; -import org.elasticsearch.common.util.Maps; import org.elasticsearch.transport.Transport; import org.elasticsearch.transport.TransportInterceptor; import org.elasticsearch.transport.TransportRequest; -import org.elasticsearch.xpack.core.XPackSettings; import org.elasticsearch.xpack.core.security.SecurityContext; import org.elasticsearch.xpack.core.security.authc.Authentication; -import org.elasticsearch.xpack.core.ssl.SSLService; import org.elasticsearch.xpack.core.ssl.SslProfile; import org.elasticsearch.xpack.security.authc.RemoteClusterAuthenticationService; import org.elasticsearch.xpack.security.transport.RemoteClusterTransportInterceptor; import org.elasticsearch.xpack.security.transport.ServerTransportFilter; import org.elasticsearch.xpack.security.transport.extension.RemoteClusterSecurityExtension; -import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Optional; public class TestRemoteClusterSecurityExtension implements RemoteClusterSecurityExtension { @@ -58,40 +53,16 @@ public boolean isRemoteClusterConnection(Transport.Connection connection) { return false; } - public boolean hasRemoteClusterAccessHeadersInContext(SecurityContext securityContext) { - return false; - } - @Override - public Map getProfileTransportFilters( - Map profileConfigurations, + public Optional getRemoteProfileTransportFilter( + SslProfile sslProfile, DestructiveOperations destructiveOperations ) { - Map profileFilters = Maps.newMapWithExpectedSize(profileConfigurations.size() + 1); - Settings settings = components.settings(); - final boolean transportSSLEnabled = XPackSettings.TRANSPORT_SSL_ENABLED.get(settings); - - for (Map.Entry entry : profileConfigurations.entrySet()) { - final String profileName = entry.getKey(); - final SslProfile sslProfile = entry.getValue(); - final SslConfiguration profileConfiguration = sslProfile.configuration(); - profileFilters.put( - profileName, - new ServerTransportFilter( - components.authenticationService(), - components.authorizationService(), - components.threadPool().getThreadContext(), - transportSSLEnabled && SSLService.isSSLClientAuthEnabled(profileConfiguration), - destructiveOperations, - components.securityContext() - ) - ); - } - // We need to register here the default security - // server transport filter which ensures that all - // incoming transport requests are properly - // authenticated and authorized. - return Collections.unmodifiableMap(profileFilters); + return Optional.empty(); + } + + public boolean hasRemoteClusterAccessHeadersInContext(SecurityContext securityContext) { + return false; } }; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java index 6a87e3fde2b99..37e4ea024a858 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java @@ -1207,6 +1207,8 @@ Collection createComponents( new SecurityServerTransportInterceptor( settings, threadPool, + authcService.get(), + authzService, getSslService(), securityContext.get(), destructiveOperations, diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/CrossClusterAccessTransportInterceptor.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/CrossClusterAccessTransportInterceptor.java index da93d67b7c542..fff9d0c8a9c78 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/CrossClusterAccessTransportInterceptor.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/CrossClusterAccessTransportInterceptor.java @@ -15,7 +15,6 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.ssl.SslConfiguration; -import org.elasticsearch.common.util.Maps; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.XPackLicenseState; @@ -49,13 +48,11 @@ import org.elasticsearch.xpack.security.authc.CrossClusterAccessHeaders; import org.elasticsearch.xpack.security.authz.AuthorizationService; -import java.util.Collections; import java.util.Map; import java.util.Optional; import java.util.function.Function; import static org.elasticsearch.core.Strings.format; -import static org.elasticsearch.transport.RemoteClusterPortSettings.REMOTE_CLUSTER_PROFILE; import static org.elasticsearch.transport.RemoteClusterPortSettings.REMOTE_CLUSTER_SERVER_ENABLED; import static org.elasticsearch.transport.RemoteClusterPortSettings.TRANSPORT_VERSION_ADVANCED_REMOTE_CLUSTER_SECURITY; @@ -83,7 +80,6 @@ public class CrossClusterAccessTransportInterceptor implements RemoteClusterTran private final Function> remoteClusterCredentialsResolver; private final CrossClusterAccessAuthenticationService crossClusterAccessAuthcService; private final CrossClusterApiKeySignatureManager crossClusterApiKeySignatureManager; - private final AuthenticationService authcService; private final AuthorizationService authzService; private final XPackLicenseState licenseState; private final SecurityContext securityContext; @@ -128,7 +124,6 @@ public CrossClusterAccessTransportInterceptor( this.remoteClusterCredentialsResolver = remoteClusterCredentialsResolver; this.crossClusterAccessAuthcService = crossClusterAccessAuthcService; this.crossClusterApiKeySignatureManager = crossClusterApiKeySignatureManager; - this.authcService = authcService; this.authzService = authzService; this.licenseState = licenseState; this.securityContext = securityContext; @@ -328,51 +323,27 @@ public boolean isRemoteClusterConnection(Transport.Connection connection) { } @Override - public Map getProfileTransportFilters( - Map profileConfigurations, + public Optional getRemoteProfileTransportFilter( + SslProfile sslProfile, DestructiveOperations destructiveOperations ) { - Map profileFilters = Maps.newMapWithExpectedSize(profileConfigurations.size() + 1); - - final boolean transportSSLEnabled = XPackSettings.TRANSPORT_SSL_ENABLED.get(settings); - final boolean remoteClusterPortEnabled = REMOTE_CLUSTER_SERVER_ENABLED.get(settings); + final SslConfiguration profileConfiguration = sslProfile.configuration(); + final boolean remoteClusterServerEnabled = REMOTE_CLUSTER_SERVER_ENABLED.get(settings); final boolean remoteClusterServerSSLEnabled = XPackSettings.REMOTE_CLUSTER_SERVER_SSL_ENABLED.get(settings); - - for (Map.Entry entry : profileConfigurations.entrySet()) { - final String profileName = entry.getKey(); - final SslProfile sslProfile = entry.getValue(); - final SslConfiguration profileConfiguration = sslProfile.configuration(); - assert profileConfiguration != null : "Ssl Profile [" + sslProfile + "] for [" + profileName + "] has a null configuration"; - final boolean useRemoteClusterProfile = remoteClusterPortEnabled && profileName.equals(REMOTE_CLUSTER_PROFILE); - if (useRemoteClusterProfile) { - profileFilters.put( - profileName, - new CrossClusterAccessServerTransportFilter( - crossClusterAccessAuthcService, - authzService, - threadPool.getThreadContext(), - remoteClusterServerSSLEnabled && SSLService.isSSLClientAuthEnabled(profileConfiguration), - destructiveOperations, - securityContext, - licenseState - ) - ); - } else { - profileFilters.put( - profileName, - new ServerTransportFilter( - authcService, - authzService, - threadPool.getThreadContext(), - transportSSLEnabled && SSLService.isSSLClientAuthEnabled(profileConfiguration), - destructiveOperations, - securityContext - ) - ); - } + if (remoteClusterServerEnabled) { + return Optional.of( + new CrossClusterAccessServerTransportFilter( + crossClusterAccessAuthcService, + authzService, + threadPool.getThreadContext(), + remoteClusterServerSSLEnabled && SSLService.isSSLClientAuthEnabled(profileConfiguration), + destructiveOperations, + securityContext, + licenseState + ) + ); } - - return Collections.unmodifiableMap(profileFilters); + return Optional.empty(); } @Override diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/RemoteClusterTransportInterceptor.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/RemoteClusterTransportInterceptor.java index 5328c280a629f..51d5152ad70ea 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/RemoteClusterTransportInterceptor.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/RemoteClusterTransportInterceptor.java @@ -8,12 +8,13 @@ package org.elasticsearch.xpack.security.transport; import org.elasticsearch.action.support.DestructiveOperations; +import org.elasticsearch.transport.RemoteClusterPortSettings; import org.elasticsearch.transport.Transport; import org.elasticsearch.transport.TransportInterceptor; import org.elasticsearch.xpack.core.security.SecurityContext; import org.elasticsearch.xpack.core.ssl.SslProfile; -import java.util.Map; +import java.util.Optional; /** * Allows to provide remote cluster interception that's capable of intercepting remote connections @@ -32,16 +33,20 @@ public interface RemoteClusterTransportInterceptor { boolean isRemoteClusterConnection(Transport.Connection connection); /** - * Allows interceptors to provide a custom {@link ServerTransportFilter} implementations per transport profile. - * The transport filter is called on the receiver side to filter incoming requests - * and execute authentication and authorization for all requests. + * Allows interceptors to provide a custom {@link ServerTransportFilter} implementation + * for intercepting requests for {@link RemoteClusterPortSettings#REMOTE_CLUSTER_PROFILE} + * transport profile. + *

+ * The transport filter is called on the receiver side to filter incoming remote cluster requests + * and to execute authentication and authorization for all incoming requests. + *

+ * This method is only called when setting {@link RemoteClusterPortSettings#REMOTE_CLUSTER_SERVER_ENABLED} + * is set to {@code true}. * - * @return map of {@link ServerTransportFilter}s per transport profile name + * @return a custom {@link ServerTransportFilter}s for the given transport profile, + * or an empty optional to fall back to the default transport filter */ - Map getProfileTransportFilters( - Map profileConfigurations, - DestructiveOperations destructiveOperations - ); + Optional getRemoteProfileTransportFilter(SslProfile sslProfile, DestructiveOperations destructiveOperations); /** * Returns {@code true} if any of the remote cluster access headers are in the security context. diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptor.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptor.java index d5f802ae0f1d1..ec36354f44c0e 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptor.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptor.java @@ -12,6 +12,8 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.DestructiveOperations; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.ssl.SslConfiguration; +import org.elasticsearch.common.util.Maps; import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.RunOnce; @@ -29,44 +31,92 @@ import org.elasticsearch.transport.TransportResponseHandler; import org.elasticsearch.transport.TransportService; import org.elasticsearch.transport.TransportService.ContextRestoreResponseHandler; +import org.elasticsearch.xpack.core.XPackSettings; import org.elasticsearch.xpack.core.security.SecurityContext; import org.elasticsearch.xpack.core.security.transport.ProfileConfigurations; import org.elasticsearch.xpack.core.ssl.SSLService; import org.elasticsearch.xpack.core.ssl.SslProfile; +import org.elasticsearch.xpack.security.authc.AuthenticationService; +import org.elasticsearch.xpack.security.authz.AuthorizationService; import org.elasticsearch.xpack.security.authz.AuthorizationUtils; import org.elasticsearch.xpack.security.authz.PreAuthorizationUtils; +import java.util.Collections; import java.util.Map; import java.util.concurrent.Executor; import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.transport.RemoteClusterPortSettings.REMOTE_CLUSTER_PROFILE; public class SecurityServerTransportInterceptor implements TransportInterceptor { private static final Logger logger = LogManager.getLogger(SecurityServerTransportInterceptor.class); + private final AuthenticationService authcService; + private final AuthorizationService authzService; private final RemoteClusterTransportInterceptor remoteClusterTransportInterceptor; private final Map profileFilters; private final ThreadPool threadPool; private final SecurityContext securityContext; + private final Settings settings; public SecurityServerTransportInterceptor( Settings settings, ThreadPool threadPool, + AuthenticationService authcService, + AuthorizationService authzService, SSLService sslService, SecurityContext securityContext, DestructiveOperations destructiveOperations, RemoteClusterTransportInterceptor remoteClusterTransportInterceptor - ) { this.remoteClusterTransportInterceptor = remoteClusterTransportInterceptor; this.securityContext = securityContext; this.threadPool = threadPool; + this.settings = settings; + this.authcService = authcService; + this.authzService = authzService; final Map profileConfigurations = ProfileConfigurations.get(settings, sslService, false); - this.profileFilters = this.remoteClusterTransportInterceptor.getProfileTransportFilters( - profileConfigurations, - destructiveOperations - ); + this.profileFilters = initializeProfileFilters(profileConfigurations, destructiveOperations); + } + + private Map initializeProfileFilters( + final Map profileConfigurations, + final DestructiveOperations destructiveOperations + ) { + final Map profileFilters = Maps.newMapWithExpectedSize(profileConfigurations.size() + 1); + final boolean transportSSLEnabled = XPackSettings.TRANSPORT_SSL_ENABLED.get(settings); + + for (Map.Entry entry : profileConfigurations.entrySet()) { + final String profileName = entry.getKey(); + final SslProfile sslProfile = entry.getValue(); + if (profileName.equals(REMOTE_CLUSTER_PROFILE)) { + var remoteProfileTransportFilter = this.remoteClusterTransportInterceptor.getRemoteProfileTransportFilter( + sslProfile, + destructiveOperations + ); + if (remoteProfileTransportFilter.isPresent()) { + profileFilters.put(profileName, remoteProfileTransportFilter.get()); + continue; + } + } + + final SslConfiguration profileConfiguration = sslProfile.configuration(); + assert profileConfiguration != null : "SSL Profile [" + sslProfile + "] for [" + profileName + "] has a null configuration"; + profileFilters.put( + profileName, + new ServerTransportFilter( + authcService, + authzService, + threadPool.getThreadContext(), + transportSSLEnabled && SSLService.isSSLClientAuthEnabled(profileConfiguration), + destructiveOperations, + securityContext + ) + ); + } + + return Collections.unmodifiableMap(profileFilters); } @Override diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/ServerTransportFilter.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/ServerTransportFilter.java index 96b40e140980f..104c5a859de6d 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/ServerTransportFilter.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/ServerTransportFilter.java @@ -62,11 +62,12 @@ public ServerTransportFilter( } /** - * Called just after the given request was received by the transport. Any exception - * thrown by this method will stop the request from being handled and the error will - * be sent back to the sender. + * Called just after the given request was received by the transport service. + *

+ * Any exception thrown by this method will stop the request from being handled + * and the error will be sent back to the sender. */ - void inbound(String action, TransportRequest request, TransportChannel transportChannel, ActionListener listener) { + public void inbound(String action, TransportRequest request, TransportChannel transportChannel, ActionListener listener) { if (TransportCloseIndexAction.NAME.equals(action) || OpenIndexAction.NAME.equals(action) || TransportDeleteIndexAction.TYPE.name().equals(action)) { diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/AbstractServerTransportFilterTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/AbstractServerTransportFilterTests.java new file mode 100644 index 0000000000000..7b057bf49e8d3 --- /dev/null +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/AbstractServerTransportFilterTests.java @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.security.transport; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.security.authc.Authentication; +import org.mockito.stubbing.Answer; + +import static org.hamcrest.Matchers.arrayWithSize; + +public abstract class AbstractServerTransportFilterTests extends ESTestCase { + + protected static Answer> getAnswer(Authentication authentication) { + return getAnswer(authentication, false); + } + + protected static Answer> getAnswer(Authentication authentication, boolean crossClusterAccess) { + return i -> { + final Object[] args = i.getArguments(); + assertThat(args, arrayWithSize(crossClusterAccess ? 3 : 4)); + @SuppressWarnings("unchecked") + ActionListener callback = (ActionListener) args[args.length - 1]; + callback.onResponse(authentication); + return Void.TYPE; + }; + } + +} diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/AbstractServerTransportInterceptorTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/AbstractServerTransportInterceptorTests.java new file mode 100644 index 0000000000000..30483e64f9b17 --- /dev/null +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/AbstractServerTransportInterceptorTests.java @@ -0,0 +1,73 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.security.transport; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.common.ssl.DefaultJdkTrustConfig; +import org.elasticsearch.common.ssl.EmptyKeyConfig; +import org.elasticsearch.common.ssl.SslClientAuthenticationMode; +import org.elasticsearch.common.ssl.SslConfiguration; +import org.elasticsearch.common.ssl.SslVerificationMode; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.transport.RemoteConnectionManager; +import org.elasticsearch.transport.Transport; +import org.elasticsearch.xpack.core.ssl.SSLService; +import org.elasticsearch.xpack.core.ssl.SslProfile; + +import java.util.List; +import java.util.Optional; +import java.util.function.Consumer; +import java.util.function.Function; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public abstract class AbstractServerTransportInterceptorTests extends ESTestCase { + + protected static String[] randomRoles() { + return generateRandomStringArray(3, 10, false, true); + } + + protected + Function> + mockRemoteClusterCredentialsResolver(String remoteClusterAlias) { + return connection -> Optional.of( + new RemoteConnectionManager.RemoteClusterAliasWithCredentials( + remoteClusterAlias, + new SecureString(randomAlphaOfLengthBetween(10, 42).toCharArray()) + ) + ); + } + + protected static SSLService mockSslService() { + final SslConfiguration defaultConfiguration = new SslConfiguration( + "", + false, + DefaultJdkTrustConfig.DEFAULT_INSTANCE, + EmptyKeyConfig.INSTANCE, + SslVerificationMode.FULL, + SslClientAuthenticationMode.NONE, + List.of("TLS_AES_256_GCM_SHA384"), + List.of("TLSv1.3"), + randomLongBetween(1, 100000) + ); + final SslProfile defaultProfile = mock(SslProfile.class); + when(defaultProfile.configuration()).thenReturn(defaultConfiguration); + final SSLService sslService = mock(SSLService.class); + when(sslService.profile("xpack.security.transport.ssl")).thenReturn(defaultProfile); + when(sslService.profile("xpack.security.transport.ssl.")).thenReturn(defaultProfile); + return sslService; + } + + @SuppressWarnings("unchecked") + protected static Consumer anyConsumer() { + return any(Consumer.class); + } +} diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/CrossClusterAccessServerTransportFilterTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/CrossClusterAccessServerTransportFilterTests.java new file mode 100644 index 0000000000000..419c045352c71 --- /dev/null +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/CrossClusterAccessServerTransportFilterTests.java @@ -0,0 +1,295 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.security.transport; + +import org.elasticsearch.ElasticsearchSecurityException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.search.TransportSearchAction; +import org.elasticsearch.action.support.DestructiveOperations; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.ClusterSettings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.license.MockLicenseState; +import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.transport.TransportChannel; +import org.elasticsearch.transport.TransportRequest; +import org.elasticsearch.transport.TransportSettings; +import org.elasticsearch.xpack.core.security.SecurityContext; +import org.elasticsearch.xpack.core.security.authc.Authentication; +import org.elasticsearch.xpack.core.security.authc.AuthenticationTestHelper; +import org.elasticsearch.xpack.security.Security; +import org.elasticsearch.xpack.security.authc.AuthenticationService; +import org.elasticsearch.xpack.security.authc.CrossClusterAccessAuthenticationService; +import org.elasticsearch.xpack.security.authz.AuthorizationService; +import org.junit.Before; +import org.mockito.Mockito; + +import java.util.Collections; +import java.util.Set; + +import static org.elasticsearch.test.ActionListenerUtils.anyActionListener; +import static org.elasticsearch.xpack.core.ClientHelper.SECURITY_HEADER_FILTERS; +import static org.elasticsearch.xpack.core.security.authc.CrossClusterAccessSubjectInfo.CROSS_CLUSTER_ACCESS_SUBJECT_INFO_HEADER_KEY; +import static org.elasticsearch.xpack.core.security.support.Exceptions.authenticationError; +import static org.elasticsearch.xpack.core.security.support.Exceptions.authorizationError; +import static org.elasticsearch.xpack.security.authc.CrossClusterAccessHeaders.CROSS_CLUSTER_ACCESS_CREDENTIALS_HEADER_KEY; +import static org.hamcrest.Matchers.arrayWithSize; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +public class CrossClusterAccessServerTransportFilterTests extends AbstractServerTransportFilterTests { + + private AuthenticationService authcService; + private AuthorizationService authzService; + private TransportChannel channel; + private boolean failDestructiveOperations; + private DestructiveOperations destructiveOperations; + private CrossClusterAccessAuthenticationService crossClusterAccessAuthcService; + private MockLicenseState mockLicenseState; + + @Before + public void init() throws Exception { + authcService = mock(AuthenticationService.class); + authzService = mock(AuthorizationService.class); + channel = mock(TransportChannel.class); + when(channel.getProfileName()).thenReturn(TransportSettings.DEFAULT_PROFILE); + when(channel.getVersion()).thenReturn(TransportVersion.current()); + failDestructiveOperations = randomBoolean(); + Settings settings = Settings.builder().put(DestructiveOperations.REQUIRES_NAME_SETTING.getKey(), failDestructiveOperations).build(); + destructiveOperations = new DestructiveOperations( + settings, + new ClusterSettings(settings, Collections.singleton(DestructiveOperations.REQUIRES_NAME_SETTING)) + ); + crossClusterAccessAuthcService = mock(CrossClusterAccessAuthenticationService.class); + when(crossClusterAccessAuthcService.getAuthenticationService()).thenReturn(authcService); + mockLicenseState = MockLicenseState.createMock(); + Mockito.when(mockLicenseState.isAllowed(Security.ADVANCED_REMOTE_CLUSTER_SECURITY_FEATURE)).thenReturn(true); + } + + public void testCrossClusterAccessInbound() { + TransportRequest request = mock(TransportRequest.class); + Authentication authentication = AuthenticationTestHelper.builder().build(); + String action = randomAlphaOfLengthBetween(10, 20); + doAnswer(getAnswer(authentication)).when(authcService).authenticate(eq(action), eq(request), eq(true), anyActionListener()); + doAnswer(getAnswer(authentication, true)).when(crossClusterAccessAuthcService) + .authenticate(eq(action), eq(request), anyActionListener()); + CrossClusterAccessServerTransportFilter filter = getNodeCrossClusterAccessFilter(); + PlainActionFuture listener = spy(new PlainActionFuture<>()); + filter.inbound(action, request, channel, listener); + verify(authzService).authorize(eq(authentication), eq(action), eq(request), anyActionListener()); + verify(crossClusterAccessAuthcService).authenticate(anyString(), any(), anyActionListener()); + verify(authcService, never()).authenticate(anyString(), any(), anyBoolean(), anyActionListener()); + } + + public void testCrossClusterAccessInboundInvalidHeadersFail() { + TransportRequest request = mock(TransportRequest.class); + Authentication authentication = AuthenticationTestHelper.builder().build(); + String action = randomAlphaOfLengthBetween(10, 20); + doAnswer(getAnswer(authentication)).when(authcService).authenticate(eq(action), eq(request), eq(true), anyActionListener()); + doAnswer(getAnswer(authentication, true)).when(crossClusterAccessAuthcService) + .authenticate(eq(action), eq(request), anyActionListener()); + CrossClusterAccessServerTransportFilter filter = getNodeCrossClusterAccessFilter( + Set.copyOf(randomNonEmptySubsetOf(SECURITY_HEADER_FILTERS)) + ); + PlainActionFuture listener = new PlainActionFuture<>(); + filter.inbound(action, request, channel, listener); + var actual = expectThrows(IllegalArgumentException.class, listener::actionGet); + verifyNoMoreInteractions(authcService); + verifyNoMoreInteractions(authzService); + assertThat( + actual.getMessage(), + containsString("is not allowed for cross cluster requests through the dedicated remote cluster server port") + ); + verify(crossClusterAccessAuthcService, never()).authenticate(anyString(), any(), anyActionListener()); + } + + public void testCrossClusterAccessInboundMissingHeadersFail() { + TransportRequest request = mock(TransportRequest.class); + Authentication authentication = AuthenticationTestHelper.builder().build(); + String action = randomAlphaOfLengthBetween(10, 20); + doAnswer(getAnswer(authentication)).when(authcService).authenticate(eq(action), eq(request), eq(true), anyActionListener()); + doAnswer(getAnswer(authentication, true)).when(crossClusterAccessAuthcService) + .authenticate(eq(action), eq(request), anyActionListener()); + Settings settings = Settings.builder().put("path.home", createTempDir()).build(); + ThreadContext threadContext = new ThreadContext(settings); + String firstMissingHeader = CROSS_CLUSTER_ACCESS_CREDENTIALS_HEADER_KEY; + if (randomBoolean()) { + String headerToInclude = randomBoolean() + ? CROSS_CLUSTER_ACCESS_CREDENTIALS_HEADER_KEY + : CROSS_CLUSTER_ACCESS_SUBJECT_INFO_HEADER_KEY; + if (headerToInclude.equals(CROSS_CLUSTER_ACCESS_CREDENTIALS_HEADER_KEY)) { + firstMissingHeader = CROSS_CLUSTER_ACCESS_SUBJECT_INFO_HEADER_KEY; + } + threadContext.putHeader(headerToInclude, randomAlphaOfLength(42)); + } + CrossClusterAccessServerTransportFilter filter = new CrossClusterAccessServerTransportFilter( + crossClusterAccessAuthcService, + authzService, + threadContext, + false, + destructiveOperations, + new SecurityContext(settings, threadContext), + mockLicenseState + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + filter.inbound(action, request, channel, listener); + var actual = expectThrows(IllegalArgumentException.class, listener::actionGet); + + verifyNoMoreInteractions(authcService); + verifyNoMoreInteractions(authzService); + assertThat( + actual.getMessage(), + equalTo( + "Cross cluster requests through the dedicated remote cluster server port require transport header [" + + firstMissingHeader + + "] but none found. " + + "Please ensure you have configured remote cluster credentials on the cluster originating the request." + ) + ); + verify(crossClusterAccessAuthcService, never()).authenticate(anyString(), any(), anyActionListener()); + } + + public void testInboundAuthorizationException() { + CrossClusterAccessServerTransportFilter filter = getNodeCrossClusterAccessFilter(); + TransportRequest request = mock(TransportRequest.class); + Authentication authentication = AuthenticationTestHelper.builder().build(); + String action = TransportSearchAction.TYPE.name(); + doAnswer(getAnswer(authentication)).when(authcService).authenticate(eq(action), eq(request), eq(true), anyActionListener()); + doAnswer(getAnswer(authentication, true)).when(crossClusterAccessAuthcService) + .authenticate(eq(action), eq(request), anyActionListener()); + PlainActionFuture future = new PlainActionFuture<>(); + doThrow(authorizationError("authz failed")).when(authzService) + .authorize(eq(authentication), eq(action), eq(request), anyActionListener()); + ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, () -> { + filter.inbound(action, request, channel, future); + future.actionGet(); + }); + assertThat(e.getMessage(), equalTo("authz failed")); + verify(crossClusterAccessAuthcService).authenticate(anyString(), any(), anyActionListener()); + verify(authcService, never()).authenticate(anyString(), any(), anyBoolean(), anyActionListener()); + } + + public void testCrossClusterAccessInboundAuthenticationException() { + TransportRequest request = mock(TransportRequest.class); + Exception authE = authenticationError("authc failed"); + String action = randomAlphaOfLengthBetween(10, 20); + doAnswer(i -> { + final Object[] args = i.getArguments(); + assertThat(args, arrayWithSize(3)); + @SuppressWarnings("unchecked") + ActionListener callback = (ActionListener) args[args.length - 1]; + callback.onFailure(authE); + return Void.TYPE; + }).when(crossClusterAccessAuthcService).authenticate(eq(action), eq(request), anyActionListener()); + doAnswer(i -> { + final Object[] args = i.getArguments(); + assertThat(args, arrayWithSize(4)); + @SuppressWarnings("unchecked") + ActionListener callback = (ActionListener) args[args.length - 1]; + callback.onFailure(authE); + return Void.TYPE; + }).when(authcService).authenticate(eq(action), eq(request), eq(true), anyActionListener()); + CrossClusterAccessServerTransportFilter filter = getNodeCrossClusterAccessFilter(); + try { + PlainActionFuture future = new PlainActionFuture<>(); + filter.inbound(action, request, channel, future); + future.actionGet(); + fail("expected filter inbound to throw an authentication exception on authentication error"); + } catch (ElasticsearchSecurityException e) { + assertThat(e.getMessage(), equalTo("authc failed")); + } + verifyNoMoreInteractions(authzService); + verify(crossClusterAccessAuthcService).authenticate(anyString(), any(), anyActionListener()); + verify(authcService, never()).authenticate(anyString(), any(), anyBoolean(), anyActionListener()); + } + + public void testCrossClusterAccessInboundFailsWithUnsupportedLicense() { + final MockLicenseState unsupportedLicenseState = MockLicenseState.createMock(); + Mockito.when(unsupportedLicenseState.isAllowed(Security.ADVANCED_REMOTE_CLUSTER_SECURITY_FEATURE)).thenReturn(false); + + CrossClusterAccessServerTransportFilter crossClusterAccessFilter = getNodeCrossClusterAccessFilter(unsupportedLicenseState); + PlainActionFuture listener = new PlainActionFuture<>(); + String action = randomAlphaOfLengthBetween(10, 20); + crossClusterAccessFilter.inbound(action, mock(TransportRequest.class), channel, listener); + + ElasticsearchSecurityException actualException = expectThrows(ElasticsearchSecurityException.class, listener::actionGet); + assertThat( + actualException.getMessage(), + equalTo("current license is non-compliant for [" + Security.ADVANCED_REMOTE_CLUSTER_SECURITY_FEATURE.getName() + "]") + ); + + // License check should be executed first, hence we don't expect authc/authz to be even attempted. + verify(crossClusterAccessAuthcService, never()).authenticate(anyString(), any(), anyActionListener()); + verifyNoInteractions(authzService, authcService); + } + + private CrossClusterAccessServerTransportFilter getNodeCrossClusterAccessFilter() { + return getNodeCrossClusterAccessFilter(Collections.emptySet(), mockLicenseState); + } + + private CrossClusterAccessServerTransportFilter getNodeCrossClusterAccessFilter(Set additionalHeadersKeys) { + return getNodeCrossClusterAccessFilter(additionalHeadersKeys, mockLicenseState); + } + + private CrossClusterAccessServerTransportFilter getNodeCrossClusterAccessFilter(XPackLicenseState licenseState) { + return getNodeCrossClusterAccessFilter(Collections.emptySet(), licenseState); + } + + private CrossClusterAccessServerTransportFilter getNodeCrossClusterAccessFilter( + Set additionalHeadersKeys, + XPackLicenseState licenseState + ) { + Settings settings = Settings.builder().put("path.home", createTempDir()).build(); + ThreadContext threadContext = new ThreadContext(settings); + for (var header : additionalHeadersKeys) { + threadContext.putHeader(header, randomAlphaOfLength(20)); + } + // Randomly include valid headers + if (randomBoolean()) { + for (var validHeader : CrossClusterAccessServerTransportFilter.ALLOWED_TRANSPORT_HEADERS) { + // don't overwrite additionalHeadersKeys + if (false == additionalHeadersKeys.contains(validHeader)) { + threadContext.putHeader(validHeader, randomAlphaOfLength(20)); + } + } + } + var requiredHeaders = Set.of(CROSS_CLUSTER_ACCESS_CREDENTIALS_HEADER_KEY, CROSS_CLUSTER_ACCESS_SUBJECT_INFO_HEADER_KEY); + for (var header : requiredHeaders) { + // don't overwrite already present headers + if (threadContext.getHeader(header) == null) { + threadContext.putHeader(header, randomAlphaOfLength(20)); + } + } + return new CrossClusterAccessServerTransportFilter( + crossClusterAccessAuthcService, + authzService, + threadContext, + false, + destructiveOperations, + new SecurityContext(settings, threadContext), + licenseState + ); + } + +} diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/CrossClusterAccessTransportInterceptorTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/CrossClusterAccessTransportInterceptorTests.java new file mode 100644 index 0000000000000..a5b3d913fc3ea --- /dev/null +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/CrossClusterAccessTransportInterceptorTests.java @@ -0,0 +1,805 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.security.transport; + +import org.elasticsearch.ElasticsearchSecurityException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.admin.cluster.state.ClusterStateAction; +import org.elasticsearch.action.admin.cluster.state.ClusterStateRequest; +import org.elasticsearch.action.support.DestructiveOperations; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.settings.ClusterSettings; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.ssl.SslClientAuthenticationMode; +import org.elasticsearch.common.ssl.SslConfiguration; +import org.elasticsearch.common.ssl.SslKeyConfig; +import org.elasticsearch.common.ssl.SslTrustConfig; +import org.elasticsearch.common.ssl.SslVerificationMode; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.license.MockLicenseState; +import org.elasticsearch.test.ClusterServiceUtils; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.RemoteConnectionManager; +import org.elasticsearch.transport.SendRequestTransportException; +import org.elasticsearch.transport.Transport; +import org.elasticsearch.transport.TransportException; +import org.elasticsearch.transport.TransportInterceptor; +import org.elasticsearch.transport.TransportRequest; +import org.elasticsearch.transport.TransportRequestOptions; +import org.elasticsearch.transport.TransportResponse; +import org.elasticsearch.transport.TransportResponseHandler; +import org.elasticsearch.xpack.core.security.SecurityContext; +import org.elasticsearch.xpack.core.security.authc.Authentication; +import org.elasticsearch.xpack.core.security.authc.AuthenticationTestHelper; +import org.elasticsearch.xpack.core.security.authc.CrossClusterAccessSubjectInfo; +import org.elasticsearch.xpack.core.security.authz.RoleDescriptorsIntersection; +import org.elasticsearch.xpack.core.security.user.InternalUsers; +import org.elasticsearch.xpack.core.security.user.SystemUser; +import org.elasticsearch.xpack.core.security.user.User; +import org.elasticsearch.xpack.core.ssl.SSLService; +import org.elasticsearch.xpack.core.ssl.SslProfile; +import org.elasticsearch.xpack.security.Security; +import org.elasticsearch.xpack.security.audit.AuditUtil; +import org.elasticsearch.xpack.security.authc.ApiKeyService; +import org.elasticsearch.xpack.security.authc.AuthenticationService; +import org.elasticsearch.xpack.security.authc.CrossClusterAccessAuthenticationService; +import org.elasticsearch.xpack.security.authz.AuthorizationService; +import org.junit.After; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import static org.elasticsearch.test.ActionListenerUtils.anyActionListener; +import static org.elasticsearch.xpack.core.security.authc.CrossClusterAccessSubjectInfo.CROSS_CLUSTER_ACCESS_SUBJECT_INFO_HEADER_KEY; +import static org.elasticsearch.xpack.core.security.authz.RoleDescriptorTestHelper.randomUniquelyNamedRoleDescriptors; +import static org.elasticsearch.xpack.security.authc.CrossClusterAccessHeaders.CROSS_CLUSTER_ACCESS_CREDENTIALS_HEADER_KEY; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class CrossClusterAccessTransportInterceptorTests extends AbstractServerTransportInterceptorTests { + + private Settings settings; + private ThreadPool threadPool; + private ThreadContext threadContext; + private SecurityContext securityContext; + private ClusterService clusterService; + private MockLicenseState mockLicenseState; + private DestructiveOperations destructiveOperations; + private CrossClusterApiKeySignatureManager crossClusterApiKeySignatureManager; + + @Override + public void setUp() throws Exception { + super.setUp(); + settings = Settings.builder().put("path.home", createTempDir()).build(); + threadPool = new TestThreadPool(getTestName()); + clusterService = ClusterServiceUtils.createClusterService(threadPool); + threadContext = threadPool.getThreadContext(); + securityContext = spy(new SecurityContext(settings, threadPool.getThreadContext())); + mockLicenseState = MockLicenseState.createMock(); + Mockito.when(mockLicenseState.isAllowed(Security.ADVANCED_REMOTE_CLUSTER_SECURITY_FEATURE)).thenReturn(true); + destructiveOperations = new DestructiveOperations( + Settings.EMPTY, + new ClusterSettings(Settings.EMPTY, Collections.singleton(DestructiveOperations.REQUIRES_NAME_SETTING)) + ); + crossClusterApiKeySignatureManager = mock(CrossClusterApiKeySignatureManager.class); + } + + @After + public void stopThreadPool() throws Exception { + clusterService.close(); + terminate(threadPool); + } + + public void testSendWithCrossClusterAccessHeadersForSystemUserRegularAction() throws Exception { + final String action; + final TransportRequest request; + if (randomBoolean()) { + action = randomAlphaOfLengthBetween(5, 30); + request = mock(TransportRequest.class); + } else { + action = ClusterStateAction.NAME; + request = mock(ClusterStateRequest.class); + } + doTestSendWithCrossClusterAccessHeaders( + true, + action, + request, + AuthenticationTestHelper.builder().internal(InternalUsers.SYSTEM_USER).build() + ); + } + + public void testSendWithCrossClusterAccessHeadersForSystemUserCcrInternalAction() throws Exception { + final String action = randomFrom( + "internal:admin/ccr/restore/session/put", + "internal:admin/ccr/restore/session/clear", + "internal:admin/ccr/restore/file_chunk/get" + ); + final TransportRequest request = mock(TransportRequest.class); + doTestSendWithCrossClusterAccessHeaders( + true, + action, + request, + AuthenticationTestHelper.builder().internal(InternalUsers.SYSTEM_USER).build() + ); + } + + public void testSendWithCrossClusterAccessHeadersForRegularUserRegularAction() throws Exception { + final Authentication authentication = randomValueOtherThanMany( + authc -> authc.getAuthenticationType() == Authentication.AuthenticationType.INTERNAL, + () -> AuthenticationTestHelper.builder().build() + ); + final String action = randomAlphaOfLengthBetween(5, 30); + final TransportRequest request = mock(TransportRequest.class); + doTestSendWithCrossClusterAccessHeaders(false, action, request, authentication); + } + + public void testSendWithCrossClusterAccessHeadersForRegularUserClusterStateAction() throws Exception { + final Authentication authentication = randomValueOtherThanMany( + authc -> authc.getAuthenticationType() == Authentication.AuthenticationType.INTERNAL, + () -> AuthenticationTestHelper.builder().build() + ); + final String action = ClusterStateAction.NAME; + final TransportRequest request = mock(ClusterStateRequest.class); + doTestSendWithCrossClusterAccessHeaders(true, action, request, authentication); + } + + private void doTestSendWithCrossClusterAccessHeaders( + boolean shouldAssertForSystemUser, + String action, + TransportRequest request, + Authentication authentication + ) throws IOException { + authentication.writeToContext(threadContext); + final String expectedRequestId = AuditUtil.getOrGenerateRequestId(threadContext); + final String remoteClusterAlias = randomAlphaOfLengthBetween(5, 10); + final String encodedApiKey = randomAlphaOfLengthBetween(10, 42); + final String remoteClusterCredential = ApiKeyService.withApiKeyPrefix(encodedApiKey); + final AuthorizationService authzService = mock(AuthorizationService.class); + // We capture the listener so that we can complete the full flow, by calling onResponse further down + @SuppressWarnings("unchecked") + final ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + doAnswer(i -> null).when(authzService) + .getRoleDescriptorsIntersectionForRemoteCluster(any(), any(), any(), listenerCaptor.capture()); + + final SecurityServerTransportInterceptor interceptor = new SecurityServerTransportInterceptor( + settings, + threadPool, + mock(AuthenticationService.class), + authzService, + mockSslService(), + securityContext, + destructiveOperations, + new CrossClusterAccessTransportInterceptor( + settings, + threadPool, + mock(AuthenticationService.class), + authzService, + securityContext, + mock(CrossClusterAccessAuthenticationService.class), + crossClusterApiKeySignatureManager, + mockLicenseState, + ignored -> Optional.of( + new RemoteConnectionManager.RemoteClusterAliasWithCredentials( + remoteClusterAlias, + new SecureString(encodedApiKey.toCharArray()) + ) + ) + ) + ); + + final AtomicBoolean calledWrappedSender = new AtomicBoolean(false); + final AtomicReference sentAction = new AtomicReference<>(); + final AtomicReference sentCredential = new AtomicReference<>(); + final AtomicReference sentCrossClusterAccessSubjectInfo = new AtomicReference<>(); + final TransportInterceptor.AsyncSender sender = interceptor.interceptSender(new TransportInterceptor.AsyncSender() { + @Override + public void sendRequest( + Transport.Connection connection, + String action, + TransportRequest request, + TransportRequestOptions options, + TransportResponseHandler handler + ) { + if (calledWrappedSender.compareAndSet(false, true) == false) { + fail("sender called more than once"); + } + assertThat(securityContext.getAuthentication(), nullValue()); + assertThat(AuditUtil.extractRequestId(securityContext.getThreadContext()), equalTo(expectedRequestId)); + sentAction.set(action); + sentCredential.set(securityContext.getThreadContext().getHeader(CROSS_CLUSTER_ACCESS_CREDENTIALS_HEADER_KEY)); + try { + sentCrossClusterAccessSubjectInfo.set( + CrossClusterAccessSubjectInfo.readFromContext(securityContext.getThreadContext()) + ); + } catch (IOException e) { + fail("no exceptions expected but got " + e); + } + handler.handleResponse(null); + } + }); + final Transport.Connection connection = mock(Transport.Connection.class); + when(connection.getTransportVersion()).thenReturn(TransportVersion.current()); + + sender.sendRequest(connection, action, request, null, new TransportResponseHandler<>() { + @Override + public Executor executor() { + return TransportResponseHandler.TRANSPORT_WORKER; + } + + @Override + public void handleResponse(TransportResponse response) { + // Headers should get restored before handle response is called + assertThat(securityContext.getAuthentication(), equalTo(authentication)); + assertThat(securityContext.getThreadContext().getHeader(CROSS_CLUSTER_ACCESS_SUBJECT_INFO_HEADER_KEY), nullValue()); + assertThat(securityContext.getThreadContext().getHeader(CROSS_CLUSTER_ACCESS_CREDENTIALS_HEADER_KEY), nullValue()); + } + + @Override + public void handleException(TransportException exp) { + fail("no exceptions expected but got " + exp); + } + + @Override + public TransportResponse read(StreamInput in) { + return null; + } + }); + if (shouldAssertForSystemUser) { + assertThat( + sentCrossClusterAccessSubjectInfo.get(), + equalTo( + SystemUser.crossClusterAccessSubjectInfo( + authentication.getEffectiveSubject().getTransportVersion(), + authentication.getEffectiveSubject().getRealm().getNodeName() + ) + ) + ); + verify(authzService, never()).getRoleDescriptorsIntersectionForRemoteCluster( + eq(remoteClusterAlias), + eq(TransportVersion.current()), + eq(authentication.getEffectiveSubject()), + anyActionListener() + ); + } else { + final RoleDescriptorsIntersection expectedRoleDescriptorsIntersection = new RoleDescriptorsIntersection( + randomList(1, 3, () -> Set.copyOf(randomUniquelyNamedRoleDescriptors(0, 1))) + ); + // Call listener to complete flow + listenerCaptor.getValue().onResponse(expectedRoleDescriptorsIntersection); + verify(authzService, times(1)).getRoleDescriptorsIntersectionForRemoteCluster( + eq(remoteClusterAlias), + eq(TransportVersion.current()), + eq(authentication.getEffectiveSubject()), + anyActionListener() + ); + assertThat( + sentCrossClusterAccessSubjectInfo.get(), + equalTo(new CrossClusterAccessSubjectInfo(authentication, expectedRoleDescriptorsIntersection)) + ); + } + assertTrue(calledWrappedSender.get()); + if (action.startsWith("internal:")) { + assertThat(sentAction.get(), equalTo("indices:internal/" + action.substring("internal:".length()))); + } else { + assertThat(sentAction.get(), equalTo(action)); + } + assertThat(sentCredential.get(), equalTo(remoteClusterCredential)); + verify(securityContext, never()).executeAsInternalUser(any(), any(), anyConsumer()); + assertThat(securityContext.getThreadContext().getHeader(CROSS_CLUSTER_ACCESS_SUBJECT_INFO_HEADER_KEY), nullValue()); + assertThat(securityContext.getThreadContext().getHeader(CROSS_CLUSTER_ACCESS_CREDENTIALS_HEADER_KEY), nullValue()); + assertThat(AuditUtil.extractRequestId(securityContext.getThreadContext()), equalTo(expectedRequestId)); + } + + public void testSendWithUserIfCrossClusterAccessHeadersConditionNotMet() throws Exception { + boolean noCredential = randomBoolean(); + final boolean notRemoteConnection = randomBoolean(); + // Ensure at least one condition fails + if (false == (notRemoteConnection || noCredential)) { + noCredential = true; + } + final boolean finalNoCredential = noCredential; + final String remoteClusterAlias = randomAlphaOfLengthBetween(5, 10); + final String encodedApiKey = randomAlphaOfLengthBetween(10, 42); + final AuthenticationTestHelper.AuthenticationTestBuilder builder = AuthenticationTestHelper.builder(); + final Authentication authentication = randomFrom( + builder.apiKey().build(), + builder.serviceAccount().build(), + builder.user(new User(randomAlphaOfLengthBetween(3, 10), randomRoles())).realm().build() + ); + authentication.writeToContext(threadContext); + + final AuthorizationService authzService = mock(AuthorizationService.class); + final SecurityServerTransportInterceptor interceptor = new SecurityServerTransportInterceptor( + settings, + threadPool, + mock(AuthenticationService.class), + authzService, + mockSslService(), + securityContext, + destructiveOperations, + new CrossClusterAccessTransportInterceptor( + settings, + threadPool, + mock(AuthenticationService.class), + authzService, + securityContext, + mock(CrossClusterAccessAuthenticationService.class), + crossClusterApiKeySignatureManager, + mockLicenseState, + ignored -> notRemoteConnection + ? Optional.empty() + : (finalNoCredential + ? Optional.of(new RemoteConnectionManager.RemoteClusterAliasWithCredentials(remoteClusterAlias, null)) + : Optional.of( + new RemoteConnectionManager.RemoteClusterAliasWithCredentials( + remoteClusterAlias, + new SecureString(encodedApiKey.toCharArray()) + ) + )) + ) + ); + + final AtomicBoolean calledWrappedSender = new AtomicBoolean(false); + final AtomicReference sentAuthentication = new AtomicReference<>(); + final TransportInterceptor.AsyncSender sender = interceptor.interceptSender(new TransportInterceptor.AsyncSender() { + @Override + public void sendRequest( + Transport.Connection connection, + String action, + TransportRequest request, + TransportRequestOptions options, + TransportResponseHandler handler + ) { + if (calledWrappedSender.compareAndSet(false, true) == false) { + fail("sender called more than once"); + } + sentAuthentication.set(securityContext.getAuthentication()); + assertThat(securityContext.getThreadContext().getHeader(CROSS_CLUSTER_ACCESS_CREDENTIALS_HEADER_KEY), nullValue()); + assertThat(securityContext.getThreadContext().getHeader(CROSS_CLUSTER_ACCESS_SUBJECT_INFO_HEADER_KEY), nullValue()); + } + }); + final Transport.Connection connection = mock(Transport.Connection.class); + when(connection.getTransportVersion()).thenReturn(TransportVersion.current()); + sender.sendRequest(connection, "action", mock(TransportRequest.class), null, null); + assertTrue(calledWrappedSender.get()); + assertThat(sentAuthentication.get(), equalTo(authentication)); + verify(authzService, never()).getRoleDescriptorsIntersectionForRemoteCluster(any(), any(), any(), anyActionListener()); + assertThat(securityContext.getThreadContext().getHeader(CROSS_CLUSTER_ACCESS_SUBJECT_INFO_HEADER_KEY), nullValue()); + assertThat(securityContext.getThreadContext().getHeader(CROSS_CLUSTER_ACCESS_CREDENTIALS_HEADER_KEY), nullValue()); + } + + public void testSendRemoteRequestFailsIfUserHasNoRemoteIndicesPrivileges() throws Exception { + final Authentication authentication = AuthenticationTestHelper.builder() + .user(new User(randomAlphaOfLengthBetween(3, 10), randomRoles())) + .realm() + .build(); + authentication.writeToContext(threadContext); + final String remoteClusterAlias = randomAlphaOfLengthBetween(5, 10); + final String encodedApiKey = randomAlphaOfLengthBetween(10, 42); + final String remoteClusterCredential = ApiKeyService.withApiKeyPrefix(encodedApiKey); + final AuthorizationService authzService = mock(AuthorizationService.class); + + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + final var listener = (ActionListener) invocation.getArgument(3); + listener.onResponse(RoleDescriptorsIntersection.EMPTY); + return null; + }).when(authzService).getRoleDescriptorsIntersectionForRemoteCluster(any(), any(), any(), anyActionListener()); + + final SecurityServerTransportInterceptor interceptor = new SecurityServerTransportInterceptor( + settings, + threadPool, + mock(AuthenticationService.class), + authzService, + mockSslService(), + securityContext, + destructiveOperations, + new CrossClusterAccessTransportInterceptor( + settings, + threadPool, + mock(AuthenticationService.class), + authzService, + securityContext, + mock(CrossClusterAccessAuthenticationService.class), + crossClusterApiKeySignatureManager, + mockLicenseState, + ignored -> Optional.of( + new RemoteConnectionManager.RemoteClusterAliasWithCredentials( + remoteClusterAlias, + new SecureString(encodedApiKey.toCharArray()) + ) + ) + ) + ); + + final TransportInterceptor.AsyncSender sender = interceptor.interceptSender(new TransportInterceptor.AsyncSender() { + @Override + public void sendRequest( + Transport.Connection connection, + String action, + TransportRequest request, + TransportRequestOptions options, + TransportResponseHandler handler + ) { + fail("request should have failed"); + } + }); + final Transport.Connection connection = mock(Transport.Connection.class); + when(connection.getTransportVersion()).thenReturn(TransportVersion.current()); + + final ElasticsearchSecurityException expectedException = new ElasticsearchSecurityException("remote action denied"); + when(authzService.remoteActionDenied(authentication, "action", remoteClusterAlias)).thenReturn(expectedException); + + final var actualException = new AtomicReference(); + sender.sendRequest(connection, "action", mock(TransportRequest.class), null, new TransportResponseHandler<>() { + @Override + public Executor executor() { + return TransportResponseHandler.TRANSPORT_WORKER; + } + + @Override + public void handleResponse(TransportResponse response) { + fail("should not success"); + } + + @Override + public void handleException(TransportException exp) { + actualException.set(exp.getCause()); + } + + @Override + public TransportResponse read(StreamInput in) { + return null; + } + }); + assertThat(actualException.get(), is(expectedException)); + assertThat(securityContext.getThreadContext().getHeader(CROSS_CLUSTER_ACCESS_SUBJECT_INFO_HEADER_KEY), nullValue()); + assertThat(securityContext.getThreadContext().getHeader(CROSS_CLUSTER_ACCESS_CREDENTIALS_HEADER_KEY), nullValue()); + } + + public void testSendWithCrossClusterAccessHeadersWithUnsupportedLicense() throws Exception { + final MockLicenseState unsupportedLicenseState = MockLicenseState.createMock(); + Mockito.when(unsupportedLicenseState.isAllowed(Security.ADVANCED_REMOTE_CLUSTER_SECURITY_FEATURE)).thenReturn(false); + + AuthenticationTestHelper.builder().build().writeToContext(threadContext); + final String remoteClusterAlias = randomAlphaOfLengthBetween(5, 10); + + final SecurityServerTransportInterceptor interceptor = new SecurityServerTransportInterceptor( + settings, + threadPool, + mock(AuthenticationService.class), + mock(AuthorizationService.class), + mockSslService(), + securityContext, + destructiveOperations, + new CrossClusterAccessTransportInterceptor( + settings, + threadPool, + mock(AuthenticationService.class), + mock(AuthorizationService.class), + securityContext, + mock(CrossClusterAccessAuthenticationService.class), + crossClusterApiKeySignatureManager, + unsupportedLicenseState, + mockRemoteClusterCredentialsResolver(remoteClusterAlias) + ) + ); + + final TransportInterceptor.AsyncSender sender = interceptor.interceptSender( + mock(TransportInterceptor.AsyncSender.class, ignored -> { + throw new AssertionError("sender should not be called"); + }) + ); + final Transport.Connection connection = mock(Transport.Connection.class); + when(connection.getTransportVersion()).thenReturn(TransportVersion.current()); + final AtomicBoolean calledHandleException = new AtomicBoolean(false); + final AtomicReference actualException = new AtomicReference<>(); + sender.sendRequest(connection, "action", mock(TransportRequest.class), null, new TransportResponseHandler<>() { + @Override + public Executor executor() { + return TransportResponseHandler.TRANSPORT_WORKER; + } + + @Override + public void handleResponse(TransportResponse response) { + fail("should not receive a response"); + } + + @Override + public void handleException(TransportException exp) { + if (calledHandleException.compareAndSet(false, true) == false) { + fail("handle exception called more than once"); + } + actualException.set(exp); + } + + @Override + public TransportResponse read(StreamInput in) { + fail("should not receive a response"); + return null; + } + }); + assertThat(actualException.get(), instanceOf(SendRequestTransportException.class)); + assertThat(actualException.get().getCause(), instanceOf(ElasticsearchSecurityException.class)); + assertThat( + actualException.get().getCause().getMessage(), + equalTo("current license is non-compliant for [" + Security.ADVANCED_REMOTE_CLUSTER_SECURITY_FEATURE.getName() + "]") + ); + assertThat(securityContext.getThreadContext().getHeader(CROSS_CLUSTER_ACCESS_SUBJECT_INFO_HEADER_KEY), nullValue()); + assertThat(securityContext.getThreadContext().getHeader(CROSS_CLUSTER_ACCESS_CREDENTIALS_HEADER_KEY), nullValue()); + } + + public void testProfileFiltersCreatedDifferentlyForDifferentTransportAndRemoteClusterSslSettings() { + // filters are created irrespective of ssl enabled + final boolean transportSslEnabled = randomBoolean(); + final boolean remoteClusterSslEnabled = randomBoolean(); + final Settings.Builder builder = Settings.builder() + .put(this.settings) + .put("xpack.security.transport.ssl.enabled", transportSslEnabled) + .put("remote_cluster_server.enabled", true) + .put("xpack.security.remote_cluster_server.ssl.enabled", remoteClusterSslEnabled); + if (randomBoolean()) { + builder.put("xpack.security.remote_cluster_client.ssl.enabled", randomBoolean()); // client SSL won't be processed + } + + final SslProfile defaultProfile = mock(SslProfile.class); + when(defaultProfile.configuration()).thenReturn( + new SslConfiguration( + "xpack.security.transport.ssl", + randomBoolean(), + mock(SslTrustConfig.class), + mock(SslKeyConfig.class), + randomFrom(SslVerificationMode.values()), + SslClientAuthenticationMode.REQUIRED, + List.of("TLS_AES_256_GCM_SHA384"), + List.of("TLSv1.3"), + randomLongBetween(1, 100000) + ) + ); + final SslProfile remoteProfile = mock(SslProfile.class); + when(remoteProfile.configuration()).thenReturn( + new SslConfiguration( + "xpack.security.remote_cluster_server.ssl", + randomBoolean(), + mock(SslTrustConfig.class), + mock(SslKeyConfig.class), + randomFrom(SslVerificationMode.values()), + SslClientAuthenticationMode.NONE, + List.of(Runtime.version().feature() < 24 ? "TLS_RSA_WITH_AES_256_GCM_SHA384" : "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"), + List.of("TLSv1.2"), + randomLongBetween(1, 100000) + ) + ); + + final SSLService sslService = mock(SSLService.class); + when(sslService.profile("xpack.security.transport.ssl.")).thenReturn(defaultProfile); + + when(sslService.profile("xpack.security.remote_cluster_server.ssl.")).thenReturn(remoteProfile); + doThrow(new AssertionError("profile filters should not be configured for remote cluster client")).when(sslService) + .profile("xpack.security.remote_cluster_client.ssl."); + + final AuthenticationService authcService = mock(AuthenticationService.class); + final AuthorizationService authzService = mock(AuthorizationService.class); + final var securityServerTransportInterceptor = new SecurityServerTransportInterceptor( + builder.build(), + threadPool, + authcService, + authzService, + sslService, + securityContext, + destructiveOperations, + new CrossClusterAccessTransportInterceptor( + builder.build(), + threadPool, + authcService, + authzService, + securityContext, + mock(CrossClusterAccessAuthenticationService.class), + crossClusterApiKeySignatureManager, + mockLicenseState + ) + ); + + final Map profileFilters = securityServerTransportInterceptor.getProfileFilters(); + assertThat(profileFilters.keySet(), containsInAnyOrder("default", "_remote_cluster")); + assertThat(profileFilters.get("default").isExtractClientCert(), is(transportSslEnabled)); + assertThat(profileFilters.get("default"), not(instanceOf(CrossClusterAccessServerTransportFilter.class))); + assertThat(profileFilters.get("_remote_cluster").isExtractClientCert(), is(false)); + assertThat(profileFilters.get("_remote_cluster"), instanceOf(CrossClusterAccessServerTransportFilter.class)); + } + + public void testNoProfileFilterForRemoteClusterWhenTheFeatureIsDisabled() { + final boolean transportSslEnabled = randomBoolean(); + final Settings.Builder builder = Settings.builder() + .put(this.settings) + .put("xpack.security.transport.ssl.enabled", transportSslEnabled) + .put("remote_cluster_server.enabled", false) + .put("xpack.security.remote_cluster_server.ssl.enabled", randomBoolean()); + if (randomBoolean()) { + builder.put("xpack.security.remote_cluster_client.ssl.enabled", randomBoolean()); // client SSL won't be processed + } + + final SslProfile profile = mock(SslProfile.class); + when(profile.configuration()).thenReturn( + new SslConfiguration( + "xpack.security.transport.ssl", + randomBoolean(), + mock(SslTrustConfig.class), + mock(SslKeyConfig.class), + randomFrom(SslVerificationMode.values()), + SslClientAuthenticationMode.REQUIRED, + List.of("TLS_AES_256_GCM_SHA384"), + List.of("TLSv1.3"), + randomLongBetween(1, 100000) + ) + ); + + final SSLService sslService = mock(SSLService.class); + when(sslService.profile("xpack.security.transport.ssl.")).thenReturn(profile); + + doThrow(new AssertionError("profile filters should not be configured for remote cluster server when the port is disabled")).when( + sslService + ).profile("xpack.security.remote_cluster_server.ssl."); + doThrow(new AssertionError("profile filters should not be configured for remote cluster client")).when(sslService) + .profile("xpack.security.remote_cluster_client.ssl."); + + final var securityServerTransportInterceptor = new SecurityServerTransportInterceptor( + builder.build(), + threadPool, + mock(AuthenticationService.class), + mock(AuthorizationService.class), + sslService, + securityContext, + destructiveOperations, + new CrossClusterAccessTransportInterceptor( + builder.build(), + threadPool, + mock(AuthenticationService.class), + mock(AuthorizationService.class), + securityContext, + mock(CrossClusterAccessAuthenticationService.class), + crossClusterApiKeySignatureManager, + mockLicenseState + ) + ); + + final Map profileFilters = securityServerTransportInterceptor.getProfileFilters(); + assertThat(profileFilters.keySet(), contains("default")); + assertThat(profileFilters.get("default").isExtractClientCert(), is(transportSslEnabled)); + } + + public void testGetRemoteProfileTransportFilter() { + final boolean remoteClusterSslEnabled = randomBoolean(); + final Settings.Builder builder = Settings.builder() + .put(this.settings) + .put("remote_cluster_server.enabled", true) + .put("xpack.security.remote_cluster_server.ssl.enabled", remoteClusterSslEnabled); + if (randomBoolean()) { + builder.put("xpack.security.remote_cluster_client.ssl.enabled", randomBoolean()); // client SSL won't be processed + } + + final SslProfile remoteProfile = mock(SslProfile.class); + when(remoteProfile.configuration()).thenReturn( + new SslConfiguration( + "xpack.security.remote_cluster_server.ssl", + randomBoolean(), + mock(SslTrustConfig.class), + mock(SslKeyConfig.class), + randomFrom(SslVerificationMode.values()), + SslClientAuthenticationMode.NONE, + List.of(Runtime.version().feature() < 24 ? "TLS_RSA_WITH_AES_256_GCM_SHA384" : "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"), + List.of("TLSv1.2"), + randomLongBetween(1, 100000) + ) + ); + + final SSLService sslService = mock(SSLService.class); + when(sslService.profile("xpack.security.remote_cluster_server.ssl.")).thenReturn(remoteProfile); + doThrow(new AssertionError("profile filters should not be configured for remote cluster client")).when(sslService) + .profile("xpack.security.remote_cluster_client.ssl."); + + final AuthenticationService authcService = mock(AuthenticationService.class); + final AuthorizationService authzService = mock(AuthorizationService.class); + CrossClusterAccessTransportInterceptor interceptor = new CrossClusterAccessTransportInterceptor( + builder.build(), + threadPool, + authcService, + authzService, + securityContext, + mock(CrossClusterAccessAuthenticationService.class), + crossClusterApiKeySignatureManager, + mockLicenseState + ); + + final Optional remoteProfileTransportFilter = interceptor.getRemoteProfileTransportFilter( + remoteProfile, + destructiveOperations + ); + assertThat(remoteProfileTransportFilter.isPresent(), is(true)); + assertThat(remoteProfileTransportFilter.get(), instanceOf(CrossClusterAccessServerTransportFilter.class)); + } + + public void testGetRemoteProfileTransportFilterWhenRemoteClusterServerIsDisabled() { + final boolean remoteClusterSslEnabled = randomBoolean(); + final Settings.Builder builder = Settings.builder() + .put(this.settings) + .put("remote_cluster_server.enabled", false) + .put("xpack.security.remote_cluster_server.ssl.enabled", remoteClusterSslEnabled); + if (randomBoolean()) { + builder.put("xpack.security.remote_cluster_client.ssl.enabled", randomBoolean()); // client SSL won't be processed + } + + final SslProfile remoteProfile = mock(SslProfile.class); + when(remoteProfile.configuration()).thenReturn( + new SslConfiguration( + "xpack.security.remote_cluster_server.ssl", + randomBoolean(), + mock(SslTrustConfig.class), + mock(SslKeyConfig.class), + randomFrom(SslVerificationMode.values()), + SslClientAuthenticationMode.NONE, + List.of(Runtime.version().feature() < 24 ? "TLS_RSA_WITH_AES_256_GCM_SHA384" : "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"), + List.of("TLSv1.2"), + randomLongBetween(1, 100000) + ) + ); + + final SSLService sslService = mock(SSLService.class); + when(sslService.profile("xpack.security.remote_cluster_server.ssl.")).thenReturn(remoteProfile); + doThrow(new AssertionError("profile filters should not be configured for remote cluster client")).when(sslService) + .profile("xpack.security.remote_cluster_client.ssl."); + + final AuthenticationService authcService = mock(AuthenticationService.class); + final AuthorizationService authzService = mock(AuthorizationService.class); + CrossClusterAccessTransportInterceptor interceptor = new CrossClusterAccessTransportInterceptor( + builder.build(), + threadPool, + authcService, + authzService, + securityContext, + mock(CrossClusterAccessAuthenticationService.class), + crossClusterApiKeySignatureManager, + mockLicenseState + ); + + final Optional remoteProfileTransportFilter = interceptor.getRemoteProfileTransportFilter( + remoteProfile, + destructiveOperations + ); + assertThat(remoteProfileTransportFilter.isPresent(), is(false)); + } + +} diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptorTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptorTests.java index 3bf995edee9eb..8b3c4b2059dbc 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptorTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptorTests.java @@ -6,37 +6,20 @@ */ package org.elasticsearch.xpack.security.transport; -import org.elasticsearch.ElasticsearchSecurityException; import org.elasticsearch.TransportVersion; -import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionResponse.Empty; -import org.elasticsearch.action.admin.cluster.state.ClusterStateAction; -import org.elasticsearch.action.admin.cluster.state.ClusterStateRequest; import org.elasticsearch.action.admin.indices.delete.DeleteIndexRequest; import org.elasticsearch.action.admin.indices.delete.TransportDeleteIndexAction; import org.elasticsearch.action.support.DestructiveOperations; import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.settings.ClusterSettings; -import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.ssl.DefaultJdkTrustConfig; -import org.elasticsearch.common.ssl.EmptyKeyConfig; -import org.elasticsearch.common.ssl.SslClientAuthenticationMode; -import org.elasticsearch.common.ssl.SslConfiguration; -import org.elasticsearch.common.ssl.SslKeyConfig; -import org.elasticsearch.common.ssl.SslTrustConfig; -import org.elasticsearch.common.ssl.SslVerificationMode; import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.license.MockLicenseState; import org.elasticsearch.tasks.Task; import org.elasticsearch.test.ClusterServiceUtils; -import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.TransportVersionUtils; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.transport.RemoteConnectionManager.RemoteClusterAliasWithCredentials; -import org.elasticsearch.transport.SendRequestTransportException; import org.elasticsearch.transport.Transport; import org.elasticsearch.transport.Transport.Connection; import org.elasticsearch.transport.TransportChannel; @@ -51,75 +34,48 @@ import org.elasticsearch.xpack.core.security.authc.Authentication; import org.elasticsearch.xpack.core.security.authc.Authentication.RealmRef; import org.elasticsearch.xpack.core.security.authc.AuthenticationTestHelper; -import org.elasticsearch.xpack.core.security.authc.CrossClusterAccessSubjectInfo; import org.elasticsearch.xpack.core.security.authz.AuthorizationServiceField; -import org.elasticsearch.xpack.core.security.authz.RoleDescriptorsIntersection; import org.elasticsearch.xpack.core.security.user.InternalUser; import org.elasticsearch.xpack.core.security.user.InternalUsers; -import org.elasticsearch.xpack.core.security.user.SystemUser; import org.elasticsearch.xpack.core.security.user.User; -import org.elasticsearch.xpack.core.ssl.SSLService; import org.elasticsearch.xpack.core.ssl.SslProfile; -import org.elasticsearch.xpack.security.Security; -import org.elasticsearch.xpack.security.audit.AuditUtil; -import org.elasticsearch.xpack.security.authc.ApiKeyService; import org.elasticsearch.xpack.security.authc.AuthenticationService; -import org.elasticsearch.xpack.security.authc.CrossClusterAccessAuthenticationService; import org.elasticsearch.xpack.security.authz.AuthorizationService; import org.junit.After; -import org.mockito.ArgumentCaptor; -import org.mockito.Mockito; import java.io.IOException; import java.util.Collections; -import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.Set; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Consumer; -import java.util.function.Function; import static org.elasticsearch.cluster.metadata.DataStreamLifecycle.DATA_STREAM_LIFECYCLE_ORIGIN; -import static org.elasticsearch.test.ActionListenerUtils.anyActionListener; import static org.elasticsearch.xpack.core.ClientHelper.ASYNC_SEARCH_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.SECURITY_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.SECURITY_PROFILE_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.TRANSFORM_ORIGIN; -import static org.elasticsearch.xpack.core.security.authc.CrossClusterAccessSubjectInfo.CROSS_CLUSTER_ACCESS_SUBJECT_INFO_HEADER_KEY; -import static org.elasticsearch.xpack.core.security.authz.RoleDescriptorTestHelper.randomUniquelyNamedRoleDescriptors; -import static org.elasticsearch.xpack.security.authc.CrossClusterAccessHeaders.CROSS_CLUSTER_ACCESS_CREDENTIALS_HEADER_KEY; -import static org.hamcrest.Matchers.contains; -import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.notNullValue; -import static org.hamcrest.Matchers.nullValue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -public class SecurityServerTransportInterceptorTests extends ESTestCase { +public class SecurityServerTransportInterceptorTests extends AbstractServerTransportInterceptorTests { private Settings settings; private ThreadPool threadPool; private ThreadContext threadContext; private SecurityContext securityContext; private ClusterService clusterService; - private MockLicenseState mockLicenseState; private DestructiveOperations destructiveOperations; - private CrossClusterApiKeySignatureManager crossClusterApiKeySignatureManager; @Override public void setUp() throws Exception { @@ -129,13 +85,10 @@ public void setUp() throws Exception { clusterService = ClusterServiceUtils.createClusterService(threadPool); threadContext = threadPool.getThreadContext(); securityContext = spy(new SecurityContext(settings, threadPool.getThreadContext())); - mockLicenseState = MockLicenseState.createMock(); - Mockito.when(mockLicenseState.isAllowed(Security.ADVANCED_REMOTE_CLUSTER_SECURITY_FEATURE)).thenReturn(true); destructiveOperations = new DestructiveOperations( Settings.EMPTY, new ClusterSettings(Settings.EMPTY, Collections.singleton(DestructiveOperations.REQUIRES_NAME_SETTING)) ); - crossClusterApiKeySignatureManager = mock(CrossClusterApiKeySignatureManager.class); } @After @@ -151,22 +104,16 @@ public void testSendAsync() throws Exception { .realmRef(new RealmRef("ldap", "foo", "node1")) .build(false); authentication.writeToContext(threadContext); + SecurityServerTransportInterceptor interceptor = new SecurityServerTransportInterceptor( settings, threadPool, + mock(AuthenticationService.class), + mock(AuthorizationService.class), mockSslService(), securityContext, destructiveOperations, - new CrossClusterAccessTransportInterceptor( - settings, - threadPool, - mock(AuthenticationService.class), - mock(AuthorizationService.class), - securityContext, - mock(CrossClusterAccessAuthenticationService.class), - crossClusterApiKeySignatureManager, - mockLicenseState - ) + new TestNoopRemoteClusterTransportInterceptor() ); ClusterServiceUtils.setState(clusterService, clusterService.state()); // force state update to trigger listener @@ -204,23 +151,15 @@ public void testSendAsyncSwitchToSystem() throws Exception { .build(false); authentication.writeToContext(threadContext); AuthorizationServiceField.ORIGINATING_ACTION_VALUE.set(threadContext, "indices:foo"); - SecurityServerTransportInterceptor interceptor = new SecurityServerTransportInterceptor( settings, threadPool, + mock(AuthenticationService.class), + mock(AuthorizationService.class), mockSslService(), securityContext, destructiveOperations, - new CrossClusterAccessTransportInterceptor( - settings, - threadPool, - mock(AuthenticationService.class), - mock(AuthorizationService.class), - securityContext, - mock(CrossClusterAccessAuthenticationService.class), - crossClusterApiKeySignatureManager, - mockLicenseState - ) + new TestNoopRemoteClusterTransportInterceptor() ); ClusterServiceUtils.setState(clusterService, clusterService.state()); // force state update to trigger listener @@ -255,19 +194,12 @@ public void testSendWithoutUser() throws Exception { SecurityServerTransportInterceptor interceptor = new SecurityServerTransportInterceptor( settings, threadPool, + mock(AuthenticationService.class), + mock(AuthorizationService.class), mockSslService(), securityContext, destructiveOperations, - new CrossClusterAccessTransportInterceptor( - settings, - threadPool, - mock(AuthenticationService.class), - mock(AuthorizationService.class), - securityContext, - mock(CrossClusterAccessAuthenticationService.class), - crossClusterApiKeySignatureManager, - mockLicenseState - ) + new TestNoopRemoteClusterTransportInterceptor() ) { @Override void assertNoAuthentication(String action) {} @@ -320,19 +252,12 @@ public void testSendToNewerVersionSetsCorrectVersion() throws Exception { SecurityServerTransportInterceptor interceptor = new SecurityServerTransportInterceptor( settings, threadPool, + mock(AuthenticationService.class), + mock(AuthorizationService.class), mockSslService(), securityContext, destructiveOperations, - new CrossClusterAccessTransportInterceptor( - settings, - threadPool, - mock(AuthenticationService.class), - mock(AuthorizationService.class), - securityContext, - mock(CrossClusterAccessAuthenticationService.class), - crossClusterApiKeySignatureManager, - mockLicenseState - ) + new TestNoopRemoteClusterTransportInterceptor() ); ClusterServiceUtils.setState(clusterService, clusterService.state()); // force state update to trigger listener @@ -391,19 +316,12 @@ public void testSendToOlderVersionSetsCorrectVersion() throws Exception { SecurityServerTransportInterceptor interceptor = new SecurityServerTransportInterceptor( settings, threadPool, + mock(AuthenticationService.class), + mock(AuthorizationService.class), mockSslService(), securityContext, destructiveOperations, - new CrossClusterAccessTransportInterceptor( - settings, - threadPool, - mock(AuthenticationService.class), - mock(AuthorizationService.class), - securityContext, - mock(CrossClusterAccessAuthenticationService.class), - crossClusterApiKeySignatureManager, - mockLicenseState - ) + new TestNoopRemoteClusterTransportInterceptor() ); ClusterServiceUtils.setState(clusterService, clusterService.state()); // force state update to trigger listener @@ -460,19 +378,12 @@ public void testSetUserBasedOnActionOrigin() { SecurityServerTransportInterceptor interceptor = new SecurityServerTransportInterceptor( settings, threadPool, + mock(AuthenticationService.class), + mock(AuthorizationService.class), mockSslService(), securityContext, destructiveOperations, - new CrossClusterAccessTransportInterceptor( - settings, - threadPool, - mock(AuthenticationService.class), - mock(AuthorizationService.class), - securityContext, - mock(CrossClusterAccessAuthenticationService.class), - crossClusterApiKeySignatureManager, - mockLicenseState - ) + new TestNoopRemoteClusterTransportInterceptor() ); final AtomicBoolean calledWrappedSender = new AtomicBoolean(false); @@ -620,594 +531,30 @@ public boolean decRef() { assertTrue(exceptionSent.get()); } - public void testSendWithCrossClusterAccessHeadersWithUnsupportedLicense() throws Exception { - final MockLicenseState unsupportedLicenseState = MockLicenseState.createMock(); - Mockito.when(unsupportedLicenseState.isAllowed(Security.ADVANCED_REMOTE_CLUSTER_SECURITY_FEATURE)).thenReturn(false); - - AuthenticationTestHelper.builder().build().writeToContext(threadContext); - final String remoteClusterAlias = randomAlphaOfLengthBetween(5, 10); - - final SecurityServerTransportInterceptor interceptor = new SecurityServerTransportInterceptor( - settings, - threadPool, - mockSslService(), - securityContext, - destructiveOperations, - new CrossClusterAccessTransportInterceptor( - settings, - threadPool, - mock(AuthenticationService.class), - mock(AuthorizationService.class), - securityContext, - mock(CrossClusterAccessAuthenticationService.class), - crossClusterApiKeySignatureManager, - unsupportedLicenseState, - mockRemoteClusterCredentialsResolver(remoteClusterAlias) - ) - ); - - final AsyncSender sender = interceptor.interceptSender(mock(AsyncSender.class, ignored -> { - throw new AssertionError("sender should not be called"); - })); - final Transport.Connection connection = mock(Transport.Connection.class); - when(connection.getTransportVersion()).thenReturn(TransportVersion.current()); - final AtomicBoolean calledHandleException = new AtomicBoolean(false); - final AtomicReference actualException = new AtomicReference<>(); - sender.sendRequest(connection, "action", mock(TransportRequest.class), null, new TransportResponseHandler<>() { - @Override - public Executor executor() { - return TransportResponseHandler.TRANSPORT_WORKER; - } - - @Override - public void handleResponse(TransportResponse response) { - fail("should not receive a response"); - } - - @Override - public void handleException(TransportException exp) { - if (calledHandleException.compareAndSet(false, true) == false) { - fail("handle exception called more than once"); - } - actualException.set(exp); - } - - @Override - public TransportResponse read(StreamInput in) { - fail("should not receive a response"); - return null; - } - }); - assertThat(actualException.get(), instanceOf(SendRequestTransportException.class)); - assertThat(actualException.get().getCause(), instanceOf(ElasticsearchSecurityException.class)); - assertThat( - actualException.get().getCause().getMessage(), - equalTo("current license is non-compliant for [" + Security.ADVANCED_REMOTE_CLUSTER_SECURITY_FEATURE.getName() + "]") - ); - assertThat(securityContext.getThreadContext().getHeader(CROSS_CLUSTER_ACCESS_SUBJECT_INFO_HEADER_KEY), nullValue()); - assertThat(securityContext.getThreadContext().getHeader(CROSS_CLUSTER_ACCESS_CREDENTIALS_HEADER_KEY), nullValue()); - } - - private Function> mockRemoteClusterCredentialsResolver( - String remoteClusterAlias - ) { - return connection -> Optional.of( - new RemoteClusterAliasWithCredentials(remoteClusterAlias, new SecureString(randomAlphaOfLengthBetween(10, 42).toCharArray())) - ); - } + private static class TestNoopRemoteClusterTransportInterceptor implements RemoteClusterTransportInterceptor { - public void testSendWithCrossClusterAccessHeadersForSystemUserRegularAction() throws Exception { - final String action; - final TransportRequest request; - if (randomBoolean()) { - action = randomAlphaOfLengthBetween(5, 30); - request = mock(TransportRequest.class); - } else { - action = ClusterStateAction.NAME; - request = mock(ClusterStateRequest.class); + @Override + public AsyncSender interceptSender(AsyncSender sender) { + return sender; } - doTestSendWithCrossClusterAccessHeaders( - true, - action, - request, - AuthenticationTestHelper.builder().internal(InternalUsers.SYSTEM_USER).build() - ); - } - - public void testSendWithCrossClusterAccessHeadersForSystemUserCcrInternalAction() throws Exception { - final String action = randomFrom( - "internal:admin/ccr/restore/session/put", - "internal:admin/ccr/restore/session/clear", - "internal:admin/ccr/restore/file_chunk/get" - ); - final TransportRequest request = mock(TransportRequest.class); - doTestSendWithCrossClusterAccessHeaders( - true, - action, - request, - AuthenticationTestHelper.builder().internal(InternalUsers.SYSTEM_USER).build() - ); - } - - public void testSendWithCrossClusterAccessHeadersForRegularUserRegularAction() throws Exception { - final Authentication authentication = randomValueOtherThanMany( - authc -> authc.getAuthenticationType() == Authentication.AuthenticationType.INTERNAL, - () -> AuthenticationTestHelper.builder().build() - ); - final String action = randomAlphaOfLengthBetween(5, 30); - final TransportRequest request = mock(TransportRequest.class); - doTestSendWithCrossClusterAccessHeaders(false, action, request, authentication); - } - - public void testSendWithCrossClusterAccessHeadersForRegularUserClusterStateAction() throws Exception { - final Authentication authentication = randomValueOtherThanMany( - authc -> authc.getAuthenticationType() == Authentication.AuthenticationType.INTERNAL, - () -> AuthenticationTestHelper.builder().build() - ); - final String action = ClusterStateAction.NAME; - final TransportRequest request = mock(ClusterStateRequest.class); - doTestSendWithCrossClusterAccessHeaders(true, action, request, authentication); - } - - private void doTestSendWithCrossClusterAccessHeaders( - boolean shouldAssertForSystemUser, - String action, - TransportRequest request, - Authentication authentication - ) throws IOException { - authentication.writeToContext(threadContext); - final String expectedRequestId = AuditUtil.getOrGenerateRequestId(threadContext); - final String remoteClusterAlias = randomAlphaOfLengthBetween(5, 10); - final String encodedApiKey = randomAlphaOfLengthBetween(10, 42); - final String remoteClusterCredential = ApiKeyService.withApiKeyPrefix(encodedApiKey); - final AuthorizationService authzService = mock(AuthorizationService.class); - // We capture the listener so that we can complete the full flow, by calling onResponse further down - @SuppressWarnings("unchecked") - final ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(ActionListener.class); - doAnswer(i -> null).when(authzService) - .getRoleDescriptorsIntersectionForRemoteCluster(any(), any(), any(), listenerCaptor.capture()); - final SecurityServerTransportInterceptor interceptor = new SecurityServerTransportInterceptor( - settings, - threadPool, - mockSslService(), - securityContext, - destructiveOperations, - new CrossClusterAccessTransportInterceptor( - settings, - threadPool, - mock(AuthenticationService.class), - authzService, - securityContext, - mock(CrossClusterAccessAuthenticationService.class), - crossClusterApiKeySignatureManager, - mockLicenseState, - ignored -> Optional.of( - new RemoteClusterAliasWithCredentials(remoteClusterAlias, new SecureString(encodedApiKey.toCharArray())) - ) - ) - ); - - final AtomicBoolean calledWrappedSender = new AtomicBoolean(false); - final AtomicReference sentAction = new AtomicReference<>(); - final AtomicReference sentCredential = new AtomicReference<>(); - final AtomicReference sentCrossClusterAccessSubjectInfo = new AtomicReference<>(); - final AsyncSender sender = interceptor.interceptSender(new AsyncSender() { - @Override - public void sendRequest( - Connection connection, - String action, - TransportRequest request, - TransportRequestOptions options, - TransportResponseHandler handler - ) { - if (calledWrappedSender.compareAndSet(false, true) == false) { - fail("sender called more than once"); - } - assertThat(securityContext.getAuthentication(), nullValue()); - assertThat(AuditUtil.extractRequestId(securityContext.getThreadContext()), equalTo(expectedRequestId)); - sentAction.set(action); - sentCredential.set(securityContext.getThreadContext().getHeader(CROSS_CLUSTER_ACCESS_CREDENTIALS_HEADER_KEY)); - try { - sentCrossClusterAccessSubjectInfo.set( - CrossClusterAccessSubjectInfo.readFromContext(securityContext.getThreadContext()) - ); - } catch (IOException e) { - fail("no exceptions expected but got " + e); - } - handler.handleResponse(null); - } - }); - final Connection connection = mock(Connection.class); - when(connection.getTransportVersion()).thenReturn(TransportVersion.current()); - - sender.sendRequest(connection, action, request, null, new TransportResponseHandler<>() { - @Override - public Executor executor() { - return TransportResponseHandler.TRANSPORT_WORKER; - } - - @Override - public void handleResponse(TransportResponse response) { - // Headers should get restored before handle response is called - assertThat(securityContext.getAuthentication(), equalTo(authentication)); - assertThat(securityContext.getThreadContext().getHeader(CROSS_CLUSTER_ACCESS_SUBJECT_INFO_HEADER_KEY), nullValue()); - assertThat(securityContext.getThreadContext().getHeader(CROSS_CLUSTER_ACCESS_CREDENTIALS_HEADER_KEY), nullValue()); - } - - @Override - public void handleException(TransportException exp) { - fail("no exceptions expected but got " + exp); - } - - @Override - public TransportResponse read(StreamInput in) { - return null; - } - }); - if (shouldAssertForSystemUser) { - assertThat( - sentCrossClusterAccessSubjectInfo.get(), - equalTo( - SystemUser.crossClusterAccessSubjectInfo( - authentication.getEffectiveSubject().getTransportVersion(), - authentication.getEffectiveSubject().getRealm().getNodeName() - ) - ) - ); - verify(authzService, never()).getRoleDescriptorsIntersectionForRemoteCluster( - eq(remoteClusterAlias), - eq(TransportVersion.current()), - eq(authentication.getEffectiveSubject()), - anyActionListener() - ); - } else { - final RoleDescriptorsIntersection expectedRoleDescriptorsIntersection = new RoleDescriptorsIntersection( - randomList(1, 3, () -> Set.copyOf(randomUniquelyNamedRoleDescriptors(0, 1))) - ); - // Call listener to complete flow - listenerCaptor.getValue().onResponse(expectedRoleDescriptorsIntersection); - verify(authzService, times(1)).getRoleDescriptorsIntersectionForRemoteCluster( - eq(remoteClusterAlias), - eq(TransportVersion.current()), - eq(authentication.getEffectiveSubject()), - anyActionListener() - ); - assertThat( - sentCrossClusterAccessSubjectInfo.get(), - equalTo(new CrossClusterAccessSubjectInfo(authentication, expectedRoleDescriptorsIntersection)) - ); - } - assertTrue(calledWrappedSender.get()); - if (action.startsWith("internal:")) { - assertThat(sentAction.get(), equalTo("indices:internal/" + action.substring("internal:".length()))); - } else { - assertThat(sentAction.get(), equalTo(action)); + @Override + public boolean isRemoteClusterConnection(Connection connection) { + return false; } - assertThat(sentCredential.get(), equalTo(remoteClusterCredential)); - verify(securityContext, never()).executeAsInternalUser(any(), any(), anyConsumer()); - assertThat(securityContext.getThreadContext().getHeader(CROSS_CLUSTER_ACCESS_SUBJECT_INFO_HEADER_KEY), nullValue()); - assertThat(securityContext.getThreadContext().getHeader(CROSS_CLUSTER_ACCESS_CREDENTIALS_HEADER_KEY), nullValue()); - assertThat(AuditUtil.extractRequestId(securityContext.getThreadContext()), equalTo(expectedRequestId)); - } - public void testSendWithUserIfCrossClusterAccessHeadersConditionNotMet() throws Exception { - boolean noCredential = randomBoolean(); - final boolean notRemoteConnection = randomBoolean(); - // Ensure at least one condition fails - if (false == (notRemoteConnection || noCredential)) { - noCredential = true; - } - final boolean finalNoCredential = noCredential; - final String remoteClusterAlias = randomAlphaOfLengthBetween(5, 10); - final String encodedApiKey = randomAlphaOfLengthBetween(10, 42); - final AuthenticationTestHelper.AuthenticationTestBuilder builder = AuthenticationTestHelper.builder(); - final Authentication authentication = randomFrom( - builder.apiKey().build(), - builder.serviceAccount().build(), - builder.user(new User(randomAlphaOfLengthBetween(3, 10), randomRoles())).realm().build() - ); - authentication.writeToContext(threadContext); - - final AuthorizationService authzService = mock(AuthorizationService.class); - final SecurityServerTransportInterceptor interceptor = new SecurityServerTransportInterceptor( - settings, - threadPool, - mockSslService(), - securityContext, - destructiveOperations, - new CrossClusterAccessTransportInterceptor( - settings, - threadPool, - mock(AuthenticationService.class), - authzService, - securityContext, - mock(CrossClusterAccessAuthenticationService.class), - crossClusterApiKeySignatureManager, - mockLicenseState, - ignored -> notRemoteConnection - ? Optional.empty() - : (finalNoCredential - ? Optional.of(new RemoteClusterAliasWithCredentials(remoteClusterAlias, null)) - : Optional.of( - new RemoteClusterAliasWithCredentials(remoteClusterAlias, new SecureString(encodedApiKey.toCharArray())) - )) - ) - ); - - final AtomicBoolean calledWrappedSender = new AtomicBoolean(false); - final AtomicReference sentAuthentication = new AtomicReference<>(); - final AsyncSender sender = interceptor.interceptSender(new AsyncSender() { - @Override - public void sendRequest( - Transport.Connection connection, - String action, - TransportRequest request, - TransportRequestOptions options, - TransportResponseHandler handler - ) { - if (calledWrappedSender.compareAndSet(false, true) == false) { - fail("sender called more than once"); - } - sentAuthentication.set(securityContext.getAuthentication()); - assertThat(securityContext.getThreadContext().getHeader(CROSS_CLUSTER_ACCESS_CREDENTIALS_HEADER_KEY), nullValue()); - assertThat(securityContext.getThreadContext().getHeader(CROSS_CLUSTER_ACCESS_SUBJECT_INFO_HEADER_KEY), nullValue()); - } - }); - final Transport.Connection connection = mock(Transport.Connection.class); - when(connection.getTransportVersion()).thenReturn(TransportVersion.current()); - sender.sendRequest(connection, "action", mock(TransportRequest.class), null, null); - assertTrue(calledWrappedSender.get()); - assertThat(sentAuthentication.get(), equalTo(authentication)); - verify(authzService, never()).getRoleDescriptorsIntersectionForRemoteCluster(any(), any(), any(), anyActionListener()); - assertThat(securityContext.getThreadContext().getHeader(CROSS_CLUSTER_ACCESS_SUBJECT_INFO_HEADER_KEY), nullValue()); - assertThat(securityContext.getThreadContext().getHeader(CROSS_CLUSTER_ACCESS_CREDENTIALS_HEADER_KEY), nullValue()); - } - - public void testSendRemoteRequestFailsIfUserHasNoRemoteIndicesPrivileges() throws Exception { - final Authentication authentication = AuthenticationTestHelper.builder() - .user(new User(randomAlphaOfLengthBetween(3, 10), randomRoles())) - .realm() - .build(); - authentication.writeToContext(threadContext); - final String remoteClusterAlias = randomAlphaOfLengthBetween(5, 10); - final String encodedApiKey = randomAlphaOfLengthBetween(10, 42); - final String remoteClusterCredential = ApiKeyService.withApiKeyPrefix(encodedApiKey); - final AuthorizationService authzService = mock(AuthorizationService.class); - - doAnswer(invocation -> { - @SuppressWarnings("unchecked") - final var listener = (ActionListener) invocation.getArgument(3); - listener.onResponse(RoleDescriptorsIntersection.EMPTY); - return null; - }).when(authzService).getRoleDescriptorsIntersectionForRemoteCluster(any(), any(), any(), anyActionListener()); - - final SecurityServerTransportInterceptor interceptor = new SecurityServerTransportInterceptor( - settings, - threadPool, - mockSslService(), - securityContext, - destructiveOperations, - new CrossClusterAccessTransportInterceptor( - settings, - threadPool, - mock(AuthenticationService.class), - authzService, - securityContext, - mock(CrossClusterAccessAuthenticationService.class), - crossClusterApiKeySignatureManager, - mockLicenseState, - ignored -> Optional.of( - new RemoteClusterAliasWithCredentials(remoteClusterAlias, new SecureString(encodedApiKey.toCharArray())) - ) - ) - ); - - final AsyncSender sender = interceptor.interceptSender(new AsyncSender() { - @Override - public void sendRequest( - Transport.Connection connection, - String action, - TransportRequest request, - TransportRequestOptions options, - TransportResponseHandler handler - ) { - fail("request should have failed"); - } - }); - final Transport.Connection connection = mock(Transport.Connection.class); - when(connection.getTransportVersion()).thenReturn(TransportVersion.current()); - - final ElasticsearchSecurityException expectedException = new ElasticsearchSecurityException("remote action denied"); - when(authzService.remoteActionDenied(authentication, "action", remoteClusterAlias)).thenReturn(expectedException); - - final var actualException = new AtomicReference(); - sender.sendRequest(connection, "action", mock(TransportRequest.class), null, new TransportResponseHandler<>() { - @Override - public Executor executor() { - return TransportResponseHandler.TRANSPORT_WORKER; - } - - @Override - public void handleResponse(TransportResponse response) { - fail("should not success"); - } - - @Override - public void handleException(TransportException exp) { - actualException.set(exp.getCause()); - } - - @Override - public TransportResponse read(StreamInput in) { - return null; - } - }); - assertThat(actualException.get(), is(expectedException)); - assertThat(securityContext.getThreadContext().getHeader(CROSS_CLUSTER_ACCESS_SUBJECT_INFO_HEADER_KEY), nullValue()); - assertThat(securityContext.getThreadContext().getHeader(CROSS_CLUSTER_ACCESS_CREDENTIALS_HEADER_KEY), nullValue()); - } - - public void testProfileFiltersCreatedDifferentlyForDifferentTransportAndRemoteClusterSslSettings() { - // filters are created irrespective of ssl enabled - final boolean transportSslEnabled = randomBoolean(); - final boolean remoteClusterSslEnabled = randomBoolean(); - final Settings.Builder builder = Settings.builder() - .put(this.settings) - .put("xpack.security.transport.ssl.enabled", transportSslEnabled) - .put("remote_cluster_server.enabled", true) - .put("xpack.security.remote_cluster_server.ssl.enabled", remoteClusterSslEnabled); - if (randomBoolean()) { - builder.put("xpack.security.remote_cluster_client.ssl.enabled", randomBoolean()); // client SSL won't be processed + @Override + public Optional getRemoteProfileTransportFilter( + SslProfile sslProfile, + DestructiveOperations destructiveOperations + ) { + return Optional.empty(); } - final SslProfile defaultProfile = mock(SslProfile.class); - when(defaultProfile.configuration()).thenReturn( - new SslConfiguration( - "xpack.security.transport.ssl", - randomBoolean(), - mock(SslTrustConfig.class), - mock(SslKeyConfig.class), - randomFrom(SslVerificationMode.values()), - SslClientAuthenticationMode.REQUIRED, - List.of("TLS_AES_256_GCM_SHA384"), - List.of("TLSv1.3"), - randomLongBetween(1, 100000) - ) - ); - final SslProfile remoteProfile = mock(SslProfile.class); - when(remoteProfile.configuration()).thenReturn( - new SslConfiguration( - "xpack.security.remote_cluster_server.ssl", - randomBoolean(), - mock(SslTrustConfig.class), - mock(SslKeyConfig.class), - randomFrom(SslVerificationMode.values()), - SslClientAuthenticationMode.NONE, - List.of(Runtime.version().feature() < 24 ? "TLS_RSA_WITH_AES_256_GCM_SHA384" : "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"), - List.of("TLSv1.2"), - randomLongBetween(1, 100000) - ) - ); - - final SSLService sslService = mock(SSLService.class); - when(sslService.profile("xpack.security.transport.ssl.")).thenReturn(defaultProfile); - - when(sslService.profile("xpack.security.remote_cluster_server.ssl.")).thenReturn(remoteProfile); - doThrow(new AssertionError("profile filters should not be configured for remote cluster client")).when(sslService) - .profile("xpack.security.remote_cluster_client.ssl."); - - final var securityServerTransportInterceptor = new SecurityServerTransportInterceptor( - builder.build(), - threadPool, - sslService, - securityContext, - destructiveOperations, - new CrossClusterAccessTransportInterceptor( - builder.build(), - threadPool, - mock(AuthenticationService.class), - mock(AuthorizationService.class), - securityContext, - mock(CrossClusterAccessAuthenticationService.class), - crossClusterApiKeySignatureManager, - mockLicenseState - ) - ); - - final Map profileFilters = securityServerTransportInterceptor.getProfileFilters(); - assertThat(profileFilters.keySet(), containsInAnyOrder("default", "_remote_cluster")); - assertThat(profileFilters.get("default").isExtractClientCert(), is(transportSslEnabled)); - assertThat(profileFilters.get("default"), not(instanceOf(CrossClusterAccessServerTransportFilter.class))); - assertThat(profileFilters.get("_remote_cluster").isExtractClientCert(), is(false)); - assertThat(profileFilters.get("_remote_cluster"), instanceOf(CrossClusterAccessServerTransportFilter.class)); - } - - public void testNoProfileFilterForRemoteClusterWhenTheFeatureIsDisabled() { - final boolean transportSslEnabled = randomBoolean(); - final Settings.Builder builder = Settings.builder() - .put(this.settings) - .put("xpack.security.transport.ssl.enabled", transportSslEnabled) - .put("remote_cluster_server.enabled", false) - .put("xpack.security.remote_cluster_server.ssl.enabled", randomBoolean()); - if (randomBoolean()) { - builder.put("xpack.security.remote_cluster_client.ssl.enabled", randomBoolean()); // client SSL won't be processed + @Override + public boolean hasRemoteClusterAccessHeadersInContext(SecurityContext securityContext) { + return false; } - - final SslProfile profile = mock(SslProfile.class); - when(profile.configuration()).thenReturn( - new SslConfiguration( - "xpack.security.transport.ssl", - randomBoolean(), - mock(SslTrustConfig.class), - mock(SslKeyConfig.class), - randomFrom(SslVerificationMode.values()), - SslClientAuthenticationMode.REQUIRED, - List.of("TLS_AES_256_GCM_SHA384"), - List.of("TLSv1.3"), - randomLongBetween(1, 100000) - ) - ); - - final SSLService sslService = mock(SSLService.class); - when(sslService.profile("xpack.security.transport.ssl.")).thenReturn(profile); - - doThrow(new AssertionError("profile filters should not be configured for remote cluster server when the port is disabled")).when( - sslService - ).profile("xpack.security.remote_cluster_server.ssl."); - doThrow(new AssertionError("profile filters should not be configured for remote cluster client")).when(sslService) - .profile("xpack.security.remote_cluster_client.ssl."); - - final var securityServerTransportInterceptor = new SecurityServerTransportInterceptor( - builder.build(), - threadPool, - sslService, - securityContext, - destructiveOperations, - new CrossClusterAccessTransportInterceptor( - builder.build(), - threadPool, - mock(AuthenticationService.class), - mock(AuthorizationService.class), - securityContext, - mock(CrossClusterAccessAuthenticationService.class), - crossClusterApiKeySignatureManager, - mockLicenseState - ) - ); - - final Map profileFilters = securityServerTransportInterceptor.getProfileFilters(); - assertThat(profileFilters.keySet(), contains("default")); - assertThat(profileFilters.get("default").isExtractClientCert(), is(transportSslEnabled)); - } - - private static SSLService mockSslService() { - final SslConfiguration defaultConfiguration = new SslConfiguration( - "", - false, - DefaultJdkTrustConfig.DEFAULT_INSTANCE, - EmptyKeyConfig.INSTANCE, - SslVerificationMode.FULL, - SslClientAuthenticationMode.NONE, - List.of("TLS_AES_256_GCM_SHA384"), - List.of("TLSv1.3"), - randomLongBetween(1, 100000) - ); - final SslProfile defaultProfile = mock(SslProfile.class); - when(defaultProfile.configuration()).thenReturn(defaultConfiguration); - final SSLService sslService = mock(SSLService.class); - when(sslService.profile("xpack.security.transport.ssl")).thenReturn(defaultProfile); - when(sslService.profile("xpack.security.transport.ssl.")).thenReturn(defaultProfile); - return sslService; - } - - private String[] randomRoles() { - return generateRandomStringArray(3, 10, false, true); - } - - @SuppressWarnings("unchecked") - private static Consumer anyConsumer() { - return any(Consumer.class); } } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/ServerTransportFilterTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/ServerTransportFilterTests.java index f6e5601c75c6a..59539ff30f9cb 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/ServerTransportFilterTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/ServerTransportFilterTests.java @@ -20,9 +20,6 @@ import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.license.MockLicenseState; -import org.elasticsearch.license.XPackLicenseState; -import org.elasticsearch.test.ESTestCase; import org.elasticsearch.transport.TransportChannel; import org.elasticsearch.transport.TransportRequest; import org.elasticsearch.transport.TransportSettings; @@ -31,25 +28,16 @@ import org.elasticsearch.xpack.core.security.authc.Authentication.RealmRef; import org.elasticsearch.xpack.core.security.authc.AuthenticationTestHelper; import org.elasticsearch.xpack.core.security.user.User; -import org.elasticsearch.xpack.security.Security; import org.elasticsearch.xpack.security.authc.AuthenticationService; -import org.elasticsearch.xpack.security.authc.CrossClusterAccessAuthenticationService; import org.elasticsearch.xpack.security.authz.AuthorizationService; import org.junit.Before; -import org.mockito.Mockito; -import org.mockito.stubbing.Answer; import java.util.Collections; -import java.util.Set; import static org.elasticsearch.test.ActionListenerUtils.anyActionListener; -import static org.elasticsearch.xpack.core.ClientHelper.SECURITY_HEADER_FILTERS; -import static org.elasticsearch.xpack.core.security.authc.CrossClusterAccessSubjectInfo.CROSS_CLUSTER_ACCESS_SUBJECT_INFO_HEADER_KEY; import static org.elasticsearch.xpack.core.security.support.Exceptions.authenticationError; import static org.elasticsearch.xpack.core.security.support.Exceptions.authorizationError; -import static org.elasticsearch.xpack.security.authc.CrossClusterAccessHeaders.CROSS_CLUSTER_ACCESS_CREDENTIALS_HEADER_KEY; import static org.hamcrest.Matchers.arrayWithSize; -import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; @@ -58,22 +46,18 @@ import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; -public class ServerTransportFilterTests extends ESTestCase { +public class ServerTransportFilterTests extends AbstractServerTransportFilterTests { private AuthenticationService authcService; private AuthorizationService authzService; private TransportChannel channel; private boolean failDestructiveOperations; private DestructiveOperations destructiveOperations; - private CrossClusterAccessAuthenticationService crossClusterAccessAuthcService; - private MockLicenseState mockLicenseState; @Before public void init() throws Exception { @@ -88,10 +72,6 @@ public void init() throws Exception { settings, new ClusterSettings(settings, Collections.singleton(DestructiveOperations.REQUIRES_NAME_SETTING)) ); - crossClusterAccessAuthcService = mock(CrossClusterAccessAuthenticationService.class); - when(crossClusterAccessAuthcService.getAuthenticationService()).thenReturn(authcService); - mockLicenseState = MockLicenseState.createMock(); - Mockito.when(mockLicenseState.isAllowed(Security.ADVANCED_REMOTE_CLUSTER_SECURITY_FEATURE)).thenReturn(true); } public void testInbound() { @@ -105,88 +85,6 @@ public void testInbound() { verify(authzService).authorize(eq(authentication), eq("_action"), eq(request), anyActionListener()); } - public void testCrossClusterAccessInbound() { - TransportRequest request = mock(TransportRequest.class); - Authentication authentication = AuthenticationTestHelper.builder().build(); - String action = randomAlphaOfLengthBetween(10, 20); - doAnswer(getAnswer(authentication)).when(authcService).authenticate(eq(action), eq(request), eq(true), anyActionListener()); - doAnswer(getAnswer(authentication, true)).when(crossClusterAccessAuthcService) - .authenticate(eq(action), eq(request), anyActionListener()); - ServerTransportFilter filter = getNodeCrossClusterAccessFilter(); - PlainActionFuture listener = spy(new PlainActionFuture<>()); - filter.inbound(action, request, channel, listener); - verify(authzService).authorize(eq(authentication), eq(action), eq(request), anyActionListener()); - verify(crossClusterAccessAuthcService).authenticate(anyString(), any(), anyActionListener()); - verify(authcService, never()).authenticate(anyString(), any(), anyBoolean(), anyActionListener()); - } - - public void testCrossClusterAccessInboundInvalidHeadersFail() { - TransportRequest request = mock(TransportRequest.class); - Authentication authentication = AuthenticationTestHelper.builder().build(); - String action = randomAlphaOfLengthBetween(10, 20); - doAnswer(getAnswer(authentication)).when(authcService).authenticate(eq(action), eq(request), eq(true), anyActionListener()); - doAnswer(getAnswer(authentication, true)).when(crossClusterAccessAuthcService) - .authenticate(eq(action), eq(request), anyActionListener()); - ServerTransportFilter filter = getNodeCrossClusterAccessFilter(Set.copyOf(randomNonEmptySubsetOf(SECURITY_HEADER_FILTERS))); - PlainActionFuture listener = new PlainActionFuture<>(); - filter.inbound(action, request, channel, listener); - var actual = expectThrows(IllegalArgumentException.class, listener::actionGet); - verifyNoMoreInteractions(authcService); - verifyNoMoreInteractions(authzService); - assertThat( - actual.getMessage(), - containsString("is not allowed for cross cluster requests through the dedicated remote cluster server port") - ); - verify(crossClusterAccessAuthcService, never()).authenticate(anyString(), any(), anyActionListener()); - } - - public void testCrossClusterAccessInboundMissingHeadersFail() { - TransportRequest request = mock(TransportRequest.class); - Authentication authentication = AuthenticationTestHelper.builder().build(); - String action = randomAlphaOfLengthBetween(10, 20); - doAnswer(getAnswer(authentication)).when(authcService).authenticate(eq(action), eq(request), eq(true), anyActionListener()); - doAnswer(getAnswer(authentication, true)).when(crossClusterAccessAuthcService) - .authenticate(eq(action), eq(request), anyActionListener()); - Settings settings = Settings.builder().put("path.home", createTempDir()).build(); - ThreadContext threadContext = new ThreadContext(settings); - String firstMissingHeader = CROSS_CLUSTER_ACCESS_CREDENTIALS_HEADER_KEY; - if (randomBoolean()) { - String headerToInclude = randomBoolean() - ? CROSS_CLUSTER_ACCESS_CREDENTIALS_HEADER_KEY - : CROSS_CLUSTER_ACCESS_SUBJECT_INFO_HEADER_KEY; - if (headerToInclude.equals(CROSS_CLUSTER_ACCESS_CREDENTIALS_HEADER_KEY)) { - firstMissingHeader = CROSS_CLUSTER_ACCESS_SUBJECT_INFO_HEADER_KEY; - } - threadContext.putHeader(headerToInclude, randomAlphaOfLength(42)); - } - ServerTransportFilter filter = new CrossClusterAccessServerTransportFilter( - crossClusterAccessAuthcService, - authzService, - threadContext, - false, - destructiveOperations, - new SecurityContext(settings, threadContext), - mockLicenseState - ); - - PlainActionFuture listener = new PlainActionFuture<>(); - filter.inbound(action, request, channel, listener); - var actual = expectThrows(IllegalArgumentException.class, listener::actionGet); - - verifyNoMoreInteractions(authcService); - verifyNoMoreInteractions(authzService); - assertThat( - actual.getMessage(), - equalTo( - "Cross cluster requests through the dedicated remote cluster server port require transport header [" - + firstMissingHeader - + "] but none found. " - + "Please ensure you have configured remote cluster credentials on the cluster originating the request." - ) - ); - verify(crossClusterAccessAuthcService, never()).authenticate(anyString(), any(), anyActionListener()); - } - public void testInboundDestructiveOperations() { String action = randomFrom(TransportCloseIndexAction.NAME, OpenIndexAction.NAME, TransportDeleteIndexAction.TYPE.name()); TransportRequest request = new MockIndicesRequest( @@ -229,49 +127,12 @@ public void testInboundAuthenticationException() { verifyNoMoreInteractions(authzService); } - public void testCrossClusterAccessInboundAuthenticationException() { - TransportRequest request = mock(TransportRequest.class); - Exception authE = authenticationError("authc failed"); - String action = randomAlphaOfLengthBetween(10, 20); - doAnswer(i -> { - final Object[] args = i.getArguments(); - assertThat(args, arrayWithSize(3)); - @SuppressWarnings("unchecked") - ActionListener callback = (ActionListener) args[args.length - 1]; - callback.onFailure(authE); - return Void.TYPE; - }).when(crossClusterAccessAuthcService).authenticate(eq(action), eq(request), anyActionListener()); - doAnswer(i -> { - final Object[] args = i.getArguments(); - assertThat(args, arrayWithSize(4)); - @SuppressWarnings("unchecked") - ActionListener callback = (ActionListener) args[args.length - 1]; - callback.onFailure(authE); - return Void.TYPE; - }).when(authcService).authenticate(eq(action), eq(request), eq(true), anyActionListener()); - ServerTransportFilter filter = getNodeCrossClusterAccessFilter(); - try { - PlainActionFuture future = new PlainActionFuture<>(); - filter.inbound(action, request, channel, future); - future.actionGet(); - fail("expected filter inbound to throw an authentication exception on authentication error"); - } catch (ElasticsearchSecurityException e) { - assertThat(e.getMessage(), equalTo("authc failed")); - } - verifyNoMoreInteractions(authzService); - verify(crossClusterAccessAuthcService).authenticate(anyString(), any(), anyActionListener()); - verify(authcService, never()).authenticate(anyString(), any(), anyBoolean(), anyActionListener()); - } - public void testInboundAuthorizationException() { - boolean crossClusterAccess = randomBoolean(); - ServerTransportFilter filter = crossClusterAccess ? getNodeCrossClusterAccessFilter() : getNodeFilter(); + ServerTransportFilter filter = getNodeFilter(); TransportRequest request = mock(TransportRequest.class); Authentication authentication = AuthenticationTestHelper.builder().build(); String action = TransportSearchAction.TYPE.name(); doAnswer(getAnswer(authentication)).when(authcService).authenticate(eq(action), eq(request), eq(true), anyActionListener()); - doAnswer(getAnswer(authentication, true)).when(crossClusterAccessAuthcService) - .authenticate(eq(action), eq(request), anyActionListener()); PlainActionFuture future = new PlainActionFuture<>(); doThrow(authorizationError("authz failed")).when(authzService) .authorize(eq(authentication), eq(action), eq(request), anyActionListener()); @@ -280,33 +141,7 @@ public void testInboundAuthorizationException() { future.actionGet(); }); assertThat(e.getMessage(), equalTo("authz failed")); - if (crossClusterAccess) { - verify(crossClusterAccessAuthcService).authenticate(anyString(), any(), anyActionListener()); - verify(authcService, never()).authenticate(anyString(), any(), anyBoolean(), anyActionListener()); - } else { - verify(authcService).authenticate(anyString(), any(), anyBoolean(), anyActionListener()); - verify(crossClusterAccessAuthcService, never()).authenticate(anyString(), any(), anyActionListener()); - } - } - - public void testCrossClusterAccessInboundFailsWithUnsupportedLicense() { - final MockLicenseState unsupportedLicenseState = MockLicenseState.createMock(); - Mockito.when(unsupportedLicenseState.isAllowed(Security.ADVANCED_REMOTE_CLUSTER_SECURITY_FEATURE)).thenReturn(false); - - ServerTransportFilter crossClusterAccessFilter = getNodeCrossClusterAccessFilter(unsupportedLicenseState); - PlainActionFuture listener = new PlainActionFuture<>(); - String action = randomAlphaOfLengthBetween(10, 20); - crossClusterAccessFilter.inbound(action, mock(TransportRequest.class), channel, listener); - - ElasticsearchSecurityException actualException = expectThrows(ElasticsearchSecurityException.class, listener::actionGet); - assertThat( - actualException.getMessage(), - equalTo("current license is non-compliant for [" + Security.ADVANCED_REMOTE_CLUSTER_SECURITY_FEATURE.getName() + "]") - ); - - // License check should be executed first, hence we don't expect authc/authz to be even attempted. - verify(crossClusterAccessAuthcService, never()).authenticate(anyString(), any(), anyActionListener()); - verifyNoInteractions(authzService, authcService); + verify(authcService).authenticate(anyString(), any(), anyBoolean(), anyActionListener()); } public void testAllowsNodeActions() { @@ -332,21 +167,6 @@ public void testAllowsNodeActions() { verifyNoMoreInteractions(authcService, authzService); } - private static Answer> getAnswer(Authentication authentication) { - return getAnswer(authentication, false); - } - - private static Answer> getAnswer(Authentication authentication, boolean crossClusterAccess) { - return i -> { - final Object[] args = i.getArguments(); - assertThat(args, arrayWithSize(crossClusterAccess ? 3 : 4)); - @SuppressWarnings("unchecked") - ActionListener callback = (ActionListener) args[args.length - 1]; - callback.onResponse(authentication); - return Void.TYPE; - }; - } - private ServerTransportFilter getNodeFilter() { Settings settings = Settings.builder().put("path.home", createTempDir()).build(); ThreadContext threadContext = new ThreadContext(settings); @@ -360,51 +180,4 @@ private ServerTransportFilter getNodeFilter() { ); } - private CrossClusterAccessServerTransportFilter getNodeCrossClusterAccessFilter() { - return getNodeCrossClusterAccessFilter(Collections.emptySet(), mockLicenseState); - } - - private CrossClusterAccessServerTransportFilter getNodeCrossClusterAccessFilter(Set additionalHeadersKeys) { - return getNodeCrossClusterAccessFilter(additionalHeadersKeys, mockLicenseState); - } - - private CrossClusterAccessServerTransportFilter getNodeCrossClusterAccessFilter(XPackLicenseState licenseState) { - return getNodeCrossClusterAccessFilter(Collections.emptySet(), licenseState); - } - - private CrossClusterAccessServerTransportFilter getNodeCrossClusterAccessFilter( - Set additionalHeadersKeys, - XPackLicenseState licenseState - ) { - Settings settings = Settings.builder().put("path.home", createTempDir()).build(); - ThreadContext threadContext = new ThreadContext(settings); - for (var header : additionalHeadersKeys) { - threadContext.putHeader(header, randomAlphaOfLength(20)); - } - // Randomly include valid headers - if (randomBoolean()) { - for (var validHeader : CrossClusterAccessServerTransportFilter.ALLOWED_TRANSPORT_HEADERS) { - // don't overwrite additionalHeadersKeys - if (false == additionalHeadersKeys.contains(validHeader)) { - threadContext.putHeader(validHeader, randomAlphaOfLength(20)); - } - } - } - var requiredHeaders = Set.of(CROSS_CLUSTER_ACCESS_CREDENTIALS_HEADER_KEY, CROSS_CLUSTER_ACCESS_SUBJECT_INFO_HEADER_KEY); - for (var header : requiredHeaders) { - // don't overwrite already present headers - if (threadContext.getHeader(header) == null) { - threadContext.putHeader(header, randomAlphaOfLength(20)); - } - } - return new CrossClusterAccessServerTransportFilter( - crossClusterAccessAuthcService, - authzService, - threadContext, - false, - destructiveOperations, - new SecurityContext(settings, threadContext), - licenseState - ); - } }