diff --git a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java index 44752d6f33600..b5f7a3facb4ba 100644 --- a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java @@ -196,7 +196,7 @@ long buildTookInMillis() { * This is the main entry point for a search. This method starts the search execution of the initial phase. */ public final void start() { - if (getNumShards() == 0) { + if (results.getNumShards() == 0) { // no search shards to search on, bail with empty response // (it happens with search across _all with no indices around and consistent with broadcast operations) int trackTotalHitsUpTo = request.source() == null ? SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO @@ -341,7 +341,8 @@ protected void executeNextPhase(String currentPhase, Supplier nextP * fail. Otherwise we continue to the next phase. */ ShardOperationFailedException[] shardSearchFailures = buildShardFailures(); - if (shardSearchFailures.length == getNumShards()) { + final int numShards = results.getNumShards(); + if (shardSearchFailures.length == numShards) { shardSearchFailures = ExceptionsHelper.groupBy(shardSearchFailures); Throwable cause = shardSearchFailures.length == 0 ? null @@ -351,7 +352,7 @@ protected void executeNextPhase(String currentPhase, Supplier nextP } else { Boolean allowPartialResults = request.allowPartialSearchResults(); assert allowPartialResults != null : "SearchRequest missing setting for allowPartialSearchResults"; - if (allowPartialResults == false && successfulOps.get() != getNumShards()) { + if (allowPartialResults == false && successfulOps.get() != numShards) { // check if there are actual failures in the atomic array since // successful retries can reset the failures to null if (shardSearchFailures.length > 0) { @@ -363,7 +364,7 @@ protected void executeNextPhase(String currentPhase, Supplier nextP } onPhaseFailure(currentPhase, "Partial shards failure", null); } else { - int discrepancy = getNumShards() - successfulOps.get(); + int discrepancy = numShards - successfulOps.get(); assert discrepancy > 0 : "discrepancy: " + discrepancy; if (logger.isDebugEnabled()) { logger.debug( @@ -371,7 +372,7 @@ protected void executeNextPhase(String currentPhase, Supplier nextP discrepancy, successfulOps.get(), toSkipShardsIts.size(), - getNumShards(), + numShards, currentPhase ); } @@ -483,7 +484,7 @@ void onShardFailure(final int shardIndex, SearchShardTarget shardTarget, Excepti synchronized (shardFailuresMutex) { shardFailures = this.shardFailures.get(); // read again otherwise somebody else has created it? if (shardFailures == null) { // still null so we are the first and create a new instance - shardFailures = new AtomicArray<>(getNumShards()); + shardFailures = new AtomicArray<>(results.getNumShards()); this.shardFailures.set(shardFailures); } } @@ -550,13 +551,6 @@ private void successfulShardExecution() { } } - /** - * Returns the total number of shards to the current search across all indices - */ - public final int getNumShards() { - return results.getNumShards(); - } - /** * Returns a logger for this context to prevent each individual phase to create their own logger. */ @@ -606,12 +600,13 @@ private SearchResponse buildSearchResponse( ) { int numSuccess = successfulOps.get(); int numFailures = failures.length; - assert numSuccess + numFailures == getNumShards() - : "numSuccess(" + numSuccess + ") + numFailures(" + numFailures + ") != totalShards(" + getNumShards() + ")"; + final int numShards = results.getNumShards(); + assert numSuccess + numFailures == numShards + : "numSuccess(" + numSuccess + ") + numFailures(" + numFailures + ") != totalShards(" + numShards + ")"; return new SearchResponse( internalSearchResponse, scrollId, - getNumShards(), + numShards, numSuccess, toSkipShardsIts.size(), buildTookInMillis(), @@ -746,7 +741,7 @@ protected final ShardSearchRequest buildShardSearchRequest(SearchShardIterator s request, shardIt.shardId(), shardIndex, - getNumShards(), + results.getNumShards(), filter, indexBoost, timeProvider.absoluteStartMillis(), diff --git a/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java index 080295210fced..3986f4a8b507c 100644 --- a/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java @@ -74,14 +74,6 @@ final class FetchSearchPhase extends SearchPhase { BiFunction, SearchPhase> nextPhaseFactory ) { super(NAME); - if (context.getNumShards() != resultConsumer.getNumShards()) { - throw new IllegalStateException( - "number of shards must match the length of the query results but doesn't:" - + context.getNumShards() - + "!=" - + resultConsumer.getNumShards() - ); - } this.searchPhaseShardResults = resultConsumer.getAtomicArray(); this.aggregatedDfs = aggregatedDfs; this.nextPhaseFactory = nextPhaseFactory; @@ -112,10 +104,10 @@ private void innerRun() throws Exception { assert this.reducedQueryPhase == null ^ this.resultConsumer == null; // depending on whether we executed the RankFeaturePhase we may or may not have the reduced query result computed already final var reducedQueryPhase = this.reducedQueryPhase == null ? resultConsumer.reduce() : this.reducedQueryPhase; - final int numShards = context.getNumShards(); // Usually when there is a single shard, we force the search type QUERY_THEN_FETCH. But when there's kNN, we might // still use DFS_QUERY_THEN_FETCH, which does not perform the "query and fetch" optimization during the query phase. - final boolean queryAndFetchOptimization = searchPhaseShardResults.length() == 1 + final int numShards = searchPhaseShardResults.length(); + final boolean queryAndFetchOptimization = numShards == 1 && context.getRequest().hasKnnSearch() == false && reducedQueryPhase.queryPhaseRankCoordinatorContext() == null && (context.getRequest().source() == null || context.getRequest().source().rankBuilder() == null); @@ -130,7 +122,7 @@ private void innerRun() throws Exception { // we have to release contexts here to free up resources searchPhaseShardResults.asList() .forEach(searchPhaseShardResult -> releaseIrrelevantSearchContext(searchPhaseShardResult, context)); - moveToNextPhase(new AtomicArray<>(numShards), reducedQueryPhase); + moveToNextPhase(new AtomicArray<>(0), reducedQueryPhase); } else { innerRunFetch(scoreDocs, numShards, reducedQueryPhase); } diff --git a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java index e9302883457e1..2c903fee16c1b 100644 --- a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java @@ -56,18 +56,10 @@ public class RankFeaturePhase extends SearchPhase { super(NAME); assert rankFeaturePhaseRankCoordinatorContext != null; this.rankFeaturePhaseRankCoordinatorContext = rankFeaturePhaseRankCoordinatorContext; - if (context.getNumShards() != queryPhaseResults.getNumShards()) { - throw new IllegalStateException( - "number of shards must match the length of the query results but doesn't:" - + context.getNumShards() - + "!=" - + queryPhaseResults.getNumShards() - ); - } this.context = context; this.queryPhaseResults = queryPhaseResults; this.aggregatedDfs = aggregatedDfs; - this.rankPhaseResults = new ArraySearchPhaseResults<>(context.getNumShards()); + this.rankPhaseResults = new ArraySearchPhaseResults<>(queryPhaseResults.getNumShards()); context.addReleasable(rankPhaseResults); this.progressListener = context.getTask().getProgressListener(); } @@ -96,10 +88,11 @@ void innerRun(RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordin // to operate on the first `rank_window_size * num_shards` results and merge them appropriately. SearchPhaseController.ReducedQueryPhase reducedQueryPhase = queryPhaseResults.reduce(); ScoreDoc[] queryScoreDocs = reducedQueryPhase.sortedTopDocs().scoreDocs(); // rank_window_size - final List[] docIdsToLoad = SearchPhaseController.fillDocIdsToLoad(context.getNumShards(), queryScoreDocs); + final int numShards = queryPhaseResults.getNumShards(); + final List[] docIdsToLoad = SearchPhaseController.fillDocIdsToLoad(numShards, queryScoreDocs); final CountedCollector rankRequestCounter = new CountedCollector<>( rankPhaseResults, - context.getNumShards(), + numShards, () -> onPhaseDone(rankFeaturePhaseRankCoordinatorContext, reducedQueryPhase), context );