2020import org .elasticsearch .action .OriginalIndices ;
2121import org .elasticsearch .action .ShardOperationFailedException ;
2222import org .elasticsearch .action .search .TransportSearchAction .SearchTimeProvider ;
23+ import org .elasticsearch .action .support .SubscribableListener ;
2324import org .elasticsearch .action .support .TransportActions ;
2425import org .elasticsearch .cluster .ClusterState ;
2526import org .elasticsearch .cluster .routing .GroupShardsIterator ;
2627import org .elasticsearch .common .bytes .BytesReference ;
2728import org .elasticsearch .common .io .stream .NamedWriteableRegistry ;
28- import org .elasticsearch .common .util .concurrent .AbstractRunnable ;
2929import org .elasticsearch .common .util .concurrent .AtomicArray ;
30- import org .elasticsearch .common .util .concurrent .EsRejectedExecutionException ;
3130import org .elasticsearch .core .Releasable ;
3231import org .elasticsearch .core .Releasables ;
3332import org .elasticsearch .index .shard .ShardId ;
4342import org .elasticsearch .tasks .TaskCancelledException ;
4443import org .elasticsearch .transport .Transport ;
4544
46- import java .util .ArrayDeque ;
4745import java .util .ArrayList ;
4846import java .util .Collections ;
4947import java .util .HashMap ;
5048import java .util .List ;
5149import java .util .Map ;
5250import java .util .concurrent .ConcurrentHashMap ;
5351import java .util .concurrent .Executor ;
52+ import java .util .concurrent .LinkedTransferQueue ;
53+ import java .util .concurrent .Semaphore ;
5454import java .util .concurrent .atomic .AtomicBoolean ;
5555import java .util .concurrent .atomic .AtomicInteger ;
5656import java .util .function .BiFunction ;
57+ import java .util .function .Consumer ;
5758import java .util .stream .Collectors ;
5859
5960import static org .elasticsearch .core .Strings .format ;
@@ -238,7 +239,12 @@ public final void run() {
238239 assert shardRoutings .skip () == false ;
239240 assert shardIndexMap .containsKey (shardRoutings );
240241 int shardIndex = shardIndexMap .get (shardRoutings );
241- performPhaseOnShard (shardIndex , shardRoutings , shardRoutings .nextOrNull ());
242+ final SearchShardTarget routing = shardRoutings .nextOrNull ();
243+ if (routing == null ) {
244+ failOnUnavailable (shardIndex , shardRoutings );
245+ } else {
246+ performPhaseOnShard (shardIndex , shardRoutings , routing );
247+ }
242248 }
243249 }
244250 }
@@ -258,7 +264,7 @@ private static boolean assertExecuteOnStartThread() {
258264 int index = 0 ;
259265 assert stackTraceElements [index ++].getMethodName ().equals ("getStackTrace" );
260266 assert stackTraceElements [index ++].getMethodName ().equals ("assertExecuteOnStartThread" );
261- assert stackTraceElements [index ++].getMethodName ().equals ("performPhaseOnShard " );
267+ assert stackTraceElements [index ++].getMethodName ().equals ("failOnUnavailable " );
262268 if (stackTraceElements [index ].getMethodName ().equals ("performPhaseOnShard" )) {
263269 assert stackTraceElements [index ].getClassName ().endsWith ("CanMatchPreFilterSearchPhase" );
264270 index ++;
@@ -277,65 +283,53 @@ private static boolean assertExecuteOnStartThread() {
277283 }
278284
279285 protected void performPhaseOnShard (final int shardIndex , final SearchShardIterator shardIt , final SearchShardTarget shard ) {
280- /*
281- * We capture the thread that this phase is starting on. When we are called back after executing the phase, we are either on the
282- * same thread (because we never went async, or the same thread was selected from the thread pool) or a different thread. If we
283- * continue on the same thread in the case that we never went async and this happens repeatedly we will end up recursing deeply and
284- * could stack overflow. To prevent this, we fork if we are called back on the same thread that execution started on and otherwise
285- * we can continue (cf. InitialSearchPhase#maybeFork).
286- */
287- if (shard == null ) {
288- assert assertExecuteOnStartThread ();
289- SearchShardTarget unassignedShard = new SearchShardTarget (null , shardIt .shardId (), shardIt .getClusterAlias ());
290- onShardFailure (shardIndex , unassignedShard , shardIt , new NoShardAvailableActionException (shardIt .shardId ()));
286+ if (throttleConcurrentRequests ) {
287+ var pendingExecutions = pendingExecutionsPerNode .computeIfAbsent (
288+ shard .getNodeId (),
289+ n -> new PendingExecutions (maxConcurrentRequestsPerNode )
290+ );
291+ pendingExecutions .submit (l -> doPerformPhaseOnShard (shardIndex , shardIt , shard , l ));
291292 } else {
292- final PendingExecutions pendingExecutions = throttleConcurrentRequests
293- ? pendingExecutionsPerNode .computeIfAbsent (shard .getNodeId (), n -> new PendingExecutions (maxConcurrentRequestsPerNode ))
294- : null ;
295- Runnable r = () -> {
296- final Thread thread = Thread .currentThread ();
297- try {
298- executePhaseOnShard (shardIt , shard , new SearchActionListener <>(shard , shardIndex ) {
299- @ Override
300- public void innerOnResponse (Result result ) {
301- try {
302- onShardResult (result , shardIt );
303- } catch (Exception exc ) {
304- onShardFailure (shardIndex , shard , shardIt , exc );
305- } finally {
306- executeNext (pendingExecutions , thread );
307- }
308- }
293+ doPerformPhaseOnShard (shardIndex , shardIt , shard , () -> {});
294+ }
295+ }
309296
310- @ Override
311- public void onFailure (Exception t ) {
312- try {
313- onShardFailure (shardIndex , shard , shardIt , t );
314- } finally {
315- executeNext (pendingExecutions , thread );
316- }
317- }
318- });
319- } catch (final Exception e ) {
320- try {
321- /*
322- * It is possible to run into connection exceptions here because we are getting the connection early and might
323- * run into nodes that are not connected. In this case, on shard failure will move us to the next shard copy.
324- */
325- fork (() -> onShardFailure (shardIndex , shard , shardIt , e ));
326- } finally {
327- executeNext (pendingExecutions , thread );
297+ private void doPerformPhaseOnShard (int shardIndex , SearchShardIterator shardIt , SearchShardTarget shard , Releasable releasable ) {
298+ try {
299+ executePhaseOnShard (shardIt , shard , new SearchActionListener <>(shard , shardIndex ) {
300+ @ Override
301+ public void innerOnResponse (Result result ) {
302+ try (releasable ) {
303+ onShardResult (result , shardIt );
304+ } catch (Exception exc ) {
305+ onShardFailure (shardIndex , shard , shardIt , exc );
328306 }
329307 }
330- };
331- if (throttleConcurrentRequests ) {
332- pendingExecutions .tryRun (r );
333- } else {
334- r .run ();
308+
309+ @ Override
310+ public void onFailure (Exception e ) {
311+ try (releasable ) {
312+ onShardFailure (shardIndex , shard , shardIt , e );
313+ }
314+ }
315+ });
316+ } catch (final Exception e ) {
317+ /*
318+ * It is possible to run into connection exceptions here because we are getting the connection early and might
319+ * run into nodes that are not connected. In this case, on shard failure will move us to the next shard copy.
320+ */
321+ try (releasable ) {
322+ onShardFailure (shardIndex , shard , shardIt , e );
335323 }
336324 }
337325 }
338326
327+ private void failOnUnavailable (int shardIndex , SearchShardIterator shardIt ) {
328+ assert assertExecuteOnStartThread ();
329+ SearchShardTarget unassignedShard = new SearchShardTarget (null , shardIt .shardId (), shardIt .getClusterAlias ());
330+ onShardFailure (shardIndex , unassignedShard , shardIt , new NoShardAvailableActionException (shardIt .shardId ()));
331+ }
332+
339333 /**
340334 * Sends the request to the actual shard.
341335 * @param shardIt the shards iterator
@@ -348,34 +342,6 @@ protected abstract void executePhaseOnShard(
348342 SearchActionListener <Result > listener
349343 );
350344
351- protected void fork (final Runnable runnable ) {
352- executor .execute (new AbstractRunnable () {
353- @ Override
354- public void onFailure (Exception e ) {
355- logger .error (() -> "unexpected error during [" + task + "]" , e );
356- assert false : e ;
357- }
358-
359- @ Override
360- public void onRejection (Exception e ) {
361- // avoid leaks during node shutdown by executing on the current thread if the executor shuts down
362- assert e instanceof EsRejectedExecutionException esre && esre .isExecutorShutdown () : e ;
363- doRun ();
364- }
365-
366- @ Override
367- protected void doRun () {
368- runnable .run ();
369- }
370-
371- @ Override
372- public boolean isForceExecution () {
373- // we can not allow a stuffed queue to reject execution here
374- return true ;
375- }
376- });
377- }
378-
379345 @ Override
380346 public final void executeNextPhase (SearchPhase currentPhase , SearchPhase nextPhase ) {
381347 /* This is the main search phase transition where we move to the next phase. If all shards
@@ -794,61 +760,63 @@ protected final ShardSearchRequest buildShardSearchRequest(SearchShardIterator s
794760 */
795761 protected abstract SearchPhase getNextPhase (SearchPhaseResults <Result > results , SearchPhaseContext context );
796762
797- private void executeNext (PendingExecutions pendingExecutions , Thread originalThread ) {
798- executeNext (pendingExecutions == null ? null : pendingExecutions .finishAndGetNext (), originalThread );
799- }
800-
801- void executeNext (Runnable runnable , Thread originalThread ) {
802- if (runnable != null ) {
803- assert throttleConcurrentRequests ;
804- if (originalThread == Thread .currentThread ()) {
805- fork (runnable );
806- } else {
807- runnable .run ();
808- }
809- }
810- }
811-
812763 private static final class PendingExecutions {
813- private final int permits ;
814- private int permitsTaken = 0 ;
815- private final ArrayDeque <Runnable > queue = new ArrayDeque <>();
764+ private final Semaphore semaphore ;
765+ private final LinkedTransferQueue <Consumer <Releasable >> queue = new LinkedTransferQueue <>();
816766
817767 PendingExecutions (int permits ) {
818768 assert permits > 0 : "not enough permits: " + permits ;
819- this . permits = permits ;
769+ semaphore = new Semaphore ( permits ) ;
820770 }
821771
822- Runnable finishAndGetNext () {
823- synchronized (this ) {
824- permitsTaken --;
825- assert permitsTaken >= 0 : "illegal taken permits: " + permitsTaken ;
772+ void submit (Consumer <Releasable > task ) {
773+ if (semaphore .tryAcquire ()) {
774+ executeAndRelease (task );
775+ } else {
776+ queue .add (task );
777+ if (semaphore .tryAcquire ()) {
778+ task = pollNextTaskOrReleasePermit ();
779+ if (task != null ) {
780+ executeAndRelease (task );
781+ }
782+ }
826783 }
827- return tryQueue ( null );
784+
828785 }
829786
830- void tryRun (Runnable runnable ) {
831- Runnable r = tryQueue (runnable );
832- if (r != null ) {
833- r .run ();
787+ private void executeAndRelease (Consumer <Releasable > task ) {
788+ while (task != null ) {
789+ final SubscribableListener <Void > onDone = new SubscribableListener <>();
790+ task .accept (() -> onDone .onResponse (null ));
791+ if (onDone .isDone ()) {
792+ // keep going on the current thread, no need to fork
793+ task = pollNextTaskOrReleasePermit ();
794+ } else {
795+ onDone .addListener (new ActionListener <>() {
796+ @ Override
797+ public void onResponse (Void unused ) {
798+ final Consumer <Releasable > nextTask = pollNextTaskOrReleasePermit ();
799+ if (nextTask != null ) {
800+ executeAndRelease (nextTask );
801+ }
802+ }
803+
804+ @ Override
805+ public void onFailure (Exception e ) {
806+ assert false : e ;
807+ }
808+ });
809+ return ;
810+ }
834811 }
835812 }
836813
837- private synchronized Runnable tryQueue (Runnable runnable ) {
838- Runnable toExecute = null ;
839- if (permitsTaken < permits ) {
840- permitsTaken ++;
841- toExecute = runnable ;
842- if (toExecute == null ) { // only poll if we don't have anything to execute
843- toExecute = queue .poll ();
844- }
845- if (toExecute == null ) {
846- permitsTaken --;
847- }
848- } else if (runnable != null ) {
849- queue .add (runnable );
814+ private Consumer <Releasable > pollNextTaskOrReleasePermit () {
815+ var task = queue .poll ();
816+ if (task == null ) {
817+ semaphore .release ();
850818 }
851- return toExecute ;
819+ return task ;
852820 }
853821 }
854822}
0 commit comments