Skip to content

Commit b96a2f6

Browse files
authored
Simplify DataNodeRequestSender (#126664)
1 parent 24dfda5 commit b96a2f6

File tree

3 files changed

+53
-54
lines changed

3 files changed

+53
-54
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeComputeHandler.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,6 @@ protected void sendRequest(
199199
);
200200
}
201201
}.startComputeOnDataNodes(
202-
clusterAlias,
203202
concreteIndices,
204203
originalIndices,
205204
PlannerUtils.canMatchFilter(dataNodePlan),

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSender.java

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
import org.elasticsearch.search.SearchShardTarget;
3131
import org.elasticsearch.search.internal.AliasFilter;
3232
import org.elasticsearch.tasks.CancellableTask;
33-
import org.elasticsearch.tasks.Task;
3433
import org.elasticsearch.tasks.TaskCancelledException;
3534
import org.elasticsearch.transport.TransportException;
3635
import org.elasticsearch.transport.TransportRequestOptions;
@@ -104,15 +103,14 @@ abstract class DataNodeRequestSender {
104103
}
105104

106105
final void startComputeOnDataNodes(
107-
String clusterAlias,
108106
Set<String> concreteIndices,
109107
OriginalIndices originalIndices,
110108
QueryBuilder requestFilter,
111109
Runnable runOnTaskFailure,
112110
ActionListener<ComputeResponse> listener
113111
) {
114112
final long startTimeInNanos = System.nanoTime();
115-
searchShards(rootTask, clusterAlias, requestFilter, concreteIndices, originalIndices, ActionListener.wrap(targetShards -> {
113+
searchShards(requestFilter, concreteIndices, originalIndices, ActionListener.wrap(targetShards -> {
116114
try (var computeListener = new ComputeListener(transportService.getThreadPool(), runOnTaskFailure, listener.map(profiles -> {
117115
return new ComputeResponse(
118116
profiles,
@@ -321,7 +319,7 @@ private void trackShardLevelFailure(ShardId shardId, boolean fatal, Exception or
321319
}
322320

323321
/**
324-
* Result from {@link #searchShards(Task, String, QueryBuilder, Set, OriginalIndices, ActionListener)} where can_match is performed to
322+
* Result from {@link #searchShards(QueryBuilder, Set, OriginalIndices, ActionListener)} where can_match is performed to
325323
* determine what shards can be skipped and which target nodes are needed for running the ES|QL query
326324
*
327325
* @param shards List of target shards to perform the ES|QL query on
@@ -412,8 +410,6 @@ private List<NodeRequest> selectNodeRequests(TargetShards targetShards) {
412410
* to a situation where the column structure (i.e., matched data types) differs depending on the query.
413411
*/
414412
void searchShards(
415-
Task parentTask,
416-
String clusterAlias,
417413
QueryBuilder filter,
418414
Set<String> concreteIndices,
419415
OriginalIndices originalIndices,
@@ -459,7 +455,7 @@ void searchShards(
459455
transportService.getLocalNode(),
460456
EsqlSearchShardsAction.TYPE.name(),
461457
searchShardsRequest,
462-
parentTask,
458+
rootTask,
463459
TransportRequestOptions.EMPTY,
464460
new ActionListenerResponseHandler<>(searchShardsListener, SearchShardsResponse::new, esqlExecutor)
465461
);

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSenderTests.java

Lines changed: 50 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,11 @@
2323
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
2424
import org.elasticsearch.common.util.concurrent.EsExecutors;
2525
import org.elasticsearch.compute.test.ComputeTestCase;
26-
import org.elasticsearch.core.TimeValue;
2726
import org.elasticsearch.index.Index;
2827
import org.elasticsearch.index.query.QueryBuilder;
2928
import org.elasticsearch.index.shard.ShardId;
3029
import org.elasticsearch.search.internal.AliasFilter;
3130
import org.elasticsearch.tasks.CancellableTask;
32-
import org.elasticsearch.tasks.Task;
3331
import org.elasticsearch.tasks.TaskId;
3432
import org.elasticsearch.test.transport.MockTransportService;
3533
import org.elasticsearch.threadpool.FixedExecutorBuilder;
@@ -41,6 +39,7 @@
4139
import java.io.IOException;
4240
import java.util.ArrayList;
4341
import java.util.Arrays;
42+
import java.util.Collection;
4443
import java.util.Collections;
4544
import java.util.HashMap;
4645
import java.util.List;
@@ -59,8 +58,11 @@
5958
import static org.elasticsearch.cluster.node.DiscoveryNodeRole.DATA_FROZEN_NODE_ROLE;
6059
import static org.elasticsearch.cluster.node.DiscoveryNodeRole.DATA_HOT_NODE_ROLE;
6160
import static org.elasticsearch.cluster.node.DiscoveryNodeRole.DATA_WARM_NODE_ROLE;
61+
import static org.elasticsearch.core.TimeValue.timeValueNanos;
6262
import static org.elasticsearch.xpack.esql.plugin.DataNodeRequestSender.NodeRequest;
6363
import static org.hamcrest.Matchers.anyOf;
64+
import static org.hamcrest.Matchers.contains;
65+
import static org.hamcrest.Matchers.containsInAnyOrder;
6466
import static org.hamcrest.Matchers.containsString;
6567
import static org.hamcrest.Matchers.empty;
6668
import static org.hamcrest.Matchers.equalTo;
@@ -120,12 +122,12 @@ public void testOnePass() {
120122
);
121123
Queue<NodeRequest> sent = ConcurrentCollections.newQueue();
122124
var future = sendRequests(targetShards, randomBoolean(), -1, (node, shardIds, aliasFilters, listener) -> {
123-
sent.add(new NodeRequest(node, shardIds, aliasFilters));
125+
sent.add(nodeRequest(node, shardIds));
124126
runWithDelay(() -> listener.onResponse(new DataNodeComputeResponse(List.of(), Map.of())));
125127
});
126128
safeGet(future);
127129
assertThat(sent.size(), equalTo(2));
128-
assertThat(groupRequests(sent, 2), equalTo(Map.of(node1, List.of(shard1, shard3), node2, List.of(shard2, shard4))));
130+
assertThat(sent, containsInAnyOrder(nodeRequest(node1, shard1, shard3), nodeRequest(node2, shard2, shard4)));
129131
}
130132

131133
public void testMissingShards() {
@@ -163,7 +165,7 @@ public void testRetryThenSuccess() {
163165
);
164166
Queue<NodeRequest> sent = ConcurrentCollections.newQueue();
165167
var future = sendRequests(targetShards, randomBoolean(), -1, (node, shardIds, aliasFilters, listener) -> {
166-
sent.add(new NodeRequest(node, shardIds, aliasFilters));
168+
sent.add(nodeRequest(node, shardIds));
167169
Map<ShardId, Exception> failures = new HashMap<>();
168170
if (node.equals(node1) && shardIds.contains(shard5)) {
169171
failures.put(shard5, new IOException("test"));
@@ -179,10 +181,11 @@ public void testRetryThenSuccess() {
179181
throw new AssertionError(e);
180182
}
181183
assertThat(sent, hasSize(5));
182-
var firstRound = groupRequests(sent, 3);
183-
assertThat(firstRound, equalTo(Map.of(node1, List.of(shard1, shard5), node4, List.of(shard2), node2, List.of(shard3, shard4))));
184-
var secondRound = groupRequests(sent, 2);
185-
assertThat(secondRound, equalTo(Map.of(node2, List.of(shard2), node3, List.of(shard5))));
184+
assertThat(
185+
take(sent, 3),
186+
containsInAnyOrder(nodeRequest(node1, shard1, shard5), nodeRequest(node4, shard2), nodeRequest(node2, shard3, shard4))
187+
);
188+
assertThat(take(sent, 2), containsInAnyOrder(nodeRequest(node2, shard2), nodeRequest(node3, shard5)));
186189
}
187190

188191
public void testRetryButFail() {
@@ -195,7 +198,7 @@ public void testRetryButFail() {
195198
);
196199
Queue<NodeRequest> sent = ConcurrentCollections.newQueue();
197200
var future = sendRequests(targetShards, false, -1, (node, shardIds, aliasFilters, listener) -> {
198-
sent.add(new NodeRequest(node, shardIds, aliasFilters));
201+
sent.add(nodeRequest(node, shardIds));
199202
Map<ShardId, Exception> failures = new HashMap<>();
200203
if (shardIds.contains(shard5)) {
201204
failures.put(shard5, new IOException("test failure for shard5"));
@@ -206,22 +209,20 @@ public void testRetryButFail() {
206209
assertNotNull(ExceptionsHelper.unwrap(error, IOException.class));
207210
// {node-1, node-2, node-4}, {node-3}, {node-2}
208211
assertThat(sent.size(), equalTo(5));
209-
var firstRound = groupRequests(sent, 3);
210-
assertThat(firstRound, equalTo(Map.of(node1, List.of(shard1, shard5), node2, List.of(shard3, shard4), node4, List.of(shard2))));
211-
NodeRequest fourth = sent.remove();
212-
assertThat(fourth.node(), equalTo(node3));
213-
assertThat(fourth.shardIds(), equalTo(List.of(shard5)));
214-
NodeRequest fifth = sent.remove();
215-
assertThat(fifth.node(), equalTo(node2));
216-
assertThat(fifth.shardIds(), equalTo(List.of(shard5)));
212+
assertThat(
213+
take(sent, 3),
214+
containsInAnyOrder(nodeRequest(node1, shard1, shard5), nodeRequest(node2, shard3, shard4), nodeRequest(node4, shard2))
215+
);
216+
assertThat(take(sent, 1), containsInAnyOrder(nodeRequest(node3, shard5)));
217+
assertThat(take(sent, 1), containsInAnyOrder(nodeRequest(node2, shard5)));
217218
}
218219

219220
public void testDoNotRetryOnRequestLevelFailure() {
220221
var targetShards = List.of(targetShard(shard1, node1), targetShard(shard2, node2), targetShard(shard3, node1));
221222
Queue<NodeRequest> sent = ConcurrentCollections.newQueue();
222223
AtomicBoolean failed = new AtomicBoolean();
223224
var future = sendRequests(targetShards, false, -1, (node, shardIds, aliasFilters, listener) -> {
224-
sent.add(new NodeRequest(node, shardIds, aliasFilters));
225+
sent.add(nodeRequest(node, shardIds));
225226
if (node1.equals(node) && failed.compareAndSet(false, true)) {
226227
runWithDelay(() -> listener.onFailure(new IOException("test request level failure"), true));
227228
} else {
@@ -232,37 +233,35 @@ public void testDoNotRetryOnRequestLevelFailure() {
232233
assertNotNull(ExceptionsHelper.unwrap(exception, IOException.class));
233234
// one round: {node-1, node-2}
234235
assertThat(sent.size(), equalTo(2));
235-
var firstRound = groupRequests(sent, 2);
236-
assertThat(firstRound, equalTo(Map.of(node1, List.of(shard1, shard3), node2, List.of(shard2))));
236+
assertThat(sent, containsInAnyOrder(nodeRequest(node1, shard1, shard3), nodeRequest(node2, shard2)));
237237
}
238238

239239
public void testAllowPartialResults() {
240240
var targetShards = List.of(targetShard(shard1, node1), targetShard(shard2, node2), targetShard(shard3, node1, node2));
241241
Queue<NodeRequest> sent = ConcurrentCollections.newQueue();
242242
AtomicBoolean failed = new AtomicBoolean();
243243
var future = sendRequests(targetShards, true, -1, (node, shardIds, aliasFilters, listener) -> {
244-
sent.add(new NodeRequest(node, shardIds, aliasFilters));
244+
sent.add(nodeRequest(node, shardIds));
245245
if (node1.equals(node) && failed.compareAndSet(false, true)) {
246246
runWithDelay(() -> listener.onFailure(new IOException("test request level failure"), true));
247247
} else {
248248
runWithDelay(() -> listener.onResponse(new DataNodeComputeResponse(List.of(), Map.of())));
249249
}
250250
});
251-
ComputeResponse resp = safeGet(future);
251+
var response = safeGet(future);
252+
assertThat(response.totalShards, equalTo(3));
253+
assertThat(response.failedShards, equalTo(2));
254+
assertThat(response.successfulShards, equalTo(1));
252255
// one round: {node-1, node-2}
253256
assertThat(sent.size(), equalTo(2));
254-
var firstRound = groupRequests(sent, 2);
255-
assertThat(firstRound, equalTo(Map.of(node1, List.of(shard1, shard3), node2, List.of(shard2))));
256-
assertThat(resp.totalShards, equalTo(3));
257-
assertThat(resp.failedShards, equalTo(2));
258-
assertThat(resp.successfulShards, equalTo(1));
257+
assertThat(sent, containsInAnyOrder(nodeRequest(node1, shard1, shard3), nodeRequest(node2, shard2)));
259258
}
260259

261260
public void testNonFatalErrorIsRetriedOnAnotherShard() {
262261
var targetShards = List.of(targetShard(shard1, node1, node2));
263262
var sent = ConcurrentCollections.<NodeRequest>newQueue();
264263
var response = safeGet(sendRequests(targetShards, false, -1, (node, shardIds, aliasFilters, listener) -> {
265-
sent.add(new NodeRequest(node, shardIds, aliasFilters));
264+
sent.add(nodeRequest(node, shardIds));
266265
if (Objects.equals(node1, node)) {
267266
runWithDelay(() -> listener.onFailure(new RuntimeException("test request level non fatal failure"), false));
268267
} else {
@@ -279,7 +278,7 @@ public void testNonFatalFailedOnAllNodes() {
279278
var targetShards = List.of(targetShard(shard1, node1, node2));
280279
var sent = ConcurrentCollections.<NodeRequest>newQueue();
281280
var future = sendRequests(targetShards, false, -1, (node, shardIds, aliasFilters, listener) -> {
282-
sent.add(new NodeRequest(node, shardIds, aliasFilters));
281+
sent.add(nodeRequest(node, shardIds));
283282
runWithDelay(() -> listener.onFailure(new RuntimeException("test request level non fatal failure"), false));
284283
});
285284
expectThrows(RuntimeException.class, equalTo("test request level non fatal failure"), future::actionGet);
@@ -290,7 +289,7 @@ public void testDoNotRetryCircuitBreakerException() {
290289
var targetShards = List.of(targetShard(shard1, node1, node2));
291290
var sent = ConcurrentCollections.<NodeRequest>newQueue();
292291
var future = sendRequests(targetShards, false, -1, (node, shardIds, aliasFilters, listener) -> {
293-
sent.add(new NodeRequest(node, shardIds, aliasFilters));
292+
sent.add(nodeRequest(node, shardIds));
294293
runWithDelay(() -> listener.onFailure(new CircuitBreakingException("cbe", randomFrom(Durability.values())), false));
295294
});
296295
expectThrows(CircuitBreakingException.class, equalTo("cbe"), future::actionGet);
@@ -321,7 +320,7 @@ public void testLimitConcurrentNodes() {
321320
}
322321
}
323322

324-
sent.add(new NodeRequest(node, shardIds, aliasFilters));
323+
sent.add(nodeRequest(node, shardIds));
325324
runWithDelay(() -> {
326325
concurrentRequests.decrementAndGet();
327326
listener.onResponse(new DataNodeComputeResponse(List.of(), Map.of()));
@@ -364,7 +363,7 @@ public void testSkipRemovesPriorNonFatalErrors() {
364363

365364
var sent = ConcurrentCollections.<NodeRequest>newQueue();
366365
var response = safeGet(sendRequests(targetShards, randomBoolean(), 1, (node, shardIds, aliasFilters, listener) -> {
367-
sent.add(new NodeRequest(node, shardIds, aliasFilters));
366+
sent.add(nodeRequest(node, shardIds));
368367
runWithDelay(() -> {
369368
if (Objects.equals(node.getId(), node1.getId()) && shardIds.equals(List.of(shard1))) {
370369
listener.onFailure(new RuntimeException("test request level non fatal failure"), false);
@@ -406,29 +405,38 @@ public void testQueryHotShardsFirstWhenIlmMovesShard() {
406405
);
407406
var sent = ConcurrentCollections.<NodeRequest>newQueue();
408407
safeGet(sendRequests(targetShards, randomBoolean(), -1, (node, shardIds, aliasFilters, listener) -> {
409-
sent.add(new NodeRequest(node, shardIds, aliasFilters));
408+
sent.add(nodeRequest(node, shardIds));
410409
runWithDelay(() -> listener.onResponse(new DataNodeComputeResponse(List.of(), Map.of())));
411410
}));
412-
assertThat(groupRequests(sent, 1), equalTo(Map.of(node1, List.of(shard1))));
413-
assertThat(groupRequests(sent, 1), anyOf(equalTo(Map.of(node2, List.of(shard2))), equalTo(Map.of(warmNode2, List.of(shard2)))));
411+
assertThat(take(sent, 1), containsInAnyOrder(nodeRequest(node1, shard1)));
412+
assertThat(take(sent, 1), anyOf(contains(nodeRequest(node2, shard2)), contains(nodeRequest(warmNode2, shard2))));
414413
}
415414

416415
static DataNodeRequestSender.TargetShard targetShard(ShardId shardId, DiscoveryNode... nodes) {
417416
return new DataNodeRequestSender.TargetShard(shardId, new ArrayList<>(Arrays.asList(nodes)), null);
418417
}
419418

420-
static Map<DiscoveryNode, List<ShardId>> groupRequests(Queue<NodeRequest> sent, int limit) {
421-
Map<DiscoveryNode, List<ShardId>> map = new HashMap<>();
419+
static DataNodeRequestSender.NodeRequest nodeRequest(DiscoveryNode node, ShardId... shardIds) {
420+
return nodeRequest(node, Arrays.asList(shardIds));
421+
}
422+
423+
static DataNodeRequestSender.NodeRequest nodeRequest(DiscoveryNode node, List<ShardId> shardIds) {
424+
var copy = new ArrayList<>(shardIds);
425+
Collections.sort(copy);
426+
return new NodeRequest(node, copy, Map.of());
427+
}
428+
429+
static <T> Collection<T> take(Queue<T> queue, int limit) {
430+
var result = new ArrayList<T>(limit);
422431
for (int i = 0; i < limit; i++) {
423-
NodeRequest r = sent.remove();
424-
assertNull(map.put(r.node(), r.shardIds().stream().sorted().toList()));
432+
result.add(queue.remove());
425433
}
426-
return map;
434+
return result;
427435
}
428436

429437
void runWithDelay(Runnable runnable) {
430438
if (randomBoolean()) {
431-
threadPool.schedule(runnable, TimeValue.timeValueNanos(between(0, 5000)), executor);
439+
threadPool.schedule(runnable, timeValueNanos(between(0, 5000)), executor);
432440
} else {
433441
executor.execute(runnable);
434442
}
@@ -465,8 +473,6 @@ PlainActionFuture<ComputeResponse> sendRequests(
465473
) {
466474
@Override
467475
void searchShards(
468-
Task parentTask,
469-
String clusterAlias,
470476
QueryBuilder filter,
471477
Set<String> concreteIndices,
472478
OriginalIndices originalIndices,
@@ -477,7 +483,6 @@ void searchShards(
477483
shards.size(),
478484
0
479485
);
480-
assertSame(parentTask, task);
481486
runWithDelay(() -> listener.onResponse(targetShards));
482487
}
483488

@@ -492,7 +497,6 @@ protected void sendRequest(
492497
}
493498
};
494499
requestSender.startComputeOnDataNodes(
495-
"",
496500
Set.of(randomAlphaOfLength(10)),
497501
new OriginalIndices(new String[0], SearchRequest.DEFAULT_INDICES_OPTIONS),
498502
null,

0 commit comments

Comments
 (0)