Skip to content

Commit d903f24

Browse files
authored
Add query cancellation support via _tasks/_cancel API for PPL queries (#5254)
* Add query cancellation support via _tasks/_cancel API for PPL queries Signed-off-by: Sunil Ramchandra Pawar <pawar_sr@apple.com> * Refactor PPL query cancellation to cooperative model and other PR suggestions. Signed-off-by: Sunil Ramchandra Pawar <pawar_sr@apple.com> --------- Signed-off-by: Sunil Ramchandra Pawar <pawar_sr@apple.com>
1 parent bebb75b commit d903f24

File tree

8 files changed

+164
-2
lines changed

8 files changed

+164
-2
lines changed

opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchQueryManager.java

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.opensearch.sql.executor.QueryId;
1717
import org.opensearch.sql.executor.QueryManager;
1818
import org.opensearch.sql.executor.execution.AbstractPlan;
19+
import org.opensearch.tasks.CancellableTask;
1920
import org.opensearch.threadpool.Scheduler;
2021
import org.opensearch.threadpool.ThreadPool;
2122
import org.opensearch.transport.client.node.NodeClient;
@@ -33,15 +34,32 @@ public class OpenSearchQueryManager implements QueryManager {
3334
public static final String SQL_WORKER_THREAD_POOL_NAME = "sql-worker";
3435
public static final String SQL_BACKGROUND_THREAD_POOL_NAME = "sql_background_io";
3536

37+
private static final ThreadLocal<CancellableTask> cancellableTask = new ThreadLocal<>();
38+
39+
public static void setCancellableTask(CancellableTask task) {
40+
cancellableTask.set(task);
41+
}
42+
43+
public static CancellableTask getCancellableTask() {
44+
return cancellableTask.get();
45+
}
46+
47+
public static void clearCancellableTask() {
48+
cancellableTask.remove();
49+
}
50+
3651
@Override
3752
public QueryId submit(AbstractPlan queryPlan) {
3853
TimeValue timeout = settings.getSettingValue(Settings.Key.PPL_QUERY_TIMEOUT);
39-
schedule(nodeClient, queryPlan::execute, timeout);
54+
CancellableTask cancelTask = cancellableTask.get();
55+
cancellableTask.remove();
56+
schedule(nodeClient, queryPlan::execute, timeout, cancelTask);
4057

4158
return queryPlan.getQueryId();
4259
}
4360

44-
private void schedule(NodeClient client, Runnable task, TimeValue timeout) {
61+
private void schedule(
62+
NodeClient client, Runnable task, TimeValue timeout, CancellableTask cancelTask) {
4563
ThreadPool threadPool = client.threadPool();
4664

4765
Runnable wrappedTask =
@@ -60,6 +78,8 @@ private void schedule(NodeClient client, Runnable task, TimeValue timeout) {
6078
timeout,
6179
ThreadPool.Names.GENERIC);
6280

81+
setCancellableTask(cancelTask);
82+
6383
try {
6484
task.run();
6585
timeoutTask.cancel();
@@ -76,6 +96,8 @@ private void schedule(NodeClient client, Runnable task, TimeValue timeout) {
7696
}
7797

7898
throw e;
99+
} finally {
100+
clearCancellableTask();
79101
}
80102
});
81103

opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexEnumerator.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,16 @@
1111
import lombok.EqualsAndHashCode;
1212
import lombok.ToString;
1313
import org.apache.calcite.linq4j.Enumerator;
14+
import org.opensearch.core.tasks.TaskCancelledException;
1415
import org.opensearch.sql.data.model.ExprValue;
1516
import org.opensearch.sql.data.model.ExprValueUtils;
1617
import org.opensearch.sql.exception.NonFallbackCalciteException;
1718
import org.opensearch.sql.expression.HighlightExpression;
1819
import org.opensearch.sql.monitor.ResourceMonitor;
1920
import org.opensearch.sql.opensearch.client.OpenSearchClient;
21+
import org.opensearch.sql.opensearch.executor.OpenSearchQueryManager;
2022
import org.opensearch.sql.opensearch.request.OpenSearchRequest;
23+
import org.opensearch.tasks.CancellableTask;
2124

2225
/**
2326
* Supports a simple iteration over a collection for OpenSearch index
@@ -55,6 +58,8 @@ public class OpenSearchIndexEnumerator implements Enumerator<Object> {
5558

5659
private ExprValue current = null;
5760

61+
private CancellableTask cancellableTask;
62+
5863
public OpenSearchIndexEnumerator(
5964
OpenSearchClient client,
6065
List<String> fields,
@@ -80,6 +85,7 @@ public OpenSearchIndexEnumerator(
8085
this.client = client;
8186
this.bgScanner = new BackgroundSearchScanner(client, maxResultWindow, queryBucketSize);
8287
this.bgScanner.startScanning(request);
88+
this.cancellableTask = OpenSearchQueryManager.getCancellableTask();
8389
}
8490

8591
private Iterator<ExprValue> fetchNextBatch() {
@@ -112,6 +118,10 @@ public boolean moveNext() {
112118
return false;
113119
}
114120

121+
if (cancellableTask != null && cancellableTask.isCancelled()) {
122+
throw new TaskCancelledException("The task is cancelled.");
123+
}
124+
115125
boolean shouldCheck = (queryCount % NUMBER_OF_NEXT_CALL_TO_CHECK == 0);
116126
if (shouldCheck) {
117127
org.opensearch.sql.monitor.ResourceStatus status = this.monitor.getStatus();

plugin/src/main/java/org/opensearch/sql/plugin/request/PPLQueryRequestFactory.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,11 @@ private static PPLQueryRequest parsePPLRequestFromPayload(RestRequest restReques
113113
if (pretty) {
114114
pplRequest.style(JsonResponseFormatter.Style.PRETTY);
115115
}
116+
// set queryId
117+
String queryId = jsonContent.optString("queryId", null);
118+
if (queryId != null) {
119+
pplRequest.queryId(queryId);
120+
}
116121
return pplRequest;
117122
} catch (JSONException e) {
118123
throw new IllegalArgumentException("Failed to parse request payload", e);
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.plugin.transport;
7+
8+
import java.util.Map;
9+
import org.opensearch.core.tasks.TaskId;
10+
import org.opensearch.tasks.CancellableTask;
11+
12+
public class PPLQueryTask extends CancellableTask {
13+
14+
public PPLQueryTask(
15+
long id,
16+
String type,
17+
String action,
18+
String description,
19+
TaskId parentTaskId,
20+
Map<String, String> headers) {
21+
super(id, type, action, description, parentTaskId, headers);
22+
}
23+
24+
@Override
25+
public boolean shouldCancelChildrenOnCancellation() {
26+
return true;
27+
}
28+
}

plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.opensearch.sql.legacy.metrics.MetricName;
3232
import org.opensearch.sql.legacy.metrics.Metrics;
3333
import org.opensearch.sql.monitor.profile.QueryProfiling;
34+
import org.opensearch.sql.opensearch.executor.OpenSearchQueryManager;
3435
import org.opensearch.sql.opensearch.setting.OpenSearchSettings;
3536
import org.opensearch.sql.plugin.config.OpenSearchPluginModule;
3637
import org.opensearch.sql.ppl.PPLService;
@@ -109,6 +110,9 @@ protected void doExecute(
109110
return;
110111
}
111112

113+
if (task instanceof PPLQueryTask pplQueryTask) {
114+
OpenSearchQueryManager.setCancellableTask(pplQueryTask);
115+
}
112116
Metrics.getInstance().getNumericalMetric(MetricName.PPL_REQ_TOTAL).increment();
113117
Metrics.getInstance().getNumericalMetric(MetricName.PPL_REQ_COUNT_TOTAL).increment();
114118

plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryRequest.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import java.io.ByteArrayOutputStream;
1010
import java.io.IOException;
1111
import java.util.Locale;
12+
import java.util.Map;
1213
import java.util.Optional;
1314
import lombok.Getter;
1415
import lombok.RequiredArgsConstructor;
@@ -21,6 +22,7 @@
2122
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
2223
import org.opensearch.core.common.io.stream.StreamInput;
2324
import org.opensearch.core.common.io.stream.StreamOutput;
25+
import org.opensearch.core.tasks.TaskId;
2426
import org.opensearch.sql.ppl.domain.PPLQueryRequest;
2527
import org.opensearch.sql.protocol.response.format.Format;
2628
import org.opensearch.sql.protocol.response.format.JsonResponseFormatter;
@@ -51,6 +53,11 @@ public class TransportPPLQueryRequest extends ActionRequest {
5153
@Accessors(fluent = true)
5254
private boolean profile = false;
5355

56+
@Setter
57+
@Getter
58+
@Accessors(fluent = true)
59+
private String queryId = null;
60+
5461
/** Constructor of TransportPPLQueryRequest from PPLQueryRequest. */
5562
public TransportPPLQueryRequest(PPLQueryRequest pplQueryRequest) {
5663
pplQuery = pplQueryRequest.getRequest();
@@ -61,6 +68,7 @@ public TransportPPLQueryRequest(PPLQueryRequest pplQueryRequest) {
6168
style = pplQueryRequest.style();
6269
profile = pplQueryRequest.profile();
6370
explainMode = pplQueryRequest.mode().getModeName();
71+
queryId = pplQueryRequest.queryId();
6472
}
6573

6674
/** Constructor of TransportPPLQueryRequest from StreamInput. */
@@ -75,6 +83,7 @@ public TransportPPLQueryRequest(StreamInput in) throws IOException {
7583
sanitize = in.readBoolean();
7684
style = in.readEnum(JsonResponseFormatter.Style.class);
7785
profile = in.readBoolean();
86+
queryId = in.readOptionalString();
7887
}
7988

8089
/** Re-create the object from the actionRequest. */
@@ -107,6 +116,7 @@ public void writeTo(StreamOutput out) throws IOException {
107116
out.writeBoolean(sanitize);
108117
out.writeEnum(style);
109118
out.writeBoolean(profile);
119+
out.writeOptionalString(queryId);
110120
}
111121

112122
public String getRequest() {
@@ -147,12 +157,25 @@ public ActionRequestValidationException validate() {
147157
return null;
148158
}
149159

160+
@Override
161+
public PPLQueryTask createTask(
162+
long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
163+
return new PPLQueryTask(id, type, action, getDescription(), parentTaskId, headers);
164+
}
165+
166+
@Override
167+
public String getDescription() {
168+
String prefix = (queryId != null) ? "PPL [queryId=" + queryId + "]: " : "PPL: ";
169+
return prefix + pplQuery;
170+
}
171+
150172
/** Convert to PPLQueryRequest. */
151173
public PPLQueryRequest toPPLQueryRequest() {
152174
PPLQueryRequest pplQueryRequest =
153175
new PPLQueryRequest(pplQuery, jsonContent, path, format, explainMode, profile);
154176
pplQueryRequest.sanitize(sanitize);
155177
pplQueryRequest.style(style);
178+
pplQueryRequest.queryId(queryId);
156179
return pplQueryRequest;
157180
}
158181
}
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.plugin.transport;
7+
8+
import static org.junit.Assert.*;
9+
10+
import java.util.Map;
11+
import org.junit.Test;
12+
import org.opensearch.core.tasks.TaskId;
13+
14+
public class PPLQueryTaskTest {
15+
16+
@Test
17+
public void testShouldCancelChildrenReturnsTrue() {
18+
PPLQueryTask pplQueryTask =
19+
new PPLQueryTask(
20+
1,
21+
"transport",
22+
"cluster:admin/opensearch/ppl",
23+
"test query",
24+
TaskId.EMPTY_TASK_ID,
25+
Map.of());
26+
assertTrue(pplQueryTask.shouldCancelChildrenOnCancellation());
27+
}
28+
29+
@Test
30+
public void testCreateTaskReturnsPPLQueryTask() {
31+
TransportPPLQueryRequest transportPPLQueryRequest =
32+
new TransportPPLQueryRequest("source=t a=1", null, "/_plugins/_ppl");
33+
PPLQueryTask task =
34+
transportPPLQueryRequest.createTask(
35+
1, "transport", "cluster:admin/opensearch/ppl", TaskId.EMPTY_TASK_ID, Map.of());
36+
assertNotNull(task);
37+
}
38+
39+
@Test
40+
public void testWithQueryId() {
41+
TransportPPLQueryRequest transportPPLQueryRequest =
42+
new TransportPPLQueryRequest("source=t a=1", null, "/_plugins/_ppl");
43+
transportPPLQueryRequest.queryId("test-123");
44+
assertEquals("PPL [queryId=test-123]: source=t a=1", transportPPLQueryRequest.getDescription());
45+
}
46+
47+
@Test
48+
public void testWithoutQueryId() {
49+
TransportPPLQueryRequest transportPPLQueryRequest =
50+
new TransportPPLQueryRequest("source=t a=1", null, "/_plugins/_ppl");
51+
assertEquals("PPL: source=t a=1", transportPPLQueryRequest.getDescription());
52+
}
53+
54+
@Test
55+
public void testCooperativeModel() {
56+
TransportPPLQueryRequest transportPPLQueryRequest =
57+
new TransportPPLQueryRequest("source=t a=1", null, "/_plugins/_ppl");
58+
PPLQueryTask task =
59+
transportPPLQueryRequest.createTask(
60+
1, "transport", "cluster:admin/opensearch/ppl", TaskId.EMPTY_TASK_ID, Map.of());
61+
assertFalse(task.isCancelled());
62+
task.cancel("Test");
63+
assertTrue(task.isCancelled());
64+
}
65+
}

ppl/src/main/java/org/opensearch/sql/ppl/domain/PPLQueryRequest.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ public class PPLQueryRequest {
5252
@Accessors(fluent = true)
5353
private boolean profile = false;
5454

55+
@Setter
56+
@Getter
57+
@Accessors(fluent = true)
58+
private String queryId = null;
59+
5560
public PPLQueryRequest(String pplQuery, JSONObject jsonContent, String path) {
5661
this(pplQuery, jsonContent, path, "");
5762
}

0 commit comments

Comments
 (0)