Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/130303.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 130303
summary: Drain responses on completion for `TransportNodesAction`
area: Distributed
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import java.util.List;
import java.util.Objects;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicBoolean;

import static org.elasticsearch.core.Strings.format;

Expand Down Expand Up @@ -99,6 +100,7 @@ protected void doExecute(Task task, NodesRequest request, ActionListener<NodesRe
final ActionContext actionContext = createActionContext(task, request);
final ArrayList<NodeResponse> responses = new ArrayList<>(concreteNodes.length);
final ArrayList<FailedNodeException> exceptions = new ArrayList<>(0);
final AtomicBoolean responsesHandled = new AtomicBoolean(false);

final TransportRequestOptions transportRequestOptions = TransportRequestOptions.timeout(request.timeout());

Expand All @@ -109,12 +111,14 @@ protected void doExecute(Task task, NodesRequest request, ActionListener<NodesRe
private void addReleaseOnCancellationListener() {
if (task instanceof CancellableTask cancellableTask) {
cancellableTask.addListener(() -> {
final List<NodeResponse> drainedResponses;
synchronized (responses) {
drainedResponses = List.copyOf(responses);
responses.clear();
if (responsesHandled.compareAndSet(false, true)) {
final List<NodeResponse> drainedResponses;
synchronized (responses) {
drainedResponses = List.copyOf(responses);
responses.clear();
}
Releasables.wrap(Iterators.map(drainedResponses.iterator(), r -> r::decRef)).close();
}
Releasables.wrap(Iterators.map(drainedResponses.iterator(), r -> r::decRef)).close();
});
}
}
Expand Down Expand Up @@ -161,10 +165,18 @@ protected void onItemFailure(DiscoveryNode discoveryNode, Exception e) {

@Override
protected CheckedConsumer<ActionListener<NodesResponse>, Exception> onCompletion() {
// ref releases all happen-before here so no need to be synchronized
return l -> {
try (var ignored = Releasables.wrap(Iterators.map(responses.iterator(), r -> r::decRef))) {
newResponseAsync(task, request, actionContext, responses, exceptions, l);
if (responsesHandled.compareAndSet(false, true)) {
// ref releases all happen-before here so no need to be synchronized
try (var ignored = Releasables.wrap(Iterators.map(responses.iterator(), r -> r::decRef))) {
newResponseAsync(task, request, actionContext, responses, exceptions, l);
}
} else {
logger.debug("task cancelled after all responses were collected");
assert task instanceof CancellableTask : "expect CancellableTask, but got: " + task;
final var cancellableTask = (CancellableTask) task;
assert cancellableTask.isCancelled();
cancellableTask.notifyIfCancelled(l);
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
package org.elasticsearch.action.support.nodes;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.FailedNodeException;
import org.elasticsearch.action.support.ActionFilters;
Expand Down Expand Up @@ -57,6 +58,8 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
Expand All @@ -66,7 +69,9 @@
import static java.util.Collections.emptyMap;
import static org.elasticsearch.test.ClusterServiceUtils.createClusterService;
import static org.elasticsearch.test.ClusterServiceUtils.setState;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.hasSize;
import static org.mockito.Mockito.mock;

public class TransportNodesActionTests extends ESTestCase {
Expand Down Expand Up @@ -316,6 +321,137 @@ protected Object createActionContext(Task task, TestNodesRequest request) {
assertTrue(cancellableTask.isCancelled()); // keep task alive
}

public void testCompletionShouldNotBeInterferedByCancellationAfterProcessingBegins() throws Exception {
final var barrier = new CyclicBarrier(2);
final var action = new TestTransportNodesAction(
clusterService,
transportService,
new ActionFilters(Set.of()),
TestNodeRequest::new,
THREAD_POOL.executor(ThreadPool.Names.GENERIC)
) {
@Override
protected void newResponseAsync(
Task task,
TestNodesRequest request,
Void unused,
List<TestNodeResponse> testNodeResponses,
List<FailedNodeException> failures,
ActionListener<TestNodesResponse> listener
) {
boolean waited = false;
// Process node responses in a loop and ensure no ConcurrentModificationException will be thrown due to
// concurrent cancellation coming after the loop has started, see also #128852
for (var response : testNodeResponses) {
if (waited == false) {
waited = true;
safeAwait(barrier);
safeAwait(barrier);
}
}
super.newResponseAsync(task, request, unused, testNodeResponses, failures, listener);
}
};

final CancellableTask cancellableTask = new CancellableTask(randomLong(), "transport", "action", "", null, emptyMap());
final var cancelledFuture = new PlainActionFuture<Void>();
cancellableTask.addListener(() -> cancelledFuture.onResponse(null));

final PlainActionFuture<TestNodesResponse> future = new PlainActionFuture<>();
action.execute(cancellableTask, new TestNodesRequest(), future);

for (var capturedRequest : transport.getCapturedRequestsAndClear()) {
completeOneRequest(capturedRequest);
}

// Wait for the overall response to start processing the node responses in a loop and then cancel the task.
// The cancellation should not interfere with the node response processing.
safeAwait(barrier);
TaskCancelHelper.cancel(cancellableTask, "simulated");
safeGet(cancelledFuture);

// Let the process continue, and it should be successful
safeAwait(barrier);
assertResponseReleased(safeGet(future));
}

public void testConcurrentlyCompletionAndCancellation() throws InterruptedException {
final var action = getTestTransportNodesAction();

final CountDownLatch onCancelledLatch = new CountDownLatch(1);
final CancellableTask cancellableTask = new CancellableTask(randomLong(), "transport", "action", "", null, emptyMap()) {
@Override
protected void onCancelled() {
onCancelledLatch.countDown();
}
};

final PlainActionFuture<TestNodesResponse> future = new PlainActionFuture<>();
action.execute(cancellableTask, new TestNodesRequest(), future);

final List<TestNodeResponse> nodeResponses = new ArrayList<>();
final CapturingTransport.CapturedRequest[] capturedRequests = transport.getCapturedRequestsAndClear();
for (int i = 0; i < capturedRequests.length - 1; i++) {
final var capturedRequest = capturedRequests[i];
nodeResponses.add(completeOneRequest(capturedRequest));
}

final var raceBarrier = new CyclicBarrier(3);
final Thread completeThread = new Thread(() -> {
safeAwait(raceBarrier);
nodeResponses.add(completeOneRequest(capturedRequests[capturedRequests.length - 1]));
});
final Thread cancelThread = new Thread(() -> {
safeAwait(raceBarrier);
TaskCancelHelper.cancel(cancellableTask, "simulated");
});
completeThread.start();
cancelThread.start();
safeAwait(raceBarrier);

// We expect either a successful response or a cancellation exception. All node responses should be released in both cases.
try {
final var testNodesResponse = future.actionGet(SAFE_AWAIT_TIMEOUT);
assertThat(testNodesResponse.getNodes(), hasSize(capturedRequests.length));
assertResponseReleased(testNodesResponse);
} catch (Exception e) {
final var taskCancelledException = (TaskCancelledException) ExceptionsHelper.unwrap(e, TaskCancelledException.class);
assertNotNull("expect task cancellation exception, but got\n" + ExceptionsHelper.stackTrace(e), taskCancelledException);
assertThat(e.getMessage(), containsString("task cancelled [simulated]"));
assertTrue(cancellableTask.isCancelled());
safeAwait(onCancelledLatch); // wait for the latch, the listener for releasing node responses is called before it
assertTrue(nodeResponses.stream().allMatch(r -> r.hasReferences() == false));
}

completeThread.join(10_000);
cancelThread.join(10_000);
assertFalse(completeThread.isAlive());
assertFalse(cancelThread.isAlive());
}

private void assertResponseReleased(TestNodesResponse response) {
final var allResponsesReleasedListener = new SubscribableListener<Void>();
try (var listeners = new RefCountingListener(allResponsesReleasedListener)) {
response.addCloseListener(listeners.acquire());
for (final var nodeResponse : response.getNodes()) {
nodeResponse.addCloseListener(listeners.acquire());
}
}
safeAwait(allResponsesReleasedListener);
assertTrue(response.getNodes().stream().noneMatch(TestNodeResponse::hasReferences));
assertFalse(response.hasReferences());
}

private TestNodeResponse completeOneRequest(CapturingTransport.CapturedRequest capturedRequest) {
final var response = new TestNodeResponse(capturedRequest.node());
try {
transport.getTransportResponseHandler(capturedRequest.requestId()).handleResponse(response);
} finally {
response.decRef();
}
return response;
}

@BeforeClass
public static void startThreadPool() {
THREAD_POOL = new TestThreadPool(TransportNodesActionTests.class.getSimpleName());
Expand Down