Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.common.io.stream.RecyclerBytesStreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.network.CloseableChannel;
import org.elasticsearch.common.network.HandlingTimeTracker;
import org.elasticsearch.common.recycler.Recycler;
Expand Down Expand Up @@ -115,6 +116,7 @@ void sendRequest(
);
sendMessage(
channel,
request,
message,
ResponseStatsConsumer.NONE,
() -> messageListener.onRequestSent(node, requestId, action, request, options)
Expand Down Expand Up @@ -148,7 +150,7 @@ void sendResponse(
);
assert response.hasReferences();
try {
sendMessage(channel, message, responseStatsConsumer, () -> messageListener.onResponseSent(requestId, action, response));
sendMessage(channel, response, message, responseStatsConsumer, () -> messageListener.onResponseSent(requestId, action));
} catch (Exception ex) {
if (isHandshake) {
logger.error(
Expand Down Expand Up @@ -178,16 +180,17 @@ void sendErrorResponse(
final Exception error
) {
assert assertValidTransportVersion(transportVersion);
var msg = new RemoteTransportException(nodeName, channel.getLocalAddress(), action, error);
OutboundMessage.Response message = new OutboundMessage.Response(
threadPool.getThreadContext(),
new RemoteTransportException(nodeName, channel.getLocalAddress(), action, error),
msg,
transportVersion,
requestId,
false,
null
);
try {
sendMessage(channel, message, responseStatsConsumer, () -> messageListener.onResponseSent(requestId, action, error));
sendMessage(channel, msg, message, responseStatsConsumer, () -> messageListener.onResponseSent(requestId, action, error));
} catch (Exception sendException) {
sendException.addSuppressed(error);
logger.error(() -> format("Failed to send error response on channel [%s], closing channel", channel), sendException);
Expand All @@ -197,6 +200,7 @@ void sendErrorResponse(

private void sendMessage(
TcpChannel channel,
Writeable writeable,
OutboundMessage networkMessage,
ResponseStatsConsumer responseStatsConsumer,
Releasable onAfter
Expand All @@ -214,7 +218,7 @@ private void sendMessage(
final BytesReference message;
boolean serializeSuccess = false;
try {
message = networkMessage.serialize(byteStreamOutput);
message = networkMessage.serialize(writeable, byteStreamOutput);
serializeSuccess = true;
} catch (Exception e) {
logger.warn(() -> "failed to serialize outbound message [" + networkMessage + "]", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,17 @@

abstract class OutboundMessage extends NetworkMessage {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no usages of NetworkMessage left now either?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right deleted it :)


protected final Writeable message;

OutboundMessage(
ThreadContext threadContext,
TransportVersion version,
byte status,
long requestId,
Compression.Scheme compressionScheme,
Writeable message
Compression.Scheme compressionScheme
) {
super(threadContext, version, status, requestId, compressionScheme);
this.message = message;
}

BytesReference serialize(RecyclerBytesStreamOutput bytesStream) throws IOException {
BytesReference serialize(Writeable message, RecyclerBytesStreamOutput bytesStream) throws IOException {
bytesStream.setTransportVersion(version);
bytesStream.skip(TcpHeader.HEADER_SIZE);

Expand Down Expand Up @@ -74,12 +70,12 @@ BytesReference serialize(RecyclerBytesStreamOutput bytesStream) throws IOExcepti
stream.close();
}
}
final BytesReference message = bytesStream.bytes();
final BytesReference msg = bytesStream.bytes();
if (zeroCopyBuffer.length() == 0) {
reference = message;
reference = msg;
} else {
zeroCopyBuffer.mustIncRef();
reference = new ReleasableBytesReference(CompositeBytesReference.of(message, zeroCopyBuffer), (RefCounted) zeroCopyBuffer);
reference = new ReleasableBytesReference(CompositeBytesReference.of(msg, zeroCopyBuffer), (RefCounted) zeroCopyBuffer);
}

bytesStream.seek(0);
Expand Down Expand Up @@ -119,7 +115,7 @@ static class Request extends OutboundMessage {
boolean isHandshake,
Compression.Scheme compressionScheme
) {
super(threadContext, version, setStatus(isHandshake), requestId, adjustCompressionScheme(compressionScheme, message), message);
super(threadContext, version, setStatus(isHandshake), requestId, adjustCompressionScheme(compressionScheme, message));
this.action = action;
}

Expand Down Expand Up @@ -168,7 +164,7 @@ static class Response extends OutboundMessage {
boolean isHandshake,
Compression.Scheme compressionScheme
) {
super(threadContext, version, setStatus(isHandshake, message), requestId, compressionScheme, message);
super(threadContext, version, setStatus(isHandshake, message), requestId, compressionScheme);
}

private static byte setStatus(boolean isHandshake, Writeable message) {
Expand All @@ -186,17 +182,7 @@ private static byte setStatus(boolean isHandshake, Writeable message) {

@Override
public String toString() {
return "Response{"
+ requestId
+ "}{"
+ isError()
+ "}{"
+ isCompress()
+ "}{"
+ isHandshake()
+ "}{"
+ message.getClass()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather we kept some way to identify the response type, although I must say I never much liked sharing its class here. Could we plumb in the action from the TcpTransportChannel down to here first?

+ "}";
return "Response{" + requestId + "}{" + isError() + "}{" + isCompress() + "}{" + isHandshake() + "}";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,8 @@ default void onRequestReceived(long requestId, String action) {}
* Called for every action response sent after the response has been passed to the underlying network implementation.
* @param requestId the request ID (unique per client)
* @param action the request action
* @param response the response send
*/
default void onResponseSent(long requestId, String action, TransportResponse response) {}
default void onResponseSent(long requestId, String action) {}

/***
* Called for every failed action response after the response has been passed to the underlying network implementation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1288,7 +1288,7 @@ public void onResponseReceived(long requestId, Transport.ResponseContext holder)

/** called by the {@link Transport} implementation once a response was sent to calling node */
@Override
public void onResponseSent(long requestId, String action, TransportResponse response) {
public void onResponseSent(long requestId, String action) {
if (tracerLog.isTraceEnabled() && shouldTraceAction(action)) {
tracerLog.trace("[{}][{}] sent response", requestId, action);
}
Expand Down Expand Up @@ -1541,7 +1541,7 @@ public String getProfileName() {

@Override
public void sendResponse(TransportResponse response) {
service.onResponseSent(requestId, action, response);
service.onResponseSent(requestId, action);
try (var shutdownBlock = service.pendingDirectHandlers.withRef()) {
if (shutdownBlock == null) {
// already shutting down, the handler will be completed by sendRequestInternal or doStop
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.RecyclerBytesStreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.MockPageCacheRecycler;
import org.elasticsearch.common.util.concurrent.ThreadContext;
Expand Down Expand Up @@ -55,29 +56,17 @@ public void testDecode() throws IOException {
threadContext.addResponseHeader(headerKey, headerValue);
}
OutboundMessage message;
final Writeable body;
if (isRequest) {
message = new OutboundMessage.Request(
threadContext,
new TestRequest(randomAlphaOfLength(100)),
TransportVersion.current(),
action,
requestId,
false,
null
);
body = new TestRequest(randomAlphaOfLength(100));
message = new OutboundMessage.Request(threadContext, body, TransportVersion.current(), action, requestId, false, null);
} else {
message = new OutboundMessage.Response(
threadContext,
new TestResponse(randomAlphaOfLength(100)),
TransportVersion.current(),
requestId,
false,
null
);
body = new TestResponse(randomAlphaOfLength(100));
message = new OutboundMessage.Response(threadContext, body, TransportVersion.current(), requestId, false, null);
}

try (RecyclerBytesStreamOutput os = new RecyclerBytesStreamOutput(recycler)) {
final BytesReference totalBytes = message.serialize(os);
final BytesReference totalBytes = message.serialize(body, os);
int totalHeaderSize = TcpHeader.HEADER_SIZE + totalBytes.getInt(TcpHeader.VARIABLE_HEADER_SIZE_POSITION);
final BytesReference messageBytes = totalBytes.slice(totalHeaderSize, totalBytes.length() - totalHeaderSize);

Expand Down Expand Up @@ -137,9 +126,10 @@ private void doHandshakeCompatibilityTest(TransportVersion transportVersion, Com
final String headerKey = randomAlphaOfLength(10);
final String headerValue = randomAlphaOfLength(20);
threadContext.putHeader(headerKey, headerValue);
var body = new TestRequest(randomAlphaOfLength(100));
OutboundMessage message = new OutboundMessage.Request(
threadContext,
new TestRequest(randomAlphaOfLength(100)),
body,
transportVersion,
action,
requestId,
Expand All @@ -148,7 +138,7 @@ private void doHandshakeCompatibilityTest(TransportVersion transportVersion, Com
);

try (RecyclerBytesStreamOutput os = new RecyclerBytesStreamOutput(recycler)) {
final BytesReference bytes = message.serialize(os);
final BytesReference bytes = message.serialize(body, os);

InboundDecoder decoder = new InboundDecoder(recycler);
final ArrayList<Object> fragments = new ArrayList<>();
Expand Down Expand Up @@ -187,9 +177,10 @@ public void testClientChannelTypeFailsDecodingRequests() throws Exception {
? randomFrom(TransportHandshaker.ALLOWED_HANDSHAKE_VERSIONS)
: TransportVersionUtils.randomCompatibleVersion(random());
logger.info("--> version = {}", version);
var req = new TestRequest(randomAlphaOfLength(100));
OutboundMessage message = new OutboundMessage.Request(
threadContext,
new TestRequest(randomAlphaOfLength(100)),
req,
version,
action,
requestId,
Expand All @@ -198,7 +189,7 @@ public void testClientChannelTypeFailsDecodingRequests() throws Exception {
);

try (RecyclerBytesStreamOutput os = new RecyclerBytesStreamOutput(recycler)) {
final BytesReference bytes = message.serialize(os);
final BytesReference bytes = message.serialize(req, os);
try (InboundDecoder clientDecoder = new InboundDecoder(recycler, ChannelType.CLIENT)) {
IllegalArgumentException e = expectThrows(
IllegalArgumentException.class,
Expand Down Expand Up @@ -234,17 +225,18 @@ public void testServerChannelTypeFailsDecodingResponses() throws Exception {
final var version = isHandshake
? randomFrom(TransportHandshaker.ALLOWED_HANDSHAKE_VERSIONS)
: TransportVersionUtils.randomCompatibleVersion(random());
var resp = new TestResponse(randomAlphaOfLength(100));
OutboundMessage message = new OutboundMessage.Response(
threadContext,
new TestResponse(randomAlphaOfLength(100)),
resp,
version,
requestId,
isHandshake,
randomFrom(Compression.Scheme.DEFLATE, Compression.Scheme.LZ4, null)
);

try (RecyclerBytesStreamOutput os = new RecyclerBytesStreamOutput(recycler)) {
final BytesReference bytes = message.serialize(os);
final BytesReference bytes = message.serialize(resp, os);
try (InboundDecoder decoder = new InboundDecoder(recycler, ChannelType.SERVER)) {
final ReleasableBytesReference releasable1 = wrapAsReleasable(bytes);
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> decoder.decode(releasable1, ignored -> {}));
Expand Down Expand Up @@ -293,7 +285,7 @@ public void testCompressedDecode() throws IOException {
}

try (RecyclerBytesStreamOutput os = new RecyclerBytesStreamOutput(recycler)) {
final BytesReference totalBytes = message.serialize(os);
final BytesReference totalBytes = message.serialize(transportMessage, os);
final BytesStreamOutput out = new BytesStreamOutput();
transportMessage.writeTo(out);
final BytesReference uncompressedBytes = out.bytes();
Expand Down Expand Up @@ -351,9 +343,10 @@ public void testVersionIncompatibilityDecodeException() throws IOException {
String action = "test-request";
long requestId = randomNonNegativeLong();
TransportVersion incompatibleVersion = TransportVersionUtils.getPreviousVersion(TransportVersions.MINIMUM_COMPATIBLE);
var req = new TestRequest(randomAlphaOfLength(10));
OutboundMessage message = new OutboundMessage.Request(
threadContext,
new TestRequest(randomAlphaOfLength(100)),
req,
incompatibleVersion,
action,
requestId,
Expand All @@ -363,7 +356,7 @@ public void testVersionIncompatibilityDecodeException() throws IOException {

final ReleasableBytesReference releasable1;
try (RecyclerBytesStreamOutput os = new RecyclerBytesStreamOutput(recycler)) {
final BytesReference bytes = message.serialize(os);
final BytesReference bytes = message.serialize(req, os);

InboundDecoder decoder = new InboundDecoder(recycler);
final ArrayList<Object> fragments = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.RecyclerBytesStreamOutput;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.network.HandlingTimeTracker;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.PageCacheRecycler;
Expand Down Expand Up @@ -170,9 +171,10 @@ public TestResponse read(StreamInput in) throws IOException {
);
requestHandlers.registerHandler(registry);
String requestValue = randomAlphaOfLength(10);
final Writeable body = new TestRequest(requestValue);
OutboundMessage.Request request = new OutboundMessage.Request(
threadPool.getThreadContext(),
new TestRequest(requestValue),
body,
TransportVersion.current(),
action,
requestId,
Expand All @@ -181,7 +183,7 @@ public TestResponse read(StreamInput in) throws IOException {
);

BytesRefRecycler recycler = new BytesRefRecycler(PageCacheRecycler.NON_RECYCLING_INSTANCE);
BytesReference fullRequestBytes = request.serialize(new RecyclerBytesStreamOutput(recycler));
BytesReference fullRequestBytes = request.serialize(body, new RecyclerBytesStreamOutput(recycler));
BytesReference requestContent = fullRequestBytes.slice(TcpHeader.HEADER_SIZE, fullRequestBytes.length() - TcpHeader.HEADER_SIZE);
Header requestHeader = new Header(
fullRequestBytes.length() - 6,
Expand Down
Loading
Loading