Skip to content

Commit 3cc5796

Browse files
Optimize threading in AbstractSearchAsyncAction (#113230) (#115643)
Forking when an action completes on the current thread is needlessly heavy handed in preventing stack-overflows. Also, we don't need locking/synchronization to deal with a worker-count + queue length problem. Both of these allow for non-trivial optimization even in the current execution model, also this change helps with moving to a more efficient execution model by saving needless forking to the search pool in particular. -> refactored the code to never fork but instead avoid stack-depth issues through use of a `SubscribableListener` -> replaced our home brew queue and semaphore combination by JDK primitives which saves blocking synchronization on task start and completion.
1 parent b22b9c7 commit 3cc5796

File tree

1 file changed

+94
-126
lines changed

1 file changed

+94
-126
lines changed

server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java

Lines changed: 94 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,13 @@
2121
import org.elasticsearch.action.OriginalIndices;
2222
import org.elasticsearch.action.ShardOperationFailedException;
2323
import org.elasticsearch.action.search.TransportSearchAction.SearchTimeProvider;
24+
import org.elasticsearch.action.support.SubscribableListener;
2425
import org.elasticsearch.action.support.TransportActions;
2526
import org.elasticsearch.cluster.ClusterState;
2627
import org.elasticsearch.cluster.routing.GroupShardsIterator;
2728
import org.elasticsearch.common.bytes.BytesReference;
2829
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
29-
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
3030
import org.elasticsearch.common.util.concurrent.AtomicArray;
31-
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
3231
import org.elasticsearch.core.Releasable;
3332
import org.elasticsearch.core.Releasables;
3433
import org.elasticsearch.index.shard.ShardId;
@@ -44,17 +43,19 @@
4443
import org.elasticsearch.tasks.TaskCancelledException;
4544
import org.elasticsearch.transport.Transport;
4645

47-
import java.util.ArrayDeque;
4846
import java.util.ArrayList;
4947
import java.util.Collections;
5048
import java.util.HashMap;
5149
import java.util.List;
5250
import java.util.Map;
5351
import java.util.concurrent.ConcurrentHashMap;
5452
import java.util.concurrent.Executor;
53+
import java.util.concurrent.LinkedTransferQueue;
54+
import java.util.concurrent.Semaphore;
5555
import java.util.concurrent.atomic.AtomicBoolean;
5656
import java.util.concurrent.atomic.AtomicInteger;
5757
import java.util.function.BiFunction;
58+
import java.util.function.Consumer;
5859
import java.util.stream.Collectors;
5960

6061
import static org.elasticsearch.core.Strings.format;
@@ -248,7 +249,12 @@ public final void run() {
248249
assert shardRoutings.skip() == false;
249250
assert shardIndexMap.containsKey(shardRoutings);
250251
int shardIndex = shardIndexMap.get(shardRoutings);
251-
performPhaseOnShard(shardIndex, shardRoutings, shardRoutings.nextOrNull());
252+
final SearchShardTarget routing = shardRoutings.nextOrNull();
253+
if (routing == null) {
254+
failOnUnavailable(shardIndex, shardRoutings);
255+
} else {
256+
performPhaseOnShard(shardIndex, shardRoutings, routing);
257+
}
252258
}
253259
}
254260
}
@@ -283,7 +289,7 @@ private static boolean assertExecuteOnStartThread() {
283289
int index = 0;
284290
assert stackTraceElements[index++].getMethodName().equals("getStackTrace");
285291
assert stackTraceElements[index++].getMethodName().equals("assertExecuteOnStartThread");
286-
assert stackTraceElements[index++].getMethodName().equals("performPhaseOnShard");
292+
assert stackTraceElements[index++].getMethodName().equals("failOnUnavailable");
287293
if (stackTraceElements[index].getMethodName().equals("performPhaseOnShard")) {
288294
assert stackTraceElements[index].getClassName().endsWith("CanMatchPreFilterSearchPhase");
289295
index++;
@@ -302,65 +308,53 @@ private static boolean assertExecuteOnStartThread() {
302308
}
303309

304310
protected void performPhaseOnShard(final int shardIndex, final SearchShardIterator shardIt, final SearchShardTarget shard) {
305-
/*
306-
* We capture the thread that this phase is starting on. When we are called back after executing the phase, we are either on the
307-
* same thread (because we never went async, or the same thread was selected from the thread pool) or a different thread. If we
308-
* continue on the same thread in the case that we never went async and this happens repeatedly we will end up recursing deeply and
309-
* could stack overflow. To prevent this, we fork if we are called back on the same thread that execution started on and otherwise
310-
* we can continue (cf. InitialSearchPhase#maybeFork).
311-
*/
312-
if (shard == null) {
313-
assert assertExecuteOnStartThread();
314-
SearchShardTarget unassignedShard = new SearchShardTarget(null, shardIt.shardId(), shardIt.getClusterAlias());
315-
onShardFailure(shardIndex, unassignedShard, shardIt, new NoShardAvailableActionException(shardIt.shardId()));
311+
if (throttleConcurrentRequests) {
312+
var pendingExecutions = pendingExecutionsPerNode.computeIfAbsent(
313+
shard.getNodeId(),
314+
n -> new PendingExecutions(maxConcurrentRequestsPerNode)
315+
);
316+
pendingExecutions.submit(l -> doPerformPhaseOnShard(shardIndex, shardIt, shard, l));
316317
} else {
317-
final PendingExecutions pendingExecutions = throttleConcurrentRequests
318-
? pendingExecutionsPerNode.computeIfAbsent(shard.getNodeId(), n -> new PendingExecutions(maxConcurrentRequestsPerNode))
319-
: null;
320-
Runnable r = () -> {
321-
final Thread thread = Thread.currentThread();
322-
try {
323-
executePhaseOnShard(shardIt, shard, new SearchActionListener<>(shard, shardIndex) {
324-
@Override
325-
public void innerOnResponse(Result result) {
326-
try {
327-
onShardResult(result, shardIt);
328-
} catch (Exception exc) {
329-
onShardFailure(shardIndex, shard, shardIt, exc);
330-
} finally {
331-
executeNext(pendingExecutions, thread);
332-
}
333-
}
318+
doPerformPhaseOnShard(shardIndex, shardIt, shard, () -> {});
319+
}
320+
}
334321

335-
@Override
336-
public void onFailure(Exception t) {
337-
try {
338-
onShardFailure(shardIndex, shard, shardIt, t);
339-
} finally {
340-
executeNext(pendingExecutions, thread);
341-
}
342-
}
343-
});
344-
} catch (final Exception e) {
345-
try {
346-
/*
347-
* It is possible to run into connection exceptions here because we are getting the connection early and might
348-
* run into nodes that are not connected. In this case, on shard failure will move us to the next shard copy.
349-
*/
350-
fork(() -> onShardFailure(shardIndex, shard, shardIt, e));
351-
} finally {
352-
executeNext(pendingExecutions, thread);
322+
private void doPerformPhaseOnShard(int shardIndex, SearchShardIterator shardIt, SearchShardTarget shard, Releasable releasable) {
323+
try {
324+
executePhaseOnShard(shardIt, shard, new SearchActionListener<>(shard, shardIndex) {
325+
@Override
326+
public void innerOnResponse(Result result) {
327+
try (releasable) {
328+
onShardResult(result, shardIt);
329+
} catch (Exception exc) {
330+
onShardFailure(shardIndex, shard, shardIt, exc);
353331
}
354332
}
355-
};
356-
if (throttleConcurrentRequests) {
357-
pendingExecutions.tryRun(r);
358-
} else {
359-
r.run();
333+
334+
@Override
335+
public void onFailure(Exception e) {
336+
try (releasable) {
337+
onShardFailure(shardIndex, shard, shardIt, e);
338+
}
339+
}
340+
});
341+
} catch (final Exception e) {
342+
/*
343+
* It is possible to run into connection exceptions here because we are getting the connection early and might
344+
* run into nodes that are not connected. In this case, on shard failure will move us to the next shard copy.
345+
*/
346+
try (releasable) {
347+
onShardFailure(shardIndex, shard, shardIt, e);
360348
}
361349
}
362350
}
363351

352+
private void failOnUnavailable(int shardIndex, SearchShardIterator shardIt) {
353+
assert assertExecuteOnStartThread();
354+
SearchShardTarget unassignedShard = new SearchShardTarget(null, shardIt.shardId(), shardIt.getClusterAlias());
355+
onShardFailure(shardIndex, unassignedShard, shardIt, new NoShardAvailableActionException(shardIt.shardId()));
356+
}
357+
364358
/**
365359
* Sends the request to the actual shard.
366360
* @param shardIt the shards iterator
@@ -373,34 +367,6 @@ protected abstract void executePhaseOnShard(
373367
SearchActionListener<Result> listener
374368
);
375369

376-
protected void fork(final Runnable runnable) {
377-
executor.execute(new AbstractRunnable() {
378-
@Override
379-
public void onFailure(Exception e) {
380-
logger.error(() -> "unexpected error during [" + task + "]", e);
381-
assert false : e;
382-
}
383-
384-
@Override
385-
public void onRejection(Exception e) {
386-
// avoid leaks during node shutdown by executing on the current thread if the executor shuts down
387-
assert e instanceof EsRejectedExecutionException esre && esre.isExecutorShutdown() : e;
388-
doRun();
389-
}
390-
391-
@Override
392-
protected void doRun() {
393-
runnable.run();
394-
}
395-
396-
@Override
397-
public boolean isForceExecution() {
398-
// we can not allow a stuffed queue to reject execution here
399-
return true;
400-
}
401-
});
402-
}
403-
404370
@Override
405371
public final void executeNextPhase(SearchPhase currentPhase, SearchPhase nextPhase) {
406372
/* This is the main search phase transition where we move to the next phase. If all shards
@@ -824,61 +790,63 @@ protected final ShardSearchRequest buildShardSearchRequest(SearchShardIterator s
824790
*/
825791
protected abstract SearchPhase getNextPhase(SearchPhaseResults<Result> results, SearchPhaseContext context);
826792

827-
private void executeNext(PendingExecutions pendingExecutions, Thread originalThread) {
828-
executeNext(pendingExecutions == null ? null : pendingExecutions.finishAndGetNext(), originalThread);
829-
}
830-
831-
void executeNext(Runnable runnable, Thread originalThread) {
832-
if (runnable != null) {
833-
assert throttleConcurrentRequests;
834-
if (originalThread == Thread.currentThread()) {
835-
fork(runnable);
836-
} else {
837-
runnable.run();
838-
}
839-
}
840-
}
841-
842793
private static final class PendingExecutions {
843-
private final int permits;
844-
private int permitsTaken = 0;
845-
private final ArrayDeque<Runnable> queue = new ArrayDeque<>();
794+
private final Semaphore semaphore;
795+
private final LinkedTransferQueue<Consumer<Releasable>> queue = new LinkedTransferQueue<>();
846796

847797
PendingExecutions(int permits) {
848798
assert permits > 0 : "not enough permits: " + permits;
849-
this.permits = permits;
799+
semaphore = new Semaphore(permits);
850800
}
851801

852-
Runnable finishAndGetNext() {
853-
synchronized (this) {
854-
permitsTaken--;
855-
assert permitsTaken >= 0 : "illegal taken permits: " + permitsTaken;
802+
void submit(Consumer<Releasable> task) {
803+
if (semaphore.tryAcquire()) {
804+
executeAndRelease(task);
805+
} else {
806+
queue.add(task);
807+
if (semaphore.tryAcquire()) {
808+
task = pollNextTaskOrReleasePermit();
809+
if (task != null) {
810+
executeAndRelease(task);
811+
}
812+
}
856813
}
857-
return tryQueue(null);
814+
858815
}
859816

860-
void tryRun(Runnable runnable) {
861-
Runnable r = tryQueue(runnable);
862-
if (r != null) {
863-
r.run();
817+
private void executeAndRelease(Consumer<Releasable> task) {
818+
while (task != null) {
819+
final SubscribableListener<Void> onDone = new SubscribableListener<>();
820+
task.accept(() -> onDone.onResponse(null));
821+
if (onDone.isDone()) {
822+
// keep going on the current thread, no need to fork
823+
task = pollNextTaskOrReleasePermit();
824+
} else {
825+
onDone.addListener(new ActionListener<>() {
826+
@Override
827+
public void onResponse(Void unused) {
828+
final Consumer<Releasable> nextTask = pollNextTaskOrReleasePermit();
829+
if (nextTask != null) {
830+
executeAndRelease(nextTask);
831+
}
832+
}
833+
834+
@Override
835+
public void onFailure(Exception e) {
836+
assert false : e;
837+
}
838+
});
839+
return;
840+
}
864841
}
865842
}
866843

867-
private synchronized Runnable tryQueue(Runnable runnable) {
868-
Runnable toExecute = null;
869-
if (permitsTaken < permits) {
870-
permitsTaken++;
871-
toExecute = runnable;
872-
if (toExecute == null) { // only poll if we don't have anything to execute
873-
toExecute = queue.poll();
874-
}
875-
if (toExecute == null) {
876-
permitsTaken--;
877-
}
878-
} else if (runnable != null) {
879-
queue.add(runnable);
844+
private Consumer<Releasable> pollNextTaskOrReleasePermit() {
845+
var task = queue.poll();
846+
if (task == null) {
847+
semaphore.release();
880848
}
881-
return toExecute;
849+
return task;
882850
}
883851
}
884852
}

0 commit comments

Comments
 (0)