Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ public InboundMessage finishAggregation() throws IOException {
checkBreaker(aggregated.getHeader(), aggregated.getContentLength(), breakerControl);
}
if (isShortCircuited()) {
aggregated.decRef();
aggregated.close();
success = true;
return new InboundMessage(aggregated.getHeader(), aggregationException);
} else {
Expand All @@ -131,7 +131,7 @@ public InboundMessage finishAggregation() throws IOException {
} finally {
resetCurrentAggregation();
if (success == false) {
aggregated.decRef();
aggregated.close();
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -87,21 +86,31 @@ void setSlowLogThreshold(TimeValue slowLogThreshold) {
this.slowLogThresholdMs = slowLogThreshold.getMillis();
}

/**
* @param message the transport message received, guaranteed to be closed by this method if it returns without exception.
* Callers must ensure that {@code message} is closed if this method throws an exception but must not release
* the message themselves otherwise
*/
void inboundMessage(TcpChannel channel, InboundMessage message) throws Exception {
final long startTime = threadPool.rawRelativeTimeInMillis();
channel.getChannelStats().markAccessed(startTime);
TransportLogger.logInboundMessage(channel, message);

if (message.isPing()) {
keepAlive.receiveKeepAlive(channel);
keepAlive.receiveKeepAlive(channel); // pings hold no resources, no need to close
} else {
messageReceived(channel, message, startTime);
messageReceived(channel, /* autocloses absent exception */ message, startTime);
}
}

// Empty stream constant to avoid instantiating a new stream for empty messages.
private static final StreamInput EMPTY_STREAM_INPUT = new ByteBufferStreamInput(ByteBuffer.wrap(BytesRef.EMPTY_BYTES));

/**
* @param message the transport message received, guaranteed to be closed by this method if it returns without exception.
* Callers must ensure that {@code message} is closed if this method throws an exception but must not release
* the message themselves otherwise
*/
private void messageReceived(TcpChannel channel, InboundMessage message, long startTime) throws IOException {
final InetSocketAddress remoteAddress = channel.getRemoteAddress();
final Header header = message.getHeader();
Expand All @@ -115,14 +124,16 @@ private void messageReceived(TcpChannel channel, InboundMessage message, long st
threadContext.setHeaders(header.getHeaders());
threadContext.putTransient("_remote_address", remoteAddress);
if (header.isRequest()) {
handleRequest(channel, message);
handleRequest(channel, /* autocloses absent exception */ message);
} else {
// Responses do not support short circuiting currently
assert message.isShortCircuit() == false;
responseHandler = findResponseHandler(header);
// ignore if its null, the service logs it
if (responseHandler != null) {
executeResponseHandler(message, responseHandler, remoteAddress);
executeResponseHandler( /* autocloses absent exception */ message, responseHandler, remoteAddress);
} else {
message.close();
}
}
} finally {
Expand All @@ -135,6 +146,11 @@ private void messageReceived(TcpChannel channel, InboundMessage message, long st
}
}

/**
* @param message the transport message received, guaranteed to be closed by this method if it returns without exception.
* Callers must ensure that {@code message} is closed if this method throws an exception but must not release
* the message themselves otherwise
*/
private void executeResponseHandler(
InboundMessage message,
TransportResponseHandler<?> responseHandler,
Expand All @@ -145,13 +161,13 @@ private void executeResponseHandler(
final StreamInput streamInput = namedWriteableStream(message.openOrGetStreamInput());
assert assertRemoteVersion(streamInput, header.getVersion());
if (header.isError()) {
handlerResponseError(streamInput, message, responseHandler);
handlerResponseError(streamInput, /* autocloses */ message, responseHandler);
} else {
handleResponse(remoteAddress, streamInput, responseHandler, message);
handleResponse(remoteAddress, streamInput, responseHandler, /* autocloses */ message);
}
} else {
assert header.isError() == false;
handleResponse(remoteAddress, EMPTY_STREAM_INPUT, responseHandler, message);
handleResponse(remoteAddress, EMPTY_STREAM_INPUT, responseHandler, /* autocloses */ message);
}
}

Expand Down Expand Up @@ -220,10 +236,15 @@ private void verifyResponseReadFully(Header header, TransportResponseHandler<?>
}
}

/**
* @param message the transport message received, guaranteed to be closed by this method if it returns without exception.
* Callers must ensure that {@code message} is closed if this method throws an exception but must not release
* the message themselves otherwise
*/
private <T extends TransportRequest> void handleRequest(TcpChannel channel, InboundMessage message) throws IOException {
final Header header = message.getHeader();
if (header.isHandshake()) {
handleHandshakeRequest(channel, message);
handleHandshakeRequest(channel, /* autocloses */ message);
return;
}

Expand All @@ -243,7 +264,7 @@ private <T extends TransportRequest> void handleRequest(TcpChannel channel, Inbo
Releasables.assertOnce(message.takeBreakerReleaseControl())
);

try {
try (message) {
messageListener.onRequestReceived(requestId, action);
if (reg != null) {
reg.addRequestStats(header.getNetworkMessageSize() + TcpHeader.BYTES_REQUIRED_FOR_MESSAGE_SIZE);
Expand All @@ -260,6 +281,7 @@ private <T extends TransportRequest> void handleRequest(TcpChannel channel, Inbo
final T request;
try {
request = reg.newRequest(stream);
message.close(); // eager release message to save heap
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we drop this one line from this PR and do it in an immediate follow-up? That way we have one commit that replaces refcounting with one-shot closing without changing semantics and then another that reduces the lifecycle.

Copy link
Contributor Author

@original-brownbear original-brownbear Apr 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++ makes sense dropped it here

} catch (Exception e) {
assert ignoreDeserializationErrors : e;
throw e;
Expand Down Expand Up @@ -331,6 +353,9 @@ public void onAfter() {
}
}

/**
* @param message guaranteed to get closed by this method
*/
private void handleHandshakeRequest(TcpChannel channel, InboundMessage message) throws IOException {
var header = message.getHeader();
assert header.actionName.equals(TransportHandshaker.HANDSHAKE_ACTION_NAME);
Expand All @@ -351,7 +376,7 @@ private void handleHandshakeRequest(TcpChannel channel, InboundMessage message)
true,
Releasables.assertOnce(message.takeBreakerReleaseControl())
);
try {
try (message) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Admittedly annoying to spread the release logic across so many spots but in the end that is precisely what we need/want, release right where we deserialize. (If we want this cleaner we need to deserialise in fewer spots I guess?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is the same thing as refcount-negative methods: we need to find a way to make the fact that the method closes one of its arguments explicit at the call site.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or alternatively (and I'm starting to think this is more practical), any method that consumes an object like that must either pass it on elsewhere (and do so exactly once) or release it.
Just documenting that something closes the argument still leaves a lot of complexity around having to check that the last receiver of the instances is one of those closing-enabled methods and that other methods don't start holding on to instances doesn't it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The trouble is I don't think this is (or could be made) true of every method that takes an argument which represents nontrivial resources. We need to support both.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... still leaves a lot of complexity around having to check...

sure, it's not a fully automated answer (i.e. a proper borrow-checker) but at least it reduces those checks to method-local things rather than having to reason about things up and down the stack some arbitrary distance.

handshaker.handleHandshake(transportChannel, requestId, stream);
} catch (Exception e) {
logger.warn(
Expand All @@ -371,29 +396,30 @@ private static void sendErrorResponse(String actionName, TransportChannel transp
}
}

/**
* @param message guaranteed to get closed by this method
*/
private <T extends TransportResponse> void handleResponse(
InetSocketAddress remoteAddress,
final StreamInput stream,
final TransportResponseHandler<T> handler,
final InboundMessage inboundMessage
final InboundMessage message
) {
final var executor = handler.executor();
if (executor == EsExecutors.DIRECT_EXECUTOR_SERVICE) {
// no need to provide a buffer release here, we never escape the buffer when handling directly
doHandleResponse(handler, remoteAddress, stream, inboundMessage.getHeader(), () -> {});
doHandleResponse(handler, remoteAddress, stream, /* autocloses */ message);
} else {
inboundMessage.mustIncRef();
// release buffer once we deserialize the message, but have a fail-safe in #onAfter below in case that didn't work out
final Releasable releaseBuffer = Releasables.releaseOnce(inboundMessage::decRef);
executor.execute(new ForkingResponseHandlerRunnable(handler, null) {
@Override
protected void doRun() {
doHandleResponse(handler, remoteAddress, stream, inboundMessage.getHeader(), releaseBuffer);
doHandleResponse(handler, remoteAddress, stream, /* autocloses */ message);
}

@Override
public void onAfter() {
Releasables.closeExpectNoException(releaseBuffer);
message.close();
}
});
}
Expand All @@ -404,20 +430,19 @@ public void onAfter() {
* @param handler response handler
* @param remoteAddress remote address that the message was sent from
* @param stream bytes stream for reading the message
* @param header message header
* @param releaseResponseBuffer releasable that will be released once the message has been read from the {@code stream}
* @param inboundMessage inbound message, guaranteed to get closed by this method
* @param <T> response message type
*/
private <T extends TransportResponse> void doHandleResponse(
TransportResponseHandler<T> handler,
InetSocketAddress remoteAddress,
final StreamInput stream,
final Header header,
Releasable releaseResponseBuffer
InboundMessage inboundMessage
) {
final T response;
try (releaseResponseBuffer) {
try (inboundMessage) {
response = handler.read(stream);
verifyResponseReadFully(inboundMessage.getHeader(), handler, stream);
} catch (Exception e) {
final TransportException serializationException = new TransportSerializationException(
"Failed to deserialize response from handler [" + handler + "]",
Expand All @@ -429,7 +454,6 @@ private <T extends TransportResponse> void doHandleResponse(
return;
}
try {
verifyResponseReadFully(header, handler, stream);
handler.handleResponse(response);
} catch (Exception e) {
doHandleException(handler, new ResponseHandlerFailureTransportException(e));
Expand All @@ -438,9 +462,12 @@ private <T extends TransportResponse> void doHandleResponse(
}
}

/**
* @param message guaranteed to get closed by this method
*/
private void handlerResponseError(StreamInput stream, InboundMessage message, final TransportResponseHandler<?> handler) {
Exception error;
try {
try (message) {
error = stream.readException();
verifyResponseReadFully(message.getHeader(), handler, stream);
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.core.AbstractRefCounted;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.core.Releasable;

import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.VarHandle;
import java.util.Objects;

public class InboundMessage extends AbstractRefCounted {
public class InboundMessage implements Releasable {

private final Header header;
private final ReleasableBytesReference content;
Expand All @@ -28,6 +29,19 @@ public class InboundMessage extends AbstractRefCounted {
private Releasable breakerRelease;
private StreamInput streamInput;

@SuppressWarnings("unused") // updated via CLOSED (and _only_ via CLOSED)
private boolean closed;

private static final VarHandle CLOSED;

static {
try {
CLOSED = MethodHandles.lookup().findVarHandle(InboundMessage.class, "closed", boolean.class);
} catch (Exception e) {
throw new ExceptionInInitializerError(e);
}
}

public InboundMessage(Header header, ReleasableBytesReference content, Releasable breakerRelease) {
this.header = header;
this.content = content;
Expand Down Expand Up @@ -84,7 +98,7 @@ public Releasable takeBreakerReleaseControl() {

public StreamInput openOrGetStreamInput() throws IOException {
assert isPing == false && content != null;
assert hasReferences();
assert (boolean) CLOSED.getAcquire(this) == false;
if (streamInput == null) {
streamInput = content.streamInput();
streamInput.setTransportVersion(header.getVersion());
Expand All @@ -98,7 +112,10 @@ public String toString() {
}

@Override
protected void closeInternal() {
public void close() {
if (CLOSED.compareAndSet(this, false, true) == false) {
return;
}
try {
IOUtils.close(streamInput, content, breakerRelease);
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,13 @@ private void forwardFragment(TcpChannel channel, Object fragment) throws IOExcep
InboundMessage aggregated = aggregator.finishAggregation();
try {
statsTracker.markMessageReceived();
messageHandler.accept(channel, aggregated);
messageHandler.accept(channel, /* autocloses */ aggregated);
aggregated = null;
} finally {
aggregated.decRef();
if (aggregated != null) {
// TODO doesn't messageHandler auto-close always?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looking at messageHandler (i.e. TcpTransport#inboundMessage) I think we close on all paths already, and catch everything anyway, so this seems redundant?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right ... let me fix that right up :)

aggregated.close();
}
}
} else {
assert aggregator.isAggregating();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -813,9 +813,14 @@ protected void serverAcceptedChannel(TcpChannel channel) {
*/
public void inboundMessage(TcpChannel channel, InboundMessage message) {
try {
inboundHandler.inboundMessage(channel, message);
inboundHandler.inboundMessage(channel, /* autocloses absent exception */ message);
message = null;
} catch (Exception e) {
onException(channel, e);
} finally {
if (message != null) {
message.close();
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ public void testInboundAggregation() throws IOException {
for (ReleasableBytesReference reference : references) {
assertTrue(reference.hasReferences());
}
aggregated.decRef();
aggregated.close();
for (ReleasableBytesReference reference : references) {
assertFalse(reference.hasReferences());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public void testPipelineHandling() throws IOException {
final List<Tuple<MessageData, Exception>> actual = new ArrayList<>();
final List<ReleasableBytesReference> toRelease = new ArrayList<>();
final BiConsumer<TcpChannel, InboundMessage> messageHandler = (c, m) -> {
try {
try (m) {
final Header header = m.getHeader();
final MessageData actualData;
final TransportVersion version = header.getVersion();
Expand Down Expand Up @@ -204,7 +204,7 @@ private static Compression.Scheme getCompressionScheme() {
}

public void testDecodeExceptionIsPropagated() throws IOException {
BiConsumer<TcpChannel, InboundMessage> messageHandler = (c, m) -> {};
BiConsumer<TcpChannel, InboundMessage> messageHandler = (c, m) -> m.close();
final StatsTracker statsTracker = new StatsTracker();
final LongSupplier millisSupplier = () -> TimeValue.nsecToMSec(System.nanoTime());
final InboundDecoder decoder = new InboundDecoder(recycler);
Expand Down Expand Up @@ -245,7 +245,7 @@ public void testDecodeExceptionIsPropagated() throws IOException {
}

public void testEnsureBodyIsNotPrematurelyReleased() throws IOException {
BiConsumer<TcpChannel, InboundMessage> messageHandler = (c, m) -> {};
BiConsumer<TcpChannel, InboundMessage> messageHandler = (c, m) -> m.close();
final StatsTracker statsTracker = new StatsTracker();
final LongSupplier millisSupplier = () -> TimeValue.nsecToMSec(System.nanoTime());
final InboundDecoder decoder = new InboundDecoder(recycler);
Expand Down