2121import org .elasticsearch .action .OriginalIndices ;
2222import org .elasticsearch .action .ShardOperationFailedException ;
2323import org .elasticsearch .action .search .TransportSearchAction .SearchTimeProvider ;
24+ import org .elasticsearch .action .support .SubscribableListener ;
2425import org .elasticsearch .action .support .TransportActions ;
2526import org .elasticsearch .cluster .ClusterState ;
2627import org .elasticsearch .cluster .routing .GroupShardsIterator ;
2728import org .elasticsearch .common .bytes .BytesReference ;
2829import org .elasticsearch .common .io .stream .NamedWriteableRegistry ;
29- import org .elasticsearch .common .util .concurrent .AbstractRunnable ;
3030import org .elasticsearch .common .util .concurrent .AtomicArray ;
31- import org .elasticsearch .common .util .concurrent .EsRejectedExecutionException ;
3231import org .elasticsearch .core .Releasable ;
3332import org .elasticsearch .core .Releasables ;
3433import org .elasticsearch .index .shard .ShardId ;
4443import org .elasticsearch .tasks .TaskCancelledException ;
4544import org .elasticsearch .transport .Transport ;
4645
47- import java .util .ArrayDeque ;
4846import java .util .ArrayList ;
4947import java .util .Collections ;
5048import java .util .HashMap ;
5149import java .util .List ;
5250import java .util .Map ;
5351import java .util .concurrent .ConcurrentHashMap ;
5452import java .util .concurrent .Executor ;
53+ import java .util .concurrent .LinkedTransferQueue ;
54+ import java .util .concurrent .Semaphore ;
5555import java .util .concurrent .atomic .AtomicBoolean ;
5656import java .util .concurrent .atomic .AtomicInteger ;
5757import java .util .function .BiFunction ;
58+ import java .util .function .Consumer ;
5859import java .util .stream .Collectors ;
5960
6061import static org .elasticsearch .core .Strings .format ;
@@ -248,7 +249,12 @@ public final void run() {
248249 assert shardRoutings .skip () == false ;
249250 assert shardIndexMap .containsKey (shardRoutings );
250251 int shardIndex = shardIndexMap .get (shardRoutings );
251- performPhaseOnShard (shardIndex , shardRoutings , shardRoutings .nextOrNull ());
252+ final SearchShardTarget routing = shardRoutings .nextOrNull ();
253+ if (routing == null ) {
254+ failOnUnavailable (shardIndex , shardRoutings );
255+ } else {
256+ performPhaseOnShard (shardIndex , shardRoutings , routing );
257+ }
252258 }
253259 }
254260 }
@@ -283,7 +289,7 @@ private static boolean assertExecuteOnStartThread() {
283289 int index = 0 ;
284290 assert stackTraceElements [index ++].getMethodName ().equals ("getStackTrace" );
285291 assert stackTraceElements [index ++].getMethodName ().equals ("assertExecuteOnStartThread" );
286- assert stackTraceElements [index ++].getMethodName ().equals ("performPhaseOnShard " );
292+ assert stackTraceElements [index ++].getMethodName ().equals ("failOnUnavailable " );
287293 if (stackTraceElements [index ].getMethodName ().equals ("performPhaseOnShard" )) {
288294 assert stackTraceElements [index ].getClassName ().endsWith ("CanMatchPreFilterSearchPhase" );
289295 index ++;
@@ -302,65 +308,53 @@ private static boolean assertExecuteOnStartThread() {
302308 }
303309
304310 protected void performPhaseOnShard (final int shardIndex , final SearchShardIterator shardIt , final SearchShardTarget shard ) {
305- /*
306- * We capture the thread that this phase is starting on. When we are called back after executing the phase, we are either on the
307- * same thread (because we never went async, or the same thread was selected from the thread pool) or a different thread. If we
308- * continue on the same thread in the case that we never went async and this happens repeatedly we will end up recursing deeply and
309- * could stack overflow. To prevent this, we fork if we are called back on the same thread that execution started on and otherwise
310- * we can continue (cf. InitialSearchPhase#maybeFork).
311- */
312- if (shard == null ) {
313- assert assertExecuteOnStartThread ();
314- SearchShardTarget unassignedShard = new SearchShardTarget (null , shardIt .shardId (), shardIt .getClusterAlias ());
315- onShardFailure (shardIndex , unassignedShard , shardIt , new NoShardAvailableActionException (shardIt .shardId ()));
311+ if (throttleConcurrentRequests ) {
312+ var pendingExecutions = pendingExecutionsPerNode .computeIfAbsent (
313+ shard .getNodeId (),
314+ n -> new PendingExecutions (maxConcurrentRequestsPerNode )
315+ );
316+ pendingExecutions .submit (l -> doPerformPhaseOnShard (shardIndex , shardIt , shard , l ));
316317 } else {
317- final PendingExecutions pendingExecutions = throttleConcurrentRequests
318- ? pendingExecutionsPerNode .computeIfAbsent (shard .getNodeId (), n -> new PendingExecutions (maxConcurrentRequestsPerNode ))
319- : null ;
320- Runnable r = () -> {
321- final Thread thread = Thread .currentThread ();
322- try {
323- executePhaseOnShard (shardIt , shard , new SearchActionListener <>(shard , shardIndex ) {
324- @ Override
325- public void innerOnResponse (Result result ) {
326- try {
327- onShardResult (result , shardIt );
328- } catch (Exception exc ) {
329- onShardFailure (shardIndex , shard , shardIt , exc );
330- } finally {
331- executeNext (pendingExecutions , thread );
332- }
333- }
318+ doPerformPhaseOnShard (shardIndex , shardIt , shard , () -> {});
319+ }
320+ }
334321
335- @ Override
336- public void onFailure (Exception t ) {
337- try {
338- onShardFailure (shardIndex , shard , shardIt , t );
339- } finally {
340- executeNext (pendingExecutions , thread );
341- }
342- }
343- });
344- } catch (final Exception e ) {
345- try {
346- /*
347- * It is possible to run into connection exceptions here because we are getting the connection early and might
348- * run into nodes that are not connected. In this case, on shard failure will move us to the next shard copy.
349- */
350- fork (() -> onShardFailure (shardIndex , shard , shardIt , e ));
351- } finally {
352- executeNext (pendingExecutions , thread );
322+ private void doPerformPhaseOnShard (int shardIndex , SearchShardIterator shardIt , SearchShardTarget shard , Releasable releasable ) {
323+ try {
324+ executePhaseOnShard (shardIt , shard , new SearchActionListener <>(shard , shardIndex ) {
325+ @ Override
326+ public void innerOnResponse (Result result ) {
327+ try (releasable ) {
328+ onShardResult (result , shardIt );
329+ } catch (Exception exc ) {
330+ onShardFailure (shardIndex , shard , shardIt , exc );
353331 }
354332 }
355- };
356- if (throttleConcurrentRequests ) {
357- pendingExecutions .tryRun (r );
358- } else {
359- r .run ();
333+
334+ @ Override
335+ public void onFailure (Exception e ) {
336+ try (releasable ) {
337+ onShardFailure (shardIndex , shard , shardIt , e );
338+ }
339+ }
340+ });
341+ } catch (final Exception e ) {
342+ /*
343+ * It is possible to run into connection exceptions here because we are getting the connection early and might
344+ * run into nodes that are not connected. In this case, on shard failure will move us to the next shard copy.
345+ */
346+ try (releasable ) {
347+ onShardFailure (shardIndex , shard , shardIt , e );
360348 }
361349 }
362350 }
363351
352+ private void failOnUnavailable (int shardIndex , SearchShardIterator shardIt ) {
353+ assert assertExecuteOnStartThread ();
354+ SearchShardTarget unassignedShard = new SearchShardTarget (null , shardIt .shardId (), shardIt .getClusterAlias ());
355+ onShardFailure (shardIndex , unassignedShard , shardIt , new NoShardAvailableActionException (shardIt .shardId ()));
356+ }
357+
364358 /**
365359 * Sends the request to the actual shard.
366360 * @param shardIt the shards iterator
@@ -373,34 +367,6 @@ protected abstract void executePhaseOnShard(
373367 SearchActionListener <Result > listener
374368 );
375369
376- protected void fork (final Runnable runnable ) {
377- executor .execute (new AbstractRunnable () {
378- @ Override
379- public void onFailure (Exception e ) {
380- logger .error (() -> "unexpected error during [" + task + "]" , e );
381- assert false : e ;
382- }
383-
384- @ Override
385- public void onRejection (Exception e ) {
386- // avoid leaks during node shutdown by executing on the current thread if the executor shuts down
387- assert e instanceof EsRejectedExecutionException esre && esre .isExecutorShutdown () : e ;
388- doRun ();
389- }
390-
391- @ Override
392- protected void doRun () {
393- runnable .run ();
394- }
395-
396- @ Override
397- public boolean isForceExecution () {
398- // we can not allow a stuffed queue to reject execution here
399- return true ;
400- }
401- });
402- }
403-
404370 @ Override
405371 public final void executeNextPhase (SearchPhase currentPhase , SearchPhase nextPhase ) {
406372 /* This is the main search phase transition where we move to the next phase. If all shards
@@ -824,61 +790,63 @@ protected final ShardSearchRequest buildShardSearchRequest(SearchShardIterator s
824790 */
825791 protected abstract SearchPhase getNextPhase (SearchPhaseResults <Result > results , SearchPhaseContext context );
826792
827- private void executeNext (PendingExecutions pendingExecutions , Thread originalThread ) {
828- executeNext (pendingExecutions == null ? null : pendingExecutions .finishAndGetNext (), originalThread );
829- }
830-
831- void executeNext (Runnable runnable , Thread originalThread ) {
832- if (runnable != null ) {
833- assert throttleConcurrentRequests ;
834- if (originalThread == Thread .currentThread ()) {
835- fork (runnable );
836- } else {
837- runnable .run ();
838- }
839- }
840- }
841-
842793 private static final class PendingExecutions {
843- private final int permits ;
844- private int permitsTaken = 0 ;
845- private final ArrayDeque <Runnable > queue = new ArrayDeque <>();
794+ private final Semaphore semaphore ;
795+ private final LinkedTransferQueue <Consumer <Releasable >> queue = new LinkedTransferQueue <>();
846796
847797 PendingExecutions (int permits ) {
848798 assert permits > 0 : "not enough permits: " + permits ;
849- this . permits = permits ;
799+ semaphore = new Semaphore ( permits ) ;
850800 }
851801
852- Runnable finishAndGetNext () {
853- synchronized (this ) {
854- permitsTaken --;
855- assert permitsTaken >= 0 : "illegal taken permits: " + permitsTaken ;
802+ void submit (Consumer <Releasable > task ) {
803+ if (semaphore .tryAcquire ()) {
804+ executeAndRelease (task );
805+ } else {
806+ queue .add (task );
807+ if (semaphore .tryAcquire ()) {
808+ task = pollNextTaskOrReleasePermit ();
809+ if (task != null ) {
810+ executeAndRelease (task );
811+ }
812+ }
856813 }
857- return tryQueue ( null );
814+
858815 }
859816
860- void tryRun (Runnable runnable ) {
861- Runnable r = tryQueue (runnable );
862- if (r != null ) {
863- r .run ();
817+ private void executeAndRelease (Consumer <Releasable > task ) {
818+ while (task != null ) {
819+ final SubscribableListener <Void > onDone = new SubscribableListener <>();
820+ task .accept (() -> onDone .onResponse (null ));
821+ if (onDone .isDone ()) {
822+ // keep going on the current thread, no need to fork
823+ task = pollNextTaskOrReleasePermit ();
824+ } else {
825+ onDone .addListener (new ActionListener <>() {
826+ @ Override
827+ public void onResponse (Void unused ) {
828+ final Consumer <Releasable > nextTask = pollNextTaskOrReleasePermit ();
829+ if (nextTask != null ) {
830+ executeAndRelease (nextTask );
831+ }
832+ }
833+
834+ @ Override
835+ public void onFailure (Exception e ) {
836+ assert false : e ;
837+ }
838+ });
839+ return ;
840+ }
864841 }
865842 }
866843
867- private synchronized Runnable tryQueue (Runnable runnable ) {
868- Runnable toExecute = null ;
869- if (permitsTaken < permits ) {
870- permitsTaken ++;
871- toExecute = runnable ;
872- if (toExecute == null ) { // only poll if we don't have anything to execute
873- toExecute = queue .poll ();
874- }
875- if (toExecute == null ) {
876- permitsTaken --;
877- }
878- } else if (runnable != null ) {
879- queue .add (runnable );
844+ private Consumer <Releasable > pollNextTaskOrReleasePermit () {
845+ var task = queue .poll ();
846+ if (task == null ) {
847+ semaphore .release ();
880848 }
881- return toExecute ;
849+ return task ;
882850 }
883851 }
884852}
0 commit comments