Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import org.elasticsearch.common.recycler.Recycler;
import org.elasticsearch.common.transport.NetworkExceptionHelper;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.RefCounted;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
Expand Down Expand Up @@ -114,6 +113,7 @@ void sendRequest(
assert assertValidTransportVersion(transportVersion);
sendMessage(
channel,
MessageDirection.REQUEST,
action,
request,
requestId,
Expand Down Expand Up @@ -146,7 +146,8 @@ void sendResponse(
try {
sendMessage(
channel,
null,
MessageDirection.RESPONSE,
action,
response,
requestId,
isHandshake,
Expand Down Expand Up @@ -188,7 +189,8 @@ void sendErrorResponse(
try {
sendMessage(
channel,
null,
MessageDirection.RESPONSE_ERROR,
action,
msg,
requestId,
false,
Expand All @@ -204,29 +206,36 @@ void sendErrorResponse(
}
}

public enum MessageDirection {
REQUEST,
RESPONSE,
RESPONSE_ERROR
}

private void sendMessage(
TcpChannel channel,
@Nullable String requestAction,
MessageDirection messageDirection,
String action,
Writeable writeable,
long requestId,
boolean isHandshake,
Compression.Scheme compressionScheme,
Compression.Scheme possibleCompressionScheme,
TransportVersion version,
ResponseStatsConsumer responseStatsConsumer,
Releasable onAfter
) throws IOException {
compressionScheme = writeable instanceof BytesTransportRequest ? null : compressionScheme;
assert action != null;
final var compressionScheme = writeable instanceof BytesTransportRequest ? null : possibleCompressionScheme;
final BytesReference message;
boolean serializeSuccess = false;
final boolean isError = writeable instanceof RemoteTransportException;
final RecyclerBytesStreamOutput byteStreamOutput = new RecyclerBytesStreamOutput(recycler);
try {
message = serialize(
requestAction,
messageDirection,
action,
requestId,
isHandshake,
version,
isError,
compressionScheme,
writeable,
threadPool.getThreadContext(),
Expand All @@ -242,14 +251,23 @@ private void sendMessage(
}
}
responseStatsConsumer.addResponseStats(message.length());
final var responseType = writeable.getClass();
final boolean compress = compressionScheme != null;
final var messageType = writeable.getClass();
internalSend(
channel,
message,
requestAction == null
? () -> "Response{" + requestId + "}{" + isError + "}{" + compress + "}{" + isHandshake + "}{" + responseType + "}"
: () -> "Request{" + requestAction + "}{" + requestId + "}{" + isError + "}{" + compress + "}{" + isHandshake + "}",
() -> (messageDirection == MessageDirection.REQUEST ? "Request{" : "Response{")
+ action
+ "}{id="
+ requestId
+ "}{err="
+ (messageDirection == MessageDirection.RESPONSE_ERROR)
+ "}{cs="
+ compressionScheme
+ "}{hs="
+ isHandshake
+ "}{t="
+ messageType
+ "}",
ActionListener.releasing(
message instanceof ReleasableBytesReference r
? Releasables.wrap(byteStreamOutput, onAfter, r)
Expand All @@ -260,11 +278,11 @@ private void sendMessage(

// public for tests
public static BytesReference serialize(
@Nullable String requestAction,
MessageDirection messageDirection,
String action,
long requestId,
boolean isHandshake,
TransportVersion version,
boolean isError,
Compression.Scheme compressionScheme,
Writeable writeable,
ThreadContext threadContext,
Expand All @@ -273,41 +291,43 @@ public static BytesReference serialize(
compressionScheme = compressionScheme == Compression.Scheme.LZ4 && version.before(Compression.Scheme.LZ4_VERSION)
? null
: compressionScheme;
assert action != null;
assert byteStreamOutput.position() == 0;
byteStreamOutput.setTransportVersion(version);
final int headerSize = TcpHeader.headerSize(version);
byteStreamOutput.skip(headerSize);
final int variableHeaderLength;
if (version.onOrAfter(TcpHeader.VERSION_WITH_HEADER_SIZE)) {
threadContext.writeTo(byteStreamOutput);
if (requestAction != null) {
if (messageDirection == MessageDirection.REQUEST) {
if (version.before(TransportVersions.V_8_0_0)) {
// empty features array
byteStreamOutput.writeStringArray(Strings.EMPTY_ARRAY);
}
byteStreamOutput.writeString(requestAction);
byteStreamOutput.writeString(action);
}
variableHeaderLength = Math.toIntExact(byteStreamOutput.position() - headerSize);
} else {
variableHeaderLength = -1;
}
BytesReference message = serializeMessageBody(
messageDirection,
writeable,
compressionScheme,
version,
byteStreamOutput,
variableHeaderLength,
threadContext,
requestAction
action
);
byte status = 0;
if (requestAction == null) {
if (messageDirection != MessageDirection.REQUEST) {
status = TransportStatus.setResponse(status);
}
if (isHandshake) {
status = TransportStatus.setHandshake(status);
}
if (isError) {
if (messageDirection == MessageDirection.RESPONSE_ERROR) {
status = TransportStatus.setError(status);
}
if (compressionScheme != null) {
Expand All @@ -319,6 +339,7 @@ public static BytesReference serialize(
}

private static BytesReference serializeMessageBody(
MessageDirection messageDirection,
Writeable writeable,
Compression.Scheme compressionScheme,
TransportVersion version,
Expand All @@ -334,12 +355,14 @@ private static BytesReference serializeMessageBody(
stream.setTransportVersion(version);
if (variableHeaderLength == -1) {
threadContext.writeTo(stream);
if (requestAction != null) {
if (messageDirection == MessageDirection.REQUEST) {
stream.writeStringArray(Strings.EMPTY_ARRAY);
stream.writeString(requestAction);
}
}
if (writeable instanceof BytesTransportRequest bRequest) {
assert stream == byteStreamOutput;
assert compressionScheme == null;
bRequest.writeThin(stream);
zeroCopyBuffer = bRequest.bytes;
} else if (writeable instanceof RemoteTransportException remoteTransportException) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,32 +56,17 @@ public void testDecode() throws IOException {
}

try (RecyclerBytesStreamOutput os = new RecyclerBytesStreamOutput(recycler)) {
final BytesReference totalBytes;
if (isRequest) {
totalBytes = OutboundHandler.serialize(
action,
requestId,
false,
TransportVersion.current(),
false,
null,
new TestRequest(randomAlphaOfLength(100)),
threadContext,
os
);
} else {
totalBytes = OutboundHandler.serialize(
null,
requestId,
false,
TransportVersion.current(),
false,
null,
new TestResponse(randomAlphaOfLength(100)),
threadContext,
os
);
}
final BytesReference totalBytes = OutboundHandler.serialize(
isRequest ? OutboundHandler.MessageDirection.REQUEST : OutboundHandler.MessageDirection.RESPONSE,
action,
requestId,
false,
TransportVersion.current(),
null,
isRequest ? new TestRequest(randomAlphaOfLength(100)) : new TestResponse(randomAlphaOfLength(100)),
threadContext,
os
);
int totalHeaderSize = TcpHeader.headerSize(TransportVersion.current()) + totalBytes.getInt(
TcpHeader.VARIABLE_HEADER_SIZE_POSITION
);
Expand Down Expand Up @@ -138,11 +123,11 @@ public void testDecodePreHeaderSizeVariableInt() throws IOException {
// 8.0 is only compatible with handshakes on a pre-variable int version
try (RecyclerBytesStreamOutput os = new RecyclerBytesStreamOutput(recycler)) {
final BytesReference totalBytes = OutboundHandler.serialize(
OutboundHandler.MessageDirection.REQUEST,
action,
requestId,
true,
preHeaderVariableInt,
false,
compressionScheme,
new TestRequest(contentValue),
threadContext,
Expand Down Expand Up @@ -195,11 +180,11 @@ public void testDecodeHandshakeV7Compatibility() throws IOException {
TransportVersion handshakeCompat = TransportHandshaker.V7_HANDSHAKE_VERSION;
try (RecyclerBytesStreamOutput os = new RecyclerBytesStreamOutput(recycler)) {
BytesReference bytes = OutboundHandler.serialize(
OutboundHandler.MessageDirection.REQUEST,
action,
requestId,
true,
handshakeCompat,
false,
null,
new TestRequest(randomAlphaOfLength(100)),
threadContext,
Expand Down Expand Up @@ -247,11 +232,11 @@ private void doHandshakeCompatibilityTest(TransportVersion transportVersion, Com
int totalHeaderSize = TcpHeader.headerSize(transportVersion);
try (RecyclerBytesStreamOutput os = new RecyclerBytesStreamOutput(recycler)) {
final BytesReference bytes = OutboundHandler.serialize(
OutboundHandler.MessageDirection.REQUEST,
action,
requestId,
true,
transportVersion,
false,
compressionScheme,
new TestRequest(randomAlphaOfLength(100)),
threadContext,
Expand Down Expand Up @@ -298,11 +283,11 @@ public void testClientChannelTypeFailsDecodingRequests() throws Exception {

try (RecyclerBytesStreamOutput os = new RecyclerBytesStreamOutput(recycler)) {
final BytesReference bytes = OutboundHandler.serialize(
OutboundHandler.MessageDirection.REQUEST,
action,
requestId,
isHandshake,
version,
false,
randomFrom(Compression.Scheme.DEFLATE, Compression.Scheme.LZ4, null),
new TestRequest(randomAlphaOfLength(100)),
threadContext,
Expand Down Expand Up @@ -348,11 +333,11 @@ public void testServerChannelTypeFailsDecodingResponses() throws Exception {

try (RecyclerBytesStreamOutput os = new RecyclerBytesStreamOutput(recycler)) {
final BytesReference bytes = OutboundHandler.serialize(
null,
OutboundHandler.MessageDirection.RESPONSE,
"test:action",
requestId,
isHandshake,
version,
false,
randomFrom(Compression.Scheme.DEFLATE, Compression.Scheme.LZ4, null),
new TestRequest(randomAlphaOfLength(100)),
threadContext,
Expand Down Expand Up @@ -388,38 +373,23 @@ public void testCompressedDecode() throws IOException {
} else {
threadContext.addResponseHeader(headerKey, headerValue);
}
final BytesReference totalBytes;
TransportMessage transportMessage;
Compression.Scheme scheme = randomFrom(Compression.Scheme.DEFLATE, Compression.Scheme.LZ4);

try (RecyclerBytesStreamOutput os = new RecyclerBytesStreamOutput(recycler)) {
if (isRequest) {
transportMessage = new TestRequest(randomAlphaOfLength(100));
totalBytes = OutboundHandler.serialize(
action,
requestId,
false,
TransportVersion.current(),
false,
scheme,
transportMessage,
threadContext,
os
);
} else {
transportMessage = new TestResponse(randomAlphaOfLength(100));
totalBytes = OutboundHandler.serialize(
null,
requestId,
false,
TransportVersion.current(),
false,
scheme,
transportMessage,
threadContext,
os
);
}
final TransportMessage transportMessage = isRequest
? new TestRequest(randomAlphaOfLength(100))
: new TestResponse(randomAlphaOfLength(100));
final BytesReference totalBytes = OutboundHandler.serialize(
isRequest ? OutboundHandler.MessageDirection.REQUEST : OutboundHandler.MessageDirection.RESPONSE,
action,
requestId,
false,
TransportVersion.current(),
scheme,
transportMessage,
threadContext,
os
);
final BytesStreamOutput out = new BytesStreamOutput();
transportMessage.writeTo(out);
final BytesReference uncompressedBytes = out.bytes();
Expand Down Expand Up @@ -479,11 +449,11 @@ public void testCompressedDecodeHandshakeCompatibility() throws IOException {
TransportVersion handshakeCompat = TransportHandshaker.V7_HANDSHAKE_VERSION;
try (RecyclerBytesStreamOutput os = new RecyclerBytesStreamOutput(recycler)) {
final BytesReference bytes = OutboundHandler.serialize(
OutboundHandler.MessageDirection.REQUEST,
action,
requestId,
true,
handshakeCompat,
false,
Compression.Scheme.DEFLATE,
new TestRequest(randomAlphaOfLength(100)),
threadContext,
Expand Down Expand Up @@ -517,11 +487,11 @@ public void testVersionIncompatibilityDecodeException() throws IOException {
final ReleasableBytesReference releasable1;
try (RecyclerBytesStreamOutput os = new RecyclerBytesStreamOutput(recycler)) {
final BytesReference bytes = OutboundHandler.serialize(
OutboundHandler.MessageDirection.REQUEST,
action,
requestId,
false,
incompatibleVersion,
false,
Compression.Scheme.DEFLATE,
new TestRequest(randomAlphaOfLength(100)),
threadContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,11 @@ public TestResponse read(StreamInput in) throws IOException {
String requestValue = randomAlphaOfLength(10);
BytesRefRecycler recycler = new BytesRefRecycler(PageCacheRecycler.NON_RECYCLING_INSTANCE);
BytesReference fullRequestBytes = OutboundHandler.serialize(
OutboundHandler.MessageDirection.REQUEST,
action,
requestId,
false,
TransportVersion.current(),
false,
null,
new TestRequest(requestValue),
threadPool.getThreadContext(),
Expand Down
Loading