diff --git a/server/src/main/java/org/elasticsearch/transport/NetworkMessage.java b/server/src/main/java/org/elasticsearch/transport/NetworkMessage.java deleted file mode 100644 index 8a0be7627ef9a..0000000000000 --- a/server/src/main/java/org/elasticsearch/transport/NetworkMessage.java +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ -package org.elasticsearch.transport; - -import org.elasticsearch.TransportVersion; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.util.concurrent.ThreadContext; - -/** - * Represents a transport message sent over the network. Subclasses implement serialization and - * deserialization. - */ -public abstract class NetworkMessage { - - protected final TransportVersion version; - protected final Writeable threadContext; - protected final long requestId; - protected final byte status; - protected final Compression.Scheme compressionScheme; - - NetworkMessage( - ThreadContext threadContext, - TransportVersion version, - byte status, - long requestId, - Compression.Scheme compressionScheme - ) { - this.threadContext = threadContext.captureAsWriteable(); - this.version = version; - this.requestId = requestId; - this.compressionScheme = compressionScheme; - if (this.compressionScheme != null) { - this.status = TransportStatus.setCompress(status); - } else { - this.status = status; - } - } - - boolean isCompress() { - return TransportStatus.isCompress(status); - } - - boolean isHandshake() { - return TransportStatus.isHandshake(status); - } - - boolean isError() { - return TransportStatus.isError(status); - } -} diff --git a/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java b/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java index 88686e7cf63e1..dcc4e08b52c20 100644 --- a/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java +++ b/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java @@ -17,25 +17,35 @@ import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.bytes.CompositeBytesReference; import org.elasticsearch.common.bytes.ReleasableBytesReference; +import org.elasticsearch.common.compress.CompressorFactory; +import org.elasticsearch.common.io.stream.OutputStreamStreamOutput; import org.elasticsearch.common.io.stream.RecyclerBytesStreamOutput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.network.CloseableChannel; import org.elasticsearch.common.network.HandlingTimeTracker; import org.elasticsearch.common.recycler.Recycler; import org.elasticsearch.common.transport.NetworkExceptionHelper; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.RefCounted; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; +import org.elasticsearch.core.Streams; import org.elasticsearch.core.TimeValue; import org.elasticsearch.core.UpdateForV10; import org.elasticsearch.threadpool.ThreadPool; import java.io.IOException; +import java.util.function.Supplier; import static org.elasticsearch.core.Strings.format; -final class OutboundHandler { +public final class OutboundHandler { private static final Logger logger = LogManager.getLogger(OutboundHandler.class); @@ -85,7 +95,7 @@ void setSlowLogThreshold(TimeValue slowLogThreshold) { * thread. */ void sendBytes(TcpChannel channel, BytesReference bytes, ActionListener listener) { - internalSend(channel, bytes, null, listener); + internalSend(channel, bytes, () -> "raw bytes", listener); } /** @@ -104,18 +114,14 @@ void sendRequest( final boolean isHandshake ) throws IOException, TransportException { assert assertValidTransportVersion(transportVersion); - final OutboundMessage.Request message = new OutboundMessage.Request( - threadPool.getThreadContext(), - request, - transportVersion, + sendMessage( + channel, action, + request, requestId, isHandshake, - compressionScheme - ); - sendMessage( - channel, - message, + compressionScheme, + transportVersion, ResponseStatsConsumer.NONE, () -> messageListener.onRequestSent(node, requestId, action, request, options) ); @@ -138,17 +144,19 @@ void sendResponse( final ResponseStatsConsumer responseStatsConsumer ) { assert assertValidTransportVersion(transportVersion); - OutboundMessage.Response message = new OutboundMessage.Response( - threadPool.getThreadContext(), - response, - transportVersion, - requestId, - isHandshake, - compressionScheme - ); assert response.hasReferences(); try { - sendMessage(channel, message, responseStatsConsumer, () -> messageListener.onResponseSent(requestId, action)); + sendMessage( + channel, + null, + response, + requestId, + isHandshake, + compressionScheme, + transportVersion, + responseStatsConsumer, + () -> messageListener.onResponseSent(requestId, action) + ); } catch (Exception ex) { if (isHandshake) { logger.error( @@ -178,16 +186,19 @@ void sendErrorResponse( final Exception error ) { assert assertValidTransportVersion(transportVersion); - OutboundMessage.Response message = new OutboundMessage.Response( - threadPool.getThreadContext(), - new RemoteTransportException(nodeName, channel.getLocalAddress(), action, error), - transportVersion, - requestId, - false, - null - ); + var msg = new RemoteTransportException(nodeName, channel.getLocalAddress(), action, error); try { - sendMessage(channel, message, responseStatsConsumer, () -> messageListener.onResponseSent(requestId, action, error)); + sendMessage( + channel, + null, + msg, + requestId, + false, + null, + transportVersion, + responseStatsConsumer, + () -> messageListener.onResponseSent(requestId, action, error) + ); } catch (Exception sendException) { sendException.addSuppressed(error); logger.error(() -> format("Failed to send error response on channel [%s], closing channel", channel), sendException); @@ -197,27 +208,35 @@ void sendErrorResponse( private void sendMessage( TcpChannel channel, - OutboundMessage networkMessage, + @Nullable String requestAction, + Writeable writeable, + long requestId, + boolean isHandshake, + Compression.Scheme compressionScheme, + TransportVersion version, ResponseStatsConsumer responseStatsConsumer, Releasable onAfter ) throws IOException { - final RecyclerBytesStreamOutput byteStreamOutput; - boolean bufferSuccess = false; - try { - byteStreamOutput = new RecyclerBytesStreamOutput(recycler); - bufferSuccess = true; - } finally { - if (bufferSuccess == false) { - Releasables.closeExpectNoException(onAfter); - } - } + compressionScheme = writeable instanceof BytesTransportRequest ? null : compressionScheme; final BytesReference message; boolean serializeSuccess = false; + final boolean isError = writeable instanceof RemoteTransportException; + final RecyclerBytesStreamOutput byteStreamOutput = new RecyclerBytesStreamOutput(recycler); try { - message = networkMessage.serialize(byteStreamOutput); + message = serialize( + requestAction, + requestId, + isHandshake, + version, + isError, + compressionScheme, + writeable, + threadPool.getThreadContext(), + byteStreamOutput + ); serializeSuccess = true; } catch (Exception e) { - logger.warn(() -> "failed to serialize outbound message [" + networkMessage + "]", e); + logger.warn(() -> "failed to serialize outbound message [" + writeable + "]", e); throw e; } finally { if (serializeSuccess == false) { @@ -225,10 +244,14 @@ private void sendMessage( } } responseStatsConsumer.addResponseStats(message.length()); + final var responseType = writeable.getClass(); + final boolean compress = compressionScheme != null; internalSend( channel, message, - networkMessage, + requestAction == null + ? () -> "Response{" + requestId + "}{" + isError + "}{" + compress + "}{" + isHandshake + "}{" + responseType + "}" + : () -> "Request{" + requestAction + "}{" + requestId + "}{" + isError + "}{" + compress + "}{" + isHandshake + "}", ActionListener.releasing( message instanceof ReleasableBytesReference r ? Releasables.wrap(byteStreamOutput, onAfter, r) @@ -237,10 +260,105 @@ private void sendMessage( ); } + // public for tests + public static BytesReference serialize( + @Nullable String requestAction, + long requestId, + boolean isHandshake, + TransportVersion version, + boolean isError, + Compression.Scheme compressionScheme, + Writeable writeable, + ThreadContext threadContext, + RecyclerBytesStreamOutput byteStreamOutput + ) throws IOException { + assert byteStreamOutput.position() == 0; + byteStreamOutput.setTransportVersion(version); + byteStreamOutput.skip(TcpHeader.HEADER_SIZE); + threadContext.writeTo(byteStreamOutput); + if (requestAction != null) { + if (version.before(TransportVersions.V_8_0_0)) { + // empty features array + byteStreamOutput.writeStringArray(Strings.EMPTY_ARRAY); + } + byteStreamOutput.writeString(requestAction); + } + + final int variableHeaderLength = Math.toIntExact(byteStreamOutput.position() - TcpHeader.HEADER_SIZE); + BytesReference message = serializeMessageBody(writeable, compressionScheme, version, byteStreamOutput); + byte status = 0; + if (requestAction == null) { + status = TransportStatus.setResponse(status); + } + if (isHandshake) { + status = TransportStatus.setHandshake(status); + } + if (isError) { + status = TransportStatus.setError(status); + } + if (compressionScheme != null) { + status = TransportStatus.setCompress(status); + } + byteStreamOutput.seek(0); + TcpHeader.writeHeader(byteStreamOutput, requestId, status, version, message.length() - TcpHeader.HEADER_SIZE, variableHeaderLength); + return message; + } + + private static BytesReference serializeMessageBody( + Writeable writeable, + Compression.Scheme compressionScheme, + TransportVersion version, + RecyclerBytesStreamOutput byteStreamOutput + ) throws IOException { + // The compressible bytes stream will not close the underlying bytes stream + final StreamOutput stream = compressionScheme != null ? wrapCompressed(compressionScheme, byteStreamOutput) : byteStreamOutput; + final ReleasableBytesReference zeroCopyBuffer; + try { + stream.setTransportVersion(version); + if (writeable instanceof BytesTransportRequest bRequest) { + bRequest.writeThin(stream); + zeroCopyBuffer = bRequest.bytes; + } else if (writeable instanceof RemoteTransportException remoteTransportException) { + stream.writeException(remoteTransportException); + zeroCopyBuffer = ReleasableBytesReference.empty(); + } else { + writeable.writeTo(stream); + zeroCopyBuffer = ReleasableBytesReference.empty(); + } + } finally { + // We have to close here before accessing the bytes when using compression to ensure that some marker bytes (EOS marker) + // are written. + if (compressionScheme != null) { + stream.close(); + } + } + final BytesReference msg = byteStreamOutput.bytes(); + if (zeroCopyBuffer.length() == 0) { + return msg; + } + zeroCopyBuffer.mustIncRef(); + return new ReleasableBytesReference(CompositeBytesReference.of(msg, zeroCopyBuffer), (RefCounted) zeroCopyBuffer); + } + + // compressed stream wrapped bytes must be no-close wrapped since we need to close the compressed wrapper below to release + // resources and write EOS marker bytes but must not yet release the bytes themselves + private static StreamOutput wrapCompressed(Compression.Scheme compressionScheme, RecyclerBytesStreamOutput bytesStream) + throws IOException { + if (compressionScheme == Compression.Scheme.DEFLATE) { + return new OutputStreamStreamOutput( + CompressorFactory.COMPRESSOR.threadLocalOutputStream(org.elasticsearch.core.Streams.noCloseStream(bytesStream)) + ); + } else if (compressionScheme == Compression.Scheme.LZ4) { + return new OutputStreamStreamOutput(Compression.Scheme.lz4OutputStream(Streams.noCloseStream(bytesStream))); + } else { + throw new IllegalArgumentException("Invalid compression scheme: " + compressionScheme); + } + } + private void internalSend( TcpChannel channel, BytesReference reference, - @Nullable OutboundMessage message, + Supplier messageDescription, ActionListener listener ) { final long startTime = threadPool.rawRelativeTimeInMillis(); @@ -280,7 +398,7 @@ private void maybeLogSlowMessage(boolean success) { logger.warn( "sending transport message [{}] of size [{}] on [{}] took [{}ms] which is above the warn " + "threshold of [{}ms] with success [{}]", - message, + messageDescription.get(), messageSize, channel, took, diff --git a/server/src/main/java/org/elasticsearch/transport/OutboundMessage.java b/server/src/main/java/org/elasticsearch/transport/OutboundMessage.java deleted file mode 100644 index 798385edefd6f..0000000000000 --- a/server/src/main/java/org/elasticsearch/transport/OutboundMessage.java +++ /dev/null @@ -1,202 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ -package org.elasticsearch.transport; - -import org.elasticsearch.TransportVersion; -import org.elasticsearch.TransportVersions; -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.bytes.CompositeBytesReference; -import org.elasticsearch.common.bytes.ReleasableBytesReference; -import org.elasticsearch.common.compress.CompressorFactory; -import org.elasticsearch.common.io.stream.OutputStreamStreamOutput; -import org.elasticsearch.common.io.stream.RecyclerBytesStreamOutput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.core.RefCounted; -import org.elasticsearch.core.Streams; - -import java.io.IOException; - -abstract class OutboundMessage extends NetworkMessage { - - protected final Writeable message; - - OutboundMessage( - ThreadContext threadContext, - TransportVersion version, - byte status, - long requestId, - Compression.Scheme compressionScheme, - Writeable message - ) { - super(threadContext, version, status, requestId, compressionScheme); - this.message = message; - } - - BytesReference serialize(RecyclerBytesStreamOutput bytesStream) throws IOException { - bytesStream.setTransportVersion(version); - bytesStream.skip(TcpHeader.HEADER_SIZE); - - // The compressible bytes stream will not close the underlying bytes stream - BytesReference reference; - final long preHeaderPosition = bytesStream.position(); - - writeVariableHeader(bytesStream); - int variableHeaderLength = Math.toIntExact(bytesStream.position() - preHeaderPosition); - - final boolean compress = TransportStatus.isCompress(status); - final StreamOutput stream = compress ? wrapCompressed(bytesStream) : bytesStream; - final ReleasableBytesReference zeroCopyBuffer; - try { - stream.setTransportVersion(version); - if (message instanceof BytesTransportRequest bRequest) { - bRequest.writeThin(stream); - zeroCopyBuffer = bRequest.bytes; - } else if (message instanceof RemoteTransportException) { - stream.writeException((RemoteTransportException) message); - zeroCopyBuffer = ReleasableBytesReference.empty(); - } else { - message.writeTo(stream); - zeroCopyBuffer = ReleasableBytesReference.empty(); - } - } finally { - // We have to close here before accessing the bytes when using compression to ensure that some marker bytes (EOS marker) - // are written. - if (compress) { - stream.close(); - } - } - final BytesReference message = bytesStream.bytes(); - if (zeroCopyBuffer.length() == 0) { - reference = message; - } else { - zeroCopyBuffer.mustIncRef(); - reference = new ReleasableBytesReference(CompositeBytesReference.of(message, zeroCopyBuffer), (RefCounted) zeroCopyBuffer); - } - - bytesStream.seek(0); - final int contentSize = reference.length() - TcpHeader.HEADER_SIZE; - TcpHeader.writeHeader(bytesStream, requestId, status, version, contentSize, variableHeaderLength); - return reference; - } - - // compressed stream wrapped bytes must be no-close wrapped since we need to close the compressed wrapper below to release - // resources and write EOS marker bytes but must not yet release the bytes themselves - private StreamOutput wrapCompressed(RecyclerBytesStreamOutput bytesStream) throws IOException { - if (compressionScheme == Compression.Scheme.DEFLATE) { - return new OutputStreamStreamOutput( - CompressorFactory.COMPRESSOR.threadLocalOutputStream(org.elasticsearch.core.Streams.noCloseStream(bytesStream)) - ); - } else if (compressionScheme == Compression.Scheme.LZ4) { - return new OutputStreamStreamOutput(Compression.Scheme.lz4OutputStream(Streams.noCloseStream(bytesStream))); - } else { - throw new IllegalArgumentException("Invalid compression scheme: " + compressionScheme); - } - } - - protected void writeVariableHeader(StreamOutput stream) throws IOException { - threadContext.writeTo(stream); - } - - static class Request extends OutboundMessage { - - private final String action; - - Request( - ThreadContext threadContext, - Writeable message, - TransportVersion version, - String action, - long requestId, - boolean isHandshake, - Compression.Scheme compressionScheme - ) { - super(threadContext, version, setStatus(isHandshake), requestId, adjustCompressionScheme(compressionScheme, message), message); - this.action = action; - } - - @Override - protected void writeVariableHeader(StreamOutput stream) throws IOException { - super.writeVariableHeader(stream); - if (version.before(TransportVersions.V_8_0_0)) { - // empty features array - stream.writeStringArray(Strings.EMPTY_ARRAY); - } - stream.writeString(action); - } - - // Do not compress instances of BytesTransportRequest - private static Compression.Scheme adjustCompressionScheme(Compression.Scheme compressionScheme, Writeable message) { - if (message instanceof BytesTransportRequest) { - return null; - } else { - return compressionScheme; - } - } - - private static byte setStatus(boolean isHandshake) { - byte status = 0; - status = TransportStatus.setRequest(status); - if (isHandshake) { - status = TransportStatus.setHandshake(status); - } - - return status; - } - - @Override - public String toString() { - return "Request{" + action + "}{" + requestId + "}{" + isError() + "}{" + isCompress() + "}{" + isHandshake() + "}"; - } - } - - static class Response extends OutboundMessage { - - Response( - ThreadContext threadContext, - Writeable message, - TransportVersion version, - long requestId, - boolean isHandshake, - Compression.Scheme compressionScheme - ) { - super(threadContext, version, setStatus(isHandshake, message), requestId, compressionScheme, message); - } - - private static byte setStatus(boolean isHandshake, Writeable message) { - byte status = 0; - status = TransportStatus.setResponse(status); - if (message instanceof RemoteTransportException) { - status = TransportStatus.setError(status); - } - if (isHandshake) { - status = TransportStatus.setHandshake(status); - } - - return status; - } - - @Override - public String toString() { - return "Response{" - + requestId - + "}{" - + isError() - + "}{" - + isCompress() - + "}{" - + isHandshake() - + "}{" - + message.getClass() - + "}"; - } - } -} diff --git a/server/src/test/java/org/elasticsearch/transport/InboundDecoderTests.java b/server/src/test/java/org/elasticsearch/transport/InboundDecoderTests.java index be51cecc2cf9a..50fbb2ae4895e 100644 --- a/server/src/test/java/org/elasticsearch/transport/InboundDecoderTests.java +++ b/server/src/test/java/org/elasticsearch/transport/InboundDecoderTests.java @@ -54,30 +54,34 @@ public void testDecode() throws IOException { } else { threadContext.addResponseHeader(headerKey, headerValue); } - OutboundMessage message; - if (isRequest) { - message = new OutboundMessage.Request( - threadContext, - new TestRequest(randomAlphaOfLength(100)), - TransportVersion.current(), - action, - requestId, - false, - null - ); - } else { - message = new OutboundMessage.Response( - threadContext, - new TestResponse(randomAlphaOfLength(100)), - TransportVersion.current(), - requestId, - false, - null - ); - } try (RecyclerBytesStreamOutput os = new RecyclerBytesStreamOutput(recycler)) { - final BytesReference totalBytes = message.serialize(os); + final BytesReference totalBytes; + if (isRequest) { + totalBytes = OutboundHandler.serialize( + action, + requestId, + false, + TransportVersion.current(), + false, + null, + new TestRequest(randomAlphaOfLength(100)), + threadContext, + os + ); + } else { + totalBytes = OutboundHandler.serialize( + null, + requestId, + false, + TransportVersion.current(), + false, + null, + new TestResponse(randomAlphaOfLength(100)), + threadContext, + os + ); + } int totalHeaderSize = TcpHeader.HEADER_SIZE + totalBytes.getInt(TcpHeader.VARIABLE_HEADER_SIZE_POSITION); final BytesReference messageBytes = totalBytes.slice(totalHeaderSize, totalBytes.length() - totalHeaderSize); @@ -137,18 +141,19 @@ private void doHandshakeCompatibilityTest(TransportVersion transportVersion, Com final String headerKey = randomAlphaOfLength(10); final String headerValue = randomAlphaOfLength(20); threadContext.putHeader(headerKey, headerValue); - OutboundMessage message = new OutboundMessage.Request( - threadContext, - new TestRequest(randomAlphaOfLength(100)), - transportVersion, - action, - requestId, - true, - compressionScheme - ); try (RecyclerBytesStreamOutput os = new RecyclerBytesStreamOutput(recycler)) { - final BytesReference bytes = message.serialize(os); + final BytesReference bytes = OutboundHandler.serialize( + action, + requestId, + true, + transportVersion, + false, + compressionScheme, + new TestRequest(randomAlphaOfLength(100)), + threadContext, + os + ); InboundDecoder decoder = new InboundDecoder(recycler); final ArrayList fragments = new ArrayList<>(); @@ -187,18 +192,19 @@ public void testClientChannelTypeFailsDecodingRequests() throws Exception { ? randomFrom(TransportHandshaker.ALLOWED_HANDSHAKE_VERSIONS) : TransportVersionUtils.randomCompatibleVersion(random()); logger.info("--> version = {}", version); - OutboundMessage message = new OutboundMessage.Request( - threadContext, - new TestRequest(randomAlphaOfLength(100)), - version, - action, - requestId, - isHandshake, - randomFrom(Compression.Scheme.DEFLATE, Compression.Scheme.LZ4, null) - ); try (RecyclerBytesStreamOutput os = new RecyclerBytesStreamOutput(recycler)) { - final BytesReference bytes = message.serialize(os); + final BytesReference bytes = OutboundHandler.serialize( + action, + requestId, + isHandshake, + version, + false, + randomFrom(Compression.Scheme.DEFLATE, Compression.Scheme.LZ4, null), + new TestRequest(randomAlphaOfLength(100)), + threadContext, + os + ); try (InboundDecoder clientDecoder = new InboundDecoder(recycler, ChannelType.CLIENT)) { IllegalArgumentException e = expectThrows( IllegalArgumentException.class, @@ -234,17 +240,19 @@ public void testServerChannelTypeFailsDecodingResponses() throws Exception { final var version = isHandshake ? randomFrom(TransportHandshaker.ALLOWED_HANDSHAKE_VERSIONS) : TransportVersionUtils.randomCompatibleVersion(random()); - OutboundMessage message = new OutboundMessage.Response( - threadContext, - new TestResponse(randomAlphaOfLength(100)), - version, - requestId, - isHandshake, - randomFrom(Compression.Scheme.DEFLATE, Compression.Scheme.LZ4, null) - ); try (RecyclerBytesStreamOutput os = new RecyclerBytesStreamOutput(recycler)) { - final BytesReference bytes = message.serialize(os); + final BytesReference bytes = OutboundHandler.serialize( + null, + requestId, + isHandshake, + version, + false, + randomFrom(Compression.Scheme.DEFLATE, Compression.Scheme.LZ4, null), + new TestRequest(randomAlphaOfLength(100)), + threadContext, + os + ); try (InboundDecoder decoder = new InboundDecoder(recycler, ChannelType.SERVER)) { final ReleasableBytesReference releasable1 = wrapAsReleasable(bytes); IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> decoder.decode(releasable1, ignored -> {})); @@ -273,27 +281,38 @@ public void testCompressedDecode() throws IOException { } else { threadContext.addResponseHeader(headerKey, headerValue); } - OutboundMessage message; + final BytesReference totalBytes; TransportMessage transportMessage; Compression.Scheme scheme = randomFrom(Compression.Scheme.DEFLATE, Compression.Scheme.LZ4); - if (isRequest) { - transportMessage = new TestRequest(randomAlphaOfLength(100)); - message = new OutboundMessage.Request( - threadContext, - transportMessage, - TransportVersion.current(), - action, - requestId, - false, - scheme - ); - } else { - transportMessage = new TestResponse(randomAlphaOfLength(100)); - message = new OutboundMessage.Response(threadContext, transportMessage, TransportVersion.current(), requestId, false, scheme); - } try (RecyclerBytesStreamOutput os = new RecyclerBytesStreamOutput(recycler)) { - final BytesReference totalBytes = message.serialize(os); + if (isRequest) { + transportMessage = new TestRequest(randomAlphaOfLength(100)); + totalBytes = OutboundHandler.serialize( + action, + requestId, + false, + TransportVersion.current(), + false, + scheme, + transportMessage, + threadContext, + os + ); + } else { + transportMessage = new TestResponse(randomAlphaOfLength(100)); + totalBytes = OutboundHandler.serialize( + null, + requestId, + false, + TransportVersion.current(), + false, + scheme, + transportMessage, + threadContext, + os + ); + } final BytesStreamOutput out = new BytesStreamOutput(); transportMessage.writeTo(out); final BytesReference uncompressedBytes = out.bytes(); @@ -351,19 +370,19 @@ public void testVersionIncompatibilityDecodeException() throws IOException { String action = "test-request"; long requestId = randomNonNegativeLong(); TransportVersion incompatibleVersion = TransportVersionUtils.getPreviousVersion(TransportVersions.MINIMUM_COMPATIBLE); - OutboundMessage message = new OutboundMessage.Request( - threadContext, - new TestRequest(randomAlphaOfLength(100)), - incompatibleVersion, - action, - requestId, - false, - Compression.Scheme.DEFLATE - ); - final ReleasableBytesReference releasable1; try (RecyclerBytesStreamOutput os = new RecyclerBytesStreamOutput(recycler)) { - final BytesReference bytes = message.serialize(os); + final BytesReference bytes = OutboundHandler.serialize( + action, + requestId, + false, + incompatibleVersion, + false, + Compression.Scheme.DEFLATE, + new TestRequest(randomAlphaOfLength(100)), + threadContext, + os + ); InboundDecoder decoder = new InboundDecoder(recycler); final ArrayList fragments = new ArrayList<>(); diff --git a/server/src/test/java/org/elasticsearch/transport/InboundHandlerTests.java b/server/src/test/java/org/elasticsearch/transport/InboundHandlerTests.java index 713123469ecfe..0a11c413a43bf 100644 --- a/server/src/test/java/org/elasticsearch/transport/InboundHandlerTests.java +++ b/server/src/test/java/org/elasticsearch/transport/InboundHandlerTests.java @@ -170,18 +170,19 @@ public TestResponse read(StreamInput in) throws IOException { ); requestHandlers.registerHandler(registry); String requestValue = randomAlphaOfLength(10); - OutboundMessage.Request request = new OutboundMessage.Request( - threadPool.getThreadContext(), - new TestRequest(requestValue), - TransportVersion.current(), + BytesRefRecycler recycler = new BytesRefRecycler(PageCacheRecycler.NON_RECYCLING_INSTANCE); + BytesReference fullRequestBytes = OutboundHandler.serialize( action, requestId, false, - null + TransportVersion.current(), + false, + null, + new TestRequest(requestValue), + threadPool.getThreadContext(), + new RecyclerBytesStreamOutput(recycler) ); - BytesRefRecycler recycler = new BytesRefRecycler(PageCacheRecycler.NON_RECYCLING_INSTANCE); - BytesReference fullRequestBytes = request.serialize(new RecyclerBytesStreamOutput(recycler)); BytesReference requestContent = fullRequestBytes.slice(TcpHeader.HEADER_SIZE, fullRequestBytes.length() - TcpHeader.HEADER_SIZE); Header requestHeader = new Header( fullRequestBytes.length() - 6, diff --git a/server/src/test/java/org/elasticsearch/transport/InboundPipelineTests.java b/server/src/test/java/org/elasticsearch/transport/InboundPipelineTests.java index 347c2dab878c1..9f4ce4811a9e7 100644 --- a/server/src/test/java/org/elasticsearch/transport/InboundPipelineTests.java +++ b/server/src/test/java/org/elasticsearch/transport/InboundPipelineTests.java @@ -23,7 +23,6 @@ import org.elasticsearch.common.util.MockPageCacheRecycler; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.Releasable; -import org.elasticsearch.core.Streams; import org.elasticsearch.core.TimeValue; import org.elasticsearch.core.Tuple; import org.elasticsearch.test.ESTestCase; @@ -109,47 +108,54 @@ public void testPipelineHandling() throws IOException { final MessageData messageData; Exception expectedExceptionClass = null; - OutboundMessage message; - if (isRequest) { - if (rarely()) { - messageData = new MessageData(version, requestId, true, compressionScheme, breakThisAction, null); - message = new OutboundMessage.Request( - threadContext, - new TestRequest(value), - version, - breakThisAction, + BytesReference message; + try (RecyclerBytesStreamOutput temporaryOutput = new RecyclerBytesStreamOutput(recycler)) { + if (isRequest) { + if (rarely()) { + messageData = new MessageData(version, requestId, true, compressionScheme, breakThisAction, null); + message = OutboundHandler.serialize( + breakThisAction, + requestId, + false, + version, + false, + compressionScheme, + new TestRequest(value), + threadContext, + temporaryOutput + ); + expectedExceptionClass = new CircuitBreakingException("", CircuitBreaker.Durability.PERMANENT); + } else { + messageData = new MessageData(version, requestId, true, compressionScheme, actionName, value); + message = OutboundHandler.serialize( + actionName, + requestId, + false, + version, + false, + compressionScheme, + new TestRequest(value), + threadContext, + temporaryOutput + ); + } + } else { + messageData = new MessageData(version, requestId, false, compressionScheme, null, value); + message = OutboundHandler.serialize( + null, requestId, false, - compressionScheme - ); - expectedExceptionClass = new CircuitBreakingException("", CircuitBreaker.Durability.PERMANENT); - } else { - messageData = new MessageData(version, requestId, true, compressionScheme, actionName, value); - message = new OutboundMessage.Request( - threadContext, - new TestRequest(value), version, - actionName, - requestId, false, - compressionScheme + compressionScheme, + new TestResponse(value), + threadContext, + temporaryOutput ); } - } else { - messageData = new MessageData(version, requestId, false, compressionScheme, null, value); - message = new OutboundMessage.Response( - threadContext, - new TestResponse(value), - version, - requestId, - false, - compressionScheme - ); - } - expected.add(new Tuple<>(messageData, expectedExceptionClass)); - try (RecyclerBytesStreamOutput temporaryOutput = new RecyclerBytesStreamOutput(recycler)) { - Streams.copy(message.serialize(temporaryOutput).streamInput(), streamOutput, false); + expected.add(new Tuple<>(messageData, expectedExceptionClass)); + message.writeTo(streamOutput); } } @@ -213,23 +219,34 @@ public void testDecodeExceptionIsPropagated() throws IOException { final boolean isRequest = randomBoolean(); final long requestId = randomNonNegativeLong(); - OutboundMessage message; + BytesReference message; if (isRequest) { - message = new OutboundMessage.Request( - threadContext, - new TestRequest(value), - invalidVersion, + message = OutboundHandler.serialize( actionName, requestId, false, - null + invalidVersion, + false, + null, + new TestRequest(value), + threadContext, + streamOutput ); } else { - message = new OutboundMessage.Response(threadContext, new TestResponse(value), invalidVersion, requestId, false, null); + message = OutboundHandler.serialize( + null, + requestId, + false, + invalidVersion, + false, + null, + new TestResponse(value), + threadContext, + streamOutput + ); } - final BytesReference reference = message.serialize(streamOutput); - try (ReleasableBytesReference releasable = ReleasableBytesReference.wrap(reference)) { + try (ReleasableBytesReference releasable = ReleasableBytesReference.wrap(message)) { expectThrows(IllegalStateException.class, () -> pipeline.handleBytes(new FakeTcpChannel(), releasable)); } @@ -258,14 +275,33 @@ public void testEnsureBodyIsNotPrematurelyReleased() throws IOException { final boolean isRequest = randomBoolean(); final long requestId = randomNonNegativeLong(); - OutboundMessage message; + final BytesReference reference; if (isRequest) { - message = new OutboundMessage.Request(threadContext, new TestRequest(value), version, actionName, requestId, false, null); + reference = OutboundHandler.serialize( + actionName, + requestId, + false, + version, + false, + null, + new TestRequest(value), + threadContext, + streamOutput + ); } else { - message = new OutboundMessage.Response(threadContext, new TestResponse(value), version, requestId, false, null); + reference = OutboundHandler.serialize( + null, + requestId, + false, + version, + false, + null, + new TestResponse(value), + threadContext, + streamOutput + ); } - final BytesReference reference = message.serialize(streamOutput); final int variableHeaderSize = reference.getInt(TcpHeader.HEADER_SIZE - 4); final int totalHeaderSize = TcpHeader.HEADER_SIZE + variableHeaderSize; final AtomicBoolean bodyReleased = new AtomicBoolean(false); diff --git a/server/src/test/java/org/elasticsearch/transport/TransportLoggerTests.java b/server/src/test/java/org/elasticsearch/transport/TransportLoggerTests.java index 828ff3f152e60..2167a9d6bcad7 100644 --- a/server/src/test/java/org/elasticsearch/transport/TransportLoggerTests.java +++ b/server/src/test/java/org/elasticsearch/transport/TransportLoggerTests.java @@ -70,16 +70,17 @@ private BytesReference buildRequest() throws IOException { BytesRefRecycler recycler = new BytesRefRecycler(PageCacheRecycler.NON_RECYCLING_INSTANCE); Compression.Scheme compress = randomFrom(Compression.Scheme.DEFLATE, Compression.Scheme.LZ4, null); try (RecyclerBytesStreamOutput bytesStreamOutput = new RecyclerBytesStreamOutput(recycler)) { - OutboundMessage.Request request = new OutboundMessage.Request( - new ThreadContext(Settings.EMPTY), - new EmptyRequest(), - TransportVersion.current(), + return OutboundHandler.serialize( "internal:test", randomInt(30), false, - compress + TransportVersion.current(), + false, + compress, + new EmptyRequest(), + new ThreadContext(Settings.EMPTY), + bytesStreamOutput ); - return request.serialize(bytesStreamOutput); } } } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/TestOutboundRequestMessage.java b/test/framework/src/main/java/org/elasticsearch/transport/TestOutboundRequestMessage.java deleted file mode 100644 index 2d49d64868ace..0000000000000 --- a/test/framework/src/main/java/org/elasticsearch/transport/TestOutboundRequestMessage.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.transport; - -import org.elasticsearch.TransportVersion; -import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.io.stream.RecyclerBytesStreamOutput; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.util.concurrent.ThreadContext; - -import java.io.IOException; - -public class TestOutboundRequestMessage extends OutboundMessage.Request { - public TestOutboundRequestMessage( - ThreadContext threadContext, - Writeable message, - TransportVersion version, - String action, - long requestId, - boolean isHandshake, - Compression.Scheme compressionScheme - ) { - super(threadContext, message, version, action, requestId, isHandshake, compressionScheme); - - } - - @Override - public BytesReference serialize(RecyclerBytesStreamOutput bytesStream) throws IOException { - return super.serialize(bytesStream); - } -} diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4ServerTransportAuthenticationTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4ServerTransportAuthenticationTests.java index d294fb50046d6..8fb4d5c004a3f 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4ServerTransportAuthenticationTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4ServerTransportAuthenticationTests.java @@ -38,13 +38,13 @@ import org.elasticsearch.transport.BytesRefRecycler; import org.elasticsearch.transport.Compression; import org.elasticsearch.transport.EmptyRequest; +import org.elasticsearch.transport.OutboundHandler; import org.elasticsearch.transport.ProxyConnectionStrategy; import org.elasticsearch.transport.RemoteClusterPortSettings; import org.elasticsearch.transport.RemoteClusterService; import org.elasticsearch.transport.RemoteConnectionStrategy; import org.elasticsearch.transport.RemoteTransportException; import org.elasticsearch.transport.SniffConnectionStrategy; -import org.elasticsearch.transport.TestOutboundRequestMessage; import org.elasticsearch.transport.TransportInterceptor; import org.elasticsearch.transport.TransportRequest; import org.elasticsearch.transport.TransportRequestHandler; @@ -331,18 +331,19 @@ public void testConnectionDisconnectedWhenAuthnFails() throws Exception { TransportAddress[] boundRemoteIngressAddresses = remoteSecurityNetty4ServerTransport.boundRemoteIngressAddress().boundAddresses(); InetSocketAddress remoteIngressTransportAddress = randomFrom(boundRemoteIngressAddresses).address(); try (Socket socket = new MockSocket(remoteIngressTransportAddress.getAddress(), remoteIngressTransportAddress.getPort())) { - TestOutboundRequestMessage message = new TestOutboundRequestMessage( - threadPool.getThreadContext(), - new EmptyRequest(), - TransportVersion.current(), + Recycler recycler = new BytesRefRecycler(PageCacheRecycler.NON_RECYCLING_INSTANCE); + RecyclerBytesStreamOutput out = new RecyclerBytesStreamOutput(recycler); + BytesReference bytesReference = OutboundHandler.serialize( "internal:whatever", randomNonNegativeLong(), false, - randomFrom(Compression.Scheme.DEFLATE, Compression.Scheme.LZ4, null) + TransportVersion.current(), + false, + randomFrom(Compression.Scheme.DEFLATE, Compression.Scheme.LZ4, null), + new EmptyRequest(), + threadPool.getThreadContext(), + out ); - Recycler recycler = new BytesRefRecycler(PageCacheRecycler.NON_RECYCLING_INSTANCE); - RecyclerBytesStreamOutput out = new RecyclerBytesStreamOutput(recycler); - BytesReference bytesReference = message.serialize(out); socket.getOutputStream().write(Arrays.copyOfRange(bytesReference.array(), 0, bytesReference.length())); socket.getOutputStream().flush();