Skip to content

Commit 7379ad5

Browse files
authored
Drain responses on completion for TransportNodesAction (#130303) (#130513)
This PR ensures the node responses are copied and drained exclusively in onCompletion so that they do not get concurrently modified by cancellation. Resolves: #128852
1 parent a9018aa commit 7379ad5

File tree

3 files changed

+161
-8
lines changed

3 files changed

+161
-8
lines changed

docs/changelog/130303.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 130303
2+
summary: Drain responses on completion for `TransportNodesAction`
3+
area: Distributed
4+
type: bug
5+
issues: []

server/src/main/java/org/elasticsearch/action/support/nodes/TransportNodesAction.java

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import java.util.List;
4747
import java.util.Objects;
4848
import java.util.concurrent.Executor;
49+
import java.util.concurrent.atomic.AtomicBoolean;
4950

5051
import static org.elasticsearch.core.Strings.format;
5152

@@ -103,6 +104,7 @@ protected void doExecute(Task task, NodesRequest request, ActionListener<NodesRe
103104
final ActionContext actionContext = createActionContext(task, request);
104105
final ArrayList<NodeResponse> responses = new ArrayList<>(concreteNodes.length);
105106
final ArrayList<FailedNodeException> exceptions = new ArrayList<>(0);
107+
final AtomicBoolean responsesHandled = new AtomicBoolean(false);
106108

107109
final TransportRequestOptions transportRequestOptions = TransportRequestOptions.timeout(request.timeout());
108110

@@ -113,12 +115,14 @@ protected void doExecute(Task task, NodesRequest request, ActionListener<NodesRe
113115
private void addReleaseOnCancellationListener() {
114116
if (task instanceof CancellableTask cancellableTask) {
115117
cancellableTask.addListener(() -> {
116-
final List<NodeResponse> drainedResponses;
117-
synchronized (responses) {
118-
drainedResponses = List.copyOf(responses);
119-
responses.clear();
118+
if (responsesHandled.compareAndSet(false, true)) {
119+
final List<NodeResponse> drainedResponses;
120+
synchronized (responses) {
121+
drainedResponses = List.copyOf(responses);
122+
responses.clear();
123+
}
124+
Releasables.wrap(Iterators.map(drainedResponses.iterator(), r -> r::decRef)).close();
120125
}
121-
Releasables.wrap(Iterators.map(drainedResponses.iterator(), r -> r::decRef)).close();
122126
});
123127
}
124128
}
@@ -165,10 +169,18 @@ protected void onItemFailure(DiscoveryNode discoveryNode, Exception e) {
165169

166170
@Override
167171
protected CheckedConsumer<ActionListener<NodesResponse>, Exception> onCompletion() {
168-
// ref releases all happen-before here so no need to be synchronized
169172
return l -> {
170-
try (var ignored = Releasables.wrap(Iterators.map(responses.iterator(), r -> r::decRef))) {
171-
newResponseAsync(task, request, actionContext, responses, exceptions, l);
173+
if (responsesHandled.compareAndSet(false, true)) {
174+
// ref releases all happen-before here so no need to be synchronized
175+
try (var ignored = Releasables.wrap(Iterators.map(responses.iterator(), r -> r::decRef))) {
176+
newResponseAsync(task, request, actionContext, responses, exceptions, l);
177+
}
178+
} else {
179+
logger.debug("task cancelled after all responses were collected");
180+
assert task instanceof CancellableTask : "expect CancellableTask, but got: " + task;
181+
final var cancellableTask = (CancellableTask) task;
182+
assert cancellableTask.isCancelled();
183+
cancellableTask.notifyIfCancelled(l);
172184
}
173185
};
174186
}

server/src/test/java/org/elasticsearch/action/support/nodes/TransportNodesActionTests.java

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
package org.elasticsearch.action.support.nodes;
1111

1212
import org.elasticsearch.ElasticsearchException;
13+
import org.elasticsearch.ExceptionsHelper;
1314
import org.elasticsearch.action.ActionListener;
1415
import org.elasticsearch.action.FailedNodeException;
1516
import org.elasticsearch.action.support.ActionFilters;
@@ -57,6 +58,8 @@
5758
import java.util.List;
5859
import java.util.Map;
5960
import java.util.Set;
61+
import java.util.concurrent.CountDownLatch;
62+
import java.util.concurrent.CyclicBarrier;
6063
import java.util.concurrent.Executor;
6164
import java.util.concurrent.TimeUnit;
6265
import java.util.concurrent.atomic.AtomicInteger;
@@ -66,7 +69,9 @@
6669
import static java.util.Collections.emptyMap;
6770
import static org.elasticsearch.test.ClusterServiceUtils.createClusterService;
6871
import static org.elasticsearch.test.ClusterServiceUtils.setState;
72+
import static org.hamcrest.Matchers.containsString;
6973
import static org.hamcrest.Matchers.greaterThan;
74+
import static org.hamcrest.Matchers.hasSize;
7075
import static org.mockito.Mockito.mock;
7176

7277
public class TransportNodesActionTests extends ESTestCase {
@@ -316,6 +321,137 @@ protected Object createActionContext(Task task, TestNodesRequest request) {
316321
assertTrue(cancellableTask.isCancelled()); // keep task alive
317322
}
318323

324+
public void testCompletionShouldNotBeInterferedByCancellationAfterProcessingBegins() throws Exception {
325+
final var barrier = new CyclicBarrier(2);
326+
final var action = new TestTransportNodesAction(
327+
clusterService,
328+
transportService,
329+
new ActionFilters(Set.of()),
330+
TestNodeRequest::new,
331+
THREAD_POOL.executor(ThreadPool.Names.GENERIC)
332+
) {
333+
@Override
334+
protected void newResponseAsync(
335+
Task task,
336+
TestNodesRequest request,
337+
Void unused,
338+
List<TestNodeResponse> testNodeResponses,
339+
List<FailedNodeException> failures,
340+
ActionListener<TestNodesResponse> listener
341+
) {
342+
boolean waited = false;
343+
// Process node responses in a loop and ensure no ConcurrentModificationException will be thrown due to
344+
// concurrent cancellation coming after the loop has started, see also #128852
345+
for (var response : testNodeResponses) {
346+
if (waited == false) {
347+
waited = true;
348+
safeAwait(barrier);
349+
safeAwait(barrier);
350+
}
351+
}
352+
super.newResponseAsync(task, request, unused, testNodeResponses, failures, listener);
353+
}
354+
};
355+
356+
final CancellableTask cancellableTask = new CancellableTask(randomLong(), "transport", "action", "", null, emptyMap());
357+
final var cancelledFuture = new PlainActionFuture<Void>();
358+
cancellableTask.addListener(() -> cancelledFuture.onResponse(null));
359+
360+
final PlainActionFuture<TestNodesResponse> future = new PlainActionFuture<>();
361+
action.execute(cancellableTask, new TestNodesRequest(), future);
362+
363+
for (var capturedRequest : transport.getCapturedRequestsAndClear()) {
364+
completeOneRequest(capturedRequest);
365+
}
366+
367+
// Wait for the overall response to start processing the node responses in a loop and then cancel the task.
368+
// The cancellation should not interfere with the node response processing.
369+
safeAwait(barrier);
370+
TaskCancelHelper.cancel(cancellableTask, "simulated");
371+
safeGet(cancelledFuture);
372+
373+
// Let the process continue, and it should be successful
374+
safeAwait(barrier);
375+
assertResponseReleased(safeGet(future));
376+
}
377+
378+
public void testConcurrentlyCompletionAndCancellation() throws InterruptedException {
379+
final var action = getTestTransportNodesAction();
380+
381+
final CountDownLatch onCancelledLatch = new CountDownLatch(1);
382+
final CancellableTask cancellableTask = new CancellableTask(randomLong(), "transport", "action", "", null, emptyMap()) {
383+
@Override
384+
protected void onCancelled() {
385+
onCancelledLatch.countDown();
386+
}
387+
};
388+
389+
final PlainActionFuture<TestNodesResponse> future = new PlainActionFuture<>();
390+
action.execute(cancellableTask, new TestNodesRequest(), future);
391+
392+
final List<TestNodeResponse> nodeResponses = new ArrayList<>();
393+
final CapturingTransport.CapturedRequest[] capturedRequests = transport.getCapturedRequestsAndClear();
394+
for (int i = 0; i < capturedRequests.length - 1; i++) {
395+
final var capturedRequest = capturedRequests[i];
396+
nodeResponses.add(completeOneRequest(capturedRequest));
397+
}
398+
399+
final var raceBarrier = new CyclicBarrier(3);
400+
final Thread completeThread = new Thread(() -> {
401+
safeAwait(raceBarrier);
402+
nodeResponses.add(completeOneRequest(capturedRequests[capturedRequests.length - 1]));
403+
});
404+
final Thread cancelThread = new Thread(() -> {
405+
safeAwait(raceBarrier);
406+
TaskCancelHelper.cancel(cancellableTask, "simulated");
407+
});
408+
completeThread.start();
409+
cancelThread.start();
410+
safeAwait(raceBarrier);
411+
412+
// We expect either a successful response or a cancellation exception. All node responses should be released in both cases.
413+
try {
414+
final var testNodesResponse = future.actionGet(SAFE_AWAIT_TIMEOUT);
415+
assertThat(testNodesResponse.getNodes(), hasSize(capturedRequests.length));
416+
assertResponseReleased(testNodesResponse);
417+
} catch (Exception e) {
418+
final var taskCancelledException = (TaskCancelledException) ExceptionsHelper.unwrap(e, TaskCancelledException.class);
419+
assertNotNull("expect task cancellation exception, but got\n" + ExceptionsHelper.stackTrace(e), taskCancelledException);
420+
assertThat(e.getMessage(), containsString("task cancelled [simulated]"));
421+
assertTrue(cancellableTask.isCancelled());
422+
safeAwait(onCancelledLatch); // wait for the latch, the listener for releasing node responses is called before it
423+
assertTrue(nodeResponses.stream().allMatch(r -> r.hasReferences() == false));
424+
}
425+
426+
completeThread.join(10_000);
427+
cancelThread.join(10_000);
428+
assertFalse(completeThread.isAlive());
429+
assertFalse(cancelThread.isAlive());
430+
}
431+
432+
private void assertResponseReleased(TestNodesResponse response) {
433+
final var allResponsesReleasedListener = new SubscribableListener<Void>();
434+
try (var listeners = new RefCountingListener(allResponsesReleasedListener)) {
435+
response.addCloseListener(listeners.acquire());
436+
for (final var nodeResponse : response.getNodes()) {
437+
nodeResponse.addCloseListener(listeners.acquire());
438+
}
439+
}
440+
safeAwait(allResponsesReleasedListener);
441+
assertTrue(response.getNodes().stream().noneMatch(TestNodeResponse::hasReferences));
442+
assertFalse(response.hasReferences());
443+
}
444+
445+
private TestNodeResponse completeOneRequest(CapturingTransport.CapturedRequest capturedRequest) {
446+
final var response = new TestNodeResponse(capturedRequest.node());
447+
try {
448+
transport.getTransportResponseHandler(capturedRequest.requestId()).handleResponse(response);
449+
} finally {
450+
response.decRef();
451+
}
452+
return response;
453+
}
454+
319455
@BeforeClass
320456
public static void startThreadPool() {
321457
THREAD_POOL = new TestThreadPool(TransportNodesActionTests.class.getSimpleName());

0 commit comments

Comments
 (0)