Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ public InboundMessage finishAggregation() throws IOException {
checkBreaker(aggregated.getHeader(), aggregated.getContentLength(), breakerControl);
}
if (isShortCircuited()) {
aggregated.decRef();
aggregated.close();
success = true;
return new InboundMessage(aggregated.getHeader(), aggregationException);
} else {
Expand All @@ -131,7 +131,7 @@ public InboundMessage finishAggregation() throws IOException {
} finally {
resetCurrentAggregation();
if (success == false) {
aggregated.decRef();
aggregated.close();
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -123,6 +122,8 @@ private void messageReceived(TcpChannel channel, InboundMessage message, long st
// ignore if its null, the service logs it
if (responseHandler != null) {
executeResponseHandler(message, responseHandler, remoteAddress);
} else {
message.close();
}
}
} finally {
Expand Down Expand Up @@ -258,7 +259,7 @@ private <T extends TransportRequest> void handleRequest(TcpChannel channel, Inbo
final StreamInput stream = namedWriteableStream(message.openOrGetStreamInput());
assert assertRemoteVersion(stream, header.getVersion());
final T request;
try {
try (message) {
request = reg.newRequest(stream);
} catch (Exception e) {
assert ignoreDeserializationErrors : e;
Expand Down Expand Up @@ -351,7 +352,7 @@ private void handleHandshakeRequest(TcpChannel channel, InboundMessage message)
true,
Releasables.assertOnce(message.takeBreakerReleaseControl())
);
try {
try (message) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Admittedly annoying to spread the release logic across so many spots but in the end that is precisely what we need/want, release right where we deserialize. (If we want this cleaner we need to deserialise in fewer spots I guess?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is the same thing as refcount-negative methods: we need to find a way to make the fact that the method closes one of its arguments explicit at the call site.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or alternatively (and I'm starting to think this is more practical), any method that consumes an object like that must either pass it on elsewhere (and do so exactly once) or release it.
Just documenting that something closes the argument still leaves a lot of complexity around having to check that the last receiver of the instances is one of those closing-enabled methods and that other methods don't start holding on to instances doesn't it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The trouble is I don't think this is (or could be made) true of every method that takes an argument which represents nontrivial resources. We need to support both.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... still leaves a lot of complexity around having to check...

sure, it's not a fully automated answer (i.e. a proper borrow-checker) but at least it reduces those checks to method-local things rather than having to reason about things up and down the stack some arbitrary distance.

handshaker.handleHandshake(transportChannel, requestId, stream);
} catch (Exception e) {
logger.warn(
Expand Down Expand Up @@ -380,20 +381,18 @@ private <T extends TransportResponse> void handleResponse(
final var executor = handler.executor();
if (executor == EsExecutors.DIRECT_EXECUTOR_SERVICE) {
// no need to provide a buffer release here, we never escape the buffer when handling directly
doHandleResponse(handler, remoteAddress, stream, inboundMessage.getHeader(), () -> {});
doHandleResponse(handler, remoteAddress, stream, inboundMessage);
} else {
inboundMessage.mustIncRef();
// release buffer once we deserialize the message, but have a fail-safe in #onAfter below in case that didn't work out
final Releasable releaseBuffer = Releasables.releaseOnce(inboundMessage::decRef);
executor.execute(new ForkingResponseHandlerRunnable(handler, null) {
@Override
protected void doRun() {
doHandleResponse(handler, remoteAddress, stream, inboundMessage.getHeader(), releaseBuffer);
doHandleResponse(handler, remoteAddress, stream, inboundMessage);
}

@Override
public void onAfter() {
Releasables.closeExpectNoException(releaseBuffer);
inboundMessage.close();
}
});
}
Expand All @@ -404,20 +403,19 @@ public void onAfter() {
* @param handler response handler
* @param remoteAddress remote address that the message was sent from
* @param stream bytes stream for reading the message
* @param header message header
* @param releaseResponseBuffer releasable that will be released once the message has been read from the {@code stream}
* @param inboundMessage inbound message
* @param <T> response message type
*/
private <T extends TransportResponse> void doHandleResponse(
TransportResponseHandler<T> handler,
InetSocketAddress remoteAddress,
final StreamInput stream,
final Header header,
Releasable releaseResponseBuffer
InboundMessage inboundMessage
) {
final T response;
try (releaseResponseBuffer) {
try (inboundMessage) {
response = handler.read(stream);
verifyResponseReadFully(inboundMessage.getHeader(), handler, stream);
} catch (Exception e) {
final TransportException serializationException = new TransportSerializationException(
"Failed to deserialize response from handler [" + handler + "]",
Expand All @@ -429,7 +427,6 @@ private <T extends TransportResponse> void doHandleResponse(
return;
}
try {
verifyResponseReadFully(header, handler, stream);
handler.handleResponse(response);
} catch (Exception e) {
doHandleException(handler, new ResponseHandlerFailureTransportException(e));
Expand All @@ -440,7 +437,7 @@ private <T extends TransportResponse> void doHandleResponse(

private void handlerResponseError(StreamInput stream, InboundMessage message, final TransportResponseHandler<?> handler) {
Exception error;
try {
try (message) {
error = stream.readException();
verifyResponseReadFully(message.getHeader(), handler, stream);
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.core.AbstractRefCounted;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.core.Releasable;

import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.VarHandle;
import java.util.Objects;

public class InboundMessage extends AbstractRefCounted {
public class InboundMessage implements Releasable {

private final Header header;
private final ReleasableBytesReference content;
Expand All @@ -28,6 +29,18 @@ public class InboundMessage extends AbstractRefCounted {
private Releasable breakerRelease;
private StreamInput streamInput;

private boolean closed;

private static final VarHandle CLOSED;

static {
try {
CLOSED = MethodHandles.lookup().findVarHandle(InboundMessage.class, "closed", boolean.class);
} catch (Exception e) {
throw new ExceptionInInitializerError(e);
}
}

public InboundMessage(Header header, ReleasableBytesReference content, Releasable breakerRelease) {
this.header = header;
this.content = content;
Expand Down Expand Up @@ -84,7 +97,7 @@ public Releasable takeBreakerReleaseControl() {

public StreamInput openOrGetStreamInput() throws IOException {
assert isPing == false && content != null;
assert hasReferences();
assert (boolean) CLOSED.getAcquire(this) == false;
if (streamInput == null) {
streamInput = content.streamInput();
streamInput.setTransportVersion(header.getVersion());
Expand All @@ -98,7 +111,10 @@ public String toString() {
}

@Override
protected void closeInternal() {
public void close() {
if (CLOSED.compareAndSet(this, false, true) == false) {
return;
}
try {
IOUtils.close(streamInput, content, breakerRelease);
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,11 @@ private void forwardFragment(TcpChannel channel, Object fragment) throws IOExcep
try {
statsTracker.markMessageReceived();
messageHandler.accept(channel, aggregated);
aggregated = null;
} finally {
aggregated.decRef();
if (aggregated != null) {
aggregated.close();
}
}
} else {
assert aggregator.isAggregating();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,7 @@ public void inboundMessage(TcpChannel channel, InboundMessage message) {
try {
inboundHandler.inboundMessage(channel, message);
} catch (Exception e) {
message.close();
onException(channel, e);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ public void testInboundAggregation() throws IOException {
for (ReleasableBytesReference reference : references) {
assertTrue(reference.hasReferences());
}
aggregated.decRef();
aggregated.close();
for (ReleasableBytesReference reference : references) {
assertFalse(reference.hasReferences());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public void testPipelineHandling() throws IOException {
final List<Tuple<MessageData, Exception>> actual = new ArrayList<>();
final List<ReleasableBytesReference> toRelease = new ArrayList<>();
final BiConsumer<TcpChannel, InboundMessage> messageHandler = (c, m) -> {
try {
try (m) {
final Header header = m.getHeader();
final MessageData actualData;
final TransportVersion version = header.getVersion();
Expand Down Expand Up @@ -204,7 +204,7 @@ private static Compression.Scheme getCompressionScheme() {
}

public void testDecodeExceptionIsPropagated() throws IOException {
BiConsumer<TcpChannel, InboundMessage> messageHandler = (c, m) -> {};
BiConsumer<TcpChannel, InboundMessage> messageHandler = (c, m) -> m.close();
final StatsTracker statsTracker = new StatsTracker();
final LongSupplier millisSupplier = () -> TimeValue.nsecToMSec(System.nanoTime());
final InboundDecoder decoder = new InboundDecoder(recycler);
Expand Down Expand Up @@ -245,7 +245,7 @@ public void testDecodeExceptionIsPropagated() throws IOException {
}

public void testEnsureBodyIsNotPrematurelyReleased() throws IOException {
BiConsumer<TcpChannel, InboundMessage> messageHandler = (c, m) -> {};
BiConsumer<TcpChannel, InboundMessage> messageHandler = (c, m) -> m.close();
final StatsTracker statsTracker = new StatsTracker();
final LongSupplier millisSupplier = () -> TimeValue.nsecToMSec(System.nanoTime());
final InboundDecoder decoder = new InboundDecoder(recycler);
Expand Down