Skip to content

Commit 0724840

Browse files
committed
Simplify DataNodeRequestSender
1 parent 3db258e commit 0724840

File tree

2 files changed

+53
-48
lines changed

2 files changed

+53
-48
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSender.java

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
import org.elasticsearch.search.SearchShardTarget;
3131
import org.elasticsearch.search.internal.AliasFilter;
3232
import org.elasticsearch.tasks.CancellableTask;
33-
import org.elasticsearch.tasks.Task;
3433
import org.elasticsearch.tasks.TaskCancelledException;
3534
import org.elasticsearch.transport.TransportException;
3635
import org.elasticsearch.transport.TransportRequestOptions;
@@ -104,15 +103,14 @@ abstract class DataNodeRequestSender {
104103
}
105104

106105
final void startComputeOnDataNodes(
107-
String clusterAlias,
108106
Set<String> concreteIndices,
109107
OriginalIndices originalIndices,
110108
QueryBuilder requestFilter,
111109
Runnable runOnTaskFailure,
112110
ActionListener<ComputeResponse> listener
113111
) {
114112
final long startTimeInNanos = System.nanoTime();
115-
searchShards(rootTask, clusterAlias, requestFilter, concreteIndices, originalIndices, ActionListener.wrap(targetShards -> {
113+
searchShards(requestFilter, concreteIndices, originalIndices, ActionListener.wrap(targetShards -> {
116114
try (var computeListener = new ComputeListener(transportService.getThreadPool(), runOnTaskFailure, listener.map(profiles -> {
117115
return new ComputeResponse(
118116
profiles,
@@ -321,7 +319,7 @@ private void trackShardLevelFailure(ShardId shardId, boolean fatal, Exception or
321319
}
322320

323321
/**
324-
* Result from {@link #searchShards(Task, String, QueryBuilder, Set, OriginalIndices, ActionListener)} where can_match is performed to
322+
* Result from {@link #searchShards(QueryBuilder, Set, OriginalIndices, ActionListener)} where can_match is performed to
325323
* determine what shards can be skipped and which target nodes are needed for running the ES|QL query
326324
*
327325
* @param shards List of target shards to perform the ES|QL query on
@@ -412,8 +410,6 @@ private List<NodeRequest> selectNodeRequests(TargetShards targetShards) {
412410
* to a situation where the column structure (i.e., matched data types) differs depending on the query.
413411
*/
414412
void searchShards(
415-
Task parentTask,
416-
String clusterAlias,
417413
QueryBuilder filter,
418414
Set<String> concreteIndices,
419415
OriginalIndices originalIndices,
@@ -459,7 +455,7 @@ void searchShards(
459455
transportService.getLocalNode(),
460456
EsqlSearchShardsAction.TYPE.name(),
461457
searchShardsRequest,
462-
parentTask,
458+
rootTask,
463459
TransportRequestOptions.EMPTY,
464460
new ActionListenerResponseHandler<>(searchShardsListener, SearchShardsResponse::new, esqlExecutor)
465461
);

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSenderTests.java

Lines changed: 50 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
2424
import org.elasticsearch.common.util.concurrent.EsExecutors;
2525
import org.elasticsearch.compute.test.ComputeTestCase;
26-
import org.elasticsearch.core.TimeValue;
2726
import org.elasticsearch.index.Index;
2827
import org.elasticsearch.index.query.QueryBuilder;
2928
import org.elasticsearch.index.shard.ShardId;
@@ -41,6 +40,7 @@
4140
import java.io.IOException;
4241
import java.util.ArrayList;
4342
import java.util.Arrays;
43+
import java.util.Collection;
4444
import java.util.Collections;
4545
import java.util.HashMap;
4646
import java.util.List;
@@ -59,8 +59,11 @@
5959
import static org.elasticsearch.cluster.node.DiscoveryNodeRole.DATA_FROZEN_NODE_ROLE;
6060
import static org.elasticsearch.cluster.node.DiscoveryNodeRole.DATA_HOT_NODE_ROLE;
6161
import static org.elasticsearch.cluster.node.DiscoveryNodeRole.DATA_WARM_NODE_ROLE;
62+
import static org.elasticsearch.core.TimeValue.timeValueNanos;
6263
import static org.elasticsearch.xpack.esql.plugin.DataNodeRequestSender.NodeRequest;
6364
import static org.hamcrest.Matchers.anyOf;
65+
import static org.hamcrest.Matchers.contains;
66+
import static org.hamcrest.Matchers.containsInAnyOrder;
6467
import static org.hamcrest.Matchers.containsString;
6568
import static org.hamcrest.Matchers.empty;
6669
import static org.hamcrest.Matchers.equalTo;
@@ -120,12 +123,12 @@ public void testOnePass() {
120123
);
121124
Queue<NodeRequest> sent = ConcurrentCollections.newQueue();
122125
var future = sendRequests(targetShards, randomBoolean(), -1, (node, shardIds, aliasFilters, listener) -> {
123-
sent.add(new NodeRequest(node, shardIds, aliasFilters));
126+
sent.add(nodeRequest(node, shardIds));
124127
runWithDelay(() -> listener.onResponse(new DataNodeComputeResponse(List.of(), Map.of())));
125128
});
126129
safeGet(future);
127130
assertThat(sent.size(), equalTo(2));
128-
assertThat(groupRequests(sent, 2), equalTo(Map.of(node1, List.of(shard1, shard3), node2, List.of(shard2, shard4))));
131+
assertThat(sent, containsInAnyOrder(nodeRequest(node1, shard1, shard3), nodeRequest(node2, shard2, shard4)));
129132
}
130133

131134
public void testMissingShards() {
@@ -163,7 +166,7 @@ public void testRetryThenSuccess() {
163166
);
164167
Queue<NodeRequest> sent = ConcurrentCollections.newQueue();
165168
var future = sendRequests(targetShards, randomBoolean(), -1, (node, shardIds, aliasFilters, listener) -> {
166-
sent.add(new NodeRequest(node, shardIds, aliasFilters));
169+
sent.add(nodeRequest(node, shardIds));
167170
Map<ShardId, Exception> failures = new HashMap<>();
168171
if (node.equals(node1) && shardIds.contains(shard5)) {
169172
failures.put(shard5, new IOException("test"));
@@ -179,10 +182,11 @@ public void testRetryThenSuccess() {
179182
throw new AssertionError(e);
180183
}
181184
assertThat(sent, hasSize(5));
182-
var firstRound = groupRequests(sent, 3);
183-
assertThat(firstRound, equalTo(Map.of(node1, List.of(shard1, shard5), node4, List.of(shard2), node2, List.of(shard3, shard4))));
184-
var secondRound = groupRequests(sent, 2);
185-
assertThat(secondRound, equalTo(Map.of(node2, List.of(shard2), node3, List.of(shard5))));
185+
assertThat(
186+
take(sent, 3),
187+
containsInAnyOrder(nodeRequest(node1, shard1, shard5), nodeRequest(node4, shard2), nodeRequest(node2, shard3, shard4))
188+
);
189+
assertThat(take(sent, 2), containsInAnyOrder(nodeRequest(node2, shard2), nodeRequest(node3, shard5)));
186190
}
187191

188192
public void testRetryButFail() {
@@ -195,7 +199,7 @@ public void testRetryButFail() {
195199
);
196200
Queue<NodeRequest> sent = ConcurrentCollections.newQueue();
197201
var future = sendRequests(targetShards, false, -1, (node, shardIds, aliasFilters, listener) -> {
198-
sent.add(new NodeRequest(node, shardIds, aliasFilters));
202+
sent.add(nodeRequest(node, shardIds));
199203
Map<ShardId, Exception> failures = new HashMap<>();
200204
if (shardIds.contains(shard5)) {
201205
failures.put(shard5, new IOException("test failure for shard5"));
@@ -206,22 +210,20 @@ public void testRetryButFail() {
206210
assertNotNull(ExceptionsHelper.unwrap(error, IOException.class));
207211
// {node-1, node-2, node-4}, {node-3}, {node-2}
208212
assertThat(sent.size(), equalTo(5));
209-
var firstRound = groupRequests(sent, 3);
210-
assertThat(firstRound, equalTo(Map.of(node1, List.of(shard1, shard5), node2, List.of(shard3, shard4), node4, List.of(shard2))));
211-
NodeRequest fourth = sent.remove();
212-
assertThat(fourth.node(), equalTo(node3));
213-
assertThat(fourth.shardIds(), equalTo(List.of(shard5)));
214-
NodeRequest fifth = sent.remove();
215-
assertThat(fifth.node(), equalTo(node2));
216-
assertThat(fifth.shardIds(), equalTo(List.of(shard5)));
213+
assertThat(
214+
take(sent, 3),
215+
containsInAnyOrder(nodeRequest(node1, shard1, shard5), nodeRequest(node2, shard3, shard4), nodeRequest(node4, shard2))
216+
);
217+
assertThat(take(sent, 1), containsInAnyOrder(nodeRequest(node3, shard5)));
218+
assertThat(take(sent, 1), containsInAnyOrder(nodeRequest(node2, shard5)));
217219
}
218220

219221
public void testDoNotRetryOnRequestLevelFailure() {
220222
var targetShards = List.of(targetShard(shard1, node1), targetShard(shard2, node2), targetShard(shard3, node1));
221223
Queue<NodeRequest> sent = ConcurrentCollections.newQueue();
222224
AtomicBoolean failed = new AtomicBoolean();
223225
var future = sendRequests(targetShards, false, -1, (node, shardIds, aliasFilters, listener) -> {
224-
sent.add(new NodeRequest(node, shardIds, aliasFilters));
226+
sent.add(nodeRequest(node, shardIds));
225227
if (node1.equals(node) && failed.compareAndSet(false, true)) {
226228
runWithDelay(() -> listener.onFailure(new IOException("test request level failure"), true));
227229
} else {
@@ -232,37 +234,35 @@ public void testDoNotRetryOnRequestLevelFailure() {
232234
assertNotNull(ExceptionsHelper.unwrap(exception, IOException.class));
233235
// one round: {node-1, node-2}
234236
assertThat(sent.size(), equalTo(2));
235-
var firstRound = groupRequests(sent, 2);
236-
assertThat(firstRound, equalTo(Map.of(node1, List.of(shard1, shard3), node2, List.of(shard2))));
237+
assertThat(sent, containsInAnyOrder(nodeRequest(node1, shard1, shard3), nodeRequest(node2, shard2)));
237238
}
238239

239240
public void testAllowPartialResults() {
240241
var targetShards = List.of(targetShard(shard1, node1), targetShard(shard2, node2), targetShard(shard3, node1, node2));
241242
Queue<NodeRequest> sent = ConcurrentCollections.newQueue();
242243
AtomicBoolean failed = new AtomicBoolean();
243244
var future = sendRequests(targetShards, true, -1, (node, shardIds, aliasFilters, listener) -> {
244-
sent.add(new NodeRequest(node, shardIds, aliasFilters));
245+
sent.add(nodeRequest(node, shardIds));
245246
if (node1.equals(node) && failed.compareAndSet(false, true)) {
246247
runWithDelay(() -> listener.onFailure(new IOException("test request level failure"), true));
247248
} else {
248249
runWithDelay(() -> listener.onResponse(new DataNodeComputeResponse(List.of(), Map.of())));
249250
}
250251
});
251-
ComputeResponse resp = safeGet(future);
252+
var response = safeGet(future);
253+
assertThat(response.totalShards, equalTo(3));
254+
assertThat(response.failedShards, equalTo(2));
255+
assertThat(response.successfulShards, equalTo(1));
252256
// one round: {node-1, node-2}
253257
assertThat(sent.size(), equalTo(2));
254-
var firstRound = groupRequests(sent, 2);
255-
assertThat(firstRound, equalTo(Map.of(node1, List.of(shard1, shard3), node2, List.of(shard2))));
256-
assertThat(resp.totalShards, equalTo(3));
257-
assertThat(resp.failedShards, equalTo(2));
258-
assertThat(resp.successfulShards, equalTo(1));
258+
assertThat(sent, containsInAnyOrder(nodeRequest(node1, shard1, shard3), nodeRequest(node2, shard2)));
259259
}
260260

261261
public void testNonFatalErrorIsRetriedOnAnotherShard() {
262262
var targetShards = List.of(targetShard(shard1, node1, node2));
263263
var sent = ConcurrentCollections.<NodeRequest>newQueue();
264264
var response = safeGet(sendRequests(targetShards, false, -1, (node, shardIds, aliasFilters, listener) -> {
265-
sent.add(new NodeRequest(node, shardIds, aliasFilters));
265+
sent.add(nodeRequest(node, shardIds));
266266
if (Objects.equals(node1, node)) {
267267
runWithDelay(() -> listener.onFailure(new RuntimeException("test request level non fatal failure"), false));
268268
} else {
@@ -279,7 +279,7 @@ public void testNonFatalFailedOnAllNodes() {
279279
var targetShards = List.of(targetShard(shard1, node1, node2));
280280
var sent = ConcurrentCollections.<NodeRequest>newQueue();
281281
var future = sendRequests(targetShards, false, -1, (node, shardIds, aliasFilters, listener) -> {
282-
sent.add(new NodeRequest(node, shardIds, aliasFilters));
282+
sent.add(nodeRequest(node, shardIds));
283283
runWithDelay(() -> listener.onFailure(new RuntimeException("test request level non fatal failure"), false));
284284
});
285285
expectThrows(RuntimeException.class, equalTo("test request level non fatal failure"), future::actionGet);
@@ -290,7 +290,7 @@ public void testDoNotRetryCircuitBreakerException() {
290290
var targetShards = List.of(targetShard(shard1, node1, node2));
291291
var sent = ConcurrentCollections.<NodeRequest>newQueue();
292292
var future = sendRequests(targetShards, false, -1, (node, shardIds, aliasFilters, listener) -> {
293-
sent.add(new NodeRequest(node, shardIds, aliasFilters));
293+
sent.add(nodeRequest(node, shardIds));
294294
runWithDelay(() -> listener.onFailure(new CircuitBreakingException("cbe", randomFrom(Durability.values())), false));
295295
});
296296
expectThrows(CircuitBreakingException.class, equalTo("cbe"), future::actionGet);
@@ -321,7 +321,7 @@ public void testLimitConcurrentNodes() {
321321
}
322322
}
323323

324-
sent.add(new NodeRequest(node, shardIds, aliasFilters));
324+
sent.add(nodeRequest(node, shardIds));
325325
runWithDelay(() -> {
326326
concurrentRequests.decrementAndGet();
327327
listener.onResponse(new DataNodeComputeResponse(List.of(), Map.of()));
@@ -364,7 +364,7 @@ public void testSkipRemovesPriorNonFatalErrors() {
364364

365365
var sent = ConcurrentCollections.<NodeRequest>newQueue();
366366
var response = safeGet(sendRequests(targetShards, randomBoolean(), 1, (node, shardIds, aliasFilters, listener) -> {
367-
sent.add(new NodeRequest(node, shardIds, aliasFilters));
367+
sent.add(nodeRequest(node, shardIds));
368368
runWithDelay(() -> {
369369
if (Objects.equals(node.getId(), node1.getId()) && shardIds.equals(List.of(shard1))) {
370370
listener.onFailure(new RuntimeException("test request level non fatal failure"), false);
@@ -406,29 +406,38 @@ public void testQueryHotShardsFirstWhenIlmMovesShard() {
406406
);
407407
var sent = ConcurrentCollections.<NodeRequest>newQueue();
408408
safeGet(sendRequests(targetShards, randomBoolean(), -1, (node, shardIds, aliasFilters, listener) -> {
409-
sent.add(new NodeRequest(node, shardIds, aliasFilters));
409+
sent.add(nodeRequest(node, shardIds));
410410
runWithDelay(() -> listener.onResponse(new DataNodeComputeResponse(List.of(), Map.of())));
411411
}));
412-
assertThat(groupRequests(sent, 1), equalTo(Map.of(node1, List.of(shard1))));
413-
assertThat(groupRequests(sent, 1), anyOf(equalTo(Map.of(node2, List.of(shard2))), equalTo(Map.of(warmNode2, List.of(shard2)))));
412+
assertThat(take(sent, 1), containsInAnyOrder(nodeRequest(node1, shard1)));
413+
assertThat(take(sent, 1), anyOf(contains(nodeRequest(node2, shard2)), contains(nodeRequest(warmNode2, shard2))));
414414
}
415415

416416
static DataNodeRequestSender.TargetShard targetShard(ShardId shardId, DiscoveryNode... nodes) {
417417
return new DataNodeRequestSender.TargetShard(shardId, new ArrayList<>(Arrays.asList(nodes)), null);
418418
}
419419

420-
static Map<DiscoveryNode, List<ShardId>> groupRequests(Queue<NodeRequest> sent, int limit) {
421-
Map<DiscoveryNode, List<ShardId>> map = new HashMap<>();
420+
static DataNodeRequestSender.NodeRequest nodeRequest(DiscoveryNode node, ShardId... shardIds) {
421+
return nodeRequest(node, Arrays.asList(shardIds));
422+
}
423+
424+
static DataNodeRequestSender.NodeRequest nodeRequest(DiscoveryNode node, List<ShardId> shardIds) {
425+
var copy = new ArrayList<>(shardIds);
426+
Collections.sort(copy);
427+
return new NodeRequest(node, copy, Map.of());
428+
}
429+
430+
static <T> Collection<T> take(Queue<T> queue, int limit) {
431+
var result = new ArrayList<T>(limit);
422432
for (int i = 0; i < limit; i++) {
423-
NodeRequest r = sent.remove();
424-
assertNull(map.put(r.node(), r.shardIds().stream().sorted().toList()));
433+
result.add(queue.remove());
425434
}
426-
return map;
435+
return result;
427436
}
428437

429438
void runWithDelay(Runnable runnable) {
430439
if (randomBoolean()) {
431-
threadPool.schedule(runnable, TimeValue.timeValueNanos(between(0, 5000)), executor);
440+
threadPool.schedule(runnable, timeValueNanos(between(0, 5000)), executor);
432441
} else {
433442
executor.execute(runnable);
434443
}

0 commit comments

Comments
 (0)