diff --git a/docs/changelog/126653.yaml b/docs/changelog/126653.yaml new file mode 100644 index 0000000000000..1497aa7a40053 --- /dev/null +++ b/docs/changelog/126653.yaml @@ -0,0 +1,5 @@ +pr: 126653 +summary: Retry shard movements during ESQL query +area: ES|QL +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchShardsAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchShardsAction.java index cf08107c57017..f9b23a773f2d8 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchShardsAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchShardsAction.java @@ -131,11 +131,8 @@ public void searchShards(Task task, SearchShardsRequest searchShardsRequest, Act listener.delegateFailureAndWrap((delegate, searchRequest) -> { Index[] concreteIndices = resolvedIndices.getConcreteLocalIndices(); final Set indicesAndAliases = indexNameExpressionResolver.resolveExpressions( - project.metadata(), - searchRequest.indices() - ); final Map aliasFilters = transportSearchAction.buildIndexAliasFilters( project, diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java index 2bd53cbfc9d30..41554bed78bfa 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java @@ -15,6 +15,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.RemoteException; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.project.ProjectResolver; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.breaker.CircuitBreaker; @@ -383,6 +384,7 @@ public static LogicalOptimizerContext unboundLogicalOptimizerContext() { mock(SearchService.class), null, mock(ClusterService.class), + mock(ProjectResolver.class), mock(IndexNameExpressionResolver.class), null, mockInferenceRunner() diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/ManyShardsIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/ManyShardsIT.java index c5809d5d5ed1c..506c143b85150 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/ManyShardsIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/ManyShardsIT.java @@ -25,7 +25,6 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.MockSearchService; import org.elasticsearch.search.SearchService; -import org.elasticsearch.test.junit.annotations.TestIssueLogging; import org.elasticsearch.test.transport.MockTransportService; import org.elasticsearch.transport.RemoteTransportException; import org.elasticsearch.transport.TransportChannel; @@ -260,10 +259,6 @@ public void testLimitConcurrentShards() { } } - @TestIssueLogging( - issueUrl = "https://github.com/elastic/elasticsearch/issues/125947", - value = "logger.org.elasticsearch.cluster.routing.allocation.ShardChangesObserver:TRACE" - ) public void testCancelUnnecessaryRequests() { assumeTrue("Requires pragmas", canUseQueryPragmas()); internalCluster().ensureAtLeastNumDataNodes(3); diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSenderIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSenderIT.java new file mode 100644 index 0000000000000..1e22d2c69c881 --- /dev/null +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSenderIT.java @@ -0,0 +1,159 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.plugin; + +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.CollectionUtils; +import org.elasticsearch.compute.operator.exchange.ExchangeService; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.test.transport.MockTransportService; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase; +import org.elasticsearch.xpack.esql.action.EsqlQueryResponse; + +import java.util.Collection; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.LongAdder; + +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.as; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.getValuesList; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.hasSize; + +public class DataNodeRequestSenderIT extends AbstractEsqlIntegTestCase { + + @Override + protected Collection> nodePlugins() { + return CollectionUtils.appendToCopy(super.nodePlugins(), MockTransportService.TestPlugin.class); + } + + public void testSearchWhileRelocating() throws InterruptedException { + internalCluster().ensureAtLeastNumDataNodes(3); + var primaries = randomIntBetween(1, 10); + var replicas = randomIntBetween(0, 1); + + indicesAdmin().prepareCreate("index-1").setSettings(indexSettings(primaries, replicas)).get(); + + var docs = randomIntBetween(10, 100); + var bulk = client().prepareBulk("index-1").setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + for (int i = 0; i < docs; i++) { + bulk.add(new IndexRequest().source("key", "value-1")); + } + bulk.get(); + + // start background searches + var stopped = new AtomicBoolean(false); + var queries = new LongAdder(); + var threads = new Thread[randomIntBetween(1, 5)]; + for (int i = 0; i < threads.length; i++) { + threads[i] = new Thread(() -> { + while (stopped.get() == false) { + try (EsqlQueryResponse resp = run("FROM index-1")) { + assertThat(getValuesList(resp), hasSize(docs)); + } + queries.increment(); + } + }); + } + for (Thread thread : threads) { + thread.start(); + } + + // start shard movements + var rounds = randomIntBetween(1, 10); + var names = internalCluster().getNodeNames(); + for (int i = 0; i < rounds; i++) { + for (String name : names) { + client().admin() + .cluster() + .prepareUpdateSettings(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT) + .setPersistentSettings(Settings.builder().put("cluster.routing.allocation.exclude._name", name)) + .get(); + ensureGreen("index-1"); + Thread.yield(); + } + } + + stopped.set(true); + for (Thread thread : threads) { + thread.join(10_000); + } + + client().admin() + .cluster() + .prepareUpdateSettings(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT) + .setPersistentSettings(Settings.builder().putNull("cluster.routing.allocation.exclude._name")) + .get(); + assertThat(queries.sum(), greaterThan((long) threads.length)); + } + + public void testRetryOnShardMovement() { + internalCluster().ensureAtLeastNumDataNodes(2); + + assertAcked( + client().admin() + .indices() + .prepareCreate("index-1") + .setSettings( + Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) + ) + ); + assertAcked( + client().admin() + .indices() + .prepareCreate("index-2") + .setSettings( + Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) + ) + ); + client().prepareBulk("index-1") + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .add(new IndexRequest().source("key", "value-1")) + .get(); + client().prepareBulk("index-2") + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .add(new IndexRequest().source("key", "value-2")) + .get(); + + var shouldMove = new AtomicBoolean(true); + + for (TransportService transportService : internalCluster().getInstances(TransportService.class)) { + as(transportService, MockTransportService.class).addRequestHandlingBehavior( + ExchangeService.OPEN_EXCHANGE_ACTION_NAME, + (handler, request, channel, task) -> { + // move index shard + if (shouldMove.compareAndSet(true, false)) { + var currentShardNodeId = clusterService().state() + .routingTable() + .index("index-1") + .shard(0) + .primaryShard() + .currentNodeId(); + assertAcked( + client().admin() + .indices() + .prepareUpdateSettings("index-1") + .setSettings(Settings.builder().put("index.routing.allocation.exclude._id", currentShardNodeId)) + ); + ensureGreen("index-1"); + } + // execute data node request + handler.messageReceived(request, channel, task); + } + ); + } + + try (EsqlQueryResponse resp = run("FROM " + randomFrom("index-1,index-2", "index-*"))) { + assertThat(getValuesList(resp), hasSize(2)); + } + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java index 4e91e1d505791..5b4ba7140822e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java @@ -12,6 +12,7 @@ import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.ShardSearchFailure; import org.elasticsearch.cluster.RemoteException; +import org.elasticsearch.cluster.project.ProjectResolver; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.concurrent.RunOnce; @@ -130,6 +131,7 @@ public class ComputeService { private final LookupFromIndexService lookupFromIndexService; private final InferenceRunner inferenceRunner; private final ClusterService clusterService; + private final ProjectResolver projectResolver; private final AtomicLong childSessionIdGenerator = new AtomicLong(); private final DataNodeComputeHandler dataNodeComputeHandler; private final ClusterComputeHandler clusterComputeHandler; @@ -157,7 +159,16 @@ public ComputeService( this.lookupFromIndexService = lookupFromIndexService; this.inferenceRunner = transportActionServices.inferenceRunner(); this.clusterService = transportActionServices.clusterService(); - this.dataNodeComputeHandler = new DataNodeComputeHandler(this, searchService, transportService, exchangeService, esqlExecutor); + this.projectResolver = transportActionServices.projectResolver(); + this.dataNodeComputeHandler = new DataNodeComputeHandler( + this, + clusterService, + projectResolver, + searchService, + transportService, + exchangeService, + esqlExecutor + ); this.clusterComputeHandler = new ClusterComputeHandler( this, exchangeService, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeComputeHandler.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeComputeHandler.java index 59da28fae7279..2c1677e012078 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeComputeHandler.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeComputeHandler.java @@ -16,6 +16,8 @@ import org.elasticsearch.action.support.ChannelActionListener; import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.project.ProjectResolver; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.compute.operator.DriverCompletionInfo; import org.elasticsearch.compute.operator.exchange.ExchangeService; import org.elasticsearch.compute.operator.exchange.ExchangeSink; @@ -66,6 +68,8 @@ final class DataNodeComputeHandler implements TransportRequestHandler { private final ComputeService computeService; private final SearchService searchService; + private final ClusterService clusterService; + private final ProjectResolver projectResolver; private final TransportService transportService; private final ExchangeService exchangeService; private final Executor esqlExecutor; @@ -73,12 +77,16 @@ final class DataNodeComputeHandler implements TransportRequestHandler 0 ? new Semaphore(concurrentRequests) : null; + this.remainingUnavailableShardResolutionAttempts = new AtomicInteger( + unavailableShardResolutionAttempts >= 0 ? unavailableShardResolutionAttempts : Integer.MAX_VALUE + ); } - final void startComputeOnDataNodes( - Set concreteIndices, - OriginalIndices originalIndices, - QueryBuilder requestFilter, - Runnable runOnTaskFailure, - ActionListener listener - ) { + final void startComputeOnDataNodes(Set concreteIndices, Runnable runOnTaskFailure, ActionListener listener) { final long startTimeInNanos = System.nanoTime(); - searchShards(requestFilter, concreteIndices, originalIndices, ActionListener.wrap(targetShards -> { + searchShards(concreteIndices, ActionListener.wrap(targetShards -> { try ( var computeListener = new ComputeListener( transportService.getThreadPool(), @@ -118,7 +137,7 @@ final void startComputeOnDataNodes( listener.map( completionInfo -> new ComputeResponse( completionInfo, - TimeValue.timeValueNanos(System.nanoTime() - startTimeInNanos), + timeValueNanos(System.nanoTime() - startTimeInNanos), targetShards.totalShards(), targetShards.totalShards() - shardFailures.size() - skippedShards.get(), targetShards.skippedShards() + skippedShards.get(), @@ -128,11 +147,6 @@ final void startComputeOnDataNodes( ) ) ) { - for (TargetShard shard : targetShards.shards.values()) { - for (DiscoveryNode node : shard.remainingNodes) { - nodePermits.putIfAbsent(node, new Semaphore(1)); - } - } pendingShardIds.addAll(order(targetShards)); trySendingRequestsForPendingShards(targetShards, computeListener); } @@ -242,11 +256,27 @@ private List selectFailures() { private void sendOneNodeRequest(TargetShards targetShards, ComputeListener computeListener, NodeRequest request) { final ActionListener listener = computeListener.acquireCompute(); sendRequest(request.node, request.shardIds, request.aliasFilters, new NodeListener() { + + private final Set pendingRetries = new HashSet<>(); + void onAfter(DriverCompletionInfo info) { nodePermits.get(request.node).release(); if (concurrentRequests != null) { concurrentRequests.release(); } + + if (pendingRetries.isEmpty() == false && remainingUnavailableShardResolutionAttempts.decrementAndGet() >= 0) { + try { + sendingLock.lock(); + var resolutions = resolveShards(pendingRetries); + for (var entry : resolutions.entrySet()) { + targetShards.shards.get(entry.getKey()).remainingNodes.addAll(entry.getValue()); + } + } finally { + sendingLock.unlock(); + } + } + trySendingRequestsForPendingShards(targetShards, computeListener); listener.onResponse(info); } @@ -259,10 +289,11 @@ public void onResponse(DataNodeComputeResponse response) { shardFailures.remove(shardId); } } - for (Map.Entry e : response.shardLevelFailures().entrySet()) { - final ShardId shardId = e.getKey(); - trackShardLevelFailure(shardId, false, e.getValue()); + for (var entry : response.shardLevelFailures().entrySet()) { + final ShardId shardId = entry.getKey(); + trackShardLevelFailure(shardId, false, entry.getValue()); pendingShardIds.add(shardId); + maybeScheduleRetry(shardId, false, entry.getValue()); } onAfter(response.completionInfo()); } @@ -272,6 +303,7 @@ public void onFailure(Exception e, boolean receivedData) { for (ShardId shardId : request.shardIds) { trackShardLevelFailure(shardId, receivedData, e); pendingShardIds.add(shardId); + maybeScheduleRetry(shardId, receivedData, e); } onAfter(DriverCompletionInfo.EMPTY); } @@ -285,6 +317,14 @@ public void onSkip() { onResponse(new DataNodeComputeResponse(DriverCompletionInfo.EMPTY, Map.of())); } } + + private void maybeScheduleRetry(ShardId shardId, boolean receivedData, Exception e) { + if (receivedData == false + && targetShards.getShard(shardId).remainingNodes.isEmpty() + && unwrapFailure(shardId, e) instanceof NoShardAvailableActionException) { + pendingRetries.add(shardId); + } + } }); } @@ -327,7 +367,7 @@ private void trackShardLevelFailure(ShardId shardId, boolean fatal, Exception or } /** - * Result from {@link #searchShards(QueryBuilder, Set, OriginalIndices, ActionListener)} where can_match is performed to + * Result from {@link #searchShards(Set, ActionListener)} where can_match is performed to * determine what shards can be skipped and which target nodes are needed for running the ES|QL query * * @param shards List of target shards to perform the ES|QL query on @@ -379,7 +419,7 @@ private List selectNodeRequests(TargetShards targetShards) { } if (concurrentRequests == null || concurrentRequests.tryAcquire()) { - if (nodePermits.get(node).tryAcquire()) { + if (nodePermits.computeIfAbsent(node, n -> new Semaphore(1)).tryAcquire()) { pendingRequest = new ArrayList<>(); pendingRequest.add(shard.shardId); nodeToShardIds.put(node, pendingRequest); @@ -417,20 +457,15 @@ private List selectNodeRequests(TargetShards targetShards) { * Ideally, the search_shards API should be called before the field-caps API; however, this can lead * to a situation where the column structure (i.e., matched data types) differs depending on the query. */ - void searchShards( - QueryBuilder filter, - Set concreteIndices, - OriginalIndices originalIndices, - ActionListener listener - ) { + void searchShards(Set concreteIndices, ActionListener listener) { ActionListener searchShardsListener = listener.map(resp -> { - Map nodes = new HashMap<>(); + Map nodes = newHashMap(resp.getNodes().size()); for (DiscoveryNode node : resp.getNodes()) { nodes.put(node.getId(), node); } int totalShards = 0; int skippedShards = 0; - Map shards = new HashMap<>(); + Map shards = newHashMap(resp.getGroups().size()); for (SearchShardsGroup group : resp.getGroups()) { var shardId = group.shardId(); if (concreteIndices.contains(shardId.getIndexName()) == false) { @@ -450,10 +485,10 @@ void searchShards( } return new TargetShards(shards, totalShards, skippedShards); }); - SearchShardsRequest searchShardsRequest = new SearchShardsRequest( + var searchShardsRequest = new SearchShardsRequest( originalIndices.indices(), originalIndices.indicesOptions(), - filter, + requestFilter, null, null, true, // unavailable_shards will be handled by the sender @@ -468,4 +503,24 @@ void searchShards( new ActionListenerResponseHandler<>(searchShardsListener, SearchShardsResponse::new, esqlExecutor) ); } + + /** + * Attempts to resolve shards locations after they have been moved + */ + Map> resolveShards(Set shardIds) { + var project = projectResolver.getProjectState(clusterService.state()); + var nodes = Maps.>newMapWithExpectedSize(shardIds.size()); + for (var shardId : shardIds) { + nodes.put( + shardId, + project.routingTable() + .shardRoutingTable(shardId) + .allShards() + .filter(shard -> shard.active() && shard.isSearchable()) + .map(shard -> project.cluster().nodes().get(shard.currentNodeId())) + .toList() + ); + } + return nodes; + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/QueryPragmas.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/QueryPragmas.java index b8aa1a7badcab..b0a99abf14288 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/QueryPragmas.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/QueryPragmas.java @@ -66,6 +66,9 @@ public final class QueryPragmas implements Writeable { public static final Setting MAX_CONCURRENT_SHARDS_PER_NODE = // Setting.intSetting("max_concurrent_shards_per_node", 10, 1, 100); + public static final Setting UNAVAILABLE_SHARD_RESOLUTION_ATTEMPTS = // + Setting.intSetting("unavailable_shard_resolution_attempts", 10, -1); + public static final Setting NODE_LEVEL_REDUCTION = Setting.boolSetting("node_level_reduction", true); public static final Setting FOLD_LIMIT = Setting.memorySizeSetting("fold_limit", "5%"); @@ -156,6 +159,14 @@ public int maxConcurrentShardsPerNode() { return MAX_CONCURRENT_SHARDS_PER_NODE.get(settings); } + /** + * Amount of attempts moved shards could be retried. + * This setting is protecting query from endlessly chasing moving shards. + */ + public int unavailableShardResolutionAttempts() { + return UNAVAILABLE_SHARD_RESOLUTION_ATTEMPTS.get(settings); + } + /** * Returns true if each data node should perform a local reduction for sort, limit, topN, stats or false if the coordinator node * will perform the reduction. diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportActionServices.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportActionServices.java index 0874ff4068227..ccabe09fd466c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportActionServices.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportActionServices.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.esql.plugin; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.project.ProjectResolver; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.compute.operator.exchange.ExchangeService; import org.elasticsearch.search.SearchService; @@ -20,6 +21,7 @@ public record TransportActionServices( SearchService searchService, ExchangeService exchangeService, ClusterService clusterService, + ProjectResolver projectResolver, IndexNameExpressionResolver indexNameExpressionResolver, UsageService usageService, InferenceRunner inferenceRunner diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java index 176545b705fe3..0404f7291fe18 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java @@ -15,6 +15,7 @@ import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.project.ProjectResolver; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; @@ -94,6 +95,7 @@ public TransportEsqlQueryAction( SearchService searchService, ExchangeService exchangeService, ClusterService clusterService, + ProjectResolver projectResolver, ThreadPool threadPool, BigArrays bigArrays, BlockFactoryProvider blockFactoryProvider, @@ -149,6 +151,7 @@ public TransportEsqlQueryAction( searchService, exchangeService, clusterService, + projectResolver, indexNameExpressionResolver, usageService, new InferenceRunner(client) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSenderTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSenderTests.java index ae09d270d6f3c..ce0ef53eadb3f 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSenderTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSenderTests.java @@ -25,8 +25,8 @@ import org.elasticsearch.compute.operator.DriverCompletionInfo; import org.elasticsearch.compute.test.ComputeTestCase; import org.elasticsearch.index.Index; -import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.index.shard.ShardNotFoundException; import org.elasticsearch.search.internal.AliasFilter; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.TaskId; @@ -34,6 +34,7 @@ import org.elasticsearch.threadpool.FixedExecutorBuilder; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.esql.plugin.DataNodeRequestSender.NodeListener; import org.junit.After; import org.junit.Before; @@ -43,6 +44,7 @@ import java.util.Collection; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; @@ -105,9 +107,9 @@ public void shutdownThreadPool() { public void testEmpty() { var future = sendRequests( - List.of(), randomBoolean(), -1, + List.of(), (node, shardIds, aliasFilters, listener) -> fail("expect no data-node request is sent") ); var resp = safeGet(future); @@ -122,7 +124,7 @@ public void testOnePass() { targetShard(shard4, node2, node3) ); Queue sent = ConcurrentCollections.newQueue(); - var future = sendRequests(targetShards, randomBoolean(), -1, (node, shardIds, aliasFilters, listener) -> { + var future = sendRequests(randomBoolean(), -1, targetShards, (node, shardIds, aliasFilters, listener) -> { sent.add(nodeRequest(node, shardIds)); runWithDelay(() -> listener.onResponse(new DataNodeComputeResponse(DriverCompletionInfo.EMPTY, Map.of()))); }); @@ -134,14 +136,14 @@ public void testOnePass() { public void testMissingShards() { { var targetShards = List.of(targetShard(shard1, node1), targetShard(shard3), targetShard(shard4, node2, node3)); - var future = sendRequests(targetShards, false, -1, (node, shardIds, aliasFilters, listener) -> { + var future = sendRequests(false, -1, targetShards, (node, shardIds, aliasFilters, listener) -> { fail("expect no data-node request is sent when target shards are missing"); }); expectThrows(NoShardAvailableActionException.class, containsString("no shard copies found"), future::actionGet); } { var targetShards = List.of(targetShard(shard1, node1), targetShard(shard3), targetShard(shard4, node2, node3)); - var future = sendRequests(targetShards, true, -1, (node, shardIds, aliasFilters, listener) -> { + var future = sendRequests(true, -1, targetShards, (node, shardIds, aliasFilters, listener) -> { assertThat(shard3, not(in(shardIds))); runWithDelay(() -> listener.onResponse(new DataNodeComputeResponse(DriverCompletionInfo.EMPTY, Map.of()))); }); @@ -165,7 +167,7 @@ public void testRetryThenSuccess() { targetShard(shard5, node1, node3, node2) ); Queue sent = ConcurrentCollections.newQueue(); - var future = sendRequests(targetShards, randomBoolean(), -1, (node, shardIds, aliasFilters, listener) -> { + var future = sendRequests(randomBoolean(), -1, targetShards, (node, shardIds, aliasFilters, listener) -> { sent.add(nodeRequest(node, shardIds)); Map failures = new HashMap<>(); if (node.equals(node1) && shardIds.contains(shard5)) { @@ -198,7 +200,7 @@ public void testRetryButFail() { targetShard(shard5, node1, node3, node2) ); Queue sent = ConcurrentCollections.newQueue(); - var future = sendRequests(targetShards, false, -1, (node, shardIds, aliasFilters, listener) -> { + var future = sendRequests(false, -1, targetShards, (node, shardIds, aliasFilters, listener) -> { sent.add(nodeRequest(node, shardIds)); Map failures = new HashMap<>(); if (shardIds.contains(shard5)) { @@ -222,7 +224,7 @@ public void testDoNotRetryOnRequestLevelFailure() { var targetShards = List.of(targetShard(shard1, node1), targetShard(shard2, node2), targetShard(shard3, node1)); Queue sent = ConcurrentCollections.newQueue(); AtomicBoolean failed = new AtomicBoolean(); - var future = sendRequests(targetShards, false, -1, (node, shardIds, aliasFilters, listener) -> { + var future = sendRequests(false, -1, targetShards, (node, shardIds, aliasFilters, listener) -> { sent.add(nodeRequest(node, shardIds)); if (node1.equals(node) && failed.compareAndSet(false, true)) { runWithDelay(() -> listener.onFailure(new IOException("test request level failure"), true)); @@ -241,7 +243,7 @@ public void testAllowPartialResults() { var targetShards = List.of(targetShard(shard1, node1), targetShard(shard2, node2), targetShard(shard3, node1, node2)); Queue sent = ConcurrentCollections.newQueue(); AtomicBoolean failed = new AtomicBoolean(); - var future = sendRequests(targetShards, true, -1, (node, shardIds, aliasFilters, listener) -> { + var future = sendRequests(true, -1, targetShards, (node, shardIds, aliasFilters, listener) -> { sent.add(nodeRequest(node, shardIds)); if (node1.equals(node) && failed.compareAndSet(false, true)) { runWithDelay(() -> listener.onFailure(new IOException("test request level failure"), true)); @@ -261,7 +263,7 @@ public void testAllowPartialResults() { public void testNonFatalErrorIsRetriedOnAnotherShard() { var targetShards = List.of(targetShard(shard1, node1, node2)); var sent = ConcurrentCollections.newQueue(); - var response = safeGet(sendRequests(targetShards, false, -1, (node, shardIds, aliasFilters, listener) -> { + var response = safeGet(sendRequests(false, -1, targetShards, (node, shardIds, aliasFilters, listener) -> { sent.add(nodeRequest(node, shardIds)); if (Objects.equals(node1, node)) { runWithDelay(() -> listener.onFailure(new RuntimeException("test request level non fatal failure"), false)); @@ -278,7 +280,7 @@ public void testNonFatalErrorIsRetriedOnAnotherShard() { public void testNonFatalFailedOnAllNodes() { var targetShards = List.of(targetShard(shard1, node1, node2)); var sent = ConcurrentCollections.newQueue(); - var future = sendRequests(targetShards, false, -1, (node, shardIds, aliasFilters, listener) -> { + var future = sendRequests(false, -1, targetShards, (node, shardIds, aliasFilters, listener) -> { sent.add(nodeRequest(node, shardIds)); runWithDelay(() -> listener.onFailure(new RuntimeException("test request level non fatal failure"), false)); }); @@ -289,7 +291,7 @@ public void testNonFatalFailedOnAllNodes() { public void testDoNotRetryCircuitBreakerException() { var targetShards = List.of(targetShard(shard1, node1, node2)); var sent = ConcurrentCollections.newQueue(); - var future = sendRequests(targetShards, false, -1, (node, shardIds, aliasFilters, listener) -> { + var future = sendRequests(false, -1, targetShards, (node, shardIds, aliasFilters, listener) -> { sent.add(nodeRequest(node, shardIds)); runWithDelay(() -> listener.onFailure(new CircuitBreakingException("cbe", randomFrom(Durability.values())), false)); }); @@ -310,7 +312,7 @@ public void testLimitConcurrentNodes() { AtomicInteger maxConcurrentRequests = new AtomicInteger(0); AtomicInteger concurrentRequests = new AtomicInteger(0); var sent = ConcurrentCollections.newQueue(); - var response = safeGet(sendRequests(targetShards, randomBoolean(), concurrency, (node, shardIds, aliasFilters, listener) -> { + var response = safeGet(sendRequests(randomBoolean(), concurrency, targetShards, (node, shardIds, aliasFilters, listener) -> { concurrentRequests.incrementAndGet(); while (true) { @@ -344,7 +346,7 @@ public void testSkipNodes() { ); AtomicInteger processed = new AtomicInteger(0); - var response = safeGet(sendRequests(targetShards, randomBoolean(), 1, (node, shardIds, aliasFilters, listener) -> { + var response = safeGet(sendRequests(randomBoolean(), 1, targetShards, (node, shardIds, aliasFilters, listener) -> { runWithDelay(() -> { if (processed.incrementAndGet() == 1) { listener.onResponse(new DataNodeComputeResponse(DriverCompletionInfo.EMPTY, Map.of())); @@ -363,7 +365,7 @@ public void testSkipRemovesPriorNonFatalErrors() { var targetShards = List.of(targetShard(shard1, node1, node2), targetShard(shard2, node3)); var sent = ConcurrentCollections.newQueue(); - var response = safeGet(sendRequests(targetShards, randomBoolean(), 1, (node, shardIds, aliasFilters, listener) -> { + var response = safeGet(sendRequests(randomBoolean(), 1, targetShards, (node, shardIds, aliasFilters, listener) -> { sent.add(nodeRequest(node, shardIds)); runWithDelay(() -> { if (Objects.equals(node.getId(), node1.getId()) && shardIds.equals(List.of(shard1))) { @@ -383,20 +385,31 @@ public void testSkipRemovesPriorNonFatalErrors() { } public void testQueryHotShardsFirst() { + var warnNode = DiscoveryNodeUtils.builder("node-2").roles(Set.of(DATA_WARM_NODE_ROLE)).build(); + var coldNode = DiscoveryNodeUtils.builder("node-3").roles(Set.of(DATA_COLD_NODE_ROLE)).build(); + var frozenNode = DiscoveryNodeUtils.builder("node-4").roles(Set.of(DATA_FROZEN_NODE_ROLE)).build(); var targetShards = shuffledList( List.of( targetShard(shard1, node1), - targetShard(shard2, DiscoveryNodeUtils.builder("node-2").roles(Set.of(DATA_WARM_NODE_ROLE)).build()), - targetShard(shard3, DiscoveryNodeUtils.builder("node-3").roles(Set.of(DATA_COLD_NODE_ROLE)).build()), - targetShard(shard4, DiscoveryNodeUtils.builder("node-4").roles(Set.of(DATA_FROZEN_NODE_ROLE)).build()) + targetShard(shard2, warnNode), + targetShard(shard3, coldNode), + targetShard(shard4, frozenNode) ) ); - var sent = Collections.synchronizedList(new ArrayList()); - safeGet(sendRequests(targetShards, randomBoolean(), -1, (node, shardIds, aliasFilters, listener) -> { - sent.add(node.getId()); + var sent = ConcurrentCollections.newQueue(); + safeGet(sendRequests(randomBoolean(), -1, targetShards, (node, shardIds, aliasFilters, listener) -> { + sent.add(nodeRequest(node, shardIds)); runWithDelay(() -> listener.onResponse(new DataNodeComputeResponse(DriverCompletionInfo.EMPTY, Map.of()))); })); - assertThat(sent, equalTo(List.of("node-1", "node-2", "node-3", "node-4"))); + assertThat( + sent, + contains( + nodeRequest(node1, shard1), + nodeRequest(warnNode, shard2), + nodeRequest(coldNode, shard3), + nodeRequest(frozenNode, shard4) + ) + ); } public void testQueryHotShardsFirstWhenIlmMovesShard() { @@ -405,14 +418,89 @@ public void testQueryHotShardsFirstWhenIlmMovesShard() { List.of(targetShard(shard1, node1), targetShard(shard2, shuffledList(List.of(node2, warmNode2)).toArray(DiscoveryNode[]::new))) ); var sent = ConcurrentCollections.newQueue(); - safeGet(sendRequests(targetShards, randomBoolean(), -1, (node, shardIds, aliasFilters, listener) -> { + safeGet(sendRequests(randomBoolean(), -1, targetShards, (node, shardIds, aliasFilters, listener) -> { sent.add(nodeRequest(node, shardIds)); runWithDelay(() -> listener.onResponse(new DataNodeComputeResponse(DriverCompletionInfo.EMPTY, Map.of()))); })); - assertThat(take(sent, 1), containsInAnyOrder(nodeRequest(node1, shard1))); + assertThat(take(sent, 1), contains(nodeRequest(node1, shard1))); assertThat(take(sent, 1), anyOf(contains(nodeRequest(node2, shard2)), contains(nodeRequest(warmNode2, shard2)))); } + public void testRetryMovedShard() { + var attempt = new AtomicInteger(0); + var response = safeGet( + sendRequests(randomBoolean(), -1, List.of(targetShard(shard1, node1)), shardIds -> switch (attempt.incrementAndGet()) { + case 1 -> Map.of(shard1, List.of(node2)); + case 2 -> Map.of(shard1, List.of(node3)); + default -> Map.of(shard1, List.of(node4)); + }, + (node, shardIds, aliasFilters, listener) -> runWithDelay( + () -> listener.onResponse( + Objects.equals(node, node4) + ? new DataNodeComputeResponse(DriverCompletionInfo.EMPTY, Map.of()) + : new DataNodeComputeResponse(DriverCompletionInfo.EMPTY, Map.of(shard1, new ShardNotFoundException(shard1))) + ) + ) + ) + ); + assertThat(response.totalShards, equalTo(1)); + assertThat(response.successfulShards, equalTo(1)); + assertThat(response.skippedShards, equalTo(0)); + assertThat(response.failedShards, equalTo(0)); + assertThat(attempt.get(), equalTo(3)); + } + + public void testDoesNotRetryMovedShardIndefinitely() { + var attempt = new AtomicInteger(0); + var response = safeGet(sendRequests(true, -1, List.of(targetShard(shard1, node1)), shardIds -> { + attempt.incrementAndGet(); + return Map.of(shard1, List.of(node2)); + }, + (node, shardIds, aliasFilters, listener) -> runWithDelay( + () -> listener.onResponse( + new DataNodeComputeResponse(DriverCompletionInfo.EMPTY, Map.of(shard1, new ShardNotFoundException(shard1))) + ) + ) + )); + assertThat(response.totalShards, equalTo(1)); + assertThat(response.successfulShards, equalTo(0)); + assertThat(response.skippedShards, equalTo(0)); + assertThat(response.failedShards, equalTo(1)); + assertThat(attempt.get(), equalTo(10)); + } + + public void testRetryOnlyMovedShards() { + var attempt = new AtomicInteger(0); + var resolvedShards = Collections.synchronizedSet(new HashSet<>()); + var response = safeGet( + sendRequests(randomBoolean(), -1, List.of(targetShard(shard1, node1, node3), targetShard(shard2, node2)), shardIds -> { + attempt.incrementAndGet(); + resolvedShards.addAll(shardIds); + return Map.of(shard2, List.of(node4)); + }, (node, shardIds, aliasFilters, listener) -> runWithDelay(() -> { + if (Objects.equals(node, node1)) { + // search is going to be retried from replica on node3 without shard resolution + listener.onResponse( + new DataNodeComputeResponse(DriverCompletionInfo.EMPTY, Map.of(shard1, new ShardNotFoundException(shard1))) + ); + } else if (Objects.equals(node, node2)) { + // search is going to be retried after resolving new shard node since there are no replicas + listener.onResponse( + new DataNodeComputeResponse(DriverCompletionInfo.EMPTY, Map.of(shard2, new ShardNotFoundException(shard2))) + ); + } else { + listener.onResponse(new DataNodeComputeResponse(DriverCompletionInfo.EMPTY, Map.of())); + } + })) + ); + assertThat(response.totalShards, equalTo(2)); + assertThat(response.successfulShards, equalTo(2)); + assertThat(response.skippedShards, equalTo(0)); + assertThat(response.failedShards, equalTo(0)); + assertThat(attempt.get(), equalTo(1)); + assertThat("Must retry only affected shards", resolvedShards, contains(shard2)); + } + static DataNodeRequestSender.TargetShard targetShard(ShardId shardId, DiscoveryNode... nodes) { return new DataNodeRequestSender.TargetShard(shardId, new ArrayList<>(Arrays.asList(nodes)), null); } @@ -444,9 +532,21 @@ void runWithDelay(Runnable runnable) { } PlainActionFuture sendRequests( + boolean allowPartialResults, + int concurrentRequests, List shards, + Sender sender + ) { + return sendRequests(allowPartialResults, concurrentRequests, shards, shardIds -> { + throw new AssertionError("No shard resolution is expected here"); + }, sender); + } + + PlainActionFuture sendRequests( boolean allowPartialResults, int concurrentRequests, + List shards, + Resolver resolver, Sender sender ) { PlainActionFuture future = new PlainActionFuture<>(); @@ -464,27 +564,35 @@ PlainActionFuture sendRequests( TaskId.EMPTY_TASK_ID, Collections.emptyMap() ); - DataNodeRequestSender requestSender = new DataNodeRequestSender( + new DataNodeRequestSender( + null, + null, transportService, executor, - "", task, + new OriginalIndices(new String[0], SearchRequest.DEFAULT_INDICES_OPTIONS), + null, + "", allowPartialResults, - concurrentRequests + concurrentRequests, + 10 ) { @Override - void searchShards( - QueryBuilder filter, - Set concreteIndices, - OriginalIndices originalIndices, - ActionListener listener - ) { - var targetShards = new TargetShards( - shards.stream().collect(Collectors.toMap(TargetShard::shardId, Function.identity())), - shards.size(), - 0 + void searchShards(Set concreteIndices, ActionListener listener) { + runWithDelay( + () -> listener.onResponse( + new TargetShards( + shards.stream().collect(Collectors.toMap(TargetShard::shardId, Function.identity())), + shards.size(), + 0 + ) + ) ); - runWithDelay(() -> listener.onResponse(targetShards)); + } + + @Override + Map> resolveShards(Set shardIds) { + return resolver.resolve(shardIds); } @Override @@ -496,23 +604,15 @@ protected void sendRequest( ) { sender.sendRequestToOneNode(node, shardIds, aliasFilters, listener); } - }; - requestSender.startComputeOnDataNodes( - Set.of(randomAlphaOfLength(10)), - new OriginalIndices(new String[0], SearchRequest.DEFAULT_INDICES_OPTIONS), - null, - () -> {}, - future - ); + }.startComputeOnDataNodes(Set.of(randomAlphaOfLength(10)), () -> {}, future); return future; } + interface Resolver { + Map> resolve(Set shardIds); + } + interface Sender { - void sendRequestToOneNode( - DiscoveryNode node, - List shardIds, - Map aliasFilters, - DataNodeRequestSender.NodeListener listener - ); + void sendRequestToOneNode(DiscoveryNode node, List shardIds, Map aliasFilters, NodeListener listener); } }