diff --git a/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/action/support/tasks/RestListTasksCancellationIT.java b/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/action/support/tasks/RestListTasksCancellationIT.java new file mode 100644 index 0000000000000..29e59af7b9f70 --- /dev/null +++ b/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/action/support/tasks/RestListTasksCancellationIT.java @@ -0,0 +1,72 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.action.support.tasks; + +import org.apache.http.client.methods.HttpGet; +import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksAction; +import org.elasticsearch.action.admin.cluster.state.ClusterStateAction; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.client.Cancellable; +import org.elasticsearch.client.Request; +import org.elasticsearch.client.Response; +import org.elasticsearch.http.HttpSmokeTestCase; +import org.elasticsearch.tasks.TaskManager; +import org.elasticsearch.transport.TransportService; + +import java.util.ArrayList; +import java.util.concurrent.CancellationException; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.action.support.ActionTestUtils.wrapAsRestResponseListener; +import static org.elasticsearch.test.TaskAssertions.awaitTaskWithPrefix; + +public class RestListTasksCancellationIT extends HttpSmokeTestCase { + + public void testListTasksCancellation() throws Exception { + final Request clusterStateRequest = new Request(HttpGet.METHOD_NAME, "/_cluster/state"); + clusterStateRequest.addParameter("wait_for_metadata_version", Long.toString(Long.MAX_VALUE)); + clusterStateRequest.addParameter("wait_for_timeout", "1h"); + + final PlainActionFuture clusterStateFuture = new PlainActionFuture<>(); + final Cancellable clusterStateCancellable = getRestClient().performRequestAsync( + clusterStateRequest, + wrapAsRestResponseListener(clusterStateFuture) + ); + + awaitTaskWithPrefix(ClusterStateAction.NAME); + + final Request tasksRequest = new Request(HttpGet.METHOD_NAME, "/_tasks"); + tasksRequest.addParameter("actions", ClusterStateAction.NAME); + tasksRequest.addParameter("wait_for_completion", Boolean.toString(true)); + tasksRequest.addParameter("timeout", "1h"); + + final PlainActionFuture tasksFuture = new PlainActionFuture<>(); + final Cancellable tasksCancellable = getRestClient().performRequestAsync(tasksRequest, wrapAsRestResponseListener(tasksFuture)); + + awaitTaskWithPrefix(ListTasksAction.NAME + "[n]"); + + tasksCancellable.cancel(); + + final var taskManagers = new ArrayList(internalCluster().getNodeNames().length); + for (final var transportService : internalCluster().getInstances(TransportService.class)) { + taskManagers.add(transportService.getTaskManager()); + } + assertBusy( + () -> assertFalse( + taskManagers.stream() + .flatMap(taskManager -> taskManager.getCancellableTasks().values().stream()) + .anyMatch(t -> t.getAction().startsWith(ListTasksAction.NAME)) + ) + ); + + expectThrows(CancellationException.class, () -> tasksFuture.actionGet(10, TimeUnit.SECONDS)); + clusterStateCancellable.cancel(); + } + +} diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/node/tasks/list/ListTasksRequest.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/node/tasks/list/ListTasksRequest.java index 5b0194c81283e..597c9821e48ec 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/node/tasks/list/ListTasksRequest.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/node/tasks/list/ListTasksRequest.java @@ -14,9 +14,12 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskId; import java.io.IOException; +import java.util.Map; import static org.elasticsearch.action.ValidateActions.addValidationError; import static org.elasticsearch.common.regex.Regex.simpleMatch; @@ -119,4 +122,8 @@ public ListTasksRequest setDescriptions(String... descriptions) { return this; } + @Override + public Task createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { + return new CancellableTask(id, type, action, "", parentTaskId, headers); + } } diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/node/tasks/list/TransportListTasksAction.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/node/tasks/list/TransportListTasksAction.java index d3a9ab80db5ca..eaaebb5d2bb9c 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/node/tasks/list/TransportListTasksAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/node/tasks/list/TransportListTasksAction.java @@ -24,6 +24,7 @@ import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.RemovedTaskListener; import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.tasks.TaskInfo; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; @@ -76,7 +77,13 @@ protected void taskOperation(CancellableTask actionTask, ListTasksRequest reques } @Override - protected void processTasks(ListTasksRequest request, ActionListener> nodeOperation) { + protected void doExecute(Task task, ListTasksRequest request, ActionListener listener) { + assert task instanceof CancellableTask; + super.doExecute(task, request, listener); + } + + @Override + protected void processTasks(CancellableTask nodeTask, ListTasksRequest request, ActionListener> nodeOperation) { if (request.getWaitForCompletion()) { final ListenableActionFuture> future = new ListenableActionFuture<>(); final List processedTasks = new ArrayList<>(); @@ -137,8 +144,9 @@ protected void processTasks(ListTasksRequest request, ActionListener> threadPool, ThreadPool.Names.SAME ); + nodeTask.addListener(() -> future.onFailure(new TaskCancelledException("task cancelled"))); } else { - super.processTasks(request, nodeOperation); + super.processTasks(nodeTask, request, nodeOperation); } } } diff --git a/server/src/main/java/org/elasticsearch/action/support/tasks/TransportTasksAction.java b/server/src/main/java/org/elasticsearch/action/support/tasks/TransportTasksAction.java index 4c563b95449e7..81f8f575f528d 100644 --- a/server/src/main/java/org/elasticsearch/action/support/tasks/TransportTasksAction.java +++ b/server/src/main/java/org/elasticsearch/action/support/tasks/TransportTasksAction.java @@ -198,7 +198,7 @@ protected String[] resolveNodes(TasksRequest request, DiscoveryNodes discoveryNo } } - protected void processTasks(TasksRequest request, ActionListener> nodeOperation) { + protected void processTasks(CancellableTask nodeTask, TasksRequest request, ActionListener> nodeOperation) { nodeOperation.onResponse(processTasks(request)); } @@ -255,6 +255,7 @@ public void messageReceived(final NodeTaskRequest request, final TransportChanne assert task instanceof CancellableTask; TasksRequest tasksRequest = request.tasksRequest; processTasks( + (CancellableTask) task, tasksRequest, new ChannelActionListener(channel).delegateFailure( (l, tasks) -> nodeOperation((CancellableTask) task, l, tasksRequest, tasks) diff --git a/server/src/main/java/org/elasticsearch/rest/action/admin/cluster/RestListTasksAction.java b/server/src/main/java/org/elasticsearch/rest/action/admin/cluster/RestListTasksAction.java index 99417fbc962b7..cbf8baa9a2ea9 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/admin/cluster/RestListTasksAction.java +++ b/server/src/main/java/org/elasticsearch/rest/action/admin/cluster/RestListTasksAction.java @@ -18,6 +18,7 @@ import org.elasticsearch.rest.BaseRestHandler; import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestCancellableNodeClient; import org.elasticsearch.rest.action.RestChunkedToXContentListener; import org.elasticsearch.tasks.TaskId; @@ -49,7 +50,9 @@ public String getName() { public RestChannelConsumer prepareRequest(final RestRequest request, final NodeClient client) throws IOException { final ListTasksRequest listTasksRequest = generateListTasksRequest(request); final String groupBy = request.param("group_by", "nodes"); - return channel -> client.admin().cluster().listTasks(listTasksRequest, listTasksResponseListener(nodesInCluster, groupBy, channel)); + return channel -> new RestCancellableNodeClient(client, request.getHttpChannel()).admin() + .cluster() + .listTasks(listTasksRequest, listTasksResponseListener(nodesInCluster, groupBy, channel)); } public static ListTasksRequest generateListTasksRequest(RestRequest request) { diff --git a/test/framework/src/main/java/org/elasticsearch/action/support/ActionTestUtils.java b/test/framework/src/main/java/org/elasticsearch/action/support/ActionTestUtils.java index b9ae3d0a62e91..49c3df17d60dd 100644 --- a/test/framework/src/main/java/org/elasticsearch/action/support/ActionTestUtils.java +++ b/test/framework/src/main/java/org/elasticsearch/action/support/ActionTestUtils.java @@ -15,11 +15,12 @@ import org.elasticsearch.client.ResponseListener; import org.elasticsearch.core.CheckedConsumer; import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskId; import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.transport.Transport; -import static org.elasticsearch.action.support.PlainActionFuture.newFuture; -import static org.mockito.Mockito.mock; +import java.util.Map; +import java.util.concurrent.TimeUnit; public class ActionTestUtils { @@ -29,10 +30,11 @@ public static R TransportAction action, Request request ) { - PlainActionFuture future = newFuture(); - Task task = mock(Task.class); - action.execute(task, request, future); - return future.actionGet(); + return PlainActionFuture.get( + future -> action.execute(request.createTask(1L, "direct", action.actionName, TaskId.EMPTY_TASK_ID, Map.of()), request, future), + 10, + TimeUnit.SECONDS + ); } public static Response executeBlockingWithTask( @@ -41,9 +43,11 @@ public static R TransportAction action, Request request ) { - PlainActionFuture future = newFuture(); - taskManager.registerAndExecute("transport", action, request, localConnection, future); - return future.actionGet(); + return PlainActionFuture.get( + future -> taskManager.registerAndExecute("transport", action, request, localConnection, future), + 10, + TimeUnit.SECONDS + ); } /**