Skip to content

Commit 2c3e84e

Browse files
committed
Fix race condition in RestCancellableNodeClient (#126686)
Today we rely on registering the channel after registering the task to be cancelled to ensure that the task is cancelled even if the channel is closed concurrently. However the client may already have processed a cancellable request on the channel and therefore this mechanism doesn't work. With this change we make sure not to register another task after draining the registrations in order to cancel them. Closes #88201
1 parent 38499d8 commit 2c3e84e

File tree

4 files changed

+94
-40
lines changed

4 files changed

+94
-40
lines changed

docs/changelog/126686.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 126686
2+
summary: Fix race condition in `RestCancellableNodeClient`
3+
area: Task Management
4+
type: bug
5+
issues:
6+
- 88201

qa/smoke-test-http/src/internalClusterTest/java/org/elasticsearch/http/IndicesSegmentsRestCancellationIT.java

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,12 @@
1212
import org.apache.http.client.methods.HttpGet;
1313
import org.elasticsearch.action.admin.indices.segments.IndicesSegmentsAction;
1414
import org.elasticsearch.client.Request;
15-
import org.elasticsearch.test.junit.annotations.TestIssueLogging;
1615

1716
public class IndicesSegmentsRestCancellationIT extends BlockedSearcherRestCancellationTestCase {
18-
@TestIssueLogging(
19-
issueUrl = "https://github.com/elastic/elasticsearch/issues/88201",
20-
value = "org.elasticsearch.http.BlockedSearcherRestCancellationTestCase:DEBUG"
21-
+ ",org.elasticsearch.transport.TransportService:TRACE"
22-
)
2317
public void testIndicesSegmentsRestCancellation() throws Exception {
2418
runTest(new Request(HttpGet.METHOD_NAME, "/_segments"), IndicesSegmentsAction.NAME);
2519
}
2620

27-
@TestIssueLogging(
28-
issueUrl = "https://github.com/elastic/elasticsearch/issues/88201",
29-
value = "org.elasticsearch.http.BlockedSearcherRestCancellationTestCase:DEBUG"
30-
+ ",org.elasticsearch.transport.TransportService:TRACE"
31-
)
3221
public void testCatSegmentsRestCancellation() throws Exception {
3322
runTest(new Request(HttpGet.METHOD_NAME, "/_cat/segments"), IndicesSegmentsAction.NAME);
3423
}

server/src/main/java/org/elasticsearch/rest/action/RestCancellableNodeClient.java

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@
1818
import org.elasticsearch.client.internal.FilterClient;
1919
import org.elasticsearch.client.internal.OriginSettingClient;
2020
import org.elasticsearch.client.internal.node.NodeClient;
21+
import org.elasticsearch.core.Nullable;
2122
import org.elasticsearch.http.HttpChannel;
2223
import org.elasticsearch.tasks.CancellableTask;
2324
import org.elasticsearch.tasks.Task;
2425
import org.elasticsearch.tasks.TaskId;
2526

26-
import java.util.ArrayList;
27+
import java.util.Collection;
2728
import java.util.HashSet;
28-
import java.util.List;
2929
import java.util.Map;
3030
import java.util.Set;
3131
import java.util.concurrent.ConcurrentHashMap;
@@ -112,12 +112,14 @@ private void cancelTask(TaskId taskId) {
112112

113113
private class CloseListener implements ActionListener<Void> {
114114
private final AtomicReference<HttpChannel> channel = new AtomicReference<>();
115-
private final Set<TaskId> tasks = new HashSet<>();
115+
116+
@Nullable // if already drained
117+
private Set<TaskId> tasks = new HashSet<>();
116118

117119
CloseListener() {}
118120

119121
synchronized int getNumTasks() {
120-
return tasks.size();
122+
return tasks == null ? 0 : tasks.size();
121123
}
122124

123125
void maybeRegisterChannel(HttpChannel httpChannel) {
@@ -130,16 +132,23 @@ void maybeRegisterChannel(HttpChannel httpChannel) {
130132
}
131133
}
132134

133-
synchronized void registerTask(TaskHolder taskHolder, TaskId taskId) {
134-
taskHolder.taskId = taskId;
135-
if (taskHolder.completed == false) {
136-
this.tasks.add(taskId);
135+
void registerTask(TaskHolder taskHolder, TaskId taskId) {
136+
synchronized (this) {
137+
taskHolder.taskId = taskId;
138+
if (tasks != null) {
139+
if (taskHolder.completed == false) {
140+
tasks.add(taskId);
141+
}
142+
return;
143+
}
137144
}
145+
// else tasks == null so the channel is already closed
146+
cancelTask(taskId);
138147
}
139148

140149
synchronized void unregisterTask(TaskHolder taskHolder) {
141-
if (taskHolder.taskId != null) {
142-
this.tasks.remove(taskHolder.taskId);
150+
if (taskHolder.taskId != null && tasks != null) {
151+
tasks.remove(taskHolder.taskId);
143152
}
144153
taskHolder.completed = true;
145154
}
@@ -149,18 +158,20 @@ public void onResponse(Void aVoid) {
149158
final HttpChannel httpChannel = channel.get();
150159
assert httpChannel != null : "channel not registered";
151160
// when the channel gets closed it won't be reused: we can remove it from the map and forget about it.
152-
CloseListener closeListener = httpChannels.remove(httpChannel);
153-
assert closeListener != null : "channel not found in the map of tracked channels";
154-
final List<TaskId> toCancel;
155-
synchronized (this) {
156-
toCancel = new ArrayList<>(tasks);
157-
tasks.clear();
158-
}
159-
for (TaskId taskId : toCancel) {
161+
final CloseListener closeListener = httpChannels.remove(httpChannel);
162+
assert closeListener != null : "channel not found in the map of tracked channels: " + httpChannel;
163+
assert closeListener == CloseListener.this : "channel had a different CloseListener registered: " + httpChannel;
164+
for (final var taskId : drainTasks()) {
160165
cancelTask(taskId);
161166
}
162167
}
163168

169+
private synchronized Collection<TaskId> drainTasks() {
170+
final var drained = tasks;
171+
tasks = null;
172+
return drained;
173+
}
174+
164175
@Override
165176
public void onFailure(Exception e) {
166177
onResponse(null);

server/src/test/java/org/elasticsearch/rest/action/RestCancellableNodeClientTests.java

Lines changed: 59 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.elasticsearch.action.support.PlainActionFuture;
2222
import org.elasticsearch.client.internal.node.NodeClient;
2323
import org.elasticsearch.common.settings.Settings;
24+
import org.elasticsearch.common.util.set.Sets;
2425
import org.elasticsearch.http.HttpChannel;
2526
import org.elasticsearch.http.HttpResponse;
2627
import org.elasticsearch.tasks.Task;
@@ -44,6 +45,7 @@
4445
import java.util.concurrent.atomic.AtomicInteger;
4546
import java.util.concurrent.atomic.AtomicLong;
4647
import java.util.concurrent.atomic.AtomicReference;
48+
import java.util.function.LongSupplier;
4749

4850
public class RestCancellableNodeClientTests extends ESTestCase {
4951

@@ -148,8 +150,42 @@ public void testChannelAlreadyClosed() {
148150
assertEquals(totalSearches, testClient.cancelledTasks.size());
149151
}
150152

153+
public void testConcurrentExecuteAndClose() throws Exception {
154+
final var testClient = new TestClient(Settings.EMPTY, threadPool, true);
155+
int initialHttpChannels = RestCancellableNodeClient.getNumChannels();
156+
int numTasks = randomIntBetween(1, 30);
157+
TestHttpChannel channel = new TestHttpChannel();
158+
final var startLatch = new CountDownLatch(1);
159+
final var doneLatch = new CountDownLatch(numTasks + 1);
160+
final var expectedTasks = Sets.<TaskId>newHashSetWithExpectedSize(numTasks);
161+
for (int j = 0; j < numTasks; j++) {
162+
RestCancellableNodeClient client = new RestCancellableNodeClient(testClient, channel);
163+
threadPool.generic().execute(() -> {
164+
client.execute(TransportSearchAction.TYPE, new SearchRequest(), ActionListener.running(ESTestCase::fail));
165+
startLatch.countDown();
166+
doneLatch.countDown();
167+
});
168+
expectedTasks.add(new TaskId(testClient.getLocalNodeId(), j));
169+
}
170+
threadPool.generic().execute(() -> {
171+
try {
172+
safeAwait(startLatch);
173+
channel.awaitClose();
174+
} catch (InterruptedException e) {
175+
Thread.currentThread().interrupt();
176+
throw new AssertionError(e);
177+
} finally {
178+
doneLatch.countDown();
179+
}
180+
});
181+
safeAwait(doneLatch);
182+
assertEquals(initialHttpChannels, RestCancellableNodeClient.getNumChannels());
183+
assertEquals(expectedTasks, testClient.cancelledTasks);
184+
}
185+
151186
private static class TestClient extends NodeClient {
152-
private final AtomicLong counter = new AtomicLong(0);
187+
private final LongSupplier searchTaskIdGenerator = new AtomicLong(0)::getAndIncrement;
188+
private final LongSupplier cancelTaskIdGenerator = new AtomicLong(1000)::getAndIncrement;
153189
private final Set<TaskId> cancelledTasks = new CopyOnWriteArraySet<>();
154190
private final AtomicInteger searchRequests = new AtomicInteger(0);
155191
private final boolean timeout;
@@ -167,9 +203,17 @@ public <Request extends ActionRequest, Response extends ActionResponse> Task exe
167203
) {
168204
switch (action.name()) {
169205
case TransportCancelTasksAction.NAME -> {
170-
CancelTasksRequest cancelTasksRequest = (CancelTasksRequest) request;
171-
assertTrue("tried to cancel the same task more than once", cancelledTasks.add(cancelTasksRequest.getTargetTaskId()));
172-
Task task = request.createTask(counter.getAndIncrement(), "cancel_task", action.name(), null, Collections.emptyMap());
206+
assertTrue(
207+
"tried to cancel the same task more than once",
208+
cancelledTasks.add(asInstanceOf(CancelTasksRequest.class, request).getTargetTaskId())
209+
);
210+
Task task = request.createTask(
211+
cancelTaskIdGenerator.getAsLong(),
212+
"cancel_task",
213+
action.name(),
214+
null,
215+
Collections.emptyMap()
216+
);
173217
if (randomBoolean()) {
174218
listener.onResponse(null);
175219
} else {
@@ -180,7 +224,13 @@ public <Request extends ActionRequest, Response extends ActionResponse> Task exe
180224
}
181225
case TransportSearchAction.NAME -> {
182226
searchRequests.incrementAndGet();
183-
Task searchTask = request.createTask(counter.getAndIncrement(), "search", action.name(), null, Collections.emptyMap());
227+
Task searchTask = request.createTask(
228+
searchTaskIdGenerator.getAsLong(),
229+
"search",
230+
action.name(),
231+
null,
232+
Collections.emptyMap()
233+
);
184234
if (timeout == false) {
185235
if (rarely()) {
186236
// make sure that search is sometimes also called from the same thread before the task is returned
@@ -191,7 +241,7 @@ public <Request extends ActionRequest, Response extends ActionResponse> Task exe
191241
}
192242
return searchTask;
193243
}
194-
default -> throw new UnsupportedOperationException();
244+
default -> throw new AssertionError("unexpected action " + action.name());
195245
}
196246

197247
}
@@ -222,10 +272,7 @@ public InetSocketAddress getRemoteAddress() {
222272

223273
@Override
224274
public void close() {
225-
if (open.compareAndSet(true, false) == false) {
226-
assert false : "HttpChannel is already closed";
227-
return; // nothing to do
228-
}
275+
assertTrue("HttpChannel is already closed", open.compareAndSet(true, false));
229276
ActionListener<Void> listener = closeListener.get();
230277
if (listener != null) {
231278
boolean failure = randomBoolean();
@@ -241,6 +288,7 @@ public void close() {
241288
}
242289

243290
private void awaitClose() throws InterruptedException {
291+
assertNotNull("must set closeListener before calling awaitClose", closeListener.get());
244292
close();
245293
closeLatch.await();
246294
}
@@ -257,7 +305,7 @@ public void addCloseListener(ActionListener<Void> listener) {
257305
listener.onResponse(null);
258306
} else {
259307
if (closeListener.compareAndSet(null, listener) == false) {
260-
throw new IllegalStateException("close listener already set, only one is allowed!");
308+
throw new AssertionError("close listener already set, only one is allowed!");
261309
}
262310
}
263311
}

0 commit comments

Comments
 (0)