@@ -488,8 +488,7 @@ private void run() {
488488 // TODO: stupid but we kinda need to fill all of these in with the current logic, do something nicer before merging
489489 final Map <SearchShardIterator , Integer > shardIndexMap = Maps .newHashMapWithExpectedSize (shardIterators .length );
490490 for (int i = 0 ; i < shardIterators .length ; i ++) {
491- var iterator = shardIterators [i ];
492- shardIndexMap .put (iterator , i );
491+ shardIndexMap .put (shardIterators [i ], i );
493492 }
494493 final boolean supportsBatchedQuery = minTransportVersion .onOrAfter (TransportVersions .BATCHED_QUERY_PHASE_VERSION );
495494 final Map <String , NodeQueryRequest > perNodeQueries = new HashMap <>();
@@ -773,11 +772,9 @@ public void sendReleaseSearchContext(ShardSearchContextId contextId, Transport.C
773772
774773 public static void registerNodeSearchAction (SearchTransportService searchTransportService , SearchService searchService ) {
775774 var transportService = searchTransportService .transportService ();
776- final Dependencies dependencies = new Dependencies (
777- searchService ,
778- transportService .getThreadPool ().executor (ThreadPool .Names .SEARCH )
779- );
780- final int searchPoolMax = transportService .getThreadPool ().info (ThreadPool .Names .SEARCH ).getMax ();
775+ var threadPool = transportService .getThreadPool ();
776+ final Dependencies dependencies = new Dependencies (searchService , threadPool .executor (ThreadPool .Names .SEARCH ));
777+ final int searchPoolMax = threadPool .info (ThreadPool .Names .SEARCH ).getMax ();
781778 final SearchPhaseController searchPhaseController = new SearchPhaseController (searchService ::aggReduceContextBuilder );
782779 transportService .registerRequestHandler (
783780 NODE_SEARCH_ACTION_NAME ,
@@ -856,16 +853,13 @@ protected void doRun() {
856853 public void onResponse (SearchPhaseResult searchPhaseResult ) {
857854 try {
858855 searchPhaseResult .setShardIndex (dataNodeLocalIdx );
859- final SearchShardTarget target = new SearchShardTarget (
860- null ,
861- shardToQuery .shardId ,
862- request .searchRequest .getLocalClusterAlias ()
856+ searchPhaseResult .setSearchShardTarget (
857+ new SearchShardTarget (null , shardToQuery .shardId , request .searchRequest .getLocalClusterAlias ())
863858 );
864- searchPhaseResult .setSearchShardTarget (target );
865859 // no need for any cache effects when we're already flipped to ture => plain read + set-release
866860 state .hasResponse .compareAndExchangeRelease (false , true );
867861 state .consumeResult (searchPhaseResult .queryResult ());
868- state .queryPhaseResultConsumer .consumeResult (searchPhaseResult , state . onDone );
862+ state .queryPhaseResultConsumer .consumeResult (searchPhaseResult , state :: onDone );
869863 } catch (Exception e ) {
870864 setFailure (state , dataNodeLocalIdx , e );
871865 } finally {
@@ -875,7 +869,7 @@ public void onResponse(SearchPhaseResult searchPhaseResult) {
875869
876870 private void setFailure (QueryPerNodeState state , int dataNodeLocalIdx , Exception e ) {
877871 state .failures .put (dataNodeLocalIdx , e );
878- state .onDone . run ();
872+ state .onDone ();
879873 }
880874
881875 @ Override
@@ -904,7 +898,7 @@ public void onFailure(Exception e) {
904898 // TODO this could be done better now, we probably should only make sure to have a single loop running at
905899 // minimum and ignore + requeue rejections in that case
906900 state .failures .put (dataNodeLocalIdx , e );
907- state .onDone . run ();
901+ state .onDone ();
908902 // TODO SO risk!
909903 maybeNext ();
910904 }
@@ -941,10 +935,11 @@ private static final class QueryPerNodeState {
941935 private final CancellableTask task ;
942936 private final ConcurrentHashMap <Integer , Exception > failures = new ConcurrentHashMap <>();
943937 private final Dependencies dependencies ;
944- private final Runnable onDone ;
945938 private final AtomicBoolean hasResponse = new AtomicBoolean (false );
946939 private final int trackTotalHitsUpTo ;
947940 private final int topDocsSize ;
941+ private final CountDown countDown ;
942+ private final TransportChannel channel ;
948943 private volatile BottomSortValuesCollector bottomSortCollector ;
949944
950945 private QueryPerNodeState (
@@ -961,70 +956,69 @@ private QueryPerNodeState(
961956 this .trackTotalHitsUpTo = searchRequest .searchRequest .resolveTrackTotalHitsUpTo ();
962957 topDocsSize = getTopDocsSize (searchRequest .searchRequest );
963958 this .task = task ;
964- final int shardCount = queryPhaseResultConsumer .getNumShards ();
965- final CountDown countDown = new CountDown ( shardCount ) ;
959+ countDown = new CountDown ( queryPhaseResultConsumer .getNumShards () );
960+ this . channel = channel ;
966961 this .dependencies = dependencies ;
967- this .onDone = () -> {
968- if (countDown .countDown ()) {
969- var channelListener = new ChannelActionListener <>(channel );
970- try (queryPhaseResultConsumer ) {
971- var failure = queryPhaseResultConsumer .failure .get ();
972- if (failure != null ) {
973- queryPhaseResultConsumer .getSuccessfulResults ()
974- .forEach (searchPhaseResult -> maybeRelease (dependencies .searchService , searchRequest , searchPhaseResult ));
975- channelListener .onFailure (failure );
976- return ;
977- }
978- final Object [] results = new Object [shardCount ];
979- for (int i = 0 ; i < results .length ; i ++) {
980- var e = failures .get (i );
981- var res = queryPhaseResultConsumer .results .get (i );
982- if (e != null ) {
983- results [i ] = e ;
984- assert res == null ;
985- } else {
986- results [i ] = res ;
987- assert results [i ] != null ;
988- }
989- }
990- final QueryPhaseResultConsumer .MergeResult mergeResult ;
991- try {
992- mergeResult = Objects .requireNonNullElse (
993- queryPhaseResultConsumer .consumePartialResult (),
994- EMPTY_PARTIAL_MERGE_RESULT
995- );
996- } catch (Exception e ) {
997- channelListener .onFailure (e );
998- return ;
999- }
1000- // translate shard indices to those on the coordinator so that it can interpret the merge result without adjustments
1001- final Set <Integer > relevantShardIndices = new HashSet <>();
1002- for (ScoreDoc scoreDoc : mergeResult .reducedTopDocs ().scoreDocs ) {
1003- final int localIndex = scoreDoc .shardIndex ;
1004- scoreDoc .shardIndex = searchRequest .shards .get (localIndex ).shardIndex ;
1005- relevantShardIndices .add (localIndex );
1006- }
1007- for (Object result : results ) {
1008- if (result instanceof QuerySearchResult q
1009- && q .getContextId () != null
1010- && relevantShardIndices .contains (q .getShardIndex ()) == false
1011- && q .hasSuggestHits () == false
1012- && q .getRankShardResult () == null
1013- && searchRequest .searchRequest .scroll () == null
1014- && (AsyncSearchContext .isPartOfPIT (null , searchRequest .searchRequest , q .getContextId ()) == false )) {
1015- if (dependencies .searchService .freeReaderContext (q .getContextId ())) {
1016- q .clearContextId ();
1017- }
1018- }
1019- }
962+ }
1020963
1021- ActionListener .respondAndRelease (
1022- channelListener ,
1023- new NodeQueryResponse (mergeResult , results , queryPhaseResultConsumer .topDocsStats )
1024- );
964+ void onDone () {
965+ if (countDown .countDown () == false ) {
966+ return ;
967+ }
968+ var channelListener = new ChannelActionListener <>(channel );
969+ try (queryPhaseResultConsumer ) {
970+ var failure = queryPhaseResultConsumer .failure .get ();
971+ if (failure != null ) {
972+ queryPhaseResultConsumer .getSuccessfulResults ()
973+ .forEach (searchPhaseResult -> maybeRelease (dependencies .searchService , searchRequest , searchPhaseResult ));
974+ channelListener .onFailure (failure );
975+ return ;
976+ }
977+ final Object [] results = new Object [queryPhaseResultConsumer .getNumShards ()];
978+ for (int i = 0 ; i < results .length ; i ++) {
979+ var e = failures .get (i );
980+ var res = queryPhaseResultConsumer .results .get (i );
981+ if (e != null ) {
982+ results [i ] = e ;
983+ assert res == null ;
984+ } else {
985+ results [i ] = res ;
986+ assert results [i ] != null ;
1025987 }
1026988 }
1027- };
989+ final QueryPhaseResultConsumer .MergeResult mergeResult ;
990+ try {
991+ mergeResult = Objects .requireNonNullElse (queryPhaseResultConsumer .consumePartialResult (), EMPTY_PARTIAL_MERGE_RESULT );
992+ } catch (Exception e ) {
993+ channelListener .onFailure (e );
994+ return ;
995+ }
996+ // translate shard indices to those on the coordinator so that it can interpret the merge result without adjustments
997+ final Set <Integer > relevantShardIndices = new HashSet <>();
998+ for (ScoreDoc scoreDoc : mergeResult .reducedTopDocs ().scoreDocs ) {
999+ final int localIndex = scoreDoc .shardIndex ;
1000+ scoreDoc .shardIndex = searchRequest .shards .get (localIndex ).shardIndex ;
1001+ relevantShardIndices .add (localIndex );
1002+ }
1003+ for (Object result : results ) {
1004+ if (result instanceof QuerySearchResult q
1005+ && q .getContextId () != null
1006+ && relevantShardIndices .contains (q .getShardIndex ()) == false
1007+ && q .hasSuggestHits () == false
1008+ && q .getRankShardResult () == null
1009+ && searchRequest .searchRequest .scroll () == null
1010+ && (AsyncSearchContext .isPartOfPIT (null , searchRequest .searchRequest , q .getContextId ()) == false )) {
1011+ if (dependencies .searchService .freeReaderContext (q .getContextId ())) {
1012+ q .clearContextId ();
1013+ }
1014+ }
1015+ }
1016+
1017+ ActionListener .respondAndRelease (
1018+ channelListener ,
1019+ new NodeQueryResponse (mergeResult , results , queryPhaseResultConsumer .topDocsStats )
1020+ );
1021+ }
10281022 }
10291023
10301024 void consumeResult (QuerySearchResult queryResult ) {
0 commit comments