diff --git a/docs/changelog/126686.yaml b/docs/changelog/126686.yaml new file mode 100644 index 0000000000000..802ec538e5c1e --- /dev/null +++ b/docs/changelog/126686.yaml @@ -0,0 +1,6 @@ +pr: 126686 +summary: Fix race condition in `RestCancellableNodeClient` +area: Task Management +type: bug +issues: + - 88201 diff --git a/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/IndicesSegmentsRestCancellationIT.java b/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/IndicesSegmentsRestCancellationIT.java index da9bb326f3c00..d2ab3b5a01636 100644 --- a/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/IndicesSegmentsRestCancellationIT.java +++ b/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/IndicesSegmentsRestCancellationIT.java @@ -11,23 +11,12 @@ import org.apache.http.client.methods.HttpGet; import org.elasticsearch.action.admin.indices.segments.IndicesSegmentsAction; import org.elasticsearch.client.Request; -import org.elasticsearch.test.junit.annotations.TestIssueLogging; public class IndicesSegmentsRestCancellationIT extends BlockedSearcherRestCancellationTestCase { - @TestIssueLogging( - issueUrl = "https://github.com/elastic/elasticsearch/issues/88201", - value = "org.elasticsearch.http.BlockedSearcherRestCancellationTestCase:DEBUG" - + ",org.elasticsearch.transport.TransportService:TRACE" - ) public void testIndicesSegmentsRestCancellation() throws Exception { runTest(new Request(HttpGet.METHOD_NAME, "/_segments"), IndicesSegmentsAction.NAME); } - @TestIssueLogging( - issueUrl = "https://github.com/elastic/elasticsearch/issues/88201", - value = "org.elasticsearch.http.BlockedSearcherRestCancellationTestCase:DEBUG" - + ",org.elasticsearch.transport.TransportService:TRACE" - ) public void testCatSegmentsRestCancellation() throws Exception { runTest(new Request(HttpGet.METHOD_NAME, "/_cat/segments"), IndicesSegmentsAction.NAME); } diff --git a/server/src/main/java/org/elasticsearch/rest/action/RestCancellableNodeClient.java b/server/src/main/java/org/elasticsearch/rest/action/RestCancellableNodeClient.java index 9192a89e8fa91..7f46e47243d72 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/RestCancellableNodeClient.java +++ b/server/src/main/java/org/elasticsearch/rest/action/RestCancellableNodeClient.java @@ -17,14 +17,14 @@ import org.elasticsearch.client.FilterClient; import org.elasticsearch.client.OriginSettingClient; import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.core.Nullable; import org.elasticsearch.http.HttpChannel; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskId; -import java.util.ArrayList; +import java.util.Collection; import java.util.HashSet; -import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; @@ -111,12 +111,14 @@ private void cancelTask(TaskId taskId) { private class CloseListener implements ActionListener { private final AtomicReference channel = new AtomicReference<>(); - private final Set tasks = new HashSet<>(); + + @Nullable // if already drained + private Set tasks = new HashSet<>(); CloseListener() {} synchronized int getNumTasks() { - return tasks.size(); + return tasks == null ? 0 : tasks.size(); } void maybeRegisterChannel(HttpChannel httpChannel) { @@ -129,16 +131,23 @@ void maybeRegisterChannel(HttpChannel httpChannel) { } } - synchronized void registerTask(TaskHolder taskHolder, TaskId taskId) { - taskHolder.taskId = taskId; - if (taskHolder.completed == false) { - this.tasks.add(taskId); + void registerTask(TaskHolder taskHolder, TaskId taskId) { + synchronized (this) { + taskHolder.taskId = taskId; + if (tasks != null) { + if (taskHolder.completed == false) { + tasks.add(taskId); + } + return; + } } + // else tasks == null so the channel is already closed + cancelTask(taskId); } synchronized void unregisterTask(TaskHolder taskHolder) { - if (taskHolder.taskId != null) { - this.tasks.remove(taskHolder.taskId); + if (taskHolder.taskId != null && tasks != null) { + tasks.remove(taskHolder.taskId); } taskHolder.completed = true; } @@ -148,18 +157,20 @@ public void onResponse(Void aVoid) { final HttpChannel httpChannel = channel.get(); assert httpChannel != null : "channel not registered"; // when the channel gets closed it won't be reused: we can remove it from the map and forget about it. - CloseListener closeListener = httpChannels.remove(httpChannel); - assert closeListener != null : "channel not found in the map of tracked channels"; - final List toCancel; - synchronized (this) { - toCancel = new ArrayList<>(tasks); - tasks.clear(); - } - for (TaskId taskId : toCancel) { + final CloseListener closeListener = httpChannels.remove(httpChannel); + assert closeListener != null : "channel not found in the map of tracked channels: " + httpChannel; + assert closeListener == CloseListener.this : "channel had a different CloseListener registered: " + httpChannel; + for (final TaskId taskId : drainTasks()) { cancelTask(taskId); } } + private synchronized Collection drainTasks() { + final Collection drained = tasks; + tasks = null; + return drained; + } + @Override public void onFailure(Exception e) { onResponse(null); diff --git a/server/src/test/java/org/elasticsearch/rest/action/RestCancellableNodeClientTests.java b/server/src/test/java/org/elasticsearch/rest/action/RestCancellableNodeClientTests.java index a94c03adade83..4f880ddc92820 100644 --- a/server/src/test/java/org/elasticsearch/rest/action/RestCancellableNodeClientTests.java +++ b/server/src/test/java/org/elasticsearch/rest/action/RestCancellableNodeClientTests.java @@ -33,6 +33,7 @@ import java.net.InetSocketAddress; import java.util.ArrayList; import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Set; import java.util.concurrent.CopyOnWriteArraySet; @@ -43,6 +44,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.LongSupplier; public class RestCancellableNodeClientTests extends ESTestCase { @@ -150,8 +152,42 @@ public void testChannelAlreadyClosed() { } } + public void testConcurrentExecuteAndClose() { + final TestClient testClient = new TestClient(Settings.EMPTY, threadPool, true); + int initialHttpChannels = RestCancellableNodeClient.getNumChannels(); + int numTasks = randomIntBetween(1, 30); + TestHttpChannel channel = new TestHttpChannel(); + final CountDownLatch startLatch = new CountDownLatch(1); + final CountDownLatch doneLatch = new CountDownLatch(numTasks + 1); + final Set expectedTasks = new HashSet<>(numTasks); + for (int j = 0; j < numTasks; j++) { + RestCancellableNodeClient client = new RestCancellableNodeClient(testClient, channel); + threadPool.generic().execute(() -> { + client.execute(SearchAction.INSTANCE, new SearchRequest(), ActionListener.wrap(ESTestCase::fail)); + startLatch.countDown(); + doneLatch.countDown(); + }); + expectedTasks.add(new TaskId(testClient.getLocalNodeId(), j)); + } + threadPool.generic().execute(() -> { + try { + safeAwait(startLatch); + channel.awaitClose(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new AssertionError(e); + } finally { + doneLatch.countDown(); + } + }); + safeAwait(doneLatch); + assertEquals(initialHttpChannels, RestCancellableNodeClient.getNumChannels()); + assertEquals(expectedTasks, testClient.cancelledTasks); + } + private static class TestClient extends NodeClient { - private final AtomicLong counter = new AtomicLong(0); + private final LongSupplier searchTaskIdGenerator = new AtomicLong(0)::getAndIncrement; + private final LongSupplier cancelTaskIdGenerator = new AtomicLong(1000)::getAndIncrement; private final Set cancelledTasks = new CopyOnWriteArraySet<>(); private final AtomicInteger searchRequests = new AtomicInteger(0); private final boolean timeout; @@ -171,7 +207,13 @@ public Task exe case CancelTasksAction.NAME: CancelTasksRequest cancelTasksRequest = (CancelTasksRequest) request; assertTrue("tried to cancel the same task more than once", cancelledTasks.add(cancelTasksRequest.getTaskId())); - Task task = request.createTask(counter.getAndIncrement(), "cancel_task", action.name(), null, Collections.emptyMap()); + Task task = request.createTask( + cancelTaskIdGenerator.getAsLong(), + "cancel_task", + action.name(), + null, + Collections.emptyMap() + ); if (randomBoolean()) { listener.onResponse(null); } else { @@ -182,7 +224,13 @@ public Task exe return task; case SearchAction.NAME: searchRequests.incrementAndGet(); - Task searchTask = request.createTask(counter.getAndIncrement(), "search", action.name(), null, Collections.emptyMap()); + Task searchTask = request.createTask( + searchTaskIdGenerator.getAsLong(), + "search", + action.name(), + null, + Collections.emptyMap() + ); if (timeout == false) { if (rarely()) { // make sure that search is sometimes also called from the same thread before the task is returned @@ -193,7 +241,7 @@ public Task exe } return searchTask; default: - throw new UnsupportedOperationException(); + throw new AssertionError("unexpected action " + action.name()); } } @@ -224,9 +272,7 @@ public InetSocketAddress getRemoteAddress() { @Override public void close() { - if (open.compareAndSet(true, false) == false) { - throw new IllegalStateException("channel already closed!"); - } + assertTrue("HttpChannel is already closed", open.compareAndSet(true, false)); ActionListener listener = closeListener.get(); if (listener != null) { boolean failure = randomBoolean(); @@ -242,6 +288,7 @@ public void close() { } private void awaitClose() throws InterruptedException { + assertNotNull("must set closeListener before calling awaitClose", closeListener.get()); close(); closeLatch.await(); } @@ -258,7 +305,7 @@ public void addCloseListener(ActionListener listener) { listener.onResponse(null); } else { if (closeListener.compareAndSet(null, listener) == false) { - throw new IllegalStateException("close listener already set, only one is allowed!"); + throw new AssertionError("close listener already set, only one is allowed!"); } } }