Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.internal.AliasFilter;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportRequestOptions;
Expand Down Expand Up @@ -104,15 +103,14 @@ abstract class DataNodeRequestSender {
}

final void startComputeOnDataNodes(
String clusterAlias,
Set<String> concreteIndices,
OriginalIndices originalIndices,
QueryBuilder requestFilter,
Runnable runOnTaskFailure,
ActionListener<ComputeResponse> listener
) {
final long startTimeInNanos = System.nanoTime();
searchShards(rootTask, clusterAlias, requestFilter, concreteIndices, originalIndices, ActionListener.wrap(targetShards -> {
searchShards(requestFilter, concreteIndices, originalIndices, ActionListener.wrap(targetShards -> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both available as fields, no need to pass them

try (var computeListener = new ComputeListener(transportService.getThreadPool(), runOnTaskFailure, listener.map(profiles -> {
return new ComputeResponse(
profiles,
Expand Down Expand Up @@ -321,7 +319,7 @@ private void trackShardLevelFailure(ShardId shardId, boolean fatal, Exception or
}

/**
* Result from {@link #searchShards(Task, String, QueryBuilder, Set, OriginalIndices, ActionListener)} where can_match is performed to
* Result from {@link #searchShards(QueryBuilder, Set, OriginalIndices, ActionListener)} where can_match is performed to
* determine what shards can be skipped and which target nodes are needed for running the ES|QL query
*
* @param shards List of target shards to perform the ES|QL query on
Expand Down Expand Up @@ -412,8 +410,6 @@ private List<NodeRequest> selectNodeRequests(TargetShards targetShards) {
* to a situation where the column structure (i.e., matched data types) differs depending on the query.
*/
void searchShards(
Task parentTask,
String clusterAlias,
QueryBuilder filter,
Set<String> concreteIndices,
OriginalIndices originalIndices,
Expand Down Expand Up @@ -459,7 +455,7 @@ void searchShards(
transportService.getLocalNode(),
EsqlSearchShardsAction.TYPE.name(),
searchShardsRequest,
parentTask,
rootTask,
TransportRequestOptions.EMPTY,
new ActionListenerResponseHandler<>(searchShardsListener, SearchShardsResponse::new, esqlExecutor)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.compute.test.ComputeTestCase;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.Index;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.shard.ShardId;
Expand All @@ -41,6 +40,7 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
Expand All @@ -59,8 +59,11 @@
import static org.elasticsearch.cluster.node.DiscoveryNodeRole.DATA_FROZEN_NODE_ROLE;
import static org.elasticsearch.cluster.node.DiscoveryNodeRole.DATA_HOT_NODE_ROLE;
import static org.elasticsearch.cluster.node.DiscoveryNodeRole.DATA_WARM_NODE_ROLE;
import static org.elasticsearch.core.TimeValue.timeValueNanos;
import static org.elasticsearch.xpack.esql.plugin.DataNodeRequestSender.NodeRequest;
import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
Expand Down Expand Up @@ -120,12 +123,12 @@ public void testOnePass() {
);
Queue<NodeRequest> sent = ConcurrentCollections.newQueue();
var future = sendRequests(targetShards, randomBoolean(), -1, (node, shardIds, aliasFilters, listener) -> {
sent.add(new NodeRequest(node, shardIds, aliasFilters));
sent.add(nodeRequest(node, shardIds));
runWithDelay(() -> listener.onResponse(new DataNodeComputeResponse(List.of(), Map.of())));
});
safeGet(future);
assertThat(sent.size(), equalTo(2));
assertThat(groupRequests(sent, 2), equalTo(Map.of(node1, List.of(shard1, shard3), node2, List.of(shard2, shard4))));
assertThat(sent, containsInAnyOrder(nodeRequest(node1, shard1, shard3), nodeRequest(node2, shard2, shard4)));
}

public void testMissingShards() {
Expand Down Expand Up @@ -163,7 +166,7 @@ public void testRetryThenSuccess() {
);
Queue<NodeRequest> sent = ConcurrentCollections.newQueue();
var future = sendRequests(targetShards, randomBoolean(), -1, (node, shardIds, aliasFilters, listener) -> {
sent.add(new NodeRequest(node, shardIds, aliasFilters));
sent.add(nodeRequest(node, shardIds));
Map<ShardId, Exception> failures = new HashMap<>();
if (node.equals(node1) && shardIds.contains(shard5)) {
failures.put(shard5, new IOException("test"));
Expand All @@ -179,10 +182,11 @@ public void testRetryThenSuccess() {
throw new AssertionError(e);
}
assertThat(sent, hasSize(5));
var firstRound = groupRequests(sent, 3);
assertThat(firstRound, equalTo(Map.of(node1, List.of(shard1, shard5), node4, List.of(shard2), node2, List.of(shard3, shard4))));
var secondRound = groupRequests(sent, 2);
assertThat(secondRound, equalTo(Map.of(node2, List.of(shard2), node3, List.of(shard5))));
assertThat(
take(sent, 3),
containsInAnyOrder(nodeRequest(node1, shard1, shard5), nodeRequest(node4, shard2), nodeRequest(node2, shard3, shard4))
);
assertThat(take(sent, 2), containsInAnyOrder(nodeRequest(node2, shard2), nodeRequest(node3, shard5)));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets compare domain objects rather than maps of lists

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW you can do matchesMap() and get lovely error messaging. If you want. But I have no objections to comparing any way you like.

}

public void testRetryButFail() {
Expand All @@ -195,7 +199,7 @@ public void testRetryButFail() {
);
Queue<NodeRequest> sent = ConcurrentCollections.newQueue();
var future = sendRequests(targetShards, false, -1, (node, shardIds, aliasFilters, listener) -> {
sent.add(new NodeRequest(node, shardIds, aliasFilters));
sent.add(nodeRequest(node, shardIds));
Map<ShardId, Exception> failures = new HashMap<>();
if (shardIds.contains(shard5)) {
failures.put(shard5, new IOException("test failure for shard5"));
Expand All @@ -206,22 +210,20 @@ public void testRetryButFail() {
assertNotNull(ExceptionsHelper.unwrap(error, IOException.class));
// {node-1, node-2, node-4}, {node-3}, {node-2}
assertThat(sent.size(), equalTo(5));
var firstRound = groupRequests(sent, 3);
assertThat(firstRound, equalTo(Map.of(node1, List.of(shard1, shard5), node2, List.of(shard3, shard4), node4, List.of(shard2))));
NodeRequest fourth = sent.remove();
assertThat(fourth.node(), equalTo(node3));
assertThat(fourth.shardIds(), equalTo(List.of(shard5)));
NodeRequest fifth = sent.remove();
assertThat(fifth.node(), equalTo(node2));
assertThat(fifth.shardIds(), equalTo(List.of(shard5)));
assertThat(
take(sent, 3),
containsInAnyOrder(nodeRequest(node1, shard1, shard5), nodeRequest(node2, shard3, shard4), nodeRequest(node4, shard2))
);
assertThat(take(sent, 1), containsInAnyOrder(nodeRequest(node3, shard5)));
assertThat(take(sent, 1), containsInAnyOrder(nodeRequest(node2, shard5)));
}

public void testDoNotRetryOnRequestLevelFailure() {
var targetShards = List.of(targetShard(shard1, node1), targetShard(shard2, node2), targetShard(shard3, node1));
Queue<NodeRequest> sent = ConcurrentCollections.newQueue();
AtomicBoolean failed = new AtomicBoolean();
var future = sendRequests(targetShards, false, -1, (node, shardIds, aliasFilters, listener) -> {
sent.add(new NodeRequest(node, shardIds, aliasFilters));
sent.add(nodeRequest(node, shardIds));
if (node1.equals(node) && failed.compareAndSet(false, true)) {
runWithDelay(() -> listener.onFailure(new IOException("test request level failure"), true));
} else {
Expand All @@ -232,37 +234,35 @@ public void testDoNotRetryOnRequestLevelFailure() {
assertNotNull(ExceptionsHelper.unwrap(exception, IOException.class));
// one round: {node-1, node-2}
assertThat(sent.size(), equalTo(2));
var firstRound = groupRequests(sent, 2);
assertThat(firstRound, equalTo(Map.of(node1, List.of(shard1, shard3), node2, List.of(shard2))));
assertThat(sent, containsInAnyOrder(nodeRequest(node1, shard1, shard3), nodeRequest(node2, shard2)));
}

public void testAllowPartialResults() {
var targetShards = List.of(targetShard(shard1, node1), targetShard(shard2, node2), targetShard(shard3, node1, node2));
Queue<NodeRequest> sent = ConcurrentCollections.newQueue();
AtomicBoolean failed = new AtomicBoolean();
var future = sendRequests(targetShards, true, -1, (node, shardIds, aliasFilters, listener) -> {
sent.add(new NodeRequest(node, shardIds, aliasFilters));
sent.add(nodeRequest(node, shardIds));
if (node1.equals(node) && failed.compareAndSet(false, true)) {
runWithDelay(() -> listener.onFailure(new IOException("test request level failure"), true));
} else {
runWithDelay(() -> listener.onResponse(new DataNodeComputeResponse(List.of(), Map.of())));
}
});
ComputeResponse resp = safeGet(future);
var response = safeGet(future);
assertThat(response.totalShards, equalTo(3));
assertThat(response.failedShards, equalTo(2));
assertThat(response.successfulShards, equalTo(1));
// one round: {node-1, node-2}
assertThat(sent.size(), equalTo(2));
var firstRound = groupRequests(sent, 2);
assertThat(firstRound, equalTo(Map.of(node1, List.of(shard1, shard3), node2, List.of(shard2))));
assertThat(resp.totalShards, equalTo(3));
assertThat(resp.failedShards, equalTo(2));
assertThat(resp.successfulShards, equalTo(1));
assertThat(sent, containsInAnyOrder(nodeRequest(node1, shard1, shard3), nodeRequest(node2, shard2)));
}

public void testNonFatalErrorIsRetriedOnAnotherShard() {
var targetShards = List.of(targetShard(shard1, node1, node2));
var sent = ConcurrentCollections.<NodeRequest>newQueue();
var response = safeGet(sendRequests(targetShards, false, -1, (node, shardIds, aliasFilters, listener) -> {
sent.add(new NodeRequest(node, shardIds, aliasFilters));
sent.add(nodeRequest(node, shardIds));
if (Objects.equals(node1, node)) {
runWithDelay(() -> listener.onFailure(new RuntimeException("test request level non fatal failure"), false));
} else {
Expand All @@ -279,7 +279,7 @@ public void testNonFatalFailedOnAllNodes() {
var targetShards = List.of(targetShard(shard1, node1, node2));
var sent = ConcurrentCollections.<NodeRequest>newQueue();
var future = sendRequests(targetShards, false, -1, (node, shardIds, aliasFilters, listener) -> {
sent.add(new NodeRequest(node, shardIds, aliasFilters));
sent.add(nodeRequest(node, shardIds));
runWithDelay(() -> listener.onFailure(new RuntimeException("test request level non fatal failure"), false));
});
expectThrows(RuntimeException.class, equalTo("test request level non fatal failure"), future::actionGet);
Expand All @@ -290,7 +290,7 @@ public void testDoNotRetryCircuitBreakerException() {
var targetShards = List.of(targetShard(shard1, node1, node2));
var sent = ConcurrentCollections.<NodeRequest>newQueue();
var future = sendRequests(targetShards, false, -1, (node, shardIds, aliasFilters, listener) -> {
sent.add(new NodeRequest(node, shardIds, aliasFilters));
sent.add(nodeRequest(node, shardIds));
runWithDelay(() -> listener.onFailure(new CircuitBreakingException("cbe", randomFrom(Durability.values())), false));
});
expectThrows(CircuitBreakingException.class, equalTo("cbe"), future::actionGet);
Expand Down Expand Up @@ -321,7 +321,7 @@ public void testLimitConcurrentNodes() {
}
}

sent.add(new NodeRequest(node, shardIds, aliasFilters));
sent.add(nodeRequest(node, shardIds));
runWithDelay(() -> {
concurrentRequests.decrementAndGet();
listener.onResponse(new DataNodeComputeResponse(List.of(), Map.of()));
Expand Down Expand Up @@ -364,7 +364,7 @@ public void testSkipRemovesPriorNonFatalErrors() {

var sent = ConcurrentCollections.<NodeRequest>newQueue();
var response = safeGet(sendRequests(targetShards, randomBoolean(), 1, (node, shardIds, aliasFilters, listener) -> {
sent.add(new NodeRequest(node, shardIds, aliasFilters));
sent.add(nodeRequest(node, shardIds));
runWithDelay(() -> {
if (Objects.equals(node.getId(), node1.getId()) && shardIds.equals(List.of(shard1))) {
listener.onFailure(new RuntimeException("test request level non fatal failure"), false);
Expand Down Expand Up @@ -406,29 +406,38 @@ public void testQueryHotShardsFirstWhenIlmMovesShard() {
);
var sent = ConcurrentCollections.<NodeRequest>newQueue();
safeGet(sendRequests(targetShards, randomBoolean(), -1, (node, shardIds, aliasFilters, listener) -> {
sent.add(new NodeRequest(node, shardIds, aliasFilters));
sent.add(nodeRequest(node, shardIds));
runWithDelay(() -> listener.onResponse(new DataNodeComputeResponse(List.of(), Map.of())));
}));
assertThat(groupRequests(sent, 1), equalTo(Map.of(node1, List.of(shard1))));
assertThat(groupRequests(sent, 1), anyOf(equalTo(Map.of(node2, List.of(shard2))), equalTo(Map.of(warmNode2, List.of(shard2)))));
assertThat(take(sent, 1), containsInAnyOrder(nodeRequest(node1, shard1)));
assertThat(take(sent, 1), anyOf(contains(nodeRequest(node2, shard2)), contains(nodeRequest(warmNode2, shard2))));
}

static DataNodeRequestSender.TargetShard targetShard(ShardId shardId, DiscoveryNode... nodes) {
return new DataNodeRequestSender.TargetShard(shardId, new ArrayList<>(Arrays.asList(nodes)), null);
}

static Map<DiscoveryNode, List<ShardId>> groupRequests(Queue<NodeRequest> sent, int limit) {
Map<DiscoveryNode, List<ShardId>> map = new HashMap<>();
static DataNodeRequestSender.NodeRequest nodeRequest(DiscoveryNode node, ShardId... shardIds) {
return nodeRequest(node, Arrays.asList(shardIds));
}

static DataNodeRequestSender.NodeRequest nodeRequest(DiscoveryNode node, List<ShardId> shardIds) {
var copy = new ArrayList<>(shardIds);
Collections.sort(copy);
return new NodeRequest(node, copy, Map.of());
}

static <T> Collection<T> take(Queue<T> queue, int limit) {
var result = new ArrayList<T>(limit);
for (int i = 0; i < limit; i++) {
NodeRequest r = sent.remove();
assertNull(map.put(r.node(), r.shardIds().stream().sorted().toList()));
result.add(queue.remove());
}
return map;
return result;
}

void runWithDelay(Runnable runnable) {
if (randomBoolean()) {
threadPool.schedule(runnable, TimeValue.timeValueNanos(between(0, 5000)), executor);
threadPool.schedule(runnable, timeValueNanos(between(0, 5000)), executor);
} else {
executor.execute(runnable);
}
Expand Down