Skip to content

Commit bdd2b42

Browse files
authored
Collect warnings in compute service (#103031) (#103079)
We have implemented #99927 in DriverRunner. However, we also need to implement this in ComputeService, where we spawn multiple requests to avoid losing response headers. Relates #99927 Closes #100163 Closes #102982 Closes #102871 Closes #103028
1 parent 268565e commit bdd2b42

File tree

6 files changed

+253
-33
lines changed

6 files changed

+253
-33
lines changed

docs/changelog/103031.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
pr: 103031
2+
summary: Collect warnings in compute service
3+
area: ES|QL
4+
type: bug
5+
issues:
6+
- 100163
7+
- 103028
8+
- 102871
9+
- 102982

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverRunner.java

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,11 @@
99

1010
import org.elasticsearch.ExceptionsHelper;
1111
import org.elasticsearch.action.ActionListener;
12-
import org.elasticsearch.common.util.concurrent.AtomicArray;
1312
import org.elasticsearch.common.util.concurrent.CountDown;
1413
import org.elasticsearch.common.util.concurrent.ThreadContext;
1514
import org.elasticsearch.tasks.TaskCancelledException;
1615

17-
import java.util.HashMap;
18-
import java.util.LinkedHashSet;
1916
import java.util.List;
20-
import java.util.Map;
21-
import java.util.Set;
2217
import java.util.concurrent.atomic.AtomicReference;
2318

2419
/**
@@ -41,11 +36,10 @@ public DriverRunner(ThreadContext threadContext) {
4136
*/
4237
public void runToCompletion(List<Driver> drivers, ActionListener<Void> listener) {
4338
AtomicReference<Exception> failure = new AtomicReference<>();
44-
AtomicArray<Map<String, List<String>>> responseHeaders = new AtomicArray<>(drivers.size());
39+
var responseHeadersCollector = new ResponseHeadersCollector(threadContext);
4540
CountDown counter = new CountDown(drivers.size());
4641
for (int i = 0; i < drivers.size(); i++) {
4742
Driver driver = drivers.get(i);
48-
int driverIndex = i;
4943
ActionListener<Void> driverListener = new ActionListener<>() {
5044
@Override
5145
public void onResponse(Void unused) {
@@ -80,9 +74,9 @@ public void onFailure(Exception e) {
8074
}
8175

8276
private void done() {
83-
responseHeaders.setOnce(driverIndex, threadContext.getResponseHeaders());
77+
responseHeadersCollector.collect();
8478
if (counter.countDown()) {
85-
mergeResponseHeaders(responseHeaders);
79+
responseHeadersCollector.finish();
8680
Exception error = failure.get();
8781
if (error != null) {
8882
listener.onFailure(error);
@@ -96,23 +90,4 @@ private void done() {
9690
start(driver, driverListener);
9791
}
9892
}
99-
100-
private void mergeResponseHeaders(AtomicArray<Map<String, List<String>>> responseHeaders) {
101-
final Map<String, Set<String>> merged = new HashMap<>();
102-
for (int i = 0; i < responseHeaders.length(); i++) {
103-
final Map<String, List<String>> resp = responseHeaders.get(i);
104-
if (resp == null || resp.isEmpty()) {
105-
continue;
106-
}
107-
for (Map.Entry<String, List<String>> e : resp.entrySet()) {
108-
// Use LinkedHashSet to retain the order of the values
109-
merged.computeIfAbsent(e.getKey(), k -> new LinkedHashSet<>(e.getValue().size())).addAll(e.getValue());
110-
}
111-
}
112-
for (Map.Entry<String, Set<String>> e : merged.entrySet()) {
113-
for (String v : e.getValue()) {
114-
threadContext.addResponseHeader(e.getKey(), v);
115-
}
116-
}
117-
}
11893
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.compute.operator;
9+
10+
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
11+
import org.elasticsearch.common.util.concurrent.ThreadContext;
12+
13+
import java.util.HashMap;
14+
import java.util.LinkedHashSet;
15+
import java.util.List;
16+
import java.util.Map;
17+
import java.util.Queue;
18+
import java.util.Set;
19+
20+
/**
21+
* A helper class that can be used to collect and merge response headers from multiple child requests.
22+
*/
23+
public final class ResponseHeadersCollector {
24+
private final ThreadContext threadContext;
25+
private final Queue<Map<String, List<String>>> collected = ConcurrentCollections.newQueue();
26+
27+
public ResponseHeadersCollector(ThreadContext threadContext) {
28+
this.threadContext = threadContext;
29+
}
30+
31+
/**
32+
* Called when a child request is completed to collect the response headers of the responding thread
33+
*/
34+
public void collect() {
35+
Map<String, List<String>> responseHeaders = threadContext.getResponseHeaders();
36+
if (responseHeaders.isEmpty() == false) {
37+
collected.add(responseHeaders);
38+
}
39+
}
40+
41+
/**
42+
* Called when all child requests are completed. This will merge all collected response headers
43+
* from the child requests and restore to the current thread.
44+
*/
45+
public void finish() {
46+
final Map<String, Set<String>> merged = new HashMap<>();
47+
Map<String, List<String>> resp;
48+
while ((resp = collected.poll()) != null) {
49+
for (Map.Entry<String, List<String>> e : resp.entrySet()) {
50+
// Use LinkedHashSet to retain the order of the values
51+
merged.computeIfAbsent(e.getKey(), k -> new LinkedHashSet<>(e.getValue().size())).addAll(e.getValue());
52+
}
53+
}
54+
for (Map.Entry<String, Set<String>> e : merged.entrySet()) {
55+
for (String v : e.getValue()) {
56+
threadContext.addResponseHeader(e.getKey(), v);
57+
}
58+
}
59+
}
60+
}
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.compute.operator;
9+
10+
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.action.ActionRunnable;
12+
import org.elasticsearch.action.support.PlainActionFuture;
13+
import org.elasticsearch.action.support.RefCountingListener;
14+
import org.elasticsearch.common.settings.Settings;
15+
import org.elasticsearch.common.util.concurrent.EsExecutors;
16+
import org.elasticsearch.common.util.concurrent.ThreadContext;
17+
import org.elasticsearch.common.util.set.Sets;
18+
import org.elasticsearch.core.TimeValue;
19+
import org.elasticsearch.test.ESTestCase;
20+
import org.elasticsearch.threadpool.FixedExecutorBuilder;
21+
import org.elasticsearch.threadpool.TestThreadPool;
22+
23+
import java.util.HashSet;
24+
import java.util.List;
25+
import java.util.Set;
26+
import java.util.concurrent.CyclicBarrier;
27+
import java.util.concurrent.TimeUnit;
28+
29+
import static org.hamcrest.Matchers.equalTo;
30+
31+
public class ResponseHeadersCollectorTests extends ESTestCase {
32+
33+
public void testCollect() {
34+
int numThreads = randomIntBetween(1, 10);
35+
TestThreadPool threadPool = new TestThreadPool(
36+
getTestClass().getSimpleName(),
37+
new FixedExecutorBuilder(Settings.EMPTY, "test", numThreads, 1024, "test", EsExecutors.TaskTrackingConfig.DEFAULT)
38+
);
39+
Set<String> expectedWarnings = new HashSet<>();
40+
try {
41+
ThreadContext threadContext = threadPool.getThreadContext();
42+
var collector = new ResponseHeadersCollector(threadContext);
43+
PlainActionFuture<Void> future = new PlainActionFuture<>();
44+
Runnable mergeAndVerify = () -> {
45+
collector.finish();
46+
List<String> actualWarnings = threadContext.getResponseHeaders().getOrDefault("Warnings", List.of());
47+
assertThat(Sets.newHashSet(actualWarnings), equalTo(expectedWarnings));
48+
};
49+
try (RefCountingListener refs = new RefCountingListener(ActionListener.runAfter(future, mergeAndVerify))) {
50+
CyclicBarrier barrier = new CyclicBarrier(numThreads);
51+
for (int i = 0; i < numThreads; i++) {
52+
String warning = "warning-" + i;
53+
expectedWarnings.add(warning);
54+
ActionListener<Void> listener = ActionListener.runBefore(refs.acquire(), collector::collect);
55+
threadPool.schedule(new ActionRunnable<>(listener) {
56+
@Override
57+
protected void doRun() throws Exception {
58+
barrier.await(30, TimeUnit.SECONDS);
59+
try (ThreadContext.StoredContext ignored = threadContext.stashContext()) {
60+
threadContext.addResponseHeader("Warnings", warning);
61+
listener.onResponse(null);
62+
}
63+
}
64+
}, TimeValue.timeValueNanos(between(0, 1000_000)), threadPool.executor("test"));
65+
}
66+
}
67+
future.actionGet(TimeValue.timeValueSeconds(30));
68+
} finally {
69+
terminate(threadPool);
70+
}
71+
}
72+
}
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.action;
9+
10+
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.action.support.PlainActionFuture;
12+
import org.elasticsearch.cluster.node.DiscoveryNode;
13+
import org.elasticsearch.common.settings.Settings;
14+
import org.elasticsearch.test.junit.annotations.TestLogging;
15+
import org.elasticsearch.transport.TransportService;
16+
17+
import java.util.List;
18+
import java.util.Map;
19+
import java.util.concurrent.TimeUnit;
20+
21+
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
22+
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
23+
24+
@TestLogging(value = "org.elasticsearch.xpack.esql:TRACE", reason = "debug")
25+
public class WarningsIT extends AbstractEsqlIntegTestCase {
26+
27+
public void testCollectWarnings() {
28+
final String node1, node2;
29+
if (randomBoolean()) {
30+
internalCluster().ensureAtLeastNumDataNodes(2);
31+
node1 = randomDataNode().getName();
32+
node2 = randomValueOtherThan(node1, () -> randomDataNode().getName());
33+
} else {
34+
node1 = randomDataNode().getName();
35+
node2 = randomDataNode().getName();
36+
}
37+
38+
int numDocs1 = randomIntBetween(1, 15);
39+
assertAcked(
40+
client().admin()
41+
.indices()
42+
.prepareCreate("index-1")
43+
.setSettings(Settings.builder().put("index.routing.allocation.require._name", node1))
44+
.setMapping("host", "type=keyword")
45+
);
46+
for (int i = 0; i < numDocs1; i++) {
47+
client().prepareIndex("index-1").setSource("host", "192." + i).get();
48+
}
49+
int numDocs2 = randomIntBetween(1, 15);
50+
assertAcked(
51+
client().admin()
52+
.indices()
53+
.prepareCreate("index-2")
54+
.setSettings(Settings.builder().put("index.routing.allocation.require._name", node2))
55+
.setMapping("host", "type=keyword")
56+
);
57+
for (int i = 0; i < numDocs2; i++) {
58+
client().prepareIndex("index-2").setSource("host", "10." + i).get();
59+
}
60+
61+
DiscoveryNode coordinator = randomFrom(clusterService().state().nodes().stream().toList());
62+
client().admin().indices().prepareRefresh("index-1", "index-2").get();
63+
64+
EsqlQueryRequest request = new EsqlQueryRequest();
65+
request.query("FROM index-* | EVAL ip = to_ip(host) | STATS s = COUNT(*) by ip | KEEP ip | LIMIT 100");
66+
request.pragmas(randomPragmas());
67+
PlainActionFuture<EsqlQueryResponse> future = new PlainActionFuture<>();
68+
client(coordinator.getName()).execute(EsqlQueryAction.INSTANCE, request, ActionListener.runBefore(future, () -> {
69+
var threadpool = internalCluster().getInstance(TransportService.class, coordinator.getName()).getThreadPool();
70+
Map<String, List<String>> responseHeaders = threadpool.getThreadContext().getResponseHeaders();
71+
List<String> warnings = responseHeaders.getOrDefault("Warning", List.of())
72+
.stream()
73+
.filter(w -> w.contains("is not an IP string literal"))
74+
.toList();
75+
int expectedWarnings = Math.min(20, numDocs1 + numDocs2);
76+
// we cap the number of warnings per node
77+
assertThat(warnings.size(), greaterThanOrEqualTo(expectedWarnings));
78+
}));
79+
future.actionGet(30, TimeUnit.SECONDS).close();
80+
}
81+
82+
private DiscoveryNode randomDataNode() {
83+
return randomFrom(clusterService().state().nodes().getDataNodes().values());
84+
}
85+
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import org.elasticsearch.compute.data.Page;
3030
import org.elasticsearch.compute.operator.Driver;
3131
import org.elasticsearch.compute.operator.DriverTaskRunner;
32+
import org.elasticsearch.compute.operator.ResponseHeadersCollector;
3233
import org.elasticsearch.compute.operator.exchange.ExchangeResponse;
3334
import org.elasticsearch.compute.operator.exchange.ExchangeService;
3435
import org.elasticsearch.compute.operator.exchange.ExchangeSinkHandler;
@@ -150,6 +151,8 @@ public void execute(
150151
LOGGER.debug("Sending data node plan\n{}\n with filter [{}]", dataNodePlan, requestFilter);
151152

152153
String[] originalIndices = PlannerUtils.planOriginalIndices(physicalPlan);
154+
var responseHeadersCollector = new ResponseHeadersCollector(transportService.getThreadPool().getThreadContext());
155+
listener = ActionListener.runBefore(listener, responseHeadersCollector::finish);
153156
computeTargetNodes(
154157
rootTask,
155158
requestFilter,
@@ -170,7 +173,16 @@ public void execute(
170173
exchangeSource.addCompletionListener(requestRefs.acquire());
171174
// run compute on the coordinator
172175
var computeContext = new ComputeContext(sessionId, List.of(), configuration, exchangeSource, null);
173-
runCompute(rootTask, computeContext, coordinatorPlan, cancelOnFailure(rootTask, cancelled, requestRefs.acquire()));
176+
runCompute(
177+
rootTask,
178+
computeContext,
179+
coordinatorPlan,
180+
cancelOnFailure(
181+
rootTask,
182+
cancelled,
183+
ActionListener.runBefore(requestRefs.acquire(), responseHeadersCollector::collect)
184+
)
185+
);
174186
// run compute on remote nodes
175187
// TODO: This is wrong, we need to be able to cancel
176188
runComputeOnRemoteNodes(
@@ -180,7 +192,11 @@ public void execute(
180192
dataNodePlan,
181193
exchangeSource,
182194
targetNodes,
183-
() -> cancelOnFailure(rootTask, cancelled, requestRefs.acquire()).map(unused -> null)
195+
() -> cancelOnFailure(
196+
rootTask,
197+
cancelled,
198+
ActionListener.runBefore(requestRefs.acquire(), responseHeadersCollector::collect)
199+
)
184200
);
185201
}
186202
})
@@ -194,7 +210,7 @@ private void runComputeOnRemoteNodes(
194210
PhysicalPlan dataNodePlan,
195211
ExchangeSourceHandler exchangeSource,
196212
List<TargetNode> targetNodes,
197-
Supplier<ActionListener<DataNodeResponse>> listener
213+
Supplier<ActionListener<Void>> listener
198214
) {
199215
// Do not complete the exchange sources until we have linked all remote sinks
200216
final SubscribableListener<Void> blockingSinkFuture = new SubscribableListener<>();
@@ -223,7 +239,7 @@ private void runComputeOnRemoteNodes(
223239
new DataNodeRequest(sessionId, configuration, targetNode.shardIds, targetNode.aliasFilters, dataNodePlan),
224240
rootTask,
225241
TransportRequestOptions.EMPTY,
226-
new ActionListenerResponseHandler<>(delegate, DataNodeResponse::new, esqlExecutor)
242+
new ActionListenerResponseHandler<>(delegate.map(ignored -> null), DataNodeResponse::new, esqlExecutor)
227243
);
228244
})
229245
);
@@ -442,7 +458,10 @@ public void messageReceived(DataNodeRequest request, TransportChannel channel, T
442458
runCompute(parentTask, computeContext, request.plan(), ActionListener.wrap(unused -> {
443459
// don't return until all pages are fetched
444460
exchangeSink.addCompletionListener(
445-
ActionListener.releaseAfter(listener, () -> exchangeService.finishSinkHandler(sessionId, null))
461+
ContextPreservingActionListener.wrapPreservingContext(
462+
ActionListener.releaseAfter(listener, () -> exchangeService.finishSinkHandler(sessionId, null)),
463+
transportService.getThreadPool().getThreadContext()
464+
)
446465
);
447466
}, e -> {
448467
exchangeService.finishSinkHandler(sessionId, e);

0 commit comments

Comments
 (0)