From 89c25f63ddbfd348ad26ca26cb8a7de6c46c030c Mon Sep 17 00:00:00 2001 From: Armin Braun Date: Thu, 6 Feb 2025 03:45:57 +0100 Subject: [PATCH] Remove redundant numShards getter from AbstractSearchAsyncAction In preparation of batched execution changes, remove this redundant interface. We always have other ways of getting the total shard count and even assert that they are equivalent. No need to have a redundant interface method like this around before extracting an interface from this class. --- .../search/AbstractSearchAsyncAction.java | 29 ++++++++----------- .../action/search/FetchSearchPhase.java | 14 ++------- .../action/search/RankFeaturePhase.java | 15 +++------- 3 files changed, 19 insertions(+), 39 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java index 44752d6f33600..b5f7a3facb4ba 100644 --- a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java @@ -196,7 +196,7 @@ long buildTookInMillis() { * This is the main entry point for a search. This method starts the search execution of the initial phase. */ public final void start() { - if (getNumShards() == 0) { + if (results.getNumShards() == 0) { // no search shards to search on, bail with empty response // (it happens with search across _all with no indices around and consistent with broadcast operations) int trackTotalHitsUpTo = request.source() == null ? SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO @@ -341,7 +341,8 @@ protected void executeNextPhase(String currentPhase, Supplier nextP * fail. Otherwise we continue to the next phase. */ ShardOperationFailedException[] shardSearchFailures = buildShardFailures(); - if (shardSearchFailures.length == getNumShards()) { + final int numShards = results.getNumShards(); + if (shardSearchFailures.length == numShards) { shardSearchFailures = ExceptionsHelper.groupBy(shardSearchFailures); Throwable cause = shardSearchFailures.length == 0 ? null @@ -351,7 +352,7 @@ protected void executeNextPhase(String currentPhase, Supplier nextP } else { Boolean allowPartialResults = request.allowPartialSearchResults(); assert allowPartialResults != null : "SearchRequest missing setting for allowPartialSearchResults"; - if (allowPartialResults == false && successfulOps.get() != getNumShards()) { + if (allowPartialResults == false && successfulOps.get() != numShards) { // check if there are actual failures in the atomic array since // successful retries can reset the failures to null if (shardSearchFailures.length > 0) { @@ -363,7 +364,7 @@ protected void executeNextPhase(String currentPhase, Supplier nextP } onPhaseFailure(currentPhase, "Partial shards failure", null); } else { - int discrepancy = getNumShards() - successfulOps.get(); + int discrepancy = numShards - successfulOps.get(); assert discrepancy > 0 : "discrepancy: " + discrepancy; if (logger.isDebugEnabled()) { logger.debug( @@ -371,7 +372,7 @@ protected void executeNextPhase(String currentPhase, Supplier nextP discrepancy, successfulOps.get(), toSkipShardsIts.size(), - getNumShards(), + numShards, currentPhase ); } @@ -483,7 +484,7 @@ void onShardFailure(final int shardIndex, SearchShardTarget shardTarget, Excepti synchronized (shardFailuresMutex) { shardFailures = this.shardFailures.get(); // read again otherwise somebody else has created it? if (shardFailures == null) { // still null so we are the first and create a new instance - shardFailures = new AtomicArray<>(getNumShards()); + shardFailures = new AtomicArray<>(results.getNumShards()); this.shardFailures.set(shardFailures); } } @@ -550,13 +551,6 @@ private void successfulShardExecution() { } } - /** - * Returns the total number of shards to the current search across all indices - */ - public final int getNumShards() { - return results.getNumShards(); - } - /** * Returns a logger for this context to prevent each individual phase to create their own logger. */ @@ -606,12 +600,13 @@ private SearchResponse buildSearchResponse( ) { int numSuccess = successfulOps.get(); int numFailures = failures.length; - assert numSuccess + numFailures == getNumShards() - : "numSuccess(" + numSuccess + ") + numFailures(" + numFailures + ") != totalShards(" + getNumShards() + ")"; + final int numShards = results.getNumShards(); + assert numSuccess + numFailures == numShards + : "numSuccess(" + numSuccess + ") + numFailures(" + numFailures + ") != totalShards(" + numShards + ")"; return new SearchResponse( internalSearchResponse, scrollId, - getNumShards(), + numShards, numSuccess, toSkipShardsIts.size(), buildTookInMillis(), @@ -746,7 +741,7 @@ protected final ShardSearchRequest buildShardSearchRequest(SearchShardIterator s request, shardIt.shardId(), shardIndex, - getNumShards(), + results.getNumShards(), filter, indexBoost, timeProvider.absoluteStartMillis(), diff --git a/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java index 080295210fced..3986f4a8b507c 100644 --- a/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java @@ -74,14 +74,6 @@ final class FetchSearchPhase extends SearchPhase { BiFunction, SearchPhase> nextPhaseFactory ) { super(NAME); - if (context.getNumShards() != resultConsumer.getNumShards()) { - throw new IllegalStateException( - "number of shards must match the length of the query results but doesn't:" - + context.getNumShards() - + "!=" - + resultConsumer.getNumShards() - ); - } this.searchPhaseShardResults = resultConsumer.getAtomicArray(); this.aggregatedDfs = aggregatedDfs; this.nextPhaseFactory = nextPhaseFactory; @@ -112,10 +104,10 @@ private void innerRun() throws Exception { assert this.reducedQueryPhase == null ^ this.resultConsumer == null; // depending on whether we executed the RankFeaturePhase we may or may not have the reduced query result computed already final var reducedQueryPhase = this.reducedQueryPhase == null ? resultConsumer.reduce() : this.reducedQueryPhase; - final int numShards = context.getNumShards(); // Usually when there is a single shard, we force the search type QUERY_THEN_FETCH. But when there's kNN, we might // still use DFS_QUERY_THEN_FETCH, which does not perform the "query and fetch" optimization during the query phase. - final boolean queryAndFetchOptimization = searchPhaseShardResults.length() == 1 + final int numShards = searchPhaseShardResults.length(); + final boolean queryAndFetchOptimization = numShards == 1 && context.getRequest().hasKnnSearch() == false && reducedQueryPhase.queryPhaseRankCoordinatorContext() == null && (context.getRequest().source() == null || context.getRequest().source().rankBuilder() == null); @@ -130,7 +122,7 @@ private void innerRun() throws Exception { // we have to release contexts here to free up resources searchPhaseShardResults.asList() .forEach(searchPhaseShardResult -> releaseIrrelevantSearchContext(searchPhaseShardResult, context)); - moveToNextPhase(new AtomicArray<>(numShards), reducedQueryPhase); + moveToNextPhase(new AtomicArray<>(0), reducedQueryPhase); } else { innerRunFetch(scoreDocs, numShards, reducedQueryPhase); } diff --git a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java index e9302883457e1..2c903fee16c1b 100644 --- a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java @@ -56,18 +56,10 @@ public class RankFeaturePhase extends SearchPhase { super(NAME); assert rankFeaturePhaseRankCoordinatorContext != null; this.rankFeaturePhaseRankCoordinatorContext = rankFeaturePhaseRankCoordinatorContext; - if (context.getNumShards() != queryPhaseResults.getNumShards()) { - throw new IllegalStateException( - "number of shards must match the length of the query results but doesn't:" - + context.getNumShards() - + "!=" - + queryPhaseResults.getNumShards() - ); - } this.context = context; this.queryPhaseResults = queryPhaseResults; this.aggregatedDfs = aggregatedDfs; - this.rankPhaseResults = new ArraySearchPhaseResults<>(context.getNumShards()); + this.rankPhaseResults = new ArraySearchPhaseResults<>(queryPhaseResults.getNumShards()); context.addReleasable(rankPhaseResults); this.progressListener = context.getTask().getProgressListener(); } @@ -96,10 +88,11 @@ void innerRun(RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordin // to operate on the first `rank_window_size * num_shards` results and merge them appropriately. SearchPhaseController.ReducedQueryPhase reducedQueryPhase = queryPhaseResults.reduce(); ScoreDoc[] queryScoreDocs = reducedQueryPhase.sortedTopDocs().scoreDocs(); // rank_window_size - final List[] docIdsToLoad = SearchPhaseController.fillDocIdsToLoad(context.getNumShards(), queryScoreDocs); + final int numShards = queryPhaseResults.getNumShards(); + final List[] docIdsToLoad = SearchPhaseController.fillDocIdsToLoad(numShards, queryScoreDocs); final CountedCollector rankRequestCounter = new CountedCollector<>( rankPhaseResults, - context.getNumShards(), + numShards, () -> onPhaseDone(rankFeaturePhaseRankCoordinatorContext, reducedQueryPhase), context );