From 5a1e40fb5051864f270802075df3d013835a27f0 Mon Sep 17 00:00:00 2001 From: David Turner Date: Fri, 11 Apr 2025 15:59:46 +0100 Subject: [PATCH] Fix race condition in `RestCancellableNodeClient` (#126686) Today we rely on registering the channel after registering the task to be cancelled to ensure that the task is cancelled even if the channel is closed concurrently. However the client may already have processed a cancellable request on the channel and therefore this mechanism doesn't work. With this change we make sure not to register another task after draining the registrations in order to cancel them. Closes #88201 --- docs/changelog/126686.yaml | 6 ++ .../IndicesSegmentsRestCancellationIT.java | 11 --- .../action/RestCancellableNodeClient.java | 47 ++++++++----- .../RestCancellableNodeClientTests.java | 70 ++++++++++++++++--- 4 files changed, 94 insertions(+), 40 deletions(-) create mode 100644 docs/changelog/126686.yaml diff --git a/docs/changelog/126686.yaml b/docs/changelog/126686.yaml new file mode 100644 index 0000000000000..802ec538e5c1e --- /dev/null +++ b/docs/changelog/126686.yaml @@ -0,0 +1,6 @@ +pr: 126686 +summary: Fix race condition in `RestCancellableNodeClient` +area: Task Management +type: bug +issues: + - 88201 diff --git a/qa/smoke-test-http/src/internalClusterTest/java/org/elasticsearch/http/IndicesSegmentsRestCancellationIT.java b/qa/smoke-test-http/src/internalClusterTest/java/org/elasticsearch/http/IndicesSegmentsRestCancellationIT.java index 92fde6d7765cc..a90b04d54649c 100644 --- a/qa/smoke-test-http/src/internalClusterTest/java/org/elasticsearch/http/IndicesSegmentsRestCancellationIT.java +++ b/qa/smoke-test-http/src/internalClusterTest/java/org/elasticsearch/http/IndicesSegmentsRestCancellationIT.java @@ -12,23 +12,12 @@ import org.apache.http.client.methods.HttpGet; import org.elasticsearch.action.admin.indices.segments.IndicesSegmentsAction; import org.elasticsearch.client.Request; -import org.elasticsearch.test.junit.annotations.TestIssueLogging; public class IndicesSegmentsRestCancellationIT extends BlockedSearcherRestCancellationTestCase { - @TestIssueLogging( - issueUrl = "https://github.com/elastic/elasticsearch/issues/88201", - value = "org.elasticsearch.http.BlockedSearcherRestCancellationTestCase:DEBUG" - + ",org.elasticsearch.transport.TransportService:TRACE" - ) public void testIndicesSegmentsRestCancellation() throws Exception { runTest(new Request(HttpGet.METHOD_NAME, "/_segments"), IndicesSegmentsAction.NAME); } - @TestIssueLogging( - issueUrl = "https://github.com/elastic/elasticsearch/issues/88201", - value = "org.elasticsearch.http.BlockedSearcherRestCancellationTestCase:DEBUG" - + ",org.elasticsearch.transport.TransportService:TRACE" - ) public void testCatSegmentsRestCancellation() throws Exception { runTest(new Request(HttpGet.METHOD_NAME, "/_cat/segments"), IndicesSegmentsAction.NAME); } diff --git a/server/src/main/java/org/elasticsearch/rest/action/RestCancellableNodeClient.java b/server/src/main/java/org/elasticsearch/rest/action/RestCancellableNodeClient.java index 33b3ef35671e3..e4e8378e4355e 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/RestCancellableNodeClient.java +++ b/server/src/main/java/org/elasticsearch/rest/action/RestCancellableNodeClient.java @@ -18,14 +18,14 @@ import org.elasticsearch.client.internal.FilterClient; import org.elasticsearch.client.internal.OriginSettingClient; import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.core.Nullable; import org.elasticsearch.http.HttpChannel; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskId; -import java.util.ArrayList; +import java.util.Collection; import java.util.HashSet; -import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; @@ -112,12 +112,14 @@ private void cancelTask(TaskId taskId) { private class CloseListener implements ActionListener { private final AtomicReference channel = new AtomicReference<>(); - private final Set tasks = new HashSet<>(); + + @Nullable // if already drained + private Set tasks = new HashSet<>(); CloseListener() {} synchronized int getNumTasks() { - return tasks.size(); + return tasks == null ? 0 : tasks.size(); } void maybeRegisterChannel(HttpChannel httpChannel) { @@ -130,16 +132,23 @@ void maybeRegisterChannel(HttpChannel httpChannel) { } } - synchronized void registerTask(TaskHolder taskHolder, TaskId taskId) { - taskHolder.taskId = taskId; - if (taskHolder.completed == false) { - this.tasks.add(taskId); + void registerTask(TaskHolder taskHolder, TaskId taskId) { + synchronized (this) { + taskHolder.taskId = taskId; + if (tasks != null) { + if (taskHolder.completed == false) { + tasks.add(taskId); + } + return; + } } + // else tasks == null so the channel is already closed + cancelTask(taskId); } synchronized void unregisterTask(TaskHolder taskHolder) { - if (taskHolder.taskId != null) { - this.tasks.remove(taskHolder.taskId); + if (taskHolder.taskId != null && tasks != null) { + tasks.remove(taskHolder.taskId); } taskHolder.completed = true; } @@ -149,18 +158,20 @@ public void onResponse(Void aVoid) { final HttpChannel httpChannel = channel.get(); assert httpChannel != null : "channel not registered"; // when the channel gets closed it won't be reused: we can remove it from the map and forget about it. - CloseListener closeListener = httpChannels.remove(httpChannel); - assert closeListener != null : "channel not found in the map of tracked channels"; - final List toCancel; - synchronized (this) { - toCancel = new ArrayList<>(tasks); - tasks.clear(); - } - for (TaskId taskId : toCancel) { + final CloseListener closeListener = httpChannels.remove(httpChannel); + assert closeListener != null : "channel not found in the map of tracked channels: " + httpChannel; + assert closeListener == CloseListener.this : "channel had a different CloseListener registered: " + httpChannel; + for (final var taskId : drainTasks()) { cancelTask(taskId); } } + private synchronized Collection drainTasks() { + final var drained = tasks; + tasks = null; + return drained; + } + @Override public void onFailure(Exception e) { onResponse(null); diff --git a/server/src/test/java/org/elasticsearch/rest/action/RestCancellableNodeClientTests.java b/server/src/test/java/org/elasticsearch/rest/action/RestCancellableNodeClientTests.java index 74c6fceddf71b..c58621d03ce8f 100644 --- a/server/src/test/java/org/elasticsearch/rest/action/RestCancellableNodeClientTests.java +++ b/server/src/test/java/org/elasticsearch/rest/action/RestCancellableNodeClientTests.java @@ -21,6 +21,7 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.http.HttpChannel; import org.elasticsearch.http.HttpResponse; import org.elasticsearch.tasks.Task; @@ -44,6 +45,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.LongSupplier; public class RestCancellableNodeClientTests extends ESTestCase { @@ -148,8 +150,42 @@ public void testChannelAlreadyClosed() { assertEquals(totalSearches, testClient.cancelledTasks.size()); } + public void testConcurrentExecuteAndClose() throws Exception { + final var testClient = new TestClient(Settings.EMPTY, threadPool, true); + int initialHttpChannels = RestCancellableNodeClient.getNumChannels(); + int numTasks = randomIntBetween(1, 30); + TestHttpChannel channel = new TestHttpChannel(); + final var startLatch = new CountDownLatch(1); + final var doneLatch = new CountDownLatch(numTasks + 1); + final var expectedTasks = Sets.newHashSetWithExpectedSize(numTasks); + for (int j = 0; j < numTasks; j++) { + RestCancellableNodeClient client = new RestCancellableNodeClient(testClient, channel); + threadPool.generic().execute(() -> { + client.execute(TransportSearchAction.TYPE, new SearchRequest(), ActionListener.running(ESTestCase::fail)); + startLatch.countDown(); + doneLatch.countDown(); + }); + expectedTasks.add(new TaskId(testClient.getLocalNodeId(), j)); + } + threadPool.generic().execute(() -> { + try { + safeAwait(startLatch); + channel.awaitClose(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new AssertionError(e); + } finally { + doneLatch.countDown(); + } + }); + safeAwait(doneLatch); + assertEquals(initialHttpChannels, RestCancellableNodeClient.getNumChannels()); + assertEquals(expectedTasks, testClient.cancelledTasks); + } + private static class TestClient extends NodeClient { - private final AtomicLong counter = new AtomicLong(0); + private final LongSupplier searchTaskIdGenerator = new AtomicLong(0)::getAndIncrement; + private final LongSupplier cancelTaskIdGenerator = new AtomicLong(1000)::getAndIncrement; private final Set cancelledTasks = new CopyOnWriteArraySet<>(); private final AtomicInteger searchRequests = new AtomicInteger(0); private final boolean timeout; @@ -167,9 +203,17 @@ public Task exe ) { switch (action.name()) { case TransportCancelTasksAction.NAME -> { - CancelTasksRequest cancelTasksRequest = (CancelTasksRequest) request; - assertTrue("tried to cancel the same task more than once", cancelledTasks.add(cancelTasksRequest.getTargetTaskId())); - Task task = request.createTask(counter.getAndIncrement(), "cancel_task", action.name(), null, Collections.emptyMap()); + assertTrue( + "tried to cancel the same task more than once", + cancelledTasks.add(asInstanceOf(CancelTasksRequest.class, request).getTargetTaskId()) + ); + Task task = request.createTask( + cancelTaskIdGenerator.getAsLong(), + "cancel_task", + action.name(), + null, + Collections.emptyMap() + ); if (randomBoolean()) { listener.onResponse(null); } else { @@ -180,7 +224,13 @@ public Task exe } case TransportSearchAction.NAME -> { searchRequests.incrementAndGet(); - Task searchTask = request.createTask(counter.getAndIncrement(), "search", action.name(), null, Collections.emptyMap()); + Task searchTask = request.createTask( + searchTaskIdGenerator.getAsLong(), + "search", + action.name(), + null, + Collections.emptyMap() + ); if (timeout == false) { if (rarely()) { // make sure that search is sometimes also called from the same thread before the task is returned @@ -191,7 +241,7 @@ public Task exe } return searchTask; } - default -> throw new UnsupportedOperationException(); + default -> throw new AssertionError("unexpected action " + action.name()); } } @@ -222,10 +272,7 @@ public InetSocketAddress getRemoteAddress() { @Override public void close() { - if (open.compareAndSet(true, false) == false) { - assert false : "HttpChannel is already closed"; - return; // nothing to do - } + assertTrue("HttpChannel is already closed", open.compareAndSet(true, false)); ActionListener listener = closeListener.get(); if (listener != null) { boolean failure = randomBoolean(); @@ -241,6 +288,7 @@ public void close() { } private void awaitClose() throws InterruptedException { + assertNotNull("must set closeListener before calling awaitClose", closeListener.get()); close(); closeLatch.await(); } @@ -257,7 +305,7 @@ public void addCloseListener(ActionListener listener) { listener.onResponse(null); } else { if (closeListener.compareAndSet(null, listener) == false) { - throw new IllegalStateException("close listener already set, only one is allowed!"); + throw new AssertionError("close listener already set, only one is allowed!"); } } }