diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/SearchCancellationIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/SearchCancellationIT.java index c2feaa4e6fe9f..4771764a11b23 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/SearchCancellationIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/SearchCancellationIT.java @@ -29,7 +29,7 @@ import org.elasticsearch.script.ScriptType; import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder; -import org.elasticsearch.search.internal.ReaderContext; +import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.test.AbstractSearchCancellationTestCase; @@ -42,6 +42,7 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import static org.elasticsearch.index.query.QueryBuilders.matchAllQuery; import static org.elasticsearch.index.query.QueryBuilders.scriptQuery; @@ -240,80 +241,103 @@ public void testCancelMultiSearch() throws Exception { } public void testCancelFailedSearchWhenPartialResultDisallowed() throws Exception { - // TODO: make this test compatible with batched execution, currently the exceptions are slightly different with batched - updateClusterSettings(Settings.builder().put(SearchService.BATCHED_QUERY_PHASE.getKey(), false)); - // Have at least two nodes so that we have parallel execution of two request guaranteed even if max concurrent requests per node - // are limited to 1 - internalCluster().ensureAtLeastNumDataNodes(2); - int numberOfShards = between(2, 5); - createIndex("test", numberOfShards, 0); - indexTestData(); - - // Define (but don't run) the search request, expecting a partial shard failure. We will run it later. - Thread searchThread = new Thread(() -> { - logger.info("Executing search"); - SearchPhaseExecutionException e = expectThrows( - SearchPhaseExecutionException.class, - prepareSearch("test").setSearchType(SearchType.QUERY_THEN_FETCH) - .setQuery(scriptQuery(new Script(ScriptType.INLINE, "mockscript", SEARCH_BLOCK_SCRIPT_NAME, Collections.emptyMap()))) - .setAllowPartialSearchResults(false) - .setSize(1000) - ); - assertThat(e.getMessage(), containsString("Partial shards failure")); - }); - - // When the search request executes, block all shards except 1. - final List searchShardBlockingPlugins = initSearchShardBlockingPlugin(); - AtomicBoolean letOneShardProceed = new AtomicBoolean(); - // Ensure we have at least one task waiting on the latch - CountDownLatch waitingTaskLatch = new CountDownLatch(1); - CountDownLatch shardTaskLatch = new CountDownLatch(1); - for (SearchShardBlockingPlugin plugin : searchShardBlockingPlugins) { - plugin.setRunOnNewReaderContext((ReaderContext c) -> { - if (letOneShardProceed.compareAndSet(false, true)) { - // Let one shard continue. - } else { - // Signal that we have a task waiting on the latch - waitingTaskLatch.countDown(); - safeAwait(shardTaskLatch); // Block the other shards. - } + boolean useBatched = randomBoolean(); + try { + if (useBatched == false) { // It's true by default + updateClusterSettings(Settings.builder().put(SearchService.BATCHED_QUERY_PHASE.getKey(), false)); + } + // Have at least two nodes so that we have parallel execution of two request guaranteed even if max concurrent requests per node + // are limited to 1 + internalCluster().ensureAtLeastNumDataNodes(2); + int numberOfShards = between(2, 5); + createIndex("test", numberOfShards, 0); + indexTestData(); + + // Define (but don't run) the search request, expecting a partial shard failure. We will run it later. + Thread searchThread = new Thread(() -> { + logger.info("Executing search"); + SearchPhaseExecutionException e = expectThrows( + SearchPhaseExecutionException.class, + prepareSearch("test").setSearchType(SearchType.QUERY_THEN_FETCH) + .setQuery( + scriptQuery(new Script(ScriptType.INLINE, "mockscript", SEARCH_BLOCK_SCRIPT_NAME, Collections.emptyMap())) + ) + .setAllowPartialSearchResults(false) + .setSize(1000) + ); + assertThat(e.getMessage(), containsString("Partial shards failure")); }); - } - // For the shard that was allowed to proceed, have a single query-execution thread throw an exception. - final List plugins = initBlockFactory(); - AtomicBoolean oneThreadWillError = new AtomicBoolean(); - for (ScriptedBlockPlugin plugin : plugins) { - plugin.disableBlock(); - plugin.setBeforeExecution(() -> { - if (oneThreadWillError.compareAndSet(false, true)) { - // wait for some task to get to the latch - safeAwait(waitingTaskLatch); - // then throw the exception - throw new IllegalStateException("This will cancel the ContextIndexSearcher.search task"); - } - }); - } + // When the search request executes, allow some shards to proceed and block others + final List searchShardBlockingPlugins = initSearchShardBlockingPlugin(); + CountDownLatch waitingTaskLatch = new CountDownLatch(1); + CountDownLatch shardTaskLatch = new CountDownLatch(1); + final AtomicReference selectedNodeId = new AtomicReference<>(); + final AtomicBoolean letOneShardProceed = new AtomicBoolean(); + for (SearchShardBlockingPlugin plugin : searchShardBlockingPlugins) { + plugin.setRunOnPreQueryPhase((SearchContext c) -> { + if (useBatched) { // Allow all the shards on one node to continue. Block all others. + String nodeId = c.shardTarget().getNodeId(); + if (selectedNodeId.compareAndSet(null, nodeId) || nodeId.equals(selectedNodeId.get())) { + logger.info("Allowing shard [{}] on node [{}] to proceed", c.shardTarget().getShardId(), nodeId); + } else { + logger.info("Blocking shard [{}] on node [{}]", c.shardTarget().getShardId(), nodeId); + // Signal that we have a task waiting on the latch + waitingTaskLatch.countDown(); + safeAwait(shardTaskLatch); // Block shards on other nodes + } + } else { // Allow one shard to continue. Block all others. + if (letOneShardProceed.compareAndSet(false, true)) { + logger.info("Allowing shard [{}] to proceed", c.shardTarget().getShardId()); + } else { + logger.info("Blocking shard [{}]", c.shardTarget().getShardId()); + // Signal that we have a task waiting on the latch + waitingTaskLatch.countDown(); + safeAwait(shardTaskLatch); // Block all other shards + } + } + }); + } - // Now run the search request. - logger.info("Starting search thread"); - searchThread.start(); + // For the shards that were allowed to proceed, have a single query-execution thread throw an exception. + final List plugins = initBlockFactory(); + AtomicBoolean oneThreadWillError = new AtomicBoolean(); + for (ScriptedBlockPlugin plugin : plugins) { + plugin.disableBlock(); + plugin.setBeforeExecution(() -> { + if (oneThreadWillError.compareAndSet(false, true)) { + // wait for some task to get to the latch + safeAwait(waitingTaskLatch); + // then throw the exception + throw new IllegalStateException("This will cancel the ContextIndexSearcher.search task"); + } + }); + } - try { - assertBusy(() -> { - final List coordinatorSearchTask = getCoordinatorSearchTasks(); - logger.info("Checking tasks: {}", coordinatorSearchTask); - assertThat("The Coordinator should have one SearchTask.", coordinatorSearchTask, hasSize(1)); - assertTrue("The SearchTask should be cancelled.", coordinatorSearchTask.get(0).isCancelled()); - for (var shardQueryTask : getShardQueryTasks()) { - assertTrue("All SearchShardTasks should then be cancelled", shardQueryTask.isCancelled()); - } - }, 30, TimeUnit.SECONDS); + // Now run the search request. + logger.info("Starting search thread"); + searchThread.start(); + + try { + assertBusy(() -> { + final List coordinatorSearchTask = getCoordinatorSearchTasks(); + logger.info("Checking tasks: {}", coordinatorSearchTask); + assertThat("The Coordinator should have one SearchTask.", coordinatorSearchTask, hasSize(1)); + assertTrue("The SearchTask should be cancelled.", coordinatorSearchTask.get(0).isCancelled()); + for (var shardQueryTask : getShardQueryTasks()) { + assertTrue("All SearchShardTasks should then be cancelled", shardQueryTask.isCancelled()); + } + }, 30, TimeUnit.SECONDS); + } finally { + shardTaskLatch.countDown(); // unblock the shardTasks, allowing the test to conclude. + searchThread.join(); + plugins.forEach(plugin -> plugin.setBeforeExecution(() -> {})); + searchShardBlockingPlugins.forEach(plugin -> plugin.setRunOnPreQueryPhase((SearchContext c) -> {})); + } } finally { - shardTaskLatch.countDown(); // unblock the shardTasks, allowing the test to conclude. - searchThread.join(); - plugins.forEach(plugin -> plugin.setBeforeExecution(() -> {})); - searchShardBlockingPlugins.forEach(plugin -> plugin.setRunOnNewReaderContext((ReaderContext c) -> {})); + if (useBatched == false) { + updateClusterSettings(Settings.builder().putNull(SearchService.BATCHED_QUERY_PHASE.getKey())); + } } } diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java index fffbd26adce50..51406d8c9ad19 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java @@ -720,8 +720,8 @@ private void setFailure(QueryPerNodeState state, int dataNodeLocalIdx, Exception @Override public void onFailure(Exception e) { - // TODO: count down fully and just respond with an exception if partial results aren't allowed as an - // optimization + // Note: this shard won't be retried until it returns to the coordinating node where the shard iterator lives + // TODO: consider alternatives that don't wait for the entire batch to complete before retrying the shard setFailure(state, dataNodeLocalIdx, e); doneFuture.onResponse(null); } diff --git a/test/framework/src/main/java/org/elasticsearch/test/AbstractSearchCancellationTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/AbstractSearchCancellationTestCase.java index d83cab3c7c205..5ecb2f24acb32 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/AbstractSearchCancellationTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/AbstractSearchCancellationTestCase.java @@ -26,7 +26,7 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.script.MockScriptPlugin; import org.elasticsearch.search.SearchService; -import org.elasticsearch.search.internal.ReaderContext; +import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.lookup.LeafStoredFieldsLookup; import org.elasticsearch.tasks.TaskInfo; import org.junit.BeforeClass; @@ -279,10 +279,10 @@ protected List initSearchShardBlockingPlugin() { } public static class SearchShardBlockingPlugin extends Plugin { - private final AtomicReference> runOnNewReaderContext = new AtomicReference<>(); + private final AtomicReference> runOnPreQueryPhase = new AtomicReference<>(); - public void setRunOnNewReaderContext(Consumer consumer) { - runOnNewReaderContext.set(consumer); + public void setRunOnPreQueryPhase(Consumer consumer) { + runOnPreQueryPhase.set(consumer); } @Override @@ -290,9 +290,9 @@ public void onIndexModule(IndexModule indexModule) { super.onIndexModule(indexModule); indexModule.addSearchOperationListener(new SearchOperationListener() { @Override - public void onNewReaderContext(ReaderContext c) { - if (runOnNewReaderContext.get() != null) { - runOnNewReaderContext.get().accept(c); + public void onPreQueryPhase(SearchContext c) { + if (runOnPreQueryPhase.get() != null) { + runOnPreQueryPhase.get().accept(c); } } });