@@ -83,6 +83,9 @@ public class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction<S
8383 private static final Logger logger = LogManager .getLogger (SearchQueryThenFetchAsyncAction .class );
8484
8585 private static final TransportVersion BATCHED_QUERY_PHASE_VERSION = TransportVersion .fromName ("batched_query_phase_version" );
86+ private static final TransportVersion BATCHED_RESPONSE_MIGHT_INCLUDE_REDUCTION_FAILURE = TransportVersion .fromName (
87+ "batched_response_might_include_reduction_failure"
88+ );
8689
8790 private final SearchProgressListener progressListener ;
8891
@@ -221,20 +224,32 @@ public static final class NodeQueryResponse extends TransportResponse {
221224 private final RefCounted refCounted = LeakTracker .wrap (new SimpleRefCounted ());
222225
223226 private final Object [] results ;
227+ private final Exception reductionFailure ;
224228 private final SearchPhaseController .TopDocsStats topDocsStats ;
225229 private final QueryPhaseResultConsumer .MergeResult mergeResult ;
226230
227231 public NodeQueryResponse (StreamInput in ) throws IOException {
228232 this .results = in .readArray (i -> i .readBoolean () ? new QuerySearchResult (i ) : i .readException (), Object []::new );
229- this .mergeResult = QueryPhaseResultConsumer .MergeResult .readFrom (in );
230- this .topDocsStats = SearchPhaseController .TopDocsStats .readFrom (in );
233+ if (in .getTransportVersion ().supports (BATCHED_RESPONSE_MIGHT_INCLUDE_REDUCTION_FAILURE ) && in .readBoolean ()) {
234+ this .reductionFailure = in .readException ();
235+ this .mergeResult = null ;
236+ this .topDocsStats = null ;
237+ } else {
238+ this .reductionFailure = null ;
239+ this .mergeResult = QueryPhaseResultConsumer .MergeResult .readFrom (in );
240+ this .topDocsStats = SearchPhaseController .TopDocsStats .readFrom (in );
241+ }
231242 }
232243
233244 // public for tests
234245 public Object [] getResults () {
235246 return results ;
236247 }
237248
249+ Exception getReductionFailure () {
250+ return reductionFailure ;
251+ }
252+
238253 @ Override
239254 public void writeTo (StreamOutput out ) throws IOException {
240255 out .writeVInt (results .length );
@@ -245,7 +260,17 @@ public void writeTo(StreamOutput out) throws IOException {
245260 writePerShardResult (out , (QuerySearchResult ) result );
246261 }
247262 }
248- writeMergeResult (out , mergeResult , topDocsStats );
263+ if (out .getTransportVersion ().supports (BATCHED_RESPONSE_MIGHT_INCLUDE_REDUCTION_FAILURE )) {
264+ boolean hasReductionFailure = reductionFailure != null ;
265+ out .writeBoolean (hasReductionFailure );
266+ if (hasReductionFailure ) {
267+ out .writeException (reductionFailure );
268+ } else {
269+ writeMergeResult (out , mergeResult , topDocsStats );
270+ }
271+ } else {
272+ writeMergeResult (out , mergeResult , topDocsStats );
273+ }
249274 }
250275
251276 @ Override
@@ -498,7 +523,12 @@ public Executor executor() {
498523 @ Override
499524 public void handleResponse (NodeQueryResponse response ) {
500525 if (results instanceof QueryPhaseResultConsumer queryPhaseResultConsumer ) {
501- queryPhaseResultConsumer .addBatchedPartialResult (response .topDocsStats , response .mergeResult );
526+ Exception reductionFailure = response .getReductionFailure ();
527+ if (reductionFailure != null ) {
528+ queryPhaseResultConsumer .failure .compareAndSet (null , reductionFailure );
529+ } else {
530+ queryPhaseResultConsumer .addBatchedPartialResult (response .topDocsStats , response .mergeResult );
531+ }
502532 }
503533 for (int i = 0 ; i < response .results .length ; i ++) {
504534 var s = request .shards .get (i );
@@ -520,6 +550,21 @@ public void handleResponse(NodeQueryResponse response) {
520550
521551 @ Override
522552 public void handleException (TransportException e ) {
553+ if (connection .getTransportVersion ().supports (BATCHED_RESPONSE_MIGHT_INCLUDE_REDUCTION_FAILURE ) == false ) {
554+ bwcHandleException (e );
555+ return ;
556+ }
557+ Exception cause = (Exception ) ExceptionsHelper .unwrapCause (e );
558+ logger .debug ("handling node search exception coming from [" + nodeId + "]" , cause );
559+ onNodeQueryFailure (e , request , routing );
560+ }
561+
562+ /**
563+ * This code is strictly for _snapshot_ backwards compatibility. The feature flag
564+ * {@link SearchService#BATCHED_QUERY_PHASE_FEATURE_FLAG} was not turned on when the transport version
565+ * {@link SearchQueryThenFetchAsyncAction#BATCHED_RESPONSE_MIGHT_INCLUDE_REDUCTION_FAILURE} was introduced.
566+ */
567+ private void bwcHandleException (TransportException e ) {
523568 Exception cause = (Exception ) ExceptionsHelper .unwrapCause (e );
524569 logger .debug ("handling node search exception coming from [" + nodeId + "]" , cause );
525570 if (e instanceof SendRequestTransportException || cause instanceof TaskCancelledException ) {
@@ -791,13 +836,98 @@ void onShardDone() {
791836 if (countDown .countDown () == false ) {
792837 return ;
793838 }
839+ if (channel .getVersion ().supports (BATCHED_RESPONSE_MIGHT_INCLUDE_REDUCTION_FAILURE ) == false ) {
840+ bwcRespond ();
841+ return ;
842+ }
843+ var channelListener = new ChannelActionListener <>(channel );
844+ RecyclerBytesStreamOutput out = dependencies .transportService .newNetworkBytesStream ();
845+ out .setTransportVersion (channel .getVersion ());
846+ try (queryPhaseResultConsumer ) {
847+ Exception reductionFailure = queryPhaseResultConsumer .failure .get ();
848+ if (reductionFailure == null ) {
849+ writeSuccessfulResponse (out );
850+ } else {
851+ writeReductionFailureResponse (out , reductionFailure );
852+ }
853+ } catch (IOException e ) {
854+ releaseAllResultsContexts ();
855+ channelListener .onFailure (e );
856+ return ;
857+ }
858+ ActionListener .respondAndRelease (channelListener , new BytesTransportResponse (out .moveToBytesReference ()));
859+ }
860+
861+ // Writes the "successful" response (see NodeQueryResponse for the corresponding read logic)
862+ private void writeSuccessfulResponse (RecyclerBytesStreamOutput out ) throws IOException {
863+ final QueryPhaseResultConsumer .MergeResult mergeResult ;
864+ try {
865+ mergeResult = Objects .requireNonNullElse (
866+ queryPhaseResultConsumer .consumePartialMergeResultDataNode (),
867+ EMPTY_PARTIAL_MERGE_RESULT
868+ );
869+ } catch (Exception e ) {
870+ writeReductionFailureResponse (out , e );
871+ return ;
872+ }
873+ // translate shard indices to those on the coordinator so that it can interpret the merge result without adjustments,
874+ // also collect the set of indices that may be part of a subsequent fetch operation here so that we can release all other
875+ // indices without a roundtrip to the coordinating node
876+ final BitSet relevantShardIndices = new BitSet (searchRequest .shards .size ());
877+ if (mergeResult .reducedTopDocs () != null ) {
878+ for (ScoreDoc scoreDoc : mergeResult .reducedTopDocs ().scoreDocs ) {
879+ final int localIndex = scoreDoc .shardIndex ;
880+ scoreDoc .shardIndex = searchRequest .shards .get (localIndex ).shardIndex ;
881+ relevantShardIndices .set (localIndex );
882+ }
883+ }
884+ final int resultCount = queryPhaseResultConsumer .getNumShards ();
885+ out .writeVInt (resultCount );
886+ for (int i = 0 ; i < resultCount ; i ++) {
887+ var result = queryPhaseResultConsumer .results .get (i );
888+ if (result == null ) {
889+ NodeQueryResponse .writePerShardException (out , failures .remove (i ));
890+ } else {
891+ // free context id and remove it from the result right away in case we don't need it anymore
892+ maybeFreeContext (result , relevantShardIndices , namedWriteableRegistry );
893+ NodeQueryResponse .writePerShardResult (out , result );
894+ }
895+ }
896+ out .writeBoolean (false ); // does not have a reduction failure
897+ NodeQueryResponse .writeMergeResult (out , mergeResult , queryPhaseResultConsumer .topDocsStats );
898+ }
899+
900+ // Writes the "reduction failure" response (see NodeQueryResponse for the corresponding read logic)
901+ private void writeReductionFailureResponse (RecyclerBytesStreamOutput out , Exception reductionFailure ) throws IOException {
902+ final int resultCount = queryPhaseResultConsumer .getNumShards ();
903+ out .writeVInt (resultCount );
904+ for (int i = 0 ; i < resultCount ; i ++) {
905+ var result = queryPhaseResultConsumer .results .get (i );
906+ if (result == null ) {
907+ NodeQueryResponse .writePerShardException (out , failures .remove (i ));
908+ } else {
909+ NodeQueryResponse .writePerShardResult (out , result );
910+ }
911+ }
912+ out .writeBoolean (true ); // does have a reduction failure
913+ out .writeException (reductionFailure );
914+ releaseAllResultsContexts ();
915+ }
916+
917+ /**
918+ * This code is strictly for _snapshot_ backwards compatibility. The feature flag
919+ * {@link SearchService#BATCHED_QUERY_PHASE_FEATURE_FLAG} was not turned on when the transport version
920+ * {@link SearchQueryThenFetchAsyncAction#BATCHED_RESPONSE_MIGHT_INCLUDE_REDUCTION_FAILURE} was introduced.
921+ */
922+ void bwcRespond () {
794923 RecyclerBytesStreamOutput out = null ;
795924 boolean success = false ;
796925 var channelListener = new ChannelActionListener <>(channel );
797926 try (queryPhaseResultConsumer ) {
798927 var failure = queryPhaseResultConsumer .failure .get ();
799928 if (failure != null ) {
800- handleMergeFailure (failure , channelListener , namedWriteableRegistry );
929+ releaseAllResultsContexts ();
930+ channelListener .onFailure (failure );
801931 return ;
802932 }
803933 final QueryPhaseResultConsumer .MergeResult mergeResult ;
@@ -807,7 +937,8 @@ void onShardDone() {
807937 EMPTY_PARTIAL_MERGE_RESULT
808938 );
809939 } catch (Exception e ) {
810- handleMergeFailure (e , channelListener , namedWriteableRegistry );
940+ releaseAllResultsContexts ();
941+ channelListener .onFailure (e );
811942 return ;
812943 }
813944 // translate shard indices to those on the coordinator so that it can interpret the merge result without adjustments,
@@ -839,7 +970,8 @@ void onShardDone() {
839970 NodeQueryResponse .writeMergeResult (out , mergeResult , queryPhaseResultConsumer .topDocsStats );
840971 success = true ;
841972 } catch (IOException e ) {
842- handleMergeFailure (e , channelListener , namedWriteableRegistry );
973+ releaseAllResultsContexts ();
974+ channelListener .onFailure (e );
843975 return ;
844976 }
845977 } finally {
@@ -868,11 +1000,7 @@ && isPartOfPIT(searchRequest.searchRequest, q.getContextId(), namedWriteableRegi
8681000 }
8691001 }
8701002
871- private void handleMergeFailure (
872- Exception e ,
873- ChannelActionListener <TransportResponse > channelListener ,
874- NamedWriteableRegistry namedWriteableRegistry
875- ) {
1003+ private void releaseAllResultsContexts () {
8761004 queryPhaseResultConsumer .getSuccessfulResults ()
8771005 .forEach (
8781006 searchPhaseResult -> releaseLocalContext (
@@ -882,7 +1010,6 @@ private void handleMergeFailure(
8821010 namedWriteableRegistry
8831011 )
8841012 );
885- channelListener .onFailure (e );
8861013 }
8871014
8881015 void consumeResult (QuerySearchResult queryResult ) {
0 commit comments