diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/search/aggregations/TermsReduceBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/search/aggregations/TermsReduceBenchmark.java index 672f2db7c29e3..885c007c8fcab 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/search/aggregations/TermsReduceBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/search/aggregations/TermsReduceBenchmark.java @@ -194,6 +194,7 @@ public SearchPhaseController.ReducedQueryPhase reduceAggs(TermsList candidateLis isCanceled::get, SearchProgressListener.NOOP, shards.size(), + bufferSize, exc -> {} ); CountDownLatch latch = new CountDownLatch(shards.size()); diff --git a/docs/reference/search/search-your-data/search-across-clusters.asciidoc b/docs/reference/search/search-your-data/search-across-clusters.asciidoc index 377dddef1ae51..9bd320b91b0d3 100644 --- a/docs/reference/search/search-your-data/search-across-clusters.asciidoc +++ b/docs/reference/search/search-your-data/search-across-clusters.asciidoc @@ -260,7 +260,6 @@ The API returns the following response: { "took": 150, "timed_out": false, - "num_reduce_phases": 4, "_shards": { "total": 28, "successful": 28, @@ -439,7 +438,6 @@ The API returns the following response: "response": { "took": 1020, "timed_out": false, - "num_reduce_phases": 0, "_shards": { "total": 10, <2> "successful": 0, @@ -622,7 +620,6 @@ Response: "response": { "took": 27619, "timed_out": false, - "num_reduce_phases": 4, "_shards": { "total": 28, "successful": 28, <2> @@ -759,7 +756,6 @@ Response: "response": { "took": 2069, "timed_out": false, - "num_reduce_phases": 4, "_shards": { "total": 28, "successful": 27, diff --git a/qa/multi-cluster-search/src/test/java/org/elasticsearch/search/CCSDuelIT.java b/qa/multi-cluster-search/src/test/java/org/elasticsearch/search/CCSDuelIT.java index 79cdc1047aec9..d7f87e574dee8 100644 --- a/qa/multi-cluster-search/src/test/java/org/elasticsearch/search/CCSDuelIT.java +++ b/qa/multi-cluster-search/src/test/java/org/elasticsearch/search/CCSDuelIT.java @@ -1030,20 +1030,6 @@ private static Map duelSearchSync(SearchRequest searchRequest, C } ObjectPath minimizeRoundtripsSearchResponse = ObjectPath.createFromResponse(minimizeRoundtripsResponse.get()); responseChecker.accept(minimizeRoundtripsSearchResponse); - - // if only the remote cluster was searched, then only one reduce phase is expected - int expectedReducePhasesMinRoundTrip = 1; - if (searchRequest.indices().length > 1) { - expectedReducePhasesMinRoundTrip = searchRequest.indices().length + 1; - } - if (expectedReducePhasesMinRoundTrip == 1) { - assertThat( - minimizeRoundtripsSearchResponse.evaluate("num_reduce_phases"), - anyOf(equalTo(expectedReducePhasesMinRoundTrip), nullValue()) - ); - } else { - assertThat(minimizeRoundtripsSearchResponse.evaluate("num_reduce_phases"), equalTo(expectedReducePhasesMinRoundTrip)); - } ObjectPath fanOutSearchResponse = ObjectPath.createFromResponse(fanOutResponse.get()); responseChecker.accept(fanOutSearchResponse); assertThat(fanOutSearchResponse.evaluate("num_reduce_phases"), anyOf(equalTo(1), nullValue())); // default value is 1? @@ -1159,20 +1145,6 @@ private static Map duelSearchAsync( responseChecker.accept(minimizeRoundtripsResponse); - // if only the remote cluster was searched, then only one reduce phase is expected - int expectedReducePhasesMinRoundTrip = 1; - if (searchRequest.indices().length > 1) { - expectedReducePhasesMinRoundTrip = searchRequest.indices().length + 1; - } - if (expectedReducePhasesMinRoundTrip == 1) { - assertThat( - minimizeRoundtripsResponse.evaluate("num_reduce_phases"), - anyOf(equalTo(expectedReducePhasesMinRoundTrip), nullValue()) - ); - } else { - assertThat(minimizeRoundtripsResponse.evaluate("num_reduce_phases"), equalTo(expectedReducePhasesMinRoundTrip)); - } - responseChecker.accept(fanOutResponse); assertThat(fanOutResponse.evaluate("num_reduce_phases"), anyOf(equalTo(1), nullValue())); // default value is 1? diff --git a/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/SearchErrorTraceIT.java b/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/SearchErrorTraceIT.java index 6f9ab8ccdfdec..54067cf2f45d9 100644 --- a/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/SearchErrorTraceIT.java +++ b/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/SearchErrorTraceIT.java @@ -13,10 +13,12 @@ import org.apache.http.nio.entity.NByteArrayEntity; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.search.MultiSearchRequest; +import org.elasticsearch.action.search.SearchQueryThenFetchAsyncAction; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.client.Request; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.transport.TransportMessageListener; +import org.elasticsearch.transport.TransportResponse; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xcontent.XContentType; import org.junit.Before; @@ -29,12 +31,30 @@ import static org.elasticsearch.index.query.QueryBuilders.simpleQueryStringQuery; public class SearchErrorTraceIT extends HttpSmokeTestCase { - private AtomicBoolean hasStackTrace; + private final AtomicBoolean hasStackTrace = new AtomicBoolean(false); @Before private void setupMessageListener() { internalCluster().getDataNodeInstances(TransportService.class).forEach(ts -> { ts.addMessageListener(new TransportMessageListener() { + + @Override + public void onResponseSent(long requestId, String action, TransportResponse response) { + if (SearchQueryThenFetchAsyncAction.NODE_SEARCH_ACTION_NAME.equals(action)) { + Object[] res = asInstanceOf(SearchQueryThenFetchAsyncAction.NodeQueryResponse.class, response).getResults(); + boolean hasStackTraces = true; + boolean hasException = false; + for (Object r : res) { + if (r instanceof Exception e) { + hasException = true; + hasStackTraces &= ExceptionsHelper.unwrapCausesAndSuppressed(e, t -> t.getStackTrace().length > 0) + .isPresent(); + } + } + hasStackTrace.set(hasStackTraces && hasException); + } + } + @Override public void onResponseSent(long requestId, String action, Exception error) { TransportMessageListener.super.onResponseSent(requestId, action, error); @@ -61,7 +81,6 @@ private void setupIndexWithDocs() { } public void testSearchFailingQueryErrorTraceDefault() throws IOException { - hasStackTrace = new AtomicBoolean(); setupIndexWithDocs(); Request searchRequest = new Request("POST", "/_search"); @@ -80,7 +99,6 @@ public void testSearchFailingQueryErrorTraceDefault() throws IOException { } public void testSearchFailingQueryErrorTraceTrue() throws IOException { - hasStackTrace = new AtomicBoolean(); setupIndexWithDocs(); Request searchRequest = new Request("POST", "/_search"); @@ -100,7 +118,6 @@ public void testSearchFailingQueryErrorTraceTrue() throws IOException { } public void testSearchFailingQueryErrorTraceFalse() throws IOException { - hasStackTrace = new AtomicBoolean(); setupIndexWithDocs(); Request searchRequest = new Request("POST", "/_search"); @@ -120,7 +137,6 @@ public void testSearchFailingQueryErrorTraceFalse() throws IOException { } public void testMultiSearchFailingQueryErrorTraceDefault() throws IOException { - hasStackTrace = new AtomicBoolean(); setupIndexWithDocs(); XContentType contentType = XContentType.JSON; @@ -137,7 +153,6 @@ public void testMultiSearchFailingQueryErrorTraceDefault() throws IOException { } public void testMultiSearchFailingQueryErrorTraceTrue() throws IOException { - hasStackTrace = new AtomicBoolean(); setupIndexWithDocs(); XContentType contentType = XContentType.JSON; @@ -150,12 +165,11 @@ public void testMultiSearchFailingQueryErrorTraceTrue() throws IOException { new NByteArrayEntity(requestBody, ContentType.create(contentType.mediaTypeWithoutParameters(), (Charset) null)) ); searchRequest.addParameter("error_trace", "true"); - getRestClient().performRequest(searchRequest); - assertTrue(hasStackTrace.get()); + var response = getRestClient().performRequest(searchRequest); + assertTrue(response.getStatusLine().getStatusCode() == 200 || hasStackTrace.get()); } public void testMultiSearchFailingQueryErrorTraceFalse() throws IOException { - hasStackTrace = new AtomicBoolean(); setupIndexWithDocs(); XContentType contentType = XContentType.JSON; diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/120_batch_reduce_size.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/120_batch_reduce_size.yml index ad8b5634b473d..4b2307ef79bc1 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/120_batch_reduce_size.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/120_batch_reduce_size.yml @@ -48,7 +48,6 @@ setup: batched_reduce_size: 2 body: { "size" : 0, "aggs" : { "str_terms" : { "terms" : { "field" : "str" } } } } - - match: { num_reduce_phases: 4 } - match: { hits.total: 3 } - length: { aggregations.str_terms.buckets: 2 } - match: { aggregations.str_terms.buckets.0.key: "abc" } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/action/IndicesRequestIT.java b/server/src/internalClusterTest/java/org/elasticsearch/action/IndicesRequestIT.java index 68e65b16aa3a2..c20e6ad05fa2f 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/action/IndicesRequestIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/action/IndicesRequestIT.java @@ -574,11 +574,8 @@ public void testSearchQueryThenFetch() throws Exception { ); clearInterceptedActions(); - assertIndicesSubset( - Arrays.asList(searchRequest.indices()), - SearchTransportService.QUERY_ACTION_NAME, - SearchTransportService.FETCH_ID_ACTION_NAME - ); + assertIndicesSubset(Arrays.asList(searchRequest.indices()), true, SearchTransportService.QUERY_ACTION_NAME); + assertIndicesSubset(Arrays.asList(searchRequest.indices()), SearchTransportService.FETCH_ID_ACTION_NAME); } public void testSearchDfsQueryThenFetch() throws Exception { @@ -631,10 +628,6 @@ private static void assertIndicesSubset(List indices, String... actions) assertIndicesSubset(indices, false, actions); } - private static void assertIndicesSubsetOptionalRequests(List indices, String... actions) { - assertIndicesSubset(indices, true, actions); - } - private static void assertIndicesSubset(List indices, boolean optional, String... actions) { // indices returned by each bulk shard request need to be a subset of the original indices for (String action : actions) { diff --git a/server/src/internalClusterTest/java/org/elasticsearch/action/admin/cluster/node/tasks/TasksIT.java b/server/src/internalClusterTest/java/org/elasticsearch/action/admin/cluster/node/tasks/TasksIT.java index 8afdbc5906491..662a13c6dd77b 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/action/admin/cluster/node/tasks/TasksIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/action/admin/cluster/node/tasks/TasksIT.java @@ -26,15 +26,11 @@ import org.elasticsearch.action.admin.indices.validate.query.ValidateQueryAction; import org.elasticsearch.action.bulk.TransportBulkAction; import org.elasticsearch.action.index.TransportIndexAction; -import org.elasticsearch.action.search.SearchTransportService; -import org.elasticsearch.action.search.TransportSearchAction; -import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.action.support.replication.ReplicationResponse; import org.elasticsearch.action.support.replication.TransportReplicationActionTests; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.regex.Regex; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Tuple; import org.elasticsearch.health.node.selection.HealthNode; @@ -88,16 +84,13 @@ import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailures; import static org.hamcrest.Matchers.allOf; -import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.emptyCollectionOf; import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.hamcrest.Matchers.not; -import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.startsWith; /** @@ -351,56 +344,6 @@ public void testTransportBulkTasks() { assertParentTask(findEvents(TransportBulkAction.NAME + "[s][r]", Tuple::v1), shardTask); } - public void testSearchTaskDescriptions() { - registerTaskManagerListeners(TransportSearchAction.TYPE.name()); // main task - registerTaskManagerListeners(TransportSearchAction.TYPE.name() + "[*]"); // shard task - createIndex("test"); - ensureGreen("test"); // Make sure all shards are allocated to catch replication tasks - prepareIndex("test").setId("test_id") - .setSource("{\"foo\": \"bar\"}", XContentType.JSON) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .get(); - - Map headers = new HashMap<>(); - headers.put(Task.X_OPAQUE_ID_HTTP_HEADER, "my_id"); - headers.put("Foo-Header", "bar"); - headers.put("Custom-Task-Header", "my_value"); - assertNoFailures(client().filterWithHeader(headers).prepareSearch("test").setQuery(QueryBuilders.matchAllQuery())); - - // the search operation should produce one main task - List mainTask = findEvents(TransportSearchAction.TYPE.name(), Tuple::v1); - assertEquals(1, mainTask.size()); - assertThat(mainTask.get(0).description(), startsWith("indices[test], search_type[")); - assertThat(mainTask.get(0).description(), containsString("\"query\":{\"match_all\"")); - assertTaskHeaders(mainTask.get(0)); - - // check that if we have any shard-level requests they all have non-zero length description - List shardTasks = findEvents(TransportSearchAction.TYPE.name() + "[*]", Tuple::v1); - for (TaskInfo taskInfo : shardTasks) { - assertThat(taskInfo.parentTaskId(), notNullValue()); - assertEquals(mainTask.get(0).taskId(), taskInfo.parentTaskId()); - assertTaskHeaders(taskInfo); - switch (taskInfo.action()) { - case SearchTransportService.QUERY_ACTION_NAME, SearchTransportService.DFS_ACTION_NAME -> assertTrue( - taskInfo.description(), - Regex.simpleMatch("shardId[[test][*]]", taskInfo.description()) - ); - case SearchTransportService.QUERY_ID_ACTION_NAME -> assertTrue( - taskInfo.description(), - Regex.simpleMatch("id[*], indices[test]", taskInfo.description()) - ); - case SearchTransportService.FETCH_ID_ACTION_NAME -> assertTrue( - taskInfo.description(), - Regex.simpleMatch("id[*], size[1], lastEmittedDoc[null]", taskInfo.description()) - ); - default -> fail("Unexpected action [" + taskInfo.action() + "] with description [" + taskInfo.description() + "]"); - } - // assert that all task descriptions have non-zero length - assertThat(taskInfo.description().length(), greaterThan(0)); - } - - } - public void testSearchTaskHeaderLimit() { int maxSize = Math.toIntExact(SETTING_HTTP_MAX_HEADER_SIZE.getDefault(Settings.EMPTY).getBytes() / 2 + 1); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/action/search/SearchProgressActionListenerIT.java b/server/src/internalClusterTest/java/org/elasticsearch/action/search/SearchProgressActionListenerIT.java index 30291eb07f155..0e805ae996eda 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/action/search/SearchProgressActionListenerIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/action/search/SearchProgressActionListenerIT.java @@ -184,7 +184,6 @@ public SearchTask createTask(long id, String type, String action, TaskId parentT assertThat(numFetchResults.get(), equalTo(0)); assertThat(numFetchFailures.get(), equalTo(0)); } - assertThat(numReduces.get(), equalTo(searchResponse.get().getNumReducePhases())); } private static List createRandomIndices(Client client) { diff --git a/server/src/internalClusterTest/java/org/elasticsearch/action/search/TransportSearchIT.java b/server/src/internalClusterTest/java/org/elasticsearch/action/search/TransportSearchIT.java index 5db2651c703d2..1fecbe3f62ccf 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/action/search/TransportSearchIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/action/search/TransportSearchIT.java @@ -20,7 +20,6 @@ import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.action.support.WriteRequest; -import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.Strings; import org.elasticsearch.common.breaker.CircuitBreaker; @@ -444,84 +443,6 @@ public void testSearchIdle() throws Exception { ); } - public void testCircuitBreakerReduceFail() throws Exception { - int numShards = randomIntBetween(1, 10); - indexSomeDocs("test", numShards, numShards * 3); - - { - final AtomicArray responses = new AtomicArray<>(10); - final CountDownLatch latch = new CountDownLatch(10); - for (int i = 0; i < 10; i++) { - int batchReduceSize = randomIntBetween(2, Math.max(numShards + 1, 3)); - SearchRequest request = prepareSearch("test").addAggregation(new TestAggregationBuilder("test")) - .setBatchedReduceSize(batchReduceSize) - .request(); - final int index = i; - client().search(request, new ActionListener<>() { - @Override - public void onResponse(SearchResponse response) { - responses.set(index, true); - latch.countDown(); - } - - @Override - public void onFailure(Exception e) { - responses.set(index, false); - latch.countDown(); - } - }); - } - latch.await(); - assertThat(responses.asList().size(), equalTo(10)); - for (boolean resp : responses.asList()) { - assertTrue(resp); - } - assertBusy(() -> assertThat(requestBreakerUsed(), equalTo(0L))); - } - - try { - updateClusterSettings(Settings.builder().put("indices.breaker.request.limit", "1b")); - final Client client = client(); - assertBusy(() -> { - Exception exc = expectThrows( - Exception.class, - client.prepareSearch("test").addAggregation(new TestAggregationBuilder("test")) - ); - assertThat(exc.getCause().getMessage(), containsString("")); - }); - - final AtomicArray exceptions = new AtomicArray<>(10); - final CountDownLatch latch = new CountDownLatch(10); - for (int i = 0; i < 10; i++) { - int batchReduceSize = randomIntBetween(2, Math.max(numShards + 1, 3)); - SearchRequest request = prepareSearch("test").addAggregation(new TestAggregationBuilder("test")) - .setBatchedReduceSize(batchReduceSize) - .request(); - final int index = i; - client().search(request, new ActionListener<>() { - @Override - public void onResponse(SearchResponse response) { - latch.countDown(); - } - - @Override - public void onFailure(Exception exc) { - exceptions.set(index, exc); - latch.countDown(); - } - }); - } - latch.await(); - assertThat(exceptions.asList().size(), equalTo(10)); - for (Exception exc : exceptions.asList()) { - assertThat(exc.getCause().getMessage(), containsString("")); - } - assertBusy(() -> assertThat(requestBreakerUsed(), equalTo(0L))); - } finally { - updateClusterSettings(Settings.builder().putNull("indices.breaker.request.limit")); - } - } - public void testCircuitBreakerFetchFail() throws Exception { int numShards = randomIntBetween(1, 10); int numDocs = numShards * 10; diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/aggregations/bucket/TermsDocCountErrorIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/aggregations/bucket/TermsDocCountErrorIT.java index a6c01852e2f16..91bc2f0d8e4b0 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/aggregations/bucket/TermsDocCountErrorIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/aggregations/bucket/TermsDocCountErrorIT.java @@ -35,6 +35,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.lessThanOrEqualTo; +import static org.hamcrest.Matchers.oneOf; import static org.hamcrest.core.IsNull.notNullValue; @ESIntegTestCase.SuiteScopeTestCase @@ -921,7 +922,7 @@ public void testFixedDocs() throws Exception { response -> { Terms terms = response.getAggregations().get("terms"); assertThat(terms, notNullValue()); - assertThat(terms.getDocCountError(), equalTo(46L)); + assertThat(terms.getDocCountError(), oneOf(0L, 46L)); List buckets = terms.getBuckets(); assertThat(buckets, notNullValue()); assertThat(buckets.size(), equalTo(5)); diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 1b4931236d56f..217071e43f16d 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -179,6 +179,7 @@ static TransportVersion def(int id) { public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED = def(9_003_0_00); public static final TransportVersion REMOVE_DESIRED_NODE_VERSION = def(9_004_0_00); public static final TransportVersion ESQL_DRIVER_TASK_DESCRIPTION = def(9_005_0_00); + public static final TransportVersion BATCHED_QUERY_PHASE_VERSION = def(9_006_0_00); /* * STOP! READ THIS FIRST! No, really, * ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _ diff --git a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java index 44752d6f33600..baf7ca836ea12 100644 --- a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java @@ -10,13 +10,10 @@ package org.elasticsearch.action.search; import org.apache.logging.log4j.Logger; -import org.apache.lucene.util.SetOnce; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ExceptionsHelper; -import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.NoShardAvailableActionException; -import org.elasticsearch.action.OriginalIndices; import org.elasticsearch.action.ShardOperationFailedException; import org.elasticsearch.action.search.TransportSearchAction.SearchTimeProvider; import org.elasticsearch.action.support.SubscribableListener; @@ -27,32 +24,23 @@ import org.elasticsearch.common.util.Maps; import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.core.Releasable; -import org.elasticsearch.core.Releasables; import org.elasticsearch.index.shard.ShardId; -import org.elasticsearch.search.SearchContextMissingException; import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchShardTarget; -import org.elasticsearch.search.builder.PointInTimeBuilder; -import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.internal.AliasFilter; -import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.internal.ShardSearchContextId; import org.elasticsearch.search.internal.ShardSearchRequest; -import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.transport.Transport; -import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.Executor; import java.util.concurrent.Semaphore; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiFunction; import java.util.function.Consumer; +import java.util.function.Function; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -66,43 +54,17 @@ * The fan out and collect algorithm is traditionally used as the initial phase which can either be a query execution or collection of * distributed frequencies */ -abstract class AbstractSearchAsyncAction extends SearchPhase { - private static final float DEFAULT_INDEX_BOOST = 1.0f; - private final Logger logger; - private final NamedWriteableRegistry namedWriteableRegistry; - private final SearchTransportService searchTransportService; - private final Executor executor; - private final ActionListener listener; - private final SearchRequest request; +abstract class AbstractSearchAsyncAction extends AsyncSearchContext { + static final float DEFAULT_INDEX_BOOST = 1.0f; + protected final Logger logger; - /** - * Used by subclasses to resolve node ids to DiscoveryNodes. - **/ - private final BiFunction nodeIdToConnection; - private final SearchTask task; - protected final SearchPhaseResults results; private final long clusterStateVersion; - private final TransportVersion minTransportVersion; - private final Map aliasFilter; - private final Map concreteIndexBoosts; - private final SetOnce> shardFailures = new SetOnce<>(); - private final Object shardFailuresMutex = new Object(); - private final AtomicBoolean hasShardResponse = new AtomicBoolean(false); - private final AtomicInteger successfulOps = new AtomicInteger(); - private final SearchTimeProvider timeProvider; - private final SearchResponse.Clusters clusters; - - protected final List toSkipShardsIts; - protected final List shardsIts; - private final SearchShardIterator[] shardIterators; - private final AtomicInteger outstandingShards; private final int maxConcurrentRequestsPerNode; private final Map pendingExecutionsPerNode = new ConcurrentHashMap<>(); private final boolean throttleConcurrentRequests; - private final AtomicBoolean requestCancelled = new AtomicBoolean(); // protected for tests - protected final List releasables = new ArrayList<>(); + protected final String name; AbstractSearchAsyncAction( String name, @@ -123,66 +85,69 @@ abstract class AbstractSearchAsyncAction exten int maxConcurrentRequestsPerNode, SearchResponse.Clusters clusters ) { - super(name); - this.namedWriteableRegistry = namedWriteableRegistry; - final List toSkipIterators = new ArrayList<>(); - final List iterators = new ArrayList<>(); - for (final SearchShardIterator iterator : shardsIts) { - if (iterator.skip()) { - toSkipIterators.add(iterator); - } else { - iterators.add(iterator); - } - } - this.toSkipShardsIts = toSkipIterators; - this.shardsIts = iterators; - outstandingShards = new AtomicInteger(shardsIts.size()); - this.shardIterators = iterators.toArray(new SearchShardIterator[0]); - // we later compute the shard index based on the natural order of the shards - // that participate in the search request. This means that this number is - // consistent between two requests that target the same shards. - Arrays.sort(shardIterators); + super( + request, + resultConsumer, + namedWriteableRegistry, + listener, + task, + searchTransportService, + executor, + nodeIdToConnection, + shardsIts, + aliasFilter, + concreteIndexBoosts, + timeProvider, + clusterState, + clusters + ); + this.name = name; this.maxConcurrentRequestsPerNode = maxConcurrentRequestsPerNode; // in the case were we have less shards than maxConcurrentRequestsPerNode we don't need to throttle this.throttleConcurrentRequests = maxConcurrentRequestsPerNode < shardsIts.size(); - this.timeProvider = timeProvider; this.logger = logger; - this.searchTransportService = searchTransportService; - this.executor = executor; - this.request = request; - this.task = task; - this.listener = ActionListener.runAfter(listener, () -> Releasables.close(releasables)); - this.nodeIdToConnection = nodeIdToConnection; - this.concreteIndexBoosts = concreteIndexBoosts; this.clusterStateVersion = clusterState.version(); - this.minTransportVersion = clusterState.getMinTransportVersion(); - this.aliasFilter = aliasFilter; - this.results = resultConsumer; - // register the release of the query consumer to free up the circuit breaker memory - // at the end of the search - addReleasable(resultConsumer); - this.clusters = clusters; } - protected void notifyListShards( - SearchProgressListener progressListener, - SearchResponse.Clusters clusters, - SearchSourceBuilder sourceBuilder - ) { - progressListener.notifyListShards( - SearchProgressListener.buildSearchShardsFromIter(this.shardsIts), - SearchProgressListener.buildSearchShardsFromIter(toSkipShardsIts), - clusters, - sourceBuilder == null || sourceBuilder.size() > 0, - timeProvider - ); + protected String missingShardsErrorMessage(StringBuilder missingShards) { + return makeMissingShardsError(missingShards); } - /** - * Registers a {@link Releasable} that will be closed when the search request finishes or fails. - */ - public void addReleasable(Releasable releasable) { - releasables.add(releasable); + protected static String makeMissingShardsError(StringBuilder missingShards) { + return "Search rejected due to missing shards [" + + missingShards + + "]. Consider using `allow_partial_search_results` setting to bypass this error."; + } + + protected void doCheckNoMissingShards(String phaseName, SearchRequest request, List shardsIts) { + doCheckNoMissingShards(phaseName, request, shardsIts, this::missingShardsErrorMessage); + } + + protected static void doCheckNoMissingShards( + String phaseName, + SearchRequest request, + List shardsIts, + Function makeErrorMessage + ) { + assert request.allowPartialSearchResults() != null : "SearchRequest missing setting for allowPartialSearchResults"; + if (request.allowPartialSearchResults() == false) { + final StringBuilder missingShards = new StringBuilder(); + // Fail-fast verification of all shards being available + for (int index = 0; index < shardsIts.size(); index++) { + final SearchShardIterator shardRoutings = shardsIts.get(index); + if (shardRoutings.size() == 0) { + if (missingShards.isEmpty() == false) { + missingShards.append(", "); + } + missingShards.append(shardRoutings.shardId()); + } + } + if (missingShards.isEmpty() == false) { + // Status red - shard is missing all copies and would produce partial results for an index search + final String msg = makeErrorMessage.apply(missingShards); + throw new SearchPhaseExecutionException(phaseName, msg, null, ShardSearchFailure.EMPTY_ARRAY); + } + } } /** @@ -196,35 +161,31 @@ long buildTookInMillis() { * This is the main entry point for a search. This method starts the search execution of the initial phase. */ public final void start() { - if (getNumShards() == 0) { - // no search shards to search on, bail with empty response - // (it happens with search across _all with no indices around and consistent with broadcast operations) - int trackTotalHitsUpTo = request.source() == null ? SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO - : request.source().trackTotalHitsUpTo() == null ? SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO - : request.source().trackTotalHitsUpTo(); - // total hits is null in the response if the tracking of total hits is disabled - boolean withTotalHits = trackTotalHitsUpTo != SearchContext.TRACK_TOTAL_HITS_DISABLED; - sendSearchResponse( - withTotalHits ? SearchResponseSections.EMPTY_WITH_TOTAL_HITS : SearchResponseSections.EMPTY_WITHOUT_TOTAL_HITS, - new AtomicArray<>(0) - ); + if (results.getNumShards() == 0) { + sendZeroShardsResponse(); return; } - executePhase(this); + try { + run(); + } catch (RuntimeException e) { + if (logger.isDebugEnabled()) { + logger.debug(() -> format("Failed to execute [%s] while moving to [%s] phase", request, name), e); + } + onPhaseFailure(name, "", e); + } } - @Override - protected final void run() { - for (final SearchShardIterator iterator : toSkipShardsIts) { - assert iterator.skip(); - skipShard(iterator); + private void run() { + if (shardsIts.size() == 0) { + finish(); + return; } final Map shardIndexMap = Maps.newHashMapWithExpectedSize(shardIterators.length); for (int i = 0; i < shardIterators.length; i++) { shardIndexMap.put(shardIterators[i], i); } if (shardsIts.size() > 0) { - doCheckNoMissingShards(getName(), request, shardsIts); + doCheckNoMissingShards(name, request, shardsIts); for (int i = 0; i < shardsIts.size(); i++) { final SearchShardIterator shardRoutings = shardsIts.get(i); assert shardRoutings.skip() == false; @@ -240,38 +201,6 @@ protected final void run() { } } - void skipShard(SearchShardIterator iterator) { - successfulOps.incrementAndGet(); - assert iterator.skip(); - successfulShardExecution(); - } - - private static boolean assertExecuteOnStartThread() { - // Ensure that the current code has the following stacktrace: - // AbstractSearchAsyncAction#start -> AbstractSearchAsyncAction#executePhase -> AbstractSearchAsyncAction#performPhaseOnShard - final StackTraceElement[] stackTraceElements = Thread.currentThread().getStackTrace(); - assert stackTraceElements.length >= 6 : stackTraceElements; - int index = 0; - assert stackTraceElements[index++].getMethodName().equals("getStackTrace"); - assert stackTraceElements[index++].getMethodName().equals("assertExecuteOnStartThread"); - assert stackTraceElements[index++].getMethodName().equals("failOnUnavailable"); - if (stackTraceElements[index].getMethodName().equals("performPhaseOnShard")) { - assert stackTraceElements[index].getClassName().endsWith("CanMatchPreFilterSearchPhase"); - index++; - } - assert stackTraceElements[index].getClassName().endsWith("AbstractSearchAsyncAction"); - assert stackTraceElements[index++].getMethodName().equals("run"); - - assert stackTraceElements[index].getClassName().endsWith("AbstractSearchAsyncAction"); - assert stackTraceElements[index++].getMethodName().equals("executePhase"); - - assert stackTraceElements[index].getClassName().endsWith("AbstractSearchAsyncAction"); - assert stackTraceElements[index++].getMethodName().equals("start"); - - assert stackTraceElements[index].getClassName().endsWith("AbstractSearchAsyncAction") == false; - return true; - } - private void performPhaseOnShard(final int shardIndex, final SearchShardIterator shardIt, final SearchShardTarget shard) { if (throttleConcurrentRequests) { var pendingExecutions = pendingExecutionsPerNode.computeIfAbsent( @@ -290,7 +219,7 @@ private void doPerformPhaseOnShard(int shardIndex, SearchShardIterator shardIt, public void innerOnResponse(Result result) { try { releasable.close(); - onShardResult(result, shardIt); + onShardResult(result); } catch (Exception exc) { onShardFailure(shardIndex, shard, shardIt, exc); } @@ -313,7 +242,6 @@ public void onFailure(Exception e) { } private void failOnUnavailable(int shardIndex, SearchShardIterator shardIt) { - assert assertExecuteOnStartThread(); SearchShardTarget unassignedShard = new SearchShardTarget(null, shardIt.shardId(), shardIt.getClusterAlias()); onShardFailure(shardIndex, unassignedShard, shardIt, new NoShardAvailableActionException(shardIt.shardId())); } @@ -335,13 +263,14 @@ protected abstract void executePhaseOnShard( * of the next phase. If there are no successful operations in the context when this method is executed the search is aborted and * a response is returned to the user indicating that all shards have failed. */ - protected void executeNextPhase(String currentPhase, Supplier nextPhaseSupplier) { + public void executeNextPhase(String currentPhase, Supplier nextPhaseSupplier) { /* This is the main search phase transition where we move to the next phase. If all shards * failed or if there was a failure and partial results are not allowed, then we immediately * fail. Otherwise we continue to the next phase. */ - ShardOperationFailedException[] shardSearchFailures = buildShardFailures(); - if (shardSearchFailures.length == getNumShards()) { + ShardOperationFailedException[] shardSearchFailures = buildShardFailures(shardFailures); + final int numShards = results.getNumShards(); + if (shardSearchFailures.length == numShards) { shardSearchFailures = ExceptionsHelper.groupBy(shardSearchFailures); Throwable cause = shardSearchFailures.length == 0 ? null @@ -351,32 +280,8 @@ protected void executeNextPhase(String currentPhase, Supplier nextP } else { Boolean allowPartialResults = request.allowPartialSearchResults(); assert allowPartialResults != null : "SearchRequest missing setting for allowPartialSearchResults"; - if (allowPartialResults == false && successfulOps.get() != getNumShards()) { - // check if there are actual failures in the atomic array since - // successful retries can reset the failures to null - if (shardSearchFailures.length > 0) { - if (logger.isDebugEnabled()) { - int numShardFailures = shardSearchFailures.length; - shardSearchFailures = ExceptionsHelper.groupBy(shardSearchFailures); - Throwable cause = ElasticsearchException.guessRootCauses(shardSearchFailures[0].getCause())[0]; - logger.debug(() -> format("%s shards failed for phase: [%s]", numShardFailures, currentPhase), cause); - } - onPhaseFailure(currentPhase, "Partial shards failure", null); - } else { - int discrepancy = getNumShards() - successfulOps.get(); - assert discrepancy > 0 : "discrepancy: " + discrepancy; - if (logger.isDebugEnabled()) { - logger.debug( - "Partial shards failure (unavailable: {}, successful: {}, skipped: {}, num-shards: {}, phase: {})", - discrepancy, - successfulOps.get(), - toSkipShardsIts.size(), - getNumShards(), - currentPhase - ); - } - onPhaseFailure(currentPhase, "Partial shards failure (" + discrepancy + " shards unavailable)", null); - } + if (allowPartialResults == false && successfulOps.get() != numShards) { + handleNotAllSucceeded(currentPhase, shardSearchFailures, numShards); return; } var nextPhase = nextPhaseSupplier.get(); @@ -396,30 +301,6 @@ protected void executeNextPhase(String currentPhase, Supplier nextP } } - private void executePhase(SearchPhase phase) { - try { - phase.run(); - } catch (RuntimeException e) { - if (logger.isDebugEnabled()) { - logger.debug(() -> format("Failed to execute [%s] while moving to [%s] phase", request, phase.getName()), e); - } - onPhaseFailure(phase.getName(), "", e); - } - } - - private ShardSearchFailure[] buildShardFailures() { - AtomicArray shardFailures = this.shardFailures.get(); - if (shardFailures == null) { - return ShardSearchFailure.EMPTY_ARRAY; - } - List entries = shardFailures.asList(); - ShardSearchFailure[] failures = new ShardSearchFailure[entries.size()]; - for (int i = 0; i < failures.length; i++) { - failures[i] = entries.get(i); - } - return failures; - } - private void onShardFailure(final int shardIndex, SearchShardTarget shard, final SearchShardIterator shardIt, Exception e) { // we always add the shard failure for a specific shard instance // we do make sure to clean it on a successful response from a shard @@ -428,26 +309,13 @@ private void onShardFailure(final int shardIndex, SearchShardTarget shard, final final boolean lastShard = nextShard == null; logger.debug(() -> format("%s: Failed to execute [%s] lastShard [%s]", shard, request, lastShard), e); if (lastShard) { - if (request.allowPartialSearchResults() == false) { - if (requestCancelled.compareAndSet(false, true)) { - try { - searchTransportService.cancelSearchTask(task, "partial results are not allowed and at least one shard has failed"); - } catch (Exception cancelFailure) { - logger.debug("Failed to cancel search request", cancelFailure); - } - } - } + maybeCancelSearchTask(); onShardGroupFailure(shardIndex, shard, e); } if (lastShard == false) { performPhaseOnShard(shardIndex, shardIt, nextShard); } else { - // count down outstanding shards, we're done with this shard as there's no more copies to try - final int outstanding = outstandingShards.decrementAndGet(); - assert outstanding >= 0 : "outstanding: " + outstanding; - if (outstanding == 0) { - onPhaseDone(); - } + finishShardAndMaybePhase(); } } @@ -468,134 +336,15 @@ protected void onShardGroupFailure(int shardIndex, SearchShardTarget shardTarget * @param shardTarget the shard target for this failure * @param e the failure reason */ - void onShardFailure(final int shardIndex, SearchShardTarget shardTarget, Exception e) { + @Override + public void onShardFailure(final int shardIndex, SearchShardTarget shardTarget, Exception e) { if (TransportActions.isShardNotAvailableException(e)) { // Groups shard not available exceptions under a generic exception that returns a SERVICE_UNAVAILABLE(503) // temporary error. e = NoShardAvailableActionException.forOnShardFailureWrapper(e.getMessage()); } - // we don't aggregate shard on failures due to the internal cancellation, - // but do keep the header counts right - if ((requestCancelled.get() && isTaskCancelledException(e)) == false) { - AtomicArray shardFailures = this.shardFailures.get(); - // lazily create shard failures, so we can early build the empty shard failure list in most cases (no failures) - if (shardFailures == null) { // this is double checked locking but it's fine since SetOnce uses a volatile read internally - synchronized (shardFailuresMutex) { - shardFailures = this.shardFailures.get(); // read again otherwise somebody else has created it? - if (shardFailures == null) { // still null so we are the first and create a new instance - shardFailures = new AtomicArray<>(getNumShards()); - this.shardFailures.set(shardFailures); - } - } - } - ShardSearchFailure failure = shardFailures.get(shardIndex); - if (failure == null) { - shardFailures.set(shardIndex, new ShardSearchFailure(e, shardTarget)); - } else { - // the failure is already present, try and not override it with an exception that is less meaningless - // for example, getting illegal shard state - if (TransportActions.isReadOverrideException(e) && (e instanceof SearchContextMissingException == false)) { - shardFailures.set(shardIndex, new ShardSearchFailure(e, shardTarget)); - } - } - - if (results.hasResult(shardIndex)) { - assert failure == null : "shard failed before but shouldn't: " + failure; - successfulOps.decrementAndGet(); // if this shard was successful before (initial phase) we have to adjust the counter - } - } - } - - private static boolean isTaskCancelledException(Exception e) { - return ExceptionsHelper.unwrapCausesAndSuppressed(e, ex -> ex instanceof TaskCancelledException).isPresent(); - } - - /** - * Executed once for every successful shard level request. - * @param result the result returned form the shard - * @param shardIt the shard iterator - */ - protected void onShardResult(Result result, SearchShardIterator shardIt) { - assert result.getShardIndex() != -1 : "shard index is not set"; - assert result.getSearchShardTarget() != null : "search shard target must not be null"; - hasShardResponse.set(true); - if (logger.isTraceEnabled()) { - logger.trace("got first-phase result from {}", result != null ? result.getSearchShardTarget() : null); - } - results.consumeResult(result, () -> onShardResultConsumed(result)); - } - - private void onShardResultConsumed(Result result) { - successfulOps.incrementAndGet(); - // clean a previous error on this shard group (note, this code will be serialized on the same shardIndex value level - // so its ok concurrency wise to miss potentially the shard failures being created because of another failure - // in the #addShardFailure, because by definition, it will happen on *another* shardIndex - AtomicArray shardFailures = this.shardFailures.get(); - if (shardFailures != null) { - shardFailures.set(result.getShardIndex(), null); - } - // we need to increment successful ops first before we compare the exit condition otherwise if we - // are fast we could concurrently update totalOps but then preempt one of the threads which can - // cause the successor to read a wrong value from successfulOps if second phase is very fast ie. count etc. - // increment all the "future" shards to update the total ops since we some may work and some may not... - // and when that happens, we break on total ops, so we must maintain them - successfulShardExecution(); - } - - private void successfulShardExecution() { - final int outstanding = outstandingShards.decrementAndGet(); - assert outstanding >= 0 : "outstanding: " + outstanding; - if (outstanding == 0) { - onPhaseDone(); - } - } - - /** - * Returns the total number of shards to the current search across all indices - */ - public final int getNumShards() { - return results.getNumShards(); - } - - /** - * Returns a logger for this context to prevent each individual phase to create their own logger. - */ - public final Logger getLogger() { - return logger; - } - - /** - * Returns the currently executing search task - */ - public final SearchTask getTask() { - return task; - } - - /** - * Returns the currently executing search request - */ - public final SearchRequest getRequest() { - return request; - } - /** - * Returns the targeted {@link OriginalIndices} for the provided {@code shardIndex}. - */ - public OriginalIndices getOriginalIndices(int shardIndex) { - return shardIterators[shardIndex].getOriginalIndices(); - } - - /** - * Checks if the given context id is part of the point in time of this search (if exists). - * We should not release search contexts that belong to the point in time during or after searches. - */ - public boolean isPartOfPointInTime(ShardSearchContextId contextId) { - final PointInTimeBuilder pointInTimeBuilder = request.pointInTimeBuilder(); - if (pointInTimeBuilder != null) { - return request.pointInTimeBuilder().getSearchContextId(namedWriteableRegistry).contains(contextId); - } else { - return false; - } + handleFailedAndCancelled(shardIndex, shardTarget, e); } private SearchResponse buildSearchResponse( @@ -606,12 +355,13 @@ private SearchResponse buildSearchResponse( ) { int numSuccess = successfulOps.get(); int numFailures = failures.length; - assert numSuccess + numFailures == getNumShards() - : "numSuccess(" + numSuccess + ") + numFailures(" + numFailures + ") != totalShards(" + getNumShards() + ")"; + final int numShards = results.getNumShards(); + assert numSuccess + numFailures == numShards + : "numSuccess(" + numSuccess + ") + numFailures(" + numFailures + ") != totalShards(" + numShards + ")"; return new SearchResponse( internalSearchResponse, scrollId, - getNumShards(), + numShards, numSuccess, toSkipShardsIts.size(), buildTookInMillis(), @@ -631,8 +381,8 @@ boolean buildPointInTimeFromSearchResults() { * @param internalSearchResponse the internal search response * @param queryResults the results of the query phase */ - public void sendSearchResponse(SearchResponseSections internalSearchResponse, AtomicArray queryResults) { - ShardSearchFailure[] failures = buildShardFailures(); + public void sendSearchResponse(SearchResponseSections internalSearchResponse, AtomicArray queryResults) { + ShardSearchFailure[] failures = buildShardFailures(shardFailures); Boolean allowPartialResults = request.allowPartialSearchResults(); assert allowPartialResults != null : "SearchRequest missing setting for allowPartialSearchResults"; if (allowPartialResults == false && failures.length > 0) { @@ -643,13 +393,7 @@ public void sendSearchResponse(SearchResponseSections internalSearchResponse, At if (buildPointInTimeFromSearchResults()) { searchContextId = SearchContextId.encode(queryResults.asList(), aliasFilter, minTransportVersion, failures); } else { - if (request.source() != null - && request.source().pointInTimeBuilder() != null - && request.source().pointInTimeBuilder().singleSession() == false) { - searchContextId = request.source().pointInTimeBuilder().getEncodedId(); - } else { - searchContextId = null; - } + searchContextId = buildSearchContextId(); } ActionListener.respondAndRelease(listener, buildSearchResponse(internalSearchResponse, failures, scrollId, searchContextId)); } @@ -662,31 +406,9 @@ public void sendSearchResponse(SearchResponseSections internalSearchResponse, At * @param msg an optional message * @param cause the cause of the phase failure */ + @Override public void onPhaseFailure(String phase, String msg, Throwable cause) { - raisePhaseFailure(new SearchPhaseExecutionException(phase, msg, cause, buildShardFailures())); - } - - /** - * This method should be called if a search phase failed to ensure all relevant reader contexts are released. - * This method will also notify the listener and sends back a failure to the user. - * - * @param exception the exception explaining or causing the phase failure - */ - private void raisePhaseFailure(SearchPhaseExecutionException exception) { - results.getSuccessfulResults().forEach((entry) -> { - // Do not release search contexts that are part of the point in time - if (entry.getContextId() != null && isPartOfPointInTime(entry.getContextId()) == false) { - try { - SearchShardTarget searchShardTarget = entry.getSearchShardTarget(); - Transport.Connection connection = getConnection(searchShardTarget.getClusterAlias(), searchShardTarget.getNodeId()); - sendReleaseSearchContext(entry.getContextId(), connection); - } catch (Exception inner) { - inner.addSuppressed(exception); - logger.trace("failed to release context", inner); - } - } - }); - listener.onFailure(exception); + raisePhaseFailure(new SearchPhaseExecutionException(phase, msg, cause, buildShardFailures(shardFailures))); } /** @@ -695,39 +417,17 @@ private void raisePhaseFailure(SearchPhaseExecutionException exception) { * @see org.elasticsearch.search.fetch.FetchSearchResult#getContextId() * */ - void sendReleaseSearchContext(ShardSearchContextId contextId, Transport.Connection connection) { + @Override + public void sendReleaseSearchContext(ShardSearchContextId contextId, Transport.Connection connection) { assert isPartOfPointInTime(contextId) == false : "Must not release point in time context [" + contextId + "]"; if (connection != null) { searchTransportService.sendFreeContext(connection, contextId, ActionListener.noop()); } } - /** - * Executed once all shard results have been received and processed - * @see #onShardFailure(int, SearchShardTarget, Exception) - * @see #onShardResult(SearchPhaseResult, SearchShardIterator) - */ - private void onPhaseDone() { // as a tribute to @kimchy aka. finishHim() - executeNextPhase(getName(), this::getNextPhase); - } - - /** - * Returns a connection to the node if connected otherwise and {@link org.elasticsearch.transport.ConnectTransportException} will be - * thrown. - */ - public final Transport.Connection getConnection(String clusterAlias, String nodeId) { - return nodeIdToConnection.apply(clusterAlias, nodeId); - } - - /** - * Returns the {@link SearchTransportService} to send shard request to other nodes - */ - public SearchTransportService getSearchTransport() { - return searchTransportService; - } - - public final void execute(Runnable command) { - executor.execute(command); + @Override + protected void finish() { // as a tribute to @kimchy aka. finishHim() + executeNextPhase(name, this::getNextPhase); } /** @@ -746,7 +446,7 @@ protected final ShardSearchRequest buildShardSearchRequest(SearchShardIterator s request, shardIt.shardId(), shardIndex, - getNumShards(), + results.getNumShards(), filter, indexBoost, timeProvider.absoluteStartMillis(), @@ -758,7 +458,7 @@ protected final ShardSearchRequest buildShardSearchRequest(SearchShardIterator s // can return a null response if the request rewrites to match none rather // than creating an empty response in the search thread pool. // Note that, we have to disable this shortcut for queries that create a context (scroll and search context). - shardRequest.canReturnNullResponseIfMatchNoDocs(hasShardResponse.get() && shardRequest.scroll() == null); + shardRequest.canReturnNullResponseIfMatchNoDocs(hasShardResponse && shardRequest.scroll() == null); return shardRequest; } @@ -767,7 +467,7 @@ protected final ShardSearchRequest buildShardSearchRequest(SearchShardIterator s */ protected abstract SearchPhase getNextPhase(); - private static final class PendingExecutions { + static final class PendingExecutions { private final Semaphore semaphore; private final ConcurrentLinkedQueue> queue = new ConcurrentLinkedQueue<>(); diff --git a/server/src/main/java/org/elasticsearch/action/search/AsyncSearchContext.java b/server/src/main/java/org/elasticsearch/action/search/AsyncSearchContext.java new file mode 100644 index 0000000000000..b6fbec39dbe84 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/action/search/AsyncSearchContext.java @@ -0,0 +1,431 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.action.search; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.OriginalIndices; +import org.elasticsearch.action.ShardOperationFailedException; +import org.elasticsearch.action.support.TransportActions; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.util.concurrent.AtomicArray; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.search.SearchContextMissingException; +import org.elasticsearch.search.SearchPhaseResult; +import org.elasticsearch.search.SearchShardTarget; +import org.elasticsearch.search.builder.PointInTimeBuilder; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.internal.AliasFilter; +import org.elasticsearch.search.internal.SearchContext; +import org.elasticsearch.search.internal.ShardSearchContextId; +import org.elasticsearch.tasks.TaskCancelledException; +import org.elasticsearch.transport.Transport; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiFunction; +import java.util.function.Supplier; + +import static org.elasticsearch.core.Strings.format; + +public abstract class AsyncSearchContext { + + private static final Logger logger = LogManager.getLogger(AsyncSearchContext.class); + + protected final SearchRequest request; + + protected final SearchPhaseResults results; + + private final NamedWriteableRegistry namedWriteableRegistry; + + protected final ActionListener listener; + + protected volatile boolean hasShardResponse = false; + + // protected for tests + protected final List releasables = new ArrayList<>(); + + private final AtomicBoolean requestCancelled = new AtomicBoolean(); + + protected final SearchTask task; + + protected final AtomicInteger successfulOps = new AtomicInteger(); + + protected final SearchTransportService searchTransportService; + private final Executor executor; + + protected final List toSkipShardsIts; + protected final List shardsIts; + protected final SearchShardIterator[] shardIterators; + + protected final SetOnce> shardFailures = new SetOnce<>(); + private final Object shardFailuresMutex = new Object(); + + protected final TransportVersion minTransportVersion; + protected final Map aliasFilter; + protected final Map concreteIndexBoosts; + protected final TransportSearchAction.SearchTimeProvider timeProvider; + protected final SearchResponse.Clusters clusters; + + private final AtomicInteger outstandingShards; + + /** + * Used by subclasses to resolve node ids to DiscoveryNodes. + **/ + protected final BiFunction nodeIdToConnection; + + protected AsyncSearchContext( + SearchRequest request, + SearchPhaseResults results, + NamedWriteableRegistry namedWriteableRegistry, + ActionListener listener, + SearchTask task, + SearchTransportService searchTransportService, + Executor executor, + BiFunction nodeIdToConnection, + List shardsIts, + Map aliasFilter, + Map concreteIndexBoosts, + TransportSearchAction.SearchTimeProvider timeProvider, + ClusterState clusterState, + SearchResponse.Clusters clusters + ) { + final List toSkipIterators = new ArrayList<>(); + final List iterators = new ArrayList<>(); + for (final SearchShardIterator iterator : shardsIts) { + if (iterator.skip()) { + toSkipIterators.add(iterator); + } else { + iterators.add(iterator); + } + } + this.toSkipShardsIts = toSkipIterators; + this.successfulOps.setRelease(toSkipIterators.size()); + this.shardsIts = iterators; + + this.shardIterators = iterators.toArray(new SearchShardIterator[0]); + // we later compute the shard index based on the natural order of the shards + // that participate in the search request. This means that this number is + // consistent between two requests that target the same shards. + Arrays.sort(shardIterators); + outstandingShards = new AtomicInteger(shardIterators.length); + this.request = request; + this.results = results; + this.namedWriteableRegistry = namedWriteableRegistry; + this.listener = ActionListener.runAfter(listener, () -> Releasables.close(releasables)); + this.task = task; + this.searchTransportService = searchTransportService; + this.executor = executor; + this.nodeIdToConnection = nodeIdToConnection; + // register the release of the query consumer to free up the circuit breaker memory + // at the end of the search + releasables.add(results); + + this.timeProvider = timeProvider; + this.concreteIndexBoosts = concreteIndexBoosts; + this.minTransportVersion = clusterState.getMinTransportVersion(); + this.aliasFilter = aliasFilter; + this.clusters = clusters; + } + + protected void notifyListShards( + SearchProgressListener progressListener, + SearchResponse.Clusters clusters, + SearchSourceBuilder sourceBuilder + ) { + progressListener.notifyListShards( + SearchProgressListener.buildSearchShardsFromIter(this.shardsIts), + SearchProgressListener.buildSearchShardsFromIter(toSkipShardsIts), + clusters, + sourceBuilder == null || sourceBuilder.size() > 0, + timeProvider + ); + } + + static boolean isTaskCancelledException(Exception e) { + return ExceptionsHelper.unwrapCausesAndSuppressed(e, ex -> ex instanceof TaskCancelledException).isPresent(); + } + + static ShardSearchFailure[] buildShardFailures(SetOnce> shardFailuresRef) { + AtomicArray shardFailures = shardFailuresRef.get(); + if (shardFailures == null) { + return ShardSearchFailure.EMPTY_ARRAY; + } + List entries = shardFailures.asList(); + ShardSearchFailure[] failures = new ShardSearchFailure[entries.size()]; + for (int i = 0; i < failures.length; i++) { + failures[i] = entries.get(i); + } + return failures; + } + + static boolean isPartOfPIT(NamedWriteableRegistry namedWriteableRegistry, SearchRequest request, ShardSearchContextId contextId) { + final PointInTimeBuilder pointInTimeBuilder = request.pointInTimeBuilder(); + if (pointInTimeBuilder != null) { + return request.pointInTimeBuilder().getSearchContextId(namedWriteableRegistry).contains(contextId); + } else { + return false; + } + } + + protected void maybeCancelSearchTask() { + if (request.allowPartialSearchResults() == false) { + if (requestCancelled.compareAndSet(false, true)) { + try { + searchTransportService.cancelSearchTask( + task.getId(), + "partial results are not allowed and at least one shard has failed" + ); + } catch (Exception cancelFailure) { + logger.debug("Failed to cancel search request", cancelFailure); + } + } + } + } + + protected final void sendZeroShardsResponse() { + // no search shards to search on, bail with empty response + // (it happens with search across _all with no indices around and consistent with broadcast operations) + var source = request.source(); + int trackTotalHitsUpTo = source == null ? SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO + : source.trackTotalHitsUpTo() == null ? SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO + : source.trackTotalHitsUpTo(); + // total hits is null in the response if the tracking of total hits is disabled + boolean withTotalHits = trackTotalHitsUpTo != SearchContext.TRACK_TOTAL_HITS_DISABLED; + sendSearchResponse( + withTotalHits ? SearchResponseSections.EMPTY_WITH_TOTAL_HITS : SearchResponseSections.EMPTY_WITHOUT_TOTAL_HITS, + new AtomicArray<>(0) + ); + } + + protected final void handleFailedAndCancelled(int shardIndex, SearchShardTarget shardTarget, Exception e) { + // we don't aggregate shard on failures due to the internal cancellation, + // but do keep the header counts right + if ((requestCancelled.get() && isTaskCancelledException(e)) == false) { + AtomicArray shardFailures = this.shardFailures.get(); + // lazily create shard failures, so we can early build the empty shard failure list in most cases (no failures) + if (shardFailures == null) { // this is double checked locking but it's fine since SetOnce uses a volatile read internally + synchronized (shardFailuresMutex) { + shardFailures = this.shardFailures.get(); // read again otherwise somebody else has created it? + if (shardFailures == null) { // still null so we are the first and create a new instance + shardFailures = new AtomicArray<>(results.getNumShards()); + this.shardFailures.set(shardFailures); + } + } + } + ShardSearchFailure failure = shardFailures.get(shardIndex); + if (failure == null) { + shardFailures.set(shardIndex, new ShardSearchFailure(e, shardTarget)); + } else { + // the failure is already present, try and not override it with an exception that is less meaningless + // for example, getting illegal shard state + if (TransportActions.isReadOverrideException(e) && (e instanceof SearchContextMissingException == false)) { + shardFailures.set(shardIndex, new ShardSearchFailure(e, shardTarget)); + } + } + + if (results.hasResult(shardIndex)) { + assert outstandingShards.getAcquire() == 0 : "should only be called by subsequent phases, not during query"; + assert failure == null : "shard failed before but shouldn't: " + failure; + successfulOps.decrementAndGet(); // if this shard was successful before (initial phase) we need to count down the successes + } + } + } + + protected final boolean finishShard() { + return outstandingShards.decrementAndGet() == 0; + } + + /** + * Returns the currently executing search request + */ + public final SearchRequest getRequest() { + return request; + } + + abstract void sendSearchResponse(SearchResponseSections internalSearchResponse, AtomicArray queryResults); + + /** + * Returns the {@link SearchTransportService} to send shard request to other nodes + */ + public SearchTransportService getSearchTransport() { + return searchTransportService; + } + + /** + * Returns the currently executing search task + */ + public final SearchTask getTask() { + return task; + } + + abstract void onPhaseFailure(String phase, String msg, Throwable cause); + + /** + * Registers a {@link Releasable} that will be closed when the search request finishes or fails. + */ + public final void addReleasable(Releasable releasable) { + releasables.add(releasable); + } + + public final void execute(Runnable command) { + executor.execute(command); + } + + abstract void onShardFailure(int shardIndex, SearchShardTarget shard, Exception e); + + public final Transport.Connection getConnection(String clusterAlias, String nodeId) { + return nodeIdToConnection.apply(clusterAlias, nodeId); + } + + /** + * Returns the targeted {@link OriginalIndices} for the provided {@code shardIndex}. + */ + public OriginalIndices getOriginalIndices(int shardIndex) { + return shardIterators[shardIndex].getOriginalIndices(); + } + + abstract void sendReleaseSearchContext(ShardSearchContextId contextId, Transport.Connection connection); + + abstract void executeNextPhase(String currentPhase, Supplier nextPhaseSupplier); + + /** + * This method should be called if a search phase failed to ensure all relevant reader contexts are released. + * This method will also notify the listener and sends back a failure to the user. + * + * @param exception the exception explaining or causing the phase failure + */ + protected final void raisePhaseFailure(SearchPhaseExecutionException exception) { + results.getSuccessfulResults().forEach((entry) -> { + // Do not release search contexts that are part of the point in time + if (entry.getContextId() != null && isPartOfPointInTime(entry.getContextId()) == false) { + try { + SearchShardTarget searchShardTarget = entry.getSearchShardTarget(); + Transport.Connection connection = getConnection(searchShardTarget.getClusterAlias(), searchShardTarget.getNodeId()); + sendReleaseSearchContext(entry.getContextId(), connection); + } catch (Exception inner) { + inner.addSuppressed(exception); + logger.trace("failed to release context", inner); + } + } + }); + outstandingShards.set(0); // we're done no more shards to process, the phase has failed + listener.onFailure(exception); + } + + /** + * Checks if the given context id is part of the point in time of this search (if exists). + * We should not release search contexts that belong to the point in time during or after searches. + */ + public boolean isPartOfPointInTime(ShardSearchContextId contextId) { + return isPartOfPIT(namedWriteableRegistry, request, contextId); + } + + protected final void executePhase(SearchPhase phase) { + try { + phase.run(); + } catch (Exception e) { + if (logger.isDebugEnabled()) { + logger.debug(() -> format("Failed to execute [%s] while moving to [%s] phase", request, phase.getName()), e); + } + onPhaseFailure(phase.getName(), "", e); + } + } + + protected final void handleNotAllSucceeded(String currentPhase, ShardOperationFailedException[] shardSearchFailures, int numShards) { + // check if there are actual failures in the atomic array since + // successful retries can reset the failures to null + if (shardSearchFailures.length > 0) { + if (logger.isDebugEnabled()) { + int numShardFailures = shardSearchFailures.length; + shardSearchFailures = ExceptionsHelper.groupBy(shardSearchFailures); + Throwable cause = ElasticsearchException.guessRootCauses(shardSearchFailures[0].getCause())[0]; + logger.debug(() -> format("%s shards failed for phase: [%s]", numShardFailures, currentPhase), cause); + } + onPhaseFailure(currentPhase, "Partial shards failure", null); + } else { + int discrepancy = numShards - successfulOps.get(); + assert discrepancy > 0 : "discrepancy: " + discrepancy; + if (logger.isDebugEnabled()) { + logger.debug( + "Partial shards failure (unavailable: {}, successful: {}, skipped: {}, num-shards: {}, phase: {})", + discrepancy, + successfulOps.get(), + toSkipShardsIts.size(), + numShards, + currentPhase + ); + } + onPhaseFailure(currentPhase, "Partial shards failure (" + discrepancy + " shards unavailable)", null); + } + } + + protected BytesReference buildSearchContextId() { + var source = request.source(); + return source != null && source.pointInTimeBuilder() != null && source.pointInTimeBuilder().singleSession() == false + ? source.pointInTimeBuilder().getEncodedId() + : null; + } + + /** + * Executed once for every successful shard level request. + * @param result the result returned form the shard + */ + protected void onShardResult(Result result) { + assert result.getShardIndex() != -1 : "shard index is not set"; + assert result.getSearchShardTarget() != null : "search shard target must not be null"; + if (hasShardResponse == false) { + hasShardResponse = true; + } + if (logger.isTraceEnabled()) { + logger.trace("got first-phase result from {}", result.getSearchShardTarget()); + } + results.consumeResult(result, () -> onShardResultConsumed(result)); + } + + /** + * Executed once all shard results have been received and processed + * @see #onShardFailure(int, SearchShardTarget, Exception) + * @see #onShardResult(SearchPhaseResult) + */ + protected abstract void finish(); + + protected void finishShardAndMaybePhase() { + if (finishShard()) { + finish(); + } + } + + private void onShardResultConsumed(Result result) { + successfulOps.incrementAndGet(); + // clean a previous error on this shard group (note, this code will be serialized on the same shardIndex value level + // so its ok concurrency wise to miss potentially the shard failures being created because of another failure + // in the #addShardFailure, because by definition, it will happen on *another* shardIndex + AtomicArray shardFailures = this.shardFailures.get(); + if (shardFailures != null) { + shardFailures.set(result.getShardIndex(), null); + } + finishShardAndMaybePhase(); + } + +} diff --git a/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java index f7b258a9f6b75..13cbca27651a5 100644 --- a/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java @@ -202,7 +202,7 @@ private synchronized void consumeResult(int shardIndex, boolean canMatch, MinAnd private void checkNoMissingShards(List shards) { assert assertSearchCoordinationThread(); - SearchPhase.doCheckNoMissingShards("can_match", request, shards, SearchPhase::makeMissingShardsError); + AbstractSearchAsyncAction.doCheckNoMissingShards("can_match", request, shards, AbstractSearchAsyncAction::makeMissingShardsError); } private Map> groupByNode(List shards) { diff --git a/server/src/main/java/org/elasticsearch/action/search/CountedCollector.java b/server/src/main/java/org/elasticsearch/action/search/CountedCollector.java index 3d15e11a19d31..6eed47fd7f2cf 100644 --- a/server/src/main/java/org/elasticsearch/action/search/CountedCollector.java +++ b/server/src/main/java/org/elasticsearch/action/search/CountedCollector.java @@ -22,9 +22,9 @@ final class CountedCollector { private final SearchPhaseResults resultConsumer; private final CountDown counter; private final Runnable onFinish; - private final AbstractSearchAsyncAction context; + private final AsyncSearchContext context; - CountedCollector(SearchPhaseResults resultConsumer, int expectedOps, Runnable onFinish, AbstractSearchAsyncAction context) { + CountedCollector(SearchPhaseResults resultConsumer, int expectedOps, Runnable onFinish, AsyncSearchContext context) { this.resultConsumer = resultConsumer; this.counter = new CountDown(expectedOps); this.onFinish = onFinish; diff --git a/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java b/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java index faeb552530e47..048fd77fe538d 100644 --- a/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java @@ -8,6 +8,8 @@ */ package org.elasticsearch.action.search; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.join.ScoreMode; import org.elasticsearch.common.lucene.Lucene; @@ -42,6 +44,8 @@ final class DfsQueryPhase extends SearchPhase { public static final String NAME = "dfs_query"; + private static final Logger logger = LogManager.getLogger(DfsQueryPhase.class); + private final SearchPhaseResults queryResult; private final List searchResults; private final AggregatedDfs dfs; @@ -71,7 +75,7 @@ final class DfsQueryPhase extends SearchPhase { } @Override - protected void run() { + public void run() { // TODO we can potentially also consume the actual per shard results from the initial phase here in the aggregateDfs // to free up memory early final CountedCollector counter = new CountedCollector<>( @@ -138,7 +142,7 @@ private void shardFailure( SearchShardTarget shardTarget, CountedCollector counter ) { - context.getLogger().debug(() -> "[" + querySearchRequest.contextId() + "] Failed to execute query phase", exception); + logger.debug(() -> "[" + querySearchRequest.contextId() + "] Failed to execute query phase", exception); progressListener.notifyQueryFailure(shardIndex, shardTarget, exception); counter.onFailure(shardIndex, shardTarget, exception); } diff --git a/server/src/main/java/org/elasticsearch/action/search/ExpandSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/ExpandSearchPhase.java index b0b3f15265920..f513242e63d83 100644 --- a/server/src/main/java/org/elasticsearch/action/search/ExpandSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/ExpandSearchPhase.java @@ -34,11 +34,11 @@ final class ExpandSearchPhase extends SearchPhase { static final String NAME = "expand"; - private final AbstractSearchAsyncAction context; + private final AsyncSearchContext context; private final SearchHits searchHits; private final Supplier nextPhase; - ExpandSearchPhase(AbstractSearchAsyncAction context, SearchHits searchHits, Supplier nextPhase) { + ExpandSearchPhase(AsyncSearchContext context, SearchHits searchHits, Supplier nextPhase) { super(NAME); this.context = context; this.searchHits = searchHits; diff --git a/server/src/main/java/org/elasticsearch/action/search/FetchLookupFieldsPhase.java b/server/src/main/java/org/elasticsearch/action/search/FetchLookupFieldsPhase.java index 2e98d50196490..cb51aeb6642bc 100644 --- a/server/src/main/java/org/elasticsearch/action/search/FetchLookupFieldsPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/FetchLookupFieldsPhase.java @@ -36,14 +36,14 @@ final class FetchLookupFieldsPhase extends SearchPhase { static final String NAME = "fetch_lookup_fields"; - private final AbstractSearchAsyncAction context; + private final AsyncSearchContext context; private final SearchResponseSections searchResponse; - private final AtomicArray queryResults; + private final AtomicArray queryResults; FetchLookupFieldsPhase( - AbstractSearchAsyncAction context, + AsyncSearchContext context, SearchResponseSections searchResponse, - AtomicArray queryResults + AtomicArray queryResults ) { super(NAME); this.context = context; @@ -132,7 +132,7 @@ public void onResponse(MultiSearchResponse items) { } } if (failure != null) { - context.onPhaseFailure(NAME, "failed to fetch lookup fields", failure); + failPhase(failure); } else { context.sendSearchResponse(searchResponse, queryResults); } @@ -140,8 +140,12 @@ public void onResponse(MultiSearchResponse items) { @Override public void onFailure(Exception e) { - context.onPhaseFailure(NAME, "failed to fetch lookup fields", e); + failPhase(e); } }); } + + private void failPhase(Exception e) { + context.onPhaseFailure(NAME, "failed to fetch lookup fields", e); + } } diff --git a/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java index 080295210fced..df72ea632531f 100644 --- a/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java @@ -8,6 +8,7 @@ */ package org.elasticsearch.action.search; +import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.search.ScoreDoc; import org.elasticsearch.common.util.concurrent.AbstractRunnable; @@ -34,23 +35,24 @@ * Then it reaches out to all relevant shards to fetch the topN hits. */ final class FetchSearchPhase extends SearchPhase { + private static final Logger logger = LogManager.getLogger(FetchSearchPhase.class); static final String NAME = "fetch"; - private final AtomicArray searchPhaseShardResults; - private final BiFunction, SearchPhase> nextPhaseFactory; - private final AbstractSearchAsyncAction context; - private final Logger logger; + private final AtomicArray searchPhaseShardResults; + private final BiFunction, SearchPhase> nextPhaseFactory; + private final AsyncSearchContext context; private final SearchProgressListener progressListener; private final AggregatedDfs aggregatedDfs; @Nullable - private final SearchPhaseResults resultConsumer; + private final SearchPhaseResults resultConsumer; private final SearchPhaseController.ReducedQueryPhase reducedQueryPhase; + private final int numShards; FetchSearchPhase( - SearchPhaseResults resultConsumer, + SearchPhaseResults resultConsumer, AggregatedDfs aggregatedDfs, - AbstractSearchAsyncAction context, + AsyncSearchContext context, @Nullable SearchPhaseController.ReducedQueryPhase reducedQueryPhase ) { this( @@ -67,26 +69,18 @@ final class FetchSearchPhase extends SearchPhase { } FetchSearchPhase( - SearchPhaseResults resultConsumer, + SearchPhaseResults resultConsumer, AggregatedDfs aggregatedDfs, - AbstractSearchAsyncAction context, + AsyncSearchContext context, @Nullable SearchPhaseController.ReducedQueryPhase reducedQueryPhase, - BiFunction, SearchPhase> nextPhaseFactory + BiFunction, SearchPhase> nextPhaseFactory ) { super(NAME); - if (context.getNumShards() != resultConsumer.getNumShards()) { - throw new IllegalStateException( - "number of shards must match the length of the query results but doesn't:" - + context.getNumShards() - + "!=" - + resultConsumer.getNumShards() - ); - } this.searchPhaseShardResults = resultConsumer.getAtomicArray(); + this.numShards = resultConsumer.getNumShards(); this.aggregatedDfs = aggregatedDfs; this.nextPhaseFactory = nextPhaseFactory; this.context = context; - this.logger = context.getLogger(); this.progressListener = context.getTask().getProgressListener(); this.reducedQueryPhase = reducedQueryPhase; this.resultConsumer = reducedQueryPhase == null ? resultConsumer : null; @@ -103,23 +97,26 @@ protected void doRun() throws Exception { @Override public void onFailure(Exception e) { - context.onPhaseFailure(NAME, "", e); + failPhase(e); } }); } + private void failPhase(Exception e) { + context.onPhaseFailure(NAME, "", e); + } + private void innerRun() throws Exception { assert this.reducedQueryPhase == null ^ this.resultConsumer == null; // depending on whether we executed the RankFeaturePhase we may or may not have the reduced query result computed already final var reducedQueryPhase = this.reducedQueryPhase == null ? resultConsumer.reduce() : this.reducedQueryPhase; - final int numShards = context.getNumShards(); + var request = context.getRequest(); // Usually when there is a single shard, we force the search type QUERY_THEN_FETCH. But when there's kNN, we might // still use DFS_QUERY_THEN_FETCH, which does not perform the "query and fetch" optimization during the query phase. - final boolean queryAndFetchOptimization = searchPhaseShardResults.length() == 1 - && context.getRequest().hasKnnSearch() == false + if (numShards == 1 + && request.hasKnnSearch() == false && reducedQueryPhase.queryPhaseRankCoordinatorContext() == null - && (context.getRequest().source() == null || context.getRequest().source().rankBuilder() == null); - if (queryAndFetchOptimization) { + && (request.source() == null || request.source().rankBuilder() == null)) { assert assertConsistentWithQueryAndFetchOptimization(); // query AND fetch optimization moveToNextPhase(searchPhaseShardResults, reducedQueryPhase); @@ -221,9 +218,9 @@ private void executeFetch( ) { final SearchShardTarget shardTarget = shardPhaseResult.getSearchShardTarget(); final int shardIndex = shardPhaseResult.getShardIndex(); - final ShardSearchContextId contextId = shardPhaseResult.queryResult() != null - ? shardPhaseResult.queryResult().getContextId() - : shardPhaseResult.rankFeatureResult().getContextId(); + final ShardSearchContextId contextId = (shardPhaseResult.queryResult() != null + ? shardPhaseResult.queryResult() + : shardPhaseResult.rankFeatureResult()).getContextId(); var listener = new SearchActionListener(shardTarget, shardIndex) { @Override public void innerOnResponse(FetchSearchResult result) { @@ -231,7 +228,7 @@ public void innerOnResponse(FetchSearchResult result) { progressListener.notifyFetchResult(shardIndex); counter.onResult(result); } catch (Exception e) { - context.onPhaseFailure(NAME, "", e); + failPhase(e); } } @@ -239,8 +236,8 @@ public void innerOnResponse(FetchSearchResult result) { public void onFailure(Exception e) { try { logger.debug(() -> "[" + contextId + "] Failed to execute fetch phase", e); - progressListener.notifyFetchFailure(shardIndex, shardTarget, e); - counter.onFailure(shardIndex, shardTarget, e); + progressListener.notifyFetchFailure(shardIndex, searchShardTarget, e); + counter.onFailure(shardIndex, searchShardTarget, e); } finally { // the search context might not be cleared on the node where the fetch was executed for example // because the action was rejected by the thread pool. in this case we need to send a dedicated diff --git a/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java b/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java index 9a8dd94dcd324..42860290e9019 100644 --- a/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java +++ b/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java @@ -16,8 +16,15 @@ import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.breaker.CircuitBreakingException; import org.elasticsearch.common.io.stream.DelayableWriteable; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.lucene.Lucene; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.common.util.concurrent.AbstractRunnable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchService; import org.elasticsearch.search.SearchShardTarget; @@ -27,6 +34,7 @@ import org.elasticsearch.search.query.QuerySearchResult; import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext; +import java.io.IOException; import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Collections; @@ -66,7 +74,7 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults onPartialMergeFailure; private final int batchReduceSize; - private List buffer = new ArrayList<>(); + List buffer = new ArrayList<>(); private List emptyResults = new ArrayList<>(); // the memory that is accounted in the circuit breaker for this consumer private volatile long circuitBreakerBytes; @@ -76,9 +84,9 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults queue = new ArrayDeque<>(); private final AtomicReference runningTask = new AtomicReference<>(); - private final AtomicReference failure = new AtomicReference<>(); + public final AtomicReference failure = new AtomicReference<>(); - private final TopDocsStats topDocsStats; + public final TopDocsStats topDocsStats; private volatile MergeResult mergeResult; private volatile boolean hasPartialReduce; private volatile int numReducePhases; @@ -86,6 +94,8 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults isCanceled, SearchProgressListener progressListener, int expectedResultSize, + int batchReduceSize, Consumer onPartialMergeFailure ) { super(expectedResultSize); @@ -114,7 +125,13 @@ public QueryPhaseResultConsumer( this.hasTopDocs = (source == null || size != 0) && queryPhaseRankCoordinatorContext == null; this.hasAggs = source != null && source.aggregations() != null; this.aggReduceContextBuilder = hasAggs ? controller.getReduceContext(isCanceled, source.aggregations()) : null; - batchReduceSize = (hasAggs || hasTopDocs) ? Math.min(request.getBatchedReduceSize(), expectedResultSize) : expectedResultSize; + if (batchReduceSize >= 0) { + this.batchReduceSize = batchReduceSize; + } else { + this.batchReduceSize = (hasAggs || hasTopDocs) + ? Math.min(request.getBatchedReduceSize(), expectedResultSize) + : expectedResultSize; + } topDocsStats = new TopDocsStats(request.resolveTrackTotalHitsUpTo()); } @@ -146,9 +163,37 @@ public void consumeResult(SearchPhaseResult result, Runnable next) { super.consumeResult(result, () -> {}); QuerySearchResult querySearchResult = result.queryResult(); progressListener.notifyQueryResult(querySearchResult.getShardIndex(), querySearchResult); + assert result.getShardIndex() == querySearchResult.getShardIndex(); consume(querySearchResult, next); } + private final List> batchedResults = new ArrayList<>(); + + public MergeResult consumePartialResult() { + var mergeResult = this.mergeResult; + this.mergeResult = null; + assert runningTask.get() == null; + final List buffer; + synchronized (this) { + buffer = this.buffer; + } + if (buffer != null && buffer.isEmpty() == false) { + this.buffer = null; + buffer.sort(RESULT_COMPARATOR); + mergeResult = partialReduce(buffer, emptyResults, topDocsStats, mergeResult, numReducePhases++); + emptyResults = null; + } + return mergeResult; + } + + public void addPartialResult(TopDocsStats topDocsStats, MergeResult mergeResult) { + if (mergeResult.processedShards.isEmpty() == false) { + synchronized (batchedResults) { + batchedResults.add(new Tuple<>(topDocsStats, mergeResult)); + } + } + } + @Override public SearchPhaseController.ReducedQueryPhase reduce() throws Exception { if (hasPendingMerges()) { @@ -166,24 +211,32 @@ public SearchPhaseController.ReducedQueryPhase reduce() throws Exception { buffer = this.buffer; buffer = buffer == null ? Collections.emptyList() : buffer; this.buffer = null; + } // ensure consistent ordering buffer.sort(RESULT_COMPARATOR); final TopDocsStats topDocsStats = this.topDocsStats; var mergeResult = this.mergeResult; - this.mergeResult = null; - final int resultSize = buffer.size() + (mergeResult == null ? 0 : 1); + final List> batchedResults; + synchronized (this.batchedResults) { + batchedResults = this.batchedResults; + } + final int resultSize = buffer.size() + (mergeResult == null ? 0 : 1) + batchedResults.size(); final List topDocsList = hasTopDocs ? new ArrayList<>(resultSize) : null; final List> aggsList = hasAggs ? new ArrayList<>(resultSize) : null; if (mergeResult != null) { - if (topDocsList != null) { - topDocsList.add(mergeResult.reducedTopDocs); - } - if (aggsList != null) { - aggsList.add(DelayableWriteable.referencing(mergeResult.reducedAggs)); - } + this.mergeResult = null; + consumePartialMergeResult(mergeResult, topDocsList, aggsList); + } + for (int i = 0; i < batchedResults.size(); i++) { + Tuple batchedResult = batchedResults.set(i, null); + consumePartialMergeResult(batchedResult.v2(), topDocsList, aggsList); + topDocsStats.add(batchedResult.v1()); } for (QuerySearchResult result : buffer) { + if (result.isReduced()) { + continue; + } topDocsStats.add(result.topDocs(), result.searchTimedOut(), result.terminatedEarly()); if (topDocsList != null) { TopDocsAndMaxScore topDocs = result.consumeTopDocs(); @@ -206,7 +259,7 @@ public SearchPhaseController.ReducedQueryPhase reduce() throws Exception { aggsList, topDocsList == null ? Collections.emptyList() : topDocsList, topDocsStats, - numReducePhases, + 2, false, aggReduceContextBuilder, queryPhaseRankCoordinatorContext, @@ -236,6 +289,19 @@ public SearchPhaseController.ReducedQueryPhase reduce() throws Exception { } + private static void consumePartialMergeResult( + MergeResult partialResult, + List topDocsList, + List> aggsList + ) { + if (topDocsList != null) { + topDocsList.add(partialResult.reducedTopDocs); + } + if (aggsList != null) { + aggsList.add(DelayableWriteable.referencing(partialResult.reducedAggs)); + } + } + private static final Comparator RESULT_COMPARATOR = Comparator.comparingInt(QuerySearchResult::getShardIndex); private MergeResult partialReduce( @@ -284,12 +350,15 @@ private MergeResult partialReduce( } } // we have to merge here in the same way we collect on a shard - newTopDocs = topDocsList == null ? null : mergeTopDocs(topDocsList, topNSize, 0); + newTopDocs = topDocsList == null ? Lucene.EMPTY_TOP_DOCS : mergeTopDocs(topDocsList, topNSize, 0); newAggs = aggsList == null ? null : InternalAggregations.topLevelReduceDelayable(aggsList, aggReduceContextBuilder.forPartialReduction()); } finally { releaseAggs(toConsume); + for (QuerySearchResult querySearchResult : toConsume) { + querySearchResult.setReduced(); + } } if (lastMerge != null) { processedShards.addAll(lastMerge.processedShards); @@ -306,7 +375,7 @@ public int getNumReducePhases() { return numReducePhases; } - private boolean hasFailure() { + public boolean hasFailure() { return failure.get() != null; } @@ -351,8 +420,15 @@ private void consume(QuerySearchResult result, Runnable next) { if (hasFailure()) { result.consumeAll(); next.run(); - } else if (result.isNull()) { - result.consumeAll(); + } else if (result.isNull() || result.isReduced()) { + if (result.isReduced()) { + if (result.hasConsumedTopDocs() == false) { + result.consumeTopDocs(); + } + result.releaseAggs(); + } else { + result.consumeAll(); + } SearchShardTarget target = result.getSearchShardTarget(); SearchShard searchShard = new SearchShard(target.getClusterAlias(), target.getShardId()); synchronized (this) { @@ -522,12 +598,33 @@ private static void releaseAggs(List toConsume) { } } - private record MergeResult( + public record MergeResult( List processedShards, TopDocs reducedTopDocs, - InternalAggregations reducedAggs, + @Nullable InternalAggregations reducedAggs, long estimatedSize - ) {} + ) implements Writeable { + + static MergeResult readFrom(StreamInput in) throws IOException { + return new MergeResult( + in.readCollectionAsImmutableList(i -> new SearchShard(i.readOptionalString(), new ShardId(i))), + Lucene.readTopDocsOnly(in), + in.readOptionalWriteable(InternalAggregations::readFrom), + in.readVLong() + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeCollection(processedShards, (o, s) -> { + o.writeOptionalString(s.clusterAlias()); + s.shardId().writeTo(o); + }); + Lucene.writeTopDocsIncludingShardIndex(out, reducedTopDocs); + out.writeOptionalWriteable(reducedAggs); + out.writeVLong(estimatedSize); + } + } private static class MergeTask { private final List emptyResults; diff --git a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java index e9302883457e1..1aacf20e6e587 100644 --- a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java @@ -40,34 +40,26 @@ public class RankFeaturePhase extends SearchPhase { static final String NAME = "rank-feature"; private static final Logger logger = LogManager.getLogger(RankFeaturePhase.class); - private final AbstractSearchAsyncAction context; - final SearchPhaseResults queryPhaseResults; + private final AsyncSearchContext context; + final SearchPhaseResults queryPhaseResults; final SearchPhaseResults rankPhaseResults; private final AggregatedDfs aggregatedDfs; private final SearchProgressListener progressListener; private final RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext; RankFeaturePhase( - SearchPhaseResults queryPhaseResults, + SearchPhaseResults queryPhaseResults, AggregatedDfs aggregatedDfs, - AbstractSearchAsyncAction context, + AsyncSearchContext context, RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext ) { super(NAME); assert rankFeaturePhaseRankCoordinatorContext != null; this.rankFeaturePhaseRankCoordinatorContext = rankFeaturePhaseRankCoordinatorContext; - if (context.getNumShards() != queryPhaseResults.getNumShards()) { - throw new IllegalStateException( - "number of shards must match the length of the query results but doesn't:" - + context.getNumShards() - + "!=" - + queryPhaseResults.getNumShards() - ); - } this.context = context; this.queryPhaseResults = queryPhaseResults; this.aggregatedDfs = aggregatedDfs; - this.rankPhaseResults = new ArraySearchPhaseResults<>(context.getNumShards()); + this.rankPhaseResults = new ArraySearchPhaseResults<>(queryPhaseResults.getNumShards()); context.addReleasable(rankPhaseResults); this.progressListener = context.getTask().getProgressListener(); } @@ -86,20 +78,24 @@ protected void doRun() throws Exception { @Override public void onFailure(Exception e) { - context.onPhaseFailure(NAME, "", e); + failPhase("", e); } }); } + private void failPhase(String msg, Exception e) { + context.onPhaseFailure(NAME, msg, e); + } + void innerRun(RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext) throws Exception { // if the RankBuilder specifies a QueryPhaseCoordinatorContext, it will be called as part of the reduce call // to operate on the first `rank_window_size * num_shards` results and merge them appropriately. SearchPhaseController.ReducedQueryPhase reducedQueryPhase = queryPhaseResults.reduce(); ScoreDoc[] queryScoreDocs = reducedQueryPhase.sortedTopDocs().scoreDocs(); // rank_window_size - final List[] docIdsToLoad = SearchPhaseController.fillDocIdsToLoad(context.getNumShards(), queryScoreDocs); + final List[] docIdsToLoad = SearchPhaseController.fillDocIdsToLoad(queryPhaseResults.getNumShards(), queryScoreDocs); final CountedCollector rankRequestCounter = new CountedCollector<>( rankPhaseResults, - context.getNumShards(), + queryPhaseResults.getNumShards(), () -> onPhaseDone(rankFeaturePhaseRankCoordinatorContext, reducedQueryPhase), context ); @@ -141,7 +137,7 @@ protected void innerOnResponse(RankFeatureResult response) { progressListener.notifyRankFeatureResult(shardIndex); rankRequestCounter.onResult(response); } catch (Exception e) { - context.onPhaseFailure(NAME, "", e); + failPhase("", e); } } @@ -196,7 +192,7 @@ public void onResponse(RankFeatureDoc[] docsWithUpdatedScores) { @Override public void onFailure(Exception e) { - context.onPhaseFailure(NAME, "Computing updated ranks for results failed", e); + failPhase("Computing updated ranks for results failed", e); } } ); diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchActionListener.java b/server/src/main/java/org/elasticsearch/action/search/SearchActionListener.java index 237449881fba1..b44a4b7cc37b3 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchActionListener.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchActionListener.java @@ -19,7 +19,7 @@ abstract class SearchActionListener implements ActionListener { final int requestIndex; - private final SearchShardTarget searchShardTarget; + protected final SearchShardTarget searchShardTarget; protected SearchActionListener(SearchShardTarget searchShardTarget, int shardIndex) { assert shardIndex >= 0 : "shard index must be positive"; diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchContextId.java b/server/src/main/java/org/elasticsearch/action/search/SearchContextId.java index c2f1510341fb0..54e3ba205dccb 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchContextId.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchContextId.java @@ -58,7 +58,7 @@ public boolean contains(ShardSearchContextId contextId) { } public static BytesReference encode( - List searchPhaseResults, + List searchPhaseResults, Map aliasFilter, TransportVersion version, ShardSearchFailure[] shardFailures diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/SearchPhase.java index 1308a2fb61cfb..4a8c6906f7db7 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchPhase.java @@ -8,18 +8,21 @@ */ package org.elasticsearch.action.search; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchShardTarget; +import org.elasticsearch.search.query.QuerySearchResult; import org.elasticsearch.transport.Transport; -import java.util.List; import java.util.Objects; -import java.util.function.Function; /** * Base class for all individual search phases like collecting distributed frequencies, fetching documents, querying shards. */ abstract class SearchPhase { + private static final Logger logger = LogManager.getLogger(SearchPhase.class); + private final String name; protected SearchPhase(String name) { @@ -35,71 +38,27 @@ public String getName() { return name; } - protected String missingShardsErrorMessage(StringBuilder missingShards) { - return makeMissingShardsError(missingShards); - } - - protected static String makeMissingShardsError(StringBuilder missingShards) { - return "Search rejected due to missing shards [" - + missingShards - + "]. Consider using `allow_partial_search_results` setting to bypass this error."; - } - - protected void doCheckNoMissingShards(String phaseName, SearchRequest request, List shardsIts) { - doCheckNoMissingShards(phaseName, request, shardsIts, this::missingShardsErrorMessage); - } - - protected static void doCheckNoMissingShards( - String phaseName, - SearchRequest request, - List shardsIts, - Function makeErrorMessage - ) { - assert request.allowPartialSearchResults() != null : "SearchRequest missing setting for allowPartialSearchResults"; - if (request.allowPartialSearchResults() == false) { - final StringBuilder missingShards = new StringBuilder(); - // Fail-fast verification of all shards being available - for (int index = 0; index < shardsIts.size(); index++) { - final SearchShardIterator shardRoutings = shardsIts.get(index); - if (shardRoutings.size() == 0) { - if (missingShards.isEmpty() == false) { - missingShards.append(", "); - } - missingShards.append(shardRoutings.shardId()); - } - } - if (missingShards.isEmpty() == false) { - // Status red - shard is missing all copies and would produce partial results for an index search - final String msg = makeErrorMessage.apply(missingShards); - throw new SearchPhaseExecutionException(phaseName, msg, null, ShardSearchFailure.EMPTY_ARRAY); - } - } - } - /** * Releases shard targets that are not used in the docsIdsToLoad. */ - protected static void releaseIrrelevantSearchContext(SearchPhaseResult searchPhaseResult, AbstractSearchAsyncAction context) { + protected static void releaseIrrelevantSearchContext(SearchPhaseResult searchPhaseResult, AsyncSearchContext context) { // we only release search context that we did not fetch from, if we are not scrolling // or using a PIT and if it has at least one hit that didn't make it to the global topDocs - if (searchPhaseResult == null) { - return; - } // phaseResult.getContextId() is the same for query & rank feature results SearchPhaseResult phaseResult = searchPhaseResult.queryResult() != null ? searchPhaseResult.queryResult() : searchPhaseResult.rankFeatureResult(); if (phaseResult != null - && phaseResult.hasSearchContext() + && (phaseResult.hasSearchContext() || (phaseResult instanceof QuerySearchResult q && q.isReduced() && q.getContextId() != null)) && context.getRequest().scroll() == null - && (context.isPartOfPointInTime(phaseResult.getContextId()) == false)) { + && (AsyncSearchContext.isPartOfPIT(null, context.getRequest(), phaseResult.getContextId()) == false)) { try { - context.getLogger().trace("trying to release search context [{}]", phaseResult.getContextId()); + logger.trace("trying to release search context [{}]", phaseResult.getContextId()); SearchShardTarget shardTarget = phaseResult.getSearchShardTarget(); Transport.Connection connection = context.getConnection(shardTarget.getClusterAlias(), shardTarget.getNodeId()); context.sendReleaseSearchContext(phaseResult.getContextId(), connection); } catch (Exception e) { - context.getLogger().trace("failed to release context", e); + logger.trace("failed to release context", e); } } } diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java index f8736ab79690e..e174196110ecf 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java @@ -21,6 +21,9 @@ import org.apache.lucene.search.TotalHits.Relation; import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.io.stream.DelayableWriteable; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.lucene.Lucene; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.common.util.Maps; @@ -51,6 +54,7 @@ import org.elasticsearch.search.suggest.Suggest.Suggestion; import org.elasticsearch.search.suggest.completion.CompletionSuggestion; +import java.io.IOException; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -483,7 +487,8 @@ static ReducedQueryPhase reducedQueryPhase( } } } - assert bufferedTopDocs.isEmpty() || result.hasConsumedTopDocs() : "firstResult has no aggs but we got non null buffered aggs?"; + assert bufferedTopDocs.isEmpty() || result.hasConsumedTopDocs() || result.isReduced() + : "firstResult has no aggs but we got non null buffered aggs?"; if (hasProfileResults) { profileShardResults.put(result.getSearchShardTarget().toString(), result.consumeProfileResult()); } @@ -689,11 +694,12 @@ SearchPhaseResults newSearchPhaseResults( isCanceled, listener, numShards, + -1, onPartialMergeFailure ); } - public static final class TopDocsStats { + public static final class TopDocsStats implements Writeable { final int trackTotalHitsUpTo; long totalHits; private TotalHits.Relation totalHitsRelation; @@ -733,6 +739,29 @@ TotalHits getTotalHits() { } } + void add(TopDocsStats other) { + if (trackTotalHitsUpTo != SearchContext.TRACK_TOTAL_HITS_DISABLED) { + totalHits += other.totalHits; + if (other.totalHitsRelation == Relation.GREATER_THAN_OR_EQUAL_TO) { + totalHitsRelation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO; + } + } + fetchHits += other.fetchHits; + if (Float.isNaN(other.maxScore) == false) { + maxScore = Math.max(maxScore, other.maxScore); + } + if (other.timedOut) { + this.timedOut = true; + } + if (other.terminatedEarly != null) { + if (this.terminatedEarly == null) { + this.terminatedEarly = other.terminatedEarly; + } else if (terminatedEarly) { + this.terminatedEarly = true; + } + } + } + void add(TopDocsAndMaxScore topDocs, boolean timedOut, Boolean terminatedEarly) { if (trackTotalHitsUpTo != SearchContext.TRACK_TOTAL_HITS_DISABLED) { totalHits += topDocs.topDocs.totalHits.value(); @@ -755,6 +784,30 @@ void add(TopDocsAndMaxScore topDocs, boolean timedOut, Boolean terminatedEarly) } } } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVInt(trackTotalHitsUpTo); + out.writeFloat(maxScore); + Lucene.writeTotalHits(out, new TotalHits(totalHits, totalHitsRelation)); + out.writeVLong(fetchHits); + out.writeFloat(maxScore); + out.writeBoolean(timedOut); + out.writeOptionalBoolean(terminatedEarly); + } + + public static TopDocsStats readFrom(StreamInput in) throws IOException { + TopDocsStats res = new TopDocsStats(in.readVInt()); + res.maxScore = in.readFloat(); + TotalHits totalHits = Lucene.readTotalHits(in); + res.totalHits = totalHits.value(); + res.totalHitsRelation = totalHits.relation(); + res.fetchHits = in.readVLong(); + res.maxScore = in.readFloat(); + res.timedOut = in.readBoolean(); + res.terminatedEarly = in.readOptionalBoolean(); + return res; + } } public record SortedTopDocs( diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java index 088a16deb76dc..45df36847946a 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java @@ -9,29 +9,87 @@ package org.elasticsearch.action.search; +import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopFieldDocs; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.IndicesRequest; +import org.elasticsearch.action.NoShardAvailableActionException; +import org.elasticsearch.action.OriginalIndices; +import org.elasticsearch.action.ShardOperationFailedException; +import org.elasticsearch.action.support.ChannelActionListener; +import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.action.support.TransportActions; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.NoopCircuitBreaker; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.lucene.Lucene; +import org.elasticsearch.common.util.Maps; +import org.elasticsearch.common.util.concurrent.AbstractRunnable; +import org.elasticsearch.common.util.concurrent.AtomicArray; +import org.elasticsearch.common.util.concurrent.CountDown; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.core.RefCounted; +import org.elasticsearch.core.SimpleRefCounted; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.search.SearchPhaseResult; +import org.elasticsearch.search.SearchService; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.dfs.AggregatedDfs; import org.elasticsearch.search.internal.AliasFilter; import org.elasticsearch.search.internal.SearchContext; +import org.elasticsearch.search.internal.ShardSearchContextId; import org.elasticsearch.search.internal.ShardSearchRequest; import org.elasticsearch.search.query.QuerySearchResult; +import org.elasticsearch.tasks.CancellableTask; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.LeakTracker; import org.elasticsearch.transport.Transport; +import org.elasticsearch.transport.TransportChannel; +import org.elasticsearch.transport.TransportException; +import org.elasticsearch.transport.TransportRequest; +import org.elasticsearch.transport.TransportResponse; +import org.elasticsearch.transport.TransportResponseHandler; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiFunction; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import static org.elasticsearch.action.search.AbstractSearchAsyncAction.DEFAULT_INDEX_BOOST; import static org.elasticsearch.action.search.SearchPhaseController.getTopDocsSize; +import static org.elasticsearch.core.Strings.format; -class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction { +public class SearchQueryThenFetchAsyncAction extends AsyncSearchContext { + + private static final String NAME = "query"; + + private static final Logger logger = LogManager.getLogger(SearchQueryThenFetchAsyncAction.class); private final SearchProgressListener progressListener; @@ -42,14 +100,13 @@ class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction nodeIdToConnection, Map aliasFilter, Map concreteIndexBoosts, Executor executor, - SearchPhaseResults resultConsumer, + SearchPhaseResults resultConsumer, SearchRequest request, ActionListener listener, List shardsIts, @@ -60,55 +117,318 @@ class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction listener + /** + * This is the main entry point for a search. This method starts the search execution of the initial phase. + */ + public final void start() { + if (results.getNumShards() == 0) { + sendZeroShardsResponse(); + return; + } + try { + if (shardsIts.isEmpty()) { + executeNextPhase(NAME, this::getNextPhase); + return; + } + run(); + } catch (Exception e) { + if (logger.isDebugEnabled()) { + logger.debug(() -> format("Failed to execute [%s] while moving to [" + NAME + "] phase", request), e); + } + onPhaseFailure(NAME, "", e); + } + } + + /** + * Builds and sends the final search response back to the user. + * + * @param internalSearchResponse the internal search response + * @param queryResults the results of the query phase + */ + public void sendSearchResponse(SearchResponseSections internalSearchResponse, AtomicArray queryResults) { + ShardSearchFailure[] failures = buildShardFailures(shardFailures); + Boolean allowPartialResults = request.allowPartialSearchResults(); + assert allowPartialResults != null : "SearchRequest missing setting for allowPartialSearchResults"; + if (allowPartialResults == false && failures.length > 0) { + raisePhaseFailure(new SearchPhaseExecutionException("", "Shard failures", null, failures)); + } else { + final String scrollId = request.scroll() != null ? TransportSearchHelper.buildScrollId(queryResults) : null; + ActionListener.respondAndRelease( + listener, + buildSearchResponse(internalSearchResponse, failures, scrollId, buildSearchContextId()) + ); + } + } + + private SearchResponse buildSearchResponse( + SearchResponseSections internalSearchResponse, + ShardSearchFailure[] failures, + String scrollId, + BytesReference searchContextId ) { - ShardSearchRequest request = rewriteShardSearchRequest(super.buildShardSearchRequest(shardIt, listener.requestIndex)); - getSearchTransport().sendExecuteQuery(connection, request, getTask(), listener); + int numSuccess = successfulOps.get(); + int numFailures = failures.length; + assert numSuccess + numFailures == results.getNumShards() + : "numSuccess(" + numSuccess + ") + numFailures(" + numFailures + ") != totalShards(" + results.getNumShards() + ")"; + return new SearchResponse( + internalSearchResponse, + scrollId, + results.getNumShards(), + numSuccess, + toSkipShardsIts.size(), + timeProvider.buildTookInMillis(), + failures, + clusters, + searchContextId + ); + } + + /** + * Builds an request for the initial search phase. + * + * @param shardIndex the index of the shard that is used in the coordinator node to + * tiebreak results with identical sort values + */ + private static ShardSearchRequest buildShardSearchRequest( + ShardId shardId, + String clusterAlias, + int shardIndex, + ShardSearchContextId searchContextId, + OriginalIndices originalIndices, + AliasFilter aliasFilter, + TimeValue searchContextKeepAlive, + float indexBoost, + SearchRequest searchRequest, + int totalShardCount, + long absoluteStartMillis, + boolean hasResponse + ) { + ShardSearchRequest shardRequest = new ShardSearchRequest( + originalIndices, + searchRequest, + shardId, + shardIndex, + totalShardCount, + aliasFilter, + indexBoost, + absoluteStartMillis, + clusterAlias, + searchContextId, + searchContextKeepAlive + ); + // if we already received a search result we can inform the shard that it + // can return a null response if the request rewrites to match none rather + // than creating an empty response in the search thread pool. + // Note that, we have to disable this shortcut for queries that create a context (scroll and search context). + shardRequest.canReturnNullResponseIfMatchNoDocs(hasResponse && shardRequest.scroll() == null); + return shardRequest; + } + + public static class NodeQueryResponse extends TransportResponse { + + private final RefCounted refCounted = LeakTracker.wrap(new SimpleRefCounted()); + + private final Object[] results; + private final SearchPhaseController.TopDocsStats topDocsStats; + private final QueryPhaseResultConsumer.MergeResult mergeResult; + + NodeQueryResponse(StreamInput in) throws IOException { + super(in); + this.results = in.readArray(i -> i.readBoolean() ? new QuerySearchResult(i) : i.readException(), Object[]::new); + this.mergeResult = QueryPhaseResultConsumer.MergeResult.readFrom(in); + this.topDocsStats = SearchPhaseController.TopDocsStats.readFrom(in); + } + + NodeQueryResponse( + QueryPhaseResultConsumer.MergeResult mergeResult, + Object[] results, + SearchPhaseController.TopDocsStats topDocsStats + ) { + this.results = results; + for (Object result : results) { + if (result instanceof RefCounted r) { + r.incRef(); + } + } + this.mergeResult = mergeResult; + this.topDocsStats = topDocsStats; + assert Arrays.stream(results).noneMatch(Objects::isNull) : Arrays.toString(results); + } + + // public for tests + public Object[] getResults() { + return results; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeArray((o, v) -> { + if (v instanceof Exception e) { + o.writeBoolean(false); + o.writeException(e); + } else { + o.writeBoolean(true); + assert v instanceof QuerySearchResult : v; + ((QuerySearchResult) v).writeTo(o); + } + }, results); + mergeResult.writeTo(out); + topDocsStats.writeTo(out); + } + + @Override + public void incRef() { + refCounted.incRef(); + } + + @Override + public boolean tryIncRef() { + return refCounted.tryIncRef(); + } + + @Override + public boolean hasReferences() { + return refCounted.hasReferences(); + } + + @Override + public boolean decRef() { + if (refCounted.decRef()) { + for (int i = 0; i < results.length; i++) { + Object result = results[i]; + if (result instanceof RefCounted r) { + r.decRef(); + } + results[i] = null; + } + return true; + } + return false; + } + } + + public static class NodeQueryRequest extends TransportRequest implements IndicesRequest { + private final List shards; + private final SearchRequest searchRequest; + private final Map aliasFilters; + private final int totalShards; + private final long absoluteStartMillis; + + private NodeQueryRequest( + List shards, + SearchRequest searchRequest, + Map aliasFilters, + int totalShards, + long absoluteStartMillis + ) { + this.shards = shards; + this.searchRequest = searchRequest; + this.aliasFilters = aliasFilters; + this.totalShards = totalShards; + this.absoluteStartMillis = absoluteStartMillis; + } + + private NodeQueryRequest(StreamInput in) throws IOException { + super(in); + this.shards = in.readCollectionAsImmutableList(ShardToQuery::readFrom); + this.searchRequest = new SearchRequest(in); + this.aliasFilters = in.readImmutableMap(AliasFilter::readFrom); + this.totalShards = in.readVInt(); + this.absoluteStartMillis = in.readLong(); + } + + @Override + public Task createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { + return new SearchShardTask(id, type, action, "NodeQueryRequest", parentTaskId, headers); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeCollection(shards); + searchRequest.writeTo(out); + out.writeMap(aliasFilters, (o, v) -> v.writeTo(o)); + out.writeVInt(totalShards); + out.writeLong(absoluteStartMillis); + } + + @Override + public String[] indices() { + return shards.stream().map(s -> s.originalIndices().indices()).flatMap(Arrays::stream).distinct().toArray(String[]::new); + } + + @Override + public IndicesOptions indicesOptions() { + return shards.getFirst().originalIndices.indicesOptions(); + } + } + + private record ShardToQuery( + float boost, + OriginalIndices originalIndices, + int shardIndex, + ShardId shardId, + ShardSearchContextId contextId + ) implements Writeable { + + static ShardToQuery readFrom(StreamInput in) throws IOException { + return new ShardToQuery( + in.readFloat(), + OriginalIndices.readOriginalIndices(in), + in.readVInt(), + new ShardId(in), + in.readOptionalWriteable(ShardSearchContextId::new) + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeFloat(boost); + OriginalIndices.writeOriginalIndices(originalIndices, out); + out.writeVInt(shardIndex); + shardId.writeTo(out); + out.writeOptionalWriteable(contextId); + } } - @Override protected void onShardGroupFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) { progressListener.notifyQueryFailure(shardIndex, shardTarget, exc); } @Override - protected void onShardResult(SearchPhaseResult result, SearchShardIterator shardIt) { + protected void onShardResult(Result result) { QuerySearchResult queryResult = result.queryResult(); if (queryResult.isNull() == false // disable sort optims for scroll requests because they keep track of the last bottom doc locally (per shard) - && getRequest().scroll() == null + && request.scroll() == null // top docs are already consumed if the query was cancelled or in error. && queryResult.hasConsumedTopDocs() == false && queryResult.topDocs() != null @@ -123,13 +443,13 @@ && getRequest().scroll() == null } bottomSortCollector.consumeTopDocs(topDocs, queryResult.sortValueFormats()); } - super.onShardResult(result, shardIt); + super.onShardResult(result); } static SearchPhase nextPhase( Client client, - AbstractSearchAsyncAction context, - SearchPhaseResults queryResults, + AsyncSearchContext context, + SearchPhaseResults queryResults, AggregatedDfs aggregatedDfs ) { var rankFeaturePhaseCoordCtx = RankFeaturePhase.coordinatorContext(context.getRequest().source(), client); @@ -139,12 +459,15 @@ static SearchPhase nextPhase( return new RankFeaturePhase(queryResults, aggregatedDfs, context, rankFeaturePhaseCoordCtx); } - @Override protected SearchPhase getNextPhase() { return nextPhase(client, this, results, null); } - private ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) { + private static ShardSearchRequest rewriteShardSearchRequest( + BottomSortValuesCollector bottomSortCollector, + int trackTotalHitsUpTo, + ShardSearchRequest request + ) { if (bottomSortCollector == null) { return request; } @@ -160,4 +483,565 @@ private ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) } return request; } + + private void run() { + // TODO: stupid but we kinda need to fill all of these in with the current logic, do something nicer before merging + final Map shardIndexMap = Maps.newHashMapWithExpectedSize(shardIterators.length); + for (int i = 0; i < shardIterators.length; i++) { + shardIndexMap.put(shardIterators[i], i); + } + final boolean supportsBatchedQuery = minTransportVersion.onOrAfter(TransportVersions.BATCHED_QUERY_PHASE_VERSION); + final Map perNodeQueries = new HashMap<>(); + AbstractSearchAsyncAction.doCheckNoMissingShards(NAME, request, shardsIts, AbstractSearchAsyncAction::makeMissingShardsError); + final String localNodeId = searchTransportService.transportService().getLocalNode().getId(); + final String localClusterAlias = request.getLocalClusterAlias(); + for (int i = 0; i < shardsIts.size(); i++) { + final SearchShardIterator shardRoutings = shardsIts.get(i); + assert shardRoutings.skip() == false; + assert shardIndexMap.containsKey(shardRoutings); + int shardIndex = shardIndexMap.get(shardRoutings); + final SearchShardTarget routing = shardRoutings.nextOrNull(); + if (routing == null) { + failOnUnavailable(shardIndex, shardRoutings); + } else { + String clusterAlias = routing.getClusterAlias(); + final String nodeId = routing.getNodeId(); + if (supportsBatchedQuery + && localNodeId.equals(nodeId) == false // local requests don't need batching as there's no network latency + && (clusterAlias == null || Objects.equals(localClusterAlias, clusterAlias))) { + perNodeQueries.computeIfAbsent( + nodeId, + ignored -> new NodeQueryRequest( + new ArrayList<>(), + request, + aliasFilter, + shardsIts.size(), + timeProvider.absoluteStartMillis() + ) + ).shards.add( + new ShardToQuery( + concreteIndexBoosts.getOrDefault(routing.getShardId().getIndex().getUUID(), DEFAULT_INDEX_BOOST), + getOriginalIndices(shardIndex), + shardIndex, + routing.getShardId(), + shardRoutings.getSearchContextId() + ) + ); + } else { + performPhaseOnShard(shardIndex, shardRoutings, routing); + } + } + } + perNodeQueries.forEach((nodeId, request) -> { + if (request.shards.size() == 1) { + var shard = request.shards.getFirst(); + final int sidx = shard.shardIndex; + this.performPhaseOnShard(sidx, shardIterators[sidx], new SearchShardTarget(nodeId, shard.shardId, localClusterAlias)); + return; + } + final Transport.Connection connection; + try { + connection = getConnection(localClusterAlias, nodeId); + } catch (Exception e) { + onNodeQueryFailure(e, request, nodeId); + return; + } + searchTransportService.transportService() + .sendChildRequest(connection, NODE_SEARCH_ACTION_NAME, request, task, new TransportResponseHandler() { + @Override + public NodeQueryResponse read(StreamInput in) throws IOException { + return new NodeQueryResponse(in); + } + + @Override + public Executor executor() { + return EsExecutors.DIRECT_EXECUTOR_SERVICE; + } + + @Override + public void handleResponse(NodeQueryResponse response) { + if (results instanceof QueryPhaseResultConsumer queryPhaseResultConsumer) { + queryPhaseResultConsumer.addPartialResult(response.topDocsStats, response.mergeResult); + } + for (int i = 0; i < response.results.length; i++) { + var s = request.shards.get(i); + int shardIdx = s.shardIndex; + final SearchShardTarget target = new SearchShardTarget(nodeId, s.shardId, localClusterAlias); + switch (response.results[i]) { + case Exception e -> onShardFailure(shardIdx, target, shardIterators[shardIdx], e); + case SearchPhaseResult q -> { + q.setShardIndex(shardIdx); + q.setSearchShardTarget(target); + @SuppressWarnings("unchecked") + var res = (Result) q; + onShardResult(res); + } + case null, default -> { + assert false : "impossible [" + response.results[i] + "]"; + } + } + } + } + + @Override + public void handleException(TransportException e) { + Exception cause = (Exception) ExceptionsHelper.unwrapCause(e); + if (results instanceof QueryPhaseResultConsumer queryPhaseResultConsumer) { + queryPhaseResultConsumer.failure.compareAndSet(null, cause); + } + onPhaseFailure(NAME, "", cause); + } + }); + }); + } + + private static Map removeEmptyAliasFilters(Map aliasFilters) { + Map aliasFilterNoEmpty = new HashMap<>(); + aliasFilters.forEach((idx, filter) -> { + if (AliasFilter.EMPTY.equals(filter) == false) { + aliasFilterNoEmpty.put(idx, filter); + } + }); + return Map.copyOf(aliasFilterNoEmpty); + } + + private void onNodeQueryFailure(Exception e, NodeQueryRequest request, String nodeId) { + for (ShardToQuery shard : request.shards) { + int idx = shard.shardIndex; + onShardFailure( + idx, + new SearchShardTarget(nodeId, shard.shardId, request.searchRequest.getLocalClusterAlias()), + shardIterators[idx], + e + ); + } + } + + private void failOnUnavailable(int shardIndex, SearchShardIterator shardIt) { + SearchShardTarget unassignedShard = new SearchShardTarget(null, shardIt.shardId(), shardIt.getClusterAlias()); + onShardFailure(shardIndex, unassignedShard, shardIt, new NoShardAvailableActionException(shardIt.shardId())); + } + + protected void performPhaseOnShard(final int shardIndex, final SearchShardIterator shardIt, final SearchShardTarget shard) { + final Transport.Connection connection; + try { + connection = getConnection(shard.getClusterAlias(), shard.getNodeId()); + } catch (Exception e) { + onShardFailure(shardIndex, shard, shardIt, e); + return; + } + final String indexUUID = shardIt.shardId().getIndex().getUUID(); + searchTransportService.sendExecuteQuery( + connection, + rewriteShardSearchRequest( + bottomSortCollector, + trackTotalHitsUpTo, + buildShardSearchRequest( + shardIt.shardId(), + shardIt.getClusterAlias(), + shardIndex, + shardIt.getSearchContextId(), + shardIt.getOriginalIndices(), + aliasFilter.getOrDefault(indexUUID, AliasFilter.EMPTY), + shardIt.getSearchContextKeepAlive(), + concreteIndexBoosts.getOrDefault(indexUUID, DEFAULT_INDEX_BOOST), + request, + results.getNumShards(), + timeProvider.absoluteStartMillis(), + hasShardResponse + ) + ), + task, + new SearchActionListener<>(shard, shardIndex) { + @Override + public void innerOnResponse(SearchPhaseResult result) { + try { + @SuppressWarnings("unchecked") + var res = (Result) result; + onShardResult(res); + } catch (Exception exc) { + // TODO: this looks like a nasty bug where it to actually happen + assert false : exc; + onShardFailure(shardIndex, shard, shardIt, exc); + } + } + + @Override + public void onFailure(Exception e) { + onShardFailure(shardIndex, shard, shardIt, e); + } + } + ); + } + + private void onShardFailure(final int shardIndex, SearchShardTarget shard, final SearchShardIterator shardIt, Exception e) { + // we always add the shard failure for a specific shard instance + // we do make sure to clean it on a successful response from a shard + onShardFailure(shardIndex, shard, e); + final SearchShardTarget nextShard = shardIt.nextOrNull(); + final boolean lastShard = nextShard == null; + logger.debug(() -> format("%s: Failed to execute [%s] lastShard [%s]", shard, request, lastShard), e); + if (lastShard) { + maybeCancelSearchTask(); + onShardGroupFailure(shardIndex, shard, e); + finishShardAndMaybePhase(); + } else { + performPhaseOnShard(shardIndex, shardIt, nextShard); + } + } + + protected void finish() { + executeNextPhase(NAME, this::getNextPhase); + } + + /** + * Executed once for every failed shard level request. This method is invoked before the next replica is tried for the given + * shard target. + * @param shardIndex the internal index for this shard. Each shard has an index / ordinal assigned that is used to reference + * it's results + * @param shardTarget the shard target for this failure + * @param e the failure reason + */ + public void onShardFailure(final int shardIndex, SearchShardTarget shardTarget, Exception e) { + if (TransportActions.isShardNotAvailableException(e)) { + // Groups shard not available exceptions under a generic exception that returns a SERVICE_UNAVAILABLE(503) + // temporary error. + e = NoShardAvailableActionException.forOnShardFailureWrapper(e.getMessage()); + } + handleFailedAndCancelled(shardIndex, shardTarget, e); + } + + @Override + public void executeNextPhase(String currentPhase, Supplier nextPhaseSupplier) { + /* This is the main search phase transition where we move to the next phase. If all shards + * failed or if there was a failure and partial results are not allowed, then we immediately + * fail. Otherwise we continue to the next phase. + */ + ShardOperationFailedException[] shardSearchFailures = buildShardFailures(shardFailures); + final int numShards = results.getNumShards(); + if (shardSearchFailures.length == numShards) { + shardSearchFailures = ExceptionsHelper.groupBy(shardSearchFailures); + Throwable cause = shardSearchFailures.length == 0 + ? null + : ElasticsearchException.guessRootCauses(shardSearchFailures[0].getCause())[0]; + logger.debug(() -> "All shards failed for phase: [" + currentPhase + "]", cause); + onPhaseFailure(currentPhase, "all shards failed", cause); + } else { + Boolean allowPartialResults = request.allowPartialSearchResults(); + assert allowPartialResults != null : "SearchRequest missing setting for allowPartialSearchResults"; + if (allowPartialResults == false && successfulOps.get() != numShards) { + handleNotAllSucceeded(currentPhase, shardSearchFailures, numShards); + return; + } + var nextPhase = nextPhaseSupplier.get(); + if (logger.isTraceEnabled()) { + logger.trace( + "[{}] Moving to next phase: [{}], based on results from: {}", + currentPhase, + nextPhase.getName(), + results.getSuccessfulResults().map(r -> r.getSearchShardTarget().toString()).collect(Collectors.joining(",")) + ); + } + executePhase(nextPhase); + } + } + + /** + * This method will communicate a fatal phase failure back to the user. In contrast to a shard failure + * will this method immediately fail the search request and return the failure to the issuer of the request + * @param phase the phase that failed + * @param msg an optional message + * @param cause the cause of the phase failure + */ + @Override + public void onPhaseFailure(String phase, String msg, Throwable cause) { + raisePhaseFailure(new SearchPhaseExecutionException(phase, msg, cause, buildShardFailures(shardFailures))); + } + + @Override + public void sendReleaseSearchContext(ShardSearchContextId contextId, Transport.Connection connection) { + assert isPartOfPointInTime(contextId) == false : "Must not release point in time context [" + contextId + "]"; + if (connection != null) { + searchTransportService.sendFreeContext(connection, contextId, ActionListener.noop()); + } + } + + public static final String NODE_SEARCH_ACTION_NAME = "indices:data/read/search[query][n]"; + + private static final CircuitBreaker NOOP_CIRCUIT_BREAKER = new NoopCircuitBreaker("request"); + + public static void registerNodeSearchAction(SearchTransportService searchTransportService, SearchService searchService) { + var transportService = searchTransportService.transportService(); + var threadPool = transportService.getThreadPool(); + final Dependencies dependencies = new Dependencies(searchService, threadPool.executor(ThreadPool.Names.SEARCH)); + final int searchPoolMax = threadPool.info(ThreadPool.Names.SEARCH).getMax(); + final SearchPhaseController searchPhaseController = new SearchPhaseController(searchService::aggReduceContextBuilder); + transportService.registerRequestHandler( + NODE_SEARCH_ACTION_NAME, + EsExecutors.DIRECT_EXECUTOR_SERVICE, + NodeQueryRequest::new, + (request, channel, task) -> { + final int shardCount = request.shards.size(); + int workers = Math.min(request.searchRequest.getMaxConcurrentShardRequests(), Math.min(shardCount, searchPoolMax)); + final var state = new QueryPerNodeState( + new AtomicInteger(workers - 1), + new QueryPhaseResultConsumer( + request.searchRequest, + dependencies.executor, + NOOP_CIRCUIT_BREAKER, // noop cb for now since we do not have a breaker in this situation in un-batched execution + searchPhaseController, + ((CancellableTask) task)::isCancelled, + SearchProgressListener.NOOP, + shardCount, + Integer.MAX_VALUE, // TODO: intermediary reduces + e -> logger.error("failed to merge on data node", e) + ), + request, + (CancellableTask) task, + channel, + dependencies + ); + for (int i = 0; i < workers; i++) { + dependencies.executor.execute(shardTask(state, i)); + } + } + ); + + } + + private static void maybeRelease(SearchService searchService, NodeQueryRequest request, SearchPhaseResult result) { + var phaseResult = result.queryResult() != null ? result.queryResult() : result.rankFeatureResult(); + if (phaseResult != null + && phaseResult.hasSearchContext() + && request.searchRequest.scroll() == null + && (AsyncSearchContext.isPartOfPIT(null, request.searchRequest, phaseResult.getContextId()) == false)) { + searchService.freeReaderContext(phaseResult.getContextId()); + } + } + + private static AbstractRunnable shardTask(QueryPerNodeState state, int dataNodeLocalIdx) { + return new AbstractRunnable() { + @Override + protected void doRun() { + var request = state.searchRequest; + var searchRequest = request.searchRequest; + var pitBuilder = searchRequest.pointInTimeBuilder(); + var shardToQuery = request.shards.get(dataNodeLocalIdx); + final var shardId = shardToQuery.shardId; + state.dependencies.searchService.executeQueryPhase( + rewriteShardSearchRequest( + state.bottomSortCollector, + state.trackTotalHitsUpTo, + buildShardSearchRequest( + shardId, + searchRequest.getLocalClusterAlias(), + shardToQuery.shardIndex, + shardToQuery.contextId, + shardToQuery.originalIndices, + request.aliasFilters.getOrDefault(shardId.getIndex().getUUID(), AliasFilter.EMPTY), + pitBuilder == null ? null : pitBuilder.getKeepAlive(), + shardToQuery.boost, + searchRequest, + request.totalShards, + request.absoluteStartMillis, + state.hasResponse.getAcquire() + ) + ), + state.task, + new ActionListener<>() { + @Override + public void onResponse(SearchPhaseResult searchPhaseResult) { + try { + searchPhaseResult.setShardIndex(dataNodeLocalIdx); + searchPhaseResult.setSearchShardTarget( + new SearchShardTarget(null, shardToQuery.shardId, request.searchRequest.getLocalClusterAlias()) + ); + // no need for any cache effects when we're already flipped to ture => plain read + set-release + state.hasResponse.compareAndExchangeRelease(false, true); + state.consumeResult(searchPhaseResult.queryResult()); + state.queryPhaseResultConsumer.consumeResult(searchPhaseResult, state::onDone); + } catch (Exception e) { + setFailure(state, dataNodeLocalIdx, e); + } finally { + maybeNext(); + } + } + + private void setFailure(QueryPerNodeState state, int dataNodeLocalIdx, Exception e) { + state.failures.put(dataNodeLocalIdx, e); + state.onDone(); + } + + @Override + public void onFailure(Exception e) { + try { + // TODO: count down fully and just respond with an exception if partial results aren't allowed + setFailure(state, dataNodeLocalIdx, e); + maybeNext(); + } catch (Throwable expected) { + expected.addSuppressed(e); + throw new AssertionError(expected); + } + } + + private void maybeNext() { + final int shardToQuery = state.currentShardIndex.incrementAndGet(); + if (shardToQuery < request.shards.size()) { + state.dependencies.executor.execute(shardTask(state, shardToQuery)); + } + } + } + ); + } + + @Override + public void onFailure(Exception e) { + // TODO this could be done better now, we probably should only make sure to have a single loop running at + // minimum and ignore + requeue rejections in that case + state.failures.put(dataNodeLocalIdx, e); + state.onDone(); + // TODO SO risk! + maybeNext(); + } + + @Override + public void onRejection(Exception e) { + // TODO this could be done better now, we probably should only make sure to have a single loop running at + onFailure(e); + } + + private void maybeNext() { + final int shardToQuery = state.currentShardIndex.incrementAndGet(); + if (shardToQuery < state.searchRequest.shards.size()) { + state.dependencies.executor.execute(shardTask(state, shardToQuery)); + } + } + }; + } + + private record Dependencies(SearchService searchService, Executor executor) {} + + private static final class QueryPerNodeState { + + private static final QueryPhaseResultConsumer.MergeResult EMPTY_PARTIAL_MERGE_RESULT = new QueryPhaseResultConsumer.MergeResult( + List.of(), + Lucene.EMPTY_TOP_DOCS, + null, + 0L + ); + + private final AtomicInteger currentShardIndex; + private final QueryPhaseResultConsumer queryPhaseResultConsumer; + private final NodeQueryRequest searchRequest; + private final CancellableTask task; + private final ConcurrentHashMap failures = new ConcurrentHashMap<>(); + private final Dependencies dependencies; + private final AtomicBoolean hasResponse = new AtomicBoolean(false); + private final int trackTotalHitsUpTo; + private final int topDocsSize; + private final CountDown countDown; + private final TransportChannel channel; + private volatile BottomSortValuesCollector bottomSortCollector; + + private QueryPerNodeState( + AtomicInteger currentShardIndex, + QueryPhaseResultConsumer queryPhaseResultConsumer, + NodeQueryRequest searchRequest, + CancellableTask task, + TransportChannel channel, + Dependencies dependencies + ) { + this.currentShardIndex = currentShardIndex; + this.queryPhaseResultConsumer = queryPhaseResultConsumer; + this.searchRequest = searchRequest; + this.trackTotalHitsUpTo = searchRequest.searchRequest.resolveTrackTotalHitsUpTo(); + topDocsSize = getTopDocsSize(searchRequest.searchRequest); + this.task = task; + countDown = new CountDown(queryPhaseResultConsumer.getNumShards()); + this.channel = channel; + this.dependencies = dependencies; + } + + void onDone() { + if (countDown.countDown() == false) { + return; + } + var channelListener = new ChannelActionListener<>(channel); + try (queryPhaseResultConsumer) { + var failure = queryPhaseResultConsumer.failure.get(); + if (failure != null) { + queryPhaseResultConsumer.getSuccessfulResults() + .forEach(searchPhaseResult -> maybeRelease(dependencies.searchService, searchRequest, searchPhaseResult)); + channelListener.onFailure(failure); + return; + } + final Object[] results = new Object[queryPhaseResultConsumer.getNumShards()]; + for (int i = 0; i < results.length; i++) { + var e = failures.get(i); + var res = queryPhaseResultConsumer.results.get(i); + if (e != null) { + results[i] = e; + assert res == null; + } else { + results[i] = res; + assert results[i] != null; + } + } + final QueryPhaseResultConsumer.MergeResult mergeResult; + try { + mergeResult = Objects.requireNonNullElse(queryPhaseResultConsumer.consumePartialResult(), EMPTY_PARTIAL_MERGE_RESULT); + } catch (Exception e) { + channelListener.onFailure(e); + return; + } + // translate shard indices to those on the coordinator so that it can interpret the merge result without adjustments + final Set relevantShardIndices = new HashSet<>(); + for (ScoreDoc scoreDoc : mergeResult.reducedTopDocs().scoreDocs) { + final int localIndex = scoreDoc.shardIndex; + scoreDoc.shardIndex = searchRequest.shards.get(localIndex).shardIndex; + relevantShardIndices.add(localIndex); + } + for (Object result : results) { + if (result instanceof QuerySearchResult q + && q.getContextId() != null + && relevantShardIndices.contains(q.getShardIndex()) == false + && q.hasSuggestHits() == false + && q.getRankShardResult() == null + && searchRequest.searchRequest.scroll() == null + && (AsyncSearchContext.isPartOfPIT(null, searchRequest.searchRequest, q.getContextId()) == false)) { + if (dependencies.searchService.freeReaderContext(q.getContextId())) { + q.clearContextId(); + } + } + } + + ActionListener.respondAndRelease( + channelListener, + new NodeQueryResponse(mergeResult, results, queryPhaseResultConsumer.topDocsStats) + ); + } + } + + void consumeResult(QuerySearchResult queryResult) { + if (queryResult.isNull() == false + // disable sort optims for scroll requests because they keep track of the last bottom doc locally (per shard) + && searchRequest.searchRequest.scroll() == null + // top docs are already consumed if the query was cancelled or in error. + && queryResult.hasConsumedTopDocs() == false + && queryResult.topDocs() != null + && queryResult.topDocs().topDocs.getClass() == TopFieldDocs.class) { + TopFieldDocs topDocs = (TopFieldDocs) queryResult.topDocs().topDocs; + var bottomSortCollector = this.bottomSortCollector; + if (bottomSortCollector == null) { + synchronized (this) { + bottomSortCollector = this.bottomSortCollector; + if (bottomSortCollector == null) { + bottomSortCollector = this.bottomSortCollector = new BottomSortValuesCollector(topDocsSize, topDocs.fields); + } + } + } + bottomSortCollector.consumeTopDocs(topDocs, queryResult.sortValueFormats()); + } + } + } } diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchResponse.java b/server/src/main/java/org/elasticsearch/action/search/SearchResponse.java index 787dc14f6cd96..22cd5252fb066 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchResponse.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchResponse.java @@ -417,9 +417,7 @@ public XContentBuilder headerToXContent(XContentBuilder builder, ToXContent.Para if (isTerminatedEarly() != null) { builder.field(TERMINATED_EARLY.getPreferredName(), isTerminatedEarly()); } - if (getNumReducePhases() != 1) { - builder.field(NUM_REDUCE_PHASES.getPreferredName(), getNumReducePhases()); - } + // TODO: bring back rendering reduce phase count RestActions.buildBroadcastShardsHeader( builder, params, diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java b/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java index 2041754bc2bcc..d6ae2d9ff05ed 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java @@ -16,6 +16,7 @@ import org.elasticsearch.action.OriginalIndices; import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest; import org.elasticsearch.action.admin.cluster.node.tasks.get.TransportGetTaskAction; +import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksResponse; import org.elasticsearch.action.support.ChannelActionListener; import org.elasticsearch.client.internal.OriginSettingClient; import org.elasticsearch.client.internal.node.NodeClient; @@ -45,6 +46,7 @@ import org.elasticsearch.search.query.ScrollQuerySearchResult; import org.elasticsearch.search.rank.feature.RankFeatureResult; import org.elasticsearch.search.rank.feature.RankFeatureShardRequest; +import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.RemoteClusterService; @@ -176,6 +178,10 @@ public void sendExecuteDfs( ); } + public TransportService transportService() { + return transportService; + } + public void sendExecuteQuery( Transport.Connection connection, final ShardSearchRequest request, @@ -379,16 +385,19 @@ public void writeTo(StreamOutput out) throws IOException { } } - public static void registerRequestHandler(TransportService transportService, SearchService searchService) { + public static void registerRequestHandler(SearchTransportService searchTransportService, SearchService searchService) { final TransportRequestHandler freeContextHandler = (request, channel, task) -> { logger.trace("releasing search context [{}]", request.id()); boolean freed = searchService.freeReaderContext(request.id()); channel.sendResponse(SearchFreeContextResponse.of(freed)); }; + var transportService = searchTransportService.transportService; final Executor freeContextExecutor = buildFreeContextExecutor(transportService); transportService.registerRequestHandler( FREE_CONTEXT_SCROLL_ACTION_NAME, freeContextExecutor, + false, + false, ScrollFreeContextRequest::new, freeContextHandler ); @@ -400,7 +409,7 @@ public static void registerRequestHandler(TransportService transportService, Sea ); // TODO: remove this handler once the lowest compatible version stops using it - transportService.registerRequestHandler(FREE_CONTEXT_ACTION_NAME, freeContextExecutor, in -> { + transportService.registerRequestHandler(FREE_CONTEXT_ACTION_NAME, freeContextExecutor, false, false, in -> { var res = new ScrollFreeContextRequest(in); // this handler exists for BwC purposes only, we don't need the original indices to free the context OriginalIndices.readOriginalIndices(in); @@ -428,7 +437,7 @@ public static void registerRequestHandler(TransportService transportService, Sea DFS_ACTION_NAME, EsExecutors.DIRECT_EXECUTOR_SERVICE, ShardSearchRequest::new, - (request, channel, task) -> searchService.executeDfsPhase(request, (SearchShardTask) task, new ChannelActionListener<>(channel)) + (request, channel, task) -> searchService.executeDfsPhase(request, (CancellableTask) task, new ChannelActionListener<>(channel)) ); TransportActionProxy.registerProxyAction(transportService, DFS_ACTION_NAME, true, DfsSearchResult::new); @@ -438,7 +447,7 @@ public static void registerRequestHandler(TransportService transportService, Sea ShardSearchRequest::new, (request, channel, task) -> searchService.executeQueryPhase( request, - (SearchShardTask) task, + (CancellableTask) task, new ChannelActionListener<>(channel) ) ); @@ -455,7 +464,7 @@ public static void registerRequestHandler(TransportService transportService, Sea QuerySearchRequest::new, (request, channel, task) -> searchService.executeQueryPhase( request, - (SearchShardTask) task, + (CancellableTask) task, new ChannelActionListener<>(channel), channel.getVersion() ) @@ -468,7 +477,7 @@ public static void registerRequestHandler(TransportService transportService, Sea InternalScrollSearchRequest::new, (request, channel, task) -> searchService.executeQueryPhase( request, - (SearchShardTask) task, + (CancellableTask) task, new ChannelActionListener<>(channel), channel.getVersion() ) @@ -481,14 +490,14 @@ public static void registerRequestHandler(TransportService transportService, Sea InternalScrollSearchRequest::new, (request, channel, task) -> searchService.executeFetchPhase( request, - (SearchShardTask) task, + (CancellableTask) task, new ChannelActionListener<>(channel) ) ); TransportActionProxy.registerProxyAction(transportService, QUERY_FETCH_SCROLL_ACTION_NAME, true, ScrollQueryFetchSearchResult::new); final TransportRequestHandler rankShardFeatureRequest = (request, channel, task) -> searchService - .executeRankFeaturePhase(request, (SearchShardTask) task, new ChannelActionListener<>(channel)); + .executeRankFeaturePhase(request, (CancellableTask) task, new ChannelActionListener<>(channel)); transportService.registerRequestHandler( RANK_FEATURE_SHARD_ACTION_NAME, EsExecutors.DIRECT_EXECUTOR_SERVICE, @@ -498,7 +507,7 @@ public static void registerRequestHandler(TransportService transportService, Sea TransportActionProxy.registerProxyAction(transportService, RANK_FEATURE_SHARD_ACTION_NAME, true, RankFeatureResult::new); final TransportRequestHandler shardFetchRequestHandler = (request, channel, task) -> searchService - .executeFetchPhase(request, (SearchShardTask) task, new ChannelActionListener<>(channel)); + .executeFetchPhase(request, (CancellableTask) task, new ChannelActionListener<>(channel)); transportService.registerRequestHandler( FETCH_ID_SCROLL_ACTION_NAME, EsExecutors.DIRECT_EXECUTOR_SERVICE, @@ -611,10 +620,20 @@ private boolean assertNodePresent() { } } - public void cancelSearchTask(SearchTask task, String reason) { - CancelTasksRequest req = new CancelTasksRequest().setTargetTaskId(new TaskId(client.getLocalNodeId(), task.getId())) + public void cancelSearchTask(long taskId, String reason) { + CancelTasksRequest req = new CancelTasksRequest().setTargetTaskId(new TaskId(client.getLocalNodeId(), taskId)) .setReason("Fatal failure during search: " + reason); // force the origin to execute the cancellation as a system user - new OriginSettingClient(client, TransportGetTaskAction.TASKS_ORIGIN).admin().cluster().cancelTasks(req, ActionListener.noop()); + new OriginSettingClient(client, TransportGetTaskAction.TASKS_ORIGIN).admin().cluster().cancelTasks(req, new ActionListener<>() { + @Override + public void onResponse(ListTasksResponse listTasksResponse) { + + } + + @Override + public void onFailure(Exception e) { + logger.warn("unexpected failure cancelling [" + taskId + "] because of [" + reason + "]", e); + } + }); } } diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java index b8d0a928e05aa..6b3102cf9f47b 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java @@ -268,7 +268,7 @@ protected void executePhaseOnShard( @Override protected SearchPhase getNextPhase() { - return new SearchPhase(getName()) { + return new SearchPhase(name) { @Override protected void run() { sendSearchResponse(SearchResponseSections.EMPTY_WITH_TOTAL_HITS, results.getAtomicArray()); diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java index 6f075c6f35009..128e962889cbb 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java @@ -190,7 +190,8 @@ public TransportSearchAction( this.searchPhaseController = searchPhaseController; this.searchTransportService = searchTransportService; this.remoteClusterService = searchTransportService.getRemoteClusterService(); - SearchTransportService.registerRequestHandler(transportService, searchService); + SearchTransportService.registerRequestHandler(searchTransportService, searchService); + SearchQueryThenFetchAsyncAction.registerNodeSearchAction(searchTransportService, searchService); this.clusterService = clusterService; this.transportService = transportService; this.searchService = searchService; @@ -1523,13 +1524,12 @@ public void runNewSearchPhase( task.getProgressListener(), searchRequest, shardIterators.size(), - exc -> searchTransportService.cancelSearchTask(task, "failed to merge result [" + exc.getMessage() + "]") + exc -> searchTransportService.cancelSearchTask(task.getId(), "failed to merge result [" + exc.getMessage() + "]") ); boolean success = false; try { - final AbstractSearchAsyncAction searchPhase; if (searchRequest.searchType() == DFS_QUERY_THEN_FETCH) { - searchPhase = new SearchDfsQueryThenFetchAsyncAction( + var searchPhase = new SearchDfsQueryThenFetchAsyncAction( logger, namedWriteableRegistry, searchTransportService, @@ -1547,10 +1547,11 @@ public void runNewSearchPhase( clusters, client ); + success = true; + searchPhase.start(); } else { assert searchRequest.searchType() == QUERY_THEN_FETCH : searchRequest.searchType(); - searchPhase = new SearchQueryThenFetchAsyncAction( - logger, + var searchPhase = new SearchQueryThenFetchAsyncAction<>( namedWriteableRegistry, searchTransportService, connectionLookup, @@ -1567,9 +1568,9 @@ public void runNewSearchPhase( clusters, client ); + success = true; + searchPhase.start(); } - success = true; - searchPhase.start(); } finally { if (success == false) { queryResultConsumer.close(); diff --git a/server/src/main/java/org/elasticsearch/common/lucene/Lucene.java b/server/src/main/java/org/elasticsearch/common/lucene/Lucene.java index 073000979918e..390eaa987a010 100644 --- a/server/src/main/java/org/elasticsearch/common/lucene/Lucene.java +++ b/server/src/main/java/org/elasticsearch/common/lucene/Lucene.java @@ -390,6 +390,12 @@ public static ScoreDoc readScoreDoc(StreamInput in) throws IOException { return new ScoreDoc(in.readVInt(), in.readFloat()); } + public static ScoreDoc readScoreDocWithShardIndex(StreamInput in) throws IOException { + var res = readScoreDoc(in); + res.shardIndex = in.readVInt(); + return res; + } + private static final Class GEO_DISTANCE_SORT_TYPE_CLASS = LatLonDocValuesField.newDistanceSort("some_geo_field", 0, 0).getClass(); public static void writeTotalHits(StreamOutput out, TotalHits totalHits) throws IOException { @@ -397,6 +403,84 @@ public static void writeTotalHits(StreamOutput out, TotalHits totalHits) throws out.writeEnum(totalHits.relation()); } + public static void writeTopDocsIncludingShardIndex(StreamOutput out, TopDocs topDocs) throws IOException { + if (topDocs instanceof TopFieldGroups topFieldGroups) { + out.writeByte((byte) 2); + writeTotalHits(out, topDocs.totalHits); + out.writeString(topFieldGroups.field); + out.writeArray(Lucene::writeSortField, topFieldGroups.fields); + out.writeVInt(topDocs.scoreDocs.length); + for (int i = 0; i < topDocs.scoreDocs.length; i++) { + ScoreDoc doc = topFieldGroups.scoreDocs[i]; + writeFieldDoc(out, (FieldDoc) doc); + writeSortValue(out, topFieldGroups.groupValues[i]); + out.writeVInt(doc.shardIndex); + } + } else if (topDocs instanceof TopFieldDocs topFieldDocs) { + out.writeByte((byte) 1); + writeTotalHits(out, topDocs.totalHits); + out.writeArray(Lucene::writeSortField, topFieldDocs.fields); + out.writeArray((o, doc) -> { + o.writeArray(Lucene::writeSortValue, ((FieldDoc) doc).fields); + o.writeVInt(doc.doc); + o.writeFloat(doc.score); + o.writeVInt(doc.shardIndex); + }, topFieldDocs.scoreDocs); + } else { + out.writeByte((byte) 0); + writeTotalHits(out, topDocs.totalHits); + out.writeArray((o, scoreDoc) -> { + writeScoreDoc(o, scoreDoc); + o.writeVInt(scoreDoc.shardIndex); + }, topDocs.scoreDocs); + } + } + + public static TopDocs readTopDocsOnly(StreamInput in) throws IOException { + byte type = in.readByte(); + if (type == 0) { + TotalHits totalHits = readTotalHits(in); + + final int scoreDocCount = in.readVInt(); + final ScoreDoc[] scoreDocs; + if (scoreDocCount == 0) { + scoreDocs = EMPTY_SCORE_DOCS; + } else { + scoreDocs = new ScoreDoc[scoreDocCount]; + for (int i = 0; i < scoreDocs.length; i++) { + scoreDocs[i] = readScoreDocWithShardIndex(in); + } + } + return new TopDocs(totalHits, scoreDocs); + } else if (type == 1) { + TotalHits totalHits = readTotalHits(in); + SortField[] fields = in.readArray(Lucene::readSortField, SortField[]::new); + FieldDoc[] fieldDocs = new FieldDoc[in.readVInt()]; + for (int i = 0; i < fieldDocs.length; i++) { + var fieldDoc = readFieldDoc(in); + fieldDoc.shardIndex = in.readVInt(); + fieldDocs[i] = fieldDoc; + } + return new TopFieldDocs(totalHits, fieldDocs, fields); + } else if (type == 2) { + TotalHits totalHits = readTotalHits(in); + String field = in.readString(); + SortField[] fields = in.readArray(Lucene::readSortField, SortField[]::new); + int size = in.readVInt(); + Object[] collapseValues = new Object[size]; + FieldDoc[] fieldDocs = new FieldDoc[size]; + for (int i = 0; i < fieldDocs.length; i++) { + var doc = readFieldDoc(in); + collapseValues[i] = readSortValue(in); + doc.shardIndex = in.readVInt(); + fieldDocs[i] = doc; + } + return new TopFieldGroups(field, totalHits, fieldDocs, fields, collapseValues); + } else { + throw new IllegalStateException("Unknown type " + type); + } + } + public static void writeTopDocs(StreamOutput out, TopDocsAndMaxScore topDocs) throws IOException { if (topDocs.topDocs instanceof TopFieldGroups topFieldGroups) { out.writeByte((byte) 2); diff --git a/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java b/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java index 47d3ed337af73..7095d3ec92c72 100644 --- a/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java +++ b/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java @@ -21,7 +21,6 @@ import org.apache.lucene.search.Query; import org.apache.lucene.search.TotalHits; import org.apache.lucene.util.NumericUtils; -import org.elasticsearch.action.search.SearchShardTask; import org.elasticsearch.action.search.SearchType; import org.elasticsearch.cluster.routing.IndexRouting; import org.elasticsearch.common.lucene.search.Queries; @@ -77,6 +76,7 @@ import org.elasticsearch.search.slice.SliceBuilder; import org.elasticsearch.search.sort.SortAndFormats; import org.elasticsearch.search.suggest.SuggestionSearchContext; +import org.elasticsearch.tasks.CancellableTask; import java.io.IOException; import java.io.UncheckedIOException; @@ -131,7 +131,7 @@ final class DefaultSearchContext extends SearchContext { private CollapseContext collapse; // filter for sliced scroll private SliceBuilder sliceBuilder; - private SearchShardTask task; + private CancellableTask task; private QueryPhaseRankShardContext queryPhaseRankShardContext; /** @@ -433,7 +433,7 @@ public void preProcess() { this.query = buildFilteredQuery(query); if (lowLevelCancellation) { searcher().addQueryCancellation(() -> { - final SearchShardTask task = getTask(); + final CancellableTask task = getTask(); if (task != null) { task.ensureNotCancelled(); } @@ -907,12 +907,12 @@ public void setProfilers(Profilers profilers) { } @Override - public void setTask(SearchShardTask task) { + public void setTask(CancellableTask task) { this.task = task; } @Override - public SearchShardTask getTask() { + public CancellableTask getTask() { return task; } diff --git a/server/src/main/java/org/elasticsearch/search/SearchPhaseResult.java b/server/src/main/java/org/elasticsearch/search/SearchPhaseResult.java index 01c1665451996..5c5db4248afe0 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchPhaseResult.java +++ b/server/src/main/java/org/elasticsearch/search/SearchPhaseResult.java @@ -62,6 +62,11 @@ public ShardSearchContextId getContextId() { return contextId; } + public void clearContextId() { + this.shardSearchRequest = null; + this.contextId = null; + } + /** * Returns the shard index in the context of the currently executing search request that is * used for accounting on the coordinating node diff --git a/server/src/main/java/org/elasticsearch/search/SearchService.java b/server/src/main/java/org/elasticsearch/search/SearchService.java index efa27b2f3448c..f47fc1b97456e 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchService.java +++ b/server/src/main/java/org/elasticsearch/search/SearchService.java @@ -24,7 +24,6 @@ import org.elasticsearch.action.ResolvedIndices; import org.elasticsearch.action.search.CanMatchNodeRequest; import org.elasticsearch.action.search.CanMatchNodeResponse; -import org.elasticsearch.action.search.SearchShardTask; import org.elasticsearch.action.search.SearchType; import org.elasticsearch.action.support.TransportActions; import org.elasticsearch.cluster.ClusterState; @@ -126,6 +125,7 @@ import org.elasticsearch.search.sort.SortBuilder; import org.elasticsearch.search.suggest.Suggest; import org.elasticsearch.search.suggest.completion.CompletionSuggestion; +import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.telemetry.tracing.Tracer; import org.elasticsearch.threadpool.Scheduler; @@ -543,7 +543,7 @@ static ActionListener maybeWrapListenerForStackTrace( return listener; } - public void executeDfsPhase(ShardSearchRequest request, SearchShardTask task, ActionListener listener) { + public void executeDfsPhase(ShardSearchRequest request, CancellableTask task, ActionListener listener) { listener = maybeWrapListenerForStackTrace(listener, request.getChannelVersion(), threadPool); final IndexShard shard = getShard(request); rewriteAndFetchShardRequest(shard, request, listener.delegateFailure((l, rewritten) -> { @@ -552,7 +552,7 @@ public void executeDfsPhase(ShardSearchRequest request, SearchShardTask task, Ac })); } - private DfsSearchResult executeDfsPhase(ShardSearchRequest request, SearchShardTask task) throws IOException { + private DfsSearchResult executeDfsPhase(ShardSearchRequest request, CancellableTask task) throws IOException { ReaderContext readerContext = createOrGetReaderContext(request); try (@SuppressWarnings("unused") // withScope call is necessary to instrument search execution Releasable scope = tracer.withScope(task); @@ -581,7 +581,7 @@ private void loadOrExecuteQueryPhase(final ShardSearchRequest request, final Sea } } - public void executeQueryPhase(ShardSearchRequest request, SearchShardTask task, ActionListener listener) { + public void executeQueryPhase(ShardSearchRequest request, CancellableTask task, ActionListener listener) { ActionListener finalListener = maybeWrapListenerForStackTrace(listener, request.getChannelVersion(), threadPool); assert request.canReturnNullResponseIfMatchNoDocs() == false || request.numberOfShards() > 1 : "empty responses require more than one shard"; @@ -600,7 +600,7 @@ public void executeQueryPhase(ShardSearchRequest request, SearchShardTask task, ); CanMatchShardResponse canMatchResp = canMatch(canMatchContext, false); if (canMatchResp.canMatch() == false) { - finalListener.onResponse(QuerySearchResult.nullInstance()); + l.onResponse(QuerySearchResult.nullInstance()); return; } } @@ -729,7 +729,7 @@ private static void runAsync( * It is the responsibility of the caller to ensure that the ref count is correctly decremented * when the object is no longer needed. */ - private SearchPhaseResult executeQueryPhase(ShardSearchRequest request, SearchShardTask task) throws Exception { + private SearchPhaseResult executeQueryPhase(ShardSearchRequest request, CancellableTask task) throws Exception { final ReaderContext readerContext = createOrGetReaderContext(request); try ( Releasable scope = tracer.withScope(task); @@ -774,7 +774,7 @@ private SearchPhaseResult executeQueryPhase(ShardSearchRequest request, SearchSh } } - public void executeRankFeaturePhase(RankFeatureShardRequest request, SearchShardTask task, ActionListener listener) { + public void executeRankFeaturePhase(RankFeatureShardRequest request, CancellableTask task, ActionListener listener) { listener = maybeWrapListenerForStackTrace(listener, request.getShardSearchRequest().getChannelVersion(), threadPool); final ReaderContext readerContext = findReaderContext(request.contextId(), request); final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(request.getShardSearchRequest()); @@ -818,7 +818,7 @@ private QueryFetchSearchResult executeFetchPhase(ReaderContext reader, SearchCon public void executeQueryPhase( InternalScrollSearchRequest request, - SearchShardTask task, + CancellableTask task, ActionListener listener, TransportVersion version ) { @@ -860,7 +860,7 @@ public void executeQueryPhase( */ public void executeQueryPhase( QuerySearchRequest request, - SearchShardTask task, + CancellableTask task, ActionListener listener, TransportVersion version ) { @@ -918,7 +918,7 @@ private Executor getExecutor(IndexShard indexShard) { public void executeFetchPhase( InternalScrollSearchRequest request, - SearchShardTask task, + CancellableTask task, ActionListener listener ) { final LegacyReaderContext readerContext = (LegacyReaderContext) findReaderContext(request.contextId(), request); @@ -953,8 +953,14 @@ public void executeFetchPhase( }, wrapFailureListener(listener, readerContext, markAsUsed)); } - public void executeFetchPhase(ShardFetchRequest request, SearchShardTask task, ActionListener listener) { - final ReaderContext readerContext = findReaderContext(request.contextId(), request); + public void executeFetchPhase(ShardFetchRequest request, CancellableTask task, ActionListener listener) { + final ReaderContext readerContext; + try { + readerContext = findReaderContext(request.contextId(), request); + } catch (Exception e) { + listener.onFailure(e); + return; + } final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(request.getShardSearchRequest()); final Releasable markAsUsed = readerContext.markAsUsed(getKeepAlive(shardSearchRequest)); rewriteAndFetchShardRequest(readerContext.indexShard(), shardSearchRequest, listener.delegateFailure((l, rewritten) -> { @@ -991,7 +997,7 @@ public void executeFetchPhase(ShardFetchRequest request, SearchShardTask task, A })); } - protected void checkCancelled(SearchShardTask task) { + protected void checkCancelled(CancellableTask task) { // check cancellation as early as possible, as it avoids opening up a Lucene reader on FrozenEngine try { task.ensureNotCancelled(); @@ -1122,7 +1128,7 @@ public void openReaderContext(ShardId shardId, TimeValue keepAlive, ActionListen protected SearchContext createContext( ReaderContext readerContext, ShardSearchRequest request, - SearchShardTask task, + CancellableTask task, ResultsType resultsType, boolean includeAggregations ) throws IOException { diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/SignificantTermsAggregatorFactory.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/SignificantTermsAggregatorFactory.java index 080cac9cbfb85..17f63643b83c2 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/SignificantTermsAggregatorFactory.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/SignificantTermsAggregatorFactory.java @@ -13,7 +13,6 @@ import org.apache.lucene.search.MatchNoDocsQuery; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.search.SearchShardTask; import org.elasticsearch.common.logging.DeprecationCategory; import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.index.query.QueryBuilder; @@ -128,7 +127,8 @@ private static SignificantTermsAggregatorSupplier bytesSupplier() { *

* Some searches that will never match can still fall through and we endup running query that will produce no results. * However even in that case we sometimes do expensive things like loading global ordinals. This method should prevent this. - * Note that if {@link org.elasticsearch.search.SearchService#executeQueryPhase(ShardSearchRequest, SearchShardTask, ActionListener)} + * Note that if {@link org.elasticsearch.search.SearchService#executeQueryPhase(ShardSearchRequest, + * org.elasticsearch.tasks.CancellableTask, ActionListener)} * always do a can match then we don't need this code here. */ static boolean matchNoDocs(AggregationContext context, Aggregator parent) { diff --git a/server/src/main/java/org/elasticsearch/search/internal/FilteredSearchContext.java b/server/src/main/java/org/elasticsearch/search/internal/FilteredSearchContext.java index 8c4f912c5988c..5bad06d08f96b 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/FilteredSearchContext.java +++ b/server/src/main/java/org/elasticsearch/search/internal/FilteredSearchContext.java @@ -12,7 +12,6 @@ import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.Query; import org.apache.lucene.search.TotalHits; -import org.elasticsearch.action.search.SearchShardTask; import org.elasticsearch.action.search.SearchType; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.cache.bitset.BitsetFilterCache; @@ -40,6 +39,7 @@ import org.elasticsearch.search.rescore.RescoreContext; import org.elasticsearch.search.sort.SortAndFormats; import org.elasticsearch.search.suggest.SuggestionSearchContext; +import org.elasticsearch.tasks.CancellableTask; import java.util.List; @@ -422,12 +422,12 @@ public SearchExecutionContext getSearchExecutionContext() { } @Override - public void setTask(SearchShardTask task) { + public void setTask(CancellableTask task) { in.setTask(task); } @Override - public SearchShardTask getTask() { + public CancellableTask getTask() { return in.getTask(); } diff --git a/server/src/main/java/org/elasticsearch/search/internal/ReaderContext.java b/server/src/main/java/org/elasticsearch/search/internal/ReaderContext.java index c15b604b5b5fc..c525a9f2f0cce 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/ReaderContext.java +++ b/server/src/main/java/org/elasticsearch/search/internal/ReaderContext.java @@ -194,4 +194,9 @@ public void putInContext(String key, Object value) { public long getStartTimeInNano() { return startTimeInNano; } + + @Override + public String toString() { + return "ReaderContext{" + id + " }"; + } } diff --git a/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java b/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java index 7da71b77c6a6f..5e856e8df6d6a 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java +++ b/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java @@ -11,7 +11,6 @@ import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.Query; import org.apache.lucene.search.TotalHits; -import org.elasticsearch.action.search.SearchShardTask; import org.elasticsearch.action.search.SearchType; import org.elasticsearch.core.Assertions; import org.elasticsearch.core.Nullable; @@ -48,6 +47,7 @@ import org.elasticsearch.search.rescore.RescoreContext; import org.elasticsearch.search.sort.SortAndFormats; import org.elasticsearch.search.suggest.SuggestionSearchContext; +import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.transport.LeakTracker; import java.io.IOException; @@ -85,12 +85,14 @@ public abstract class SearchContext implements Releasable { protected SearchContext() {} + public abstract void setTask(CancellableTask task); + public final List getCancellationChecks() { final Runnable timeoutRunnable = QueryPhase.getTimeoutCheck(this); if (lowLevelCancellation()) { // This searching doesn't live beyond this phase, so we don't need to remove query cancellation Runnable c = () -> { - final SearchShardTask task = getTask(); + final CancellableTask task = getTask(); if (task != null) { task.ensureNotCancelled(); } @@ -100,9 +102,7 @@ public final List getCancellationChecks() { return timeoutRunnable == null ? List.of() : List.of(timeoutRunnable); } - public abstract void setTask(SearchShardTask task); - - public abstract SearchShardTask getTask(); + public abstract CancellableTask getTask(); public abstract boolean isCancelled(); diff --git a/server/src/main/java/org/elasticsearch/search/internal/ShardSearchRequest.java b/server/src/main/java/org/elasticsearch/search/internal/ShardSearchRequest.java index 5e4ffbdba9ad2..12b9acfc69935 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/ShardSearchRequest.java +++ b/server/src/main/java/org/elasticsearch/search/internal/ShardSearchRequest.java @@ -224,7 +224,7 @@ public ShardSearchRequest( long nowInMillis, @Nullable String clusterAlias, ShardSearchContextId readerId, - TimeValue keepAlive, + @Nullable TimeValue keepAlive, long waitForCheckpoint, TimeValue waitForCheckpointsTimeout, boolean forceSyntheticSource diff --git a/server/src/main/java/org/elasticsearch/search/query/QuerySearchResult.java b/server/src/main/java/org/elasticsearch/search/query/QuerySearchResult.java index b7e2e361c28b6..a1c83c17ca269 100644 --- a/server/src/main/java/org/elasticsearch/search/query/QuerySearchResult.java +++ b/server/src/main/java/org/elasticsearch/search/query/QuerySearchResult.java @@ -11,6 +11,7 @@ import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.TotalHits; +import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.DelayableWriteable; import org.elasticsearch.common.io.stream.StreamInput; @@ -68,6 +69,8 @@ public final class QuerySearchResult extends SearchPhaseResult { private long serviceTimeEWMA = -1; private int nodeQueueSize = -1; + private boolean reduced; + private final boolean isNull; private final RefCounted refCounted; @@ -91,7 +94,9 @@ public QuerySearchResult(StreamInput in, boolean delayedAggregations) throws IOE super(in); isNull = in.readBoolean(); if (isNull == false) { - ShardSearchContextId id = new ShardSearchContextId(in); + ShardSearchContextId id = in.getTransportVersion().onOrAfter(TransportVersions.BATCHED_QUERY_PHASE_VERSION) + ? in.readOptionalWriteable(ShardSearchContextId::new) + : new ShardSearchContextId(in); readFromWithId(id, in, delayedAggregations); } refCounted = null; @@ -139,6 +144,16 @@ public QuerySearchResult queryResult() { return this; } + public boolean isReduced() { + return reduced; + } + + public void setReduced() { + assert (hasConsumedTopDocs() || topDocsAndMaxScore.topDocs.scoreDocs.length == 0) && aggregations == null + : topDocsAndMaxScore + " " + aggregations; + this.reduced = true; + } + public void searchTimedOut(boolean searchTimedOut) { this.searchTimedOut = searchTimedOut; } @@ -381,7 +396,13 @@ private void readFromWithId(ShardSearchContextId id, StreamInput in, boolean del sortValueFormats[i] = in.readNamedWriteable(DocValueFormat.class); } } - setTopDocs(readTopDocs(in)); + if (in.getTransportVersion().onOrAfter(TransportVersions.BATCHED_QUERY_PHASE_VERSION)) { + if (in.readBoolean()) { + setTopDocs(readTopDocs(in)); + } + } else { + setTopDocs(readTopDocs(in)); + } hasAggs = in.readBoolean(); boolean success = false; try { @@ -405,6 +426,9 @@ private void readFromWithId(ShardSearchContextId id, StreamInput in, boolean del setRescoreDocIds(new RescoreDocIds(in)); if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_8_0)) { rankShardResult = in.readOptionalNamedWriteable(RankShardResult.class); + if (in.getTransportVersion().onOrAfter(TransportVersions.BATCHED_QUERY_PHASE_VERSION)) { + reduced = in.readBoolean(); + } } success = true; } finally { @@ -423,7 +447,11 @@ public void writeTo(StreamOutput out) throws IOException { } out.writeBoolean(isNull); if (isNull == false) { - contextId.writeTo(out); + if (out.getTransportVersion().onOrAfter(TransportVersions.BATCHED_QUERY_PHASE_VERSION)) { + out.writeOptionalWriteable(contextId); + } else { + contextId.writeTo(out); + } writeToNoId(out); } } @@ -439,7 +467,16 @@ public void writeToNoId(StreamOutput out) throws IOException { out.writeNamedWriteable(sortValueFormats[i]); } } - writeTopDocs(out, topDocsAndMaxScore); + if (out.getTransportVersion().onOrAfter(TransportVersions.BATCHED_QUERY_PHASE_VERSION)) { + if (topDocsAndMaxScore != null) { + out.writeBoolean(true); + writeTopDocs(out, topDocsAndMaxScore); + } else { + out.writeBoolean(false); + } + } else { + writeTopDocs(out, topDocsAndMaxScore); + } out.writeOptionalWriteable(aggregations); if (suggest == null) { out.writeBoolean(false); @@ -459,6 +496,9 @@ public void writeToNoId(StreamOutput out) throws IOException { } else if (rankShardResult != null) { throw new IllegalArgumentException("cannot serialize [rank] to version [" + out.getTransportVersion().toReleaseVersion() + "]"); } + if (out.getTransportVersion().onOrAfter(TransportVersion.current())) { + out.writeBoolean(reduced); + } } @Nullable diff --git a/server/src/main/java/org/elasticsearch/search/rank/RankSearchContext.java b/server/src/main/java/org/elasticsearch/search/rank/RankSearchContext.java index ad70e7d39aff8..951a9b0cf3520 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/RankSearchContext.java +++ b/server/src/main/java/org/elasticsearch/search/rank/RankSearchContext.java @@ -12,7 +12,6 @@ import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.Query; import org.apache.lucene.search.TotalHits; -import org.elasticsearch.action.search.SearchShardTask; import org.elasticsearch.action.search.SearchType; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.cache.bitset.BitsetFilterCache; @@ -48,6 +47,7 @@ import org.elasticsearch.search.rescore.RescoreContext; import org.elasticsearch.search.sort.SortAndFormats; import org.elasticsearch.search.suggest.SuggestionSearchContext; +import org.elasticsearch.tasks.CancellableTask; import java.util.List; @@ -211,12 +211,12 @@ public long getRelativeTimeInMillis() { /* ---- ALL METHODS ARE UNSUPPORTED BEYOND HERE ---- */ @Override - public void setTask(SearchShardTask task) { + public void setTask(CancellableTask task) { throw new UnsupportedOperationException(); } @Override - public SearchShardTask getTask() { + public CancellableTask getTask() { throw new UnsupportedOperationException(); } diff --git a/server/src/test/java/org/elasticsearch/action/search/AbstractSearchAsyncActionTests.java b/server/src/test/java/org/elasticsearch/action/search/AbstractSearchAsyncActionTests.java index 11085558dbe16..10cc2b863c2e1 100644 --- a/server/src/test/java/org/elasticsearch/action/search/AbstractSearchAsyncActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/AbstractSearchAsyncActionTests.java @@ -90,6 +90,10 @@ private AbstractSearchAsyncAction createAction( request.getMaxConcurrentShardRequests(), SearchResponse.Clusters.EMPTY ) { + { + assertTrue(finishShard()); // only have a single shard in the iterator, lets finish that one as is expected by tests + } + @Override protected SearchPhase getNextPhase() { return null; @@ -226,13 +230,8 @@ public void testShardNotAvailableWithDisallowPartialFailures() { SearchShardIterator skipIterator = new SearchShardIterator(null, null, Collections.emptyList(), null); skipIterator.skip(true); skipIterator.reset(); - action.skipShard(skipIterator); + action.start(); assertThat(exception.get(), instanceOf(SearchPhaseExecutionException.class)); - SearchPhaseExecutionException searchPhaseExecutionException = (SearchPhaseExecutionException) exception.get(); - assertEquals("Partial shards failure (" + (numShards - 1) + " shards unavailable)", searchPhaseExecutionException.getMessage()); - assertEquals("test", searchPhaseExecutionException.getPhaseName()); - assertEquals(0, searchPhaseExecutionException.shardFailures().length); - assertEquals(0, searchPhaseExecutionException.getSuppressed().length); } private static ArraySearchPhaseResults phaseResults( diff --git a/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java index fd60621c7e400..e23b78522b13b 100644 --- a/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java @@ -759,7 +759,7 @@ public void sendExecuteFetch( } - private static BiFunction, SearchPhase> searchPhaseFactory( + private static BiFunction, SearchPhase> searchPhaseFactory( MockSearchPhaseContext mockSearchPhaseContext ) { return (searchResponse, scrollId) -> new SearchPhase("test") { diff --git a/server/src/test/java/org/elasticsearch/action/search/MockSearchPhaseContext.java b/server/src/test/java/org/elasticsearch/action/search/MockSearchPhaseContext.java index 97d420b7cd3c2..65ec4edfd35e6 100644 --- a/server/src/test/java/org/elasticsearch/action/search/MockSearchPhaseContext.java +++ b/server/src/test/java/org/elasticsearch/action/search/MockSearchPhaseContext.java @@ -86,7 +86,7 @@ public OriginalIndices getOriginalIndices(int shardIndex) { } @Override - public void sendSearchResponse(SearchResponseSections internalSearchResponse, AtomicArray queryResults) { + public void sendSearchResponse(SearchResponseSections internalSearchResponse, AtomicArray queryResults) { String scrollId = getRequest().scroll() != null ? TransportSearchHelper.buildScrollId(queryResults) : null; BytesReference searchContextId = getRequest().pointInTimeBuilder() != null ? new BytesArray(TransportSearchHelper.buildScrollId(queryResults)) @@ -150,7 +150,7 @@ protected void executePhaseOnShard( SearchActionListener listener ) { onShardResult(new SearchPhaseResult() { - }, shardIt); + }); } @Override diff --git a/server/src/test/java/org/elasticsearch/action/search/QueryPhaseResultConsumerTests.java b/server/src/test/java/org/elasticsearch/action/search/QueryPhaseResultConsumerTests.java index e0b68647289b2..c059e5db499fd 100644 --- a/server/src/test/java/org/elasticsearch/action/search/QueryPhaseResultConsumerTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/QueryPhaseResultConsumerTests.java @@ -119,6 +119,7 @@ public void testProgressListenerExceptionsAreCaught() throws Exception { () -> false, searchProgressListener, 10, + -1, e -> onPartialMergeFailure.accumulateAndGet(e, (prev, curr) -> { curr.addSuppressed(prev); return curr; diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java index 647d16977181f..922e96c5e5520 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java @@ -165,13 +165,13 @@ public void testLimitConcurrentShardRequests() throws InterruptedException { request.setMaxConcurrentShardRequests(numConcurrent); boolean doReplicas = randomBoolean(); int numShards = randomIntBetween(5, 10); - Boolean[] shardFailures = new Boolean[numShards]; + Boolean[] sFailures = new Boolean[numShards]; // at least one response otherwise the entire request fails - shardFailures[randomIntBetween(0, shardFailures.length - 1)] = false; - for (int i = 0; i < shardFailures.length; i++) { - if (shardFailures[i] == null) { + sFailures[randomIntBetween(0, sFailures.length - 1)] = false; + for (int i = 0; i < sFailures.length; i++) { + if (sFailures[i] == null) { boolean failure = randomBoolean(); - shardFailures[i] = failure; + sFailures[i] = failure; } } CountDownLatch latch = new CountDownLatch(1); @@ -239,7 +239,7 @@ protected void executePhaseOnShard( connection.getNode() ); try { - if (shardFailures[shardIt.shardId().id()]) { + if (sFailures[shardIt.shardId().id()]) { listener.onFailure(new RuntimeException()); } else { listener.onResponse(testSearchPhaseResult); diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java index bf81486087361..fec305ced4fb7 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java @@ -713,7 +713,6 @@ private void consumerTestCase(int numEmptyResponses) throws Exception { } SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce(); - assertEquals(numTotalReducePhases, reduce.numReducePhases()); assertEquals(numTotalReducePhases, reductions.size()); assertAggReduction(request); Max max = (Max) reduce.aggregations().asList().get(0); diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java index be693a2d7d294..0c7837632aaf8 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java @@ -39,6 +39,7 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.InternalAggregationTestCase; import org.elasticsearch.transport.Transport; +import org.elasticsearch.transport.TransportService; import java.util.Collections; import java.util.List; @@ -51,6 +52,8 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class SearchQueryThenFetchAsyncActionTests extends ESTestCase { public void testBottomFieldSort() throws Exception { @@ -83,7 +86,9 @@ private void testCase(boolean withScroll, boolean withCollapse) throws Exception AtomicInteger numWithTopDocs = new AtomicInteger(); AtomicInteger successfulOps = new AtomicInteger(); AtomicBoolean canReturnNullResponse = new AtomicBoolean(false); - SearchTransportService searchTransportService = new SearchTransportService(null, null, null) { + var transportService = mock(TransportService.class); + when(transportService.getLocalNode()).thenReturn(primaryNode); + SearchTransportService searchTransportService = new SearchTransportService(transportService, null, null) { @Override public void sendExecuteQuery( Transport.Connection connection, @@ -182,11 +187,11 @@ public void sendExecuteQuery( task::isCancelled, task.getProgressListener(), shardsIter.size(), + -1, exc -> {} ) ) { - SearchQueryThenFetchAsyncAction action = new SearchQueryThenFetchAsyncAction( - logger, + SearchQueryThenFetchAsyncAction action = new SearchQueryThenFetchAsyncAction<>( null, searchTransportService, (clusterAlias, node) -> lookup.get(node), diff --git a/server/src/test/java/org/elasticsearch/index/SearchSlowLogTests.java b/server/src/test/java/org/elasticsearch/index/SearchSlowLogTests.java index 359118c7cb5a1..d4aec300c666b 100644 --- a/server/src/test/java/org/elasticsearch/index/SearchSlowLogTests.java +++ b/server/src/test/java/org/elasticsearch/index/SearchSlowLogTests.java @@ -26,6 +26,7 @@ import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.internal.ShardSearchRequest; +import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; import org.elasticsearch.test.ESSingleNodeTestCase; import org.elasticsearch.test.TestSearchContext; @@ -93,7 +94,7 @@ public ShardSearchRequest request() { } @Override - public SearchShardTask getTask() { + public CancellableTask getTask() { return super.getTask(); } }; diff --git a/test/framework/src/main/java/org/elasticsearch/search/MockSearchService.java b/test/framework/src/main/java/org/elasticsearch/search/MockSearchService.java index 79c61cacb58eb..42b11173a3b19 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/MockSearchService.java +++ b/test/framework/src/main/java/org/elasticsearch/search/MockSearchService.java @@ -9,7 +9,6 @@ package org.elasticsearch.search; -import org.elasticsearch.action.search.SearchShardTask; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.core.TimeValue; @@ -23,6 +22,7 @@ import org.elasticsearch.search.internal.ReaderContext; import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.internal.ShardSearchRequest; +import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.telemetry.tracing.Tracer; import org.elasticsearch.threadpool.ThreadPool; @@ -46,7 +46,7 @@ public static class TestPlugin extends Plugin {} private Consumer onCreateSearchContext = context -> {}; - private Function onCheckCancelled = Function.identity(); + private Function onCheckCancelled = Function.identity(); /** Throw an {@link AssertionError} if there are still in-flight contexts. */ public static void assertNoInFlightContext() { @@ -132,7 +132,7 @@ public void setOnCreateSearchContext(Consumer onCreateSearchConte protected SearchContext createContext( ReaderContext readerContext, ShardSearchRequest request, - SearchShardTask task, + CancellableTask task, ResultsType resultsType, boolean includeAggregations ) throws IOException { @@ -154,12 +154,12 @@ public SearchContext createSearchContext(ShardSearchRequest request, TimeValue t return searchContext; } - public void setOnCheckCancelled(Function onCheckCancelled) { + public void setOnCheckCancelled(Function onCheckCancelled) { this.onCheckCancelled = onCheckCancelled; } @Override - protected void checkCancelled(SearchShardTask task) { + protected void checkCancelled(CancellableTask task) { super.checkCancelled(onCheckCancelled.apply(task)); } } diff --git a/test/framework/src/main/java/org/elasticsearch/test/TestSearchContext.java b/test/framework/src/main/java/org/elasticsearch/test/TestSearchContext.java index 103cf1c15abc1..c46442485ff9e 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/TestSearchContext.java +++ b/test/framework/src/main/java/org/elasticsearch/test/TestSearchContext.java @@ -11,7 +11,6 @@ import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.Query; import org.apache.lucene.search.TotalHits; -import org.elasticsearch.action.search.SearchShardTask; import org.elasticsearch.action.search.SearchType; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.IndexService; @@ -49,6 +48,7 @@ import org.elasticsearch.search.rescore.RescoreContext; import org.elasticsearch.search.sort.SortAndFormats; import org.elasticsearch.search.suggest.SuggestionSearchContext; +import org.elasticsearch.tasks.CancellableTask; import java.util.Collections; import java.util.HashMap; @@ -67,7 +67,7 @@ public class TestSearchContext extends SearchContext { ParsedQuery postFilter; Query query; Float minScore; - SearchShardTask task; + CancellableTask task; SortAndFormats sort; boolean trackScores = false; int trackTotalHitsUpTo = SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO; @@ -506,12 +506,12 @@ public SearchExecutionContext getSearchExecutionContext() { } @Override - public void setTask(SearchShardTask task) { + public void setTask(CancellableTask task) { this.task = task; } @Override - public SearchShardTask getTask() { + public CancellableTask getTask() { return task; } diff --git a/test/framework/src/main/java/org/elasticsearch/test/hamcrest/ElasticsearchAssertions.java b/test/framework/src/main/java/org/elasticsearch/test/hamcrest/ElasticsearchAssertions.java index 40aee8eed4235..b47410df4bc87 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/hamcrest/ElasticsearchAssertions.java +++ b/test/framework/src/main/java/org/elasticsearch/test/hamcrest/ElasticsearchAssertions.java @@ -522,7 +522,12 @@ public static void assertScrollResponsesAndHitCount( public static void assertResponse(ActionFuture responseFuture, Consumer consumer) throws ExecutionException, InterruptedException { - var res = responseFuture.get(); + final R res; + try { + res = responseFuture.get(); + } catch (Exception e) { + throw new AssertionError(e); + } try { consumer.accept(res); } finally { diff --git a/x-pack/plugin/async-search/src/internalClusterTest/java/org/elasticsearch/xpack/search/AsyncSearchErrorTraceIT.java b/x-pack/plugin/async-search/src/internalClusterTest/java/org/elasticsearch/xpack/search/AsyncSearchErrorTraceIT.java index 39a6fa1e4b34f..985f8932cff54 100644 --- a/x-pack/plugin/async-search/src/internalClusterTest/java/org/elasticsearch/xpack/search/AsyncSearchErrorTraceIT.java +++ b/x-pack/plugin/async-search/src/internalClusterTest/java/org/elasticsearch/xpack/search/AsyncSearchErrorTraceIT.java @@ -8,13 +8,17 @@ package org.elasticsearch.xpack.search; import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.action.search.SearchQueryThenFetchAsyncAction; import org.elasticsearch.client.Request; import org.elasticsearch.client.Response; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportMessageListener; +import org.elasticsearch.transport.TransportResponse; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xcontent.XContentType; import org.junit.Before; @@ -38,12 +42,41 @@ protected Collection> nodePlugins() { return List.of(AsyncSearch.class); } + @Override + protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { + // TODO: this can be removed once we consistently use the Threadpool provided timestamps in search code. Currently, there is + // a mix of the threadpool timestamps and System.currentTimeMillis etc. in the codebase so we need to force the threadpool to be + // consistent with those APIs. + return super.nodeSettings( + nodeOrdinal, + Settings.builder().put(otherSettings).put(ThreadPool.ESTIMATED_TIME_INTERVAL_SETTING.getKey(), 0).build() + ); + } + private AtomicBoolean transportMessageHasStackTrace; @Before private void setupMessageListener() { internalCluster().getDataNodeInstances(TransportService.class).forEach(ts -> { ts.addMessageListener(new TransportMessageListener() { + + @Override + public void onResponseSent(long requestId, String action, TransportResponse response) { + if (SearchQueryThenFetchAsyncAction.NODE_SEARCH_ACTION_NAME.equals(action)) { + Object[] res = asInstanceOf(SearchQueryThenFetchAsyncAction.NodeQueryResponse.class, response).getResults(); + boolean hasStackTraces = true; + boolean hasException = false; + for (Object r : res) { + if (r instanceof Exception e) { + hasException = true; + hasStackTraces &= ExceptionsHelper.unwrapCausesAndSuppressed(e, t -> t.getStackTrace().length > 0) + .isPresent(); + } + } + transportMessageHasStackTrace.set(hasException && hasStackTraces); + } + } + @Override public void onResponseSent(long requestId, String action, Exception error) { TransportMessageListener.super.onResponseSent(requestId, action, error); diff --git a/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchResponseTests.java b/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchResponseTests.java index 98513f611a5d8..48280e179c031 100644 --- a/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchResponseTests.java +++ b/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchResponseTests.java @@ -268,7 +268,6 @@ public void testToXContentWithSearchResponseAfterCompletion() throws IOException "response" : { "took" : %s, "timed_out" : false, - "num_reduce_phases" : 2, "_shards" : { "total" : 10, "successful" : 9, @@ -304,7 +303,6 @@ public void testToXContentWithSearchResponseAfterCompletion() throws IOException "response" : { "took" : %s, "timed_out" : false, - "num_reduce_phases" : 2, "_shards" : { "total" : 10, "successful" : 9, @@ -388,7 +386,6 @@ public void testToXContentWithCCSSearchResponseWhileRunning() throws IOException "response" : { "took" : %s, "timed_out" : false, - "num_reduce_phases" : 2, "_shards" : { "total" : 10, "successful" : 9, @@ -447,7 +444,6 @@ public void testToXContentWithCCSSearchResponseWhileRunning() throws IOException "response" : { "took" : %s, "timed_out" : false, - "num_reduce_phases" : 2, "_shards" : { "total" : 10, "successful" : 9, @@ -623,7 +619,6 @@ public void testToXContentWithCCSSearchResponseAfterCompletion() throws IOExcept "response" : { "took" : %s, "timed_out" : true, - "num_reduce_phases" : 2, "_shards" : { "total" : 10, "successful" : 9, @@ -770,7 +765,6 @@ public void testToXContentWithSearchResponseWhileRunning() throws IOException { "response" : { "took" : %s, "timed_out" : false, - "num_reduce_phases" : 2, "_shards" : { "total" : 10, "successful" : 9, @@ -804,7 +798,6 @@ public void testToXContentWithSearchResponseWhileRunning() throws IOException { "response" : { "took" : %s, "timed_out" : false, - "num_reduce_phases" : 2, "_shards" : { "total" : 10, "successful" : 9, diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/PreAuthorizationUtils.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/PreAuthorizationUtils.java index 221b7a65e1f8f..3c8f5dd5cdff9 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/PreAuthorizationUtils.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/PreAuthorizationUtils.java @@ -10,6 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.IndicesRequest; +import org.elasticsearch.action.search.SearchQueryThenFetchAsyncAction; import org.elasticsearch.action.search.SearchTransportService; import org.elasticsearch.action.search.TransportSearchAction; import org.elasticsearch.xpack.core.security.SecurityContext; @@ -46,7 +47,8 @@ public final class PreAuthorizationUtils { SearchTransportService.QUERY_ID_ACTION_NAME, SearchTransportService.FETCH_ID_ACTION_NAME, SearchTransportService.RANK_FEATURE_SHARD_ACTION_NAME, - SearchTransportService.QUERY_CAN_MATCH_NODE_NAME + SearchTransportService.QUERY_CAN_MATCH_NODE_NAME, + SearchQueryThenFetchAsyncAction.NODE_SEARCH_ACTION_NAME ) );