diff --git a/docs/changelog/136386.yaml b/docs/changelog/136386.yaml new file mode 100644 index 0000000000000..c3de13a8c3e51 --- /dev/null +++ b/docs/changelog/136386.yaml @@ -0,0 +1,5 @@ +pr: 136386 +summary: Limit concurrent TLS handshakes +area: Network +type: enhancement +issues: [] diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java index def517d21f91e..cdc9e1a50cc79 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java @@ -38,6 +38,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.action.support.SubscribableListener; +import org.elasticsearch.common.component.Lifecycle; import org.elasticsearch.common.network.CloseableChannel; import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.network.ThreadWatchdog; @@ -107,6 +109,8 @@ public class Netty4HttpServerTransport extends AbstractHttpServerTransport { private volatile ServerBootstrap serverBootstrap; private volatile SharedGroupFactory.SharedGroup sharedGroup; + private final TlsHandshakeThrottleManager tlsHandshakeThrottleManager; + public Netty4HttpServerTransport( Settings settings, NetworkService networkService, @@ -144,6 +148,8 @@ public Netty4HttpServerTransport( this.readTimeoutMillis = Math.toIntExact(SETTING_HTTP_READ_TIMEOUT.get(settings).getMillis()); + this.tlsHandshakeThrottleManager = new TlsHandshakeThrottleManager(clusterSettings, telemetryProvider.getMeterRegistry()); + ByteSizeValue receivePredictor = Netty4Plugin.SETTING_HTTP_NETTY_RECEIVE_PREDICTOR_SIZE.get(settings); recvByteBufAllocator = new FixedRecvByteBufAllocator(receivePredictor.bytesAsInt()); @@ -231,6 +237,9 @@ protected void startInternal() { if (acceptChannelPredicate != null) { acceptChannelPredicate.setBoundAddress(boundAddress()); } + + tlsHandshakeThrottleManager.start(); + success = true; } finally { if (success == false) { @@ -250,6 +259,9 @@ protected HttpServerChannel bind(InetSocketAddress socketAddress) throws Excepti @Override protected void stopInternal() { + if (tlsHandshakeThrottleManager.lifecycleState() != Lifecycle.State.INITIALIZED) { + tlsHandshakeThrottleManager.stop(); + } if (sharedGroup != null) { sharedGroup.shutdown(); sharedGroup = null; @@ -329,7 +341,29 @@ protected void initChannel(Channel ch) throws Exception { ); } if (tlsConfig.isTLSEnabled()) { - ch.pipeline().addLast("ssl", new SslHandler(tlsConfig.createServerSSLEngine())); + final var sslHandler = new SslHandler(tlsConfig.createServerSSLEngine()); + final var tlsHandshakeThrottle = transport.tlsHandshakeThrottleManager.getThrottleForCurrentThread(); + + if (tlsHandshakeThrottle == null) { + // throttling currently disabled + ch.pipeline().addLast("ssl", sslHandler); + } else { + final var handshakeCompletePromise = new SubscribableListener(); + ch.pipeline() + // accumulate data until the initial handshake + .addLast( + "initial-tls-handshake-throttle", + tlsHandshakeThrottle.newHandshakeThrottleHandler(handshakeCompletePromise) + ) + // actually do the TLS processing + .addLast("ssl", sslHandler) + // watch for the completion of this channel's initial handshake at which point we can release one for another + // channel + .addLast( + "initial-tls-handshake-completion-watcher", + tlsHandshakeThrottle.newHandshakeCompletionWatcher(handshakeCompletePromise) + ); + } } final var threadWatchdogActivityTracker = transport.threadWatchdog.getActivityTrackerForCurrentThread(); ch.pipeline() diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/TlsHandshakeThrottleManager.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/TlsHandshakeThrottleManager.java new file mode 100644 index 0000000000000..89a1a9a51891c --- /dev/null +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/TlsHandshakeThrottleManager.java @@ -0,0 +1,427 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.http.netty4; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleUserEventChannelHandler; +import io.netty.handler.ssl.SslClientHelloHandler; +import io.netty.handler.ssl.SslCompletionEvent; +import io.netty.util.ReferenceCounted; +import io.netty.util.concurrent.Future; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.SubscribableListener; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.component.AbstractLifecycleComponent; +import org.elasticsearch.common.settings.ClusterSettings; +import org.elasticsearch.common.settings.Setting; +import org.elasticsearch.common.util.concurrent.AbstractRunnable; +import org.elasticsearch.common.util.concurrent.ConcurrentCollections; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.node.NodeClosedException; +import org.elasticsearch.telemetry.metric.LongWithAttributes; +import org.elasticsearch.telemetry.metric.MeterRegistry; +import org.elasticsearch.transport.NodeDisconnectedException; +import org.elasticsearch.transport.netty4.Netty4Plugin; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.function.Supplier; +import java.util.function.ToLongFunction; + +/** + * Allows to limit the number of in-flight TLS handshakes for inbound HTTPS connections processed by each event loop, protecting against a + * thundering herd of fresh connections. + */ +class TlsHandshakeThrottleManager extends AbstractLifecycleComponent { + + private static final Logger logger = LogManager.getLogger(TlsHandshakeThrottleManager.class); + + private final ClusterSettings clusterSettings; + private final MeterRegistry meterRegistry; + private final List metricsToClose = new ArrayList<>(3); + volatile int maxInProgressTlsHandshakes; + volatile int maxDelayedTlsHandshakes; + + private final Map tlsHandshakeThrottles = ConcurrentCollections.newConcurrentMap(); + + TlsHandshakeThrottleManager(ClusterSettings clusterSettings, MeterRegistry meterRegistry) { + this.clusterSettings = clusterSettings; + this.meterRegistry = meterRegistry; + } + + @SuppressWarnings("unchecked") + private static Setting getRegisteredInstance(ClusterSettings clusterSettings, Setting setting) { + // wtf Netty4Plugin ends up loaded twice in different classloaders, so we have to look up the setting instances by name + return (Setting) Objects.requireNonNull(clusterSettings.get(setting.getKey())); + } + + @Override + protected void doStart() { + clusterSettings.initializeAndWatch( + getRegisteredInstance(clusterSettings, Netty4Plugin.SETTING_HTTP_NETTY_TLS_HANDSHAKES_MAX_IN_PROGRESS), + maxInProgressTlsHandshakes -> this.maxInProgressTlsHandshakes = maxInProgressTlsHandshakes + ); + clusterSettings.initializeAndWatch( + getRegisteredInstance(clusterSettings, Netty4Plugin.SETTING_HTTP_NETTY_TLS_HANDSHAKES_MAX_DELAYED), + maxDelayedTlsHandshakes -> this.maxDelayedTlsHandshakes = maxDelayedTlsHandshakes + ); + + metricsToClose.add( + meterRegistry.registerLongGauge( + "es.http.tls_handshakes.in_progress.current", + "current number of in-progress TLS handshakes for HTTP connections", + "count", + getMetric(TlsHandshakeThrottle::getInProgressHandshakesCount) + ) + ); + metricsToClose.add( + meterRegistry.registerLongGauge( + "es.http.tls_handshakes.delayed.current", + "current number of delayed TLS handshakes for HTTP connections", + "count", + getMetric(TlsHandshakeThrottle::getCurrentDelayedHandshakesCount) + ) + ); + metricsToClose.add( + meterRegistry.registerLongAsyncCounter( + "es.http.tls_handshakes.delayed.total", + "total number of TLS handshakes for HTTP connections that were delayed due to throttling", + "count", + getMetric(TlsHandshakeThrottle::getTotalDelayedHandshakesCount) + ) + ); + metricsToClose.add( + meterRegistry.registerLongAsyncCounter( + "es.http.tls_handshakes.dropped.total", + "number of TLS handshakes for HTTP connections dropped due to throttling", + "count", + getMetric(TlsHandshakeThrottle::getDroppedHandshakesCount) + ) + ); + } + + @Override + protected void doStop() { + tlsHandshakeThrottles.values().forEach(TlsHandshakeThrottle::close); + for (var metricToClose : metricsToClose) { + try { + metricToClose.close(); + } catch (Exception e) { + assert false : e; + logger.error(Strings.format("exception closing metric [%s]", metricToClose), e); + } + } + } + + @Override + protected void doClose() {} + + private Supplier getMetric(ToLongFunction metricFunction) { + return () -> { + long result = 0L; + for (var tlsHandshakeThrottle : tlsHandshakeThrottles.values()) { + result += metricFunction.applyAsLong(tlsHandshakeThrottle); + } + return new LongWithAttributes(result); + }; + } + + @Nullable // if throttling disabled + TlsHandshakeThrottle getThrottleForCurrentThread() { + synchronized (lifecycle) { + if (lifecycle.stoppedOrClosed()) { + throw new IllegalStateException("HTTP transport is already stopped"); + } + if (maxInProgressTlsHandshakes == 0) { + return null; + } + return tlsHandshakeThrottles.computeIfAbsent(Thread.currentThread(), ignored -> new TlsHandshakeThrottle()); + } + } + + /** + * A throttle on TLS handshakes for incoming HTTP connections for a single event loop thread. + */ + class TlsHandshakeThrottle { + + // volatile for metrics + private volatile int inProgressHandshakesCount = 0; + + // actions to run (or fail) to release a throttled handshake + private final ArrayDeque delayedHandshakes = new ArrayDeque<>(); + + // delayedHandshakes.size() but tracked separately for metrics + private volatile int delayedHandshakesCount = 0; + + // for metrics + private volatile long totalDelayedHandshakesCount = 0; + private volatile long droppedHandshakesCount = 0; + + private AbstractRunnable takeFirstDelayedHandshake() { + final var result = delayedHandshakes.removeFirst(); + delayedHandshakesCount = delayedHandshakes.size(); + return result; + } + + private AbstractRunnable takeLastDelayedHandshake() { + final var result = delayedHandshakes.removeLast(); + delayedHandshakesCount = delayedHandshakes.size(); + return result; + } + + private void addDelayedHandshake(AbstractRunnable abstractRunnable) { + delayedHandshakes.addFirst(abstractRunnable); + // noinspection NonAtomicOperationOnVolatileField all writes are on this thread + totalDelayedHandshakesCount += 1; + delayedHandshakesCount = delayedHandshakes.size(); + } + + void close() { + while (delayedHandshakes.isEmpty() == false) { + takeFirstDelayedHandshake().onFailure(new NodeClosedException((DiscoveryNode) null)); + } + } + + ChannelHandler newHandshakeThrottleHandler(SubscribableListener handshakeCompletePromise) { + return new HandshakeThrottleHandler(handshakeCompletePromise); + } + + ChannelHandler newHandshakeCompletionWatcher(SubscribableListener handshakeCompletePromise) { + return new HandshakeCompletionWatcher(handshakeCompletePromise); + } + + void handleHandshakeCompletion() { + if (delayedHandshakes.isEmpty()) { + // noinspection NonAtomicOperationOnVolatileField all writes are on this thread + inProgressHandshakesCount -= 1; + } else { + takeFirstDelayedHandshake().run(); + } + } + + public int getInProgressHandshakesCount() { + return inProgressHandshakesCount; + } + + public int getCurrentDelayedHandshakesCount() { + return delayedHandshakesCount; + } + + public long getTotalDelayedHandshakesCount() { + return totalDelayedHandshakesCount; + } + + public long getDroppedHandshakesCount() { + return droppedHandshakesCount; + } + + /** + * A Netty pipeline handler that aggregates inbound messages until it receives a full TLS {@code ClientHello} and then either + * passes all the received messages on down the pipeline (if not throttled) or else delays that work until another TLS handshake + * completes (if too many such handshakes are already in flight). + */ + private class HandshakeThrottleHandler extends SslClientHelloHandler { + + /** + * Promise which accumulates the messages received until we receive a full handshake. Completed when we receive a full + * handshake, at which point all the delayed messages are pushed down the pipeline for actual processing. + */ + private final SubscribableListener handshakeStartedPromise = new SubscribableListener<>(); + + /** + * Promise which will be completed by the channel's matching {@link HandshakeCompletionWatcher} when the handshake we sent down + * the pipeline has completed. + */ + private final SubscribableListener handshakeCompletePromise; + + HandshakeThrottleHandler(SubscribableListener handshakeCompletePromise) { + this.handshakeCompletePromise = handshakeCompletePromise; + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof ReferenceCounted referenceCounted) { + referenceCounted.retain(); + } + handshakeStartedPromise.addListener(new ActionListener<>() { + @Override + public void onResponse(Void unused) { + ctx.fireChannelRead(msg); + } + + @Override + public void onFailure(Exception e) { + if (msg instanceof ReferenceCounted referenceCounted) { + referenceCounted.release(); + } + } + }); + super.channelRead(ctx, msg); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + handshakeStartedPromise.addListener(new ActionListener<>() { + @Override + public void onResponse(Void unused) { + ctx.fireChannelReadComplete(); + } + + @Override + public void onFailure(Exception e) {} + }); + super.channelReadComplete(ctx); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + handshakeStartedPromise.onFailure( + new NodeDisconnectedException(null, "connection closed before handshake started", null, null) + ); + super.channelInactive(ctx); + } + + @Override + protected Future lookup(ChannelHandlerContext ctx, ByteBuf clientHello) { + if (clientHello == null) { + logger.debug("lookup with no ClientHello, closing [{}]", ctx.channel()); + ctx.channel().close(); + final var exception = new IllegalArgumentException( + "did not receive initial ClientHello on channel [" + ctx.channel() + "]" + ); + handshakeStartedPromise.onFailure(exception); + return ctx.executor().newFailedFuture(exception); + } + + if (ctx.channel().isActive() == false) { + logger.debug("lookup after channel inactive, ignoring [{}]", ctx.channel()); + final var exception = new NodeDisconnectedException( + null, + "lookup after channel inactive [" + ctx.channel() + "]", + null, + null + ); + handshakeStartedPromise.onFailure(exception); + return ctx.executor().newFailedFuture(exception); + } + + final var maxInProgressTlsHandshakes = TlsHandshakeThrottleManager.this.maxInProgressTlsHandshakes; // single volatile read + if (maxInProgressTlsHandshakes == 0 || inProgressHandshakesCount < maxInProgressTlsHandshakes) { + // noinspection NonAtomicOperationOnVolatileField all writes are on this thread + inProgressHandshakesCount += 1; + handshakeCompletePromise.addListener(ActionListener.running(TlsHandshakeThrottle.this::handleHandshakeCompletion)); + ctx.channel().pipeline().remove(HandshakeThrottleHandler.this); + handshakeStartedPromise.onResponse(null); + } else { + logger.debug( + "[{}] in-progress TLS handshakes already, enqueueing new handshake on [{}]", + inProgressHandshakesCount, + ctx.channel() + ); + addDelayedHandshake(new AbstractRunnable() { + @Override + public void onFailure(Exception e) { + logger.debug( + "[{}] in-progress and [{}] delayed TLS handshakes, cancelling handshake on [{}]: {}", + inProgressHandshakesCount, + delayedHandshakes.size(), + ctx.channel(), + e.getMessage() + ); + ctx.channel().close(); + } + + @Override + protected void doRun() { + logger.debug( + "[{}] in flight and [{}] delayed TLS handshakes, processing delayed handshake on [{}]", + inProgressHandshakesCount, + delayedHandshakes.size(), + ctx.channel() + ); + handshakeCompletePromise.addListener( + ActionListener.running(TlsHandshakeThrottle.this::handleHandshakeCompletion) + ); + ctx.pipeline().remove(HandshakeThrottleHandler.this); + handshakeStartedPromise.onResponse(null); + } + + @Override + public String toString() { + return "delayed handshake on [" + ctx.channel() + "]"; + } + }); + + final var maxDelayedTlsHandshakes = TlsHandshakeThrottleManager.this.maxDelayedTlsHandshakes; // single volatile read + while (delayedHandshakes.size() > maxDelayedTlsHandshakes) { + final var lastDelayedHandshake = takeLastDelayedHandshake(); + // noinspection NonAtomicOperationOnVolatileField all writes are on this thread + droppedHandshakesCount += 1; + lastDelayedHandshake.onFailure(new ElasticsearchException("too many in-flight TLS handshakes")); + } + } + + ctx.read(); // auto-read is disabled but we must watch for client-close + return ctx.executor().newSucceededFuture(null); + } + + @Override + protected void onLookupComplete(ChannelHandlerContext ctx, Future future) {} + } + + /** + * A Netty pipeline handler that watches for a user event indicating a TLS handshake completed (or else the channel closed). On + * completion of a handshake, this handler removes itself from the pipeline and completes a promise which will in turn call + * {@link #handleHandshakeCompletion} to trigger either the processing of another handshake or a decrement of + * {@link #inProgressHandshakesCount}. + */ + private static class HandshakeCompletionWatcher extends SimpleUserEventChannelHandler { + private final SubscribableListener handshakeCompletePromise; + + HandshakeCompletionWatcher(SubscribableListener handshakeCompletePromise) { + this.handshakeCompletePromise = handshakeCompletePromise; + } + + @Override + protected void eventReceived(ChannelHandlerContext ctx, SslCompletionEvent evt) { + ctx.pipeline().remove(HandshakeCompletionWatcher.this); + if (evt.isSuccess()) { + handshakeCompletePromise.onResponse(null); + } else { + ExceptionsHelper.maybeDieOnAnotherThread(evt.cause()); + handshakeCompletePromise.onFailure( + evt.cause() instanceof Exception exception + ? exception + : new ElasticsearchException("TLS handshake failed", evt.cause()) + ); + } + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + if (handshakeCompletePromise.isDone() == false) { + handshakeCompletePromise.onFailure(new ElasticsearchException("channel closed before TLS handshake completed")); + } + super.channelInactive(ctx); + } + } + } +} diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Plugin.java b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Plugin.java index a69693c1c7af5..04355999d7263 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Plugin.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Plugin.java @@ -122,6 +122,53 @@ public class Netty4Plugin extends Plugin implements NetworkPlugin { Setting.Property.NodeScope ); + /* + * [NOTE: TLS Handshake Throttling] + * + * Each TLS handshake takes around 2.5ms of CPU to process, so each transport worker thread can process up to 400 handshakes per second. + * This is much slower than the rate at which we can accept new connections, so the handshakes can form a backlog of work. Clients + * typically impose a 10s timeout on TLS handshakes, so if we fall behind by more than 4000 handshakes then (even without any other + * CPU-bound work) each new client's handshake will take more than 10s to reach the head of the queue, and yet we will still attempt to + * complete it, delaying yet more client's handshake attempts, and ending up in a state where egregiously few new clients will be able + * to connect. + * + * We prevent this by restricting the number of handshakes in progress at once: by default we permit a backlog of up to 2000 handshakes + * per worker. This represents 5s of CPU time, half of the usual client timeout of 10s, which should be enough margin that we can work + * through this backlog before any of them time out (even in the -- likely -- situation that the CPU has something other than TLS + * handshakes to do). + * + * By default, the permitted 2000 handshakes are further divided into 1000 in-flight handshake tasks (2.5s of CPU time) enqueued on the + * Netty event loop as normal, and 1000 more delayed handshake tasks which are held in a separate queue and processed in LIFO order. The + * LIFO order yields better behaviour than FIFO in the situation that we cannot even spend 50% of CPU time on TLS handshakes, because in + * that case some of the enqueued handshakes will still hit the client timeout, so there's more value in focussing our limited attention + * on younger handshakes which we're more likely to complete before timing out. + * + * In future we may decide to adjust this division of work dynamically based on available CPU time, rather than relying on constant + * limits as described above. + */ + + /** + * Maximum number of in-flight TLS handshakes to permit on each event loop. + */ + public static final Setting SETTING_HTTP_NETTY_TLS_HANDSHAKES_MAX_IN_PROGRESS = intSetting( + "http.netty.tls_handshakes.max_in_progress", + 1000, // See [NOTE: TLS Handshake Throttling] above + 0, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + /** + * Maximum number of TLS handshakes to delay by holding in a queue on each event loop. + */ + public static final Setting SETTING_HTTP_NETTY_TLS_HANDSHAKES_MAX_DELAYED = intSetting( + "http.netty.tls_handshakes.max_delayed", + 1000, // See [NOTE: TLS Handshake Throttling] above + 0, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + private final SetOnce groupFactory = new SetOnce<>(); @Override @@ -130,6 +177,8 @@ public List> getSettings() { SETTING_HTTP_NETTY_MAX_COMPOSITE_BUFFER_COMPONENTS, SETTING_HTTP_WORKER_COUNT, SETTING_HTTP_NETTY_RECEIVE_PREDICTOR_SIZE, + SETTING_HTTP_NETTY_TLS_HANDSHAKES_MAX_IN_PROGRESS, + SETTING_HTTP_NETTY_TLS_HANDSHAKES_MAX_DELAYED, WORKER_COUNT, NETTY_RECEIVE_PREDICTOR_SIZE, NETTY_RECEIVE_PREDICTOR_MIN, diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4BadRequestTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4BadRequestTests.java index 4e44c99921853..50dfeb6fb5a8b 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4BadRequestTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4BadRequestTests.java @@ -14,7 +14,6 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.common.network.NetworkService; -import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.util.concurrent.ThreadContext; @@ -38,6 +37,7 @@ import java.util.Collection; import java.util.Collections; +import static org.elasticsearch.http.netty4.Netty4TestUtils.randomClusterSettings; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -84,7 +84,7 @@ public void dispatchBadRequest(RestChannel channel, ThreadContext threadContext, threadPool, xContentRegistry(), dispatcher, - new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS), + randomClusterSettings(), new SharedGroupFactory(Settings.EMPTY), TelemetryProvider.NOOP, TLSConfig.noTLS(), diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerPipeliningTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerPipeliningTests.java index e36c50c46b779..cd32e8c6f61e9 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerPipeliningTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerPipeliningTests.java @@ -17,7 +17,6 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.network.NetworkService; -import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.http.AggregatingDispatcher; @@ -43,6 +42,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import static org.elasticsearch.http.netty4.Netty4TestUtils.randomClusterSettings; import static org.hamcrest.Matchers.contains; /** @@ -104,7 +104,7 @@ class CustomNettyHttpServerTransport extends Netty4HttpServerTransport { Netty4HttpServerPipeliningTests.this.threadPool, xContentRegistry(), new AggregatingDispatcher(), - new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS), + randomClusterSettings(), new SharedGroupFactory(settings), TelemetryProvider.NOOP, TLSConfig.noTLS(), diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java index a6cd7ac13be24..bbcbc61ae2c43 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java @@ -66,7 +66,6 @@ import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.TimeValue; import org.elasticsearch.core.Tuple; -import org.elasticsearch.http.AbstractHttpServerTransportTestCase; import org.elasticsearch.http.AggregatingDispatcher; import org.elasticsearch.http.BindHttpException; import org.elasticsearch.http.CorsHandler; @@ -80,6 +79,7 @@ import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestResponse; import org.elasticsearch.telemetry.TelemetryProvider; +import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.rest.FakeRestRequest; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; @@ -113,6 +113,7 @@ import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN; import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ENABLED; import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_SERVER_SHUTDOWN_GRACE_PERIOD; +import static org.elasticsearch.http.netty4.Netty4TestUtils.randomClusterSettings; import static org.elasticsearch.rest.RestStatus.BAD_REQUEST; import static org.elasticsearch.rest.RestStatus.OK; import static org.elasticsearch.rest.RestStatus.UNAUTHORIZED; @@ -130,7 +131,7 @@ /** * Tests for the {@link Netty4HttpServerTransport} class. */ -public class Netty4HttpServerTransportTests extends AbstractHttpServerTransportTestCase { +public class Netty4HttpServerTransportTests extends ESTestCase { private NetworkService networkService; private ThreadPool threadPool; diff --git a/test/framework/src/main/java/org/elasticsearch/http/AbstractHttpServerTransportTestCase.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4TestUtils.java similarity index 55% rename from test/framework/src/main/java/org/elasticsearch/http/AbstractHttpServerTransportTestCase.java rename to modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4TestUtils.java index fd260c015e505..68b4a99bba656 100644 --- a/test/framework/src/main/java/org/elasticsearch/http/AbstractHttpServerTransportTestCase.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4TestUtils.java @@ -6,18 +6,28 @@ * your election, the "Elastic License 2.0", the "GNU Affero General Public * License v3.0 only", or the "Server Side Public License, v 1". */ -package org.elasticsearch.http; + +package org.elasticsearch.http.netty4; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.http.HttpTransportSettings; +import org.elasticsearch.transport.netty4.Netty4Plugin; + +import static org.elasticsearch.test.ESTestCase.randomBoolean; -public class AbstractHttpServerTransportTestCase extends ESTestCase { +public enum Netty4TestUtils { + ; - protected static ClusterSettings randomClusterSettings() { + public static ClusterSettings randomClusterSettings() { return new ClusterSettings( Settings.builder().put(HttpTransportSettings.SETTING_HTTP_CLIENT_STATS_ENABLED.getKey(), randomBoolean()).build(), - ClusterSettings.BUILT_IN_CLUSTER_SETTINGS + Sets.addToCopy( + ClusterSettings.BUILT_IN_CLUSTER_SETTINGS, + Netty4Plugin.SETTING_HTTP_NETTY_TLS_HANDSHAKES_MAX_IN_PROGRESS, + Netty4Plugin.SETTING_HTTP_NETTY_TLS_HANDSHAKES_MAX_DELAYED + ) ); } } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HttpServerTransportCloseNotifyTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HttpServerTransportCloseNotifyTests.java index e96415a848c17..029c7eb497f9d 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HttpServerTransportCloseNotifyTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HttpServerTransportCloseNotifyTests.java @@ -31,7 +31,6 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.env.TestEnvironment; -import org.elasticsearch.http.AbstractHttpServerTransportTestCase; import org.elasticsearch.http.AggregatingDispatcher; import org.elasticsearch.http.netty4.Netty4HttpServerTransport; import org.elasticsearch.rest.RestChannel; @@ -39,6 +38,7 @@ import org.elasticsearch.rest.RestResponse; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.telemetry.TelemetryProvider; +import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.netty4.SharedGroupFactory; @@ -60,8 +60,9 @@ import javax.net.ssl.SSLException; import static org.elasticsearch.test.SecuritySettingsSource.addSSLSettingsForNodePEMFiles; +import static org.elasticsearch.xpack.security.transport.netty4.SecurityNetty4TestUtils.randomClusterSettings; -public class SecurityNetty4HttpServerTransportCloseNotifyTests extends AbstractHttpServerTransportTestCase { +public class SecurityNetty4HttpServerTransportCloseNotifyTests extends ESTestCase { private static T safePoll(BlockingQueue queue) { try { diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HttpServerTransportTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HttpServerTransportTests.java index 4ca8a9f373c4a..71f40078076d8 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HttpServerTransportTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HttpServerTransportTests.java @@ -30,7 +30,6 @@ import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.env.Environment; import org.elasticsearch.env.TestEnvironment; -import org.elasticsearch.http.AbstractHttpServerTransportTestCase; import org.elasticsearch.http.AggregatingDispatcher; import org.elasticsearch.http.HttpHeadersValidationException; import org.elasticsearch.http.HttpRequest; @@ -45,6 +44,7 @@ import org.elasticsearch.rest.RestResponse; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.telemetry.TelemetryProvider; +import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.rest.FakeRestRequest; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; @@ -68,6 +68,7 @@ import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; import static org.elasticsearch.transport.Transports.TEST_MOCK_TRANSPORT_THREAD_PREFIX; +import static org.elasticsearch.xpack.security.transport.netty4.SecurityNetty4TestUtils.randomClusterSettings; import static org.elasticsearch.xpack.security.transport.netty4.SimpleSecurityNetty4ServerTransportTests.randomCapitalization; import static org.hamcrest.Matchers.arrayContaining; import static org.hamcrest.Matchers.containsString; @@ -79,7 +80,7 @@ import static org.hamcrest.Matchers.nullValue; import static org.mockito.Mockito.mock; -public class SecurityNetty4HttpServerTransportTests extends AbstractHttpServerTransportTestCase { +public class SecurityNetty4HttpServerTransportTests extends ESTestCase { private SSLService sslService; private Environment env; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HttpServerTransportTlsHandshakeThrottleTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HttpServerTransportTlsHandshakeThrottleTests.java new file mode 100644 index 0000000000000..9eaeea5c2837f --- /dev/null +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HttpServerTransportTlsHandshakeThrottleTests.java @@ -0,0 +1,1000 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.security.transport.netty4; + +import io.netty.bootstrap.Bootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.SimpleUserEventChannelHandler; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.SslHandshakeCompletionEvent; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.SubscribableListener; +import org.elasticsearch.common.network.NetworkService; +import org.elasticsearch.common.settings.ClusterSettings; +import org.elasticsearch.common.settings.Setting; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.transport.TransportAddress; +import org.elasticsearch.common.util.concurrent.ConcurrentCollections; +import org.elasticsearch.common.util.concurrent.RunOnce; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.env.TestEnvironment; +import org.elasticsearch.http.HttpServerTransport; +import org.elasticsearch.http.netty4.Netty4HttpServerTransport; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; +import org.elasticsearch.rest.RestChannel; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.telemetry.InstrumentType; +import org.elasticsearch.telemetry.Measurement; +import org.elasticsearch.telemetry.MetricRecorder; +import org.elasticsearch.telemetry.RecordingMeterRegistry; +import org.elasticsearch.telemetry.TelemetryProvider; +import org.elasticsearch.telemetry.metric.Instrument; +import org.elasticsearch.telemetry.metric.MeterRegistry; +import org.elasticsearch.telemetry.tracing.Tracer; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.NodeDisconnectedException; +import org.elasticsearch.transport.netty4.Netty4Plugin; +import org.elasticsearch.transport.netty4.SharedGroupFactory; +import org.elasticsearch.transport.netty4.TLSConfig; +import org.elasticsearch.xpack.core.XPackSettings; +import org.elasticsearch.xpack.core.ssl.SSLService; +import org.hamcrest.Matchers; + +import java.nio.channels.ClosedChannelException; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.IntStream; + +import static org.elasticsearch.telemetry.InstrumentType.LONG_ASYNC_COUNTER; +import static org.elasticsearch.telemetry.InstrumentType.LONG_GAUGE; +import static org.elasticsearch.test.SecuritySettingsSource.addSSLSettingsForNodePEMFiles; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.lessThanOrEqualTo; + +public class SecurityNetty4HttpServerTransportTlsHandshakeThrottleTests extends ESTestCase { + + /** + * Represents a handshake that has passed the throttle and is in progress on the server side. It is blocked by this test fixture until + * the {@link #unblock()} method is called. + *

+ * The server transport exposes these to the tests via the {@code handshakeBlockQueue} queue. + */ + private static class HandshakeBlock { + + private final ActionListener innerPromise; + private final String threadName = Thread.currentThread().getName(); + private final Channel channel; + + HandshakeBlock(ActionListener innerPromise, Channel channel) { + this.innerPromise = innerPromise; + this.channel = channel; + } + + void unblock() { + // complete innerPromise on the event loop to ensure that the blocked read and readComplete events are unblocked in order: + // if completed on the caller thread then a concurrent addListener call may invoke its listener before the waiting ones + channel.eventLoop().execute(() -> innerPromise.onResponse(null)); + } + + @Override + public String toString() { + return "handshake block promise on " + threadName + " for " + channel; + } + } + + /** + * Set up a {@link Netty4HttpServerTransport} with SSL enabled (using a self-signed certificate) and some extra handlers in the pipeline + * around the {@link io.netty.handler.ssl.SslHandler} to enable testing. The first handler, just before the + * {@link io.netty.handler.ssl.SslHandler}, watches for (and blocks) the initial messages that make up a TLS handshake after it has + * passed the throttle mechanism. At this point it increments the count of concurrent handshakes (asserting that it is within bounds), + * and adds an entry to the {@code handshakeBlockQueue} allowing the queue's consumer to release the blocked messages. The second + * handler, just after the {@link io.netty.handler.ssl.SslHandler}, watches for the {@link SslHandshakeCompletionEvent} to decrement the + * count of concurrent handshakes. + */ + private Netty4HttpServerTransport createServerTransport( + ThreadPool threadPool, + SharedGroupFactory sharedGroupFactory, + int maxConcurrentTlsHandshakes, + int maxDelayedTlsHandshakes, + Queue handshakeBlockQueue, + MeterRegistry meterRegistry + ) { + final var dynamicConfiguration = randomBoolean(); + + final Settings.Builder builder = Settings.builder(); + addSSLSettingsForNodePEMFiles(builder, "xpack.security.http.", randomBoolean()); + final var settings = builder.put("xpack.security.http.ssl.enabled", true) + .put("path.home", createTempDir()) + .put( + Netty4Plugin.SETTING_HTTP_NETTY_TLS_HANDSHAKES_MAX_IN_PROGRESS.getKey(), + dynamicConfiguration ? between(0, 5) : maxConcurrentTlsHandshakes + ) + .put( + Netty4Plugin.SETTING_HTTP_NETTY_TLS_HANDSHAKES_MAX_DELAYED.getKey(), + dynamicConfiguration ? between(0, 5) : maxDelayedTlsHandshakes + ) + .build(); + final var env = TestEnvironment.newEnvironment(settings); + final var sslService = new SSLService(env); + final var tlsConfig = new TLSConfig(sslService.profile(XPackSettings.HTTP_SSL_PREFIX)::engine); + + final var inflightHandshakesByEventLoop = ConcurrentCollections.newConcurrentMap(); + + final List> settingsSet = new ArrayList<>(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + settingsSet.add(Netty4Plugin.SETTING_HTTP_NETTY_TLS_HANDSHAKES_MAX_IN_PROGRESS); + settingsSet.add(Netty4Plugin.SETTING_HTTP_NETTY_TLS_HANDSHAKES_MAX_DELAYED); + final var clusterSettings = new ClusterSettings(settings, Set.copyOf(settingsSet)); + + final var telemetryProvider = new TelemetryProvider() { + @Override + public Tracer getTracer() { + return Tracer.NOOP; + } + + @Override + public MeterRegistry getMeterRegistry() { + return meterRegistry; + } + }; + + final var server = new Netty4HttpServerTransport( + settings, + new NetworkService(Collections.emptyList()), + threadPool, + xContentRegistry(), + NEVER_CALLED_DISPATCHER, + clusterSettings, + sharedGroupFactory, + telemetryProvider, + tlsConfig, + null, + null + ) { + @Override + public ChannelHandler configureServerChannelHandler() { + return new HttpChannelHandler(this, handlingSettings, tlsConfig, null, null) { + @Override + protected void initChannel(Channel ch) throws Exception { + super.initChannel(ch); + + final var workerThread = Thread.currentThread(); + final var handshakeCounter = inflightHandshakesByEventLoop.computeIfAbsent( + workerThread, + ignored -> new AtomicInteger() + ); + final var handshakeCompletePromise = new SubscribableListener<>(); + final var handshakeBlockPromise = new SubscribableListener(); + + final var handshakeStartRecorder = new RunOnce(() -> { + logger.info("--> handshake start detected on [{}]", ch); + assertThat(handshakeCounter.incrementAndGet(), Matchers.lessThanOrEqualTo(maxConcurrentTlsHandshakes)); + handshakeCompletePromise.addListener(ActionListener.running(handshakeCounter::decrementAndGet)); + }); + + final var handshakeBlock = new HandshakeBlock(handshakeBlockPromise, ch); + final var handshakeBlockPromiseEnqueuer = new RunOnce(() -> handshakeBlockQueue.add(handshakeBlock)); + + ch.pipeline().addBefore("ssl", "handshake-start-detector", new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + handshakeStartRecorder.run(); + assertSame(workerThread, Thread.currentThread()); + // ownership transfer of msg to handshakeBlockPromise - no refcounting needed + handshakeBlockPromise.addListener(ActionListener.running(() -> { + assertTrue(ctx.executor().inEventLoop()); + ctx.fireChannelRead(msg); + })); + handshakeBlockPromiseEnqueuer.run(); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) { + handshakeBlockPromise.addListener(ActionListener.running(() -> { + assertTrue(ctx.executor().inEventLoop()); + ctx.fireChannelReadComplete(); + })); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + if (handshakeBlockPromise.isDone() == false) { + handshakeBlockPromise.onFailure(new NodeDisconnectedException(null, "channel inactive", null, null)); + } + super.channelInactive(ctx); + } + }).addAfter("ssl", "handshake-complete-detector", new SimpleUserEventChannelHandler() { + @Override + protected void eventReceived(ChannelHandlerContext ctx, SslHandshakeCompletionEvent evt) { + handshakeCompletePromise.onResponse(null); + ctx.fireUserEventTriggered(evt); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + handshakeCompletePromise.onResponse(null); + super.channelInactive(ctx); + } + }); + } + }; + } + }; + server.start(); + + if (dynamicConfiguration) { + clusterSettings.applySettings( + Settings.builder() + .put(Netty4Plugin.SETTING_HTTP_NETTY_TLS_HANDSHAKES_MAX_IN_PROGRESS.getKey(), maxConcurrentTlsHandshakes) + .put(Netty4Plugin.SETTING_HTTP_NETTY_TLS_HANDSHAKES_MAX_DELAYED.getKey(), maxDelayedTlsHandshakes) + .build() + ); + } + + return server; + } + + private static SslContext newClientSslContext() { + try { + return SslContextBuilder.forClient().trustManager(InsecureTrustManagerFactory.INSTANCE).build(); + } catch (Exception e) { + throw new AssertionError(e); + } + } + + private static HandshakeBlock getNextBlock(BlockingQueue handshakeBlockQueue) { + try { + return Objects.requireNonNull( + handshakeBlockQueue.poll(SAFE_AWAIT_TIMEOUT.seconds(), TimeUnit.SECONDS), + "timed out waiting for handshake block" + ); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new AssertionError(e); + } catch (Exception e) { + throw new AssertionError(e); + } + } + + /** + * This test ensures that we permit up to the max number of concurrent handshakes without any throttling + */ + public void testNoThrottling() { + final List releasables = new ArrayList<>(); + try { + // connection-to-event-loop assignment is round-robin so this will never hit the limit + final var eventLoopCount = between(1, 5); + final var maxConcurrentTlsHandshakes = between(1, 5); + final var clientCount = between(1, eventLoopCount * maxConcurrentTlsHandshakes); + final var maxDelayedTlsHandshakes = between(0, 100); + + final var threadPool = newThreadPool(releasables); + final var sharedGroupFactory = new SharedGroupFactory( + Settings.builder().put(Netty4Plugin.WORKER_COUNT.getKey(), eventLoopCount).build() + ); + final var handshakeBlockQueue = ConcurrentCollections.newBlockingQueue(); + + final var meterRegistry = new RecordingMeterRegistry(); + final var metricRecorder = meterRegistry.getRecorder(); + + final var serverTransport = createServerTransport( + threadPool, + sharedGroupFactory, + maxConcurrentTlsHandshakes, + maxDelayedTlsHandshakes, + handshakeBlockQueue, + meterRegistry + ); + releasables.add(serverTransport); + + final var handshakeCompletePromises = startClientsAndGetHandshakeCompletePromises( + clientCount, + randomFrom(serverTransport.boundAddress().boundAddresses()), + releasables + ); + + final var handshakeBlocks = IntStream.range(0, clientCount).mapToObj(ignored -> getNextBlock(handshakeBlockQueue)).toList(); + + logger.info("--> all handshakes blocked"); + + metricRecorder.collect(); + assertLongMetric(metricRecorder, LONG_GAUGE, CURRENT_IN_PROGRESS_METRIC, clientCount); + assertLongMetric(metricRecorder, LONG_GAUGE, CURRENT_DELAYED_METRIC, 0); + assertLongMetric(metricRecorder, LONG_ASYNC_COUNTER, TOTAL_DELAYED_METRIC, 0); + assertLongMetric(metricRecorder, LONG_ASYNC_COUNTER, TOTAL_DROPPED_METRIC, 0); + metricRecorder.resetCalls(); + + handshakeBlocks.forEach(HandshakeBlock::unblock); + handshakeCompletePromises.forEach(ESTestCase::safeAwait); + + assertFinalStats(metricRecorder, 0, 0); + } catch (Exception e) { + throw new AssertionError(e); + } finally { + Collections.reverse(releasables); + Releasables.close(releasables); + } + } + + /** + * This test ensures that if we send more than the permitted number of TLS handshakes at once then the excess are throttled + */ + public void testThrottleConcurrentHandshakes() { + final List releasables = new ArrayList<>(); + try { + // connection-to-event-loop assignment is round-robin so this will always use all available slots before queueing + final var eventLoopCount = between(1, 5); + final var maxConcurrentTlsHandshakes = between(1, 5); + final var expectedDelayedHandshakes = between(1, 3 * eventLoopCount); + final var clientCount = eventLoopCount * maxConcurrentTlsHandshakes + expectedDelayedHandshakes; + final var maxDelayedTlsHandshakes = clientCount + between(0, 100); + + final var threadPool = newThreadPool(releasables); + final var sharedGroupFactory = new SharedGroupFactory( + Settings.builder().put(Netty4Plugin.WORKER_COUNT.getKey(), eventLoopCount).build() + ); + final var handshakeBlockQueue = ConcurrentCollections.newBlockingQueue(); + + final var meterRegistry = new RecordingMeterRegistry(); + final var metricRecorder = meterRegistry.getRecorder(); + + final var serverTransport = createServerTransport( + threadPool, + sharedGroupFactory, + maxConcurrentTlsHandshakes, + maxDelayedTlsHandshakes, + handshakeBlockQueue, + meterRegistry + ); + releasables.add(serverTransport); + + final var serverAddress = randomFrom(serverTransport.boundAddress().boundAddresses()); + + final var handshakeCompletePromises = startClientsAndGetHandshakeCompletePromises(clientCount, serverAddress, releasables); + + final var handshakeBlocks = new ArrayDeque(); + for (int i = 0; i < eventLoopCount * maxConcurrentTlsHandshakes; i++) { + handshakeBlocks.addLast(getNextBlock(handshakeBlockQueue)); + } + assertNull(handshakeBlockQueue.poll()); // this is key: all the handshakes beyond the limit are delayed + logger.info("--> max number of handshakes received & blocked"); + + awaitLongMetric(metricRecorder, LONG_ASYNC_COUNTER, TOTAL_DELAYED_METRIC, expectedDelayedHandshakes); + logger.info("--> expected number of delayed handshakes observed"); + + metricRecorder.collect(); + assertLongMetric(metricRecorder, LONG_GAUGE, CURRENT_IN_PROGRESS_METRIC, eventLoopCount * maxConcurrentTlsHandshakes); + assertLongMetric(metricRecorder, LONG_GAUGE, CURRENT_DELAYED_METRIC, expectedDelayedHandshakes); + assertLongMetric(metricRecorder, LONG_ASYNC_COUNTER, TOTAL_DELAYED_METRIC, expectedDelayedHandshakes); + assertLongMetric(metricRecorder, LONG_ASYNC_COUNTER, TOTAL_DROPPED_METRIC, 0); + metricRecorder.resetCalls(); + + while (handshakeBlocks.isEmpty() == false) { + handshakeBlocks.removeFirst().unblock(); + } + + for (int i = 0; i < expectedDelayedHandshakes; i++) { + getNextBlock(handshakeBlockQueue).unblock(); + } + assertNull(handshakeBlockQueue.poll()); + + logger.info("--> all handshakes unblocked"); + + handshakeCompletePromises.forEach(ESTestCase::safeAwait); + assertFinalStats(metricRecorder, expectedDelayedHandshakes, 0); + } catch (Exception e) { + throw new AssertionError(e); + } finally { + Collections.reverse(releasables); + Releasables.close(releasables); + } + } + + public void testProcessThrottledHandshakesInLifoOrder() { + final List releasables = new ArrayList<>(); + try { + final var eventLoopCount = 1; // no concurrency + final var expectedDelayedHandshakes = 2; + + final var threadPool = newThreadPool(releasables); + final var sharedGroupFactory = new SharedGroupFactory( + Settings.builder().put(Netty4Plugin.WORKER_COUNT.getKey(), eventLoopCount).build() + ); + final var handshakeBlockQueue = ConcurrentCollections.newBlockingQueue(); + + final var meterRegistry = new RecordingMeterRegistry(); + final var metricRecorder = meterRegistry.getRecorder(); + + final var serverTransport = createServerTransport( + threadPool, + sharedGroupFactory, + 1, + expectedDelayedHandshakes, + handshakeBlockQueue, + meterRegistry + ); + releasables.add(serverTransport); + + final var clientEventLoop = newClientEventLoop(releasables); + final var sslContext = newClientSslContext(); + final var serverAddress = randomFrom(serverTransport.boundAddress().boundAddresses()); + + class TestClient { + final SubscribableListener handshakeListener = new SubscribableListener<>(); + final Bootstrap bootstrap = new Bootstrap().group(clientEventLoop) + .channel(NioSocketChannel.class) + .remoteAddress(serverAddress.getAddress(), serverAddress.getPort()) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel ch) { + final var sslHandler = sslContext.newHandler(ch.alloc()); + ch.pipeline().addLast(sslHandler); + sslHandler.handshakeFuture().addListener(fut -> { + assertTrue(fut.isSuccess()); + handshakeListener.onResponse(null); + }); + } + }); + + TestClient() { + final var connectFuture = bootstrap.connect(); + releasables.add(() -> connectFuture.syncUninterruptibly().channel().close().syncUninterruptibly()); + } + } + + final var client0 = new TestClient(); + awaitLongMetric(metricRecorder, LONG_GAUGE, CURRENT_IN_PROGRESS_METRIC, 1); + + final var client1 = new TestClient(); + awaitLongMetric(metricRecorder, LONG_GAUGE, CURRENT_DELAYED_METRIC, 1); + + final var client2 = new TestClient(); + awaitLongMetric(metricRecorder, LONG_GAUGE, CURRENT_DELAYED_METRIC, 2); + + logger.info("--> all handshakes blocked/delayed, unblocking client0"); + + getNextBlock(handshakeBlockQueue).unblock(); + safeAwait(client0.handshakeListener); + assertFalse(client1.handshakeListener.isDone()); + assertFalse(client2.handshakeListener.isDone()); + + logger.info("--> client0 handshake complete, unblocking client2"); + + // unblocking the client0 handshake releases the one received last, i.e. from client2 + getNextBlock(handshakeBlockQueue).unblock(); + safeAwait(client2.handshakeListener); + assertFalse(client1.handshakeListener.isDone()); + + getNextBlock(handshakeBlockQueue).unblock(); + safeAwait(client1.handshakeListener); + + assertNull(handshakeBlockQueue.poll()); + + assertFinalStats(metricRecorder, 2, 0); + } catch (Exception e) { + throw new AssertionError(e); + } finally { + Collections.reverse(releasables); + Releasables.close(releasables); + } + } + + public void testDiscardExcessiveConcurrentHandshakes() { + final List releasables = new ArrayList<>(); + try { + // connection-to-event-loop assignment is round-robin so this will always use all available slots before rejecting + final var eventLoopCount = between(1, 5); + final var maxConcurrentTlsHandshakes = between(1, 5); + final var maxDelayedTlsHandshakes = between(1, 5); + final var excessiveHandshakes = between(eventLoopCount, 5); // at least eventLoopCount to ensure max delayed everywhere + final var clientCount = eventLoopCount * (maxConcurrentTlsHandshakes + maxDelayedTlsHandshakes) + excessiveHandshakes; + + final var threadPool = newThreadPool(releasables); + final var sharedGroupFactory = new SharedGroupFactory( + Settings.builder().put(Netty4Plugin.WORKER_COUNT.getKey(), eventLoopCount).build() + ); + final var handshakeBlockQueue = ConcurrentCollections.newBlockingQueue(); + + final var meterRegistry = new RecordingMeterRegistry(); + final var metricRecorder = meterRegistry.getRecorder(); + + final var serverTransport = createServerTransport( + threadPool, + sharedGroupFactory, + maxConcurrentTlsHandshakes, + maxDelayedTlsHandshakes, + handshakeBlockQueue, + meterRegistry + ); + releasables.add(serverTransport); + + final var handshakeCompletePromises = startClientsAndGetHandshakeCompletePromises( + clientCount, + randomFrom(serverTransport.boundAddress().boundAddresses()), + releasables + ); + + final var failedHandshakesLatch = new CountDownLatch(excessiveHandshakes); + final var exceptionCount = new AtomicInteger(); + handshakeCompletePromises.forEach(handshakeCompletePromise -> handshakeCompletePromise.addListener(new ActionListener<>() { + @Override + public void onResponse(Void unused) {} + + @Override + public void onFailure(Exception e) { + assertThat(exceptionCount.incrementAndGet(), lessThanOrEqualTo(excessiveHandshakes)); + assertThat(e, Matchers.instanceOf(NodeDisconnectedException.class)); + assertThat(e.getMessage(), equalTo("disconnected before handshake complete")); + failedHandshakesLatch.countDown(); + } + })); + + safeAwait(failedHandshakesLatch); + logger.info("--> excessive handshakes cancelled"); + + metricRecorder.collect(); + assertLongMetric(metricRecorder, LONG_GAUGE, CURRENT_IN_PROGRESS_METRIC, eventLoopCount * maxConcurrentTlsHandshakes); + assertLongMetric(metricRecorder, LONG_GAUGE, CURRENT_DELAYED_METRIC, eventLoopCount * maxDelayedTlsHandshakes); + assertLongMetric( + metricRecorder, + LONG_ASYNC_COUNTER, + TOTAL_DELAYED_METRIC, + eventLoopCount * maxDelayedTlsHandshakes + excessiveHandshakes + ); + assertLongMetric(metricRecorder, LONG_ASYNC_COUNTER, TOTAL_DROPPED_METRIC, excessiveHandshakes); + metricRecorder.resetCalls(); + + for (int i = 0; i < eventLoopCount * (maxConcurrentTlsHandshakes + maxDelayedTlsHandshakes); i++) { + getNextBlock(handshakeBlockQueue).unblock(); + } + logger.info("--> all handshakes released"); + + final var completeLatch = new CountDownLatch(handshakeCompletePromises.size()); + handshakeCompletePromises.forEach( + handshakeCompletePromise -> handshakeCompletePromise.addListener(ActionListener.running(completeLatch::countDown)) + ); + safeAwait(completeLatch); + logger.info("--> all handshakes completed"); + + assertFinalStats(metricRecorder, eventLoopCount * maxDelayedTlsHandshakes + excessiveHandshakes, excessiveHandshakes); + + } catch (Exception e) { + throw new AssertionError(e); + } finally { + Collections.reverse(releasables); + Releasables.close(releasables); + } + } + + public void testThrottleHandlersOmittedWhenDisabled() { + final List releasables = new ArrayList<>(); + try { + final var threadPool = newThreadPool(releasables); + final var sharedGroupFactory = new SharedGroupFactory(Settings.builder().put(Netty4Plugin.WORKER_COUNT.getKey(), 1).build()); + final var handlersObservedQueue = ConcurrentCollections.newBlockingQueue(); + + final Settings.Builder builder = Settings.builder(); + addSSLSettingsForNodePEMFiles(builder, "xpack.security.http.", randomBoolean()); + final var settings = builder.put("xpack.security.http.ssl.enabled", true) + .put("path.home", createTempDir()) + .put(Netty4Plugin.SETTING_HTTP_NETTY_TLS_HANDSHAKES_MAX_IN_PROGRESS.getKey(), between(0, 1)) + .build(); + final var env = TestEnvironment.newEnvironment(settings); + final var sslService = new SSLService(env); + final var tlsConfig = new TLSConfig(sslService.profile(XPackSettings.HTTP_SSL_PREFIX)::engine); + + final List> settingsSet = new ArrayList<>(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + settingsSet.add(Netty4Plugin.SETTING_HTTP_NETTY_TLS_HANDSHAKES_MAX_IN_PROGRESS); + settingsSet.add(Netty4Plugin.SETTING_HTTP_NETTY_TLS_HANDSHAKES_MAX_DELAYED); + final var clusterSettings = new ClusterSettings(settings, Set.copyOf(settingsSet)); + + final var serverTransport = new Netty4HttpServerTransport( + settings, + new NetworkService(Collections.emptyList()), + threadPool, + xContentRegistry(), + NEVER_CALLED_DISPATCHER, + clusterSettings, + sharedGroupFactory, + TelemetryProvider.NOOP, + tlsConfig, + null, + null + ) { + @Override + public ChannelHandler configureServerChannelHandler() { + return new HttpChannelHandler(this, handlingSettings, tlsConfig, null, null) { + @Override + protected void initChannel(Channel ch) throws Exception { + super.initChannel(ch); + final var hasInitialThrottleHandler = ch.pipeline().get("initial-tls-handshake-throttle") != null; + final var hasCompletionHandler = ch.pipeline().get("initial-tls-handshake-completion-watcher") != null; + assertEquals(hasInitialThrottleHandler, hasCompletionHandler); + assertTrue(handlersObservedQueue.offer(hasInitialThrottleHandler && hasCompletionHandler)); + } + }; + } + }; + serverTransport.start(); + releasables.add(serverTransport); + + final var clientEventLoop = newClientEventLoop(releasables); + final var sslContext = newClientSslContext(); + final var serverAddress = randomFrom(serverTransport.boundAddress().boundAddresses()); + + for (final var throttleConfigured : new Boolean[] { null, Boolean.TRUE, Boolean.FALSE }) { + + logger.info("--> throttleConfigured: {}", throttleConfigured); + + final boolean throttleEnabled; + if (throttleConfigured == null) { + throttleEnabled = settings.getAsInt(Netty4Plugin.SETTING_HTTP_NETTY_TLS_HANDSHAKES_MAX_IN_PROGRESS.getKey(), null) > 0; + } else { + final var settingsBuilder = Settings.builder(); + settingsBuilder.put( + Netty4Plugin.SETTING_HTTP_NETTY_TLS_HANDSHAKES_MAX_IN_PROGRESS.getKey(), + throttleConfigured ? between(1, 1000) : 0 + ); + throttleEnabled = throttleConfigured; + clusterSettings.applySettings(settingsBuilder.build()); + } + + final var connectFuture = new Bootstrap().group(clientEventLoop) + .channel(NioSocketChannel.class) + .remoteAddress(serverAddress.getAddress(), serverAddress.getPort()) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel ch) { + final var sslHandler = sslContext.newHandler(ch.alloc()); + ch.pipeline().addLast(sslHandler); + } + }) + .connect(); + releasables.add(() -> connectFuture.syncUninterruptibly().channel().close().syncUninterruptibly()); + + assertEquals(throttleEnabled, handlersObservedQueue.poll(SAFE_AWAIT_TIMEOUT.millis(), TimeUnit.MILLISECONDS)); + } + + } catch (Exception e) { + throw new AssertionError(e); + } finally { + Collections.reverse(releasables); + Releasables.close(releasables); + } + } + + public void testIgnoreIncompleteHandshakes() { + final List releasables = new ArrayList<>(); + try { + final var eventLoopCount = between(1, 5); + final var maxConcurrentTlsHandshakes = between(1, 5); + final var maxDelayedTlsHandshakes = between(1, 5); + final var clientCount = between(1, eventLoopCount * (maxConcurrentTlsHandshakes + maxDelayedTlsHandshakes) + between(1, 5)); + + final var threadPool = newThreadPool(releasables); + final var sharedGroupFactory = new SharedGroupFactory( + Settings.builder().put(Netty4Plugin.WORKER_COUNT.getKey(), eventLoopCount).build() + ); + final var handshakeBlockQueue = ConcurrentCollections.newBlockingQueue(); + + final var meterRegistry = new RecordingMeterRegistry(); + final var metricRecorder = meterRegistry.getRecorder(); + + final var serverTransport = createServerTransport( + threadPool, + sharedGroupFactory, + maxConcurrentTlsHandshakes, + maxDelayedTlsHandshakes, + handshakeBlockQueue, + meterRegistry + ); + releasables.add(serverTransport); + + final var clientEventLoop = newClientEventLoop(releasables); + final var sslContext = newClientSslContext(); + final var serverAddress = randomFrom(serverTransport.boundAddress().boundAddresses()); + final var completeLatch = new CountDownLatch(clientCount); + + for (int i = 0; i < clientCount; i++) { + final var bootstrap = new Bootstrap().group(clientEventLoop) + .channel(NioSocketChannel.class) + .remoteAddress(serverAddress.getAddress(), serverAddress.getPort()) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel ch) { + final var sslHandler = sslContext.newHandler(ch.alloc()); + ch.pipeline().addLast(new ChannelOutboundHandlerAdapter() { + boolean truncating; + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + if (truncating) { + return; + } + + truncating = true; + final var msgByteBuf = (ByteBuf) msg; + final var truncatedMsg = msgByteBuf.slice(0, between(1, msgByteBuf.readableBytes() - 1)); + ctx.executor() + .execute( + () -> ctx.writeAndFlush(truncatedMsg, promise) + .addListener(f -> ctx.close().addListener(ff -> completeLatch.countDown())) + ); + } + }).addLast(sslHandler); + } + }); + final var connectFuture = bootstrap.connect(); + releasables.add(() -> connectFuture.syncUninterruptibly().channel().close().syncUninterruptibly()); + } + + metricRecorder.collect(); + assertLongMetric(metricRecorder, LONG_GAUGE, CURRENT_IN_PROGRESS_METRIC, 0); + assertLongMetric(metricRecorder, LONG_GAUGE, CURRENT_DELAYED_METRIC, 0); + assertLongMetric(metricRecorder, LONG_ASYNC_COUNTER, TOTAL_DELAYED_METRIC, 0); + assertLongMetric(metricRecorder, LONG_ASYNC_COUNTER, TOTAL_DROPPED_METRIC, 0); + metricRecorder.resetCalls(); + + safeAwait(completeLatch); + + metricRecorder.collect(); + assertLongMetric(metricRecorder, LONG_GAUGE, CURRENT_IN_PROGRESS_METRIC, 0); + assertLongMetric(metricRecorder, LONG_GAUGE, CURRENT_DELAYED_METRIC, 0); + assertLongMetric(metricRecorder, LONG_ASYNC_COUNTER, TOTAL_DELAYED_METRIC, 0); + assertLongMetric(metricRecorder, LONG_ASYNC_COUNTER, TOTAL_DROPPED_METRIC, 0); + metricRecorder.resetCalls(); + + assertTrue(handshakeBlockQueue.isEmpty()); + + } catch (Exception e) { + throw new AssertionError(e); + } finally { + Collections.reverse(releasables); + Releasables.close(releasables); + } + } + + public void testCleanUpOnEarlyClientClose() { + final List releasables = new ArrayList<>(); + try { + final var eventLoopCount = between(1, 5); + final var maxConcurrentTlsHandshakes = between(1, 5); + final var maxDelayedTlsHandshakes = between(1, 5); + final var clientCount = between(1, eventLoopCount * (maxConcurrentTlsHandshakes + maxDelayedTlsHandshakes) + between(1, 5)); + + final var threadPool = newThreadPool(releasables); + final var sharedGroupFactory = new SharedGroupFactory( + Settings.builder().put(Netty4Plugin.WORKER_COUNT.getKey(), eventLoopCount).build() + ); + final var handshakeBlockQueue = ConcurrentCollections.newBlockingQueue(); + + final var meterRegistry = new RecordingMeterRegistry(); + final var metricRecorder = meterRegistry.getRecorder(); + + final var serverTransport = createServerTransport( + threadPool, + sharedGroupFactory, + maxConcurrentTlsHandshakes, + maxDelayedTlsHandshakes, + handshakeBlockQueue, + meterRegistry + ); + releasables.add(serverTransport); + + final var clientEventLoop = newClientEventLoop(releasables); + final var sslContext = newClientSslContext(); + final var serverAddress = randomFrom(serverTransport.boundAddress().boundAddresses()); + final var completeLatch = new CountDownLatch(clientCount); + + for (int i = 0; i < clientCount; i++) { + final var bootstrap = new Bootstrap().group(clientEventLoop) + .channel(NioSocketChannel.class) + .remoteAddress(serverAddress.getAddress(), serverAddress.getPort()) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel ch) { + final var sslHandler = sslContext.newHandler(ch.alloc()); + ch.pipeline().addLast(new ChannelOutboundHandlerAdapter() { + @Override + public void flush(ChannelHandlerContext ctx) throws Exception { + super.flush(ctx); + ctx.close().addListener(ff -> completeLatch.countDown()); + } + }).addLast(sslHandler); + } + }); + final var connectFuture = bootstrap.connect(); + releasables.add(() -> connectFuture.syncUninterruptibly().channel().close().syncUninterruptibly()); + } + + safeAwait(completeLatch); + logger.info("--> completeLatch released"); + + assertTrue(waitUntil(() -> serverTransport.stats().getTotalOpen() == clientCount)); + logger.info("--> all connections opened"); + + assertTrue(waitUntil(() -> serverTransport.stats().getServerOpen() == 0)); + logger.info("--> all connections closed"); + + getNextBlock(handshakeBlockQueue); + logger.info("--> at least one handshake started"); + + awaitLongMetric(metricRecorder, LONG_GAUGE, CURRENT_IN_PROGRESS_METRIC, 0); + logger.info("--> CURRENT_IN_FLIGHT_METRIC reached zero"); + + metricRecorder.collect(); + assertLongMetric(metricRecorder, LONG_GAUGE, CURRENT_IN_PROGRESS_METRIC, 0); + assertLongMetric(metricRecorder, LONG_GAUGE, CURRENT_DELAYED_METRIC, 0); + metricRecorder.resetCalls(); + + } catch (Exception e) { + throw new AssertionError(e); + } finally { + Collections.reverse(releasables); + Releasables.close(releasables); + } + } + + private List> startClientsAndGetHandshakeCompletePromises( + int clientCount, + TransportAddress serverAddress, + List releasables + ) { + final var clientEventLoop = newClientEventLoop(releasables); + final var sslContext = newClientSslContext(); + final List> handshakeCompletePromises = new ArrayList<>(clientCount); + for (int i = 0; i < clientCount; i++) { + final var handshakeCompletePromise = new SubscribableListener(); + handshakeCompletePromises.add(handshakeCompletePromise); + final var bootstrap = new Bootstrap().group(clientEventLoop) + .channel(NioSocketChannel.class) + .remoteAddress(serverAddress.getAddress(), serverAddress.getPort()) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel ch) { + final var sslHandler = sslContext.newHandler(ch.alloc()); + ch.pipeline().addLast(sslHandler); + sslHandler.handshakeFuture().addListener(future -> { + if (future.isSuccess()) { + handshakeCompletePromise.onResponse(null); + } else { + ExceptionsHelper.maybeDieOnAnotherThread(future.cause()); + if (future.cause() instanceof ClosedChannelException closedChannelException) { + handshakeCompletePromise.onFailure( + new NodeDisconnectedException( + null, + "disconnected before handshake complete", + null, + closedChannelException + ) + ); + } else { + handshakeCompletePromise.onFailure(new ElasticsearchException("handshake failed", future.cause())); + } + } + }); + + ch.closeFuture() + .addListener( + ignored -> handshakeCompletePromise.onFailure( + new NodeDisconnectedException(null, "disconnected before handshake complete", null, null) + ) + ); + + handshakeCompletePromise.addListener(ActionListener.running(ch::close)); + } + }); + final var connectFuture = bootstrap.connect(); + releasables.add(() -> connectFuture.syncUninterruptibly().channel().close().syncUninterruptibly()); + } + return List.copyOf(handshakeCompletePromises); + } + + private static EventLoopGroup newClientEventLoop(List releasables) { + final var clientEventLoop = new NioEventLoopGroup(1); + releasables.add(() -> ThreadPool.terminate(clientEventLoop, SAFE_AWAIT_TIMEOUT.millis(), TimeUnit.MILLISECONDS)); + return clientEventLoop; + } + + private ThreadPool newThreadPool(List releasables) { + final var threadPool = new TestThreadPool(getTestName()); + releasables.add(() -> ThreadPool.terminate(threadPool, SAFE_AWAIT_TIMEOUT.seconds(), TimeUnit.SECONDS)); + return threadPool; + } + + private static void assertFinalStats(MetricRecorder metricRecorder, int expectedTotalDelayed, int expectedTotalDropped) { + // clients may get handshake completion before server, so we have to busy-wait here before checking final metrics: + awaitLongMetric(metricRecorder, LONG_GAUGE, CURRENT_IN_PROGRESS_METRIC, 0); + metricRecorder.collect(); + assertLongMetric(metricRecorder, LONG_GAUGE, CURRENT_IN_PROGRESS_METRIC, 0); + assertLongMetric(metricRecorder, LONG_GAUGE, CURRENT_DELAYED_METRIC, 0); + assertLongMetric(metricRecorder, LONG_ASYNC_COUNTER, TOTAL_DELAYED_METRIC, expectedTotalDelayed); + assertLongMetric(metricRecorder, LONG_ASYNC_COUNTER, TOTAL_DROPPED_METRIC, expectedTotalDropped); + metricRecorder.resetCalls(); + } + + private static void assertLongMetric( + MetricRecorder metricRecorder, + InstrumentType instrumentType, + String name, + int expectedValue + ) { + assertEquals( + name, + List.of((long) expectedValue), + metricRecorder.getMeasurements(instrumentType, name).stream().map(Measurement::getLong).toList() + ); + } + + private static final Logger logger = LogManager.getLogger(SecurityNetty4HttpServerTransportTlsHandshakeThrottleTests.class); + + private static void awaitLongMetric( + MetricRecorder metricRecorder, + InstrumentType instrumentType, + String name, + int expectedValue + ) { + final var expectedMeasurements = List.of((long) expectedValue); + assertTrue(waitUntil(() -> { + metricRecorder.collect(); + final var measurements = metricRecorder.getMeasurements(instrumentType, name).stream().map(Measurement::getLong).toList(); + metricRecorder.resetCalls(); + logger.info("--> awaitLongMetric[{}/{}] got {}", instrumentType, name, measurements); + assertThat(measurements, hasSize(1)); + return measurements.equals(expectedMeasurements); + })); + } + + private static final HttpServerTransport.Dispatcher NEVER_CALLED_DISPATCHER = new HttpServerTransport.Dispatcher() { + @Override + public void dispatchRequest(RestRequest request, RestChannel channel, ThreadContext threadContext) { + fail("dispatchRequest should never be called"); + } + + @Override + public void dispatchBadRequest(RestChannel channel, ThreadContext threadContext, Throwable cause) { + fail("dispatchBadRequest should never be called"); + } + }; + + private static final String METRIC_PREFIX = "es.http.tls_handshakes."; + private static final String CURRENT_IN_PROGRESS_METRIC = METRIC_PREFIX + "in_progress.current"; + private static final String CURRENT_DELAYED_METRIC = METRIC_PREFIX + "delayed.current"; + private static final String TOTAL_DELAYED_METRIC = METRIC_PREFIX + "delayed.total"; + private static final String TOTAL_DROPPED_METRIC = METRIC_PREFIX + "dropped.total"; +} diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4TestUtils.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4TestUtils.java new file mode 100644 index 0000000000000..6be39907a145c --- /dev/null +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4TestUtils.java @@ -0,0 +1,31 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.security.transport.netty4; + +import org.elasticsearch.common.settings.ClusterSettings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.http.HttpTransportSettings; +import org.elasticsearch.transport.netty4.Netty4Plugin; + +import static org.elasticsearch.test.ESTestCase.randomBoolean; + +public enum SecurityNetty4TestUtils { + ; + + public static ClusterSettings randomClusterSettings() { + return new ClusterSettings( + Settings.builder().put(HttpTransportSettings.SETTING_HTTP_CLIENT_STATS_ENABLED.getKey(), randomBoolean()).build(), + Sets.addToCopy( + ClusterSettings.BUILT_IN_CLUSTER_SETTINGS, + Netty4Plugin.SETTING_HTTP_NETTY_TLS_HANDSHAKES_MAX_IN_PROGRESS, + Netty4Plugin.SETTING_HTTP_NETTY_TLS_HANDSHAKES_MAX_DELAYED + ) + ); + } +}