diff --git a/distribution/archives/integ-test-zip/src/javaRestTest/java/org/elasticsearch/test/rest/RequestsWithoutContentIT.java b/distribution/archives/integ-test-zip/src/javaRestTest/java/org/elasticsearch/test/rest/RequestsWithoutContentIT.java index c95c4c1d198f2..8732110bb1937 100644 --- a/distribution/archives/integ-test-zip/src/javaRestTest/java/org/elasticsearch/test/rest/RequestsWithoutContentIT.java +++ b/distribution/archives/integ-test-zip/src/javaRestTest/java/org/elasticsearch/test/rest/RequestsWithoutContentIT.java @@ -27,10 +27,9 @@ public void testIndexMissingBody() throws IOException { } public void testBulkMissingBody() throws IOException { - ResponseException responseException = expectThrows( - ResponseException.class, - () -> client().performRequest(new Request(randomBoolean() ? "POST" : "PUT", "/_bulk")) - ); + Request request = new Request(randomBoolean() ? "POST" : "PUT", "/_bulk"); + request.setJsonEntity(""); + ResponseException responseException = expectThrows(ResponseException.class, () -> client().performRequest(request)); assertResponseException(responseException, "request body is required"); } diff --git a/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4HttpRequestSizeLimitIT.java b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4HttpRequestSizeLimitIT.java index d9cfe009718b7..fcd45e9f9f47e 100644 --- a/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4HttpRequestSizeLimitIT.java +++ b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4HttpRequestSizeLimitIT.java @@ -14,6 +14,7 @@ import io.netty.util.ReferenceCounted; import org.elasticsearch.ESNetty4IntegTestCase; +import org.elasticsearch.action.bulk.IncrementalBulkService; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.unit.ByteSizeUnit; @@ -52,6 +53,8 @@ protected boolean addMockHttpTransport() { protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { return Settings.builder() .put(super.nodeSettings(nodeOrdinal, otherSettings)) + // TODO: We do not currently support in flight circuit breaker limits for bulk. However, IndexingPressure applies + .put(IncrementalBulkService.INCREMENTAL_BULK.getKey(), false) .put(HierarchyCircuitBreakerService.IN_FLIGHT_REQUESTS_CIRCUIT_BREAKER_LIMIT_SETTING.getKey(), LIMIT) .build(); } diff --git a/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4IncrementalRequestHandlingIT.java b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4IncrementalRequestHandlingIT.java new file mode 100644 index 0000000000000..2b9c77b17bced --- /dev/null +++ b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4IncrementalRequestHandlingIT.java @@ -0,0 +1,695 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.http.netty4; + +import io.netty.bootstrap.Bootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufInputStream; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.DefaultHttpContent; +import io.netty.handler.codec.http.DefaultHttpRequest; +import io.netty.handler.codec.http.DefaultLastHttpContent; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpChunkedInput; +import io.netty.handler.codec.http.HttpClientCodec; +import io.netty.handler.codec.http.HttpContent; +import io.netty.handler.codec.http.HttpObjectAggregator; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpUtil; +import io.netty.handler.codec.http.LastHttpContent; +import io.netty.handler.stream.ChunkedStream; +import io.netty.handler.stream.ChunkedWriteHandler; + +import org.apache.logging.log4j.Level; +import org.elasticsearch.ESNetty4IntegTestCase; +import org.elasticsearch.action.support.SubscribableListener; +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.node.DiscoveryNodes; +import org.elasticsearch.common.bytes.ReleasableBytesReference; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.settings.ClusterSettings; +import org.elasticsearch.common.settings.IndexScopedSettings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.settings.SettingsFilter; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.util.CollectionUtils; +import org.elasticsearch.features.NodeFeature; +import org.elasticsearch.http.HttpBodyTracer; +import org.elasticsearch.http.HttpHandlingSettings; +import org.elasticsearch.http.HttpServerTransport; +import org.elasticsearch.http.HttpTransportSettings; +import org.elasticsearch.plugins.ActionPlugin; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestChannel; +import org.elasticsearch.rest.RestController; +import org.elasticsearch.rest.RestHandler; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestResponse; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.test.MockLog; +import org.elasticsearch.test.junit.annotations.TestLogging; +import org.elasticsearch.transport.netty4.Netty4Utils; + +import java.util.Collection; +import java.util.List; +import java.util.concurrent.BlockingDeque; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import java.util.function.Predicate; +import java.util.function.Supplier; + +import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH; +import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_TYPE; +import static io.netty.handler.codec.http.HttpHeaderValues.APPLICATION_JSON; +import static io.netty.handler.codec.http.HttpMethod.POST; +import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; + +@ESIntegTestCase.ClusterScope(numDataNodes = 1) +public class Netty4IncrementalRequestHandlingIT extends ESNetty4IntegTestCase { + + @Override + protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { + Settings.Builder builder = Settings.builder().put(super.nodeSettings(nodeOrdinal, otherSettings)); + builder.put(HttpTransportSettings.SETTING_HTTP_MAX_CONTENT_LENGTH.getKey(), new ByteSizeValue(50, ByteSizeUnit.MB)); + return builder.build(); + } + + // ensure empty http content has single 0 size chunk + public void testEmptyContent() throws Exception { + try (var ctx = setupClientCtx()) { + var totalRequests = randomIntBetween(1, 10); + for (int reqNo = 0; reqNo < totalRequests; reqNo++) { + var opaqueId = opaqueId(reqNo); + + // send request with empty content + ctx.clientChannel.writeAndFlush(fullHttpRequest(opaqueId, Unpooled.EMPTY_BUFFER)); + var handler = ctx.awaitRestChannelAccepted(opaqueId); + handler.stream.next(); + + // should receive a single empty chunk + var recvChunk = safePoll(handler.recvChunks); + assertTrue(recvChunk.isLast); + assertEquals(0, recvChunk.chunk.length()); + recvChunk.chunk.close(); + assertFalse(handler.streamClosed); + + // send response to process following request + handler.sendResponse(new RestResponse(RestStatus.OK, "")); + assertBusy(() -> assertTrue(handler.streamClosed)); + } + assertBusy(() -> assertEquals("should receive all server responses", totalRequests, ctx.clientRespQueue.size())); + } + } + + // ensures content integrity, no loses and re-order + public void testReceiveAllChunks() throws Exception { + try (var ctx = setupClientCtx()) { + var totalRequests = randomIntBetween(1, 10); + for (int reqNo = 0; reqNo < totalRequests; reqNo++) { + var opaqueId = opaqueId(reqNo); + + // this dataset will be compared with one on server side + var dataSize = randomIntBetween(1024, maxContentLength()); + var sendData = Unpooled.wrappedBuffer(randomByteArrayOfLength(dataSize)); + sendData.retain(); + ctx.clientChannel.writeAndFlush(fullHttpRequest(opaqueId, sendData)); + + var handler = ctx.awaitRestChannelAccepted(opaqueId); + + var recvData = Unpooled.buffer(dataSize); + while (true) { + handler.stream.next(); + var recvChunk = safePoll(handler.recvChunks); + try (recvChunk.chunk) { + recvData.writeBytes(Netty4Utils.toByteBuf(recvChunk.chunk)); + if (recvChunk.isLast) { + break; + } + } + } + + assertFalse(handler.streamClosed); + assertEquals("sent and received payloads are not the same", sendData, recvData); + handler.sendResponse(new RestResponse(RestStatus.OK, "")); + assertBusy(() -> assertTrue(handler.streamClosed)); + } + assertBusy(() -> assertEquals("should receive all server responses", totalRequests, ctx.clientRespQueue.size())); + } + } + + // ensures that all received chunks are released when connection closed and handler notified + public void testClientConnectionCloseMidStream() throws Exception { + try (var ctx = setupClientCtx()) { + var opaqueId = opaqueId(0); + + // write half of http request + ctx.clientChannel.write(httpRequest(opaqueId, 2 * 1024)); + ctx.clientChannel.writeAndFlush(randomContent(1024, false)); + + // await stream handler is ready and request full content + var handler = ctx.awaitRestChannelAccepted(opaqueId); + assertBusy(() -> assertNotNull(handler.stream.buf())); + + // enable auto-read to receive channel close event + handler.stream.channel().config().setAutoRead(true); + assertFalse(handler.streamClosed); + + // terminate connection and wait resources are released + ctx.clientChannel.close(); + assertBusy(() -> { + assertNull(handler.stream.buf()); + assertTrue(handler.streamClosed); + }); + } + } + + // ensures that all recieved chunks are released when server decides to close connection + public void testServerCloseConnectionMidStream() throws Exception { + try (var ctx = setupClientCtx()) { + var opaqueId = opaqueId(0); + + // write half of http request + ctx.clientChannel.write(httpRequest(opaqueId, 2 * 1024)); + ctx.clientChannel.writeAndFlush(randomContent(1024, false)); + + // await stream handler is ready and request full content + var handler = ctx.awaitRestChannelAccepted(opaqueId); + assertBusy(() -> assertNotNull(handler.stream.buf())); + assertFalse(handler.streamClosed); + + // terminate connection on server and wait resources are released + handler.channel.request().getHttpChannel().close(); + assertBusy(() -> { + assertNull(handler.stream.buf()); + assertTrue(handler.streamClosed); + }); + } + } + + // ensure that client's socket buffers data when server is not consuming data + public void testClientBackpressure() throws Exception { + try (var ctx = setupClientCtx()) { + var opaqueId = opaqueId(0); + var payloadSize = maxContentLength(); + var totalParts = 10; + var partSize = payloadSize / totalParts; + ctx.clientChannel.writeAndFlush(httpRequest(opaqueId, payloadSize)); + for (int i = 0; i < totalParts; i++) { + ctx.clientChannel.writeAndFlush(randomContent(partSize, false)); + } + assertFalse( + "should not flush last content immediately", + ctx.clientChannel.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT).isDone() + ); + + var handler = ctx.awaitRestChannelAccepted(opaqueId); + + // some data flushes from channel into OS buffer and won't be visible here, usually 4-8Mb + var osBufferOffset = MBytes(10); + + // incrementally read data on server side and ensure client side buffer drains accordingly + for (int readBytes = 0; readBytes <= payloadSize; readBytes += partSize) { + var minBufSize = Math.max(payloadSize - readBytes - osBufferOffset, 0); + var maxBufSize = Math.max(payloadSize - readBytes, 0); + // it is hard to tell that client's channel is no logger flushing data + // it might take a few busy-iterations before channel buffer flush to OS + // and bytesBeforeWritable will stop changing + assertBusy(() -> { + var bufSize = ctx.clientChannel.bytesBeforeWritable(); + assertTrue( + "client's channel buffer should be in range [" + minBufSize + "," + maxBufSize + "], got " + bufSize, + bufSize >= minBufSize && bufSize <= maxBufSize + ); + }); + handler.readBytes(partSize); + } + assertTrue(handler.stream.hasLast()); + } + } + + // ensures that server reply 100-continue on acceptable request size + public void test100Continue() throws Exception { + try (var ctx = setupClientCtx()) { + for (int reqNo = 0; reqNo < randomIntBetween(2, 10); reqNo++) { + var id = opaqueId(reqNo); + var acceptableContentLength = randomIntBetween(0, maxContentLength()); + + // send request header and await 100-continue + var req = httpRequest(id, acceptableContentLength); + HttpUtil.set100ContinueExpected(req, true); + ctx.clientChannel.writeAndFlush(req); + var resp = (FullHttpResponse) safePoll(ctx.clientRespQueue); + assertEquals(HttpResponseStatus.CONTINUE, resp.status()); + resp.release(); + + // send content + var content = randomContent(acceptableContentLength, true); + ctx.clientChannel.writeAndFlush(content); + + // consume content and reply 200 + var handler = ctx.awaitRestChannelAccepted(id); + var consumed = handler.readAllBytes(); + assertEquals(acceptableContentLength, consumed); + handler.sendResponse(new RestResponse(RestStatus.OK, "")); + + resp = (FullHttpResponse) safePoll(ctx.clientRespQueue); + assertEquals(HttpResponseStatus.OK, resp.status()); + resp.release(); + } + } + } + + // ensures that server reply 413-too-large on oversized request with expect-100-continue + public void test413TooLargeOnExpect100Continue() throws Exception { + try (var ctx = setupClientCtx()) { + for (int reqNo = 0; reqNo < randomIntBetween(2, 10); reqNo++) { + var id = opaqueId(reqNo); + var oversized = maxContentLength() + 1; + + // send request header and await 413 too large + var req = httpRequest(id, oversized); + HttpUtil.set100ContinueExpected(req, true); + ctx.clientChannel.writeAndFlush(req); + var resp = (FullHttpResponse) safePoll(ctx.clientRespQueue); + assertEquals(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, resp.status()); + resp.release(); + + // terminate request + ctx.clientChannel.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT); + } + } + } + + // ensures that oversized chunked encoded request has no limits at http layer + // rest handler is responsible for oversized requests + public void testOversizedChunkedEncodingNoLimits() throws Exception { + try (var ctx = setupClientCtx()) { + for (var reqNo = 0; reqNo < randomIntBetween(2, 10); reqNo++) { + var id = opaqueId(reqNo); + var contentSize = maxContentLength() + 1; + var content = randomByteArrayOfLength(contentSize); + var is = new ByteBufInputStream(Unpooled.wrappedBuffer(content)); + var chunkedIs = new ChunkedStream(is); + var httpChunkedIs = new HttpChunkedInput(chunkedIs, LastHttpContent.EMPTY_LAST_CONTENT); + var req = httpRequest(id, 0); + HttpUtil.setTransferEncodingChunked(req, true); + + ctx.clientChannel.pipeline().addLast(new ChunkedWriteHandler()); + ctx.clientChannel.writeAndFlush(req); + ctx.clientChannel.writeAndFlush(httpChunkedIs); + var handler = ctx.awaitRestChannelAccepted(id); + var consumed = handler.readAllBytes(); + assertEquals(contentSize, consumed); + handler.sendResponse(new RestResponse(RestStatus.OK, "")); + + var resp = (FullHttpResponse) safePoll(ctx.clientRespQueue); + assertEquals(HttpResponseStatus.OK, resp.status()); + resp.release(); + } + } + } + + // ensures that we dont leak buffers in stream on 400-bad-request + // some bad requests are dispatched from rest-controller before reaching rest handler + // test relies on netty's buffer leak detection + public void testBadRequestReleaseQueuedChunks() throws Exception { + try (var ctx = setupClientCtx()) { + for (var reqNo = 0; reqNo < randomIntBetween(2, 10); reqNo++) { + var id = opaqueId(reqNo); + var contentSize = randomIntBetween(0, maxContentLength()); + var req = httpRequest(id, contentSize); + var content = randomContent(contentSize, true); + + // set unacceptable content-type + req.headers().set(CONTENT_TYPE, "unknown"); + ctx.clientChannel.writeAndFlush(req); + ctx.clientChannel.writeAndFlush(content); + + var resp = (FullHttpResponse) safePoll(ctx.clientRespQueue); + assertEquals(HttpResponseStatus.BAD_REQUEST, resp.status()); + resp.release(); + } + } + } + + private static long transportStatsRequestBytesSize(Ctx ctx) { + var httpTransport = internalCluster().getInstance(HttpServerTransport.class, ctx.nodeName); + var stats = httpTransport.stats().clientStats(); + var bytes = 0L; + for (var s : stats) { + bytes += s.requestSizeBytes(); + } + return bytes; + } + + /** + * ensures that {@link org.elasticsearch.http.HttpClientStatsTracker} counts streamed content bytes + */ + public void testHttpClientStats() throws Exception { + try (var ctx = setupClientCtx()) { + // need to offset starting point, since we reuse cluster and other tests already sent some data + var totalBytesSent = transportStatsRequestBytesSize(ctx); + + for (var reqNo = 0; reqNo < randomIntBetween(2, 10); reqNo++) { + var id = opaqueId(reqNo); + var contentSize = randomIntBetween(0, maxContentLength()); + totalBytesSent += contentSize; + ctx.clientChannel.writeAndFlush(httpRequest(id, contentSize)); + ctx.clientChannel.writeAndFlush(randomContent(contentSize, true)); + var handler = ctx.awaitRestChannelAccepted(id); + handler.readAllBytes(); + handler.sendResponse(new RestResponse(RestStatus.OK, "")); + assertEquals(totalBytesSent, transportStatsRequestBytesSize(ctx)); + } + } + } + + /** + * ensures that we log parts of http body and final line + */ + @TestLogging( + reason = "testing TRACE logging", + value = "org.elasticsearch.http.HttpTracer:TRACE,org.elasticsearch.http.HttpBodyTracer:TRACE" + ) + public void testHttpBodyLogging() throws Exception { + assertHttpBodyLogging((ctx) -> () -> { + try { + var req = fullHttpRequest(opaqueId(0), randomByteBuf(8 * 1024)); + ctx.clientChannel.writeAndFlush(req); + var handler = ctx.awaitRestChannelAccepted(opaqueId(0)); + handler.readAllBytes(); + } catch (Exception e) { + fail(e); + } + }); + } + + /** + * ensures that we log some parts of body and final line when connection is closed in the middle + */ + @TestLogging( + reason = "testing TRACE logging", + value = "org.elasticsearch.http.HttpTracer:TRACE,org.elasticsearch.http.HttpBodyTracer:TRACE" + ) + public void testHttpBodyLoggingChannelClose() throws Exception { + assertHttpBodyLogging((ctx) -> () -> { + try { + var req = httpRequest(opaqueId(0), 2 * 8192); + var halfContent = randomContent(8192, false); + ctx.clientChannel.writeAndFlush(req); + ctx.clientChannel.writeAndFlush(halfContent); + var handler = ctx.awaitRestChannelAccepted(opaqueId(0)); + handler.readBytes(8192); + ctx.clientChannel.close(); + handler.stream.next(); + assertBusy(() -> assertTrue(handler.streamClosed)); + } catch (Exception e) { + fail(e); + } + }); + } + + // asserts that we emit at least one logging event for a part and last line + // http body should be large enough to split across multiple lines, > 4kb + private void assertHttpBodyLogging(Function test) throws Exception { + try (var ctx = setupClientCtx()) { + MockLog.assertThatLogger( + test.apply(ctx), + HttpBodyTracer.class, + new MockLog.SeenEventExpectation( + "request part", + HttpBodyTracer.class.getCanonicalName(), + Level.TRACE, + "* request body [part *]*" + ), + new MockLog.SeenEventExpectation( + "request end", + HttpBodyTracer.class.getCanonicalName(), + Level.TRACE, + "* request body (gzip compressed, base64-encoded, and split into * parts on preceding log lines; for details see " + + "https://www.elastic.co/guide/en/elasticsearch/reference/master/modules-network.html#http-rest-request-tracer)" + ) + ); + } + } + + private int maxContentLength() { + return HttpHandlingSettings.fromSettings(internalCluster().getInstance(Settings.class)).maxContentLength(); + } + + private String opaqueId(int reqNo) { + return getTestName() + "-" + reqNo; + } + + static int MBytes(int m) { + return m * 1024 * 1024; + } + + static T safePoll(BlockingDeque queue) { + try { + var t = queue.poll(SAFE_AWAIT_TIMEOUT.seconds(), TimeUnit.SECONDS); + assertNotNull("queue is empty", t); + return t; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new AssertionError(e); + } + } + + static FullHttpRequest fullHttpRequest(String opaqueId, ByteBuf content) { + var req = new DefaultFullHttpRequest(HTTP_1_1, POST, ControlServerRequestPlugin.ROUTE, Unpooled.wrappedBuffer(content)); + req.headers().add(CONTENT_LENGTH, content.readableBytes()); + req.headers().add(CONTENT_TYPE, APPLICATION_JSON); + req.headers().add(Task.X_OPAQUE_ID_HTTP_HEADER, opaqueId); + return req; + } + + static HttpRequest httpRequest(String opaqueId, int contentLength) { + return httpRequest(ControlServerRequestPlugin.ROUTE, opaqueId, contentLength); + } + + static HttpRequest httpRequest(String uri, String opaqueId, int contentLength) { + var req = new DefaultHttpRequest(HTTP_1_1, POST, uri); + req.headers().add(CONTENT_LENGTH, contentLength); + req.headers().add(CONTENT_TYPE, APPLICATION_JSON); + req.headers().add(Task.X_OPAQUE_ID_HTTP_HEADER, opaqueId); + return req; + } + + static HttpContent randomContent(int size, boolean isLast) { + var buf = Unpooled.wrappedBuffer(randomByteArrayOfLength(size)); + if (isLast) { + return new DefaultLastHttpContent(buf); + } else { + return new DefaultHttpContent(buf); + } + } + + static ByteBuf randomByteBuf(int size) { + return Unpooled.wrappedBuffer(randomByteArrayOfLength(size)); + } + + Ctx setupClientCtx() throws Exception { + var nodeName = internalCluster().getRandomNodeName(); + var clientRespQueue = new LinkedBlockingDeque<>(16); + var bootstrap = bootstrapClient(nodeName, clientRespQueue); + var channel = bootstrap.connect().sync().channel(); + return new Ctx(getTestName(), nodeName, bootstrap, channel, clientRespQueue); + } + + Bootstrap bootstrapClient(String node, BlockingQueue queue) { + var httpServer = internalCluster().getInstance(HttpServerTransport.class, node); + var remoteAddr = randomFrom(httpServer.boundAddress().boundAddresses()); + return new Bootstrap().group(new NioEventLoopGroup(1)) + .channel(NioSocketChannel.class) + .remoteAddress(remoteAddr.getAddress(), remoteAddr.getPort()) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel ch) { + var p = ch.pipeline(); + p.addLast(new HttpClientCodec()); + p.addLast(new HttpObjectAggregator(4096)); + p.addLast(new SimpleChannelInboundHandler() { + @Override + protected void channelRead0(ChannelHandlerContext ctx, FullHttpResponse msg) { + msg.retain(); + queue.add(msg); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + queue.add(cause); + } + }); + } + }); + } + + @Override + protected Collection> nodePlugins() { + return CollectionUtils.concatLists(List.of(ControlServerRequestPlugin.class), super.nodePlugins()); + } + + @Override + protected boolean addMockHttpTransport() { + return false; // enable http + } + + record Ctx(String testName, String nodeName, Bootstrap clientBootstrap, Channel clientChannel, BlockingDeque clientRespQueue) + implements + AutoCloseable { + + @Override + public void close() throws Exception { + safeGet(clientChannel.close()); + safeGet(clientBootstrap.config().group().shutdownGracefully()); + clientRespQueue.forEach(o -> { if (o instanceof FullHttpResponse resp) resp.release(); }); + for (var opaqueId : ControlServerRequestPlugin.handlers.keySet()) { + if (opaqueId.startsWith(testName)) { + var handler = ControlServerRequestPlugin.handlers.get(opaqueId); + handler.recvChunks.forEach(c -> c.chunk.close()); + handler.channel.request().getHttpChannel().close(); + ControlServerRequestPlugin.handlers.remove(opaqueId); + } + } + } + + ServerRequestHandler awaitRestChannelAccepted(String opaqueId) throws Exception { + assertBusy(() -> assertTrue(ControlServerRequestPlugin.handlers.containsKey(opaqueId))); + var handler = ControlServerRequestPlugin.handlers.get(opaqueId); + safeAwait(handler.channelAccepted); + return handler; + } + } + + static class ServerRequestHandler implements BaseRestHandler.RequestBodyChunkConsumer { + final SubscribableListener channelAccepted = new SubscribableListener<>(); + final String opaqueId; + final BlockingDeque recvChunks = new LinkedBlockingDeque<>(); + final Netty4HttpRequestBodyStream stream; + RestChannel channel; + boolean recvLast = false; + volatile boolean streamClosed = false; + + ServerRequestHandler(String opaqueId, Netty4HttpRequestBodyStream stream) { + this.opaqueId = opaqueId; + this.stream = stream; + } + + @Override + public void handleChunk(RestChannel channel, ReleasableBytesReference chunk, boolean isLast) { + recvChunks.add(new Chunk(chunk, isLast)); + } + + @Override + public void accept(RestChannel channel) throws Exception { + this.channel = channel; + channelAccepted.onResponse(null); + } + + @Override + public void streamClose() { + streamClosed = true; + } + + void sendResponse(RestResponse response) { + channel.sendResponse(response); + } + + int readBytes(int bytes) { + var consumed = 0; + if (recvLast == false) { + while (consumed < bytes) { + stream.next(); + var recvChunk = safePoll(recvChunks); + consumed += recvChunk.chunk.length(); + recvChunk.chunk.close(); + if (recvChunk.isLast) { + recvLast = true; + break; + } + } + } + return consumed; + } + + int readAllBytes() { + return readBytes(Integer.MAX_VALUE); + } + + record Chunk(ReleasableBytesReference chunk, boolean isLast) {} + } + + // takes full control of rest handler from the outside + public static class ControlServerRequestPlugin extends Plugin implements ActionPlugin { + + static final String ROUTE = "/_test/request-stream"; + + static final ConcurrentHashMap handlers = new ConcurrentHashMap<>(); + + @Override + public Collection getRestHandlers( + Settings settings, + NamedWriteableRegistry namedWriteableRegistry, + RestController restController, + ClusterSettings clusterSettings, + IndexScopedSettings indexScopedSettings, + SettingsFilter settingsFilter, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier nodesInCluster, + Predicate clusterSupportsFeature + ) { + return List.of(new BaseRestHandler() { + @Override + public boolean allowsUnsafeBuffers() { + return true; + } + + @Override + public String getName() { + return ROUTE; + } + + @Override + public List routes() { + return List.of(new Route(RestRequest.Method.POST, ROUTE)); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) { + var stream = (Netty4HttpRequestBodyStream) request.contentStream(); + var opaqueId = request.getHeaders().get(Task.X_OPAQUE_ID_HTTP_HEADER).get(0); + var handler = new ServerRequestHandler(opaqueId, stream); + handlers.put(opaqueId, handler); + return handler; + } + }); + } + } + +} diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpAggregator.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpAggregator.java new file mode 100644 index 0000000000000..3c9e684ef4279 --- /dev/null +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpAggregator.java @@ -0,0 +1,84 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.http.netty4; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpContent; +import io.netty.handler.codec.http.HttpObject; +import io.netty.handler.codec.http.HttpObjectAggregator; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpUtil; + +import org.elasticsearch.http.HttpPreRequest; +import org.elasticsearch.http.netty4.internal.HttpHeadersAuthenticatorUtils; + +import java.util.function.Predicate; + +/** + * A wrapper around {@link HttpObjectAggregator}. Provides optional content aggregation based on + * predicate. {@link HttpObjectAggregator} also handles Expect: 100-continue and oversized content. + * Unfortunately, Netty does not provide handlers for oversized messages beyond HttpObjectAggregator. + */ +public class Netty4HttpAggregator extends HttpObjectAggregator { + private static final Predicate IGNORE_TEST = (req) -> req.uri().startsWith("/_test/request-stream") == false; + + private final Predicate decider; + private boolean aggregating = true; + private boolean ignoreContentAfterContinueResponse = false; + + public Netty4HttpAggregator(int maxContentLength, Predicate decider) { + super(maxContentLength); + this.decider = decider; + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + assert msg instanceof HttpObject; + if (msg instanceof HttpRequest request) { + var preReq = HttpHeadersAuthenticatorUtils.asHttpPreRequest(request); + aggregating = decider.test(preReq) && IGNORE_TEST.test(preReq); + } + if (aggregating || msg instanceof FullHttpRequest) { + super.channelRead(ctx, msg); + } else { + handle(ctx, (HttpObject) msg); + } + } + + private void handle(ChannelHandlerContext ctx, HttpObject msg) { + if (msg instanceof HttpRequest request) { + var continueResponse = newContinueResponse(request, maxContentLength(), ctx.pipeline()); + if (continueResponse != null) { + // there are 3 responses expected: 100, 413, 417 + // on 100 we pass request further and reply to client to continue + // on 413/417 we ignore following content + ctx.writeAndFlush(continueResponse); + var resp = (FullHttpResponse) continueResponse; + if (resp.status() != HttpResponseStatus.CONTINUE) { + ignoreContentAfterContinueResponse = true; + return; + } + HttpUtil.set100ContinueExpected(request, false); + } + ignoreContentAfterContinueResponse = false; + ctx.fireChannelRead(msg); + } else { + var httpContent = (HttpContent) msg; + if (ignoreContentAfterContinueResponse) { + httpContent.release(); + } else { + ctx.fireChannelRead(msg); + } + } + } +} diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpHeaderValidator.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpHeaderValidator.java index 88b458fd1c416..95a68cb52bbdb 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpHeaderValidator.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpHeaderValidator.java @@ -61,6 +61,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception pending.add(ReferenceCountUtil.retain(httpObject)); requestStart(ctx); assert state == QUEUEING_DATA; + assert ctx.channel().config().isAutoRead() == false; break; case QUEUEING_DATA: pending.add(ReferenceCountUtil.retain(httpObject)); @@ -77,14 +78,14 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception if (httpObject instanceof LastHttpContent) { state = WAITING_TO_START; } - // fall-through + ReferenceCountUtil.release(httpObject); + break; case DROPPING_DATA_PERMANENTLY: assert pending.isEmpty(); ReferenceCountUtil.release(httpObject); // consume without enqueuing + ctx.channel().config().setAutoRead(false); break; } - - setAutoReadForState(ctx, state); } private void requestStart(ChannelHandlerContext ctx) { @@ -105,6 +106,7 @@ private void requestStart(ChannelHandlerContext ctx) { } state = QUEUEING_DATA; + ctx.channel().config().setAutoRead(false); if (httpRequest == null) { // this looks like a malformed request and will forward without validation @@ -150,6 +152,7 @@ private void forwardFullRequest(ChannelHandlerContext ctx) { assert ctx.channel().config().isAutoRead() == false; assert state == QUEUEING_DATA; + ctx.channel().config().setAutoRead(true); boolean fullRequestForwarded = forwardData(ctx, pending); assert fullRequestForwarded || pending.isEmpty(); @@ -161,7 +164,6 @@ private void forwardFullRequest(ChannelHandlerContext ctx) { } assert state == WAITING_TO_START || state == QUEUEING_DATA || state == FORWARDING_DATA_UNTIL_NEXT_REQUEST; - setAutoReadForState(ctx, state); } private void forwardRequestWithDecoderExceptionAndNoContent(ChannelHandlerContext ctx, Exception e) { @@ -177,6 +179,8 @@ private void forwardRequestWithDecoderExceptionAndNoContent(ChannelHandlerContex messageToForward = toReplace.replace(Unpooled.EMPTY_BUFFER); } messageToForward.setDecoderResult(DecoderResult.failure(e)); + + ctx.channel().config().setAutoRead(true); ctx.fireChannelRead(messageToForward); assert fullRequestDropped || pending.isEmpty(); @@ -188,7 +192,6 @@ private void forwardRequestWithDecoderExceptionAndNoContent(ChannelHandlerContex } assert state == WAITING_TO_START || state == QUEUEING_DATA || state == DROPPING_DATA_UNTIL_NEXT_REQUEST; - setAutoReadForState(ctx, state); } @Override @@ -244,10 +247,6 @@ private static void maybeResizePendingDown(int largeSize, ArrayDeque } } - private static void setAutoReadForState(ChannelHandlerContext ctx, State state) { - ctx.channel().config().setAutoRead((state == QUEUEING_DATA || state == DROPPING_DATA_PERMANENTLY) == false); - } - enum State { WAITING_TO_START, QUEUEING_DATA, diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandler.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandler.java index d8eadf4fca95d..b08c93a4dc240 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandler.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandler.java @@ -20,8 +20,11 @@ import io.netty.handler.codec.http.DefaultHttpResponse; import io.netty.handler.codec.http.DefaultLastHttpContent; import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.HttpContent; import io.netty.handler.codec.http.HttpObject; +import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.LastHttpContent; import io.netty.handler.ssl.SslCloseCompletionEvent; import io.netty.util.ReferenceCountUtil; import io.netty.util.concurrent.Future; @@ -71,6 +74,12 @@ private record ChunkedWrite(PromiseCombiner combiner, ChannelPromise onDone, Chu @Nullable private ChunkedWrite currentChunkedWrite; + /** + * HTTP request content stream for current request, it's null if there is no current request or request is fully-aggregated + */ + @Nullable + private Netty4HttpRequestBodyStream currentRequestStream; + /* * The current read and write sequence numbers. Read sequence numbers are attached to requests in the order they are read from the * channel, and then transferred to responses. A response is not written to the channel context until its sequence number matches the @@ -110,23 +119,38 @@ public Netty4HttpPipeliningHandler( public void channelRead(final ChannelHandlerContext ctx, final Object msg) { activityTracker.startActivity(); try { - assert msg instanceof FullHttpRequest : "Should have fully aggregated message already but saw [" + msg + "]"; - final FullHttpRequest fullHttpRequest = (FullHttpRequest) msg; - final Netty4HttpRequest netty4HttpRequest; - if (fullHttpRequest.decoderResult().isFailure()) { - final Throwable cause = fullHttpRequest.decoderResult().cause(); - final Exception nonError; - if (cause instanceof Error) { - ExceptionsHelper.maybeDieOnAnotherThread(cause); - nonError = new Exception(cause); + if (msg instanceof HttpRequest request) { + final Netty4HttpRequest netty4HttpRequest; + if (request.decoderResult().isFailure()) { + final Throwable cause = request.decoderResult().cause(); + final Exception nonError; + if (cause instanceof Error) { + ExceptionsHelper.maybeDieOnAnotherThread(cause); + nonError = new Exception(cause); + } else { + nonError = (Exception) cause; + } + netty4HttpRequest = new Netty4HttpRequest(readSequence++, (FullHttpRequest) request, nonError); } else { - nonError = (Exception) cause; + assert currentRequestStream == null : "current stream must be null for new request"; + if (request instanceof FullHttpRequest fullHttpRequest) { + netty4HttpRequest = new Netty4HttpRequest(readSequence++, fullHttpRequest); + currentRequestStream = null; + } else { + var contentStream = new Netty4HttpRequestBodyStream(ctx.channel()); + currentRequestStream = contentStream; + netty4HttpRequest = new Netty4HttpRequest(readSequence++, request, contentStream); + } } - netty4HttpRequest = new Netty4HttpRequest(readSequence++, fullHttpRequest, nonError); + handlePipelinedRequest(ctx, netty4HttpRequest); } else { - netty4HttpRequest = new Netty4HttpRequest(readSequence++, fullHttpRequest); + assert msg instanceof HttpContent : "expect HttpContent got " + msg; + assert currentRequestStream != null : "current stream must exists before handling http content"; + currentRequestStream.handleNettyContent((HttpContent) msg); + if (msg instanceof LastHttpContent) { + currentRequestStream = null; + } } - handlePipelinedRequest(ctx, netty4HttpRequest); } finally { activityTracker.stopActivity(); } diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequest.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequest.java index fa4babea21555..b04da46a2d7d7 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequest.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequest.java @@ -12,6 +12,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.EmptyHttpHeaders; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaders; @@ -21,6 +22,7 @@ import io.netty.handler.codec.http.cookie.ServerCookieEncoder; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.http.HttpBody; import org.elasticsearch.http.HttpRequest; import org.elasticsearch.http.HttpResponse; import org.elasticsearch.rest.ChunkedRestResponseBodyPart; @@ -40,22 +42,40 @@ public class Netty4HttpRequest implements HttpRequest { private final FullHttpRequest request; - private final BytesReference content; + private final HttpBody content; private final Map> headers; private final AtomicBoolean released; private final Exception inboundException; private final boolean pooled; private final int sequence; + Netty4HttpRequest(int sequence, io.netty.handler.codec.http.HttpRequest request, Netty4HttpRequestBodyStream contentStream) { + this( + sequence, + new DefaultFullHttpRequest( + request.protocolVersion(), + request.method(), + request.uri(), + Unpooled.EMPTY_BUFFER, + request.headers(), + EmptyHttpHeaders.INSTANCE + ), + new AtomicBoolean(false), + true, + contentStream, + null + ); + } + Netty4HttpRequest(int sequence, FullHttpRequest request) { - this(sequence, request, new AtomicBoolean(false), true, Netty4Utils.toBytesReference(request.content())); + this(sequence, request, new AtomicBoolean(false), true, Netty4Utils.fullHttpBodyFrom(request.content())); } Netty4HttpRequest(int sequence, FullHttpRequest request, Exception inboundException) { - this(sequence, request, new AtomicBoolean(false), true, Netty4Utils.toBytesReference(request.content()), inboundException); + this(sequence, request, new AtomicBoolean(false), true, Netty4Utils.fullHttpBodyFrom(request.content()), inboundException); } - private Netty4HttpRequest(int sequence, FullHttpRequest request, AtomicBoolean released, boolean pooled, BytesReference content) { + private Netty4HttpRequest(int sequence, FullHttpRequest request, AtomicBoolean released, boolean pooled, HttpBody content) { this(sequence, request, released, pooled, content, null); } @@ -64,7 +84,7 @@ private Netty4HttpRequest( FullHttpRequest request, AtomicBoolean released, boolean pooled, - BytesReference content, + HttpBody content, Exception inboundException ) { this.sequence = sequence; @@ -87,7 +107,7 @@ public String uri() { } @Override - public BytesReference content() { + public HttpBody body() { assert released.get() == false; return content; } @@ -96,6 +116,7 @@ public BytesReference content() { public void release() { if (pooled && released.compareAndSet(false, true)) { request.release(); + content.close(); } } @@ -107,6 +128,12 @@ public HttpRequest releaseAndCopy() { } try { final ByteBuf copiedContent = Unpooled.copiedBuffer(request.content()); + HttpBody newContent; + if (content.isStream()) { + newContent = content; + } else { + newContent = Netty4Utils.fullHttpBodyFrom(copiedContent); + } return new Netty4HttpRequest( sequence, new DefaultFullHttpRequest( @@ -119,7 +146,7 @@ public HttpRequest releaseAndCopy() { ), new AtomicBoolean(false), false, - Netty4Utils.toBytesReference(copiedContent) + newContent ); } finally { release(); diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequestBodyStream.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequestBodyStream.java new file mode 100644 index 0000000000000..96f7deea978d9 --- /dev/null +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequestBodyStream.java @@ -0,0 +1,161 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.http.netty4; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFutureListener; +import io.netty.handler.codec.http.HttpContent; +import io.netty.handler.codec.http.LastHttpContent; + +import org.elasticsearch.core.Releasables; +import org.elasticsearch.http.HttpBody; +import org.elasticsearch.transport.netty4.Netty4Utils; + +import java.util.ArrayList; +import java.util.List; + +/** + * Netty based implementation of {@link HttpBody.Stream}. + * This implementation utilize {@link io.netty.channel.ChannelConfig#setAutoRead(boolean)} + * to prevent entire payload buffering. But sometimes upstream can send few chunks of data despite + * autoRead=off. In this case chunks will be buffered until downstream calls {@link Stream#next()} + */ +public class Netty4HttpRequestBodyStream implements HttpBody.Stream { + + private final Channel channel; + private final ChannelFutureListener closeListener = future -> doClose(); + private final List tracingHandlers = new ArrayList<>(4); + private ByteBuf buf; + private boolean hasLast = false; + private boolean requested = false; + private boolean closing = false; + private HttpBody.ChunkHandler handler; + + public Netty4HttpRequestBodyStream(Channel channel) { + this.channel = channel; + Netty4Utils.addListener(channel.closeFuture(), closeListener); + channel.config().setAutoRead(false); + } + + @Override + public ChunkHandler handler() { + return handler; + } + + @Override + public void setHandler(ChunkHandler chunkHandler) { + this.handler = chunkHandler; + } + + @Override + public void addTracingHandler(ChunkHandler chunkHandler) { + assert tracingHandlers.contains(chunkHandler) == false; + tracingHandlers.add(chunkHandler); + } + + @Override + public void next() { + assert closing == false : "cannot request next chunk on closing stream"; + assert handler != null : "handler must be set before requesting next chunk"; + channel.eventLoop().submit(() -> { + requested = true; + if (buf == null) { + channel.read(); + } else { + send(); + } + }); + } + + public void handleNettyContent(HttpContent httpContent) { + assert hasLast == false : "receive http content on completed stream"; + hasLast = httpContent instanceof LastHttpContent; + if (closing) { + httpContent.release(); + } else { + addChunk(httpContent.content()); + if (requested) { + send(); + } + } + if (hasLast) { + channel.config().setAutoRead(true); + channel.closeFuture().removeListener(closeListener); + } + } + + // adds chunk to current buffer, will allocate composite buffer when need to hold more than 1 chunk + private void addChunk(ByteBuf chunk) { + assert chunk != null; + if (buf == null) { + buf = chunk; + } else if (buf instanceof CompositeByteBuf comp) { + comp.addComponent(true, chunk); + } else { + var comp = channel.alloc().compositeBuffer(); + comp.addComponent(true, buf); + comp.addComponent(true, chunk); + buf = comp; + } + } + + // visible for test + Channel channel() { + return channel; + } + + // visible for test + ByteBuf buf() { + return buf; + } + + // visible for test + boolean hasLast() { + return hasLast; + } + + private void send() { + assert requested; + assert handler != null : "must set handler before receiving next chunk"; + var bytesRef = Netty4Utils.toReleasableBytesReference(buf); + requested = false; + buf = null; + for (var tracer : tracingHandlers) { + tracer.onNext(bytesRef, hasLast); + } + handler.onNext(bytesRef, hasLast); + } + + @Override + public void close() { + if (channel.eventLoop().inEventLoop()) { + doClose(); + } else { + channel.eventLoop().submit(this::doClose); + } + } + + private void doClose() { + closing = true; + for (var tracer : tracingHandlers) { + Releasables.closeExpectNoException(tracer); + } + if (handler != null) { + handler.close(); + } + if (buf != null) { + buf.release(); + buf = null; + } + channel.config().setAutoRead(true); + } +} diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java index f38bd1107ab33..c6e7fa3517771 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java @@ -37,6 +37,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.action.bulk.IncrementalBulkService; import org.elasticsearch.common.network.CloseableChannel; import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.network.ThreadWatchdog; @@ -97,6 +98,7 @@ public class Netty4HttpServerTransport extends AbstractHttpServerTransport { private final TLSConfig tlsConfig; private final AcceptChannelHandler.AcceptPredicate acceptChannelPredicate; private final HttpValidator httpValidator; + private final IncrementalBulkService.Enabled enabled; private final ThreadWatchdog threadWatchdog; private final int readTimeoutMillis; @@ -135,6 +137,7 @@ public Netty4HttpServerTransport( this.acceptChannelPredicate = acceptChannelPredicate; this.httpValidator = httpValidator; this.threadWatchdog = networkService.getThreadWatchdog(); + this.enabled = new IncrementalBulkService.Enabled(clusterSettings); this.pipeliningMaxEvents = SETTING_PIPELINING_MAX_EVENTS.get(settings); @@ -280,7 +283,7 @@ public void onException(HttpChannel channel, Exception cause) { } public ChannelHandler configureServerChannelHandler() { - return new HttpChannelHandler(this, handlingSettings, tlsConfig, acceptChannelPredicate, httpValidator); + return new HttpChannelHandler(this, handlingSettings, tlsConfig, acceptChannelPredicate, httpValidator, enabled); } static final AttributeKey HTTP_CHANNEL_KEY = AttributeKey.newInstance("es-http-channel"); @@ -293,19 +296,22 @@ protected static class HttpChannelHandler extends ChannelInitializer { private final TLSConfig tlsConfig; private final BiPredicate acceptChannelPredicate; private final HttpValidator httpValidator; + private final IncrementalBulkService.Enabled enabled; protected HttpChannelHandler( final Netty4HttpServerTransport transport, final HttpHandlingSettings handlingSettings, final TLSConfig tlsConfig, @Nullable final BiPredicate acceptChannelPredicate, - @Nullable final HttpValidator httpValidator + @Nullable final HttpValidator httpValidator, + IncrementalBulkService.Enabled enabled ) { this.transport = transport; this.handlingSettings = handlingSettings; this.tlsConfig = tlsConfig; this.acceptChannelPredicate = acceptChannelPredicate; this.httpValidator = httpValidator; + this.enabled = enabled; } @Override @@ -366,7 +372,13 @@ protected HttpMessage createMessage(String[] initialLine) throws Exception { ); } // combines the HTTP message pieces into a single full HTTP request (with headers and body) - final HttpObjectAggregator aggregator = new HttpObjectAggregator(handlingSettings.maxContentLength()); + final HttpObjectAggregator aggregator = new Netty4HttpAggregator( + handlingSettings.maxContentLength(), + httpPreRequest -> enabled.get() == false + || (httpPreRequest.uri().contains("_bulk") == false + || httpPreRequest.uri().contains("_bulk_update") + || httpPreRequest.uri().contains("/_xpack/monitoring/_bulk")) + ); aggregator.setMaxCumulationBufferComponents(transport.maxCompositeBufferComponents); ch.pipeline() .addLast("decoder_compress", new HttpContentDecompressor()) // this handles request body decompression diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Utils.java b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Utils.java index b3596c75999ec..f57aa0e680fa1 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Utils.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Utils.java @@ -27,11 +27,13 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.bytes.ReleasableBytesReference; import org.elasticsearch.common.recycler.Recycler; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.core.Booleans; import org.elasticsearch.core.SuppressForbidden; +import org.elasticsearch.http.HttpBody; import org.elasticsearch.transport.TransportException; import java.io.IOException; @@ -128,6 +130,14 @@ public static BytesReference toBytesReference(final ByteBuf buffer) { } } + public static ReleasableBytesReference toReleasableBytesReference(final ByteBuf buffer) { + return new ReleasableBytesReference(toBytesReference(buffer), buffer::release); + } + + public static HttpBody.Full fullHttpBodyFrom(final ByteBuf buf) { + return new HttpBody.ByteRefHttpBody(toBytesReference(buf)); + } + public static Recycler createRecycler(Settings settings) { // If this method is called by super ctor the processors will not be set. Accessing NettyAllocator initializes netty's internals // setting the processors. We must do it ourselves first just in case. diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpHeaderValidatorTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpHeaderValidatorTests.java index c2d52ac761034..1c0b434105f28 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpHeaderValidatorTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpHeaderValidatorTests.java @@ -117,6 +117,36 @@ public void testValidationPausesAndResumesData() { assertThat(netty4HttpHeaderValidator.getState(), equalTo(QUEUEING_DATA)); } + public void testValidatorDoesNotTweakAutoReadAfterValidationComplete() { + assertTrue(channel.config().isAutoRead()); + assertThat(netty4HttpHeaderValidator.getState(), equalTo(WAITING_TO_START)); + + final DefaultHttpRequest request = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/uri"); + DefaultHttpContent content = new DefaultHttpContent(Unpooled.buffer(4)); + channel.writeInbound(request); + channel.writeInbound(content); + + assertThat(header.get(), sameInstance(request)); + // channel is paused + assertThat(channel.readInbound(), nullValue()); + assertFalse(channel.config().isAutoRead()); + + // channel is resumed + listener.get().onResponse(null); + channel.runPendingTasks(); + + assertTrue(channel.config().isAutoRead()); + assertThat(netty4HttpHeaderValidator.getState(), equalTo(FORWARDING_DATA_UNTIL_NEXT_REQUEST)); + assertThat(channel.readInbound(), sameInstance(request)); + assertThat(channel.readInbound(), sameInstance(content)); + assertThat(channel.readInbound(), nullValue()); + assertThat(content.refCnt(), equalTo(1)); + channel.config().setAutoRead(false); + + channel.writeOutbound(new DefaultHttpContent(Unpooled.buffer(4))); + assertFalse(channel.config().isAutoRead()); + } + public void testContentForwardedAfterValidation() { assertTrue(channel.config().isAutoRead()); assertThat(netty4HttpHeaderValidator.getState(), equalTo(WAITING_TO_START)); diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpRequestBodyStreamTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpRequestBodyStreamTests.java new file mode 100644 index 0000000000000..0f35de483dc82 --- /dev/null +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpRequestBodyStreamTests.java @@ -0,0 +1,122 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.http.netty4; + +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.DefaultHttpContent; +import io.netty.handler.codec.http.DefaultLastHttpContent; +import io.netty.handler.codec.http.HttpContent; +import io.netty.handler.flow.FlowControlHandler; + +import org.elasticsearch.common.bytes.ReleasableBytesReference; +import org.elasticsearch.http.HttpBody; +import org.elasticsearch.test.ESTestCase; +import org.junit.Before; + +import java.util.ArrayList; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +public class Netty4HttpRequestBodyStreamTests extends ESTestCase { + + EmbeddedChannel channel; + Netty4HttpRequestBodyStream stream; + static HttpBody.ChunkHandler discardHandler = (chunk, isLast) -> chunk.close(); + + @Before + public void createStream() { + channel = new EmbeddedChannel(); + stream = new Netty4HttpRequestBodyStream(channel); + stream.setHandler(discardHandler); // set default handler, each test might override one + channel.pipeline().addLast(new SimpleChannelInboundHandler() { + @Override + protected void channelRead0(ChannelHandlerContext ctx, HttpContent msg) { + msg.retain(); + stream.handleNettyContent(msg); + } + }); + } + + // ensures that no chunks are sent downstream without request + public void testEnqueueChunksBeforeRequest() { + var totalChunks = randomIntBetween(1, 100); + for (int i = 0; i < totalChunks; i++) { + channel.writeInbound(randomContent(1024)); + } + assertEquals(totalChunks * 1024, stream.buf().readableBytes()); + } + + // ensures all received chunks can be flushed downstream + public void testFlushAllReceivedChunks() { + var chunks = new ArrayList(); + var totalBytes = new AtomicInteger(); + stream.setHandler((chunk, isLast) -> { + chunks.add(chunk); + totalBytes.addAndGet(chunk.length()); + }); + + var chunkSize = 1024; + var totalChunks = randomIntBetween(1, 100); + for (int i = 0; i < totalChunks; i++) { + channel.writeInbound(randomContent(chunkSize)); + } + stream.next(); + channel.runPendingTasks(); + assertEquals("should receive all chunks as single composite", 1, chunks.size()); + assertEquals(chunkSize * totalChunks, totalBytes.get()); + } + + // ensures that we read from channel when no current chunks available + // and pass next chunk downstream without holding + public void testReadFromChannel() { + var gotChunks = new ArrayList(); + var gotLast = new AtomicBoolean(false); + stream.setHandler((chunk, isLast) -> { + gotChunks.add(chunk); + gotLast.set(isLast); + }); + channel.pipeline().addFirst(new FlowControlHandler()); // block all incoming messages, need explicit channel.read() + var chunkSize = 1024; + var totalChunks = randomIntBetween(1, 32); + for (int i = 0; i < totalChunks - 1; i++) { + channel.writeInbound(randomContent(chunkSize)); + } + channel.writeInbound(randomLastContent(chunkSize)); + + for (int i = 0; i < totalChunks; i++) { + assertNull("should not enqueue chunks", stream.buf()); + stream.next(); + channel.runPendingTasks(); + assertEquals("each next() should produce single chunk", i + 1, gotChunks.size()); + } + assertTrue("should receive last content", gotLast.get()); + } + + HttpContent randomContent(int size, boolean isLast) { + var buf = Unpooled.wrappedBuffer(randomByteArrayOfLength(size)); + if (isLast) { + return new DefaultLastHttpContent(buf); + } else { + return new DefaultHttpContent(buf); + } + } + + HttpContent randomContent(int size) { + return randomContent(size, false); + } + + HttpContent randomLastContent(int size) { + return randomContent(size, true); + } + +} diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java index 18df654dd435f..3fd5cc44a3403 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java @@ -46,6 +46,7 @@ import org.elasticsearch.ElasticsearchSecurityException; import org.elasticsearch.ElasticsearchWrapperException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.bulk.IncrementalBulkService; import org.elasticsearch.action.support.ActionTestUtils; import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.client.Request; @@ -419,7 +420,8 @@ public ChannelHandler configureServerChannelHandler() { handlingSettings, TLSConfig.noTLS(), null, - randomFrom((httpPreRequest, channel, listener) -> listener.onResponse(null), null) + randomFrom((httpPreRequest, channel, listener) -> listener.onResponse(null), null), + new IncrementalBulkService.Enabled(clusterSettings) ) { @Override protected void initChannel(Channel ch) throws Exception { @@ -905,7 +907,7 @@ public void dispatchBadRequest(final RestChannel channel, final ThreadContext th assertThat(channel.request().getHttpRequest().header(headerReference.get()), is(headerValueReference.get())); assertThat(channel.request().getHttpRequest().method(), is(translateRequestMethod(httpMethodReference.get()))); // assert content is dropped - assertThat(channel.request().getHttpRequest().content().utf8ToString(), is("")); + assertThat(channel.request().getHttpRequest().body().asFull().bytes().utf8ToString(), is("")); try { channel.sendResponse(new RestResponse(channel, (Exception) ((ElasticsearchWrapperException) cause).getCause())); } catch (IOException e) { diff --git a/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/IncrementalBulkRestIT.java b/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/IncrementalBulkRestIT.java new file mode 100644 index 0000000000000..08026e0435f33 --- /dev/null +++ b/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/IncrementalBulkRestIT.java @@ -0,0 +1,187 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.http; + +import org.elasticsearch.action.bulk.IncrementalBulkService; +import org.elasticsearch.client.Request; +import org.elasticsearch.client.Response; +import org.elasticsearch.client.ResponseException; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.xcontent.json.JsonXContent; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.rest.RestStatus.OK; +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.Matchers.equalTo; + +@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.SUITE, supportsDedicatedMasters = false, numDataNodes = 2, numClientNodes = 0) +public class IncrementalBulkRestIT extends HttpSmokeTestCase { + + public void testBulkMissingBody() throws IOException { + Request request = new Request(randomBoolean() ? "POST" : "PUT", "/_bulk"); + request.setJsonEntity(""); + ResponseException responseException = expectThrows(ResponseException.class, () -> getRestClient().performRequest(request)); + assertEquals(400, responseException.getResponse().getStatusLine().getStatusCode()); + assertThat(responseException.getMessage(), containsString("request body is required")); + } + + public void testBulkRequestBodyImproperlyTerminated() throws IOException { + Request request = new Request(randomBoolean() ? "POST" : "PUT", "/_bulk"); + // missing final line of the bulk body. cannot process + request.setJsonEntity( + "{\"index\":{\"_index\":\"index_name\",\"_id\":\"1\"}}\n" + + "{\"field\":1}\n" + + "{\"index\":{\"_index\":\"index_name\",\"_id\":\"2\"}" + ); + ResponseException responseException = expectThrows(ResponseException.class, () -> getRestClient().performRequest(request)); + assertEquals(400, responseException.getResponse().getStatusLine().getStatusCode()); + assertThat(responseException.getMessage(), containsString("could not parse bulk request body")); + } + + public void testIncrementalBulk() throws IOException { + Request createRequest = new Request("PUT", "/index_name"); + createRequest.setJsonEntity(""" + { + "settings": { + "index": { + "number_of_shards": 1, + "number_of_replicas": 1, + "write.wait_for_active_shards": 2 + } + } + }"""); + final Response indexCreatedResponse = getRestClient().performRequest(createRequest); + assertThat(indexCreatedResponse.getStatusLine().getStatusCode(), equalTo(OK.getStatus())); + + Request firstBulkRequest = new Request("POST", "/index_name/_bulk"); + + // index documents for the rollup job + String bulkBody = "{\"index\":{\"_index\":\"index_name\",\"_id\":\"1\"}}\n" + + "{\"field\":1}\n" + + "{\"index\":{\"_index\":\"index_name\",\"_id\":\"2\"}}\n" + + "{\"field\":1}\n" + + "\r\n"; + + firstBulkRequest.setJsonEntity(bulkBody); + + final Response indexSuccessFul = getRestClient().performRequest(firstBulkRequest); + assertThat(indexSuccessFul.getStatusLine().getStatusCode(), equalTo(OK.getStatus())); + + sendLargeBulk(); + } + + public void testBulkWithIncrementalDisabled() throws IOException { + Request createRequest = new Request("PUT", "/index_name"); + createRequest.setJsonEntity(""" + { + "settings": { + "index": { + "number_of_shards": 1, + "number_of_replicas": 1, + "write.wait_for_active_shards": 2 + } + } + }"""); + final Response indexCreatedResponse = getRestClient().performRequest(createRequest); + assertThat(indexCreatedResponse.getStatusLine().getStatusCode(), equalTo(OK.getStatus())); + + Request firstBulkRequest = new Request("POST", "/index_name/_bulk"); + + // index documents for the rollup job + String bulkBody = "{\"index\":{\"_index\":\"index_name\",\"_id\":\"1\"}}\n" + + "{\"field\":1}\n" + + "{\"index\":{\"_index\":\"index_name\",\"_id\":\"2\"}}\n" + + "{\"field\":1}\n" + + "\r\n"; + + firstBulkRequest.setJsonEntity(bulkBody); + + final Response indexSuccessFul = getRestClient().performRequest(firstBulkRequest); + assertThat(indexSuccessFul.getStatusLine().getStatusCode(), equalTo(OK.getStatus())); + + updateClusterSettings(Settings.builder().put(IncrementalBulkService.INCREMENTAL_BULK.getKey(), false)); + + internalCluster().getInstances(IncrementalBulkService.class).forEach(i -> i.setForTests(false)); + + try { + sendLargeBulk(); + } finally { + internalCluster().getInstances(IncrementalBulkService.class).forEach(i -> i.setForTests(true)); + updateClusterSettings(Settings.builder().put(IncrementalBulkService.INCREMENTAL_BULK.getKey(), (String) null)); + } + } + + public void testIncrementalMalformed() throws IOException { + Request createRequest = new Request("PUT", "/index_name"); + createRequest.setJsonEntity(""" + { + "settings": { + "index": { + "number_of_shards": 1, + "number_of_replicas": 1, + "write.wait_for_active_shards": 2 + } + } + }"""); + final Response indexCreatedResponse = getRestClient().performRequest(createRequest); + assertThat(indexCreatedResponse.getStatusLine().getStatusCode(), equalTo(OK.getStatus())); + + Request bulkRequest = new Request("POST", "/index_name/_bulk"); + + // index documents for the rollup job + final StringBuilder bulk = new StringBuilder(); + bulk.append("{\"index\":{\"_index\":\"index_name\"}}\n"); + bulk.append("{\"field\":1}\n"); + bulk.append("{}\n"); + bulk.append("\r\n"); + + bulkRequest.setJsonEntity(bulk.toString()); + + expectThrows(ResponseException.class, () -> getRestClient().performRequest(bulkRequest)); + } + + @SuppressWarnings("unchecked") + private static void sendLargeBulk() throws IOException { + Request bulkRequest = new Request("POST", "/index_name/_bulk"); + + // index documents for the rollup job + final StringBuilder bulk = new StringBuilder(); + bulk.append("{\"delete\":{\"_index\":\"index_name\",\"_id\":\"1\"}}\n"); + int updates = 0; + for (int i = 0; i < 1000; i++) { + bulk.append("{\"index\":{\"_index\":\"index_name\"}}\n"); + bulk.append("{\"field\":").append(i).append("}\n"); + if (randomBoolean() && randomBoolean() && randomBoolean() && randomBoolean()) { + ++updates; + bulk.append("{\"update\":{\"_index\":\"index_name\",\"_id\":\"2\"}}\n"); + bulk.append("{\"doc\":{\"field\":").append(i).append("}}\n"); + } + } + bulk.append("\r\n"); + + bulkRequest.setJsonEntity(bulk.toString()); + + final Response bulkResponse = getRestClient().performRequest(bulkRequest); + assertThat(bulkResponse.getStatusLine().getStatusCode(), equalTo(OK.getStatus())); + Map responseMap = XContentHelper.convertToMap( + JsonXContent.jsonXContent, + bulkResponse.getEntity().getContent(), + true + ); + + assertFalse((Boolean) responseMap.get("errors")); + assertThat(((List) responseMap.get("items")).size(), equalTo(1001 + updates)); + } +} diff --git a/server/src/internalClusterTest/java/org/elasticsearch/action/bulk/IncrementalBulkIT.java b/server/src/internalClusterTest/java/org/elasticsearch/action/bulk/IncrementalBulkIT.java new file mode 100644 index 0000000000000..d7a5d4e2ac973 --- /dev/null +++ b/server/src/internalClusterTest/java/org/elasticsearch/action/bulk/IncrementalBulkIT.java @@ -0,0 +1,532 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.action.bulk; + +import org.elasticsearch.action.DocWriteRequest; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; +import org.elasticsearch.core.AbstractRefCounted; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.index.Index; +import org.elasticsearch.index.IndexService; +import org.elasticsearch.index.IndexingPressure; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.index.shard.IndexShard; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.indices.IndicesService; +import org.elasticsearch.ingest.IngestClientIT; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.threadpool.ThreadPool; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; + +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailures; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse; +import static org.elasticsearch.xcontent.XContentFactory.jsonBuilder; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.instanceOf; + +public class IncrementalBulkIT extends ESIntegTestCase { + + @Override + protected Collection> nodePlugins() { + return List.of(IngestClientIT.ExtendedIngestTestPlugin.class); + } + + @Override + protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { + return Settings.builder() + .put(super.nodeSettings(nodeOrdinal, otherSettings)) + .put(IndexingPressure.SPLIT_BULK_THRESHOLD.getKey(), "512B") + .build(); + } + + public void testSingleBulkRequest() { + String index = "test"; + createIndex(index); + + IncrementalBulkService incrementalBulkService = internalCluster().getInstance(IncrementalBulkService.class); + + IncrementalBulkService.Handler handler = incrementalBulkService.newBulkRequest(); + IndexRequest indexRequest = indexRequest(index); + + PlainActionFuture future = new PlainActionFuture<>(); + AbstractRefCounted refCounted = AbstractRefCounted.of(() -> {}); + handler.lastItems(List.of(indexRequest), refCounted::decRef, future); + + BulkResponse bulkResponse = future.actionGet(); + assertNoFailures(bulkResponse); + + refresh(index); + + assertResponse(prepareSearch(index).setQuery(QueryBuilders.matchAllQuery()), searchResponse -> { + assertNoFailures(searchResponse); + assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) 1)); + }); + + assertFalse(refCounted.hasReferences()); + } + + public void testBufferedResourcesReleasedOnClose() { + String index = "test"; + createIndex(index); + + String nodeName = internalCluster().getRandomNodeName(); + IncrementalBulkService incrementalBulkService = internalCluster().getInstance(IncrementalBulkService.class, nodeName); + IndexingPressure indexingPressure = internalCluster().getInstance(IndexingPressure.class, nodeName); + + IncrementalBulkService.Handler handler = incrementalBulkService.newBulkRequest(); + IndexRequest indexRequest = indexRequest(index); + + AbstractRefCounted refCounted = AbstractRefCounted.of(() -> {}); + handler.addItems(List.of(indexRequest), refCounted::decRef, () -> {}); + + assertTrue(refCounted.hasReferences()); + assertThat(indexingPressure.stats().getCurrentCoordinatingBytes(), greaterThan(0L)); + + handler.close(); + + assertFalse(refCounted.hasReferences()); + assertThat(indexingPressure.stats().getCurrentCoordinatingBytes(), equalTo(0L)); + } + + public void testIndexingPressureRejection() { + String index = "test"; + createIndex(index); + + String nodeName = internalCluster().getRandomNodeName(); + IncrementalBulkService incrementalBulkService = internalCluster().getInstance(IncrementalBulkService.class, nodeName); + IndexingPressure indexingPressure = internalCluster().getInstance(IndexingPressure.class, nodeName); + + try (Releasable r = indexingPressure.markCoordinatingOperationStarted(1, indexingPressure.stats().getMemoryLimit(), true)) { + IncrementalBulkService.Handler handler = incrementalBulkService.newBulkRequest(); + AbstractRefCounted refCounted = AbstractRefCounted.of(() -> {}); + + if (randomBoolean()) { + AtomicBoolean nextPage = new AtomicBoolean(false); + refCounted.incRef(); + handler.addItems(List.of(indexRequest(index)), refCounted::decRef, () -> nextPage.set(true)); + assertTrue(nextPage.get()); + } + + PlainActionFuture future = new PlainActionFuture<>(); + handler.lastItems(List.of(indexRequest(index)), refCounted::decRef, future); + + expectThrows(EsRejectedExecutionException.class, future::actionGet); + assertFalse(refCounted.hasReferences()); + } + } + + public void testIncrementalBulkRequestMemoryBackOff() throws Exception { + String index = "test"; + createIndex(index); + + String nodeName = internalCluster().getRandomNodeName(); + IncrementalBulkService incrementalBulkService = internalCluster().getInstance(IncrementalBulkService.class, nodeName); + IndexingPressure indexingPressure = internalCluster().getInstance(IndexingPressure.class, nodeName); + + IncrementalBulkService.Handler handler = incrementalBulkService.newBulkRequest(); + + AbstractRefCounted refCounted = AbstractRefCounted.of(() -> {}); + AtomicBoolean nextPage = new AtomicBoolean(false); + + IndexRequest indexRequest = indexRequest(index); + long total = indexRequest.ramBytesUsed(); + while (total < 512) { + refCounted.incRef(); + handler.addItems(List.of(indexRequest), refCounted::decRef, () -> nextPage.set(true)); + assertTrue(nextPage.get()); + nextPage.set(false); + indexRequest = indexRequest(index); + total += indexRequest.ramBytesUsed(); + } + + assertThat(indexingPressure.stats().getCurrentCombinedCoordinatingAndPrimaryBytes(), greaterThan(0L)); + refCounted.incRef(); + handler.addItems(List.of(indexRequest(index)), refCounted::decRef, () -> nextPage.set(true)); + + assertBusy(() -> assertThat(indexingPressure.stats().getCurrentCombinedCoordinatingAndPrimaryBytes(), equalTo(0L))); + + PlainActionFuture future = new PlainActionFuture<>(); + handler.lastItems(List.of(indexRequest), refCounted::decRef, future); + + BulkResponse bulkResponse = future.actionGet(); + assertNoFailures(bulkResponse); + assertFalse(refCounted.hasReferences()); + } + + public void testMultipleBulkPartsWithBackoff() { + ExecutorService executorService = Executors.newFixedThreadPool(1); + + try (Releasable ignored = executorService::shutdown;) { + String index = "test"; + createIndex(index); + + IncrementalBulkService incrementalBulkService = internalCluster().getInstance(IncrementalBulkService.class); + long docs = randomIntBetween(200, 400); + + IncrementalBulkService.Handler handler = incrementalBulkService.newBulkRequest(); + + BulkResponse bulkResponse = executeBulk(docs, index, handler, executorService); + assertNoFailures(bulkResponse); + + refresh(index); + + assertResponse(prepareSearch(index).setQuery(QueryBuilders.matchAllQuery()), searchResponse -> { + assertNoFailures(searchResponse); + assertThat(searchResponse.getHits().getTotalHits().value, equalTo(docs)); + }); + } + } + + public void testGlobalBulkFailure() throws InterruptedException { + ExecutorService executorService = Executors.newFixedThreadPool(1); + CountDownLatch blockingLatch = new CountDownLatch(1); + + try (Releasable ignored = executorService::shutdown; Releasable ignored2 = blockingLatch::countDown) { + String index = "test"; + createIndex(index); + + String randomNodeName = internalCluster().getRandomNodeName(); + IncrementalBulkService incrementalBulkService = internalCluster().getInstance(IncrementalBulkService.class, randomNodeName); + ThreadPool threadPool = internalCluster().getInstance(ThreadPool.class, randomNodeName); + + int threadCount = threadPool.info(ThreadPool.Names.WRITE).getMax(); + long queueSize = threadPool.info(ThreadPool.Names.WRITE).getQueueSize().singles(); + blockWritePool(threadCount, threadPool, blockingLatch); + + Runnable runnable = () -> {}; + for (int i = 0; i < queueSize; i++) { + threadPool.executor(ThreadPool.Names.WRITE).execute(runnable); + } + + IncrementalBulkService.Handler handler = incrementalBulkService.newBulkRequest(); + if (randomBoolean()) { + expectThrows( + EsRejectedExecutionException.class, + () -> executeBulk(randomIntBetween(200, 400), index, handler, executorService) + ); + } else { + PlainActionFuture future = new PlainActionFuture<>(); + AbstractRefCounted refCounted = AbstractRefCounted.of(() -> {}); + handler.lastItems(List.of(indexRequest(index)), refCounted::decRef, future); + assertFalse(refCounted.hasReferences()); + expectThrows(EsRejectedExecutionException.class, future::actionGet); + } + } + } + + public void testBulkLevelBulkFailureAfterFirstIncrementalRequest() throws Exception { + ExecutorService executorService = Executors.newFixedThreadPool(1); + + try (Releasable ignored = executorService::shutdown) { + String index = "test"; + createIndex(index); + + String randomNodeName = internalCluster().getRandomNodeName(); + IncrementalBulkService incrementalBulkService = internalCluster().getInstance(IncrementalBulkService.class, randomNodeName); + ThreadPool threadPool = internalCluster().getInstance(ThreadPool.class, randomNodeName); + IncrementalBulkService.Handler handler = incrementalBulkService.newBulkRequest(); + AbstractRefCounted refCounted = AbstractRefCounted.of(() -> {}); + PlainActionFuture future = new PlainActionFuture<>(); + + int threadCount = threadPool.info(ThreadPool.Names.WRITE).getMax(); + long queueSize = threadPool.info(ThreadPool.Names.WRITE).getQueueSize().singles(); + + CountDownLatch blockingLatch1 = new CountDownLatch(1); + + AtomicBoolean nextRequested = new AtomicBoolean(true); + AtomicLong hits = new AtomicLong(0); + try (Releasable ignored2 = blockingLatch1::countDown;) { + blockWritePool(threadCount, threadPool, blockingLatch1); + while (nextRequested.get()) { + nextRequested.set(false); + refCounted.incRef(); + handler.addItems(List.of(indexRequest(index)), refCounted::decRef, () -> nextRequested.set(true)); + hits.incrementAndGet(); + } + } + assertBusy(() -> assertTrue(nextRequested.get())); + + CountDownLatch blockingLatch2 = new CountDownLatch(1); + + try (Releasable ignored3 = blockingLatch2::countDown;) { + blockWritePool(threadCount, threadPool, blockingLatch2); + Runnable runnable = () -> {}; + // Fill Queue + for (int i = 0; i < queueSize; i++) { + threadPool.executor(ThreadPool.Names.WRITE).execute(runnable); + } + + handler.lastItems(List.of(indexRequest(index)), refCounted::decRef, future); + } + + // Should not throw because some succeeded + BulkResponse bulkResponse = future.actionGet(); + + assertTrue(bulkResponse.hasFailures()); + BulkItemResponse[] items = bulkResponse.getItems(); + assertThat(Arrays.stream(items).filter(r -> r.isFailed() == false).count(), equalTo(hits.get())); + assertThat(items[items.length - 1].getFailure().getCause(), instanceOf(EsRejectedExecutionException.class)); + + refresh(index); + + assertResponse(prepareSearch(index).setQuery(QueryBuilders.matchAllQuery()), searchResponse -> { + assertNoFailures(searchResponse); + assertThat(searchResponse.getHits().getTotalHits().value, equalTo(hits.get())); + }); + } + } + + public void testShortCircuitShardLevelFailure() throws Exception { + String index = "test"; + createIndex(index, 2, 0); + + String coordinatingOnlyNode = internalCluster().startCoordinatingOnlyNode(Settings.EMPTY); + + AbstractRefCounted refCounted = AbstractRefCounted.of(() -> {}); + IncrementalBulkService incrementalBulkService = internalCluster().getInstance(IncrementalBulkService.class, coordinatingOnlyNode); + IncrementalBulkService.Handler handler = incrementalBulkService.newBulkRequest(); + + AtomicBoolean nextRequested = new AtomicBoolean(true); + AtomicLong hits = new AtomicLong(0); + while (nextRequested.get()) { + nextRequested.set(false); + refCounted.incRef(); + handler.addItems(List.of(indexRequest(index)), refCounted::decRef, () -> nextRequested.set(true)); + hits.incrementAndGet(); + } + + assertBusy(() -> assertTrue(nextRequested.get())); + + String node = findShard(resolveIndex(index), 0); + String secondShardNode = findShard(resolveIndex(index), 1); + IndexingPressure primaryPressure = internalCluster().getInstance(IndexingPressure.class, node); + long memoryLimit = primaryPressure.stats().getMemoryLimit(); + long primaryRejections = primaryPressure.stats().getPrimaryRejections(); + try (Releasable releasable = primaryPressure.markPrimaryOperationStarted(10, memoryLimit, false)) { + while (primaryPressure.stats().getPrimaryRejections() == primaryRejections) { + while (nextRequested.get()) { + nextRequested.set(false); + refCounted.incRef(); + List> requests = new ArrayList<>(); + for (int i = 0; i < 20; ++i) { + requests.add(indexRequest(index)); + } + handler.addItems(requests, refCounted::decRef, () -> nextRequested.set(true)); + } + assertBusy(() -> assertTrue(nextRequested.get())); + } + } + + while (nextRequested.get()) { + nextRequested.set(false); + refCounted.incRef(); + handler.addItems(List.of(indexRequest(index)), refCounted::decRef, () -> nextRequested.set(true)); + } + + assertBusy(() -> assertTrue(nextRequested.get())); + + PlainActionFuture future = new PlainActionFuture<>(); + handler.lastItems(List.of(indexRequest(index)), refCounted::decRef, future); + + BulkResponse bulkResponse = future.actionGet(); + assertTrue(bulkResponse.hasFailures()); + for (int i = 0; i < hits.get(); ++i) { + assertFalse(bulkResponse.getItems()[i].isFailed()); + } + + boolean shardsOnDifferentNodes = node.equals(secondShardNode) == false; + for (int i = (int) hits.get(); i < bulkResponse.getItems().length; ++i) { + BulkItemResponse item = bulkResponse.getItems()[i]; + if (item.getResponse() != null && item.getResponse().getShardId().id() == 1 && shardsOnDifferentNodes) { + assertFalse(item.isFailed()); + } else { + assertTrue(item.isFailed()); + assertThat(item.getFailure().getCause().getCause(), instanceOf(EsRejectedExecutionException.class)); + } + } + } + + public void testShortCircuitShardLevelFailureWithIngestNodeHop() throws Exception { + String dataOnlyNode = internalCluster().startDataOnlyNode(); + String index = "test1"; + + // We ensure that the index is assigned to a non-ingest node to ensure that indexing pressure does not reject at the coordinating + // level. + createIndex( + index, + Settings.builder() + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1) + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) + .put("index.routing.allocation.require._name", dataOnlyNode) + .build() + ); + + String pipelineId = "pipeline_id"; + BytesReference pipelineSource = BytesReference.bytes( + jsonBuilder().startObject() + .field("description", "my_pipeline") + .startArray("processors") + .startObject() + .startObject("test") + .endObject() + .endObject() + .endArray() + .endObject() + ); + + putJsonPipeline(pipelineId, pipelineSource); + + // By adding an ingest pipeline and sending the request to a coordinating node without the ingest role, we ensure that we are + // testing the serialization of shard level requests over the wire. This is because the transport bulk action will be dispatched to + // a node with the ingest role. + String coordinatingOnlyNode = internalCluster().startCoordinatingOnlyNode(Settings.EMPTY); + + AbstractRefCounted refCounted = AbstractRefCounted.of(() -> {}); + IncrementalBulkService incrementalBulkService = internalCluster().getInstance(IncrementalBulkService.class, coordinatingOnlyNode); + IncrementalBulkService.Handler handler = incrementalBulkService.newBulkRequest(); + + AtomicBoolean nextRequested = new AtomicBoolean(true); + AtomicLong hits = new AtomicLong(0); + while (nextRequested.get()) { + nextRequested.set(false); + refCounted.incRef(); + handler.addItems(List.of(indexRequest(index).setPipeline(pipelineId)), refCounted::decRef, () -> nextRequested.set(true)); + hits.incrementAndGet(); + } + + assertBusy(() -> assertTrue(nextRequested.get())); + + String node = findShard(resolveIndex(index), 0); + assertThat(node, equalTo(dataOnlyNode)); + IndexingPressure primaryPressure = internalCluster().getInstance(IndexingPressure.class, node); + long memoryLimit = primaryPressure.stats().getMemoryLimit(); + try (Releasable releasable = primaryPressure.markPrimaryOperationStarted(10, memoryLimit, false)) { + while (nextRequested.get()) { + nextRequested.set(false); + refCounted.incRef(); + handler.addItems(List.of(indexRequest(index).setPipeline(pipelineId)), refCounted::decRef, () -> nextRequested.set(true)); + } + + assertBusy(() -> assertTrue(nextRequested.get())); + } + + while (nextRequested.get()) { + nextRequested.set(false); + refCounted.incRef(); + handler.addItems(List.of(indexRequest(index).setPipeline(pipelineId)), refCounted::decRef, () -> nextRequested.set(true)); + } + + assertBusy(() -> assertTrue(nextRequested.get())); + + PlainActionFuture future = new PlainActionFuture<>(); + handler.lastItems(List.of(indexRequest(index)), refCounted::decRef, future); + + BulkResponse bulkResponse = future.actionGet(); + assertTrue(bulkResponse.hasFailures()); + for (int i = 0; i < hits.get(); ++i) { + assertFalse(bulkResponse.getItems()[i].isFailed()); + } + + for (int i = (int) hits.get(); i < bulkResponse.getItems().length; ++i) { + BulkItemResponse item = bulkResponse.getItems()[i]; + assertTrue(item.isFailed()); + assertThat(item.getFailure().getCause().getCause(), instanceOf(EsRejectedExecutionException.class)); + } + } + + private static void blockWritePool(int threadCount, ThreadPool threadPool, CountDownLatch blockingLatch) throws InterruptedException { + CountDownLatch startedLatch = new CountDownLatch(threadCount); + for (int i = 0; i < threadCount; i++) { + threadPool.executor(ThreadPool.Names.WRITE).execute(() -> { + startedLatch.countDown(); + try { + blockingLatch.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + }); + } + startedLatch.await(); + } + + private BulkResponse executeBulk(long docs, String index, IncrementalBulkService.Handler handler, ExecutorService executorService) { + ConcurrentLinkedQueue> queue = new ConcurrentLinkedQueue<>(); + for (int i = 0; i < docs; i++) { + IndexRequest indexRequest = indexRequest(index); + queue.add(indexRequest); + } + + AbstractRefCounted refCounted = AbstractRefCounted.of(() -> {}); + PlainActionFuture future = new PlainActionFuture<>(); + Runnable r = new Runnable() { + + @Override + public void run() { + int toRemove = Math.min(randomIntBetween(5, 10), queue.size()); + ArrayList> docs = new ArrayList<>(); + for (int i = 0; i < toRemove; i++) { + docs.add(queue.poll()); + } + + if (queue.isEmpty()) { + handler.lastItems(docs, refCounted::decRef, future); + } else { + refCounted.incRef(); + handler.addItems(docs, refCounted::decRef, () -> executorService.execute(this)); + } + } + }; + + executorService.execute(r); + + BulkResponse bulkResponse = future.actionGet(); + assertFalse(refCounted.hasReferences()); + return bulkResponse; + } + + private static IndexRequest indexRequest(String index) { + IndexRequest indexRequest = new IndexRequest(); + indexRequest.index(index); + indexRequest.source(Map.of("field", randomAlphaOfLength(10))); + return indexRequest; + } + + protected static String findShard(Index index, int shardId) { + for (String node : internalCluster().getNodeNames()) { + var indicesService = internalCluster().getInstance(IndicesService.class, node); + IndexService indexService = indicesService.indexService(index); + if (indexService != null) { + IndexShard shard = indexService.getShardOrNull(shardId); + if (shard != null && shard.isActive() && shard.routingEntry().primary()) { + return node; + } + } + } + throw new AssertionError("IndexShard instance not found for shard " + new ShardId(index, shardId)); + } +} diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index c25825475aa9c..cef4bd14d992b 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -218,6 +218,7 @@ static TransportVersion def(int id) { public static final TransportVersion ML_SCHEDULED_EVENT_TIME_SHIFT_CONFIGURATION = def(8_742_00_0); public static final TransportVersion SIMULATE_COMPONENT_TEMPLATES_SUBSTITUTIONS = def(8_743_00_0); public static final TransportVersion ML_INFERENCE_IBM_WATSONX_EMBEDDINGS_ADDED = def(8_744_00_0); + public static final TransportVersion BULK_INCREMENTAL_STATE = def(8_745_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/action/ActionModule.java b/server/src/main/java/org/elasticsearch/action/ActionModule.java index 6c736b47bc94c..2d72f5d71ccda 100644 --- a/server/src/main/java/org/elasticsearch/action/ActionModule.java +++ b/server/src/main/java/org/elasticsearch/action/ActionModule.java @@ -160,6 +160,7 @@ import org.elasticsearch.action.admin.indices.template.put.TransportPutIndexTemplateAction; import org.elasticsearch.action.admin.indices.validate.query.TransportValidateQueryAction; import org.elasticsearch.action.admin.indices.validate.query.ValidateQueryAction; +import org.elasticsearch.action.bulk.IncrementalBulkService; import org.elasticsearch.action.bulk.SimulateBulkAction; import org.elasticsearch.action.bulk.TransportBulkAction; import org.elasticsearch.action.bulk.TransportShardBulkAction; @@ -448,6 +449,7 @@ public class ActionModule extends AbstractModule { private final List actionPlugins; private final Map> actions; private final ActionFilters actionFilters; + private final IncrementalBulkService bulkService; private final AutoCreateIndex autoCreateIndex; private final DestructiveOperations destructiveOperations; private final RestController restController; @@ -476,7 +478,8 @@ public ActionModule( ClusterService clusterService, RerouteService rerouteService, List> reservedStateHandlers, - RestExtension restExtension + RestExtension restExtension, + IncrementalBulkService bulkService ) { this.settings = settings; this.indexNameExpressionResolver = indexNameExpressionResolver; @@ -488,6 +491,7 @@ public ActionModule( this.threadPool = threadPool; actions = setupActions(actionPlugins); actionFilters = setupActionFilters(actionPlugins); + this.bulkService = bulkService; autoCreateIndex = new AutoCreateIndex(settings, clusterSettings, indexNameExpressionResolver, systemIndices); destructiveOperations = new DestructiveOperations(settings, clusterSettings); Set headers = Stream.concat( @@ -928,7 +932,7 @@ public void initRestHandlers(Supplier nodesInCluster, Predicate< registerHandler.accept(new RestCountAction()); registerHandler.accept(new RestTermVectorsAction()); registerHandler.accept(new RestMultiTermVectorsAction()); - registerHandler.accept(new RestBulkAction(settings)); + registerHandler.accept(new RestBulkAction(settings, bulkService)); registerHandler.accept(new RestUpdateAction()); registerHandler.accept(new RestSearchAction(restController.getSearchUsageHolder(), clusterSupportsFeature)); diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java index 98e3548ecf30e..13229fbf65fef 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java @@ -95,6 +95,7 @@ final class BulkOperation extends ActionRunnable { private final OriginSettingClient rolloverClient; private final Set failureStoresToBeRolledOver = ConcurrentCollections.newConcurrentSet(); private final Set failedRolloverRequests = ConcurrentCollections.newConcurrentSet(); + private final Map shortCircuitShardFailures = ConcurrentCollections.newConcurrentMap(); private final FailureStoreMetrics failureStoreMetrics; BulkOperation( @@ -164,6 +165,7 @@ final class BulkOperation extends ActionRunnable { this.observer = observer; this.failureStoreDocumentConverter = failureStoreDocumentConverter; this.rolloverClient = new OriginSettingClient(client, LAZY_ROLLOVER_ORIGIN); + this.shortCircuitShardFailures.putAll(bulkRequest.incrementalState().shardLevelFailures()); this.failureStoreMetrics = failureStoreMetrics; } @@ -403,7 +405,12 @@ private void redirectFailuresOrCompleteBulkOperation() { private void completeBulkOperation() { listener.onResponse( - new BulkResponse(responses.toArray(new BulkItemResponse[responses.length()]), buildTookInMillis(startTimeNanos)) + new BulkResponse( + responses.toArray(new BulkItemResponse[responses.length()]), + buildTookInMillis(startTimeNanos), + BulkResponse.NO_INGEST_TOOK, + new BulkRequest.IncrementalState(shortCircuitShardFailures, bulkRequest.incrementalState().indexingPressureAccounted()) + ) ); // Allow memory for bulk shard request items to be reclaimed before all items have been completed bulkRequest = null; @@ -429,90 +436,102 @@ private void discardRedirectsAndFinish(Exception exception) { } private void executeBulkShardRequest(BulkShardRequest bulkShardRequest, Releasable releaseOnFinish) { - client.executeLocally(TransportShardBulkAction.TYPE, bulkShardRequest, new ActionListener<>() { + ShardId shardId = bulkShardRequest.shardId(); - // Lazily get the cluster state to avoid keeping it around longer than it is needed - private ClusterState clusterState = null; + // Short circuit the shard level request with the existing shard failure. + if (shortCircuitShardFailures.containsKey(shardId)) { + handleShardFailure(bulkShardRequest, clusterService.state(), shortCircuitShardFailures.get(shardId)); + releaseOnFinish.close(); + } else { + client.executeLocally(TransportShardBulkAction.TYPE, bulkShardRequest, new ActionListener<>() { - private ClusterState getClusterState() { - if (clusterState == null) { - clusterState = clusterService.state(); - } - return clusterState; - } + // Lazily get the cluster state to avoid keeping it around longer than it is needed + private ClusterState clusterState = null; - @Override - public void onResponse(BulkShardResponse bulkShardResponse) { - for (int idx = 0; idx < bulkShardResponse.getResponses().length; idx++) { - // We zip the requests and responses together so that we can identify failed documents and potentially store them - BulkItemResponse bulkItemResponse = bulkShardResponse.getResponses()[idx]; - BulkItemRequest bulkItemRequest = bulkShardRequest.items()[idx]; - - if (bulkItemResponse.isFailed()) { - assert bulkItemRequest.id() == bulkItemResponse.getItemId() : "Bulk items were returned out of order"; - processFailure(bulkItemRequest, bulkItemResponse.getFailure().getCause()); - addFailure(bulkItemResponse); - } else { - bulkItemResponse.getResponse().setShardInfo(bulkShardResponse.getShardInfo()); - responses.set(bulkItemResponse.getItemId(), bulkItemResponse); + private ClusterState getClusterState() { + if (clusterState == null) { + clusterState = clusterService.state(); } + return clusterState; } - completeShardOperation(); - } - @Override - public void onFailure(Exception e) { - // create failures for all relevant requests - for (BulkItemRequest request : bulkShardRequest.items()) { - final String indexName = request.index(); - DocWriteRequest docWriteRequest = request.request(); - - processFailure(request, e); - addFailure(docWriteRequest, request.id(), indexName, e); + @Override + public void onResponse(BulkShardResponse bulkShardResponse) { + for (int idx = 0; idx < bulkShardResponse.getResponses().length; idx++) { + // We zip the requests and responses together so that we can identify failed documents and potentially store them + BulkItemResponse bulkItemResponse = bulkShardResponse.getResponses()[idx]; + BulkItemRequest bulkItemRequest = bulkShardRequest.items()[idx]; + + if (bulkItemResponse.isFailed()) { + assert bulkItemRequest.id() == bulkItemResponse.getItemId() : "Bulk items were returned out of order"; + processFailure(bulkItemRequest, getClusterState(), bulkItemResponse.getFailure().getCause()); + addFailure(bulkItemResponse); + } else { + bulkItemResponse.getResponse().setShardInfo(bulkShardResponse.getShardInfo()); + responses.set(bulkItemResponse.getItemId(), bulkItemResponse); + } + } + completeShardOperation(); } - completeShardOperation(); - } - private void completeShardOperation() { - // Clear our handle on the cluster state to allow it to be cleaned up - clusterState = null; - releaseOnFinish.close(); - } + @Override + public void onFailure(Exception e) { + assert shortCircuitShardFailures.containsKey(shardId) == false; + shortCircuitShardFailures.put(shardId, e); - private void processFailure(BulkItemRequest bulkItemRequest, Exception cause) { - var error = ExceptionsHelper.unwrapCause(cause); - var errorType = ElasticsearchException.getExceptionName(error); - DocWriteRequest docWriteRequest = bulkItemRequest.request(); - DataStream failureStoreCandidate = getRedirectTargetCandidate(docWriteRequest, getClusterState().metadata()); - // If the candidate is not null, the BulkItemRequest targets a data stream, but we'll still have to check if - // it has the failure store enabled. - if (failureStoreCandidate != null) { - // Do not redirect documents to a failure store that were already headed to one. - var isFailureStoreDoc = docWriteRequest instanceof IndexRequest indexRequest && indexRequest.isWriteToFailureStore(); - if (isFailureStoreDoc == false - && failureStoreCandidate.isFailureStoreEnabled() - && error instanceof VersionConflictEngineException == false) { - // Redirect to failure store. - maybeMarkFailureStoreForRollover(failureStoreCandidate); - addDocumentToRedirectRequests(bulkItemRequest, cause, failureStoreCandidate.getName()); - failureStoreMetrics.incrementFailureStore( - bulkItemRequest.index(), - errorType, - FailureStoreMetrics.ErrorLocation.SHARD - ); - } else { - // If we can't redirect to a failure store (because either the data stream doesn't have the failure store enabled - // or this request was already targeting a failure store), we increment the rejected counter. - failureStoreMetrics.incrementRejected( - bulkItemRequest.index(), - errorType, - FailureStoreMetrics.ErrorLocation.SHARD, - isFailureStoreDoc - ); - } + // create failures for all relevant requests + handleShardFailure(bulkShardRequest, getClusterState(), e); + completeShardOperation(); } + + private void completeShardOperation() { + // Clear our handle on the cluster state to allow it to be cleaned up + clusterState = null; + releaseOnFinish.close(); + } + }); + } + } + + private void handleShardFailure(BulkShardRequest bulkShardRequest, ClusterState clusterState, Exception e) { + // create failures for all relevant requests + for (BulkItemRequest request : bulkShardRequest.items()) { + final String indexName = request.index(); + DocWriteRequest docWriteRequest = request.request(); + + processFailure(request, clusterState, e); + addFailure(docWriteRequest, request.id(), indexName, e); + } + } + + private void processFailure(BulkItemRequest bulkItemRequest, ClusterState clusterState, Exception cause) { + var error = ExceptionsHelper.unwrapCause(cause); + var errorType = ElasticsearchException.getExceptionName(error); + DocWriteRequest docWriteRequest = bulkItemRequest.request(); + DataStream failureStoreCandidate = getRedirectTargetCandidate(docWriteRequest, clusterState.metadata()); + // If the candidate is not null, the BulkItemRequest targets a data stream, but we'll still have to check if + // it has the failure store enabled. + if (failureStoreCandidate != null) { + // Do not redirect documents to a failure store that were already headed to one. + var isFailureStoreDoc = docWriteRequest instanceof IndexRequest indexRequest && indexRequest.isWriteToFailureStore(); + if (isFailureStoreDoc == false + && failureStoreCandidate.isFailureStoreEnabled() + && error instanceof VersionConflictEngineException == false) { + // Redirect to failure store. + maybeMarkFailureStoreForRollover(failureStoreCandidate); + addDocumentToRedirectRequests(bulkItemRequest, cause, failureStoreCandidate.getName()); + failureStoreMetrics.incrementFailureStore(bulkItemRequest.index(), errorType, FailureStoreMetrics.ErrorLocation.SHARD); + } else { + // If we can't redirect to a failure store (because either the data stream doesn't have the failure store enabled + // or this request was already targeting a failure store), we increment the rejected counter. + failureStoreMetrics.incrementRejected( + bulkItemRequest.index(), + errorType, + FailureStoreMetrics.ErrorLocation.SHARD, + isFailureStoreDoc + ); } - }); + } } /** diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkRequest.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkRequest.java index 1425dde28ea3b..558901f102299 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkRequest.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkRequest.java @@ -11,6 +11,7 @@ import org.apache.lucene.util.Accountable; import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.CompositeIndicesRequest; @@ -27,9 +28,11 @@ import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.RestApiVersion; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.transport.RawIndexingDataTransportRequest; import org.elasticsearch.xcontent.XContentType; @@ -72,6 +75,7 @@ public class BulkRequest extends ActionRequest private final Set indices = new HashSet<>(); protected TimeValue timeout = BulkShardRequest.DEFAULT_TIMEOUT; + private IncrementalState incrementalState = IncrementalState.EMPTY; private ActiveShardCount waitForActiveShards = ActiveShardCount.DEFAULT; private RefreshPolicy refreshPolicy = RefreshPolicy.NONE; private String globalPipeline; @@ -93,6 +97,11 @@ public BulkRequest(StreamInput in) throws IOException { for (DocWriteRequest request : requests) { indices.add(Objects.requireNonNull(request.index(), "request index must not be null")); } + if (in.getTransportVersion().onOrAfter(TransportVersions.BULK_INCREMENTAL_STATE)) { + incrementalState = new BulkRequest.IncrementalState(in); + } else { + incrementalState = BulkRequest.IncrementalState.EMPTY; + } } public BulkRequest(@Nullable String globalIndex) { @@ -327,6 +336,10 @@ public final BulkRequest timeout(TimeValue timeout) { return this; } + public void incrementalState(IncrementalState incrementalState) { + this.incrementalState = incrementalState; + } + /** * Note for internal callers (NOT high level rest client), * the global parameter setting is ignored when used with: @@ -365,6 +378,10 @@ public TimeValue timeout() { return timeout; } + public IncrementalState incrementalState() { + return incrementalState; + } + public String pipeline() { return globalPipeline; } @@ -436,6 +453,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeCollection(requests, DocWriteRequest::writeDocumentRequest); refreshPolicy.writeTo(out); out.writeTimeValue(timeout); + if (out.getTransportVersion().onOrAfter(TransportVersions.BULK_INCREMENTAL_STATE)) { + incrementalState.writeTo(out); + } } @Override @@ -486,6 +506,20 @@ public Map getComponentTemplateSubstitutions() throws return Map.of(); } + record IncrementalState(Map shardLevelFailures, boolean indexingPressureAccounted) implements Writeable { + + static final IncrementalState EMPTY = new IncrementalState(Collections.emptyMap(), false); + + IncrementalState(StreamInput in) throws IOException { + this(in.readMap(ShardId::new, input -> input.readException()), false); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeMap(shardLevelFailures, (o, s) -> s.writeTo(o), StreamOutput::writeException); + } + } + /* * This copies this bulk request, but without all of its inner requests or the set of indices found in those requests */ diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkRequestModifier.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkRequestModifier.java index 3e47c78a76354..282e4d33fb83b 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkRequestModifier.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkRequestModifier.java @@ -114,7 +114,12 @@ BulkRequest getBulkRequest() { ActionListener wrapActionListenerIfNeeded(long ingestTookInMillis, ActionListener actionListener) { if (itemResponses.isEmpty()) { return actionListener.map( - response -> new BulkResponse(response.getItems(), response.getTook().getMillis(), ingestTookInMillis) + response -> new BulkResponse( + response.getItems(), + response.getTook().getMillis(), + ingestTookInMillis, + response.getIncrementalState() + ) ); } else { return actionListener.map(response -> { @@ -139,7 +144,7 @@ ActionListener wrapActionListenerIfNeeded(long ingestTookInMillis, assertResponsesAreCorrect(bulkResponses, allResponses); } - return new BulkResponse(allResponses, response.getTook().getMillis(), ingestTookInMillis); + return new BulkResponse(allResponses, response.getTook().getMillis(), ingestTookInMillis, response.getIncrementalState()); }); } } diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkRequestParser.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkRequestParser.java index e94bfff69d3d1..c27e3d319d7ca 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkRequestParser.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkRequestParser.java @@ -86,13 +86,13 @@ public BulkRequestParser(boolean deprecateOrErrorOnType, RestApiVersion restApiV .withRestApiVersion(restApiVersion); } - private static int findNextMarker(byte marker, int from, BytesReference data) { + private static int findNextMarker(byte marker, int from, BytesReference data, boolean isIncremental) { final int res = data.indexOf(marker, from); if (res != -1) { assert res >= 0; return res; } - if (from != data.length()) { + if (from != data.length() && isIncremental == false) { throw new IllegalArgumentException("The bulk request must be terminated by a newline [\\n]"); } return res; @@ -137,18 +137,57 @@ public void parse( Consumer updateRequestConsumer, Consumer deleteRequestConsumer ) throws IOException { - XContent xContent = xContentType.xContent(); - int line = 0; - int from = 0; - byte marker = xContent.bulkSeparator(); // Bulk requests can contain a lot of repeated strings for the index, pipeline and routing parameters. This map is used to // deduplicate duplicate strings parsed for these parameters. While it does not prevent instantiating the duplicate strings, it // reduces their lifetime to the lifetime of this parse call instead of the lifetime of the full bulk request. final Map stringDeduplicator = new HashMap<>(); + + incrementalParse( + data, + defaultIndex, + defaultRouting, + defaultFetchSourceContext, + defaultPipeline, + defaultRequireAlias, + defaultRequireDataStream, + defaultListExecutedPipelines, + allowExplicitIndex, + xContentType, + indexRequestConsumer, + updateRequestConsumer, + deleteRequestConsumer, + false, + stringDeduplicator + ); + } + + public int incrementalParse( + BytesReference data, + String defaultIndex, + String defaultRouting, + FetchSourceContext defaultFetchSourceContext, + String defaultPipeline, + Boolean defaultRequireAlias, + Boolean defaultRequireDataStream, + Boolean defaultListExecutedPipelines, + boolean allowExplicitIndex, + XContentType xContentType, + BiConsumer indexRequestConsumer, + Consumer updateRequestConsumer, + Consumer deleteRequestConsumer, + boolean isIncremental, + Map stringDeduplicator + ) throws IOException { + XContent xContent = xContentType.xContent(); + byte marker = xContent.bulkSeparator(); boolean typesDeprecationLogged = false; + int line = 0; + int from = 0; + int consumed = 0; + while (true) { - int nextMarker = findNextMarker(marker, from, data); + int nextMarker = findNextMarker(marker, from, data, isIncremental); if (nextMarker == -1) { break; } @@ -333,8 +372,9 @@ public void parse( .setIfSeqNo(ifSeqNo) .setIfPrimaryTerm(ifPrimaryTerm) ); + consumed = from; } else { - nextMarker = findNextMarker(marker, from, data); + nextMarker = findNextMarker(marker, from, data, isIncremental); if (nextMarker == -1) { break; } @@ -407,9 +447,11 @@ public void parse( } // move pointers from = nextMarker + 1; + consumed = from; } } } + return isIncremental ? consumed : from; } @UpdateForV9 diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkResponse.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkResponse.java index 8f12341d71e7b..b02d7acf66d14 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkResponse.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkResponse.java @@ -9,6 +9,7 @@ package org.elasticsearch.action.bulk; +import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.common.collect.Iterators; import org.elasticsearch.common.io.stream.StreamInput; @@ -37,12 +38,18 @@ public class BulkResponse extends ActionResponse implements Iterable INCREMENTAL_BULK = boolSetting( + "rest.incremental_bulk", + true, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + private final Client client; + private final AtomicBoolean enabledForTests = new AtomicBoolean(true); + private final IndexingPressure indexingPressure; + private final ThreadContext threadContext; + + public IncrementalBulkService(Client client, IndexingPressure indexingPressure, ThreadContext threadContext) { + this.client = client; + this.indexingPressure = indexingPressure; + this.threadContext = threadContext; + } + + public Handler newBulkRequest() { + ensureEnabled(); + return newBulkRequest(null, null, null); + } + + public Handler newBulkRequest(@Nullable String waitForActiveShards, @Nullable TimeValue timeout, @Nullable String refresh) { + ensureEnabled(); + return new Handler(client, threadContext, indexingPressure, waitForActiveShards, timeout, refresh); + } + + private void ensureEnabled() { + if (enabledForTests.get() == false) { + throw new AssertionError("Unexpected incremental bulk request"); + } + } + + // This method only exists to tests that the feature flag works. Remove once we no longer need the flag. + public void setForTests(boolean value) { + enabledForTests.set(value); + } + + public static class Enabled implements Supplier { + + private final AtomicBoolean incrementalBulksEnabled = new AtomicBoolean(true); + + public Enabled() {} + + public Enabled(ClusterSettings clusterSettings) { + incrementalBulksEnabled.set(clusterSettings.get(INCREMENTAL_BULK)); + clusterSettings.addSettingsUpdateConsumer(INCREMENTAL_BULK, incrementalBulksEnabled::set); + } + + @Override + public Boolean get() { + return incrementalBulksEnabled.get(); + } + } + + public static class Handler implements Releasable { + + public static final BulkRequest.IncrementalState EMPTY_STATE = new BulkRequest.IncrementalState(Collections.emptyMap(), true); + + private final Client client; + private final ThreadContext threadContext; + private final IndexingPressure indexingPressure; + private final ActiveShardCount waitForActiveShards; + private final TimeValue timeout; + private final String refresh; + + private final ArrayList releasables = new ArrayList<>(4); + private final ArrayList responses = new ArrayList<>(2); + private boolean closed = false; + private boolean globalFailure = false; + private boolean incrementalRequestSubmitted = false; + private ThreadContext.StoredContext requestContext; + private Exception bulkActionLevelFailure = null; + private BulkRequest bulkRequest = null; + + protected Handler( + Client client, + ThreadContext threadContext, + IndexingPressure indexingPressure, + @Nullable String waitForActiveShards, + @Nullable TimeValue timeout, + @Nullable String refresh + ) { + this.client = client; + this.threadContext = threadContext; + this.requestContext = threadContext.newStoredContext(); + this.indexingPressure = indexingPressure; + this.waitForActiveShards = waitForActiveShards != null ? ActiveShardCount.parseString(waitForActiveShards) : null; + this.timeout = timeout; + this.refresh = refresh; + createNewBulkRequest(EMPTY_STATE); + } + + public void addItems(List> items, Releasable releasable, Runnable nextItems) { + assert closed == false; + if (bulkActionLevelFailure != null) { + shortCircuitDueToTopLevelFailure(items, releasable); + nextItems.run(); + } else { + assert bulkRequest != null; + if (internalAddItems(items, releasable)) { + if (shouldBackOff()) { + final boolean isFirstRequest = incrementalRequestSubmitted == false; + incrementalRequestSubmitted = true; + try (ThreadContext.StoredContext ignored = threadContext.stashContext()) { + requestContext.restore(); + final ArrayList toRelease = new ArrayList<>(releasables); + releasables.clear(); + client.bulk(bulkRequest, ActionListener.runAfter(new ActionListener<>() { + + @Override + public void onResponse(BulkResponse bulkResponse) { + handleBulkSuccess(bulkResponse); + createNewBulkRequest( + new BulkRequest.IncrementalState(bulkResponse.getIncrementalState().shardLevelFailures(), true) + ); + } + + @Override + public void onFailure(Exception e) { + handleBulkFailure(isFirstRequest, e); + } + }, () -> { + requestContext = threadContext.newStoredContext(); + toRelease.forEach(Releasable::close); + nextItems.run(); + })); + } + } else { + nextItems.run(); + } + } else { + nextItems.run(); + } + } + } + + private boolean shouldBackOff() { + return indexingPressure.shouldSplitBulks(); + } + + public void lastItems(List> items, Releasable releasable, ActionListener listener) { + if (bulkActionLevelFailure != null) { + shortCircuitDueToTopLevelFailure(items, releasable); + errorResponse(listener); + } else { + assert bulkRequest != null; + if (internalAddItems(items, releasable)) { + try (ThreadContext.StoredContext ignored = threadContext.stashContext()) { + requestContext.restore(); + final ArrayList toRelease = new ArrayList<>(releasables); + releasables.clear(); + client.bulk(bulkRequest, ActionListener.runBefore(new ActionListener<>() { + + private final boolean isFirstRequest = incrementalRequestSubmitted == false; + + @Override + public void onResponse(BulkResponse bulkResponse) { + handleBulkSuccess(bulkResponse); + listener.onResponse(combineResponses()); + } + + @Override + public void onFailure(Exception e) { + handleBulkFailure(isFirstRequest, e); + errorResponse(listener); + } + }, () -> toRelease.forEach(Releasable::close))); + } + } else { + errorResponse(listener); + } + } + } + + @Override + public void close() { + closed = true; + releasables.forEach(Releasable::close); + releasables.clear(); + } + + private void shortCircuitDueToTopLevelFailure(List> items, Releasable releasable) { + assert releasables.isEmpty(); + assert bulkRequest == null; + if (globalFailure == false) { + addItemLevelFailures(items); + } + Releasables.close(releasable); + } + + private void errorResponse(ActionListener listener) { + if (globalFailure) { + listener.onFailure(bulkActionLevelFailure); + } else { + listener.onResponse(combineResponses()); + } + } + + private void handleBulkSuccess(BulkResponse bulkResponse) { + responses.add(bulkResponse); + bulkRequest = null; + } + + private void handleBulkFailure(boolean isFirstRequest, Exception e) { + assert bulkActionLevelFailure == null; + globalFailure = isFirstRequest; + bulkActionLevelFailure = e; + addItemLevelFailures(bulkRequest.requests()); + bulkRequest = null; + } + + private void addItemLevelFailures(List> items) { + BulkItemResponse[] bulkItemResponses = new BulkItemResponse[items.size()]; + int idx = 0; + for (DocWriteRequest item : items) { + BulkItemResponse.Failure failure = new BulkItemResponse.Failure(item.index(), item.id(), bulkActionLevelFailure); + bulkItemResponses[idx++] = BulkItemResponse.failure(idx, item.opType(), failure); + } + + responses.add(new BulkResponse(bulkItemResponses, 0, 0)); + } + + private boolean internalAddItems(List> items, Releasable releasable) { + try { + bulkRequest.add(items); + releasables.add(releasable); + releasables.add( + indexingPressure.markCoordinatingOperationStarted( + items.size(), + items.stream().mapToLong(Accountable::ramBytesUsed).sum(), + false + ) + ); + return true; + } catch (EsRejectedExecutionException e) { + handleBulkFailure(incrementalRequestSubmitted == false, e); + releasables.forEach(Releasable::close); + releasables.clear(); + return false; + } + } + + private void createNewBulkRequest(BulkRequest.IncrementalState incrementalState) { + bulkRequest = new BulkRequest(); + bulkRequest.incrementalState(incrementalState); + + if (waitForActiveShards != null) { + bulkRequest.waitForActiveShards(waitForActiveShards); + } + if (timeout != null) { + bulkRequest.timeout(timeout); + } + if (refresh != null) { + bulkRequest.setRefreshPolicy(refresh); + } + } + + private void releaseCurrentReferences() { + bulkRequest = null; + releasables.forEach(Releasable::close); + releasables.clear(); + } + + private BulkResponse combineResponses() { + long tookInMillis = 0; + long ingestTookInMillis = 0; + int itemResponseCount = 0; + for (BulkResponse response : responses) { + tookInMillis += response.getTookInMillis(); + ingestTookInMillis += response.getIngestTookInMillis(); + itemResponseCount += response.getItems().length; + } + BulkItemResponse[] bulkItemResponses = new BulkItemResponse[itemResponseCount]; + int i = 0; + for (BulkResponse response : responses) { + for (BulkItemResponse itemResponse : response.getItems()) { + bulkItemResponses[i++] = itemResponse; + } + } + + return new BulkResponse(bulkItemResponses, tookInMillis, ingestTookInMillis); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportAbstractBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportAbstractBulkAction.java index d306299645d64..78652081c9f0d 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportAbstractBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportAbstractBulkAction.java @@ -112,7 +112,12 @@ protected void doExecute(Task task, BulkRequest bulkRequest, ActionListener {}; + } else { + releasable = indexingPressure.markCoordinatingOperationStarted(indexingOps, indexingBytes, isOnlySystem); + } final ActionListener releasingListener = ActionListener.runBefore(listener, releasable::close); final Executor executor = isOnlySystem ? systemWriteExecutor : writeExecutor; ensureClusterStateThenForkAndExecute(task, bulkRequest, executor, releasingListener); diff --git a/server/src/main/java/org/elasticsearch/common/settings/ClusterSettings.java b/server/src/main/java/org/elasticsearch/common/settings/ClusterSettings.java index 18aaaf414101b..2ab0318490f7a 100644 --- a/server/src/main/java/org/elasticsearch/common/settings/ClusterSettings.java +++ b/server/src/main/java/org/elasticsearch/common/settings/ClusterSettings.java @@ -11,6 +11,7 @@ import org.apache.logging.log4j.LogManager; import org.elasticsearch.action.admin.cluster.configuration.TransportAddVotingConfigExclusionsAction; import org.elasticsearch.action.admin.indices.close.TransportCloseIndexAction; +import org.elasticsearch.action.bulk.IncrementalBulkService; import org.elasticsearch.action.bulk.WriteAckDelay; import org.elasticsearch.action.datastreams.autosharding.DataStreamAutoShardingService; import org.elasticsearch.action.ingest.SimulatePipelineTransportAction; @@ -242,6 +243,7 @@ public void apply(Settings value, Settings current, Settings previous) { Metadata.SETTING_READ_ONLY_SETTING, Metadata.SETTING_READ_ONLY_ALLOW_DELETE_SETTING, ShardLimitValidator.SETTING_CLUSTER_MAX_SHARDS_PER_NODE, + IncrementalBulkService.INCREMENTAL_BULK, RecoverySettings.INDICES_RECOVERY_MAX_BYTES_PER_SEC_SETTING, RecoverySettings.INDICES_RECOVERY_RETRY_DELAY_STATE_SYNC_SETTING, RecoverySettings.INDICES_RECOVERY_RETRY_DELAY_NETWORK_SETTING, @@ -560,6 +562,7 @@ public void apply(Settings value, Settings current, Settings previous) { FsHealthService.REFRESH_INTERVAL_SETTING, FsHealthService.SLOW_PATH_LOGGING_THRESHOLD_SETTING, IndexingPressure.MAX_INDEXING_BYTES, + IndexingPressure.SPLIT_BULK_THRESHOLD, ShardLimitValidator.SETTING_CLUSTER_MAX_SHARDS_PER_NODE_FROZEN, DataTier.ENFORCE_DEFAULT_TIER_PREFERENCE_SETTING, CoordinationDiagnosticsService.IDENTITY_CHANGES_THRESHOLD_SETTING, diff --git a/server/src/main/java/org/elasticsearch/http/HttpBody.java b/server/src/main/java/org/elasticsearch/http/HttpBody.java new file mode 100644 index 0000000000000..a10487502ed3c --- /dev/null +++ b/server/src/main/java/org/elasticsearch/http/HttpBody.java @@ -0,0 +1,118 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.http; + +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.bytes.ReleasableBytesReference; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Releasable; + +/** + * A super-interface for different HTTP content implementations + */ +public sealed interface HttpBody extends Releasable permits HttpBody.Full, HttpBody.Stream { + + static Full fromBytesReference(BytesReference bytesRef) { + return new ByteRefHttpBody(bytesRef); + } + + static Full empty() { + return new ByteRefHttpBody(BytesArray.EMPTY); + } + + default boolean isFull() { + return this instanceof Full; + } + + default boolean isStream() { + return this instanceof Stream; + } + + /** + * Assumes that HTTP body is a full content. If not sure, use {@link HttpBody#isFull()}. + */ + default Full asFull() { + assert this instanceof Full; + return (Full) this; + } + + /** + * Assumes that HTTP body is a lazy-stream. If not sure, use {@link HttpBody#isStream()}. + */ + default Stream asStream() { + assert this instanceof Stream; + return (Stream) this; + } + + /** + * Full content represents a complete http body content that can be accessed immediately. + */ + non-sealed interface Full extends HttpBody { + BytesReference bytes(); + + @Override + default void close() {} + } + + /** + * Stream is a lazy-loaded content. Stream supports only single handler, this handler must be + * set before requesting next chunk. + */ + non-sealed interface Stream extends HttpBody { + /** + * Returns current handler + */ + @Nullable + ChunkHandler handler(); + + /** + * Adds tracing chunk handler. Tracing handler will be invoked before main handler, and + * should never release or call for next chunk. It should be used for monitoring and + * logging purposes. + */ + void addTracingHandler(ChunkHandler chunkHandler); + + /** + * Sets handler that can handle next chunk + */ + void setHandler(ChunkHandler chunkHandler); + + /** + * Request next chunk of data from the network. The size of the chunk depends on following + * factors. If request is not compressed then chunk size will be up to + * {@link HttpTransportSettings#SETTING_HTTP_MAX_CHUNK_SIZE}. If request is compressed then + * chunk size will be up to max_chunk_size * compression_ratio. Multiple calls can be + * deduplicated when next chunk is not yet available. It's recommended to call "next" once + * for every chunk. + *
+         * {@code
+         *     stream.setHandler((chunk, isLast) -> {
+         *         processChunk(chunk);
+         *         if (isLast == false) {
+         *             stream.next();
+         *         }
+         *     });
+         * }
+         * 
+ */ + void next(); + } + + @FunctionalInterface + interface ChunkHandler extends Releasable { + void onNext(ReleasableBytesReference chunk, boolean isLast); + + @Override + default void close() {} + } + + record ByteRefHttpBody(BytesReference bytes) implements Full {} +} diff --git a/server/src/main/java/org/elasticsearch/http/HttpClientStatsTracker.java b/server/src/main/java/org/elasticsearch/http/HttpClientStatsTracker.java index 9f7a4fdc2ee6e..59e45242e46c5 100644 --- a/server/src/main/java/org/elasticsearch/http/HttpClientStatsTracker.java +++ b/server/src/main/java/org/elasticsearch/http/HttpClientStatsTracker.java @@ -227,7 +227,11 @@ synchronized void update(HttpRequest httpRequest, HttpChannel httpChannel, long lastRequestTimeMillis = currentTimeMillis; lastUri = httpRequest.uri(); requestCount += 1; - requestSizeBytes += httpRequest.content().length(); + if (httpRequest.body().isFull()) { + requestSizeBytes += httpRequest.body().asFull().bytes().length(); + } else { + httpRequest.body().asStream().addTracingHandler((chunk, last) -> requestSizeBytes += chunk.length()); + } } private static String getFirstValueForHeader(final HttpRequest request, final String header) { diff --git a/server/src/main/java/org/elasticsearch/http/HttpRequest.java b/server/src/main/java/org/elasticsearch/http/HttpRequest.java index b41f82def5013..ca6e51f2cec08 100644 --- a/server/src/main/java/org/elasticsearch/http/HttpRequest.java +++ b/server/src/main/java/org/elasticsearch/http/HttpRequest.java @@ -28,7 +28,7 @@ enum HttpVersion { HTTP_1_1 } - BytesReference content(); + HttpBody body(); List strictCookies(); @@ -47,7 +47,7 @@ enum HttpVersion { Exception getInboundException(); /** - * Release any resources associated with this request. Implementations should be idempotent. The behavior of {@link #content()} + * Release any resources associated with this request. Implementations should be idempotent. The behavior of {@link #body()} * after this method has been invoked is undefined and implementation specific. */ void release(); diff --git a/server/src/main/java/org/elasticsearch/http/HttpTracer.java b/server/src/main/java/org/elasticsearch/http/HttpTracer.java index 81c406b3545ed..3d8360e6ee3fa 100644 --- a/server/src/main/java/org/elasticsearch/http/HttpTracer.java +++ b/server/src/main/java/org/elasticsearch/http/HttpTracer.java @@ -12,6 +12,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.bytes.ReleasableBytesReference; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Nullable; @@ -21,6 +22,7 @@ import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; +import java.io.IOException; import java.io.OutputStream; import java.util.List; @@ -78,10 +80,10 @@ HttpTracer maybeLogRequest(RestRequest restRequest, @Nullable Exception e) { e ); if (isBodyTracerEnabled()) { - try (var stream = HttpBodyTracer.getBodyOutputStream(restRequest.getRequestId(), HttpBodyTracer.Type.REQUEST)) { - restRequest.content().writeTo(stream); - } catch (Exception e2) { - assert false : e2; // no real IO here + if (restRequest.isFullContent()) { + logFullContent(restRequest); + } else { + logStreamContent(restRequest); } } @@ -90,6 +92,53 @@ HttpTracer maybeLogRequest(RestRequest restRequest, @Nullable Exception e) { return null; } + private void logFullContent(RestRequest restRequest) { + try (var stream = HttpBodyTracer.getBodyOutputStream(restRequest.getRequestId(), HttpBodyTracer.Type.REQUEST)) { + restRequest.content().writeTo(stream); + } catch (Exception e2) { + assert false : e2; // no real IO here + } + } + + private void logStreamContent(RestRequest restRequest) { + restRequest.contentStream().addTracingHandler(new LoggingChunkHandler(restRequest)); + } + + private static class LoggingChunkHandler implements HttpBody.ChunkHandler { + private final OutputStream stream; + private volatile boolean closed = false; + + LoggingChunkHandler(RestRequest request) { + stream = HttpBodyTracer.getBodyOutputStream(request.getRequestId(), HttpBodyTracer.Type.REQUEST); + } + + @Override + public void onNext(ReleasableBytesReference chunk, boolean isLast) { + try { + chunk.writeTo(stream); + } catch (IOException e) { + assert false : e; // no real IO + } finally { + if (isLast) { + this.close(); + } + } + } + + @Override + public void close() { + if (closed) { + return; + } + try { + closed = true; + stream.close(); + } catch (IOException e) { + assert false : e; // no real IO + } + } + } + boolean isBodyTracerEnabled() { return HttpBodyTracer.isEnabled(); } diff --git a/server/src/main/java/org/elasticsearch/index/IndexingPressure.java b/server/src/main/java/org/elasticsearch/index/IndexingPressure.java index 70300222883d2..14f8b92db3eaa 100644 --- a/server/src/main/java/org/elasticsearch/index/IndexingPressure.java +++ b/server/src/main/java/org/elasticsearch/index/IndexingPressure.java @@ -30,6 +30,12 @@ public class IndexingPressure { Setting.Property.NodeScope ); + public static final Setting SPLIT_BULK_THRESHOLD = Setting.memorySizeSetting( + "indexing_pressure.memory.split_bulk_threshold", + "8.5%", + Setting.Property.NodeScope + ); + private static final Logger logger = LogManager.getLogger(IndexingPressure.class); private final AtomicLong currentCombinedCoordinatingAndPrimaryBytes = new AtomicLong(0); @@ -57,10 +63,12 @@ public class IndexingPressure { private final AtomicLong primaryDocumentRejections = new AtomicLong(0); private final long primaryAndCoordinatingLimits; + private final long splitBulkThreshold; private final long replicaLimits; public IndexingPressure(Settings settings) { this.primaryAndCoordinatingLimits = MAX_INDEXING_BYTES.get(settings).getBytes(); + this.splitBulkThreshold = SPLIT_BULK_THRESHOLD.get(settings).getBytes(); this.replicaLimits = (long) (this.primaryAndCoordinatingLimits * 1.5); } @@ -204,6 +212,10 @@ public Releasable markReplicaOperationStarted(int operations, long bytes, boolea }); } + public boolean shouldSplitBulks() { + return currentCombinedCoordinatingAndPrimaryBytes.get() >= splitBulkThreshold; + } + public IndexingPressureStats stats() { return new IndexingPressureStats( totalCombinedCoordinatingAndPrimaryBytes.get(), diff --git a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java index b6a63aefcfaff..c4816b440f568 100644 --- a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java +++ b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java @@ -24,6 +24,7 @@ import org.elasticsearch.action.admin.cluster.repositories.reservedstate.ReservedRepositoryAction; import org.elasticsearch.action.admin.indices.template.reservedstate.ReservedComposableIndexTemplateAction; import org.elasticsearch.action.bulk.FailureStoreMetrics; +import org.elasticsearch.action.bulk.IncrementalBulkService; import org.elasticsearch.action.datastreams.autosharding.DataStreamAutoShardingService; import org.elasticsearch.action.ingest.ReservedPipelineAction; import org.elasticsearch.action.search.SearchExecutionStatsCollector; @@ -890,6 +891,13 @@ private void construct( .map(TerminationHandlerProvider::handler); terminationHandler = getSinglePlugin(terminationHandlers, TerminationHandler.class).orElse(null); + final IndexingPressure indexingLimits = new IndexingPressure(settings); + final IncrementalBulkService incrementalBulkService = new IncrementalBulkService( + client, + indexingLimits, + threadPool.getThreadContext() + ); + ActionModule actionModule = new ActionModule( settings, clusterModule.getIndexNameExpressionResolver(), @@ -915,7 +923,8 @@ private void construct( metadataCreateIndexService, dataStreamGlobalRetentionSettings ), - pluginsService.loadSingletonServiceProvider(RestExtension.class, RestExtension::allowAll) + pluginsService.loadSingletonServiceProvider(RestExtension.class, RestExtension::allowAll), + incrementalBulkService ); modules.add(actionModule); @@ -978,7 +987,6 @@ private void construct( SearchExecutionStatsCollector.makeWrapper(responseCollectorService) ); final HttpServerTransport httpServerTransport = serviceProvider.newHttpTransport(pluginsService, networkModule); - final IndexingPressure indexingLimits = new IndexingPressure(settings); SnapshotsService snapshotsService = new SnapshotsService( settings, @@ -1140,6 +1148,7 @@ private void construct( b.bind(PageCacheRecycler.class).toInstance(pageCacheRecycler); b.bind(IngestService.class).toInstance(ingestService); b.bind(IndexingPressure.class).toInstance(indexingLimits); + b.bind(IncrementalBulkService.class).toInstance(incrementalBulkService); b.bind(AggregationUsageService.class).toInstance(searchModule.getValuesSourceRegistry().getUsageService()); b.bind(MetaStateService.class).toInstance(metaStateService); b.bind(IndicesService.class).toInstance(indicesService); diff --git a/server/src/main/java/org/elasticsearch/rest/BaseRestHandler.java b/server/src/main/java/org/elasticsearch/rest/BaseRestHandler.java index 27432050c8b45..5f12a2bdd6783 100644 --- a/server/src/main/java/org/elasticsearch/rest/BaseRestHandler.java +++ b/server/src/main/java/org/elasticsearch/rest/BaseRestHandler.java @@ -11,6 +11,7 @@ import org.apache.lucene.search.spell.LevenshteinDistance; import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.common.bytes.ReleasableBytesReference; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Setting.Property; import org.elasticsearch.common.util.set.Sets; @@ -19,6 +20,7 @@ import org.elasticsearch.core.Releasable; import org.elasticsearch.core.RestApiVersion; import org.elasticsearch.core.Tuple; +import org.elasticsearch.http.HttpBody; import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.rest.action.admin.cluster.RestNodesUsageAction; @@ -117,12 +119,28 @@ public final void handleRequest(RestRequest request, RestChannel channel, NodeCl throw new IllegalArgumentException(unrecognized(request, unconsumedParams, candidateParams, "parameter")); } - if (request.hasContent() && request.isContentConsumed() == false) { + if (request.hasContent() && (request.isContentConsumed() == false && request.isFullContent())) { throw new IllegalArgumentException( "request [" + request.method() + " " + request.path() + "] does not support having a body" ); } + if (request.isStreamedContent()) { + assert action instanceof RequestBodyChunkConsumer; + var chunkConsumer = (RequestBodyChunkConsumer) action; + request.contentStream().setHandler(new HttpBody.ChunkHandler() { + @Override + public void onNext(ReleasableBytesReference chunk, boolean isLast) { + chunkConsumer.handleChunk(channel, chunk, isLast); + } + + @Override + public void close() { + chunkConsumer.streamClose(); + } + }); + } + usageCount.increment(); // execute the action action.accept(channel); @@ -180,6 +198,17 @@ protected interface RestChannelConsumer extends CheckedConsumer 0) { + if (request.hasContent()) { if (isContentTypeDisallowed(request) || handler.mediaTypesValid(request) == false) { sendContentTypeErrorMessage(request.getAllHeaderValues("Content-Type"), channel); return; @@ -454,6 +453,9 @@ private void dispatchRequest( return; } } + // TODO: estimate streamed content size for circuit breaker, + // something like http_max_chunk_size * avg_compression_ratio(for compressed content) + final int contentLength = request.isFullContent() ? request.contentLength() : 0; try { if (handler.canTripCircuitBreaker()) { inFlightRequestsBreaker(circuitBreakerService).addEstimateBytesAndMaybeBreak(contentLength, ""); diff --git a/server/src/main/java/org/elasticsearch/rest/RestRequest.java b/server/src/main/java/org/elasticsearch/rest/RestRequest.java index fb227f471256d..e48677f46d57a 100644 --- a/server/src/main/java/org/elasticsearch/rest/RestRequest.java +++ b/server/src/main/java/org/elasticsearch/rest/RestRequest.java @@ -24,6 +24,7 @@ import org.elasticsearch.core.RestApiVersion; import org.elasticsearch.core.TimeValue; import org.elasticsearch.core.Tuple; +import org.elasticsearch.http.HttpBody; import org.elasticsearch.http.HttpChannel; import org.elasticsearch.http.HttpRequest; import org.elasticsearch.telemetry.tracing.Traceable; @@ -304,16 +305,28 @@ public final String path() { } public boolean hasContent() { - return contentLength() > 0; + return isStreamedContent() || contentLength() > 0; } public int contentLength() { - return httpRequest.content().length(); + return httpRequest.body().asFull().bytes().length(); + } + + public boolean isFullContent() { + return httpRequest.body().isFull(); } public BytesReference content() { this.contentConsumed = true; - return httpRequest.content(); + return httpRequest.body().asFull().bytes(); + } + + public boolean isStreamedContent() { + return httpRequest.body().isStream(); + } + + public HttpBody.Stream contentStream() { + return httpRequest.body().asStream(); } /** diff --git a/server/src/main/java/org/elasticsearch/rest/action/document/RestBulkAction.java b/server/src/main/java/org/elasticsearch/rest/action/document/RestBulkAction.java index 74009401f02c9..ff87bb834f3e1 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/document/RestBulkAction.java +++ b/server/src/main/java/org/elasticsearch/rest/action/document/RestBulkAction.java @@ -9,22 +9,39 @@ package org.elasticsearch.rest.action.document; +import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.action.DocWriteRequest; import org.elasticsearch.action.bulk.BulkRequest; +import org.elasticsearch.action.bulk.BulkRequestParser; import org.elasticsearch.action.bulk.BulkShardRequest; +import org.elasticsearch.action.bulk.IncrementalBulkService; import org.elasticsearch.action.support.ActiveShardCount; import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.bytes.CompositeBytesReference; +import org.elasticsearch.common.bytes.ReleasableBytesReference; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; import org.elasticsearch.core.RestApiVersion; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.Scope; import org.elasticsearch.rest.ServerlessScope; import org.elasticsearch.rest.action.RestRefCountedChunkedToXContentListener; +import org.elasticsearch.rest.action.RestToXContentListener; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; +import org.elasticsearch.transport.Transports; import java.io.IOException; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.function.Supplier; import static org.elasticsearch.rest.RestRequest.Method.POST; import static org.elasticsearch.rest.RestRequest.Method.PUT; @@ -40,12 +57,15 @@ */ @ServerlessScope(Scope.PUBLIC) public class RestBulkAction extends BaseRestHandler { + public static final String TYPES_DEPRECATION_MESSAGE = "[types removal] Specifying types in bulk requests is deprecated."; private final boolean allowExplicitIndex; + private final IncrementalBulkService bulkHandler; - public RestBulkAction(Settings settings) { + public RestBulkAction(Settings settings, IncrementalBulkService bulkHandler) { this.allowExplicitIndex = MULTI_ALLOW_EXPLICIT_INDEX.get(settings); + this.bulkHandler = bulkHandler; } @Override @@ -67,38 +87,199 @@ public String getName() { @Override public RestChannelConsumer prepareRequest(final RestRequest request, final NodeClient client) throws IOException { - if (request.getRestApiVersion() == RestApiVersion.V_7 && request.hasParam("type")) { - request.param("type"); + if (request.isStreamedContent() == false) { + if (request.getRestApiVersion() == RestApiVersion.V_7 && request.hasParam("type")) { + request.param("type"); + } + BulkRequest bulkRequest = new BulkRequest(); + String defaultIndex = request.param("index"); + String defaultRouting = request.param("routing"); + FetchSourceContext defaultFetchSourceContext = FetchSourceContext.parseFromRestRequest(request); + String defaultPipeline = request.param("pipeline"); + boolean defaultListExecutedPipelines = request.paramAsBoolean("list_executed_pipelines", false); + String waitForActiveShards = request.param("wait_for_active_shards"); + if (waitForActiveShards != null) { + bulkRequest.waitForActiveShards(ActiveShardCount.parseString(waitForActiveShards)); + } + Boolean defaultRequireAlias = request.paramAsBoolean(DocWriteRequest.REQUIRE_ALIAS, false); + boolean defaultRequireDataStream = request.paramAsBoolean(DocWriteRequest.REQUIRE_DATA_STREAM, false); + bulkRequest.timeout(request.paramAsTime("timeout", BulkShardRequest.DEFAULT_TIMEOUT)); + bulkRequest.setRefreshPolicy(request.param("refresh")); + bulkRequest.add( + request.requiredContent(), + defaultIndex, + defaultRouting, + defaultFetchSourceContext, + defaultPipeline, + defaultRequireAlias, + defaultRequireDataStream, + defaultListExecutedPipelines, + allowExplicitIndex, + request.getXContentType(), + request.getRestApiVersion() + ); + + return channel -> client.bulk(bulkRequest, new RestRefCountedChunkedToXContentListener<>(channel)); + } else { + if (request.getRestApiVersion() == RestApiVersion.V_7 && request.hasParam("type")) { + request.param("type"); + } + + String waitForActiveShards = request.param("wait_for_active_shards"); + TimeValue timeout = request.paramAsTime("timeout", BulkShardRequest.DEFAULT_TIMEOUT); + String refresh = request.param("refresh"); + return new ChunkHandler(allowExplicitIndex, request, () -> bulkHandler.newBulkRequest(waitForActiveShards, timeout, refresh)); + } + } + + static class ChunkHandler implements BaseRestHandler.RequestBodyChunkConsumer { + + private final boolean allowExplicitIndex; + private final RestRequest request; + + private final Map stringDeduplicator = new HashMap<>(); + private final String defaultIndex; + private final String defaultRouting; + private final FetchSourceContext defaultFetchSourceContext; + private final String defaultPipeline; + private final boolean defaultListExecutedPipelines; + private final Boolean defaultRequireAlias; + private final boolean defaultRequireDataStream; + private final BulkRequestParser parser; + private final Supplier handlerSupplier; + private IncrementalBulkService.Handler handler; + + private volatile RestChannel restChannel; + private boolean shortCircuited; + private int bytesParsed = 0; + private final ArrayDeque unParsedChunks = new ArrayDeque<>(4); + private final ArrayList> items = new ArrayList<>(4); + + ChunkHandler(boolean allowExplicitIndex, RestRequest request, Supplier handlerSupplier) { + this.allowExplicitIndex = allowExplicitIndex; + this.request = request; + this.defaultIndex = request.param("index"); + this.defaultRouting = request.param("routing"); + this.defaultFetchSourceContext = FetchSourceContext.parseFromRestRequest(request); + this.defaultPipeline = request.param("pipeline"); + this.defaultListExecutedPipelines = request.paramAsBoolean("list_executed_pipelines", false); + this.defaultRequireAlias = request.paramAsBoolean(DocWriteRequest.REQUIRE_ALIAS, false); + this.defaultRequireDataStream = request.paramAsBoolean(DocWriteRequest.REQUIRE_DATA_STREAM, false); + // TODO: Fix type deprecation logging + this.parser = new BulkRequestParser(false, request.getRestApiVersion()); + this.handlerSupplier = handlerSupplier; } - BulkRequest bulkRequest = new BulkRequest(); - String defaultIndex = request.param("index"); - String defaultRouting = request.param("routing"); - FetchSourceContext defaultFetchSourceContext = FetchSourceContext.parseFromRestRequest(request); - String defaultPipeline = request.param("pipeline"); - boolean defaultListExecutedPipelines = request.paramAsBoolean("list_executed_pipelines", false); - String waitForActiveShards = request.param("wait_for_active_shards"); - if (waitForActiveShards != null) { - bulkRequest.waitForActiveShards(ActiveShardCount.parseString(waitForActiveShards)); + + @Override + public void accept(RestChannel restChannel) { + this.restChannel = restChannel; + this.handler = handlerSupplier.get(); + request.contentStream().next(); } - Boolean defaultRequireAlias = request.paramAsBoolean(DocWriteRequest.REQUIRE_ALIAS, false); - boolean defaultRequireDataStream = request.paramAsBoolean(DocWriteRequest.REQUIRE_DATA_STREAM, false); - bulkRequest.timeout(request.paramAsTime("timeout", BulkShardRequest.DEFAULT_TIMEOUT)); - bulkRequest.setRefreshPolicy(request.param("refresh")); - bulkRequest.add( - request.requiredContent(), - defaultIndex, - defaultRouting, - defaultFetchSourceContext, - defaultPipeline, - defaultRequireAlias, - defaultRequireDataStream, - defaultListExecutedPipelines, - allowExplicitIndex, - request.getXContentType(), - request.getRestApiVersion() - ); - return channel -> client.bulk(bulkRequest, new RestRefCountedChunkedToXContentListener<>(channel)); + @Override + public void handleChunk(RestChannel channel, ReleasableBytesReference chunk, boolean isLast) { + assert handler != null; + assert channel == restChannel; + if (shortCircuited) { + chunk.close(); + return; + } + + final BytesReference data; + int bytesConsumed; + if (chunk.length() == 0) { + chunk.close(); + bytesConsumed = 0; + } else { + try { + unParsedChunks.add(chunk); + + if (unParsedChunks.size() > 1) { + data = CompositeBytesReference.of(unParsedChunks.toArray(new ReleasableBytesReference[0])); + } else { + data = chunk; + } + + // TODO: Check that the behavior here vs. globalRouting, globalPipeline, globalRequireAlias, globalRequireDatsStream in + // BulkRequest#add is fine + bytesConsumed = parser.incrementalParse( + data, + defaultIndex, + defaultRouting, + defaultFetchSourceContext, + defaultPipeline, + defaultRequireAlias, + defaultRequireDataStream, + defaultListExecutedPipelines, + allowExplicitIndex, + request.getXContentType(), + (request, type) -> items.add(request), + items::add, + items::add, + isLast == false, + stringDeduplicator + ); + bytesParsed += bytesConsumed; + + } catch (Exception e) { + shortCircuit(); + new RestToXContentListener<>(channel).onFailure( + new ElasticsearchParseException("could not parse bulk request body", e) + ); + return; + } + } + + final ArrayList releasables = accountParsing(bytesConsumed); + if (isLast) { + assert unParsedChunks.isEmpty(); + if (bytesParsed == 0) { + shortCircuit(); + new RestToXContentListener<>(channel).onFailure(new ElasticsearchParseException("request body is required")); + } else { + assert channel != null; + ArrayList> toPass = new ArrayList<>(items); + items.clear(); + handler.lastItems(toPass, () -> Releasables.close(releasables), new RestRefCountedChunkedToXContentListener<>(channel)); + } + } else if (items.isEmpty() == false) { + ArrayList> toPass = new ArrayList<>(items); + items.clear(); + handler.addItems(toPass, () -> Releasables.close(releasables), () -> request.contentStream().next()); + } else { + assert releasables.isEmpty(); + request.contentStream().next(); + } + } + + @Override + public void streamClose() { + assert Transports.assertTransportThread(); + shortCircuit(); + } + + private void shortCircuit() { + shortCircuited = true; + Releasables.close(handler); + Releasables.close(unParsedChunks); + unParsedChunks.clear(); + } + + private ArrayList accountParsing(int bytesConsumed) { + ArrayList releasables = new ArrayList<>(unParsedChunks.size()); + while (bytesConsumed > 0) { + ReleasableBytesReference reference = unParsedChunks.removeFirst(); + releasables.add(reference); + if (bytesConsumed >= reference.length()) { + bytesConsumed -= reference.length(); + } else { + unParsedChunks.addFirst(reference.retainedSlice(bytesConsumed, reference.length() - bytesConsumed)); + bytesConsumed = 0; + } + } + return releasables; + } } @Override diff --git a/server/src/test/java/org/elasticsearch/action/ActionModuleTests.java b/server/src/test/java/org/elasticsearch/action/ActionModuleTests.java index 4ed493a94e20e..871062a687429 100644 --- a/server/src/test/java/org/elasticsearch/action/ActionModuleTests.java +++ b/server/src/test/java/org/elasticsearch/action/ActionModuleTests.java @@ -10,6 +10,7 @@ package org.elasticsearch.action; import org.elasticsearch.action.admin.cluster.node.info.TransportNodesInfoAction; +import org.elasticsearch.action.bulk.IncrementalBulkService; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.TransportAction; import org.elasticsearch.client.internal.node.NodeClient; @@ -129,7 +130,8 @@ public void testSetupRestHandlerContainsKnownBuiltin() { mock(ClusterService.class), null, List.of(), - RestExtension.allowAll() + RestExtension.allowAll(), + new IncrementalBulkService(null, null, new ThreadContext(Settings.EMPTY)) ); actionModule.initRestHandlers(null, null); // At this point the easiest way to confirm that a handler is loaded is to try to register another one on top of it and to fail @@ -193,7 +195,8 @@ public String getName() { mock(ClusterService.class), null, List.of(), - RestExtension.allowAll() + RestExtension.allowAll(), + new IncrementalBulkService(null, null, new ThreadContext(Settings.EMPTY)) ); Exception e = expectThrows(IllegalArgumentException.class, () -> actionModule.initRestHandlers(null, null)); assertThat(e.getMessage(), startsWith("Cannot replace existing handler for [/_nodes] for method: GET")); @@ -250,7 +253,8 @@ public List getRestHandlers( mock(ClusterService.class), null, List.of(), - RestExtension.allowAll() + RestExtension.allowAll(), + new IncrementalBulkService(null, null, new ThreadContext(Settings.EMPTY)) ); actionModule.initRestHandlers(null, null); // At this point the easiest way to confirm that a handler is loaded is to try to register another one on top of it and to fail @@ -300,7 +304,8 @@ public void test3rdPartyHandlerIsNotInstalled() { mock(ClusterService.class), null, List.of(), - RestExtension.allowAll() + RestExtension.allowAll(), + new IncrementalBulkService(null, null, new ThreadContext(Settings.EMPTY)) ) ); assertThat( @@ -341,7 +346,8 @@ public void test3rdPartyRestControllerIsNotInstalled() { mock(ClusterService.class), null, List.of(), - RestExtension.allowAll() + RestExtension.allowAll(), + new IncrementalBulkService(null, null, new ThreadContext(Settings.EMPTY)) ) ); assertThat( diff --git a/server/src/test/java/org/elasticsearch/http/AbstractHttpServerTransportTests.java b/server/src/test/java/org/elasticsearch/http/AbstractHttpServerTransportTests.java index c842dbd294b65..981eae9d60694 100644 --- a/server/src/test/java/org/elasticsearch/http/AbstractHttpServerTransportTests.java +++ b/server/src/test/java/org/elasticsearch/http/AbstractHttpServerTransportTests.java @@ -15,6 +15,7 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionModule; +import org.elasticsearch.action.bulk.IncrementalBulkService; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.UUIDs; import org.elasticsearch.common.bytes.BytesArray; @@ -1177,7 +1178,8 @@ public Collection getRestHeaders() { mock(ClusterService.class), null, List.of(), - RestExtension.allowAll() + RestExtension.allowAll(), + new IncrementalBulkService(null, null, new ThreadContext(Settings.EMPTY)) ); } diff --git a/server/src/test/java/org/elasticsearch/http/HttpClientStatsTrackerTests.java b/server/src/test/java/org/elasticsearch/http/HttpClientStatsTrackerTests.java index 7de283bab2ea1..a1129e4a717fd 100644 --- a/server/src/test/java/org/elasticsearch/http/HttpClientStatsTrackerTests.java +++ b/server/src/test/java/org/elasticsearch/http/HttpClientStatsTrackerTests.java @@ -120,7 +120,7 @@ public void testStatsCollection() { assertThat(clientStats.remoteAddress(), equalTo(NetworkAddress.format(httpChannel.getRemoteAddress()))); assertThat(clientStats.lastUri(), equalTo(httpRequest1.uri())); assertThat(clientStats.requestCount(), equalTo(1L)); - requestLength += httpRequest1.content().length(); + requestLength += httpRequest1.body().asFull().bytes().length(); assertThat(clientStats.requestSizeBytes(), equalTo(requestLength)); assertThat(clientStats.closedTimeMillis(), equalTo(-1L)); assertThat(clientStats.openedTimeMillis(), equalTo(openTimeMillis)); @@ -150,7 +150,7 @@ public void testStatsCollection() { assertThat(clientStats.remoteAddress(), equalTo(NetworkAddress.format(httpChannel.getRemoteAddress()))); assertThat(clientStats.lastUri(), equalTo(httpRequest2.uri())); assertThat(clientStats.requestCount(), equalTo(2L)); - requestLength += httpRequest2.content().length(); + requestLength += httpRequest2.body().asFull().bytes().length(); assertThat(clientStats.requestSizeBytes(), equalTo(requestLength)); assertThat(clientStats.closedTimeMillis(), equalTo(-1L)); assertThat(clientStats.openedTimeMillis(), equalTo(openTimeMillis)); diff --git a/server/src/test/java/org/elasticsearch/http/TestHttpRequest.java b/server/src/test/java/org/elasticsearch/http/TestHttpRequest.java index d6cf010a90471..8cd61453a3391 100644 --- a/server/src/test/java/org/elasticsearch/http/TestHttpRequest.java +++ b/server/src/test/java/org/elasticsearch/http/TestHttpRequest.java @@ -9,7 +9,6 @@ package org.elasticsearch.http; -import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.rest.ChunkedRestResponseBodyPart; import org.elasticsearch.rest.RestRequest; @@ -49,8 +48,8 @@ public String uri() { } @Override - public BytesReference content() { - return BytesArray.EMPTY; + public HttpBody body() { + return HttpBody.empty(); } @Override diff --git a/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java b/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java index 8853665ae2641..1d946681661e7 100644 --- a/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java +++ b/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java @@ -27,6 +27,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.RestApiVersion; +import org.elasticsearch.http.HttpBody; import org.elasticsearch.http.HttpHeadersValidationException; import org.elasticsearch.http.HttpInfo; import org.elasticsearch.http.HttpRequest; @@ -831,11 +832,11 @@ public String uri() { } @Override - public BytesReference content() { + public HttpBody body() { if (hasContent) { - return new BytesArray("test"); + return HttpBody.fromBytesReference(new BytesArray("test")); } - return BytesArray.EMPTY; + return HttpBody.empty(); } @Override diff --git a/server/src/test/java/org/elasticsearch/rest/RestRequestTests.java b/server/src/test/java/org/elasticsearch/rest/RestRequestTests.java index 17c66b888b320..8a0ca5ba6c8a5 100644 --- a/server/src/test/java/org/elasticsearch/rest/RestRequestTests.java +++ b/server/src/test/java/org/elasticsearch/rest/RestRequestTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.core.CheckedConsumer; +import org.elasticsearch.http.HttpBody; import org.elasticsearch.http.HttpChannel; import org.elasticsearch.http.HttpRequest; import org.elasticsearch.test.ESTestCase; @@ -87,7 +88,7 @@ public void testContentLengthDoesNotConsumesContent() { private void runConsumesContentTest(final CheckedConsumer consumer, final boolean expected) { final HttpRequest httpRequest = mock(HttpRequest.class); when(httpRequest.uri()).thenReturn(""); - when(httpRequest.content()).thenReturn(new BytesArray(new byte[1])); + when(httpRequest.body()).thenReturn(HttpBody.fromBytesReference(new BytesArray(new byte[1]))); when(httpRequest.getHeaders()).thenReturn( Collections.singletonMap("Content-Type", Collections.singletonList(randomFrom("application/json", "application/x-ndjson"))) ); diff --git a/server/src/test/java/org/elasticsearch/rest/action/document/RestBulkActionTests.java b/server/src/test/java/org/elasticsearch/rest/action/document/RestBulkActionTests.java index 772ff0efb1218..d3cd6dd9ca420 100644 --- a/server/src/test/java/org/elasticsearch/rest/action/document/RestBulkActionTests.java +++ b/server/src/test/java/org/elasticsearch/rest/action/document/RestBulkActionTests.java @@ -11,23 +11,37 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.DocWriteRequest; import org.elasticsearch.action.bulk.BulkRequest; import org.elasticsearch.action.bulk.BulkResponse; +import org.elasticsearch.action.bulk.IncrementalBulkService; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.update.UpdateRequest; +import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.ReleasableBytesReference; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.http.HttpBody; import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.IndexingPressure; import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.client.NoOpNodeClient; +import org.elasticsearch.test.rest.FakeRestChannel; import org.elasticsearch.test.rest.FakeRestRequest; import org.elasticsearch.xcontent.XContentType; +import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.mockito.Mockito.mock; @@ -51,7 +65,10 @@ public void bulk(BulkRequest request, ActionListener listener) { }; final Map params = new HashMap<>(); params.put("pipeline", "timestamps"); - new RestBulkAction(settings(IndexVersion.current()).build()).handleRequest( + new RestBulkAction( + settings(IndexVersion.current()).build(), + new IncrementalBulkService(mock(Client.class), mock(IndexingPressure.class), new ThreadContext(Settings.EMPTY)) + ).handleRequest( new FakeRestRequest.Builder(xContentRegistry()).withPath("my_index/_bulk").withParams(params).withContent(new BytesArray(""" {"index":{"_id":"1"}} {"field1":"val1"} @@ -83,7 +100,10 @@ public void bulk(BulkRequest request, ActionListener listener) { }; Map params = new HashMap<>(); { - new RestBulkAction(settings(IndexVersion.current()).build()).handleRequest( + new RestBulkAction( + settings(IndexVersion.current()).build(), + new IncrementalBulkService(mock(Client.class), mock(IndexingPressure.class), new ThreadContext(Settings.EMPTY)) + ).handleRequest( new FakeRestRequest.Builder(xContentRegistry()).withPath("my_index/_bulk") .withParams(params) .withContent(new BytesArray(""" @@ -104,7 +124,10 @@ public void bulk(BulkRequest request, ActionListener listener) { { params.put("list_executed_pipelines", "true"); bulkCalled.set(false); - new RestBulkAction(settings(IndexVersion.current()).build()).handleRequest( + new RestBulkAction( + settings(IndexVersion.current()).build(), + new IncrementalBulkService(mock(Client.class), mock(IndexingPressure.class), new ThreadContext(Settings.EMPTY)) + ).handleRequest( new FakeRestRequest.Builder(xContentRegistry()).withPath("my_index/_bulk") .withParams(params) .withContent(new BytesArray(""" @@ -124,7 +147,10 @@ public void bulk(BulkRequest request, ActionListener listener) { } { bulkCalled.set(false); - new RestBulkAction(settings(IndexVersion.current()).build()).handleRequest( + new RestBulkAction( + settings(IndexVersion.current()).build(), + new IncrementalBulkService(mock(Client.class), mock(IndexingPressure.class), new ThreadContext(Settings.EMPTY)) + ).handleRequest( new FakeRestRequest.Builder(xContentRegistry()).withPath("my_index/_bulk") .withParams(params) .withContent(new BytesArray(""" @@ -145,7 +171,10 @@ public void bulk(BulkRequest request, ActionListener listener) { { params.remove("list_executed_pipelines"); bulkCalled.set(false); - new RestBulkAction(settings(IndexVersion.current()).build()).handleRequest( + new RestBulkAction( + settings(IndexVersion.current()).build(), + new IncrementalBulkService(mock(Client.class), mock(IndexingPressure.class), new ThreadContext(Settings.EMPTY)) + ).handleRequest( new FakeRestRequest.Builder(xContentRegistry()).withPath("my_index/_bulk") .withParams(params) .withContent(new BytesArray(""" @@ -165,4 +194,98 @@ public void bulk(BulkRequest request, ActionListener listener) { } } } + + public void testIncrementalParsing() { + ArrayList> docs = new ArrayList<>(); + AtomicBoolean isLast = new AtomicBoolean(false); + AtomicBoolean next = new AtomicBoolean(false); + + FakeRestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withPath("my_index/_bulk") + .withMethod(RestRequest.Method.POST) + .withBody(new HttpBody.Stream() { + @Override + public void close() {} + + @Override + public ChunkHandler handler() { + return null; + } + + @Override + public void addTracingHandler(ChunkHandler chunkHandler) {} + + @Override + public void setHandler(ChunkHandler chunkHandler) {} + + @Override + public void next() { + next.set(true); + } + }) + .withHeaders(Map.of("Content-Type", Collections.singletonList("application/json"))) + .build(); + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + + RestBulkAction.ChunkHandler chunkHandler = new RestBulkAction.ChunkHandler( + true, + request, + () -> new IncrementalBulkService.Handler(null, new ThreadContext(Settings.EMPTY), null, null, null, null) { + + @Override + public void addItems(List> items, Releasable releasable, Runnable nextItems) { + releasable.close(); + docs.addAll(items); + } + + @Override + public void lastItems(List> items, Releasable releasable, ActionListener listener) { + releasable.close(); + docs.addAll(items); + isLast.set(true); + } + } + ); + + chunkHandler.accept(channel); + ReleasableBytesReference r1 = new ReleasableBytesReference(new BytesArray("{\"index\":{\"_index\":\"index_name\"}}\n"), () -> {}); + chunkHandler.handleChunk(channel, r1, false); + assertThat(docs, empty()); + assertTrue(next.get()); + next.set(false); + assertFalse(isLast.get()); + + ReleasableBytesReference r2 = new ReleasableBytesReference(new BytesArray("{\"field\":1}"), () -> {}); + chunkHandler.handleChunk(channel, r2, false); + assertThat(docs, empty()); + assertTrue(next.get()); + next.set(false); + assertFalse(isLast.get()); + assertTrue(r1.hasReferences()); + assertTrue(r2.hasReferences()); + + ReleasableBytesReference r3 = new ReleasableBytesReference(new BytesArray("\n{\"delete\":"), () -> {}); + chunkHandler.handleChunk(channel, r3, false); + assertThat(docs, hasSize(1)); + assertFalse(next.get()); + assertFalse(isLast.get()); + assertFalse(r1.hasReferences()); + assertFalse(r2.hasReferences()); + assertTrue(r3.hasReferences()); + + ReleasableBytesReference r4 = new ReleasableBytesReference(new BytesArray("{\"_index\":\"test\",\"_id\":\"2\"}}"), () -> {}); + chunkHandler.handleChunk(channel, r4, false); + assertThat(docs, hasSize(1)); + assertTrue(next.get()); + next.set(false); + assertFalse(isLast.get()); + + ReleasableBytesReference r5 = new ReleasableBytesReference(new BytesArray("\n"), () -> {}); + chunkHandler.handleChunk(channel, r5, true); + assertThat(docs, hasSize(2)); + assertFalse(next.get()); + assertTrue(isLast.get()); + assertFalse(r3.hasReferences()); + assertFalse(r4.hasReferences()); + assertFalse(r5.hasReferences()); + } } diff --git a/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java index 9132474fa9415..92e480aff3bc9 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java @@ -26,6 +26,7 @@ import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.DocWriteRequest; import org.elasticsearch.action.DocWriteResponse; import org.elasticsearch.action.admin.cluster.allocation.ClusterAllocationExplainRequest; import org.elasticsearch.action.admin.cluster.allocation.ClusterAllocationExplainResponse; @@ -48,6 +49,8 @@ import org.elasticsearch.action.admin.indices.template.put.PutIndexTemplateRequestBuilder; import org.elasticsearch.action.bulk.BulkRequestBuilder; import org.elasticsearch.action.bulk.BulkResponse; +import org.elasticsearch.action.bulk.IncrementalBulkService; +import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.action.ingest.DeletePipelineRequest; import org.elasticsearch.action.ingest.DeletePipelineTransportAction; @@ -122,6 +125,7 @@ import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexModule; import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.index.IndexingPressure; import org.elasticsearch.index.MergePolicyConfig; import org.elasticsearch.index.MergeSchedulerConfig; import org.elasticsearch.index.MockEngineFactoryPlugin; @@ -192,6 +196,7 @@ import java.util.Random; import java.util.Set; import java.util.concurrent.Callable; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; @@ -1776,11 +1781,49 @@ public void indexRandom(boolean forceRefresh, boolean dummyDocuments, boolean ma ); logger.info("Index [{}] docs async: [{}] bulk: [{}] partitions [{}]", builders.size(), false, true, partition.size()); for (List segmented : partition) { - BulkRequestBuilder bulkBuilder = client().prepareBulk(); - for (IndexRequestBuilder indexRequestBuilder : segmented) { - bulkBuilder.add(indexRequestBuilder); + BulkResponse actionGet; + if (randomBoolean()) { + BulkRequestBuilder bulkBuilder = client().prepareBulk(); + for (IndexRequestBuilder indexRequestBuilder : segmented) { + bulkBuilder.add(indexRequestBuilder); + } + actionGet = bulkBuilder.get(); + } else { + IncrementalBulkService bulkService = internalCluster().getInstance(IncrementalBulkService.class); + IncrementalBulkService.Handler handler = bulkService.newBulkRequest(); + + ConcurrentLinkedQueue queue = new ConcurrentLinkedQueue<>(); + segmented.forEach(b -> queue.add(b.request())); + + PlainActionFuture future = new PlainActionFuture<>(); + AtomicInteger runs = new AtomicInteger(0); + Runnable r = new Runnable() { + + @Override + public void run() { + int toRemove = Math.min(randomIntBetween(5, 10), queue.size()); + ArrayList> docs = new ArrayList<>(); + for (int i = 0; i < toRemove; i++) { + docs.add(queue.poll()); + } + + if (queue.isEmpty()) { + handler.lastItems(docs, () -> {}, future); + } else { + handler.addItems(docs, () -> {}, () -> { + // Every 10 runs dispatch to new thread to prevent stackoverflow + if (runs.incrementAndGet() % 10 == 0) { + new Thread(this).start(); + } else { + this.run(); + } + }); + } + } + }; + r.run(); + actionGet = future.actionGet(); } - BulkResponse actionGet = bulkBuilder.get(); assertThat(actionGet.hasFailures() ? actionGet.buildFailureMessage() : "", actionGet.hasFailures(), equalTo(false)); } } @@ -2061,6 +2104,9 @@ protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { TransportSearchAction.DEFAULT_PRE_FILTER_SHARD_SIZE.getKey(), randomFrom(1, 2, SearchRequest.DEFAULT_PRE_FILTER_SHARD_SIZE) ); + if (randomBoolean()) { + builder.put(IndexingPressure.SPLIT_BULK_THRESHOLD.getKey(), randomFrom("256B", "1KB", "64KB")); + } return builder.build(); } diff --git a/test/framework/src/main/java/org/elasticsearch/test/rest/FakeRestRequest.java b/test/framework/src/main/java/org/elasticsearch/test/rest/FakeRestRequest.java index ab3543d8f2bb7..9ddcf39d24d98 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/rest/FakeRestRequest.java +++ b/test/framework/src/main/java/org/elasticsearch/test/rest/FakeRestRequest.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.http.HttpBody; import org.elasticsearch.http.HttpChannel; import org.elasticsearch.http.HttpRequest; import org.elasticsearch.http.HttpResponse; @@ -54,24 +55,24 @@ public static class FakeHttpRequest implements HttpRequest { private final Method method; private final String uri; - private final BytesReference content; + private final HttpBody content; private final Map> headers; private final Exception inboundException; public FakeHttpRequest(Method method, String uri, BytesReference content, Map> headers) { - this(method, uri, content, headers, null); + this(method, uri, content == null ? HttpBody.empty() : HttpBody.fromBytesReference(content), headers, null); } private FakeHttpRequest( Method method, String uri, - BytesReference content, + HttpBody content, Map> headers, Exception inboundException ) { this.method = method; this.uri = uri; - this.content = content == null ? BytesArray.EMPTY : content; + this.content = content; this.headers = headers; this.inboundException = inboundException; } @@ -87,7 +88,7 @@ public String uri() { } @Override - public BytesReference content() { + public HttpBody body() { return content; } @@ -195,7 +196,7 @@ public static class Builder { private Map params = new HashMap<>(); - private BytesReference content = BytesArray.EMPTY; + private HttpBody content = HttpBody.empty(); private String path = "/"; @@ -221,13 +222,18 @@ public Builder withParams(Map params) { } public Builder withContent(BytesReference contentBytes, XContentType xContentType) { - this.content = contentBytes; + this.content = HttpBody.fromBytesReference(contentBytes); if (xContentType != null) { headers.put("Content-Type", Collections.singletonList(xContentType.mediaType())); } return this; } + public Builder withBody(HttpBody body) { + this.content = body; + return this; + } + public Builder withPath(String path) { this.path = path; return this; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/SecurityTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/SecurityTests.java index a07a7a3a5dd27..8d580f10e5137 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/SecurityTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/SecurityTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionModule; import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.bulk.IncrementalBulkService; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; @@ -822,7 +823,8 @@ public void testSecurityRestHandlerInterceptorCanBeInstalled() throws IllegalAcc mock(ClusterService.class), null, List.of(), - RestExtension.allowAll() + RestExtension.allowAll(), + new IncrementalBulkService(null, null, new ThreadContext(Settings.EMPTY)) ); actionModule.initRestHandlers(null, null); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrailTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrailTests.java index cb524a48d0ec7..5adc1e351931d 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrailTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrailTests.java @@ -2644,7 +2644,7 @@ public void testAuthenticationSuccessRest() throws Exception { checkedFields.put(LoggingAuditTrail.REQUEST_ID_FIELD_NAME, requestId); checkedFields.put(LoggingAuditTrail.URL_PATH_FIELD_NAME, "_uri"); if (includeRequestBody && Strings.hasLength(request.content())) { - checkedFields.put(LoggingAuditTrail.REQUEST_BODY_FIELD_NAME, request.getHttpRequest().content().utf8ToString()); + checkedFields.put(LoggingAuditTrail.REQUEST_BODY_FIELD_NAME, request.getHttpRequest().body().asFull().bytes().utf8ToString()); } if (params.isEmpty() == false) { checkedFields.put(LoggingAuditTrail.URL_QUERY_FIELD_NAME, "foo=bar&evac=true");