Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.action.support.tasks;

import org.apache.http.client.methods.HttpGet;
import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksAction;
import org.elasticsearch.action.admin.cluster.state.ClusterStateAction;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.Cancellable;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.Response;
import org.elasticsearch.http.HttpSmokeTestCase;
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.transport.TransportService;

import java.util.ArrayList;
import java.util.concurrent.CancellationException;
import java.util.concurrent.TimeUnit;

import static org.elasticsearch.action.support.ActionTestUtils.wrapAsRestResponseListener;
import static org.elasticsearch.test.TaskAssertions.awaitTaskWithPrefix;

public class RestListTasksCancellationIT extends HttpSmokeTestCase {

public void testListTasksCancellation() throws Exception {
final Request clusterStateRequest = new Request(HttpGet.METHOD_NAME, "/_cluster/state");
clusterStateRequest.addParameter("wait_for_metadata_version", Long.toString(Long.MAX_VALUE));
clusterStateRequest.addParameter("wait_for_timeout", "1h");

final PlainActionFuture<Response> clusterStateFuture = new PlainActionFuture<>();
final Cancellable clusterStateCancellable = getRestClient().performRequestAsync(
clusterStateRequest,
wrapAsRestResponseListener(clusterStateFuture)
);

awaitTaskWithPrefix(ClusterStateAction.NAME);

final Request tasksRequest = new Request(HttpGet.METHOD_NAME, "/_tasks");
tasksRequest.addParameter("actions", ClusterStateAction.NAME);
tasksRequest.addParameter("wait_for_completion", Boolean.toString(true));
tasksRequest.addParameter("timeout", "1h");

final PlainActionFuture<Response> tasksFuture = new PlainActionFuture<>();
final Cancellable tasksCancellable = getRestClient().performRequestAsync(tasksRequest, wrapAsRestResponseListener(tasksFuture));

awaitTaskWithPrefix(ListTasksAction.NAME + "[n]");

tasksCancellable.cancel();

final var taskManagers = new ArrayList<TaskManager>(internalCluster().getNodeNames().length);
for (final var transportService : internalCluster().getInstances(TransportService.class)) {
Copy link
Contributor

@arteam arteam May 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's pity that internalCluster().getInstances returns Iterable, but it's actually a list. Otherwise, we could perform a check just with the Streams API.

 internalCluster().getInstances(TransportService.class)
     .stream()
     .map(TransportService::getTaskManager)
     .flatMap(taskManager -> taskManager.getCancellableTasks().values().stream())
     .anyMatch(t -> t.getAction().startsWith(ListTasksAction.NAME));

taskManagers.add(transportService.getTaskManager());
}
assertBusy(
() -> assertFalse(
taskManagers.stream()
.flatMap(taskManager -> taskManager.getCancellableTasks().values().stream())
.anyMatch(t -> t.getAction().startsWith(ListTasksAction.NAME))
)
);

expectThrows(CancellationException.class, () -> tasksFuture.actionGet(10, TimeUnit.SECONDS));
clusterStateCancellable.cancel();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;

import java.io.IOException;
import java.util.Map;

import static org.elasticsearch.action.ValidateActions.addValidationError;
import static org.elasticsearch.common.regex.Regex.simpleMatch;
Expand Down Expand Up @@ -119,4 +122,8 @@ public ListTasksRequest setDescriptions(String... descriptions) {
return this;
}

@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new CancellableTask(id, type, action, "", parentTaskId, headers);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.RemovedTaskListener;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.tasks.TaskInfo;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
Expand Down Expand Up @@ -76,7 +77,13 @@ protected void taskOperation(CancellableTask actionTask, ListTasksRequest reques
}

@Override
protected void processTasks(ListTasksRequest request, ActionListener<List<Task>> nodeOperation) {
protected void doExecute(Task task, ListTasksRequest request, ActionListener<ListTasksResponse> listener) {
assert task instanceof CancellableTask;
super.doExecute(task, request, listener);
}

@Override
protected void processTasks(CancellableTask nodeTask, ListTasksRequest request, ActionListener<List<Task>> nodeOperation) {
if (request.getWaitForCompletion()) {
final ListenableActionFuture<List<Task>> future = new ListenableActionFuture<>();
final List<Task> processedTasks = new ArrayList<>();
Expand Down Expand Up @@ -137,8 +144,9 @@ protected void processTasks(ListTasksRequest request, ActionListener<List<Task>>
threadPool,
ThreadPool.Names.SAME
);
nodeTask.addListener(() -> future.onFailure(new TaskCancelledException("task cancelled")));
} else {
super.processTasks(request, nodeOperation);
super.processTasks(nodeTask, request, nodeOperation);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ protected String[] resolveNodes(TasksRequest request, DiscoveryNodes discoveryNo
}
}

protected void processTasks(TasksRequest request, ActionListener<List<OperationTask>> nodeOperation) {
protected void processTasks(CancellableTask nodeTask, TasksRequest request, ActionListener<List<OperationTask>> nodeOperation) {
nodeOperation.onResponse(processTasks(request));
}

Expand Down Expand Up @@ -255,6 +255,7 @@ public void messageReceived(final NodeTaskRequest request, final TransportChanne
assert task instanceof CancellableTask;
TasksRequest tasksRequest = request.tasksRequest;
processTasks(
(CancellableTask) task,
tasksRequest,
new ChannelActionListener<NodeTasksResponse>(channel).delegateFailure(
(l, tasks) -> nodeOperation((CancellableTask) task, l, tasksRequest, tasks)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.elasticsearch.rest.BaseRestHandler;
import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.action.RestCancellableNodeClient;
import org.elasticsearch.rest.action.RestChunkedToXContentListener;
import org.elasticsearch.tasks.TaskId;

Expand Down Expand Up @@ -49,7 +50,9 @@ public String getName() {
public RestChannelConsumer prepareRequest(final RestRequest request, final NodeClient client) throws IOException {
final ListTasksRequest listTasksRequest = generateListTasksRequest(request);
final String groupBy = request.param("group_by", "nodes");
return channel -> client.admin().cluster().listTasks(listTasksRequest, listTasksResponseListener(nodesInCluster, groupBy, channel));
return channel -> new RestCancellableNodeClient(client, request.getHttpChannel()).admin()
.cluster()
.listTasks(listTasksRequest, listTasksResponseListener(nodesInCluster, groupBy, channel));
}

public static ListTasksRequest generateListTasksRequest(RestRequest request) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
import org.elasticsearch.client.ResponseListener;
import org.elasticsearch.core.CheckedConsumer;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.transport.Transport;

import static org.elasticsearch.action.support.PlainActionFuture.newFuture;
import static org.mockito.Mockito.mock;
import java.util.Map;
import java.util.concurrent.TimeUnit;

public class ActionTestUtils {

Expand All @@ -29,10 +30,11 @@ public static <Request extends ActionRequest, Response extends ActionResponse> R
TransportAction<Request, Response> action,
Request request
) {
PlainActionFuture<Response> future = newFuture();
Task task = mock(Task.class);
action.execute(task, request, future);
return future.actionGet();
return PlainActionFuture.get(
future -> action.execute(request.createTask(1L, "direct", action.actionName, TaskId.EMPTY_TASK_ID, Map.of()), request, future),
10,
TimeUnit.SECONDS
);
}

public static <Request extends ActionRequest, Response extends ActionResponse> Response executeBlockingWithTask(
Expand All @@ -41,9 +43,11 @@ public static <Request extends ActionRequest, Response extends ActionResponse> R
TransportAction<Request, Response> action,
Request request
) {
PlainActionFuture<Response> future = newFuture();
taskManager.registerAndExecute("transport", action, request, localConnection, future);
return future.actionGet();
return PlainActionFuture.get(
future -> taskManager.registerAndExecute("transport", action, request, localConnection, future),
10,
TimeUnit.SECONDS
);
}

/**
Expand Down