Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/126686.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 126686
summary: Fix race condition in `RestCancellableNodeClient`
area: Task Management
type: bug
issues:
- 88201
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,12 @@
import org.apache.http.client.methods.HttpGet;
import org.elasticsearch.action.admin.indices.segments.IndicesSegmentsAction;
import org.elasticsearch.client.Request;
import org.elasticsearch.test.junit.annotations.TestIssueLogging;

public class IndicesSegmentsRestCancellationIT extends BlockedSearcherRestCancellationTestCase {
@TestIssueLogging(
issueUrl = "https://github.com/elastic/elasticsearch/issues/88201",
value = "org.elasticsearch.http.BlockedSearcherRestCancellationTestCase:DEBUG"
+ ",org.elasticsearch.transport.TransportService:TRACE"
)
public void testIndicesSegmentsRestCancellation() throws Exception {
runTest(new Request(HttpGet.METHOD_NAME, "/_segments"), IndicesSegmentsAction.NAME);
}

@TestIssueLogging(
issueUrl = "https://github.com/elastic/elasticsearch/issues/88201",
value = "org.elasticsearch.http.BlockedSearcherRestCancellationTestCase:DEBUG"
+ ",org.elasticsearch.transport.TransportService:TRACE"
)
public void testCatSegmentsRestCancellation() throws Exception {
runTest(new Request(HttpGet.METHOD_NAME, "/_cat/segments"), IndicesSegmentsAction.NAME);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ private void cancelTask(TaskId taskId) {
private class CloseListener implements ActionListener<Void> {
private final AtomicReference<HttpChannel> channel = new AtomicReference<>();
private final Set<TaskId> tasks = new HashSet<>();
private boolean tasksDrained = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: whatever you prefer really but to me it seems that in these spots it's mostly easier to just make tasks non-final and null ist out to signal tasksDrained, removing the need for the copy and somewhat hardening the design against adding a task after tasks have been drained?
It's also one less field which is always nice :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah fair enough, see 20f492f. Required adding more null checks than I initially expected...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah right the stats :) thanks!


CloseListener() {}

Expand All @@ -130,11 +131,17 @@ void maybeRegisterChannel(HttpChannel httpChannel) {
}
}

synchronized void registerTask(TaskHolder taskHolder, TaskId taskId) {
taskHolder.taskId = taskId;
if (taskHolder.completed == false) {
this.tasks.add(taskId);
void registerTask(TaskHolder taskHolder, TaskId taskId) {
synchronized (this) {
taskHolder.taskId = taskId;
if (tasksDrained == false) {
if (taskHolder.completed == false) {
this.tasks.add(taskId);
}
return;
}
}
cancelTask(taskId);
}

synchronized void unregisterTask(TaskHolder taskHolder) {
Expand All @@ -155,6 +162,7 @@ public void onResponse(Void aVoid) {
synchronized (this) {
toCancel = new ArrayList<>(tasks);
tasks.clear();
tasksDrained = true;
}
for (TaskId taskId : toCancel) {
cancelTask(taskId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.internal.node.NodeClient;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.http.HttpChannel;
import org.elasticsearch.http.HttpResponse;
import org.elasticsearch.tasks.Task;
Expand All @@ -44,6 +45,7 @@
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.LongSupplier;

public class RestCancellableNodeClientTests extends ESTestCase {

Expand Down Expand Up @@ -148,8 +150,42 @@ public void testChannelAlreadyClosed() {
assertEquals(totalSearches, testClient.cancelledTasks.size());
}

public void testConcurrentExecuteAndClose() throws Exception {
final var testClient = new TestClient(Settings.EMPTY, threadPool, true);
int initialHttpChannels = RestCancellableNodeClient.getNumChannels();
int numTasks = randomIntBetween(1, 30);
TestHttpChannel channel = new TestHttpChannel();
final var startLatch = new CountDownLatch(1);
final var doneLatch = new CountDownLatch(numTasks + 1);
final var expectedTasks = Sets.<TaskId>newHashSetWithExpectedSize(numTasks);
for (int j = 0; j < numTasks; j++) {
RestCancellableNodeClient client = new RestCancellableNodeClient(testClient, channel);
threadPool.generic().execute(() -> {
client.execute(TransportSearchAction.TYPE, new SearchRequest(), ActionListener.running(ESTestCase::fail));
startLatch.countDown();
doneLatch.countDown();
});
expectedTasks.add(new TaskId(testClient.getLocalNodeId(), j));
}
threadPool.generic().execute(() -> {
try {
safeAwait(startLatch);
channel.awaitClose();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new AssertionError(e);
} finally {
doneLatch.countDown();
}
});
safeAwait(doneLatch);
assertEquals(initialHttpChannels, RestCancellableNodeClient.getNumChannels());
assertEquals(expectedTasks, testClient.cancelledTasks);
}

private static class TestClient extends NodeClient {
private final AtomicLong counter = new AtomicLong(0);
private final LongSupplier searchTaskIdGenerator = new AtomicLong(0)::getAndIncrement;
private final LongSupplier cancelTaskIdGenerator = new AtomicLong(1000)::getAndIncrement;
private final Set<TaskId> cancelledTasks = new CopyOnWriteArraySet<>();
private final AtomicInteger searchRequests = new AtomicInteger(0);
private final boolean timeout;
Expand All @@ -167,9 +203,17 @@ public <Request extends ActionRequest, Response extends ActionResponse> Task exe
) {
switch (action.name()) {
case TransportCancelTasksAction.NAME -> {
CancelTasksRequest cancelTasksRequest = (CancelTasksRequest) request;
assertTrue("tried to cancel the same task more than once", cancelledTasks.add(cancelTasksRequest.getTargetTaskId()));
Task task = request.createTask(counter.getAndIncrement(), "cancel_task", action.name(), null, Collections.emptyMap());
assertTrue(
"tried to cancel the same task more than once",
cancelledTasks.add(asInstanceOf(CancelTasksRequest.class, request).getTargetTaskId())
);
Task task = request.createTask(
cancelTaskIdGenerator.getAsLong(),
"cancel_task",
action.name(),
null,
Collections.emptyMap()
);
if (randomBoolean()) {
listener.onResponse(null);
} else {
Expand All @@ -180,7 +224,13 @@ public <Request extends ActionRequest, Response extends ActionResponse> Task exe
}
case TransportSearchAction.NAME -> {
searchRequests.incrementAndGet();
Task searchTask = request.createTask(counter.getAndIncrement(), "search", action.name(), null, Collections.emptyMap());
Task searchTask = request.createTask(
searchTaskIdGenerator.getAsLong(),
"search",
action.name(),
null,
Collections.emptyMap()
);
if (timeout == false) {
if (rarely()) {
// make sure that search is sometimes also called from the same thread before the task is returned
Expand All @@ -191,7 +241,7 @@ public <Request extends ActionRequest, Response extends ActionResponse> Task exe
}
return searchTask;
}
default -> throw new UnsupportedOperationException();
default -> throw new AssertionError("unexpected action " + action.name());
}

}
Expand Down Expand Up @@ -222,10 +272,7 @@ public InetSocketAddress getRemoteAddress() {

@Override
public void close() {
if (open.compareAndSet(true, false) == false) {
assert false : "HttpChannel is already closed";
return; // nothing to do
}
assertTrue("HttpChannel is already closed", open.compareAndSet(true, false));
ActionListener<Void> listener = closeListener.get();
if (listener != null) {
boolean failure = randomBoolean();
Expand All @@ -241,6 +288,7 @@ public void close() {
}

private void awaitClose() throws InterruptedException {
assertNotNull("must set closeListener before calling awaitClose", closeListener.get());
close();
closeLatch.await();
}
Expand All @@ -257,7 +305,7 @@ public void addCloseListener(ActionListener<Void> listener) {
listener.onResponse(null);
} else {
if (closeListener.compareAndSet(null, listener) == false) {
throw new IllegalStateException("close listener already set, only one is allowed!");
throw new AssertionError("close listener already set, only one is allowed!");
}
}
}
Expand Down
Loading