diff --git a/docs/changelog/117787.yaml b/docs/changelog/117787.yaml new file mode 100644 index 0000000000000..342947d7cb08d --- /dev/null +++ b/docs/changelog/117787.yaml @@ -0,0 +1,5 @@ +pr: 117787 +summary: "Move HTTP content aggregation from Netty to RestController" +area: Network +type: enhancement +issues: [] diff --git a/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4IncrementalRequestHandlingIT.java b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4IncrementalRequestHandlingIT.java index 4bb27af4bd0f5..59b8a08e90c9c 100644 --- a/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4IncrementalRequestHandlingIT.java +++ b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4IncrementalRequestHandlingIT.java @@ -117,11 +117,10 @@ public void testEmptyContent() throws Exception { assertTrue(recvChunk.isLast); assertEquals(0, recvChunk.chunk.length()); recvChunk.chunk.close(); - assertFalse(handler.streamClosed); + assertBusy(() -> assertTrue(handler.streamClosed)); // send response to process following request handler.sendResponse(new RestResponse(RestStatus.OK, "")); - assertBusy(() -> assertTrue(handler.streamClosed)); } assertBusy(() -> assertEquals("should receive all server responses", totalRequests, ctx.clientRespQueue.size())); } @@ -154,10 +153,9 @@ public void testReceiveAllChunks() throws Exception { } } - assertFalse(handler.streamClosed); + assertBusy(() -> assertTrue(handler.streamClosed)); assertEquals("sent and received payloads are not the same", sendData, recvData); handler.sendResponse(new RestResponse(RestStatus.OK, "")); - assertBusy(() -> assertTrue(handler.streamClosed)); } assertBusy(() -> assertEquals("should receive all server responses", totalRequests, ctx.clientRespQueue.size())); } @@ -327,38 +325,35 @@ public void test413TooLargeOnExpect100Continue() throws Exception { assertEquals(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, resp.status()); resp.release(); - // terminate request + // HttpRequestEncoder should properly close request, not required on server side ctx.clientChannel.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT); } } } - // ensures that oversized chunked encoded request has no limits at http layer - // rest handler is responsible for oversized requests - public void testOversizedChunkedEncodingNoLimits() throws Exception { + // ensures that oversized chunked encoded request has limits at http layer + // and closes connection after reaching limit + public void testOversizedChunkedEncodingLimits() throws Exception { try (var ctx = setupClientCtx()) { - for (var reqNo = 0; reqNo < randomIntBetween(2, 10); reqNo++) { - var id = opaqueId(reqNo); - var contentSize = maxContentLength() + 1; - var content = randomByteArrayOfLength(contentSize); - var is = new ByteBufInputStream(Unpooled.wrappedBuffer(content)); - var chunkedIs = new ChunkedStream(is); - var httpChunkedIs = new HttpChunkedInput(chunkedIs, LastHttpContent.EMPTY_LAST_CONTENT); - var req = httpRequest(id, 0); - HttpUtil.setTransferEncodingChunked(req, true); - - ctx.clientChannel.pipeline().addLast(new ChunkedWriteHandler()); - ctx.clientChannel.writeAndFlush(req); - ctx.clientChannel.writeAndFlush(httpChunkedIs); - var handler = ctx.awaitRestChannelAccepted(id); - var consumed = handler.readAllBytes(); - assertEquals(contentSize, consumed); - handler.sendResponse(new RestResponse(RestStatus.OK, "")); - - var resp = (FullHttpResponse) safePoll(ctx.clientRespQueue); - assertEquals(HttpResponseStatus.OK, resp.status()); - resp.release(); - } + var id = opaqueId(0); + var contentSize = maxContentLength() + 1; + var content = randomByteArrayOfLength(contentSize); + var is = new ByteBufInputStream(Unpooled.wrappedBuffer(content)); + var chunkedIs = new ChunkedStream(is); + var httpChunkedIs = new HttpChunkedInput(chunkedIs, LastHttpContent.EMPTY_LAST_CONTENT); + var req = httpRequest(id, 0); + HttpUtil.setTransferEncodingChunked(req, true); + + ctx.clientChannel.pipeline().addLast(new ChunkedWriteHandler()); + ctx.clientChannel.writeAndFlush(req); + ctx.clientChannel.writeAndFlush(httpChunkedIs); + var handler = ctx.awaitRestChannelAccepted(id); + handler.readAllBytes(); + + var resp = (FullHttpResponse) safePoll(ctx.clientRespQueue); + assertEquals(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, resp.status()); + safeGet(ctx.clientChannel.closeFuture()); + resp.release(); } } @@ -594,7 +589,7 @@ record Ctx(String testName, String nodeName, Bootstrap clientBootstrap, Channel @Override public void close() throws Exception { safeGet(clientChannel.close()); - safeGet(clientBootstrap.config().group().shutdownGracefully()); + safeGet(clientBootstrap.config().group().shutdownGracefully(0, 0, TimeUnit.SECONDS)); clientRespQueue.forEach(o -> { if (o instanceof FullHttpResponse resp) resp.release(); }); for (var opaqueId : ControlServerRequestPlugin.handlers.keySet()) { if (opaqueId.startsWith(testName)) { @@ -655,24 +650,27 @@ void sendResponse(RestResponse response) { channel.sendResponse(response); } - int readBytes(int bytes) { + int readBytes(int bytes) throws InterruptedException { var consumed = 0; if (recvLast == false) { - while (consumed < bytes) { - stream.next(); - var recvChunk = safePoll(recvChunks); - consumed += recvChunk.chunk.length(); - recvChunk.chunk.close(); - if (recvChunk.isLast) { - recvLast = true; - break; + stream.next(); + while (consumed < bytes && streamClosed == false) { + var recvChunk = recvChunks.poll(10, TimeUnit.MILLISECONDS); + if (recvChunk != null) { + consumed += recvChunk.chunk.length(); + recvChunk.chunk.close(); + if (recvChunk.isLast) { + recvLast = true; + break; + } + stream.next(); } } } return consumed; } - int readAllBytes() { + int readAllBytes() throws InterruptedException { return readBytes(Integer.MAX_VALUE); } @@ -704,6 +702,11 @@ public String getName() { return ROUTE; } + @Override + public boolean supportContentStream() { + return true; + } + @Override public List routes() { return List.of(new Route(RestRequest.Method.POST, ROUTE)); diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/AutoReadSync.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/AutoReadSync.java new file mode 100644 index 0000000000000..e0e54f60f26f2 --- /dev/null +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/AutoReadSync.java @@ -0,0 +1,102 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.http.netty4; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelConfig; +import io.netty.util.AttributeKey; + +import java.util.BitSet; + +/** + * AutoReadSync provides coordinated access to the {@link ChannelConfig#setAutoRead(boolean)}. + * We use autoRead flag for the data flow control in the channel pipeline to prevent excessive + * buffering inside channel handlers. Every actor in the pipeline should obtain its own {@link Handle} + * by calling {@link AutoReadSync#getHandle} channel. Channel autoRead is enabled as long as all Handles + * are enabled. If one of handles disables autoRead, channel autoRead disables too. + * Simply, {@code channel.setAutoRead(allHandlesTrue)}. + *

+ * TODO: this flow control should be removed when {@link Netty4HttpHeaderValidator} moves to RestController. + * And whole control flow can be simplified to {@link io.netty.handler.flow.FlowControlHandler}. + */ +class AutoReadSync { + + private static final AttributeKey AUTO_READ_SYNC_KEY = AttributeKey.valueOf("AutoReadSync"); + private final Channel channel; + private final ChannelConfig config; + + // A pool of reusable handles and their states. Handle id is a sequence number in the set. + // Handles bitset is a pool of ids. Toggles bitset is a set of autoRead states. + // Default value for toggle is 0, which means autoRead is enabled. + private final BitSet handles; + private final BitSet toggles; + + AutoReadSync(Channel channel) { + this.channel = channel; + this.config = channel.config(); + this.handles = new BitSet(); + this.toggles = new BitSet(); + } + + static Handle getHandle(Channel channel) { + assert channel.eventLoop().inEventLoop(); + var autoRead = channel.attr(AUTO_READ_SYNC_KEY).get(); + if (autoRead == null) { + autoRead = new AutoReadSync(channel); + channel.attr(AUTO_READ_SYNC_KEY).set(autoRead); + } + return autoRead.getHandle(); + } + + Handle getHandle() { + var handleId = handles.nextClearBit(0); // next unused handle id + handles.set(handleId, true); // acquire lease + return new Handle(handleId); + } + + class Handle { + private final int id; + private boolean released; + + Handle(int id) { + this.id = id; + } + + private void assertState() { + assert channel.eventLoop().inEventLoop(); + assert released == false; + } + + boolean isEnabled() { + assertState(); + return toggles.get(id) == false; + } + + void enable() { + assertState(); + toggles.set(id, false); + config.setAutoRead(toggles.isEmpty()); + } + + void disable() { + assertState(); + toggles.set(id, true); + config.setAutoRead(false); + } + + void release() { + assertState(); + enable(); + handles.set(id, false); + released = true; + } + } + +} diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpAggregator.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpAggregator.java deleted file mode 100644 index 021ce09e0ed8e..0000000000000 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpAggregator.java +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.http.netty4; - -import io.netty.channel.ChannelHandlerContext; -import io.netty.handler.codec.http.FullHttpRequest; -import io.netty.handler.codec.http.FullHttpResponse; -import io.netty.handler.codec.http.HttpContent; -import io.netty.handler.codec.http.HttpObject; -import io.netty.handler.codec.http.HttpObjectAggregator; -import io.netty.handler.codec.http.HttpRequest; -import io.netty.handler.codec.http.HttpResponseStatus; -import io.netty.handler.codec.http.HttpUtil; - -import org.elasticsearch.http.HttpPreRequest; -import org.elasticsearch.http.netty4.internal.HttpHeadersAuthenticatorUtils; - -import java.util.function.Predicate; - -/** - * A wrapper around {@link HttpObjectAggregator}. Provides optional content aggregation based on - * predicate. {@link HttpObjectAggregator} also handles Expect: 100-continue and oversized content. - * Unfortunately, Netty does not provide handlers for oversized messages beyond HttpObjectAggregator. - */ -public class Netty4HttpAggregator extends HttpObjectAggregator { - private static final Predicate IGNORE_TEST = (req) -> req.uri().startsWith("/_test/request-stream") == false; - - private final Predicate decider; - private boolean aggregating = true; - private boolean ignoreContentAfterContinueResponse = false; - - public Netty4HttpAggregator(int maxContentLength, Predicate decider) { - super(maxContentLength); - this.decider = decider; - } - - @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { - assert msg instanceof HttpObject; - if (msg instanceof HttpRequest request) { - var preReq = HttpHeadersAuthenticatorUtils.asHttpPreRequest(request); - aggregating = (decider.test(preReq) && IGNORE_TEST.test(preReq)) || request.decoderResult().isFailure(); - } - if (aggregating || msg instanceof FullHttpRequest) { - super.channelRead(ctx, msg); - } else { - handle(ctx, (HttpObject) msg); - } - } - - private void handle(ChannelHandlerContext ctx, HttpObject msg) { - if (msg instanceof HttpRequest request) { - var continueResponse = newContinueResponse(request, maxContentLength(), ctx.pipeline()); - if (continueResponse != null) { - // there are 3 responses expected: 100, 413, 417 - // on 100 we pass request further and reply to client to continue - // on 413/417 we ignore following content - ctx.writeAndFlush(continueResponse); - var resp = (FullHttpResponse) continueResponse; - if (resp.status() != HttpResponseStatus.CONTINUE) { - ignoreContentAfterContinueResponse = true; - return; - } - HttpUtil.set100ContinueExpected(request, false); - } - ignoreContentAfterContinueResponse = false; - ctx.fireChannelRead(msg); - } else { - var httpContent = (HttpContent) msg; - if (ignoreContentAfterContinueResponse) { - httpContent.release(); - } else { - ctx.fireChannelRead(msg); - } - } - } -} diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpContentSizeHandler.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpContentSizeHandler.java new file mode 100644 index 0000000000000..eedec15ccc87a --- /dev/null +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpContentSizeHandler.java @@ -0,0 +1,176 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.http.netty4; + +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpContent; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpObject; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpRequestDecoder; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpUtil; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.handler.codec.http.LastHttpContent; + +import org.elasticsearch.core.SuppressForbidden; + +import static io.netty.handler.codec.http.HttpHeaderNames.CONNECTION; +import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH; + +/** + * Provides handling for Expect header and content size. Implements HTTP1.1 spec. + * Allows {@code Expect: 100-continue} header only. Other Expect headers will be rejected with + * {@code 417 Expectation Failed} reason. + *
+ * Replies {@code 100 Continue} to requests with allowed maxContentLength. + *
+ * Replies {@code 413 Request Entity Too Large} when content size exceeds maxContentLength. + * Clients sending oversized requests with Expect: 100-continue included are allowed to reuse same + * connection as long as they dont send content after rejection. Otherwise, when client started to + * send oversized content, we cannot safely accept it. Connection will be closed. + *

+ * TODO: move to RestController to allow content limits per RestHandler. + * Ideally we should be able to handle Continue and oversized request in the RestController. + * But that introduces a few challenges, basically re-implementation of HTTP protocol at the RestController: + *
    + *
  • + * 100 Continue is interim response, means RestChannel will send 2 responses for a single request. See + * rfc9110.html#status.100 + *
  • + *
  • + * RestChannel should be able to close underlying HTTP channel connection. + *
  • + *
+ */ +@SuppressForbidden(reason = "use of default ChannelFutureListener's CLOSE and CLOSE_ON_FAILURE") +public class Netty4HttpContentSizeHandler extends ChannelInboundHandlerAdapter { + + // copied from HttpObjectAggregator + private static final FullHttpResponse CONTINUE = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, + HttpResponseStatus.CONTINUE, + Unpooled.EMPTY_BUFFER + ); + private static final FullHttpResponse EXPECTATION_FAILED = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, + HttpResponseStatus.EXPECTATION_FAILED, + Unpooled.EMPTY_BUFFER + ); + private static final FullHttpResponse TOO_LARGE_CLOSE = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, + HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, + Unpooled.EMPTY_BUFFER + ); + private static final FullHttpResponse TOO_LARGE = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, + HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, + Unpooled.EMPTY_BUFFER + ); + + static { + EXPECTATION_FAILED.headers().set(CONTENT_LENGTH, 0); + TOO_LARGE.headers().set(CONTENT_LENGTH, 0); + + TOO_LARGE_CLOSE.headers().set(CONTENT_LENGTH, 0); + TOO_LARGE_CLOSE.headers().set(CONNECTION, HttpHeaderValues.CLOSE); + } + + private final HttpRequestDecoder decoder; + private final int maxContentLength; + private boolean ignoreFollowingContent = false; + private boolean contentNotAllowed = false; + private int contentLength = 0; + + public Netty4HttpContentSizeHandler(HttpRequestDecoder decoder, int maxContentLength) { + this.decoder = decoder; + this.maxContentLength = maxContentLength; + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + assert msg instanceof HttpObject; + if (msg instanceof HttpRequest request) { + handleRequest(ctx, request); + } else { + handleContent(ctx, (HttpContent) msg); + } + } + + private void replyAndForbidFollowingContent(ChannelHandlerContext ctx, FullHttpResponse errResponse) { + decoder.reset(); // reset decoder to skip following content + ctx.writeAndFlush(errResponse.retainedDuplicate()).addListener(ChannelFutureListener.CLOSE_ON_FAILURE); + contentNotAllowed = true; // some content might be already in the pipeline, we need to catch it + ignoreFollowingContent = true; + } + + private void handleRequest(ChannelHandlerContext ctx, HttpRequest request) { + final var expectValue = request.headers().get(HttpHeaderNames.EXPECT); + + boolean isContinueExpected = false; + if (expectValue != null) { + if (HttpHeaderValues.CONTINUE.toString().equalsIgnoreCase(expectValue)) { + isContinueExpected = true; + } else { + replyAndForbidFollowingContent(ctx, EXPECTATION_FAILED); + return; + } + } + + boolean isOversized = HttpUtil.getContentLength(request, -1) > maxContentLength; + if (isOversized) { + if (isContinueExpected) { + // Client is allowed to send content without waiting for Continue. + // See https://www.rfc-editor.org/rfc/rfc9110.html#section-10.1.1-11.3 + // + // Mark following content as forbidden to prevent unbounded content after Expect failed. + replyAndForbidFollowingContent(ctx, TOO_LARGE); + } else { + // Client is sending oversized content, we cannot safely take it. Closing channel. + ctx.writeAndFlush(TOO_LARGE_CLOSE.retainedDuplicate()).addListener(ChannelFutureListener.CLOSE); + ignoreFollowingContent = true; + } + } else { + if (isContinueExpected) { + ctx.writeAndFlush(CONTINUE.retainedDuplicate()); + HttpUtil.set100ContinueExpected(request, false); + } + ignoreFollowingContent = false; + contentNotAllowed = false; + contentLength = 0; + ctx.fireChannelRead(request); + } + } + + private void handleContent(ChannelHandlerContext ctx, HttpContent httpContent) { + if (contentNotAllowed && httpContent != LastHttpContent.EMPTY_LAST_CONTENT) { + httpContent.release(); + ctx.close(); + } else if (ignoreFollowingContent) { + httpContent.release(); + } else { + contentLength += httpContent.content().readableBytes(); + if (contentLength > maxContentLength) { + ignoreFollowingContent = true; + httpContent.release(); + // Client is sending oversized content, we cannot safely take it. Closing channel. + ctx.writeAndFlush(TOO_LARGE_CLOSE.retainedDuplicate()).addListener(ChannelFutureListener.CLOSE); + return; + } + ctx.fireChannelRead(httpContent); + } + } +} diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpHeaderValidator.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpHeaderValidator.java index 95a68cb52bbdb..33eee9278a731 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpHeaderValidator.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpHeaderValidator.java @@ -37,6 +37,7 @@ public class Netty4HttpHeaderValidator extends ChannelInboundHandlerAdapter { private final HttpValidator validator; private final ThreadContext threadContext; + private AutoReadSync.Handle autoRead; private ArrayDeque pending = new ArrayDeque<>(4); private State state = WAITING_TO_START; @@ -45,6 +46,12 @@ public Netty4HttpHeaderValidator(HttpValidator validator, ThreadContext threadCo this.threadContext = threadContext; } + @Override + public void channelRegistered(ChannelHandlerContext ctx) throws Exception { + autoRead = AutoReadSync.getHandle(ctx.channel()); + super.channelRegistered(ctx); + } + State getState() { return state; } @@ -61,7 +68,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception pending.add(ReferenceCountUtil.retain(httpObject)); requestStart(ctx); assert state == QUEUEING_DATA; - assert ctx.channel().config().isAutoRead() == false; + assert autoRead.isEnabled() == false; break; case QUEUEING_DATA: pending.add(ReferenceCountUtil.retain(httpObject)); @@ -83,7 +90,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception case DROPPING_DATA_PERMANENTLY: assert pending.isEmpty(); ReferenceCountUtil.release(httpObject); // consume without enqueuing - ctx.channel().config().setAutoRead(false); + autoRead.disable(); break; } } @@ -106,7 +113,7 @@ private void requestStart(ChannelHandlerContext ctx) { } state = QUEUEING_DATA; - ctx.channel().config().setAutoRead(false); + autoRead.disable(); if (httpRequest == null) { // this looks like a malformed request and will forward without validation @@ -149,10 +156,10 @@ public void onFailure(Exception e) { private void forwardFullRequest(ChannelHandlerContext ctx) { Transports.assertDefaultThreadContext(threadContext); assert ctx.channel().eventLoop().inEventLoop(); - assert ctx.channel().config().isAutoRead() == false; + assert autoRead.isEnabled() == false; assert state == QUEUEING_DATA; - ctx.channel().config().setAutoRead(true); + autoRead.enable(); boolean fullRequestForwarded = forwardData(ctx, pending); assert fullRequestForwarded || pending.isEmpty(); @@ -169,7 +176,7 @@ private void forwardFullRequest(ChannelHandlerContext ctx) { private void forwardRequestWithDecoderExceptionAndNoContent(ChannelHandlerContext ctx, Exception e) { Transports.assertDefaultThreadContext(threadContext); assert ctx.channel().eventLoop().inEventLoop(); - assert ctx.channel().config().isAutoRead() == false; + assert autoRead.isEnabled() == false; assert state == QUEUEING_DATA; HttpObject messageToForward = pending.getFirst(); @@ -180,7 +187,7 @@ private void forwardRequestWithDecoderExceptionAndNoContent(ChannelHandlerContex } messageToForward.setDecoderResult(DecoderResult.failure(e)); - ctx.channel().config().setAutoRead(true); + autoRead.enable(); ctx.fireChannelRead(messageToForward); assert fullRequestDropped || pending.isEmpty(); diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandler.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandler.java index 1a391a05add58..d09ff07d2d6f8 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandler.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandler.java @@ -116,7 +116,7 @@ public Netty4HttpPipeliningHandler( } @Override - public void channelRead(final ChannelHandlerContext ctx, final Object msg) { + public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception { activityTracker.startActivity(); try { if (msg instanceof HttpRequest request) { @@ -130,7 +130,7 @@ public void channelRead(final ChannelHandlerContext ctx, final Object msg) { } else { nonError = (Exception) cause; } - netty4HttpRequest = new Netty4HttpRequest(readSequence++, (FullHttpRequest) request, nonError); + netty4HttpRequest = new Netty4HttpRequest(readSequence++, request, nonError); } else { assert currentRequestStream == null : "current stream must be null for new request"; if (request instanceof FullHttpRequest fullHttpRequest) { @@ -139,7 +139,8 @@ public void channelRead(final ChannelHandlerContext ctx, final Object msg) { } else { var contentStream = new Netty4HttpRequestBodyStream( ctx.channel(), - serverTransport.getThreadPool().getThreadContext() + serverTransport.getThreadPool().getThreadContext(), + activityTracker ); currentRequestStream = contentStream; netty4HttpRequest = new Netty4HttpRequest(readSequence++, request, contentStream); diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequest.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequest.java index 2662ddf7e1440..8584aa99899ae 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequest.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequest.java @@ -11,11 +11,14 @@ import io.netty.buffer.Unpooled; import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.DefaultHttpRequest; import io.netty.handler.codec.http.EmptyHttpHeaders; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpUtil; import io.netty.handler.codec.http.QueryStringDecoder; import io.netty.handler.codec.http.cookie.Cookie; import io.netty.handler.codec.http.cookie.ServerCookieDecoder; @@ -23,7 +26,6 @@ import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.http.HttpBody; -import org.elasticsearch.http.HttpRequest; import org.elasticsearch.http.HttpResponse; import org.elasticsearch.rest.ChunkedRestResponseBodyPart; import org.elasticsearch.rest.RestRequest; @@ -39,9 +41,9 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Collectors; -public class Netty4HttpRequest implements HttpRequest { +public class Netty4HttpRequest implements org.elasticsearch.http.HttpRequest { - private final FullHttpRequest request; + private final HttpRequest request; private final HttpBody content; private final Map> headers; private final AtomicBoolean released; @@ -49,8 +51,9 @@ public class Netty4HttpRequest implements HttpRequest { private final boolean pooled; private final int sequence; private final QueryStringDecoder queryStringDecoder; + private final int contentLength; - Netty4HttpRequest(int sequence, io.netty.handler.codec.http.HttpRequest request, Netty4HttpRequestBodyStream contentStream) { + Netty4HttpRequest(int sequence, HttpRequest request, Netty4HttpRequestBodyStream contentStream) { this( sequence, new DefaultFullHttpRequest( @@ -72,17 +75,18 @@ public class Netty4HttpRequest implements HttpRequest { this(sequence, request, new AtomicBoolean(false), true, Netty4Utils.fullHttpBodyFrom(request.content())); } - Netty4HttpRequest(int sequence, FullHttpRequest request, Exception inboundException) { - this(sequence, request, new AtomicBoolean(false), true, Netty4Utils.fullHttpBodyFrom(request.content()), inboundException); + Netty4HttpRequest(int sequence, HttpRequest request, Exception inboundException) { + this(sequence, request, new AtomicBoolean(false), true, HttpBody.empty(), inboundException); + } - private Netty4HttpRequest(int sequence, FullHttpRequest request, AtomicBoolean released, boolean pooled, HttpBody content) { + private Netty4HttpRequest(int sequence, HttpRequest request, AtomicBoolean released, boolean pooled, HttpBody content) { this(sequence, request, released, pooled, content, null); } private Netty4HttpRequest( int sequence, - FullHttpRequest request, + HttpRequest request, AtomicBoolean released, boolean pooled, HttpBody content, @@ -96,6 +100,15 @@ private Netty4HttpRequest( this.released = released; this.inboundException = inboundException; this.queryStringDecoder = new QueryStringDecoder(request.uri()); + this.contentLength = getContentLength(request); + } + + static int getContentLength(io.netty.handler.codec.http.HttpRequest request) { + if (HttpUtil.isTransferEncodingChunked(request)) { + return -1; + } else { + return HttpUtil.getContentLength(request, 0); + } } @Override @@ -113,6 +126,11 @@ public String rawPath() { return queryStringDecoder.rawPath(); } + @Override + public int contentLength() { + return contentLength; + } + @Override public HttpBody body() { assert released.get() == false; @@ -122,7 +140,6 @@ public HttpBody body() { @Override public void release() { if (pooled && released.compareAndSet(false, true)) { - request.release(); content.close(); } } @@ -147,27 +164,23 @@ public List strictCookies() { @Override public HttpVersion protocolVersion() { if (request.protocolVersion().equals(io.netty.handler.codec.http.HttpVersion.HTTP_1_0)) { - return HttpRequest.HttpVersion.HTTP_1_0; + return org.elasticsearch.http.HttpRequest.HttpVersion.HTTP_1_0; } else if (request.protocolVersion().equals(io.netty.handler.codec.http.HttpVersion.HTTP_1_1)) { - return HttpRequest.HttpVersion.HTTP_1_1; + return org.elasticsearch.http.HttpRequest.HttpVersion.HTTP_1_1; } else { throw new IllegalArgumentException("Unexpected http protocol version: " + request.protocolVersion()); } } @Override - public HttpRequest removeHeader(String header) { + public org.elasticsearch.http.HttpRequest removeHeader(String header) { HttpHeaders copiedHeadersWithout = request.headers().copy(); copiedHeadersWithout.remove(header); - HttpHeaders copiedTrailingHeadersWithout = request.trailingHeaders().copy(); - copiedTrailingHeadersWithout.remove(header); - FullHttpRequest requestWithoutHeader = new DefaultFullHttpRequest( + HttpRequest requestWithoutHeader = new DefaultHttpRequest( request.protocolVersion(), request.method(), request.uri(), - request.content(), - copiedHeadersWithout, - copiedTrailingHeadersWithout + copiedHeadersWithout ); return new Netty4HttpRequest(sequence, requestWithoutHeader, released, pooled, content); } diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequestBodyStream.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequestBodyStream.java index ac3e3aecf97b9..964e55390bf09 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequestBodyStream.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequestBodyStream.java @@ -16,6 +16,7 @@ import io.netty.handler.codec.http.HttpContent; import io.netty.handler.codec.http.LastHttpContent; +import org.elasticsearch.common.network.ThreadWatchdog; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.Releasables; import org.elasticsearch.http.HttpBody; @@ -33,25 +34,29 @@ public class Netty4HttpRequestBodyStream implements HttpBody.Stream { private final Channel channel; - private final ChannelFutureListener closeListener = future -> doClose(); private final List tracingHandlers = new ArrayList<>(4); private final ThreadContext threadContext; + private final AutoReadSync.Handle autoRead; + private final ThreadWatchdog.ActivityTracker activityTracker; private ByteBuf buf; private boolean requested = false; - private boolean closing = false; + private final ChannelFutureListener closeListener = future -> doClose(); private HttpBody.ChunkHandler handler; private ThreadContext.StoredContext requestContext; + private boolean hasLast = false; + private boolean closed = false; // used in tests private volatile int bufSize = 0; - private volatile boolean hasLast = false; - public Netty4HttpRequestBodyStream(Channel channel, ThreadContext threadContext) { + public Netty4HttpRequestBodyStream(Channel channel, ThreadContext threadContext, ThreadWatchdog.ActivityTracker activityTracker) { this.channel = channel; this.threadContext = threadContext; + this.activityTracker = activityTracker; this.requestContext = threadContext.newStoredContext(); + this.autoRead = AutoReadSync.getHandle(channel); Netty4Utils.addListener(channel.closeFuture(), closeListener); - channel.config().setAutoRead(false); + autoRead.disable(); } @Override @@ -72,7 +77,7 @@ public void addTracingHandler(ChunkHandler chunkHandler) { @Override public void next() { - assert closing == false : "cannot request next chunk on closing stream"; + assert closed == false : "cannot request next chunk on closing stream"; assert handler != null : "handler must be set before requesting next chunk"; requestContext = threadContext.newStoredContext(); channel.eventLoop().submit(() -> { @@ -81,18 +86,23 @@ public void next() { channel.read(); } else { try { + activityTracker.startActivity(); send(); - } catch (Exception e) { - channel.pipeline().fireExceptionCaught(e); + } catch (Throwable t) { + // must catch everything + doClose(); + channel.pipeline().fireExceptionCaught(t); + } finally { + activityTracker.stopActivity(); } } }); } - public void handleNettyContent(HttpContent httpContent) { + public void handleNettyContent(HttpContent httpContent) throws Exception { assert hasLast == false : "receive http content on completed stream"; hasLast = httpContent instanceof LastHttpContent; - if (closing) { + if (closed) { httpContent.release(); } else { addChunk(httpContent.content()); @@ -128,7 +138,7 @@ boolean hasLast() { return hasLast; } - private void send() { + private void send() throws Exception { assert requested; assert handler != null : "must set handler before receiving next chunk"; var bytesRef = Netty4Utils.toReleasableBytesReference(buf); @@ -142,8 +152,7 @@ private void send() { handler.onNext(bytesRef, hasLast); } if (hasLast) { - channel.config().setAutoRead(true); - channel.closeFuture().removeListener(closeListener); + doClose(); } } @@ -152,25 +161,32 @@ public void close() { if (channel.eventLoop().inEventLoop()) { doClose(); } else { - channel.eventLoop().submit(this::doClose); + if (channel.eventLoop().isShutdown() == false) { + channel.eventLoop().submit(this::doClose); + } } } private void doClose() { - closing = true; - try (var ignored = threadContext.restoreExistingContext(requestContext)) { - for (var tracer : tracingHandlers) { - Releasables.closeExpectNoException(tracer); - } - if (handler != null) { - handler.close(); + if (closed == false) { + closed = true; + try (var ignored = threadContext.restoreExistingContext(requestContext)) { + for (var tracer : tracingHandlers) { + Releasables.closeExpectNoException(tracer); + } + if (handler != null) { + handler.close(); + } + } finally { + if (buf != null) { + buf.release(); + buf = null; + bufSize = 0; + } + autoRead.release(); + channel.closeFuture().removeListener(closeListener); } } - if (buf != null) { - buf.release(); - buf = null; - bufSize = 0; - } - channel.config().setAutoRead(true); } + } diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java index 36c860f1fb90b..29fa65d8983eb 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java @@ -24,7 +24,6 @@ import io.netty.handler.codec.http.HttpContentCompressor; import io.netty.handler.codec.http.HttpContentDecompressor; import io.netty.handler.codec.http.HttpMessage; -import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpRequestDecoder; import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpResponseEncoder; @@ -376,14 +375,7 @@ protected HttpMessage createMessage(String[] initialLine) throws Exception { ) ); } - // combines the HTTP message pieces into a single full HTTP request (with headers and body) - final HttpObjectAggregator aggregator = new Netty4HttpAggregator( - handlingSettings.maxContentLength(), - httpPreRequest -> enabled.get() == false - || ((httpPreRequest.rawPath().endsWith("/_bulk") == false) - || httpPreRequest.rawPath().startsWith("/_xpack/monitoring/_bulk")) - ); - aggregator.setMaxCumulationBufferComponents(transport.maxCompositeBufferComponents); + final var contentSizeHandler = new Netty4HttpContentSizeHandler(decoder, handlingSettings.maxContentLength()); ch.pipeline() .addLast("decoder_compress", new HttpContentDecompressor()) // this handles request body decompression .addLast("encoder", new HttpResponseEncoder() { @@ -398,7 +390,7 @@ protected boolean isContentAlwaysEmpty(HttpResponse msg) { return super.isContentAlwaysEmpty(msg); } }) - .addLast("aggregator", aggregator); + .addLast("content_size", contentSizeHandler); if (handlingSettings.compression()) { ch.pipeline().addLast("encoder_compress", new HttpContentCompressor(handlingSettings.compressionLevel()) { @Override diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/AutoReadSyncTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/AutoReadSyncTests.java new file mode 100644 index 0000000000000..0b90cc2c05efe --- /dev/null +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/AutoReadSyncTests.java @@ -0,0 +1,99 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.http.netty4; + +import io.netty.channel.Channel; +import io.netty.channel.embedded.EmbeddedChannel; + +import org.elasticsearch.test.ESTestCase; + +import java.util.HashSet; +import java.util.stream.IntStream; + +public class AutoReadSyncTests extends ESTestCase { + + Channel chan; + + @Override + public void setUp() throws Exception { + super.setUp(); + chan = new EmbeddedChannel(); + } + + AutoReadSync.Handle getHandle() { + return AutoReadSync.getHandle(chan); + } + + public void testToggleSetAutoRead() { + var autoRead = getHandle(); + assertTrue("must be enabled by default", autoRead.isEnabled()); + + autoRead.disable(); + assertFalse("must disable handle", autoRead.isEnabled()); + assertFalse("must turn off chan autoRead", chan.config().isAutoRead()); + + autoRead.enable(); + assertTrue("must enable handle", autoRead.isEnabled()); + assertTrue("must turn on chan autoRead", chan.config().isAutoRead()); + + autoRead.disable(); + autoRead.release(); + assertTrue("must turn on chan autoRead on release", chan.config().isAutoRead()); + } + + public void testAnyToggleDisableAutoRead() { + var handles = IntStream.range(0, 100).mapToObj(i -> getHandle()).toList(); + handles.forEach(AutoReadSync.Handle::enable); + handles.get(between(0, 100)).disable(); + assertFalse(chan.config().isAutoRead()); + } + + public void testNewHandleDoesNotChangeAutoRead() { + var handle1 = getHandle(); + + handle1.disable(); + assertFalse(chan.config().isAutoRead()); + getHandle(); + assertFalse("acquiring new handle should enable autoRead", chan.config().isAutoRead()); + + handle1.enable(); + assertTrue(chan.config().isAutoRead()); + getHandle(); + assertTrue("acquiring new handle should not disable autoRead", chan.config().isAutoRead()); + } + + public void testAllTogglesEnableAutoRead() { + // mix-in acquire/release + var handles = new HashSet(); + IntStream.range(0, 100).mapToObj(i -> getHandle()).forEach(h -> { + h.disable(); + handles.add(h); + }); + assertFalse(chan.config().isAutoRead()); + + var toRelease = between(1, 98); // release some but not all + var releasedHandles = handles.stream().limit(toRelease).toList(); + releasedHandles.forEach(h -> { + h.release(); + handles.remove(h); + }); + assertFalse("releasing some but not all handles should not enable autoRead", chan.config().isAutoRead()); + + var lastHandle = getHandle(); + lastHandle.disable(); + for (var handle : handles) { + handle.enable(); + assertFalse("should not enable autoRead until lastHandle is enabled", chan.config().isAutoRead()); + } + lastHandle.enable(); + assertTrue(chan.config().isAutoRead()); + } + +} diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpHeaderThreadContextTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpHeaderThreadContextTests.java index 9a12ba75d7742..998146c6faeb8 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpHeaderThreadContextTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpHeaderThreadContextTests.java @@ -73,6 +73,7 @@ public void testSuccessfulSyncValidationUntamperedThreadContext() throws Excepti ) ); channel.pipeline().addLast(defaultContextAssertingChannelHandler(threadPool.getThreadContext())); + channel.pipeline().fireChannelRegistered(); // send first request through sendRequestThrough(isValidationSuccessful.get(), null); // send second request through, to check in case the context got stained by the first one through @@ -93,6 +94,7 @@ public void testFailedSyncValidationUntamperedThreadContext() throws Exception { ) ); channel.pipeline().addLast(defaultContextAssertingChannelHandler(threadPool.getThreadContext())); + channel.pipeline().fireChannelRegistered(); // send first request through sendRequestThrough(isValidationSuccessful.get(), null); // send second request through, to check in case the context got stained by the first one through @@ -115,6 +117,7 @@ public void testSuccessfulAsyncValidationUntamperedThreadContext() throws Except ) ); channel.pipeline().addLast(defaultContextAssertingChannelHandler(threadPool.getThreadContext())); + channel.pipeline().fireChannelRegistered(); // send first request through sendRequestThrough(isValidationSuccessful.get(), validationDone); // send second request through, to check in case the context got stained by the first one through @@ -137,6 +140,7 @@ public void testUnsuccessfulAsyncValidationUntamperedThreadContext() throws Exce ) ); channel.pipeline().addLast(defaultContextAssertingChannelHandler(threadPool.getThreadContext())); + channel.pipeline().fireChannelRegistered(); // send first request through sendRequestThrough(isValidationSuccessful.get(), validationDone); // send second request through, to check in case the context got stained by the first one through diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpHeaderValidatorTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpHeaderValidatorTests.java index 1c0b434105f28..e64eba606398f 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpHeaderValidatorTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpHeaderValidatorTests.java @@ -62,7 +62,6 @@ public void setUp() throws Exception { } private void reset() { - channel = new EmbeddedChannel(); header.set(null); listener.set(null); validationException.set(null); @@ -75,7 +74,7 @@ private void reset() { listener.set(validationCompleteListener); }; netty4HttpHeaderValidator = new Netty4HttpHeaderValidator(validator, new ThreadContext(Settings.EMPTY)); - channel.pipeline().addLast(netty4HttpHeaderValidator); + channel = new EmbeddedChannel(true, false, netty4HttpHeaderValidator); } public void testValidationPausesAndResumesData() { diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpRequestBodyStreamTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpRequestBodyStreamTests.java index d456bbecfbd20..41cb241f08338 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpRequestBodyStreamTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpRequestBodyStreamTests.java @@ -19,13 +19,13 @@ import io.netty.handler.flow.FlowControlHandler; import org.elasticsearch.common.bytes.ReleasableBytesReference; +import org.elasticsearch.common.network.ThreadWatchdog; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.http.HttpBody; import org.elasticsearch.test.ESTestCase; import java.util.ArrayList; -import java.util.HashMap; import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; @@ -45,11 +45,11 @@ public void setUp() throws Exception { super.setUp(); channel = new EmbeddedChannel(); threadContext.putHeader("header1", "value1"); - stream = new Netty4HttpRequestBodyStream(channel, threadContext); + stream = new Netty4HttpRequestBodyStream(channel, threadContext, new ThreadWatchdog.ActivityTracker()); stream.setHandler(discardHandler); // set default handler, each test might override one channel.pipeline().addLast(new SimpleChannelInboundHandler(false) { @Override - protected void channelRead0(ChannelHandlerContext ctx, HttpContent msg) { + protected void channelRead0(ChannelHandlerContext ctx, HttpContent msg) throws Exception { stream.handleNettyContent(msg); } }); @@ -169,16 +169,7 @@ public void close() { assertThat(headers.get(), hasEntry("header1", "value1")); assertThat(headers.get(), hasEntry("header2", "value2")); assertThat(headers.get(), hasEntry("header3", "value3")); - assertTrue("should receive last content", gotLast.get()); - - headers.set(new HashMap<>()); - - stream.close(); - - assertThat(headers.get(), hasEntry("header1", "value1")); - assertThat(headers.get(), hasEntry("header2", "value2")); - assertThat(headers.get(), hasEntry("header3", "value3")); } HttpContent randomContent(int size, boolean isLast) { diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java index 1d39b993cef92..2f59392699bdc 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java @@ -192,9 +192,9 @@ private void runExpectHeaderTest( final int contentLength, final HttpResponseStatus expectedStatus ) throws InterruptedException { - final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() { + final HttpServerTransport.Dispatcher dispatcher = new AggregatingDispatcher() { @Override - public void dispatchRequest(RestRequest request, RestChannel channel, ThreadContext threadContext) { + public void dispatchAggregatedRequest(RestRequest request, RestChannel channel, ThreadContext threadContext) { channel.sendResponse(new RestResponse(OK, RestResponse.TEXT_CONTENT_TYPE, new BytesArray("done"))); } @@ -1057,9 +1057,9 @@ private void runRespondAfterServiceCloseTest(boolean clientCancel) throws Except final SubscribableListener transportClosedFuture = new SubscribableListener<>(); final CountDownLatch handlingRequestLatch = new CountDownLatch(1); - final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() { + final HttpServerTransport.Dispatcher dispatcher = new AggregatingDispatcher() { @Override - public void dispatchRequest(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) { + public void dispatchAggregatedRequest(RestRequest request, RestChannel channel, ThreadContext threadContext) { assertEquals(request.uri(), url); final var response = RestResponse.chunked( OK, diff --git a/server/src/main/java/org/elasticsearch/action/ActionModule.java b/server/src/main/java/org/elasticsearch/action/ActionModule.java index 98d6284fd91d2..e3dee4af22768 100644 --- a/server/src/main/java/org/elasticsearch/action/ActionModule.java +++ b/server/src/main/java/org/elasticsearch/action/ActionModule.java @@ -931,7 +931,7 @@ public void initRestHandlers(Supplier nodesInCluster, Predicate< registerHandler.accept(new RestCountAction()); registerHandler.accept(new RestTermVectorsAction()); registerHandler.accept(new RestMultiTermVectorsAction()); - registerHandler.accept(new RestBulkAction(settings, bulkService)); + registerHandler.accept(new RestBulkAction(settings, clusterSettings, bulkService)); registerHandler.accept(new RestUpdateAction()); registerHandler.accept(new RestSearchAction(restController.getSearchUsageHolder(), clusterSupportsFeature)); diff --git a/server/src/main/java/org/elasticsearch/http/HttpBody.java b/server/src/main/java/org/elasticsearch/http/HttpBody.java index 6571125677fab..eb6adb3851a02 100644 --- a/server/src/main/java/org/elasticsearch/http/HttpBody.java +++ b/server/src/main/java/org/elasticsearch/http/HttpBody.java @@ -23,6 +23,10 @@ static Full fromBytesReference(BytesReference bytesRef) { return new ByteRefHttpBody(ReleasableBytesReference.wrap(bytesRef)); } + static Full fromReleasableBytesReference(ReleasableBytesReference relBytes) { + return new ByteRefHttpBody(relBytes); + } + static Full empty() { return new ByteRefHttpBody(ReleasableBytesReference.empty()); } @@ -56,9 +60,6 @@ default Stream asStream() { */ non-sealed interface Full extends HttpBody { ReleasableBytesReference bytes(); - - @Override - default void close() {} } /** @@ -107,11 +108,16 @@ non-sealed interface Stream extends HttpBody { @FunctionalInterface interface ChunkHandler extends Releasable { - void onNext(ReleasableBytesReference chunk, boolean isLast); + void onNext(ReleasableBytesReference chunk, boolean isLast) throws Exception; @Override default void close() {} } - record ByteRefHttpBody(ReleasableBytesReference bytes) implements Full {} + record ByteRefHttpBody(ReleasableBytesReference bytes) implements Full { + @Override + public void close() { + bytes.close(); + } + } } diff --git a/server/src/main/java/org/elasticsearch/http/HttpRequest.java b/server/src/main/java/org/elasticsearch/http/HttpRequest.java index b4b1bb84433c9..2b36dff967c91 100644 --- a/server/src/main/java/org/elasticsearch/http/HttpRequest.java +++ b/server/src/main/java/org/elasticsearch/http/HttpRequest.java @@ -28,6 +28,15 @@ enum HttpVersion { HTTP_1_1 } + /** + * Returns HTTP request content length, empty content has 0 length, unknown -1. Fully aggregated content returns its actual size. + * Streamed request returns content-length header value. There are two cases when content-length header is not present. + * First, when transfer-encoding is chunked. Request must not specify content-length header. Method returns -1. Second, when + * request does not have a body, for example, GET request without body can omit header. Method returns 0. + *

See RFC 9112 # Content-Length. + */ + int contentLength(); + HttpBody body(); List strictCookies(); diff --git a/server/src/main/java/org/elasticsearch/rest/DelegatingRestChannel.java b/server/src/main/java/org/elasticsearch/rest/DelegatingRestChannel.java new file mode 100644 index 0000000000000..6b2fa6995e4ee --- /dev/null +++ b/server/src/main/java/org/elasticsearch/rest/DelegatingRestChannel.java @@ -0,0 +1,79 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.rest; + +import org.elasticsearch.common.io.stream.BytesStream; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentType; + +import java.io.IOException; +import java.io.OutputStream; + +class DelegatingRestChannel implements RestChannel { + + private final RestChannel delegate; + + DelegatingRestChannel(RestChannel delegate) { + this.delegate = delegate; + } + + @Override + public XContentBuilder newBuilder() throws IOException { + return delegate.newBuilder(); + } + + @Override + public XContentBuilder newErrorBuilder() throws IOException { + return delegate.newErrorBuilder(); + } + + @Override + public XContentBuilder newBuilder(@Nullable XContentType xContentType, boolean useFiltering) throws IOException { + return delegate.newBuilder(xContentType, useFiltering); + } + + @Override + public XContentBuilder newBuilder(XContentType xContentType, XContentType responseContentType, boolean useFiltering) + throws IOException { + return delegate.newBuilder(xContentType, responseContentType, useFiltering); + } + + @Override + public XContentBuilder newBuilder(XContentType xContentType, XContentType responseContentType, boolean useFiltering, OutputStream out) + throws IOException { + return delegate.newBuilder(xContentType, responseContentType, useFiltering, out); + } + + @Override + public BytesStream bytesOutput() { + return delegate.bytesOutput(); + } + + @Override + public void releaseOutputBuffer() { + delegate.releaseOutputBuffer(); + } + + @Override + public RestRequest request() { + return delegate.request(); + } + + @Override + public boolean detailedErrorsEnabled() { + return delegate.detailedErrorsEnabled(); + } + + @Override + public void sendResponse(RestResponse response) { + delegate.sendResponse(response); + } +} diff --git a/server/src/main/java/org/elasticsearch/rest/RestContentAggregator.java b/server/src/main/java/org/elasticsearch/rest/RestContentAggregator.java new file mode 100644 index 0000000000000..6d9a827f07bb2 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/rest/RestContentAggregator.java @@ -0,0 +1,232 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.rest; + +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.bytes.CompositeBytesReference; +import org.elasticsearch.common.bytes.ReleasableBytesReference; +import org.elasticsearch.core.AbstractRefCounted; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.http.HttpBody; +import org.elasticsearch.http.HttpRequest; +import org.elasticsearch.http.HttpResponse; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.elasticsearch.http.HttpBody.ChunkHandler; + +public class RestContentAggregator { + + /** + * Aggregates streamed HTTP content and completes listener with RestRequest with full content. + * Completes with exception on unexpected HTTP content (ie content-length is 0, but receive + * non-empty chunk). + */ + public static void aggregate(RestRequest request, RestChannel channel, AggregateConsumer result) { + ChunkHandler handler; + if (request.contentLength() == 0) { + handler = new NoContent(request, channel, result); + } else { + handler = new ChunkAggregator(request, channel, result); + } + var stream = request.contentStream(); + stream.setHandler(handler); + stream.next(); + } + + @FunctionalInterface + public interface AggregateConsumer { + void accept(RestRequest request, RestChannel channel) throws Exception; + } + + /** + * Wraps streamed {@link HttpRequest} with aggregated content. + * Does not replace the original content-length header. + * Since full content is already available, we can use its length instead. + */ + private static class AggregatedHttpRequest implements HttpRequest { + final HttpRequest streamedRequest; + final HttpBody.Full aggregatedContent; + final AtomicBoolean released = new AtomicBoolean(false); + + private AggregatedHttpRequest(HttpRequest streamedRequest, HttpBody.Full aggregatedContent) { + this.streamedRequest = streamedRequest; + this.aggregatedContent = aggregatedContent; + } + + @Override + public int contentLength() { + return aggregatedContent.bytes().length(); + } + + @Override + public HttpBody body() { + return aggregatedContent; + } + + @Override + public List strictCookies() { + return streamedRequest.strictCookies(); + } + + @Override + public HttpVersion protocolVersion() { + return streamedRequest.protocolVersion(); + } + + @Override + public HttpRequest removeHeader(String header) { + var request = streamedRequest.removeHeader(header); + return new AggregatedHttpRequest(request, aggregatedContent); + } + + @Override + public HttpResponse createResponse(RestStatus status, BytesReference content) { + return streamedRequest.createResponse(status, content); + } + + @Override + public HttpResponse createResponse(RestStatus status, ChunkedRestResponseBodyPart firstBodyPart) { + return streamedRequest.createResponse(status, firstBodyPart); + } + + @Override + public Exception getInboundException() { + return streamedRequest.getInboundException(); + } + + @Override + public void release() { + if (released.compareAndSet(false, true)) { + // request is not ref counted, but content is + aggregatedContent.close(); + } + } + + @Override + public RestRequest.Method method() { + return streamedRequest.method(); + } + + @Override + public String uri() { + return streamedRequest.uri(); + } + + @Override + public Map> getHeaders() { + return streamedRequest.getHeaders(); + } + } + + static final class AggregatedRestRequestChannel extends DelegatingRestChannel { + private final RestRequest request; + + AggregatedRestRequestChannel(RestChannel delegate, RestRequest request) { + super(delegate); + this.request = request; + } + + @Override + public RestRequest request() { + return request; + } + + @Override + public void sendResponse(RestResponse response) { + request.getHttpRequest().release(); // see DefaultRestChannel + super.sendResponse(response); + } + } + + /** + * A special case aggregator that expects no content. + * We still wait for the last empty content to proceed with streamedRequest handling. + */ + static class NoContent implements ChunkHandler { + private final RestRequest request; + private final RestChannel channel; + private final AggregateConsumer result; + + NoContent(RestRequest request, RestChannel channel, AggregateConsumer result) { + this.request = request; + this.channel = channel; + this.result = result; + } + + @Override + public void onNext(final ReleasableBytesReference lastEmptyChunk, boolean isLast) throws Exception { + assert lastEmptyChunk.length() == 0 && isLast; + var aggReq = new RestRequest(request, new AggregatedHttpRequest(request.getHttpRequest(), HttpBody.empty())); + var aggChan = new AggregatedRestRequestChannel(channel, aggReq); + result.accept(aggReq, aggChan); + } + } + + static class ChunkAggregator implements ChunkHandler { + + private final RestRequest request; + private final RestChannel channel; + private final HttpBody.Stream stream; + private final AggregateConsumer result; + private List aggregate; + + ChunkAggregator(RestRequest request, RestChannel channel, AggregateConsumer result) { + this.request = request; + this.channel = channel; + this.stream = request.contentStream(); + this.result = result; + this.aggregate = new ArrayList<>(); + } + + /** + * Compose and wrap all chunks into new {@link RestRequest} with full content. + */ + static RestRequest composeRestRequest(RestRequest streamedRestRequest, List chunks) { + final var composite = CompositeBytesReference.of(chunks.toArray(new BytesReference[0])); + final var refCnt = new AbstractRefCounted() { + + @Override + protected void closeInternal() { + Releasables.close(chunks); + } + }; + var aggregatedHttpRequest = new AggregatedHttpRequest( + streamedRestRequest.getHttpRequest(), + HttpBody.fromReleasableBytesReference(new ReleasableBytesReference(composite, refCnt)) + ); + return new RestRequest(streamedRestRequest, aggregatedHttpRequest); + } + + @Override + public void onNext(ReleasableBytesReference chunk, boolean isLast) throws Exception { + aggregate.add(chunk); + if (isLast == false) { + stream.next(); + } else { + var aggReq = composeRestRequest(request, aggregate); + var aggChan = new AggregatedRestRequestChannel(channel, aggReq); + aggregate = List.of(); + result.accept(aggReq, aggChan); + } + } + + @Override + public void close() { + Releasables.close(aggregate); + aggregate = List.of(); + } + + } + +} diff --git a/server/src/main/java/org/elasticsearch/rest/RestController.java b/server/src/main/java/org/elasticsearch/rest/RestController.java index 49801499ea991..d8d1b22141b08 100644 --- a/server/src/main/java/org/elasticsearch/rest/RestController.java +++ b/server/src/main/java/org/elasticsearch/rest/RestController.java @@ -22,7 +22,6 @@ import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.ReleasableBytesReference; -import org.elasticsearch.common.io.stream.BytesStream; import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.common.path.PathTrie; import org.elasticsearch.common.recycler.Recycler; @@ -53,7 +52,6 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; -import java.io.OutputStream; import java.util.Collections; import java.util.EnumSet; import java.util.Iterator; @@ -387,6 +385,27 @@ public Map getStats() { return Collections.unmodifiableSortedMap(allStats); } + private void maybeAggregateAndDispatch( + RestRequest request, + RestChannel channel, + RestHandler handler, + MethodHandlers methodHandlers, + ThreadContext context + ) throws Exception { + if (request.isStreamedContent() && handler.supportContentStream() == false) { + // all production non stream handlers + + RestContentAggregator.aggregate( + request, + channel, + (aggRequest, aggChan) -> dispatchRequest(aggRequest, aggChan, handler, methodHandlers, context) + ); + } else { + // stream handlers and tests that create request with full content + dispatchRequest(request, channel, handler, methodHandlers, context); + } + } + private void dispatchRequest( RestRequest request, RestChannel channel, @@ -621,7 +640,7 @@ private void tryAllHandlers(final RestRequest request, final RestChannel channel } else { startTrace(threadContext, channel, handlers.getPath()); var decoratedChannel = new MeteringRestChannelDecorator(channel, requestsCounter, handler.getConcreteRestHandler()); - dispatchRequest(request, decoratedChannel, handler, handlers, threadContext); + maybeAggregateAndDispatch(request, decoratedChannel, handler, handlers, threadContext); return; } } @@ -791,71 +810,6 @@ private static void recordRequestMetric(RestStatus statusCode, LongCounter reque } } - private static class DelegatingRestChannel implements RestChannel { - - private final RestChannel delegate; - - private DelegatingRestChannel(RestChannel delegate) { - this.delegate = delegate; - } - - @Override - public XContentBuilder newBuilder() throws IOException { - return delegate.newBuilder(); - } - - @Override - public XContentBuilder newErrorBuilder() throws IOException { - return delegate.newErrorBuilder(); - } - - @Override - public XContentBuilder newBuilder(@Nullable XContentType xContentType, boolean useFiltering) throws IOException { - return delegate.newBuilder(xContentType, useFiltering); - } - - @Override - public XContentBuilder newBuilder(XContentType xContentType, XContentType responseContentType, boolean useFiltering) - throws IOException { - return delegate.newBuilder(xContentType, responseContentType, useFiltering); - } - - @Override - public XContentBuilder newBuilder( - XContentType xContentType, - XContentType responseContentType, - boolean useFiltering, - OutputStream out - ) throws IOException { - return delegate.newBuilder(xContentType, responseContentType, useFiltering, out); - } - - @Override - public BytesStream bytesOutput() { - return delegate.bytesOutput(); - } - - @Override - public void releaseOutputBuffer() { - delegate.releaseOutputBuffer(); - } - - @Override - public RestRequest request() { - return delegate.request(); - } - - @Override - public boolean detailedErrorsEnabled() { - return delegate.detailedErrorsEnabled(); - } - - @Override - public void sendResponse(RestResponse response) { - delegate.sendResponse(response); - } - } - private static final class MeteringRestChannelDecorator extends DelegatingRestChannel { private final LongCounter requestsCounter; diff --git a/server/src/main/java/org/elasticsearch/rest/RestHandler.java b/server/src/main/java/org/elasticsearch/rest/RestHandler.java index 572e92e369a63..b1eb3c357c93a 100644 --- a/server/src/main/java/org/elasticsearch/rest/RestHandler.java +++ b/server/src/main/java/org/elasticsearch/rest/RestHandler.java @@ -40,6 +40,15 @@ default boolean canTripCircuitBreaker() { return true; } + /** + * Indicates if the RestHandler supports content processing as a stream of + * {@link org.elasticsearch.common.bytes.ReleasableBytesReference} chunks. + * See {@link org.elasticsearch.http.HttpBody.Stream}. + */ + default boolean supportContentStream() { + return false; + } + /** * Indicates if the RestHandler supports bulk content. A bulk request contains multiple objects * delineated by {@link XContent#bulkSeparator()}. If a handler returns true this will affect diff --git a/server/src/main/java/org/elasticsearch/rest/RestRequest.java b/server/src/main/java/org/elasticsearch/rest/RestRequest.java index a04bdcb32f2b4..f050dbeee24c7 100644 --- a/server/src/main/java/org/elasticsearch/rest/RestRequest.java +++ b/server/src/main/java/org/elasticsearch/rest/RestRequest.java @@ -176,6 +176,12 @@ protected RestRequest(RestRequest other) { this.requestId = other.requestId; } + protected RestRequest(RestRequest other, HttpRequest httpRequest) { + this(other); + this.consumedParams.addAll(other.consumedParams); + this.httpRequest = httpRequest; + } + private static @Nullable ParsedMediaType parseHeaderWithMediaType(Map> headers, String headerName) { // TODO: make all usages of headers case-insensitive List header = headers.get(headerName); @@ -291,11 +297,11 @@ public final String path() { } public boolean hasContent() { - return isStreamedContent() || contentLength() > 0; + return contentLength() != 0; } public int contentLength() { - return httpRequest.body().asFull().bytes().length(); + return httpRequest.contentLength(); } public boolean isFullContent() { diff --git a/server/src/main/java/org/elasticsearch/rest/action/document/RestBulkAction.java b/server/src/main/java/org/elasticsearch/rest/action/document/RestBulkAction.java index dea7b7138d0d0..e638bf19a089c 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/document/RestBulkAction.java +++ b/server/src/main/java/org/elasticsearch/rest/action/document/RestBulkAction.java @@ -22,6 +22,7 @@ import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.CompositeBytesReference; import org.elasticsearch.common.bytes.ReleasableBytesReference; +import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; @@ -63,14 +64,21 @@ public class RestBulkAction extends BaseRestHandler { private final boolean allowExplicitIndex; private final IncrementalBulkService bulkHandler; + private final IncrementalBulkService.Enabled incrementalEnabled; private final Set capabilities; - public RestBulkAction(Settings settings, IncrementalBulkService bulkHandler) { + public RestBulkAction(Settings settings, ClusterSettings clusterSettings, IncrementalBulkService bulkHandler) { this.allowExplicitIndex = MULTI_ALLOW_EXPLICIT_INDEX.get(settings); this.bulkHandler = bulkHandler; + this.incrementalEnabled = new IncrementalBulkService.Enabled(clusterSettings); this.capabilities = DataStream.isFailureStoreFeatureFlagEnabled() ? Set.of(FAILURE_STORE_STATUS_CAPABILITY) : Set.of(); } + @Override + public boolean supportContentStream() { + return incrementalEnabled.get(); + } + @Override public List routes() { return List.of( @@ -122,10 +130,10 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC } catch (Exception e) { return channel -> new RestToXContentListener<>(channel).onFailure(parseFailureException(e)); } - return channel -> { - content.mustIncRef(); - client.bulk(bulkRequest, ActionListener.releaseAfter(new RestRefCountedChunkedToXContentListener<>(channel), content)); - }; + return channel -> client.bulk( + bulkRequest, + ActionListener.withRef(new RestRefCountedChunkedToXContentListener<>(channel), content) + ); } else { String waitForActiveShards = request.param("wait_for_active_shards"); TimeValue timeout = request.paramAsTime("timeout", BulkShardRequest.DEFAULT_TIMEOUT); diff --git a/server/src/test/java/org/elasticsearch/http/AbstractHttpServerTransportTests.java b/server/src/test/java/org/elasticsearch/http/AbstractHttpServerTransportTests.java index fa774c0bcfd12..58883daf00184 100644 --- a/server/src/test/java/org/elasticsearch/http/AbstractHttpServerTransportTests.java +++ b/server/src/test/java/org/elasticsearch/http/AbstractHttpServerTransportTests.java @@ -459,7 +459,7 @@ public void testIncorrectHeaderHandling() { FakeRestRequest.FakeHttpRequest fakeHttpRequest = new FakeRestRequest.FakeHttpRequest( RestRequest.Method.GET, "/", - null, + BytesArray.EMPTY, headers ); @@ -475,7 +475,7 @@ public void testIncorrectHeaderHandling() { FakeRestRequest.FakeHttpRequest fakeHttpRequest = new FakeRestRequest.FakeHttpRequest( RestRequest.Method.GET, "/", - null, + BytesArray.EMPTY, headers ); diff --git a/server/src/test/java/org/elasticsearch/http/TestHttpRequest.java b/server/src/test/java/org/elasticsearch/http/TestHttpRequest.java index 27dc0be673abb..0bdb459d5fed1 100644 --- a/server/src/test/java/org/elasticsearch/http/TestHttpRequest.java +++ b/server/src/test/java/org/elasticsearch/http/TestHttpRequest.java @@ -47,6 +47,11 @@ public String uri() { return uri; } + @Override + public int contentLength() { + return 0; + } + @Override public HttpBody body() { return HttpBody.empty(); diff --git a/server/src/test/java/org/elasticsearch/rest/RestContentAggregatorTests.java b/server/src/test/java/org/elasticsearch/rest/RestContentAggregatorTests.java new file mode 100644 index 0000000000000..012bbdd3afafb --- /dev/null +++ b/server/src/test/java/org/elasticsearch/rest/RestContentAggregatorTests.java @@ -0,0 +1,166 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.rest; + +import org.elasticsearch.action.support.SubscribableListener; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.ReleasableBytesReference; +import org.elasticsearch.core.AbstractRefCounted; +import org.elasticsearch.http.HttpBody; +import org.elasticsearch.http.HttpRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.rest.FakeRestChannel; +import org.elasticsearch.xcontent.XContentParserConfiguration; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.stream.IntStream; + +import static org.elasticsearch.rest.RestRequest.Method.GET; +import static org.elasticsearch.test.rest.FakeRestRequest.FakeHttpRequest; + +public class RestContentAggregatorTests extends ESTestCase { + + static RestRequest restRequest(HttpRequest request) { + return RestRequest.request(XContentParserConfiguration.EMPTY, request, null); + } + + static RestRequest restRequest(HttpStream stream) { + return restRequest(new FakeHttpRequest(GET, "/", stream, Map.of("content-length", List.of("" + stream.contentLength)))); + } + + static FakeRestChannel restChan(RestRequest request) { + return new FakeRestChannel(request, false, 1); + } + + static List randomChunksList() { + return fixedChunksList(between(1, 16), between(1024, 8192)); + } + + static List fixedChunksList(int n, int size) { + return IntStream.range(0, n).mapToObj(i -> randomByteArrayOfLength(size)).toList(); + } + + static byte[] concatChunks(List chunks) { + var size = chunks.stream().mapToInt(c -> c.length).sum(); + var out = new byte[size]; + var off = 0; + for (var chunk : chunks) { + System.arraycopy(chunk, 0, out, off, chunk.length); + off += chunk.length; + } + return out; + } + + public void testNoContent() { + try (var stream = HttpStream.noContent()) { + var request = restRequest(stream); + var result = new SubscribableListener(); + RestContentAggregator.aggregate(request, restChan(request), (req, chan) -> result.onResponse(req)); + assertEquals(0, safeAwait(result).content().length()); + } + } + + public void testAggregateChunks() { + var chunks = randomChunksList(); + try (var stream = HttpStream.of(chunks)) { + var request = restRequest(stream); + var result = new SubscribableListener(); + RestContentAggregator.aggregate(request, restChan(request), (r, c) -> result.onResponse(r)); + var aggBytes = safeAwait(result).content().toBytesRef().bytes; + assertArrayEquals(concatChunks(chunks), aggBytes); + } + } + + /** + * An HttpBody Stream implementation with single thread executor. Must be closed after use. + */ + static class HttpStream implements HttpBody.Stream { + static final byte[] EMPTY = new byte[] {}; + final ExecutorService executor = Executors.newSingleThreadExecutor(); + final List chunks; + final int contentLength; + final SubscribableListener err = new SubscribableListener<>(); + int ind; + ChunkHandler handler; + + HttpStream(List chunks) { + this.chunks = chunks; + this.contentLength = chunks.stream().mapToInt(ReleasableBytesReference::length).sum(); + } + + static HttpStream noContent() { + return of(EMPTY); + } + + static HttpStream of(byte[]... chunks) { + return new HttpStream(fromArrays(chunks)); + } + + static HttpStream of(List chunks) { + return new HttpStream(fromList(chunks)); + } + + static ReleasableBytesReference wrappedChunk(byte[] arr) { + var bytesArray = new BytesArray(arr); + return new ReleasableBytesReference(bytesArray, new AbstractRefCounted() { + @Override + protected void closeInternal() { + + } + }); + } + + static List fromArrays(byte[]... chunks) { + return java.util.stream.Stream.of(chunks).map(HttpStream::wrappedChunk).toList(); + } + + static List fromList(List chunks) { + return chunks.stream().map(HttpStream::wrappedChunk).toList(); + } + + @Override + public HttpBody.ChunkHandler handler() { + return handler; + } + + @Override + public void addTracingHandler(ChunkHandler chunkHandler) {} + + @Override + public void setHandler(ChunkHandler chunkHandler) { + this.handler = chunkHandler; + } + + @Override + public void next() { + executor.submit(() -> { + var chunk = chunks.get(ind); + if (chunk != null) { + ind++; + try { + handler.onNext(chunk, ind == chunks.size()); + } catch (Exception e) { + err.onResponse(e); + } + } + }); + } + + @Override + public void close() { + handler.close(); + executor.shutdownNow(); + } + } + +} diff --git a/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java b/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java index 2fdb3daa26da4..a9082dec7cbb5 100644 --- a/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java +++ b/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java @@ -861,6 +861,11 @@ public String uri() { return "/"; } + @Override + public int contentLength() { + return body().asFull().bytes().length(); + } + @Override public HttpBody body() { if (hasContent) { diff --git a/server/src/test/java/org/elasticsearch/rest/RestRequestTests.java b/server/src/test/java/org/elasticsearch/rest/RestRequestTests.java index b391b77503400..27e7704eb188d 100644 --- a/server/src/test/java/org/elasticsearch/rest/RestRequestTests.java +++ b/server/src/test/java/org/elasticsearch/rest/RestRequestTests.java @@ -92,6 +92,7 @@ private void runConsumesContentTest(final CheckedConsumer< when(httpRequest.getHeaders()).thenReturn( Collections.singletonMap("Content-Type", Collections.singletonList(randomFrom("application/json", "application/x-ndjson"))) ); + when(httpRequest.contentLength()).thenReturn(1); final RestRequest request = RestRequest.request(XContentParserConfiguration.EMPTY, httpRequest, mock(HttpChannel.class)); assertFalse(request.isContentConsumed()); try { diff --git a/server/src/test/java/org/elasticsearch/rest/action/document/RestBulkActionTests.java b/server/src/test/java/org/elasticsearch/rest/action/document/RestBulkActionTests.java index f83ba1704f954..eaad630753e56 100644 --- a/server/src/test/java/org/elasticsearch/rest/action/document/RestBulkActionTests.java +++ b/server/src/test/java/org/elasticsearch/rest/action/document/RestBulkActionTests.java @@ -20,6 +20,8 @@ import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.ReleasableBytesReference; +import org.elasticsearch.common.settings.ClusterSettings; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Releasable; import org.elasticsearch.http.HttpBody; import org.elasticsearch.index.IndexVersion; @@ -49,6 +51,12 @@ */ public class RestBulkActionTests extends ESTestCase { + static RestBulkAction newRestBulkAction() { + final Settings settings = settings(IndexVersion.current()).build(); + final ClusterSettings clusterSettings = new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + return new RestBulkAction(settings, clusterSettings, new IncrementalBulkService(mock(Client.class), mock(IndexingPressure.class))); + } + public void testBulkPipelineUpsert() throws Exception { SetOnce bulkCalled = new SetOnce<>(); try (var threadPool = createThreadPool()) { @@ -63,10 +71,7 @@ public void bulk(BulkRequest request, ActionListener listener) { }; final Map params = new HashMap<>(); params.put("pipeline", "timestamps"); - new RestBulkAction( - settings(IndexVersion.current()).build(), - new IncrementalBulkService(mock(Client.class), mock(IndexingPressure.class)) - ).handleRequest( + newRestBulkAction().handleRequest( new FakeRestRequest.Builder(xContentRegistry()).withPath("my_index/_bulk").withParams(params).withContent(new BytesArray(""" {"index":{"_id":"1"}} {"field1":"val1"} @@ -98,10 +103,7 @@ public void bulk(BulkRequest request, ActionListener listener) { }; Map params = new HashMap<>(); { - new RestBulkAction( - settings(IndexVersion.current()).build(), - new IncrementalBulkService(mock(Client.class), mock(IndexingPressure.class)) - ).handleRequest( + newRestBulkAction().handleRequest( new FakeRestRequest.Builder(xContentRegistry()).withPath("my_index/_bulk") .withParams(params) .withContent(new BytesArray(""" @@ -122,10 +124,7 @@ public void bulk(BulkRequest request, ActionListener listener) { { params.put("list_executed_pipelines", "true"); bulkCalled.set(false); - new RestBulkAction( - settings(IndexVersion.current()).build(), - new IncrementalBulkService(mock(Client.class), mock(IndexingPressure.class)) - ).handleRequest( + newRestBulkAction().handleRequest( new FakeRestRequest.Builder(xContentRegistry()).withPath("my_index/_bulk") .withParams(params) .withContent(new BytesArray(""" @@ -145,10 +144,7 @@ public void bulk(BulkRequest request, ActionListener listener) { } { bulkCalled.set(false); - new RestBulkAction( - settings(IndexVersion.current()).build(), - new IncrementalBulkService(mock(Client.class), mock(IndexingPressure.class)) - ).handleRequest( + newRestBulkAction().handleRequest( new FakeRestRequest.Builder(xContentRegistry()).withPath("my_index/_bulk") .withParams(params) .withContent(new BytesArray(""" @@ -169,10 +165,7 @@ public void bulk(BulkRequest request, ActionListener listener) { { params.remove("list_executed_pipelines"); bulkCalled.set(false); - new RestBulkAction( - settings(IndexVersion.current()).build(), - new IncrementalBulkService(mock(Client.class), mock(IndexingPressure.class)) - ).handleRequest( + newRestBulkAction().handleRequest( new FakeRestRequest.Builder(xContentRegistry()).withPath("my_index/_bulk") .withParams(params) .withContent(new BytesArray(""" diff --git a/test/framework/src/main/java/org/elasticsearch/http/AbstractHttpServerTransportTestCase.java b/test/framework/src/main/java/org/elasticsearch/http/AbstractHttpServerTransportTestCase.java index fd260c015e505..243c8382f2d8f 100644 --- a/test/framework/src/main/java/org/elasticsearch/http/AbstractHttpServerTransportTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/http/AbstractHttpServerTransportTestCase.java @@ -10,6 +10,10 @@ import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.rest.RestChannel; +import org.elasticsearch.rest.RestContentAggregator; +import org.elasticsearch.rest.RestRequest; import org.elasticsearch.test.ESTestCase; public class AbstractHttpServerTransportTestCase extends ESTestCase { @@ -20,4 +24,14 @@ protected static ClusterSettings randomClusterSettings() { ClusterSettings.BUILT_IN_CLUSTER_SETTINGS ); } + + public abstract static class AggregatingDispatcher implements HttpServerTransport.Dispatcher { + + public abstract void dispatchAggregatedRequest(RestRequest request, RestChannel channel, ThreadContext threadContext); + + @Override + public void dispatchRequest(RestRequest request, RestChannel channel, ThreadContext threadContext) { + RestContentAggregator.aggregate(request, channel, (r, c) -> dispatchAggregatedRequest(r, c, threadContext)); + } + } } diff --git a/test/framework/src/main/java/org/elasticsearch/test/rest/FakeRestRequest.java b/test/framework/src/main/java/org/elasticsearch/test/rest/FakeRestRequest.java index 0c466b9162eb8..efa83fdc50239 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/rest/FakeRestRequest.java +++ b/test/framework/src/main/java/org/elasticsearch/test/rest/FakeRestRequest.java @@ -63,6 +63,10 @@ public FakeHttpRequest(Method method, String uri, BytesReference content, Map> headers) { + this(method, uri, body, headers, null); + } + private FakeHttpRequest( Method method, String uri, @@ -87,6 +91,19 @@ public String uri() { return uri; } + @Override + public int contentLength() { + if (content.isFull()) { + return content.asFull().bytes().length(); + } else { + if (headers.getOrDefault("transfer-encoding", List.of("")).getFirst().isEmpty()) { // no transfer encoding + return Integer.parseInt(headers.getOrDefault("content-length", List.of("0")).getFirst()); + } else { + return -1; + } + } + } + @Override public HttpBody body() { return content; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HttpServerTransportCloseNotifyTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HttpServerTransportCloseNotifyTests.java index ec2881b989d0b..ed118a26faa98 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HttpServerTransportCloseNotifyTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HttpServerTransportCloseNotifyTests.java @@ -32,7 +32,6 @@ import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.env.TestEnvironment; import org.elasticsearch.http.AbstractHttpServerTransportTestCase; -import org.elasticsearch.http.HttpServerTransport; import org.elasticsearch.http.netty4.Netty4HttpServerTransport; import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestRequest; @@ -255,12 +254,12 @@ public void close() { } } - private static class QueuedDispatcher implements HttpServerTransport.Dispatcher { + private static class QueuedDispatcher extends AggregatingDispatcher { BlockingQueue reqQueue = new LinkedBlockingDeque<>(); BlockingDeque errQueue = new LinkedBlockingDeque<>(); @Override - public void dispatchRequest(RestRequest request, RestChannel channel, ThreadContext threadContext) { + public void dispatchAggregatedRequest(RestRequest request, RestChannel channel, ThreadContext threadContext) { reqQueue.add(new ReqCtx(request, channel, threadContext)); }