diff --git a/docs/changelog/120774.yaml b/docs/changelog/120774.yaml new file mode 100644 index 0000000000000..8157e1725be83 --- /dev/null +++ b/docs/changelog/120774.yaml @@ -0,0 +1,5 @@ +pr: 120774 +summary: Retry ES|QL node requests on shard level failures +area: ES|QL +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 1b4931236d56f..d905787f4d5d7 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -179,6 +179,8 @@ static TransportVersion def(int id) { public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED = def(9_003_0_00); public static final TransportVersion REMOVE_DESIRED_NODE_VERSION = def(9_004_0_00); public static final TransportVersion ESQL_DRIVER_TASK_DESCRIPTION = def(9_005_0_00); + public static final TransportVersion ESQL_RETRY_ON_SHARD_LEVEL_FAILURE = def(9_006_0_00); + /* * STOP! READ THIS FIRST! No, really, * ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _ diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/FailureCollector.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/FailureCollector.java index 337075edbdcf6..c492ba6796350 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/FailureCollector.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/FailureCollector.java @@ -45,7 +45,7 @@ public FailureCollector(int maxExceptions) { this.nonCancelledExceptionsPermits = new Semaphore(maxExceptions); } - private static Exception unwrapTransportException(TransportException te) { + public static Exception unwrapTransportException(TransportException te) { final Throwable cause = te.getCause(); if (cause == null) { return te; diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlRetryIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlRetryIT.java new file mode 100644 index 0000000000000..05b2211deecb8 --- /dev/null +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlRetryIT.java @@ -0,0 +1,91 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.action; + +import org.elasticsearch.action.index.IndexRequestBuilder; +import org.elasticsearch.index.IndexService; +import org.elasticsearch.index.shard.IndexShard; +import org.elasticsearch.indices.IndicesService; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.test.transport.MockTransportService; +import org.elasticsearch.xpack.esql.EsqlTestUtils; +import org.elasticsearch.xpack.esql.plugin.ComputeService; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.elasticsearch.index.shard.IndexShardTestCase.closeShardNoCheck; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.hamcrest.Matchers.equalTo; + +public class EsqlRetryIT extends AbstractEsqlIntegTestCase { + + @Override + protected Collection> nodePlugins() { + List> plugins = new ArrayList<>(super.nodePlugins()); + plugins.add(MockTransportService.TestPlugin.class); + return plugins; + } + + public void testRetryOnShardFailures() throws Exception { + populateIndices(); + try { + final AtomicBoolean relocated = new AtomicBoolean(); + for (String node : internalCluster().getNodeNames()) { + // fail some target shards while handling the data node request + MockTransportService.getInstance(node) + .addRequestHandlingBehavior(ComputeService.DATA_ACTION_NAME, (handler, request, channel, task) -> { + if (relocated.compareAndSet(false, true)) { + closeOrFailShards(node); + } + handler.messageReceived(request, channel, task); + }); + } + try (var resp = run("FROM log-* | STATS COUNT(timestamp) | LIMIT 1")) { + assertThat(EsqlTestUtils.getValuesList(resp).get(0).get(0), equalTo(7L)); + } + } finally { + for (String node : internalCluster().getNodeNames()) { + MockTransportService.getInstance(node).clearAllRules(); + } + } + } + + private void populateIndices() { + internalCluster().ensureAtLeastNumDataNodes(2); + assertAcked(prepareCreate("log-index-1").setSettings(indexSettings(between(1, 3), 1)).setMapping("timestamp", "type=date")); + assertAcked(prepareCreate("log-index-2").setSettings(indexSettings(between(1, 3), 1)).setMapping("timestamp", "type=date")); + List reqs = new ArrayList<>(); + reqs.add(prepareIndex("log-index-1").setSource("timestamp", "2015-07-08")); + reqs.add(prepareIndex("log-index-1").setSource("timestamp", "2018-07-08")); + reqs.add(prepareIndex("log-index-1").setSource("timestamp", "2020-03-03")); + reqs.add(prepareIndex("log-index-1").setSource("timestamp", "2020-09-09")); + reqs.add(prepareIndex("log-index-2").setSource("timestamp", "2019-10-12")); + reqs.add(prepareIndex("log-index-2").setSource("timestamp", "2020-02-02")); + reqs.add(prepareIndex("log-index-2").setSource("timestamp", "2020-10-10")); + indexRandom(true, reqs); + ensureGreen("log-index-1", "log-index-2"); + indicesAdmin().prepareRefresh("log-index-1", "log-index-2").get(); + } + + private void closeOrFailShards(String nodeName) throws Exception { + final IndicesService indicesService = internalCluster().getInstance(IndicesService.class, nodeName); + for (IndexService indexService : indicesService) { + for (IndexShard indexShard : indexService) { + if (randomBoolean()) { + indexShard.failShard("simulated", new IOException("simulated failure")); + } else if (randomBoolean()) { + closeShardNoCheck(indexShard); + } + } + } + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeComputeHandler.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeComputeHandler.java index 40a87fca4dc25..ba2f3c5dfdc2c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeComputeHandler.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeComputeHandler.java @@ -7,13 +7,11 @@ package org.elasticsearch.xpack.esql.plugin; +import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListenerResponseHandler; import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.action.OriginalIndices; -import org.elasticsearch.action.search.SearchShardsGroup; -import org.elasticsearch.action.search.SearchShardsRequest; -import org.elasticsearch.action.search.SearchShardsResponse; import org.elasticsearch.action.support.ChannelActionListener; import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.cluster.node.DiscoveryNode; @@ -24,12 +22,9 @@ import org.elasticsearch.compute.operator.exchange.ExchangeSourceHandler; import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.Releasable; -import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.Index; -import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.index.shard.ShardId; -import org.elasticsearch.index.shard.ShardNotFoundException; import org.elasticsearch.search.SearchService; import org.elasticsearch.search.internal.AliasFilter; import org.elasticsearch.search.internal.SearchContext; @@ -43,7 +38,6 @@ import org.elasticsearch.transport.TransportRequestHandler; import org.elasticsearch.transport.TransportRequestOptions; import org.elasticsearch.transport.TransportService; -import org.elasticsearch.xpack.esql.action.EsqlSearchShardsAction; import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.plan.physical.ExchangeSinkExec; import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; @@ -57,6 +51,9 @@ import java.util.Set; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; import static org.elasticsearch.xpack.esql.plugin.EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME; @@ -70,6 +67,7 @@ final class DataNodeComputeHandler implements TransportRequestHandler outListener ) { - QueryBuilder requestFilter = PlannerUtils.requestTimestampFilter(dataNodePlan); - var listener = ActionListener.runAfter(outListener, exchangeSource.addEmptySink()::close); - final long startTimeInNanos = System.nanoTime(); - lookupDataNodes(parentTask, clusterAlias, requestFilter, concreteIndices, originalIndices, ActionListener.wrap(dataNodeResult -> { - try (var computeListener = new ComputeListener(transportService.getThreadPool(), runOnTaskFailure, listener.map(profiles -> { - TimeValue took = TimeValue.timeValueNanos(System.nanoTime() - startTimeInNanos); - return new ComputeResponse( - profiles, - took, - dataNodeResult.totalShards(), - dataNodeResult.totalShards(), - dataNodeResult.skippedShards(), - 0 - ); - }))) { + DataNodeRequestSender sender = new DataNodeRequestSender(transportService, esqlExecutor, parentTask) { + @Override + protected void sendRequest( + DiscoveryNode node, + List shardIds, + Map aliasFilters, + NodeListener nodeListener + ) { + final AtomicLong pagesFetched = new AtomicLong(); + var listener = ActionListener.wrap(nodeListener::onResponse, e -> nodeListener.onFailure(e, pagesFetched.get() > 0)); + final Transport.Connection connection; + try { + connection = transportService.getConnection(node); + } catch (Exception e) { + listener.onFailure(e); + return; + } + var queryPragmas = configuration.pragmas(); + var childSessionId = computeService.newChildSession(sessionId); // For each target node, first open a remote exchange on the remote node, then link the exchange source to // the new remote exchange sink, and initialize the computation on the target node via data-node-request. - for (DataNode node : dataNodeResult.dataNodes()) { - var queryPragmas = configuration.pragmas(); - var childSessionId = computeService.newChildSession(sessionId); - ActionListener nodeListener = computeListener.acquireCompute().map(ComputeResponse::getProfiles); - ExchangeService.openExchange( - transportService, - node.connection, - childSessionId, - queryPragmas.exchangeBufferSize(), - esqlExecutor, - nodeListener.delegateFailureAndWrap((l, unused) -> { - var remoteSink = exchangeService.newRemoteSink(parentTask, childSessionId, transportService, node.connection); + ExchangeService.openExchange( + transportService, + connection, + childSessionId, + queryPragmas.exchangeBufferSize(), + esqlExecutor, + listener.delegateFailureAndWrap((l, unused) -> { + final AtomicReference nodeResponseRef = new AtomicReference<>(); + try ( + var computeListener = new ComputeListener(threadPool, runOnTaskFailure, l.map(ignored -> nodeResponseRef.get())) + ) { + final var remoteSink = exchangeService.newRemoteSink(parentTask, childSessionId, transportService, connection); exchangeSource.addRemoteSink( remoteSink, true, - () -> {}, + pagesFetched::incrementAndGet, queryPragmas.concurrentExchangeClients(), computeListener.acquireAvoid() ); - final boolean sameNode = transportService.getLocalNode().getId().equals(node.connection.getNode().getId()); + final boolean sameNode = transportService.getLocalNode().getId().equals(connection.getNode().getId()); var dataNodeRequest = new DataNodeRequest( childSessionId, configuration, clusterAlias, - node.shardIds, - node.aliasFilters, + shardIds, + aliasFilters, dataNodePlan, originalIndices.indices(), originalIndices.indicesOptions(), sameNode == false && queryPragmas.nodeLevelReduction() ); transportService.sendChildRequest( - node.connection, + connection, ComputeService.DATA_ACTION_NAME, dataNodeRequest, parentTask, TransportRequestOptions.EMPTY, - new ActionListenerResponseHandler<>(nodeListener, ComputeResponse::new, esqlExecutor) + new ActionListenerResponseHandler<>(computeListener.acquireCompute().map(r -> { + nodeResponseRef.set(r); + return r.profiles(); + }), DataNodeComputeResponse::new, esqlExecutor) ); - }) - ); - } - } - }, listener::onFailure)); - } - - private void acquireSearchContexts( - String clusterAlias, - List shardIds, - Configuration configuration, - Map aliasFilters, - ActionListener> listener - ) { - final List targetShards = new ArrayList<>(); - try { - for (ShardId shardId : shardIds) { - var indexShard = searchService.getIndicesService().indexServiceSafe(shardId.getIndex()).getShard(shardId.id()); - targetShards.add(indexShard); - } - } catch (Exception e) { - listener.onFailure(e); - return; - } - final var doAcquire = ActionRunnable.supply(listener, () -> { - final List searchContexts = new ArrayList<>(targetShards.size()); - boolean success = false; - try { - for (IndexShard shard : targetShards) { - var aliasFilter = aliasFilters.getOrDefault(shard.shardId().getIndex(), AliasFilter.EMPTY); - var shardRequest = new ShardSearchRequest( - shard.shardId(), - configuration.absoluteStartedTimeInMillis(), - aliasFilter, - clusterAlias - ); - // TODO: `searchService.createSearchContext` allows opening search contexts without limits, - // we need to limit the number of active search contexts here or in SearchService - SearchContext context = searchService.createSearchContext(shardRequest, SearchService.NO_TIMEOUT); - searchContexts.add(context); - } - for (SearchContext searchContext : searchContexts) { - searchContext.preProcess(); - } - success = true; - return searchContexts; - } finally { - if (success == false) { - IOUtils.close(searchContexts); - } - } - }); - final AtomicBoolean waitedForRefreshes = new AtomicBoolean(); - try (RefCountingRunnable refs = new RefCountingRunnable(() -> { - if (waitedForRefreshes.get()) { - esqlExecutor.execute(doAcquire); - } else { - doAcquire.run(); - } - })) { - for (IndexShard targetShard : targetShards) { - final Releasable ref = refs.acquire(); - targetShard.ensureShardSearchActive(await -> { - try (ref) { - if (await) { - waitedForRefreshes.set(true); } - } - }); - } - } - } - - record DataNode(Transport.Connection connection, List shardIds, Map aliasFilters) { - - } - - /** - * Result from lookupDataNodes where can_match is performed to determine what shards can be skipped - * and which target nodes are needed for running the ES|QL query - * - * @param dataNodes list of DataNode to perform the ES|QL query on - * @param totalShards Total number of shards (from can_match phase), including skipped shards - * @param skippedShards Number of skipped shards (from can_match phase) - */ - record DataNodeResult(List dataNodes, int totalShards, int skippedShards) {} - - /** - * Performs can_match and find the target nodes for the given target indices and filter. - *

- * Ideally, the search_shards API should be called before the field-caps API; however, this can lead - * to a situation where the column structure (i.e., matched data types) differs depending on the query. - */ - private void lookupDataNodes( - Task parentTask, - String clusterAlias, - QueryBuilder filter, - Set concreteIndices, - OriginalIndices originalIndices, - ActionListener listener - ) { - ActionListener searchShardsListener = listener.map(resp -> { - Map nodes = new HashMap<>(); - for (DiscoveryNode node : resp.getNodes()) { - nodes.put(node.getId(), node); - } - Map> nodeToShards = new HashMap<>(); - Map> nodeToAliasFilters = new HashMap<>(); - int totalShards = 0; - int skippedShards = 0; - for (SearchShardsGroup group : resp.getGroups()) { - var shardId = group.shardId(); - if (group.allocatedNodes().isEmpty()) { - throw new ShardNotFoundException(group.shardId(), "no shard copies found {}", group.shardId()); - } - if (concreteIndices.contains(shardId.getIndexName()) == false) { - continue; - } - totalShards++; - if (group.skipped()) { - skippedShards++; - continue; - } - String targetNode = group.allocatedNodes().get(0); - nodeToShards.computeIfAbsent(targetNode, k -> new ArrayList<>()).add(shardId); - AliasFilter aliasFilter = resp.getAliasFilters().get(shardId.getIndex().getUUID()); - if (aliasFilter != null) { - nodeToAliasFilters.computeIfAbsent(targetNode, k -> new HashMap<>()).put(shardId.getIndex(), aliasFilter); - } - } - List dataNodes = new ArrayList<>(nodeToShards.size()); - for (Map.Entry> e : nodeToShards.entrySet()) { - DiscoveryNode node = nodes.get(e.getKey()); - Map aliasFilters = nodeToAliasFilters.getOrDefault(e.getKey(), Map.of()); - dataNodes.add(new DataNode(transportService.getConnection(node), e.getValue(), aliasFilters)); + }) + ); } - return new DataNodeResult(dataNodes, totalShards, skippedShards); - }); - SearchShardsRequest searchShardsRequest = new SearchShardsRequest( - originalIndices.indices(), - originalIndices.indicesOptions(), - filter, - null, - null, - false, - clusterAlias - ); - transportService.sendChildRequest( - transportService.getLocalNode(), - EsqlSearchShardsAction.TYPE.name(), - searchShardsRequest, - parentTask, - TransportRequestOptions.EMPTY, - new ActionListenerResponseHandler<>(searchShardsListener, SearchShardsResponse::new, esqlExecutor) + }; + sender.startComputeOnDataNodes( + clusterAlias, + concreteIndices, + originalIndices, + PlannerUtils.requestTimestampFilter(dataNodePlan), + runOnTaskFailure, + ActionListener.runAfter(outListener, exchangeSource.addEmptySink()::close) ); } @@ -318,12 +182,16 @@ private class DataNodeRequestExecutor { private final ComputeListener computeListener; private final int maxConcurrentShards; private final ExchangeSink blockingSink; // block until we have completed on all shards or the coordinator has enough data + private final boolean failFastOnShardFailure; + private final Map shardLevelFailures; DataNodeRequestExecutor( DataNodeRequest request, CancellableTask parentTask, ExchangeSinkHandler exchangeSink, int maxConcurrentShards, + boolean failFastOnShardFailure, + Map shardLevelFailures, ComputeListener computeListener ) { this.request = request; @@ -331,6 +199,8 @@ private class DataNodeRequestExecutor { this.exchangeSink = exchangeSink; this.computeListener = computeListener; this.maxConcurrentShards = maxConcurrentShards; + this.failFastOnShardFailure = failFastOnShardFailure; + this.shardLevelFailures = shardLevelFailures; this.blockingSink = exchangeSink.createExchangeSink(() -> {}); } @@ -346,6 +216,7 @@ private void runBatch(int startBatchIndex) { final String clusterAlias = request.clusterAlias(); final var sessionId = request.sessionId(); final int endBatchIndex = Math.min(startBatchIndex + maxConcurrentShards, request.shardIds().size()); + final AtomicInteger pagesProduced = new AtomicInteger(); List shardIds = request.shardIds().subList(startBatchIndex, endBatchIndex); ActionListener> batchListener = new ActionListener<>() { final ActionListener> ref = computeListener.acquireCompute(); @@ -361,15 +232,26 @@ public void onResponse(List result) { @Override public void onFailure(Exception e) { - try { - exchangeService.finishSinkHandler(request.sessionId(), e); - } finally { - ref.onFailure(e); + if (pagesProduced.get() == 0 && failFastOnShardFailure == false) { + for (ShardId shardId : shardIds) { + addShardLevelFailure(shardId, e); + } + onResponse(List.of()); + } else { + try { + exchangeService.finishSinkHandler(request.sessionId(), e); + } finally { + ref.onFailure(e); + } } } }; acquireSearchContexts(clusterAlias, shardIds, configuration, request.aliasFilters(), ActionListener.wrap(searchContexts -> { assert ThreadPool.assertCurrentThreadPool(ThreadPool.Names.SEARCH, ESQL_WORKER_THREAD_POOL_NAME); + if (searchContexts.isEmpty()) { + batchListener.onResponse(List.of()); + return; + } var computeContext = new ComputeContext( sessionId, "data", @@ -378,12 +260,80 @@ public void onFailure(Exception e) { configuration, configuration.newFoldContext(), null, - () -> exchangeSink.createExchangeSink(() -> {}) + () -> exchangeSink.createExchangeSink(pagesProduced::incrementAndGet) ); computeService.runCompute(parentTask, computeContext, request.plan(), batchListener); }, batchListener::onFailure)); } + private void acquireSearchContexts( + String clusterAlias, + List shardIds, + Configuration configuration, + Map aliasFilters, + ActionListener> listener + ) { + final List targetShards = new ArrayList<>(); + for (ShardId shardId : shardIds) { + try { + var indexShard = searchService.getIndicesService().indexServiceSafe(shardId.getIndex()).getShard(shardId.id()); + targetShards.add(indexShard); + } catch (Exception e) { + if (addShardLevelFailure(shardId, e) == false) { + listener.onFailure(e); + return; + } + } + } + final var doAcquire = ActionRunnable.supply(listener, () -> { + final List searchContexts = new ArrayList<>(targetShards.size()); + SearchContext context = null; + for (IndexShard shard : targetShards) { + try { + var aliasFilter = aliasFilters.getOrDefault(shard.shardId().getIndex(), AliasFilter.EMPTY); + var shardRequest = new ShardSearchRequest( + shard.shardId(), + configuration.absoluteStartedTimeInMillis(), + aliasFilter, + clusterAlias + ); + // TODO: `searchService.createSearchContext` allows opening search contexts without limits, + // we need to limit the number of active search contexts here or in SearchService + context = searchService.createSearchContext(shardRequest, SearchService.NO_TIMEOUT); + context.preProcess(); + searchContexts.add(context); + } catch (Exception e) { + if (addShardLevelFailure(shard.shardId(), e)) { + IOUtils.close(context); + } else { + IOUtils.closeWhileHandlingException(context, () -> IOUtils.close(searchContexts)); + throw e; + } + } + } + return searchContexts; + }); + final AtomicBoolean waitedForRefreshes = new AtomicBoolean(); + try (RefCountingRunnable refs = new RefCountingRunnable(() -> { + if (waitedForRefreshes.get()) { + esqlExecutor.execute(doAcquire); + } else { + doAcquire.run(); + } + })) { + for (IndexShard targetShard : targetShards) { + final Releasable ref = refs.acquire(); + targetShard.ensureShardSearchActive(await -> { + try (ref) { + if (await) { + waitedForRefreshes.set(true); + } + } + }); + } + } + } + private void onBatchCompleted(int lastBatchIndex) { if (lastBatchIndex < request.shardIds().size() && exchangeSink.isFinished() == false) { runBatch(lastBatchIndex); @@ -396,6 +346,14 @@ private void onBatchCompleted(int lastBatchIndex) { blockingSink.finish(); } } + + private boolean addShardLevelFailure(ShardId shardId, Exception e) { + if (failFastOnShardFailure) { + return false; + } + shardLevelFailures.put(shardId, e); + return true; + } } private void runComputeOnDataNode( @@ -403,13 +361,15 @@ private void runComputeOnDataNode( String externalId, PhysicalPlan reducePlan, DataNodeRequest request, - ActionListener listener + boolean failFastOnShardFailure, + ActionListener listener ) { + final Map shardLevelFailures = new HashMap<>(); try ( ComputeListener computeListener = new ComputeListener( transportService.getThreadPool(), computeService.cancelQueryOnFailure(task), - listener.map(ComputeResponse::new) + listener.map(profiles -> new DataNodeComputeResponse(profiles, shardLevelFailures)) ) ) { var parentListener = computeListener.acquireAvoid(); @@ -421,6 +381,8 @@ private void runComputeOnDataNode( task, internalSink, request.configuration().pragmas().maxConcurrentShardsPerNode(), + failFastOnShardFailure, + shardLevelFailures, computeListener ); dataNodeRequestExecutor.start(); @@ -467,7 +429,7 @@ private void runComputeOnDataNode( @Override public void messageReceived(DataNodeRequest request, TransportChannel channel, Task task) { - final ActionListener listener = new ChannelActionListener<>(channel); + final ActionListener listener = new ChannelActionListener<>(channel); final PhysicalPlan reductionPlan; if (request.plan() instanceof ExchangeSinkExec plan) { reductionPlan = ComputeService.reductionPlan(plan, request.runNodeLevelReduction()); @@ -487,6 +449,8 @@ public void messageReceived(DataNodeRequest request, TransportChannel channel, T request.indicesOptions(), request.runNodeLevelReduction() ); - runComputeOnDataNode((CancellableTask) task, sessionId, reductionPlan, request, listener); + // the sender doesn't support retry on shard failures, so we need to fail fast here. + final boolean failFastOnShardFailures = channel.getVersion().before(TransportVersions.ESQL_RETRY_ON_SHARD_LEVEL_FAILURE); + runComputeOnDataNode((CancellableTask) task, sessionId, reductionPlan, request, failFastOnShardFailures, listener); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeComputeResponse.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeComputeResponse.java new file mode 100644 index 0000000000000..34a92fb135277 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeComputeResponse.java @@ -0,0 +1,64 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.plugin; + +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.compute.operator.DriverProfile; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.transport.TransportResponse; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +/** + * The compute result of {@link DataNodeRequest} + */ +final class DataNodeComputeResponse extends TransportResponse { + private final List profiles; + private final Map shardLevelFailures; + + DataNodeComputeResponse(List profiles, Map shardLevelFailures) { + this.profiles = profiles; + this.shardLevelFailures = shardLevelFailures; + } + + DataNodeComputeResponse(StreamInput in) throws IOException { + if (in.getTransportVersion().onOrAfter(TransportVersions.ESQL_RETRY_ON_SHARD_LEVEL_FAILURE)) { + this.profiles = in.readCollectionAsImmutableList(DriverProfile::new); + this.shardLevelFailures = in.readMap(ShardId::new, StreamInput::readException); + } else { + this.profiles = Objects.requireNonNullElse(new ComputeResponse(in).getProfiles(), List.of()); + this.shardLevelFailures = Map.of(); + } + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_RETRY_ON_SHARD_LEVEL_FAILURE)) { + out.writeCollection(profiles, (o, v) -> v.writeTo(o)); + out.writeMap(shardLevelFailures, (o, v) -> v.writeTo(o), StreamOutput::writeException); + } else { + if (shardLevelFailures.isEmpty() == false) { + throw new IllegalStateException("shard level failures are not supported in old versions"); + } + new ComputeResponse(profiles).writeTo(out); + } + } + + List profiles() { + return profiles; + } + + Map shardLevelFailures() { + return shardLevelFailures; + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSender.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSender.java new file mode 100644 index 0000000000000..6af2c12ace086 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSender.java @@ -0,0 +1,343 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.plugin; + +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionListenerResponseHandler; +import org.elasticsearch.action.NoShardAvailableActionException; +import org.elasticsearch.action.OriginalIndices; +import org.elasticsearch.action.search.SearchShardsGroup; +import org.elasticsearch.action.search.SearchShardsRequest; +import org.elasticsearch.action.search.SearchShardsResponse; +import org.elasticsearch.action.support.TransportActions; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.util.concurrent.ConcurrentCollections; +import org.elasticsearch.compute.operator.DriverProfile; +import org.elasticsearch.compute.operator.FailureCollector; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.Index; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.search.internal.AliasFilter; +import org.elasticsearch.tasks.CancellableTask; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskCancelledException; +import org.elasticsearch.transport.TransportException; +import org.elasticsearch.transport.TransportRequestOptions; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.esql.action.EsqlSearchShardsAction; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.Executor; +import java.util.concurrent.Semaphore; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.locks.ReentrantLock; + +/** + * Handles computes within a single cluster by dispatching {@link DataNodeRequest} to data nodes + * and executing these computes on the data nodes. + */ +abstract class DataNodeRequestSender { + private final TransportService transportService; + private final Executor esqlExecutor; + private final CancellableTask rootTask; + private final ReentrantLock sendingLock = new ReentrantLock(); + private final Queue pendingShardIds = ConcurrentCollections.newQueue(); + private final Map nodePermits = new HashMap<>(); + private final Map shardFailures = ConcurrentCollections.newConcurrentMap(); + private final AtomicBoolean changed = new AtomicBoolean(); + + DataNodeRequestSender(TransportService transportService, Executor esqlExecutor, CancellableTask rootTask) { + this.transportService = transportService; + this.esqlExecutor = esqlExecutor; + this.rootTask = rootTask; + } + + final void startComputeOnDataNodes( + String clusterAlias, + Set concreteIndices, + OriginalIndices originalIndices, + QueryBuilder requestFilter, + Runnable runOnTaskFailure, + ActionListener listener + ) { + final long startTimeInNanos = System.nanoTime(); + searchShards(rootTask, clusterAlias, requestFilter, concreteIndices, originalIndices, ActionListener.wrap(targetShards -> { + try (var computeListener = new ComputeListener(transportService.getThreadPool(), runOnTaskFailure, listener.map(profiles -> { + TimeValue took = TimeValue.timeValueNanos(System.nanoTime() - startTimeInNanos); + return new ComputeResponse( + profiles, + took, + targetShards.totalShards(), + targetShards.totalShards(), + targetShards.skippedShards(), + 0 + ); + }))) { + for (TargetShard shard : targetShards.shards.values()) { + for (DiscoveryNode node : shard.remainingNodes) { + nodePermits.putIfAbsent(node, new Semaphore(1)); + } + } + pendingShardIds.addAll(targetShards.shards.keySet()); + trySendingRequestsForPendingShards(targetShards, computeListener); + } + }, listener::onFailure)); + } + + private void trySendingRequestsForPendingShards(TargetShards targetShards, ComputeListener computeListener) { + changed.set(true); + final ActionListener listener = computeListener.acquireAvoid(); + try { + while (sendingLock.tryLock()) { + try { + if (changed.compareAndSet(true, false) == false) { + break; + } + for (ShardId shardId : pendingShardIds) { + if (targetShards.getShard(shardId).remainingNodes.isEmpty()) { + shardFailures.compute( + shardId, + (k, v) -> new ShardFailure( + true, + v == null ? new NoShardAvailableActionException(shardId, "no shard copies found") : v.failure + ) + ); + } + } + if (shardFailures.values().stream().anyMatch(shardFailure -> shardFailure.fatal)) { + for (var e : shardFailures.values()) { + computeListener.acquireAvoid().onFailure(e.failure); + } + } else { + var nodeRequests = selectNodeRequests(targetShards); + for (NodeRequest request : nodeRequests) { + sendOneNodeRequest(targetShards, computeListener, request); + } + } + } finally { + sendingLock.unlock(); + } + } + } finally { + listener.onResponse(null); + } + } + + private void sendOneNodeRequest(TargetShards targetShards, ComputeListener computeListener, NodeRequest request) { + final ActionListener> listener = computeListener.acquireCompute(); + sendRequest(request.node, request.shardIds, request.aliasFilters, new NodeListener() { + void onAfter(List profiles) { + nodePermits.get(request.node).release(); + trySendingRequestsForPendingShards(targetShards, computeListener); + listener.onResponse(profiles); + } + + @Override + public void onResponse(DataNodeComputeResponse response) { + // remove failures of successful shards + for (ShardId shardId : targetShards.shardIds()) { + if (response.shardLevelFailures().containsKey(shardId) == false) { + shardFailures.remove(shardId); + } + } + for (Map.Entry e : response.shardLevelFailures().entrySet()) { + final ShardId shardId = e.getKey(); + trackShardLevelFailure(shardId, false, e.getValue()); + pendingShardIds.add(shardId); + } + onAfter(response.profiles()); + } + + @Override + public void onFailure(Exception e, boolean receivedData) { + for (ShardId shardId : request.shardIds) { + trackShardLevelFailure(shardId, receivedData, e); + pendingShardIds.add(shardId); + } + onAfter(List.of()); + } + }); + } + + abstract void sendRequest(DiscoveryNode node, List shardIds, Map aliasFilters, NodeListener nodeListener); + + interface NodeListener { + void onResponse(DataNodeComputeResponse response); + + void onFailure(Exception e, boolean receivedData); + } + + private static Exception unwrapFailure(Exception e) { + e = e instanceof TransportException te ? FailureCollector.unwrapTransportException(te) : e; + if (TransportActions.isShardNotAvailableException(e)) { + return NoShardAvailableActionException.forOnShardFailureWrapper(e.getMessage()); + } else { + return e; + } + } + + private void trackShardLevelFailure(ShardId shardId, boolean fatal, Exception originalEx) { + final Exception e = unwrapFailure(originalEx); + // Retain only one meaningful exception and avoid suppressing previous failures to minimize memory usage, especially when handling + // many shards. + shardFailures.compute(shardId, (k, current) -> { + boolean mergedFatal = fatal || ExceptionsHelper.unwrap(e, TaskCancelledException.class) != null; + if (current == null) { + return new ShardFailure(mergedFatal, e); + } + mergedFatal |= current.fatal; + if (e instanceof NoShardAvailableActionException || ExceptionsHelper.unwrap(e, TaskCancelledException.class) != null) { + return new ShardFailure(mergedFatal, current.failure); + } + return new ShardFailure(mergedFatal, e); + }); + } + + /** + * Result from {@link #searchShards(Task, String, QueryBuilder, Set, OriginalIndices, ActionListener)} where can_match is performed to + * determine what shards can be skipped and which target nodes are needed for running the ES|QL query + * + * @param shards List of target shards to perform the ES|QL query on + * @param totalShards Total number of shards (from can_match phase), including skipped shards + * @param skippedShards Number of skipped shards (from can_match phase) + */ + record TargetShards(Map shards, int totalShards, int skippedShards) { + TargetShard getShard(ShardId shardId) { + return shards.get(shardId); + } + + Set shardIds() { + return shards.keySet(); + } + } + + /** + * (Remaining) allocated nodes of a given shard id and its alias filter + */ + record TargetShard(ShardId shardId, List remainingNodes, AliasFilter aliasFilter) { + + } + + record NodeRequest(DiscoveryNode node, List shardIds, Map aliasFilters) { + + } + + private record ShardFailure(boolean fatal, Exception failure) { + + } + + /** + * Selects the next nodes to send requests to. Limits to at most one outstanding request per node. + * If there is already a request in-flight to a node, another request will not be sent to the same node + * until the first request completes. Instead, the next node in the remaining nodes will be tried. + */ + private List selectNodeRequests(TargetShards targetShards) { + assert sendingLock.isHeldByCurrentThread(); + final Map> nodeToShardIds = new HashMap<>(); + final Iterator shardsIt = pendingShardIds.iterator(); + while (shardsIt.hasNext()) { + ShardId shardId = shardsIt.next(); + TargetShard shard = targetShards.getShard(shardId); + Iterator nodesIt = shard.remainingNodes.iterator(); + DiscoveryNode selectedNode = null; + while (nodesIt.hasNext()) { + DiscoveryNode node = nodesIt.next(); + if (nodeToShardIds.containsKey(node) || nodePermits.get(node).tryAcquire()) { + nodesIt.remove(); + shardsIt.remove(); + selectedNode = node; + break; + } + } + if (selectedNode != null) { + nodeToShardIds.computeIfAbsent(selectedNode, unused -> new ArrayList<>()).add(shard.shardId); + } + } + final List nodeRequests = new ArrayList<>(nodeToShardIds.size()); + for (var e : nodeToShardIds.entrySet()) { + List shardIds = e.getValue(); + Map aliasFilters = new HashMap<>(); + for (ShardId shardId : shardIds) { + var aliasFilter = targetShards.getShard(shardId).aliasFilter; + if (aliasFilter != null) { + aliasFilters.put(shardId.getIndex(), aliasFilter); + } + } + nodeRequests.add(new NodeRequest(e.getKey(), shardIds, aliasFilters)); + } + return nodeRequests; + } + + /** + * Performs can_match and find the target nodes for the given target indices and filter. + *

+ * Ideally, the search_shards API should be called before the field-caps API; however, this can lead + * to a situation where the column structure (i.e., matched data types) differs depending on the query. + */ + void searchShards( + Task parentTask, + String clusterAlias, + QueryBuilder filter, + Set concreteIndices, + OriginalIndices originalIndices, + ActionListener listener + ) { + ActionListener searchShardsListener = listener.map(resp -> { + Map nodes = new HashMap<>(); + for (DiscoveryNode node : resp.getNodes()) { + nodes.put(node.getId(), node); + } + int totalShards = 0; + int skippedShards = 0; + Map shards = new HashMap<>(); + for (SearchShardsGroup group : resp.getGroups()) { + var shardId = group.shardId(); + if (concreteIndices.contains(shardId.getIndexName()) == false) { + continue; + } + totalShards++; + if (group.skipped()) { + skippedShards++; + continue; + } + List allocatedNodes = new ArrayList<>(group.allocatedNodes().size()); + for (String n : group.allocatedNodes()) { + allocatedNodes.add(nodes.get(n)); + } + AliasFilter aliasFilter = resp.getAliasFilters().get(shardId.getIndex().getUUID()); + shards.put(shardId, new TargetShard(shardId, allocatedNodes, aliasFilter)); + } + return new TargetShards(shards, totalShards, skippedShards); + }); + SearchShardsRequest searchShardsRequest = new SearchShardsRequest( + originalIndices.indices(), + originalIndices.indicesOptions(), + filter, + null, + null, + false, + clusterAlias + ); + transportService.sendChildRequest( + transportService.getLocalNode(), + EsqlSearchShardsAction.TYPE.name(), + searchShardsRequest, + parentTask, + TransportRequestOptions.EMPTY, + new ActionListenerResponseHandler<>(searchShardsListener, SearchShardsResponse::new, esqlExecutor) + ); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSenderTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSenderTests.java new file mode 100644 index 0000000000000..e181d9bb34955 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSenderTests.java @@ -0,0 +1,287 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.plugin; + +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.NoShardAvailableActionException; +import org.elasticsearch.action.OriginalIndices; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.node.DiscoveryNodeUtils; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ConcurrentCollections; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.compute.test.ComputeTestCase; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.Index; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.search.internal.AliasFilter; +import org.elasticsearch.tasks.CancellableTask; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.threadpool.FixedExecutorBuilder; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.transport.TransportService; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static org.elasticsearch.xpack.esql.plugin.DataNodeRequestSender.NodeRequest; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class DataNodeRequestSenderTests extends ComputeTestCase { + + private TestThreadPool threadPool; + private Executor executor = null; + private static final String ESQL_TEST_EXECUTOR = "esql_test_executor"; + + private final DiscoveryNode node1 = DiscoveryNodeUtils.create("node-1"); + private final DiscoveryNode node2 = DiscoveryNodeUtils.create("node-2"); + private final DiscoveryNode node3 = DiscoveryNodeUtils.create("node-3"); + private final DiscoveryNode node4 = DiscoveryNodeUtils.create("node-4"); + private final DiscoveryNode node5 = DiscoveryNodeUtils.create("node-5"); + private final ShardId shard1 = new ShardId("index", "n/a", 1); + private final ShardId shard2 = new ShardId("index", "n/a", 2); + private final ShardId shard3 = new ShardId("index", "n/a", 3); + private final ShardId shard4 = new ShardId("index", "n/a", 4); + private final ShardId shard5 = new ShardId("index", "n/a", 5); + + @Before + public void setThreadPool() { + int numThreads = randomBoolean() ? 1 : between(2, 16); + threadPool = new TestThreadPool( + "test", + new FixedExecutorBuilder(Settings.EMPTY, ESQL_TEST_EXECUTOR, numThreads, 1024, "esql", EsExecutors.TaskTrackingConfig.DEFAULT) + ); + executor = threadPool.executor(ESQL_TEST_EXECUTOR); + } + + @After + public void shutdownThreadPool() throws Exception { + terminate(threadPool); + } + + public void testEmpty() { + var future = sendRequests(List.of(), (node, shardIds, aliasFilters, listener) -> fail("expect no data-node request is sent")); + var resp = safeGet(future); + assertThat(resp.totalShards, equalTo(0)); + } + + public void testOnePass() { + var targetShards = List.of( + targetShard(shard1, node1), + targetShard(shard2, node2, node4), + targetShard(shard3, node1, node2), + targetShard(shard4, node2, node3) + ); + Queue sent = ConcurrentCollections.newQueue(); + var future = sendRequests(targetShards, (node, shardIds, aliasFilters, listener) -> { + sent.add(new NodeRequest(node, shardIds, aliasFilters)); + var resp = new DataNodeComputeResponse(List.of(), Map.of()); + runWithDelay(() -> listener.onResponse(resp)); + }); + safeGet(future); + assertThat(sent.size(), equalTo(2)); + assertThat(groupRequests(sent, 2), equalTo(Map.of(node1, List.of(shard1, shard3), node2, List.of(shard2, shard4)))); + } + + public void testMissingShards() { + var targetShards = List.of(targetShard(shard1, node1), targetShard(shard3), targetShard(shard4, node2, node3)); + var future = sendRequests(targetShards, (node, shardIds, aliasFilters, listener) -> { + fail("expect no data-node request is sent when target shards are missing"); + }); + var error = expectThrows(NoShardAvailableActionException.class, future::actionGet); + assertThat(error.getMessage(), containsString("no shard copies found")); + } + + public void testRetryThenSuccess() { + var targetShards = List.of( + targetShard(shard1, node1), + targetShard(shard2, node4, node2), + targetShard(shard3, node2, node3), + targetShard(shard4, node2, node3), + targetShard(shard5, node1, node3, node2) + ); + Queue sent = ConcurrentCollections.newQueue(); + var future = sendRequests(targetShards, (node, shardIds, aliasFilters, listener) -> { + sent.add(new NodeRequest(node, shardIds, aliasFilters)); + Map failures = new HashMap<>(); + if (node.equals(node1) && shardIds.contains(shard5)) { + failures.put(shard5, new IOException("test")); + } + if (node.equals(node4) && shardIds.contains(shard2)) { + failures.put(shard2, new IOException("test")); + } + runWithDelay(() -> listener.onResponse(new DataNodeComputeResponse(List.of(), failures))); + }); + try { + future.actionGet(1, TimeUnit.MINUTES); + } catch (Exception e) { + throw new AssertionError(e); + } + assertThat(sent, hasSize(5)); + var firstRound = groupRequests(sent, 3); + assertThat(firstRound, equalTo(Map.of(node1, List.of(shard1, shard5), node4, List.of(shard2), node2, List.of(shard3, shard4)))); + var secondRound = groupRequests(sent, 2); + assertThat(secondRound, equalTo(Map.of(node2, List.of(shard2), node3, List.of(shard5)))); + } + + public void testRetryButFail() { + var targetShards = List.of( + targetShard(shard1, node1), + targetShard(shard2, node4, node2), + targetShard(shard3, node2, node3), + targetShard(shard4, node2, node3), + targetShard(shard5, node1, node3, node2) + ); + Queue sent = ConcurrentCollections.newQueue(); + var future = sendRequests(targetShards, (node, shardIds, aliasFilters, listener) -> { + sent.add(new NodeRequest(node, shardIds, aliasFilters)); + Map failures = new HashMap<>(); + if (shardIds.contains(shard5)) { + failures.put(shard5, new IOException("test failure for shard5")); + } + runWithDelay(() -> listener.onResponse(new DataNodeComputeResponse(List.of(), failures))); + }); + var error = expectThrows(Exception.class, future::actionGet); + assertNotNull(ExceptionsHelper.unwrap(error, IOException.class)); + // {node-1, node-2, node-4}, {node-3}, {node-2} + assertThat(sent.size(), equalTo(5)); + var firstRound = groupRequests(sent, 3); + assertThat(firstRound, equalTo(Map.of(node1, List.of(shard1, shard5), node2, List.of(shard3, shard4), node4, List.of(shard2)))); + NodeRequest fourth = sent.remove(); + assertThat(fourth.node(), equalTo(node3)); + assertThat(fourth.shardIds(), equalTo(List.of(shard5))); + NodeRequest fifth = sent.remove(); + assertThat(fifth.node(), equalTo(node2)); + assertThat(fifth.shardIds(), equalTo(List.of(shard5))); + } + + public void testDoNotRetryOnRequestLevelFailure() { + var targetShards = List.of(targetShard(shard1, node1), targetShard(shard2, node2), targetShard(shard3, node1)); + Queue sent = ConcurrentCollections.newQueue(); + AtomicBoolean failed = new AtomicBoolean(); + var future = sendRequests(targetShards, (node, shardIds, aliasFilters, listener) -> { + sent.add(new NodeRequest(node, shardIds, aliasFilters)); + if (node1.equals(node) && failed.compareAndSet(false, true)) { + runWithDelay(() -> listener.onFailure(new IOException("test request level failure"), true)); + } else { + runWithDelay(() -> listener.onResponse(new DataNodeComputeResponse(List.of(), Map.of()))); + } + }); + Exception exception = expectThrows(Exception.class, future::actionGet); + assertNotNull(ExceptionsHelper.unwrap(exception, IOException.class)); + // one round: {node-1, node-2} + assertThat(sent.size(), equalTo(2)); + var firstRound = groupRequests(sent, 2); + assertThat(firstRound, equalTo(Map.of(node1, List.of(shard1, shard3), node2, List.of(shard2)))); + } + + static DataNodeRequestSender.TargetShard targetShard(ShardId shardId, DiscoveryNode... nodes) { + return new DataNodeRequestSender.TargetShard(shardId, new ArrayList<>(Arrays.asList(nodes)), null); + } + + static Map> groupRequests(Queue sent, int limit) { + Map> map = new HashMap<>(); + for (int i = 0; i < limit; i++) { + NodeRequest r = sent.remove(); + assertNull(map.put(r.node(), r.shardIds().stream().sorted().toList())); + } + return map; + } + + void runWithDelay(Runnable runnable) { + if (randomBoolean()) { + threadPool.schedule(runnable, TimeValue.timeValueNanos(between(0, 5000)), executor); + } else { + executor.execute(runnable); + } + } + + PlainActionFuture sendRequests(List shards, Sender sender) { + PlainActionFuture future = new PlainActionFuture<>(); + TransportService transportService = mock(TransportService.class); + when(transportService.getThreadPool()).thenReturn(threadPool); + CancellableTask task = new CancellableTask( + randomNonNegativeLong(), + "type", + "action", + randomAlphaOfLength(10), + TaskId.EMPTY_TASK_ID, + Collections.emptyMap() + ); + DataNodeRequestSender requestSender = new DataNodeRequestSender(transportService, executor, task) { + @Override + void searchShards( + Task parentTask, + String clusterAlias, + QueryBuilder filter, + Set concreteIndices, + OriginalIndices originalIndices, + ActionListener listener + ) { + var targetShards = new TargetShards( + shards.stream().collect(Collectors.toMap(TargetShard::shardId, Function.identity())), + shards.size(), + 0 + ); + assertSame(parentTask, task); + runWithDelay(() -> listener.onResponse(targetShards)); + } + + @Override + protected void sendRequest( + DiscoveryNode node, + List shardIds, + Map aliasFilters, + NodeListener listener + ) { + sender.sendRequestToOneNode(node, shardIds, aliasFilters, listener); + } + }; + requestSender.startComputeOnDataNodes( + "", + Set.of(randomAlphaOfLength(10)), + new OriginalIndices(new String[0], SearchRequest.DEFAULT_INDICES_OPTIONS), + null, + () -> {}, + future + ); + return future; + } + + interface Sender { + void sendRequestToOneNode( + DiscoveryNode node, + List shardIds, + Map aliasFilters, + DataNodeRequestSender.NodeListener listener + ); + } +}