Skip to content

Commit 8a94037

Browse files
Fix unnecessary context switch in RankFeaturePhase (#113232)
If we don't actually execute this phase we shouldn't fork the phase unnecessarily. We can compute the RankFeaturePhaseRankCoordinatorContext on the transport thread and move on to fetch without forking. Fetch itself will then fork and we can run the reduce as part of fetch instead of in a separte search pool task (this is the way it worked up until the recent introduction of RankFeaturePhase, this fixes that regression).
1 parent b00129a commit 8a94037

File tree

4 files changed

+61
-45
lines changed

4 files changed

+61
-45
lines changed

server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.apache.lucene.search.ScoreDoc;
1313
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
1414
import org.elasticsearch.common.util.concurrent.AtomicArray;
15+
import org.elasticsearch.core.Nullable;
1516
import org.elasticsearch.search.SearchPhaseResult;
1617
import org.elasticsearch.search.SearchShardTarget;
1718
import org.elasticsearch.search.dfs.AggregatedDfs;
@@ -39,13 +40,15 @@ final class FetchSearchPhase extends SearchPhase {
3940
private final Logger logger;
4041
private final SearchProgressListener progressListener;
4142
private final AggregatedDfs aggregatedDfs;
43+
@Nullable
44+
private final SearchPhaseResults<SearchPhaseResult> resultConsumer;
4245
private final SearchPhaseController.ReducedQueryPhase reducedQueryPhase;
4346

4447
FetchSearchPhase(
4548
SearchPhaseResults<SearchPhaseResult> resultConsumer,
4649
AggregatedDfs aggregatedDfs,
4750
SearchPhaseContext context,
48-
SearchPhaseController.ReducedQueryPhase reducedQueryPhase
51+
@Nullable SearchPhaseController.ReducedQueryPhase reducedQueryPhase
4952
) {
5053
this(
5154
resultConsumer,
@@ -64,7 +67,7 @@ final class FetchSearchPhase extends SearchPhase {
6467
SearchPhaseResults<SearchPhaseResult> resultConsumer,
6568
AggregatedDfs aggregatedDfs,
6669
SearchPhaseContext context,
67-
SearchPhaseController.ReducedQueryPhase reducedQueryPhase,
70+
@Nullable SearchPhaseController.ReducedQueryPhase reducedQueryPhase,
6871
BiFunction<SearchResponseSections, AtomicArray<SearchPhaseResult>, SearchPhase> nextPhaseFactory
6972
) {
7073
super("fetch");
@@ -85,14 +88,15 @@ final class FetchSearchPhase extends SearchPhase {
8588
this.logger = context.getLogger();
8689
this.progressListener = context.getTask().getProgressListener();
8790
this.reducedQueryPhase = reducedQueryPhase;
91+
this.resultConsumer = reducedQueryPhase == null ? resultConsumer : null;
8892
}
8993

9094
@Override
9195
public void run() {
9296
context.execute(new AbstractRunnable() {
9397

9498
@Override
95-
protected void doRun() {
99+
protected void doRun() throws Exception {
96100
innerRun();
97101
}
98102

@@ -103,7 +107,10 @@ public void onFailure(Exception e) {
103107
});
104108
}
105109

106-
private void innerRun() {
110+
private void innerRun() throws Exception {
111+
assert this.reducedQueryPhase == null ^ this.resultConsumer == null;
112+
// depending on whether we executed the RankFeaturePhase we may or may not have the reduced query result computed already
113+
final var reducedQueryPhase = this.reducedQueryPhase == null ? resultConsumer.reduce() : this.reducedQueryPhase;
107114
final int numShards = context.getNumShards();
108115
// Usually when there is a single shard, we force the search type QUERY_THEN_FETCH. But when there's kNN, we might
109116
// still use DFS_QUERY_THEN_FETCH, which does not perform the "query and fetch" optimization during the query phase.
@@ -113,15 +120,15 @@ private void innerRun() {
113120
if (queryAndFetchOptimization) {
114121
assert assertConsistentWithQueryAndFetchOptimization();
115122
// query AND fetch optimization
116-
moveToNextPhase(searchPhaseShardResults);
123+
moveToNextPhase(searchPhaseShardResults, reducedQueryPhase);
117124
} else {
118125
ScoreDoc[] scoreDocs = reducedQueryPhase.sortedTopDocs().scoreDocs();
119126
// no docs to fetch -- sidestep everything and return
120127
if (scoreDocs.length == 0) {
121128
// we have to release contexts here to free up resources
122129
searchPhaseShardResults.asList()
123130
.forEach(searchPhaseShardResult -> releaseIrrelevantSearchContext(searchPhaseShardResult, context));
124-
moveToNextPhase(fetchResults.getAtomicArray());
131+
moveToNextPhase(fetchResults.getAtomicArray(), reducedQueryPhase);
125132
} else {
126133
final boolean shouldExplainRank = shouldExplainRankScores(context.getRequest());
127134
final List<Map<Integer, RankDoc>> rankDocsPerShard = false == shouldExplainRank
@@ -134,7 +141,7 @@ private void innerRun() {
134141
final CountedCollector<FetchSearchResult> counter = new CountedCollector<>(
135142
fetchResults,
136143
docIdsToLoad.length, // we count down every shard in the result no matter if we got any results or not
137-
() -> moveToNextPhase(fetchResults.getAtomicArray()),
144+
() -> moveToNextPhase(fetchResults.getAtomicArray(), reducedQueryPhase),
138145
context
139146
);
140147
for (int i = 0; i < docIdsToLoad.length; i++) {
@@ -243,7 +250,10 @@ public void onFailure(Exception e) {
243250
);
244251
}
245252

246-
private void moveToNextPhase(AtomicArray<? extends SearchPhaseResult> fetchResultsArr) {
253+
private void moveToNextPhase(
254+
AtomicArray<? extends SearchPhaseResult> fetchResultsArr,
255+
SearchPhaseController.ReducedQueryPhase reducedQueryPhase
256+
) {
247257
var resp = SearchPhaseController.merge(context.getRequest().scroll() != null, reducedQueryPhase, fetchResultsArr);
248258
context.addReleasable(resp::decRef);
249259
fetchResults.close();

server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java

Lines changed: 28 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,20 @@ public class RankFeaturePhase extends SearchPhase {
7070

7171
@Override
7272
public void run() {
73+
RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext = coordinatorContext(context.getRequest().source());
74+
if (rankFeaturePhaseRankCoordinatorContext == null) {
75+
moveToNextPhase(queryPhaseResults, null);
76+
return;
77+
}
78+
7379
context.execute(new AbstractRunnable() {
7480
@Override
7581
protected void doRun() throws Exception {
7682
// we need to reduce the results at this point instead of fetch phase, so we fork this process similarly to how
7783
// was set up at FetchSearchPhase.
7884

7985
// we do the heavy lifting in this inner run method where we reduce aggs etc
80-
innerRun();
86+
innerRun(rankFeaturePhaseRankCoordinatorContext);
8187
}
8288

8389
@Override
@@ -87,51 +93,39 @@ public void onFailure(Exception e) {
8793
});
8894
}
8995

90-
void innerRun() throws Exception {
96+
void innerRun(RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext) throws Exception {
9197
// if the RankBuilder specifies a QueryPhaseCoordinatorContext, it will be called as part of the reduce call
9298
// to operate on the first `rank_window_size * num_shards` results and merge them appropriately.
9399
SearchPhaseController.ReducedQueryPhase reducedQueryPhase = queryPhaseResults.reduce();
94-
RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext = coordinatorContext(context.getRequest().source());
95-
if (rankFeaturePhaseRankCoordinatorContext != null) {
96-
ScoreDoc[] queryScoreDocs = reducedQueryPhase.sortedTopDocs().scoreDocs(); // rank_window_size
97-
final List<Integer>[] docIdsToLoad = SearchPhaseController.fillDocIdsToLoad(context.getNumShards(), queryScoreDocs);
98-
final CountedCollector<SearchPhaseResult> rankRequestCounter = new CountedCollector<>(
99-
rankPhaseResults,
100-
context.getNumShards(),
101-
() -> onPhaseDone(rankFeaturePhaseRankCoordinatorContext, reducedQueryPhase),
102-
context
103-
);
100+
ScoreDoc[] queryScoreDocs = reducedQueryPhase.sortedTopDocs().scoreDocs(); // rank_window_size
101+
final List<Integer>[] docIdsToLoad = SearchPhaseController.fillDocIdsToLoad(context.getNumShards(), queryScoreDocs);
102+
final CountedCollector<SearchPhaseResult> rankRequestCounter = new CountedCollector<>(
103+
rankPhaseResults,
104+
context.getNumShards(),
105+
() -> onPhaseDone(rankFeaturePhaseRankCoordinatorContext, reducedQueryPhase),
106+
context
107+
);
104108

105-
// we send out a request to each shard in order to fetch the needed feature info
106-
for (int i = 0; i < docIdsToLoad.length; i++) {
107-
List<Integer> entry = docIdsToLoad[i];
108-
SearchPhaseResult queryResult = queryPhaseResults.getAtomicArray().get(i);
109-
if (entry == null || entry.isEmpty()) {
110-
if (queryResult != null) {
111-
releaseIrrelevantSearchContext(queryResult, context);
112-
progressListener.notifyRankFeatureResult(i);
113-
}
114-
rankRequestCounter.countDown();
115-
} else {
116-
executeRankFeatureShardPhase(queryResult, rankRequestCounter, entry);
109+
// we send out a request to each shard in order to fetch the needed feature info
110+
for (int i = 0; i < docIdsToLoad.length; i++) {
111+
List<Integer> entry = docIdsToLoad[i];
112+
SearchPhaseResult queryResult = queryPhaseResults.getAtomicArray().get(i);
113+
if (entry == null || entry.isEmpty()) {
114+
if (queryResult != null) {
115+
releaseIrrelevantSearchContext(queryResult, context);
116+
progressListener.notifyRankFeatureResult(i);
117117
}
118+
rankRequestCounter.countDown();
119+
} else {
120+
executeRankFeatureShardPhase(queryResult, rankRequestCounter, entry);
118121
}
119-
} else {
120-
moveToNextPhase(queryPhaseResults, reducedQueryPhase);
121122
}
122123
}
123124

124125
private RankFeaturePhaseRankCoordinatorContext coordinatorContext(SearchSourceBuilder source) {
125126
return source == null || source.rankBuilder() == null
126127
? null
127-
: context.getRequest()
128-
.source()
129-
.rankBuilder()
130-
.buildRankFeaturePhaseCoordinatorContext(
131-
context.getRequest().source().size(),
132-
context.getRequest().source().from(),
133-
client
134-
);
128+
: source.rankBuilder().buildRankFeaturePhaseCoordinatorContext(source.size(), source.from(), client);
135129
}
136130

137131
private void executeRankFeatureShardPhase(

server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,7 @@ public void sendExecuteRankFeature(
536536
// override the RankFeaturePhase to raise an exception
537537
RankFeaturePhase rankFeaturePhase = new RankFeaturePhase(results, null, mockSearchPhaseContext, null) {
538538
@Override
539-
void innerRun() {
539+
void innerRun(RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext) {
540540
throw new IllegalArgumentException("simulated failure");
541541
}
542542

@@ -1142,7 +1142,13 @@ public void moveToNextPhase(
11421142
) {
11431143
// this is called after the RankFeaturePhaseCoordinatorContext has been executed
11441144
phaseDone.set(true);
1145-
finalResults[0] = reducedQueryPhase.sortedTopDocs().scoreDocs();
1145+
try {
1146+
finalResults[0] = reducedQueryPhase == null
1147+
? queryPhaseResults.reduce().sortedTopDocs().scoreDocs()
1148+
: reducedQueryPhase.sortedTopDocs().scoreDocs();
1149+
} catch (Exception e) {
1150+
throw new AssertionError(e);
1151+
}
11461152
logger.debug("Skipping moving to next phase");
11471153
}
11481154
};

server/src/test/java/org/elasticsearch/search/rank/RankFeatureShardPhaseTests.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.apache.lucene.search.TotalHits;
1515
import org.elasticsearch.TransportVersion;
1616
import org.elasticsearch.TransportVersions;
17+
import org.elasticsearch.action.ActionListener;
1718
import org.elasticsearch.client.internal.Client;
1819
import org.elasticsearch.common.document.DocumentField;
1920
import org.elasticsearch.common.io.stream.StreamOutput;
@@ -170,7 +171,12 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId)
170171
// no work to be done on the coordinator node for the rank feature phase
171172
@Override
172173
public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from, Client client) {
173-
return null;
174+
return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) {
175+
@Override
176+
protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
177+
throw new AssertionError("not expected");
178+
}
179+
};
174180
}
175181

176182
@Override

0 commit comments

Comments
 (0)