diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java index eb6049dc81367..ba7da3151d90a 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java @@ -56,6 +56,7 @@ import org.elasticsearch.common.util.Maps; import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.core.CheckedConsumer; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.core.TimeValue; @@ -70,6 +71,7 @@ import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHit; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; @@ -535,63 +537,72 @@ private void getTokenDocById( } final GetRequest getRequest = client.prepareGet(securityTokensIndex.aliasName(), getTokenDocumentId(tokenId)).request(); final Consumer onFailure = ex -> listener.onFailure(traceLog("get token from id", tokenId, ex)); + + CheckedConsumer checkedConsumer = response -> { + assert ThreadPool.assertCurrentThreadPool(ThreadPool.Names.GENERIC); + if (response.isExists() == false) { + // The chances of a random token string decoding to something that we can read is minimal, so + // we assume that this was a token we have created but is now expired/revoked and deleted + logger.trace("The token [{}] probably expired and has already been deleted", tokenId); + listener.onResponse(null); + return; + } + Map accessSource = (Map) response.getSource().get("access_token"); + Map refreshSource = (Map) response.getSource().get("refresh_token"); + boolean versionGetForRefresh = tokenVersion.onOrAfter(VERSION_GET_TOKEN_DOC_FOR_REFRESH); + if (accessSource == null) { + onFailure.accept(new IllegalStateException("token document is missing the access_token field")); + } else if (accessSource.containsKey("user_token") == false) { + onFailure.accept(new IllegalStateException("token document is missing the user_token field")); + } else if (versionGetForRefresh && accessSource.containsKey("token") == false) { + onFailure.accept(new IllegalStateException("token document is missing the user_token.token field")); + } else if (versionGetForRefresh && refreshSource != null && refreshSource.containsKey("token") == false) { + onFailure.accept(new IllegalStateException("token document is missing the refresh_token.token field")); + } else if (storedAccessToken != null && storedAccessToken.equals(accessSource.get("token")) == false) { + logger.error("The stored access token [{}] for token doc id [{}] could not be verified", storedAccessToken, tokenId); + listener.onResponse(null); + } else if (storedRefreshToken != null + && (refreshSource == null || storedRefreshToken.equals(refreshSource.get("token")) == false)) { + logger.error("The stored refresh token [{}] for token doc id [{}] could not be verified", storedRefreshToken, tokenId); + listener.onResponse(null); + } else { + listener.onResponse(new Doc(response)); + } + }; + + Consumer exceptionConsumer = e -> { + // if the index or the shard is not there / available we assume that + // the token is not valid + if (isShardNotAvailableException(e)) { + logger.warn("failed to get token doc [{}] because index [{}] is not available", tokenId, securityTokensIndex.aliasName()); + } else { + logger.error(() -> "failed to get token doc [" + tokenId + "]", e); + } + listener.onFailure(e); + }; + + // this wrapper handles a situation where the current flow (executing on the generic thread) + // finds itself unintentionally forking over to a transport_worker thread, which is a consequence of using + // client::get when calling executeAsyncWithOrigin (see NodeClient issue: https://github.com/elastic/elasticsearch/issues/97916). + // If you follow the implementation of client::get within NodeClient, you'll spot the execution happening on a transport thread. + // The wrapper below handles this situation by introducing a second fork to make sure post-processing of the GetResponse + // returns to the generic thread pool. + CheckedConsumer wrappedCheckedConsumer = resp -> client.threadPool().generic().execute(() -> { + try { + checkedConsumer.accept(resp); + } catch (Exception e) { + // reuse the exception consumer already defined + exceptionConsumer.accept(e); + } + }); + projectSecurityIndex.checkIndexVersionThenExecute( ex -> listener.onFailure(traceLog("prepare tokens index [" + securityTokensIndex.aliasName() + "]", tokenId, ex)), () -> executeAsyncWithOrigin( client.threadPool().getThreadContext(), SECURITY_ORIGIN, getRequest, - ActionListener.wrap(response -> { - if (response.isExists() == false) { - // The chances of a random token string decoding to something that we can read is minimal, so - // we assume that this was a token we have created but is now expired/revoked and deleted - logger.trace("The token [{}] probably expired and has already been deleted", tokenId); - listener.onResponse(null); - return; - } - Map accessSource = (Map) response.getSource().get("access_token"); - Map refreshSource = (Map) response.getSource().get("refresh_token"); - boolean versionGetForRefresh = tokenVersion.onOrAfter(VERSION_GET_TOKEN_DOC_FOR_REFRESH); - if (accessSource == null) { - onFailure.accept(new IllegalStateException("token document is missing the access_token field")); - } else if (accessSource.containsKey("user_token") == false) { - onFailure.accept(new IllegalStateException("token document is missing the user_token field")); - } else if (versionGetForRefresh && accessSource.containsKey("token") == false) { - onFailure.accept(new IllegalStateException("token document is missing the user_token.token field")); - } else if (versionGetForRefresh && refreshSource != null && refreshSource.containsKey("token") == false) { - onFailure.accept(new IllegalStateException("token document is missing the refresh_token.token field")); - } else if (storedAccessToken != null && storedAccessToken.equals(accessSource.get("token")) == false) { - logger.error( - "The stored access token [{}] for token doc id [{}] could not be verified", - storedAccessToken, - tokenId - ); - listener.onResponse(null); - } else if (storedRefreshToken != null - && (refreshSource == null || storedRefreshToken.equals(refreshSource.get("token")) == false)) { - logger.error( - "The stored refresh token [{}] for token doc id [{}] could not be verified", - storedRefreshToken, - tokenId - ); - listener.onResponse(null); - } else { - listener.onResponse(new Doc(response)); - } - }, e -> { - // if the index or the shard is not there / available we assume that - // the token is not valid - if (isShardNotAvailableException(e)) { - logger.warn( - "failed to get token doc [{}] because index [{}] is not available", - tokenId, - securityTokensIndex.aliasName() - ); - } else { - logger.error(() -> "failed to get token doc [" + tokenId + "]", e); - } - listener.onFailure(e); - }), + ActionListener.wrap(wrappedCheckedConsumer, exceptionConsumer), client::get ) ); @@ -881,6 +892,130 @@ private void indexInvalidation( tokensIndexManager.aliasName() ) ); + + CheckedConsumer checkedConsumer = bulkResponse -> { + ArrayList retryTokenDocIds = new ArrayList<>(); + ArrayList failedRequestResponses = new ArrayList<>(); + ArrayList previouslyInvalidated = new ArrayList<>(); + ArrayList invalidated = new ArrayList<>(); + if (null != previousResult) { + failedRequestResponses.addAll((previousResult.getErrors())); + previouslyInvalidated.addAll(previousResult.getPreviouslyInvalidatedTokens()); + invalidated.addAll(previousResult.getInvalidatedTokens()); + } + for (BulkItemResponse bulkItemResponse : bulkResponse.getItems()) { + if (bulkItemResponse.isFailed()) { + Throwable cause = bulkItemResponse.getFailure().getCause(); + final String failedTokenDocId = getTokenIdFromDocumentId(bulkItemResponse.getFailure().getId()); + if (isShardNotAvailableException(cause)) { + retryTokenDocIds.add(failedTokenDocId); + } else { + traceLog("invalidate access token", failedTokenDocId, cause); + failedRequestResponses.add(new ElasticsearchException("Error invalidating " + srcPrefix + ": ", cause)); + } + } else { + UpdateResponse updateResponse = bulkItemResponse.getResponse(); + if (updateResponse.getResult() == DocWriteResponse.Result.UPDATED) { + logger.debug(() -> format("Invalidated [%s] for doc [%s]", srcPrefix, updateResponse.getGetResult().getId())); + invalidated.add(updateResponse.getGetResult().getId()); + } else if (updateResponse.getResult() == DocWriteResponse.Result.NOOP) { + previouslyInvalidated.add(updateResponse.getGetResult().getId()); + } + } + } + if (retryTokenDocIds.isEmpty() == false && backoff.hasNext()) { + logger.debug( + "failed to invalidate [{}] tokens out of [{}], retrying to invalidate these too", + retryTokenDocIds.size(), + tokenIds.size() + ); + final TokensInvalidationResult incompleteResult = new TokensInvalidationResult( + invalidated, + previouslyInvalidated, + failedRequestResponses, + RestStatus.OK + ); + client.threadPool() + .schedule( + () -> indexInvalidation( + retryTokenDocIds, + tokensIndexManager, + backoff, + srcPrefix, + incompleteResult, + refreshPolicy, + listener + ), + backoff.next(), + client.threadPool().generic() + ); + } else { + if (retryTokenDocIds.isEmpty() == false) { + logger.warn( + "failed to invalidate [{}] tokens out of [{}] after all retries", + retryTokenDocIds.size(), + tokenIds.size() + ); + for (String retryTokenDocId : retryTokenDocIds) { + failedRequestResponses.add( + new ElasticsearchException( + "Error invalidating [{}] with doc id [{}] after retries exhausted", + srcPrefix, + retryTokenDocId + ) + ); + } + } + final TokensInvalidationResult result = new TokensInvalidationResult( + invalidated, + previouslyInvalidated, + failedRequestResponses, + RestStatus.OK + ); + listener.onResponse(result); + } + }; + Consumer exceptionConsumer = e -> { + Throwable cause = ExceptionsHelper.unwrapCause(e); + traceLog("invalidate tokens", cause); + if (isShardNotAvailableException(cause) && backoff.hasNext()) { + logger.debug("failed to invalidate tokens, retrying "); + client.threadPool() + .schedule( + () -> indexInvalidation( + tokenIds, + tokensIndexManager, + backoff, + srcPrefix, + previousResult, + refreshPolicy, + listener + ), + backoff.next(), + client.threadPool().generic() + ); + } else { + listener.onFailure(e); + } + }; + + // this wrapper handles a situation where the current flow (executing on the generic thread) + // finds itself unintentionally forking over to a transport_worker thread, which is a consequence of using + // client::bulk when calling executeAsyncWithOrigin + // (see NodeClient issue: https://github.com/elastic/elasticsearch/issues/97916). If you follow the implementation of + // client::bulk within NodeClient, you'll spot the execution happening on a transport thread. + // The wrapper below handles this situation by introducing a second fork to make sure post-processing of the BulkResponse + // returns us to the generic thread pool. + CheckedConsumer wrappedCheckedConsumer = bulkItemResponses -> client.threadPool() + .generic() + .execute(() -> { + try { + checkedConsumer.accept(bulkItemResponses); + } catch (Exception e) { + // re-use the exception consumer + exceptionConsumer.accept(e); + } + }); bulkRequestBuilder.setRefreshPolicy(refreshPolicy); tokensIndexManager.forCurrentProject() .prepareIndexIfNeededThenExecute( @@ -889,114 +1024,7 @@ private void indexInvalidation( client.threadPool().getThreadContext(), SECURITY_ORIGIN, bulkRequestBuilder.request(), - ActionListener.wrap(bulkResponse -> { - ArrayList retryTokenDocIds = new ArrayList<>(); - ArrayList failedRequestResponses = new ArrayList<>(); - ArrayList previouslyInvalidated = new ArrayList<>(); - ArrayList invalidated = new ArrayList<>(); - if (null != previousResult) { - failedRequestResponses.addAll((previousResult.getErrors())); - previouslyInvalidated.addAll(previousResult.getPreviouslyInvalidatedTokens()); - invalidated.addAll(previousResult.getInvalidatedTokens()); - } - for (BulkItemResponse bulkItemResponse : bulkResponse.getItems()) { - if (bulkItemResponse.isFailed()) { - Throwable cause = bulkItemResponse.getFailure().getCause(); - final String failedTokenDocId = getTokenIdFromDocumentId(bulkItemResponse.getFailure().getId()); - if (isShardNotAvailableException(cause)) { - retryTokenDocIds.add(failedTokenDocId); - } else { - traceLog("invalidate access token", failedTokenDocId, cause); - failedRequestResponses.add( - new ElasticsearchException("Error invalidating " + srcPrefix + ": ", cause) - ); - } - } else { - UpdateResponse updateResponse = bulkItemResponse.getResponse(); - if (updateResponse.getResult() == DocWriteResponse.Result.UPDATED) { - logger.debug( - () -> format("Invalidated [%s] for doc [%s]", srcPrefix, updateResponse.getGetResult().getId()) - ); - invalidated.add(updateResponse.getGetResult().getId()); - } else if (updateResponse.getResult() == DocWriteResponse.Result.NOOP) { - previouslyInvalidated.add(updateResponse.getGetResult().getId()); - } - } - } - if (retryTokenDocIds.isEmpty() == false && backoff.hasNext()) { - logger.debug( - "failed to invalidate [{}] tokens out of [{}], retrying to invalidate these too", - retryTokenDocIds.size(), - tokenIds.size() - ); - final TokensInvalidationResult incompleteResult = new TokensInvalidationResult( - invalidated, - previouslyInvalidated, - failedRequestResponses, - RestStatus.OK - ); - client.threadPool() - .schedule( - () -> indexInvalidation( - retryTokenDocIds, - tokensIndexManager, - backoff, - srcPrefix, - incompleteResult, - refreshPolicy, - listener - ), - backoff.next(), - client.threadPool().generic() - ); - } else { - if (retryTokenDocIds.isEmpty() == false) { - logger.warn( - "failed to invalidate [{}] tokens out of [{}] after all retries", - retryTokenDocIds.size(), - tokenIds.size() - ); - for (String retryTokenDocId : retryTokenDocIds) { - failedRequestResponses.add( - new ElasticsearchException( - "Error invalidating [{}] with doc id [{}] after retries exhausted", - srcPrefix, - retryTokenDocId - ) - ); - } - } - final TokensInvalidationResult result = new TokensInvalidationResult( - invalidated, - previouslyInvalidated, - failedRequestResponses, - RestStatus.OK - ); - listener.onResponse(result); - } - }, e -> { - Throwable cause = ExceptionsHelper.unwrapCause(e); - traceLog("invalidate tokens", cause); - if (isShardNotAvailableException(cause) && backoff.hasNext()) { - logger.debug("failed to invalidate tokens, retrying "); - client.threadPool() - .schedule( - () -> indexInvalidation( - tokenIds, - tokensIndexManager, - backoff, - srcPrefix, - previousResult, - refreshPolicy, - listener - ), - backoff.next(), - client.threadPool().generic() - ); - } else { - listener.onFailure(e); - } - }), + ActionListener.wrap(wrappedCheckedConsumer, exceptionConsumer), client::bulk ) ); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/oidc/TransportOpenIdConnectLogoutActionTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/oidc/TransportOpenIdConnectLogoutActionTests.java index 84c1544b3b334..383a19c1afc85 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/oidc/TransportOpenIdConnectLogoutActionTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/oidc/TransportOpenIdConnectLogoutActionTests.java @@ -25,6 +25,7 @@ import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.Tuple; import org.elasticsearch.env.Environment; @@ -105,6 +106,7 @@ public void setup() throws Exception { final var defaultContext = threadContext.newStoredContext(); final ThreadPool threadPool = mock(ThreadPool.class); when(threadPool.getThreadContext()).thenReturn(threadContext); + when(threadPool.generic()).thenReturn(EsExecutors.DIRECT_EXECUTOR_SERVICE); AuthenticationTestHelper.builder() .user(new User("kibana")) .realmRef(new Authentication.RealmRef("realm", "type", "node")) diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java index 20daf13a45ac6..e412f4a550b6d 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java @@ -122,6 +122,7 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Consumer; @@ -1940,6 +1941,7 @@ public void testAuthenticateWithToken() throws Exception { when(projectIndex.isAvailable(SecurityIndexManager.Availability.PRIMARY_SHARDS)).thenReturn(true); when(projectIndex.isAvailable(SecurityIndexManager.Availability.SEARCH_SHARDS)).thenReturn(true); when(projectIndex.indexExists()).thenReturn(true); + CountDownLatch latch = new CountDownLatch(1); try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { threadContext.putHeader("Authorization", "Bearer " + token); boolean requestIdAlreadyPresent = randomBoolean(); @@ -1962,8 +1964,10 @@ public void testAuthenticateWithToken() throws Exception { verify(operatorPrivilegesService).maybeMarkOperatorUser(eq(result), eq(threadContext)); setCompletedToTrue(completed); verify(auditTrail).authenticationSuccess(eq(reqId.get()), eq(result), eq("_action"), same(transportRequest)); + latch.countDown(); }, this::logAndFail)); } + latch.await(1, TimeUnit.SECONDS); assertTrue(completed.get()); verifyNoMoreInteractions(auditTrail); } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/SecondaryAuthenticatorTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/SecondaryAuthenticatorTests.java index 260702cb36fa0..e6b5b62e6974f 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/SecondaryAuthenticatorTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/SecondaryAuthenticatorTests.java @@ -68,7 +68,9 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; @@ -334,9 +336,12 @@ public void testAuthenticateUsingBearerToken() throws Exception { SecurityMocks.mockGetRequest(client, SecuritySystemIndices.SECURITY_TOKENS_ALIAS, tokenDocId.get(), tokenSource.get()); final TransportRequest request = AuthenticateRequest.INSTANCE; + CountDownLatch latch = new CountDownLatch(1); final PlainActionFuture future = new PlainActionFuture<>(); + ActionListener.runAfter(future, latch::countDown); authenticator.authenticate(AuthenticateAction.NAME, request, future); + latch.await(1, TimeUnit.SECONDS); final SecondaryAuthentication secondaryAuthentication = future.result(); assertThat(secondaryAuthentication, Matchers.notNullValue()); assertThat(secondaryAuthentication.getAuthentication(), Matchers.notNullValue());