Skip to content

Commit fa32445

Browse files
cleanups
1 parent 5b2028b commit fa32445

File tree

1 file changed

+70
-76
lines changed

1 file changed

+70
-76
lines changed

server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java

Lines changed: 70 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -488,8 +488,7 @@ private void run() {
488488
// TODO: stupid but we kinda need to fill all of these in with the current logic, do something nicer before merging
489489
final Map<SearchShardIterator, Integer> shardIndexMap = Maps.newHashMapWithExpectedSize(shardIterators.length);
490490
for (int i = 0; i < shardIterators.length; i++) {
491-
var iterator = shardIterators[i];
492-
shardIndexMap.put(iterator, i);
491+
shardIndexMap.put(shardIterators[i], i);
493492
}
494493
final boolean supportsBatchedQuery = minTransportVersion.onOrAfter(TransportVersions.BATCHED_QUERY_PHASE_VERSION);
495494
final Map<String, NodeQueryRequest> perNodeQueries = new HashMap<>();
@@ -773,11 +772,9 @@ public void sendReleaseSearchContext(ShardSearchContextId contextId, Transport.C
773772

774773
public static void registerNodeSearchAction(SearchTransportService searchTransportService, SearchService searchService) {
775774
var transportService = searchTransportService.transportService();
776-
final Dependencies dependencies = new Dependencies(
777-
searchService,
778-
transportService.getThreadPool().executor(ThreadPool.Names.SEARCH)
779-
);
780-
final int searchPoolMax = transportService.getThreadPool().info(ThreadPool.Names.SEARCH).getMax();
775+
var threadPool = transportService.getThreadPool();
776+
final Dependencies dependencies = new Dependencies(searchService, threadPool.executor(ThreadPool.Names.SEARCH));
777+
final int searchPoolMax = threadPool.info(ThreadPool.Names.SEARCH).getMax();
781778
final SearchPhaseController searchPhaseController = new SearchPhaseController(searchService::aggReduceContextBuilder);
782779
transportService.registerRequestHandler(
783780
NODE_SEARCH_ACTION_NAME,
@@ -856,16 +853,13 @@ protected void doRun() {
856853
public void onResponse(SearchPhaseResult searchPhaseResult) {
857854
try {
858855
searchPhaseResult.setShardIndex(dataNodeLocalIdx);
859-
final SearchShardTarget target = new SearchShardTarget(
860-
null,
861-
shardToQuery.shardId,
862-
request.searchRequest.getLocalClusterAlias()
856+
searchPhaseResult.setSearchShardTarget(
857+
new SearchShardTarget(null, shardToQuery.shardId, request.searchRequest.getLocalClusterAlias())
863858
);
864-
searchPhaseResult.setSearchShardTarget(target);
865859
// no need for any cache effects when we're already flipped to ture => plain read + set-release
866860
state.hasResponse.compareAndExchangeRelease(false, true);
867861
state.consumeResult(searchPhaseResult.queryResult());
868-
state.queryPhaseResultConsumer.consumeResult(searchPhaseResult, state.onDone);
862+
state.queryPhaseResultConsumer.consumeResult(searchPhaseResult, state::onDone);
869863
} catch (Exception e) {
870864
setFailure(state, dataNodeLocalIdx, e);
871865
} finally {
@@ -875,7 +869,7 @@ public void onResponse(SearchPhaseResult searchPhaseResult) {
875869

876870
private void setFailure(QueryPerNodeState state, int dataNodeLocalIdx, Exception e) {
877871
state.failures.put(dataNodeLocalIdx, e);
878-
state.onDone.run();
872+
state.onDone();
879873
}
880874

881875
@Override
@@ -904,7 +898,7 @@ public void onFailure(Exception e) {
904898
// TODO this could be done better now, we probably should only make sure to have a single loop running at
905899
// minimum and ignore + requeue rejections in that case
906900
state.failures.put(dataNodeLocalIdx, e);
907-
state.onDone.run();
901+
state.onDone();
908902
// TODO SO risk!
909903
maybeNext();
910904
}
@@ -941,10 +935,11 @@ private static final class QueryPerNodeState {
941935
private final CancellableTask task;
942936
private final ConcurrentHashMap<Integer, Exception> failures = new ConcurrentHashMap<>();
943937
private final Dependencies dependencies;
944-
private final Runnable onDone;
945938
private final AtomicBoolean hasResponse = new AtomicBoolean(false);
946939
private final int trackTotalHitsUpTo;
947940
private final int topDocsSize;
941+
private final CountDown countDown;
942+
private final TransportChannel channel;
948943
private volatile BottomSortValuesCollector bottomSortCollector;
949944

950945
private QueryPerNodeState(
@@ -961,70 +956,69 @@ private QueryPerNodeState(
961956
this.trackTotalHitsUpTo = searchRequest.searchRequest.resolveTrackTotalHitsUpTo();
962957
topDocsSize = getTopDocsSize(searchRequest.searchRequest);
963958
this.task = task;
964-
final int shardCount = queryPhaseResultConsumer.getNumShards();
965-
final CountDown countDown = new CountDown(shardCount);
959+
countDown = new CountDown(queryPhaseResultConsumer.getNumShards());
960+
this.channel = channel;
966961
this.dependencies = dependencies;
967-
this.onDone = () -> {
968-
if (countDown.countDown()) {
969-
var channelListener = new ChannelActionListener<>(channel);
970-
try (queryPhaseResultConsumer) {
971-
var failure = queryPhaseResultConsumer.failure.get();
972-
if (failure != null) {
973-
queryPhaseResultConsumer.getSuccessfulResults()
974-
.forEach(searchPhaseResult -> maybeRelease(dependencies.searchService, searchRequest, searchPhaseResult));
975-
channelListener.onFailure(failure);
976-
return;
977-
}
978-
final Object[] results = new Object[shardCount];
979-
for (int i = 0; i < results.length; i++) {
980-
var e = failures.get(i);
981-
var res = queryPhaseResultConsumer.results.get(i);
982-
if (e != null) {
983-
results[i] = e;
984-
assert res == null;
985-
} else {
986-
results[i] = res;
987-
assert results[i] != null;
988-
}
989-
}
990-
final QueryPhaseResultConsumer.MergeResult mergeResult;
991-
try {
992-
mergeResult = Objects.requireNonNullElse(
993-
queryPhaseResultConsumer.consumePartialResult(),
994-
EMPTY_PARTIAL_MERGE_RESULT
995-
);
996-
} catch (Exception e) {
997-
channelListener.onFailure(e);
998-
return;
999-
}
1000-
// translate shard indices to those on the coordinator so that it can interpret the merge result without adjustments
1001-
final Set<Integer> relevantShardIndices = new HashSet<>();
1002-
for (ScoreDoc scoreDoc : mergeResult.reducedTopDocs().scoreDocs) {
1003-
final int localIndex = scoreDoc.shardIndex;
1004-
scoreDoc.shardIndex = searchRequest.shards.get(localIndex).shardIndex;
1005-
relevantShardIndices.add(localIndex);
1006-
}
1007-
for (Object result : results) {
1008-
if (result instanceof QuerySearchResult q
1009-
&& q.getContextId() != null
1010-
&& relevantShardIndices.contains(q.getShardIndex()) == false
1011-
&& q.hasSuggestHits() == false
1012-
&& q.getRankShardResult() == null
1013-
&& searchRequest.searchRequest.scroll() == null
1014-
&& (AsyncSearchContext.isPartOfPIT(null, searchRequest.searchRequest, q.getContextId()) == false)) {
1015-
if (dependencies.searchService.freeReaderContext(q.getContextId())) {
1016-
q.clearContextId();
1017-
}
1018-
}
1019-
}
962+
}
1020963

1021-
ActionListener.respondAndRelease(
1022-
channelListener,
1023-
new NodeQueryResponse(mergeResult, results, queryPhaseResultConsumer.topDocsStats)
1024-
);
964+
void onDone() {
965+
if (countDown.countDown() == false) {
966+
return;
967+
}
968+
var channelListener = new ChannelActionListener<>(channel);
969+
try (queryPhaseResultConsumer) {
970+
var failure = queryPhaseResultConsumer.failure.get();
971+
if (failure != null) {
972+
queryPhaseResultConsumer.getSuccessfulResults()
973+
.forEach(searchPhaseResult -> maybeRelease(dependencies.searchService, searchRequest, searchPhaseResult));
974+
channelListener.onFailure(failure);
975+
return;
976+
}
977+
final Object[] results = new Object[queryPhaseResultConsumer.getNumShards()];
978+
for (int i = 0; i < results.length; i++) {
979+
var e = failures.get(i);
980+
var res = queryPhaseResultConsumer.results.get(i);
981+
if (e != null) {
982+
results[i] = e;
983+
assert res == null;
984+
} else {
985+
results[i] = res;
986+
assert results[i] != null;
1025987
}
1026988
}
1027-
};
989+
final QueryPhaseResultConsumer.MergeResult mergeResult;
990+
try {
991+
mergeResult = Objects.requireNonNullElse(queryPhaseResultConsumer.consumePartialResult(), EMPTY_PARTIAL_MERGE_RESULT);
992+
} catch (Exception e) {
993+
channelListener.onFailure(e);
994+
return;
995+
}
996+
// translate shard indices to those on the coordinator so that it can interpret the merge result without adjustments
997+
final Set<Integer> relevantShardIndices = new HashSet<>();
998+
for (ScoreDoc scoreDoc : mergeResult.reducedTopDocs().scoreDocs) {
999+
final int localIndex = scoreDoc.shardIndex;
1000+
scoreDoc.shardIndex = searchRequest.shards.get(localIndex).shardIndex;
1001+
relevantShardIndices.add(localIndex);
1002+
}
1003+
for (Object result : results) {
1004+
if (result instanceof QuerySearchResult q
1005+
&& q.getContextId() != null
1006+
&& relevantShardIndices.contains(q.getShardIndex()) == false
1007+
&& q.hasSuggestHits() == false
1008+
&& q.getRankShardResult() == null
1009+
&& searchRequest.searchRequest.scroll() == null
1010+
&& (AsyncSearchContext.isPartOfPIT(null, searchRequest.searchRequest, q.getContextId()) == false)) {
1011+
if (dependencies.searchService.freeReaderContext(q.getContextId())) {
1012+
q.clearContextId();
1013+
}
1014+
}
1015+
}
1016+
1017+
ActionListener.respondAndRelease(
1018+
channelListener,
1019+
new NodeQueryResponse(mergeResult, results, queryPhaseResultConsumer.topDocsStats)
1020+
);
1021+
}
10281022
}
10291023

10301024
void consumeResult(QuerySearchResult queryResult) {

0 commit comments

Comments
 (0)