diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeComputeHandler.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeComputeHandler.java index bea9c7b7a5db9..d8aa9ea3e258a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeComputeHandler.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeComputeHandler.java @@ -199,7 +199,6 @@ protected void sendRequest( ); } }.startComputeOnDataNodes( - clusterAlias, concreteIndices, originalIndices, PlannerUtils.canMatchFilter(dataNodePlan), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSender.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSender.java index 7abc0ba40af76..d2ba7ddc11792 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSender.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSender.java @@ -30,7 +30,6 @@ import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.internal.AliasFilter; import org.elasticsearch.tasks.CancellableTask; -import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.transport.TransportException; import org.elasticsearch.transport.TransportRequestOptions; @@ -104,7 +103,6 @@ abstract class DataNodeRequestSender { } final void startComputeOnDataNodes( - String clusterAlias, Set concreteIndices, OriginalIndices originalIndices, QueryBuilder requestFilter, @@ -112,7 +110,7 @@ final void startComputeOnDataNodes( ActionListener listener ) { final long startTimeInNanos = System.nanoTime(); - searchShards(rootTask, clusterAlias, requestFilter, concreteIndices, originalIndices, ActionListener.wrap(targetShards -> { + searchShards(requestFilter, concreteIndices, originalIndices, ActionListener.wrap(targetShards -> { try (var computeListener = new ComputeListener(transportService.getThreadPool(), runOnTaskFailure, listener.map(profiles -> { return new ComputeResponse( profiles, @@ -321,7 +319,7 @@ private void trackShardLevelFailure(ShardId shardId, boolean fatal, Exception or } /** - * Result from {@link #searchShards(Task, String, QueryBuilder, Set, OriginalIndices, ActionListener)} where can_match is performed to + * Result from {@link #searchShards(QueryBuilder, Set, OriginalIndices, ActionListener)} where can_match is performed to * determine what shards can be skipped and which target nodes are needed for running the ES|QL query * * @param shards List of target shards to perform the ES|QL query on @@ -412,8 +410,6 @@ private List selectNodeRequests(TargetShards targetShards) { * to a situation where the column structure (i.e., matched data types) differs depending on the query. */ void searchShards( - Task parentTask, - String clusterAlias, QueryBuilder filter, Set concreteIndices, OriginalIndices originalIndices, @@ -459,7 +455,7 @@ void searchShards( transportService.getLocalNode(), EsqlSearchShardsAction.TYPE.name(), searchShardsRequest, - parentTask, + rootTask, TransportRequestOptions.EMPTY, new ActionListenerResponseHandler<>(searchShardsListener, SearchShardsResponse::new, esqlExecutor) ); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSenderTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSenderTests.java index 92c77f7bd47c7..0ca5d8f79ca8d 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSenderTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSenderTests.java @@ -23,13 +23,11 @@ import org.elasticsearch.common.util.concurrent.ConcurrentCollections; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.compute.test.ComputeTestCase; -import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.Index; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.search.internal.AliasFilter; import org.elasticsearch.tasks.CancellableTask; -import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.test.transport.MockTransportService; import org.elasticsearch.threadpool.FixedExecutorBuilder; @@ -41,6 +39,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -59,8 +58,11 @@ import static org.elasticsearch.cluster.node.DiscoveryNodeRole.DATA_FROZEN_NODE_ROLE; import static org.elasticsearch.cluster.node.DiscoveryNodeRole.DATA_HOT_NODE_ROLE; import static org.elasticsearch.cluster.node.DiscoveryNodeRole.DATA_WARM_NODE_ROLE; +import static org.elasticsearch.core.TimeValue.timeValueNanos; import static org.elasticsearch.xpack.esql.plugin.DataNodeRequestSender.NodeRequest; import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; @@ -120,12 +122,12 @@ public void testOnePass() { ); Queue sent = ConcurrentCollections.newQueue(); var future = sendRequests(targetShards, randomBoolean(), -1, (node, shardIds, aliasFilters, listener) -> { - sent.add(new NodeRequest(node, shardIds, aliasFilters)); + sent.add(nodeRequest(node, shardIds)); runWithDelay(() -> listener.onResponse(new DataNodeComputeResponse(List.of(), Map.of()))); }); safeGet(future); assertThat(sent.size(), equalTo(2)); - assertThat(groupRequests(sent, 2), equalTo(Map.of(node1, List.of(shard1, shard3), node2, List.of(shard2, shard4)))); + assertThat(sent, containsInAnyOrder(nodeRequest(node1, shard1, shard3), nodeRequest(node2, shard2, shard4))); } public void testMissingShards() { @@ -163,7 +165,7 @@ public void testRetryThenSuccess() { ); Queue sent = ConcurrentCollections.newQueue(); var future = sendRequests(targetShards, randomBoolean(), -1, (node, shardIds, aliasFilters, listener) -> { - sent.add(new NodeRequest(node, shardIds, aliasFilters)); + sent.add(nodeRequest(node, shardIds)); Map failures = new HashMap<>(); if (node.equals(node1) && shardIds.contains(shard5)) { failures.put(shard5, new IOException("test")); @@ -179,10 +181,11 @@ public void testRetryThenSuccess() { throw new AssertionError(e); } assertThat(sent, hasSize(5)); - var firstRound = groupRequests(sent, 3); - assertThat(firstRound, equalTo(Map.of(node1, List.of(shard1, shard5), node4, List.of(shard2), node2, List.of(shard3, shard4)))); - var secondRound = groupRequests(sent, 2); - assertThat(secondRound, equalTo(Map.of(node2, List.of(shard2), node3, List.of(shard5)))); + assertThat( + take(sent, 3), + containsInAnyOrder(nodeRequest(node1, shard1, shard5), nodeRequest(node4, shard2), nodeRequest(node2, shard3, shard4)) + ); + assertThat(take(sent, 2), containsInAnyOrder(nodeRequest(node2, shard2), nodeRequest(node3, shard5))); } public void testRetryButFail() { @@ -195,7 +198,7 @@ public void testRetryButFail() { ); Queue sent = ConcurrentCollections.newQueue(); var future = sendRequests(targetShards, false, -1, (node, shardIds, aliasFilters, listener) -> { - sent.add(new NodeRequest(node, shardIds, aliasFilters)); + sent.add(nodeRequest(node, shardIds)); Map failures = new HashMap<>(); if (shardIds.contains(shard5)) { failures.put(shard5, new IOException("test failure for shard5")); @@ -206,14 +209,12 @@ public void testRetryButFail() { assertNotNull(ExceptionsHelper.unwrap(error, IOException.class)); // {node-1, node-2, node-4}, {node-3}, {node-2} assertThat(sent.size(), equalTo(5)); - var firstRound = groupRequests(sent, 3); - assertThat(firstRound, equalTo(Map.of(node1, List.of(shard1, shard5), node2, List.of(shard3, shard4), node4, List.of(shard2)))); - NodeRequest fourth = sent.remove(); - assertThat(fourth.node(), equalTo(node3)); - assertThat(fourth.shardIds(), equalTo(List.of(shard5))); - NodeRequest fifth = sent.remove(); - assertThat(fifth.node(), equalTo(node2)); - assertThat(fifth.shardIds(), equalTo(List.of(shard5))); + assertThat( + take(sent, 3), + containsInAnyOrder(nodeRequest(node1, shard1, shard5), nodeRequest(node2, shard3, shard4), nodeRequest(node4, shard2)) + ); + assertThat(take(sent, 1), containsInAnyOrder(nodeRequest(node3, shard5))); + assertThat(take(sent, 1), containsInAnyOrder(nodeRequest(node2, shard5))); } public void testDoNotRetryOnRequestLevelFailure() { @@ -221,7 +222,7 @@ public void testDoNotRetryOnRequestLevelFailure() { Queue sent = ConcurrentCollections.newQueue(); AtomicBoolean failed = new AtomicBoolean(); var future = sendRequests(targetShards, false, -1, (node, shardIds, aliasFilters, listener) -> { - sent.add(new NodeRequest(node, shardIds, aliasFilters)); + sent.add(nodeRequest(node, shardIds)); if (node1.equals(node) && failed.compareAndSet(false, true)) { runWithDelay(() -> listener.onFailure(new IOException("test request level failure"), true)); } else { @@ -232,8 +233,7 @@ public void testDoNotRetryOnRequestLevelFailure() { assertNotNull(ExceptionsHelper.unwrap(exception, IOException.class)); // one round: {node-1, node-2} assertThat(sent.size(), equalTo(2)); - var firstRound = groupRequests(sent, 2); - assertThat(firstRound, equalTo(Map.of(node1, List.of(shard1, shard3), node2, List.of(shard2)))); + assertThat(sent, containsInAnyOrder(nodeRequest(node1, shard1, shard3), nodeRequest(node2, shard2))); } public void testAllowPartialResults() { @@ -241,28 +241,27 @@ public void testAllowPartialResults() { Queue sent = ConcurrentCollections.newQueue(); AtomicBoolean failed = new AtomicBoolean(); var future = sendRequests(targetShards, true, -1, (node, shardIds, aliasFilters, listener) -> { - sent.add(new NodeRequest(node, shardIds, aliasFilters)); + sent.add(nodeRequest(node, shardIds)); if (node1.equals(node) && failed.compareAndSet(false, true)) { runWithDelay(() -> listener.onFailure(new IOException("test request level failure"), true)); } else { runWithDelay(() -> listener.onResponse(new DataNodeComputeResponse(List.of(), Map.of()))); } }); - ComputeResponse resp = safeGet(future); + var response = safeGet(future); + assertThat(response.totalShards, equalTo(3)); + assertThat(response.failedShards, equalTo(2)); + assertThat(response.successfulShards, equalTo(1)); // one round: {node-1, node-2} assertThat(sent.size(), equalTo(2)); - var firstRound = groupRequests(sent, 2); - assertThat(firstRound, equalTo(Map.of(node1, List.of(shard1, shard3), node2, List.of(shard2)))); - assertThat(resp.totalShards, equalTo(3)); - assertThat(resp.failedShards, equalTo(2)); - assertThat(resp.successfulShards, equalTo(1)); + assertThat(sent, containsInAnyOrder(nodeRequest(node1, shard1, shard3), nodeRequest(node2, shard2))); } public void testNonFatalErrorIsRetriedOnAnotherShard() { var targetShards = List.of(targetShard(shard1, node1, node2)); var sent = ConcurrentCollections.newQueue(); var response = safeGet(sendRequests(targetShards, false, -1, (node, shardIds, aliasFilters, listener) -> { - sent.add(new NodeRequest(node, shardIds, aliasFilters)); + sent.add(nodeRequest(node, shardIds)); if (Objects.equals(node1, node)) { runWithDelay(() -> listener.onFailure(new RuntimeException("test request level non fatal failure"), false)); } else { @@ -279,7 +278,7 @@ public void testNonFatalFailedOnAllNodes() { var targetShards = List.of(targetShard(shard1, node1, node2)); var sent = ConcurrentCollections.newQueue(); var future = sendRequests(targetShards, false, -1, (node, shardIds, aliasFilters, listener) -> { - sent.add(new NodeRequest(node, shardIds, aliasFilters)); + sent.add(nodeRequest(node, shardIds)); runWithDelay(() -> listener.onFailure(new RuntimeException("test request level non fatal failure"), false)); }); expectThrows(RuntimeException.class, equalTo("test request level non fatal failure"), future::actionGet); @@ -290,7 +289,7 @@ public void testDoNotRetryCircuitBreakerException() { var targetShards = List.of(targetShard(shard1, node1, node2)); var sent = ConcurrentCollections.newQueue(); var future = sendRequests(targetShards, false, -1, (node, shardIds, aliasFilters, listener) -> { - sent.add(new NodeRequest(node, shardIds, aliasFilters)); + sent.add(nodeRequest(node, shardIds)); runWithDelay(() -> listener.onFailure(new CircuitBreakingException("cbe", randomFrom(Durability.values())), false)); }); expectThrows(CircuitBreakingException.class, equalTo("cbe"), future::actionGet); @@ -321,7 +320,7 @@ public void testLimitConcurrentNodes() { } } - sent.add(new NodeRequest(node, shardIds, aliasFilters)); + sent.add(nodeRequest(node, shardIds)); runWithDelay(() -> { concurrentRequests.decrementAndGet(); listener.onResponse(new DataNodeComputeResponse(List.of(), Map.of())); @@ -364,7 +363,7 @@ public void testSkipRemovesPriorNonFatalErrors() { var sent = ConcurrentCollections.newQueue(); var response = safeGet(sendRequests(targetShards, randomBoolean(), 1, (node, shardIds, aliasFilters, listener) -> { - sent.add(new NodeRequest(node, shardIds, aliasFilters)); + sent.add(nodeRequest(node, shardIds)); runWithDelay(() -> { if (Objects.equals(node.getId(), node1.getId()) && shardIds.equals(List.of(shard1))) { listener.onFailure(new RuntimeException("test request level non fatal failure"), false); @@ -406,29 +405,38 @@ public void testQueryHotShardsFirstWhenIlmMovesShard() { ); var sent = ConcurrentCollections.newQueue(); safeGet(sendRequests(targetShards, randomBoolean(), -1, (node, shardIds, aliasFilters, listener) -> { - sent.add(new NodeRequest(node, shardIds, aliasFilters)); + sent.add(nodeRequest(node, shardIds)); runWithDelay(() -> listener.onResponse(new DataNodeComputeResponse(List.of(), Map.of()))); })); - assertThat(groupRequests(sent, 1), equalTo(Map.of(node1, List.of(shard1)))); - assertThat(groupRequests(sent, 1), anyOf(equalTo(Map.of(node2, List.of(shard2))), equalTo(Map.of(warmNode2, List.of(shard2))))); + assertThat(take(sent, 1), containsInAnyOrder(nodeRequest(node1, shard1))); + assertThat(take(sent, 1), anyOf(contains(nodeRequest(node2, shard2)), contains(nodeRequest(warmNode2, shard2)))); } static DataNodeRequestSender.TargetShard targetShard(ShardId shardId, DiscoveryNode... nodes) { return new DataNodeRequestSender.TargetShard(shardId, new ArrayList<>(Arrays.asList(nodes)), null); } - static Map> groupRequests(Queue sent, int limit) { - Map> map = new HashMap<>(); + static DataNodeRequestSender.NodeRequest nodeRequest(DiscoveryNode node, ShardId... shardIds) { + return nodeRequest(node, Arrays.asList(shardIds)); + } + + static DataNodeRequestSender.NodeRequest nodeRequest(DiscoveryNode node, List shardIds) { + var copy = new ArrayList<>(shardIds); + Collections.sort(copy); + return new NodeRequest(node, copy, Map.of()); + } + + static Collection take(Queue queue, int limit) { + var result = new ArrayList(limit); for (int i = 0; i < limit; i++) { - NodeRequest r = sent.remove(); - assertNull(map.put(r.node(), r.shardIds().stream().sorted().toList())); + result.add(queue.remove()); } - return map; + return result; } void runWithDelay(Runnable runnable) { if (randomBoolean()) { - threadPool.schedule(runnable, TimeValue.timeValueNanos(between(0, 5000)), executor); + threadPool.schedule(runnable, timeValueNanos(between(0, 5000)), executor); } else { executor.execute(runnable); } @@ -465,8 +473,6 @@ PlainActionFuture sendRequests( ) { @Override void searchShards( - Task parentTask, - String clusterAlias, QueryBuilder filter, Set concreteIndices, OriginalIndices originalIndices, @@ -477,7 +483,6 @@ void searchShards( shards.size(), 0 ); - assertSame(parentTask, task); runWithDelay(() -> listener.onResponse(targetShards)); } @@ -492,7 +497,6 @@ protected void sendRequest( } }; requestSender.startComputeOnDataNodes( - "", Set.of(randomAlphaOfLength(10)), new OriginalIndices(new String[0], SearchRequest.DEFAULT_INDICES_OPTIONS), null,