diff --git a/build.gradle b/build.gradle index 02cd2de7..4733acf5 100644 --- a/build.gradle +++ b/build.gradle @@ -24,7 +24,14 @@ repositories { dependencies { testImplementation 'junit:junit:4.12' testImplementation 'org.hamcrest:hamcrest-library:1.3' - testImplementation 'org.apache.logging.log4j:log4j-core:2.17.0' + implementation 'org.apache.logging.log4j:log4j-core:2.17.0' + + // needed to run ServerRunner self hosted + implementation "com.fasterxml.jackson.core:jackson-core:${jacksonVersion}" + implementation "com.fasterxml.jackson.core:jackson-databind:${jacksonDatabindVersion}" + implementation "com.fasterxml.jackson.core:jackson-annotations:${jacksonVersion}" + implementation "com.fasterxml.jackson.module:jackson-module-afterburner:2.15.2" + implementation "io.netty:netty-buffer:${nettyVersion}" implementation "io.netty:netty-codec:${nettyVersion}" implementation "io.netty:netty-common:${nettyVersion}" @@ -104,3 +111,15 @@ publishing { } } } + +jar { + duplicatesStrategy(DuplicatesStrategy.EXCLUDE) + + manifest { + attributes "Main-Class": "org.logstash.beats.ServerRunner" + } + + from { + configurations.runtimeClasspath.collect { it.isDirectory() ? it : zipTree(it) } + } +} diff --git a/src/main/java/org/logstash/beats/BeatsParser.java b/src/main/java/org/logstash/beats/BeatsParser.java index 61337d3b..f0d49fb0 100644 --- a/src/main/java/org/logstash/beats/BeatsParser.java +++ b/src/main/java/org/logstash/beats/BeatsParser.java @@ -3,6 +3,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufOutputStream; +import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.ByteToMessageDecoder; import org.apache.logging.log4j.LogManager; @@ -48,16 +49,20 @@ private enum States { @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws InvalidFrameProtocolException, IOException { - if(!hasEnoughBytes(in)) { - if (decodingCompressedBuffer){ + if (!hasEnoughBytes(in)) { + if (decodingCompressedBuffer) { throw new InvalidFrameProtocolException("Insufficient bytes in compressed content to decode: " + currentState); } return; } + if (!ctx.channel().isOpen()) { + logger.info("Channel is not open, {}", ctx.channel()); + } + switch (currentState) { case READ_HEADER: { - logger.trace("Running: READ_HEADER"); + logger.trace("Running: READ_HEADER {}", ctx.channel()); int version = Protocol.version(in.readByte()); @@ -70,7 +75,7 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t batch = new V1Batch(); } } - transition(States.READ_FRAME_TYPE); + transition(States.READ_FRAME_TYPE, ctx.channel()); break; } case READ_FRAME_TYPE: { @@ -78,20 +83,20 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t switch(frameType) { case Protocol.CODE_WINDOW_SIZE: { - transition(States.READ_WINDOW_SIZE); + transition(States.READ_WINDOW_SIZE, ctx.channel()); break; } case Protocol.CODE_JSON_FRAME: { // Reading Sequence + size of the payload - transition(States.READ_JSON_HEADER); + transition(States.READ_JSON_HEADER, ctx.channel()); break; } case Protocol.CODE_COMPRESSED_FRAME: { - transition(States.READ_COMPRESSED_FRAME_HEADER); + transition(States.READ_COMPRESSED_FRAME_HEADER, ctx.channel()); break; } case Protocol.CODE_FRAME: { - transition(States.READ_DATA_FIELDS); + transition(States.READ_DATA_FIELDS, ctx.channel()); break; } default: { @@ -101,8 +106,10 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t break; } case READ_WINDOW_SIZE: { - logger.trace("Running: READ_WINDOW_SIZE"); - batch.setBatchSize((int) in.readUnsignedInt()); + logger.trace("Running: READ_WINDOW_SIZE {}", ctx.channel()); + int batchSize = (int) in.readUnsignedInt(); + logger.trace("Batch size: {} - channel {}", batchSize, ctx.channel()); + batch.setBatchSize(batchSize); // This is unlikely to happen but I have no way to known when a frame is // actually completely done other than checking the windows and the sequence number, @@ -114,12 +121,12 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t batchComplete(); } - transition(States.READ_HEADER); + transition(States.READ_HEADER, ctx.channel()); break; } case READ_DATA_FIELDS: { // Lumberjack version 1 protocol, which use the Key:Value format. - logger.trace("Running: READ_DATA_FIELDS"); + logger.trace("Running: READ_DATA_FIELDS {}", ctx.channel()); sequence = (int) in.readUnsignedInt(); int fieldsCount = (int) in.readUnsignedInt(); int count = 0; @@ -152,34 +159,36 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t out.add(batch); batchComplete(); } - transition(States.READ_HEADER); + transition(States.READ_HEADER, ctx.channel()); break; } case READ_JSON_HEADER: { - logger.trace("Running: READ_JSON_HEADER"); + logger.trace("Running: READ_JSON_HEADER {}", ctx.channel()); sequence = (int) in.readUnsignedInt(); + logger.trace("Sequence num to read {} for channel {}", sequence, ctx.channel()); int jsonPayloadSize = (int) in.readUnsignedInt(); if(jsonPayloadSize <= 0) { throw new InvalidFrameProtocolException("Invalid json length, received: " + jsonPayloadSize); } - transition(States.READ_JSON, jsonPayloadSize); + transition(States.READ_JSON, jsonPayloadSize, ctx.channel()); break; } case READ_COMPRESSED_FRAME_HEADER: { - logger.trace("Running: READ_COMPRESSED_FRAME_HEADER"); + logger.trace("Running: READ_COMPRESSED_FRAME_HEADER {}", ctx.channel()); - transition(States.READ_COMPRESSED_FRAME, in.readInt()); + transition(States.READ_COMPRESSED_FRAME, in.readInt(), ctx.channel()); break; } case READ_COMPRESSED_FRAME: { - logger.trace("Running: READ_COMPRESSED_FRAME"); + logger.trace("Running: READ_COMPRESSED_FRAME {}", ctx.channel()); + inflateCompressedFrame(ctx, in, (buffer) -> { - transition(States.READ_HEADER); + transition(States.READ_HEADER, ctx.channel()); decodingCompressedBuffer = true; try { @@ -188,23 +197,32 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t } } finally { decodingCompressedBuffer = false; - transition(States.READ_HEADER); + transition(States.READ_HEADER, ctx.channel()); } }); break; } case READ_JSON: { - logger.trace("Running: READ_JSON"); - ((V2Batch)batch).addMessage(sequence, in, requiredBytes); - if(batch.isComplete()) { - if(logger.isTraceEnabled()) { - logger.trace("Sending batch size: " + this.batch.size() + ", windowSize: " + batch.getBatchSize() + " , seq: " + sequence); + logger.trace("Running: READ_JSON {}", ctx.channel()); + try { + ((V2Batch) batch).addMessage(sequence, in, requiredBytes); + } catch (Throwable th) { + // batch has to release its internal buffer before bubbling up the exception + batch.release(); + + // re throw the same error after released the internal buffer + throw th; + } + + if (batch.isComplete()) { + if (logger.isTraceEnabled()) { + logger.trace("Sending batch size: " + this.batch.size() + ", windowSize: " + batch.getBatchSize() + " , seq: " + sequence + " {}", ctx.channel()); } out.add(batch); batchComplete(); } - transition(States.READ_HEADER); + transition(States.READ_HEADER, ctx.channel()); break; } } @@ -238,13 +256,13 @@ private boolean hasEnoughBytes(ByteBuf in) { return in.readableBytes() >= requiredBytes; } - private void transition(States next) { - transition(next, next.length); + private void transition(States next, Channel ch) { + transition(next, next.length, ch); } - private void transition(States nextState, int requiredBytes) { + private void transition(States nextState, int requiredBytes, Channel ch) { if (logger.isTraceEnabled()) { - logger.trace("Transition, from: " + currentState + ", to: " + nextState + ", requiring " + requiredBytes + " bytes"); + logger.trace("Transition, from: " + currentState + ", to: " + nextState + ", requiring " + requiredBytes + " bytes {}", ch); } this.currentState = nextState; this.requiredBytes = requiredBytes; diff --git a/src/main/java/org/logstash/beats/FlowLimiterHandler.java b/src/main/java/org/logstash/beats/FlowLimiterHandler.java new file mode 100644 index 00000000..6a0da517 --- /dev/null +++ b/src/main/java/org/logstash/beats/FlowLimiterHandler.java @@ -0,0 +1,56 @@ +package org.logstash.beats; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +/** + * Configure the channel where it's installed to operate the reads in pull mode, + * disabling the autoread and explicitly invoking the read operation. + * The flow control to keep the outgoing buffer under control is done + * avoiding to read in new bytes if the outgoing direction became not writable, this + * excert back pressure to the TCP layer and ultimately to the upstream system. + * */ +@Sharable +public final class FlowLimiterHandler extends ChannelInboundHandlerAdapter { + + private final static Logger logger = LogManager.getLogger(FlowLimiterHandler.class); + + @Override + public void channelRegistered(final ChannelHandlerContext ctx) throws Exception { + ctx.channel().config().setAutoRead(false); + super.channelRegistered(ctx); + } + + @Override + public void channelActive(final ChannelHandlerContext ctx) throws Exception { + super.channelActive(ctx); + if (isAutoreadDisabled(ctx.channel()) && ctx.channel().isWritable()) { + ctx.channel().read(); + } + } + + @Override + public void channelReadComplete(final ChannelHandlerContext ctx) throws Exception { + super.channelReadComplete(ctx); + if (isAutoreadDisabled(ctx.channel()) && ctx.channel().isWritable()) { + ctx.channel().read(); + } + } + + private boolean isAutoreadDisabled(Channel channel) { + return !channel.config().isAutoRead(); + } + + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception { + ctx.channel().read(); + super.channelWritabilityChanged(ctx); + + logger.debug("Writability on channel {} changed to {}", ctx.channel(), ctx.channel().isWritable()); + } + +} diff --git a/src/main/java/org/logstash/beats/OOMConnectionCloser.java b/src/main/java/org/logstash/beats/OOMConnectionCloser.java new file mode 100644 index 00000000..57a69802 --- /dev/null +++ b/src/main/java/org/logstash/beats/OOMConnectionCloser.java @@ -0,0 +1,59 @@ +package org.logstash.beats; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +public class OOMConnectionCloser extends ChannelInboundHandlerAdapter { + + private static class DirectMemoryUsage { + private final long used; + private final long pinned; + private final short ratio; + + private DirectMemoryUsage(long used, long pinned) { + this.used = used; + this.pinned = pinned; + this.ratio = (short) Math.round(((double) pinned / used) * 100); + } + + static DirectMemoryUsage capture() { + PooledByteBufAllocator allocator = (PooledByteBufAllocator) ByteBufAllocator.DEFAULT; + long usedDirectMemory = allocator.metric().usedDirectMemory(); + long pinnedDirectMemory = allocator.pinnedDirectMemory(); + return new DirectMemoryUsage(usedDirectMemory, pinnedDirectMemory); + } + } + + private final static Logger logger = LogManager.getLogger(OOMConnectionCloser.class); + + public static final Pattern DIRECT_MEMORY_ERROR = Pattern.compile("^Cannot reserve \\d* bytes of direct buffer memory.*$"); + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (isDirectMemoryOOM(cause)) { + DirectMemoryUsage direct = DirectMemoryUsage.capture(); + logger.info("Direct memory status, used: {}, pinned: {}, ratio: {}", direct.used, direct.pinned, direct.ratio); + logger.warn("Dropping connection {} due to lack of available Direct Memory. Please lower the number of concurrent connections or reduce the batch size. " + + "Alternatively, raise -XX:MaxDirectMemorySize option in the JVM running Logstash", ctx.channel()); + ctx.flush(); + ctx.close(); + } else { + super.exceptionCaught(ctx, cause); + } + } + + private boolean isDirectMemoryOOM(Throwable th) { + if (!(th instanceof OutOfMemoryError)) { + return false; + } + Matcher m = DIRECT_MEMORY_ERROR.matcher(th.getMessage()); + return m.matches(); + } +} \ No newline at end of file diff --git a/src/main/java/org/logstash/beats/Server.java b/src/main/java/org/logstash/beats/Server.java index c343aaf6..8e7786b6 100644 --- a/src/main/java/org/logstash/beats/Server.java +++ b/src/main/java/org/logstash/beats/Server.java @@ -1,6 +1,8 @@ package org.logstash.beats; import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.PooledByteBufAllocator; import io.netty.channel.*; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.SocketChannel; @@ -112,7 +114,6 @@ private class BeatsInitializer extends ChannelInitializer { private final int IDLESTATE_WRITER_IDLE_TIME_SECONDS = 5; private final EventExecutorGroup idleExecutorGroup; - private final EventExecutorGroup beatsHandlerExecutorGroup; private final IMessageListener localMessageListener; private final int localClientInactivityTimeoutSeconds; @@ -121,7 +122,6 @@ private class BeatsInitializer extends ChannelInitializer { this.localMessageListener = messageListener; this.localClientInactivityTimeoutSeconds = clientInactivityTimeoutSeconds; idleExecutorGroup = new DefaultEventExecutorGroup(DEFAULT_IDLESTATEHANDLER_THREAD); - beatsHandlerExecutorGroup = new DefaultEventExecutorGroup(beatsHandlerThread); } public void initChannel(SocketChannel socket){ @@ -130,11 +130,29 @@ public void initChannel(SocketChannel socket){ if (isSslEnabled()) { pipeline.addLast(SSL_HANDLER, sslHandlerProvider.sslHandlerForChannel(socket)); } - pipeline.addLast(idleExecutorGroup, IDLESTATE_HANDLER, - new IdleStateHandler(localClientInactivityTimeoutSeconds, IDLESTATE_WRITER_IDLE_TIME_SECONDS, localClientInactivityTimeoutSeconds)); - pipeline.addLast(BEATS_ACKER, new AckEncoder()); - pipeline.addLast(CONNECTION_HANDLER, new ConnectionHandler()); - pipeline.addLast(beatsHandlerExecutorGroup, new BeatsParser(), new BeatsHandler(localMessageListener)); +// pipeline.addLast(idleExecutorGroup, IDLESTATE_HANDLER, +// new IdleStateHandler(localClientInactivityTimeoutSeconds, IDLESTATE_WRITER_IDLE_TIME_SECONDS, localClientInactivityTimeoutSeconds)); +// pipeline.addLast(BEATS_ACKER, new AckEncoder()); +// pipeline.addLast(CONNECTION_HANDLER, new ConnectionHandler()); + +// pipeline.addLast(new FlowLimiterHandler()); +// pipeline.addLast(new ThunderingGuardHandler()); + pipeline.addLast("beats parser", new BeatsParser()); +// pipeline.addLast(new OOMConnectionCloser()); +// pipeline.addLast("beats handler", new BeatsHandler(localMessageListener)); + pipeline.addLast(new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + logger.warn("Exception {} received on {}", cause.getMessage(), ctx.channel()); +// pipeline.remove("beats parser"); +// if (cause instanceof OutOfMemoryError) { + ctx.channel().close(); +// } + super.exceptionCaught(ctx, cause); + } + }); + + logger.info("Starting with handlers: {}", pipeline.names()); } @@ -152,7 +170,6 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E public void shutdownEventExecutor() { try { idleExecutorGroup.shutdownGracefully().sync(); - beatsHandlerExecutorGroup.shutdownGracefully().sync(); } catch (InterruptedException e) { throw new IllegalStateException(e); } diff --git a/src/main/java/org/logstash/beats/ServerRunner.java b/src/main/java/org/logstash/beats/ServerRunner.java new file mode 100644 index 00000000..3b263a89 --- /dev/null +++ b/src/main/java/org/logstash/beats/ServerRunner.java @@ -0,0 +1,57 @@ +package org.logstash.beats; + +import io.netty.channel.ChannelHandlerContext; +import org.logstash.netty.SslContextBuilder; +import org.logstash.netty.SslHandlerProvider; + +public class ServerRunner { + + public static void main(String[] args) throws Exception { + int clientInactivityTimeoutSeconds = 60; + Server server = new Server("127.0.0.1", 3333, clientInactivityTimeoutSeconds, Runtime.getRuntime().availableProcessors()); + + // enable TLS + System.out.println("Using SSL"); + + String sslCertificate = "/Users/andrea/workspace/certificates/client_from_root.crt"; + String sslKey = "/Users/andrea/workspace/certificates/client_from_root.key.pkcs8"; + String[] certificateAuthorities = new String[] { "/Users/andrea/workspace/certificates/root.crt" }; + + SslContextBuilder sslBuilder = new SslContextBuilder(sslCertificate, sslKey, null) + .setProtocols(new String[] { "TLSv1.2", "TLSv1.3" }) + .setClientAuthentication(SslContextBuilder.SslClientVerifyMode.REQUIRED, certificateAuthorities); + SslHandlerProvider sslHandlerProvider = new SslHandlerProvider(sslBuilder.buildContext(), 10000); + server.setSslHandlerProvider(sslHandlerProvider); + + // no TLS + + server.setMessageListener(new IMessageListener() { + @Override + public void onNewMessage(ChannelHandlerContext ctx, Message message) { + + } + + @Override + public void onNewConnection(ChannelHandlerContext ctx) { + + } + + @Override + public void onConnectionClose(ChannelHandlerContext ctx) { + + } + + @Override + public void onException(ChannelHandlerContext ctx, Throwable cause) { + + } + + @Override + public void onChannelInitializeException(ChannelHandlerContext ctx, Throwable cause) { + + } + }); + // blocking till the end + server.listen(); + } +} diff --git a/src/main/java/org/logstash/beats/ThunderingGuardHandler.java b/src/main/java/org/logstash/beats/ThunderingGuardHandler.java new file mode 100644 index 00000000..a880ed35 --- /dev/null +++ b/src/main/java/org/logstash/beats/ThunderingGuardHandler.java @@ -0,0 +1,40 @@ +package org.logstash.beats; + +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +/** + * This handler is responsible to avoid accepting new connections when the direct memory + * consumption is close to the MaxDirectMemorySize. + *

+ * If the total allocated direct memory is close to the max memory size and also the pinned + * bytes from the direct memory allocator is close to the direct memory used, then it drops the new + * incoming connections. + * */ +@Sharable +public final class ThunderingGuardHandler extends ChannelInboundHandlerAdapter { + + private final static long MAX_DIRECT_MEMORY = io.netty.util.internal.PlatformDependent.maxDirectMemory(); + + private final static Logger logger = LogManager.getLogger(ThunderingGuardHandler.class); + + @Override + public void channelRegistered(final ChannelHandlerContext ctx) throws Exception { + PooledByteBufAllocator pooledAllocator = (PooledByteBufAllocator) ctx.alloc(); + long usedDirectMemory = pooledAllocator.metric().usedDirectMemory(); + if (usedDirectMemory > MAX_DIRECT_MEMORY * 0.90) { + long pinnedDirectMemory = pooledAllocator.pinnedDirectMemory(); + if (pinnedDirectMemory >= usedDirectMemory * 0.80) { + ctx.close(); + logger.warn("Dropping connection {} due to high resource consumption", ctx.channel()); + return; + } + } + + super.channelRegistered(ctx); + } +} diff --git a/src/test/java/org/logstash/beats/FlowLimiterHandlerTest.java b/src/test/java/org/logstash/beats/FlowLimiterHandlerTest.java new file mode 100644 index 00000000..f268621e --- /dev/null +++ b/src/test/java/org/logstash/beats/FlowLimiterHandlerTest.java @@ -0,0 +1,195 @@ +package org.logstash.beats; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import org.junit.Test; + +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; + +import static org.junit.Assert.*; + +public class FlowLimiterHandlerTest { + + private ReadMessagesCollector readMessagesCollector; + + private static ByteBuf prepareSample(int numBytes) { + return prepareSample(numBytes, 'A'); + } + + private static ByteBuf prepareSample(int numBytes, char c) { + ByteBuf payload = PooledByteBufAllocator.DEFAULT.directBuffer(numBytes); + for (int i = 0; i < numBytes; i++) { + payload.writeByte(c); + } + return payload; + } + + private ChannelInboundHandlerAdapter onClientConnected(Consumer action) { + return new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + super.channelActive(ctx); + action.accept(ctx); + } + }; + } + + private static class ReadMessagesCollector extends SimpleChannelInboundHandler { + private Channel clientChannel; + private final NioEventLoopGroup group; + boolean firstChunkRead = false; + + ReadMessagesCollector(NioEventLoopGroup group) { + this.group = group; + } + + @Override + protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception { + if (!firstChunkRead) { + assertEquals("Expect to read a first chunk and no others", 32, msg.readableBytes()); + firstChunkRead = true; + + // client write other data that MUSTN'T be read by the server, because + // is rate limited. + clientChannel.writeAndFlush(prepareSample(16)).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isSuccess()) { + // on successful flush schedule a shutdown + ctx.channel().eventLoop().schedule(new Runnable() { + @Override + public void run() { + group.shutdownGracefully(); + } + }, 2, TimeUnit.SECONDS); + } else { + ctx.fireExceptionCaught(future.cause()); + } + } + }); + + } else { + // the first read happened, no other reads are commanded by the server + // should never pass here + fail("Shouldn't never be notified other data while in rate limiting"); + } + } + + public void updateClient(Channel clientChannel) { + assertNotNull(clientChannel); + this.clientChannel = clientChannel; + } + } + + + private static class AssertionsHandler extends ChannelInboundHandlerAdapter { + + private final NioEventLoopGroup group; + + private Throwable lastError; + + public AssertionsHandler(NioEventLoopGroup group) { + this.group = group; + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + lastError = cause; + group.shutdownGracefully(); + } + + public void assertNoErrors() { + if (lastError != null) { + if (lastError instanceof AssertionError) { + throw (AssertionError) lastError; + } else { + fail("Failed with error" + lastError); + } + } + } + } + + @Test + public void givenAChannelInNotWriteableStateWhenNewBuffersAreSentByClientThenNoDecodeTakePartOnServerSide() throws Exception { + final int highWaterMark = 32 * 1024; + FlowLimiterHandler sut = new FlowLimiterHandler(); + + NioEventLoopGroup group = new NioEventLoopGroup(); + ServerBootstrap b = new ServerBootstrap(); + + readMessagesCollector = new ReadMessagesCollector(group); + AssertionsHandler assertionsHandler = new AssertionsHandler(group); + try { + b.group(group) + .channel(NioServerSocketChannel.class) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel ch) throws Exception { + ch.config().setWriteBufferHighWaterMark(highWaterMark); + ch.pipeline() + .addLast(onClientConnected(ctx -> { + // write as much to move the channel in not writable state + fillOutboundWatermark(ctx, highWaterMark); + // ask the client to send some data present on the channel + clientChannel.writeAndFlush(prepareSample(32)); + })) + .addLast(sut) + .addLast(readMessagesCollector) + .addLast(assertionsHandler); + } + }); + ChannelFuture future = b.bind("0.0.0.0", 1234).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isSuccess()) { + startAClient(group); + } + } + }).sync(); + future.channel().closeFuture().sync(); + } finally { + group.shutdownGracefully().sync(); + } + + assertionsHandler.assertNoErrors(); + } + + private static void fillOutboundWatermark(ChannelHandlerContext ctx, int highWaterMark) { + final ByteBuf payload = prepareSample(highWaterMark, 'C'); + while (ctx.channel().isWritable()) { + ctx.pipeline().writeAndFlush(payload.copy()); + } + } + + Channel clientChannel; + + private void startAClient(NioEventLoopGroup group) { + Bootstrap b = new Bootstrap(); + b.group(group) + .channel(NioSocketChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel ch) throws Exception { + ch.config().setAutoRead(false); + clientChannel = ch; + readMessagesCollector.updateClient(clientChannel); + } + }); + b.connect("localhost", 1234); + } + +} \ No newline at end of file diff --git a/src/test/java/org/logstash/beats/ThunderingGuardHandlerTest.java b/src/test/java/org/logstash/beats/ThunderingGuardHandlerTest.java new file mode 100644 index 00000000..d93f6582 --- /dev/null +++ b/src/test/java/org/logstash/beats/ThunderingGuardHandlerTest.java @@ -0,0 +1,83 @@ +package org.logstash.beats; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.ReferenceCounted; +import io.netty.util.internal.PlatformDependent; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.Assert.*; + +public class ThunderingGuardHandlerTest { + + public static final int MB = 1024 * 1024; + public static final long MAX_DIRECT_MEMORY_BYTES = PlatformDependent.maxDirectMemory(); + + @Test + public void testVerifyDirectMemoryCouldGoBeyondThe90Percent() { + // allocate 90% of direct memory + List allocatedBuffers = allocateDirectMemory(MAX_DIRECT_MEMORY_BYTES, 0.9); + + // allocate one more + ByteBuf payload = PooledByteBufAllocator.DEFAULT.directBuffer(1 * MB); + long usedDirectMemory = PooledByteBufAllocator.DEFAULT.metric().usedDirectMemory(); + long pinnedDirectMemory = PooledByteBufAllocator.DEFAULT.pinnedDirectMemory(); + + // verify + assertTrue("Direct memory allocation should be > 90% of the max available", usedDirectMemory > 0.9 * MAX_DIRECT_MEMORY_BYTES); + assertTrue("Direct memory usage should be > 80% of the max available", pinnedDirectMemory > 0.8 * MAX_DIRECT_MEMORY_BYTES); + + allocatedBuffers.forEach(ReferenceCounted::release); + payload.release(); + } + + private static List allocateDirectMemory(long maxDirectMemoryBytes, double percentage) { + List allocatedBuffers = new ArrayList<>(); + final long numBuffersToAllocate = (long) (maxDirectMemoryBytes / MB * percentage); + for (int i = 0; i < numBuffersToAllocate; i++) { + allocatedBuffers.add(PooledByteBufAllocator.DEFAULT.directBuffer(1 * MB)); + } + return allocatedBuffers; + } + + @Test + public void givenUsedDirectMemoryAndPinnedMemoryAreCloseToTheMaxDirectAvailableWhenNewConnectionIsCreatedThenItIsReject() { + EmbeddedChannel channel = new EmbeddedChannel(new ThunderingGuardHandler()); + + // consume > 90% of the direct memory + List allocatedBuffers = allocateDirectMemory(MAX_DIRECT_MEMORY_BYTES, 0.9); + // allocate one more + ByteBuf payload = PooledByteBufAllocator.DEFAULT.directBuffer(1 * MB); + + channel.pipeline().fireChannelRegistered(); + + // verify + assertFalse("Under constrained memory new channels has to be forcibly closed", channel.isOpen()); + + allocatedBuffers.forEach(ReferenceCounted::release); + payload.release(); + } + + @Test + public void givenUsedDirectMemoryAndNotPinnedWhenNewConnectionIsCreatedThenItIsAccepted() { + EmbeddedChannel channel = new EmbeddedChannel(new ThunderingGuardHandler()); + + // consume > 90% of the direct memory + List allocatedBuffers = allocateDirectMemory(MAX_DIRECT_MEMORY_BYTES, 0.9); + allocatedBuffers.forEach(ReferenceCounted::release); + // allocate one more + ByteBuf payload = PooledByteBufAllocator.DEFAULT.directBuffer(1 * MB); + payload.release(); + + channel.pipeline().fireChannelRegistered(); + + // verify + assertTrue("Despite memory is allocated but not pinned, new connections MUST be accepted", channel.isOpen()); + + } + +} \ No newline at end of file