Skip to content

Commit 2d85476

Browse files
Optimize threading in AbstractSearchAsyncAction (#113230)
Forking when an action completes on the current thread is needlessly heavy handed in preventing stack-overflows. Also, we don't need locking/synchronization to deal with a worker-count + queue length problem. Both of these allow for non-trivial optimization even in the current execution model, also this change helps with moving to a more efficient execution model by saving needless forking to the search pool in particular. -> refactored the code to never fork but instead avoid stack-depth issues through use of a `SubscribableListener` -> replaced our home brew queue and semaphore combination by JDK primitives which saves blocking synchronization on task start and completion.
1 parent f1de84b commit 2d85476

File tree

1 file changed

+94
-126
lines changed

1 file changed

+94
-126
lines changed

server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java

Lines changed: 94 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,13 @@
2020
import org.elasticsearch.action.OriginalIndices;
2121
import org.elasticsearch.action.ShardOperationFailedException;
2222
import org.elasticsearch.action.search.TransportSearchAction.SearchTimeProvider;
23+
import org.elasticsearch.action.support.SubscribableListener;
2324
import org.elasticsearch.action.support.TransportActions;
2425
import org.elasticsearch.cluster.ClusterState;
2526
import org.elasticsearch.cluster.routing.GroupShardsIterator;
2627
import org.elasticsearch.common.bytes.BytesReference;
2728
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
28-
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
2929
import org.elasticsearch.common.util.concurrent.AtomicArray;
30-
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
3130
import org.elasticsearch.core.Releasable;
3231
import org.elasticsearch.core.Releasables;
3332
import org.elasticsearch.index.shard.ShardId;
@@ -43,17 +42,19 @@
4342
import org.elasticsearch.tasks.TaskCancelledException;
4443
import org.elasticsearch.transport.Transport;
4544

46-
import java.util.ArrayDeque;
4745
import java.util.ArrayList;
4846
import java.util.Collections;
4947
import java.util.HashMap;
5048
import java.util.List;
5149
import java.util.Map;
5250
import java.util.concurrent.ConcurrentHashMap;
5351
import java.util.concurrent.Executor;
52+
import java.util.concurrent.LinkedTransferQueue;
53+
import java.util.concurrent.Semaphore;
5454
import java.util.concurrent.atomic.AtomicBoolean;
5555
import java.util.concurrent.atomic.AtomicInteger;
5656
import java.util.function.BiFunction;
57+
import java.util.function.Consumer;
5758
import java.util.stream.Collectors;
5859

5960
import static org.elasticsearch.core.Strings.format;
@@ -238,7 +239,12 @@ public final void run() {
238239
assert shardRoutings.skip() == false;
239240
assert shardIndexMap.containsKey(shardRoutings);
240241
int shardIndex = shardIndexMap.get(shardRoutings);
241-
performPhaseOnShard(shardIndex, shardRoutings, shardRoutings.nextOrNull());
242+
final SearchShardTarget routing = shardRoutings.nextOrNull();
243+
if (routing == null) {
244+
failOnUnavailable(shardIndex, shardRoutings);
245+
} else {
246+
performPhaseOnShard(shardIndex, shardRoutings, routing);
247+
}
242248
}
243249
}
244250
}
@@ -258,7 +264,7 @@ private static boolean assertExecuteOnStartThread() {
258264
int index = 0;
259265
assert stackTraceElements[index++].getMethodName().equals("getStackTrace");
260266
assert stackTraceElements[index++].getMethodName().equals("assertExecuteOnStartThread");
261-
assert stackTraceElements[index++].getMethodName().equals("performPhaseOnShard");
267+
assert stackTraceElements[index++].getMethodName().equals("failOnUnavailable");
262268
if (stackTraceElements[index].getMethodName().equals("performPhaseOnShard")) {
263269
assert stackTraceElements[index].getClassName().endsWith("CanMatchPreFilterSearchPhase");
264270
index++;
@@ -277,65 +283,53 @@ private static boolean assertExecuteOnStartThread() {
277283
}
278284

279285
protected void performPhaseOnShard(final int shardIndex, final SearchShardIterator shardIt, final SearchShardTarget shard) {
280-
/*
281-
* We capture the thread that this phase is starting on. When we are called back after executing the phase, we are either on the
282-
* same thread (because we never went async, or the same thread was selected from the thread pool) or a different thread. If we
283-
* continue on the same thread in the case that we never went async and this happens repeatedly we will end up recursing deeply and
284-
* could stack overflow. To prevent this, we fork if we are called back on the same thread that execution started on and otherwise
285-
* we can continue (cf. InitialSearchPhase#maybeFork).
286-
*/
287-
if (shard == null) {
288-
assert assertExecuteOnStartThread();
289-
SearchShardTarget unassignedShard = new SearchShardTarget(null, shardIt.shardId(), shardIt.getClusterAlias());
290-
onShardFailure(shardIndex, unassignedShard, shardIt, new NoShardAvailableActionException(shardIt.shardId()));
286+
if (throttleConcurrentRequests) {
287+
var pendingExecutions = pendingExecutionsPerNode.computeIfAbsent(
288+
shard.getNodeId(),
289+
n -> new PendingExecutions(maxConcurrentRequestsPerNode)
290+
);
291+
pendingExecutions.submit(l -> doPerformPhaseOnShard(shardIndex, shardIt, shard, l));
291292
} else {
292-
final PendingExecutions pendingExecutions = throttleConcurrentRequests
293-
? pendingExecutionsPerNode.computeIfAbsent(shard.getNodeId(), n -> new PendingExecutions(maxConcurrentRequestsPerNode))
294-
: null;
295-
Runnable r = () -> {
296-
final Thread thread = Thread.currentThread();
297-
try {
298-
executePhaseOnShard(shardIt, shard, new SearchActionListener<>(shard, shardIndex) {
299-
@Override
300-
public void innerOnResponse(Result result) {
301-
try {
302-
onShardResult(result, shardIt);
303-
} catch (Exception exc) {
304-
onShardFailure(shardIndex, shard, shardIt, exc);
305-
} finally {
306-
executeNext(pendingExecutions, thread);
307-
}
308-
}
293+
doPerformPhaseOnShard(shardIndex, shardIt, shard, () -> {});
294+
}
295+
}
309296

310-
@Override
311-
public void onFailure(Exception t) {
312-
try {
313-
onShardFailure(shardIndex, shard, shardIt, t);
314-
} finally {
315-
executeNext(pendingExecutions, thread);
316-
}
317-
}
318-
});
319-
} catch (final Exception e) {
320-
try {
321-
/*
322-
* It is possible to run into connection exceptions here because we are getting the connection early and might
323-
* run into nodes that are not connected. In this case, on shard failure will move us to the next shard copy.
324-
*/
325-
fork(() -> onShardFailure(shardIndex, shard, shardIt, e));
326-
} finally {
327-
executeNext(pendingExecutions, thread);
297+
private void doPerformPhaseOnShard(int shardIndex, SearchShardIterator shardIt, SearchShardTarget shard, Releasable releasable) {
298+
try {
299+
executePhaseOnShard(shardIt, shard, new SearchActionListener<>(shard, shardIndex) {
300+
@Override
301+
public void innerOnResponse(Result result) {
302+
try (releasable) {
303+
onShardResult(result, shardIt);
304+
} catch (Exception exc) {
305+
onShardFailure(shardIndex, shard, shardIt, exc);
328306
}
329307
}
330-
};
331-
if (throttleConcurrentRequests) {
332-
pendingExecutions.tryRun(r);
333-
} else {
334-
r.run();
308+
309+
@Override
310+
public void onFailure(Exception e) {
311+
try (releasable) {
312+
onShardFailure(shardIndex, shard, shardIt, e);
313+
}
314+
}
315+
});
316+
} catch (final Exception e) {
317+
/*
318+
* It is possible to run into connection exceptions here because we are getting the connection early and might
319+
* run into nodes that are not connected. In this case, on shard failure will move us to the next shard copy.
320+
*/
321+
try (releasable) {
322+
onShardFailure(shardIndex, shard, shardIt, e);
335323
}
336324
}
337325
}
338326

327+
private void failOnUnavailable(int shardIndex, SearchShardIterator shardIt) {
328+
assert assertExecuteOnStartThread();
329+
SearchShardTarget unassignedShard = new SearchShardTarget(null, shardIt.shardId(), shardIt.getClusterAlias());
330+
onShardFailure(shardIndex, unassignedShard, shardIt, new NoShardAvailableActionException(shardIt.shardId()));
331+
}
332+
339333
/**
340334
* Sends the request to the actual shard.
341335
* @param shardIt the shards iterator
@@ -348,34 +342,6 @@ protected abstract void executePhaseOnShard(
348342
SearchActionListener<Result> listener
349343
);
350344

351-
protected void fork(final Runnable runnable) {
352-
executor.execute(new AbstractRunnable() {
353-
@Override
354-
public void onFailure(Exception e) {
355-
logger.error(() -> "unexpected error during [" + task + "]", e);
356-
assert false : e;
357-
}
358-
359-
@Override
360-
public void onRejection(Exception e) {
361-
// avoid leaks during node shutdown by executing on the current thread if the executor shuts down
362-
assert e instanceof EsRejectedExecutionException esre && esre.isExecutorShutdown() : e;
363-
doRun();
364-
}
365-
366-
@Override
367-
protected void doRun() {
368-
runnable.run();
369-
}
370-
371-
@Override
372-
public boolean isForceExecution() {
373-
// we can not allow a stuffed queue to reject execution here
374-
return true;
375-
}
376-
});
377-
}
378-
379345
@Override
380346
public final void executeNextPhase(SearchPhase currentPhase, SearchPhase nextPhase) {
381347
/* This is the main search phase transition where we move to the next phase. If all shards
@@ -794,61 +760,63 @@ protected final ShardSearchRequest buildShardSearchRequest(SearchShardIterator s
794760
*/
795761
protected abstract SearchPhase getNextPhase(SearchPhaseResults<Result> results, SearchPhaseContext context);
796762

797-
private void executeNext(PendingExecutions pendingExecutions, Thread originalThread) {
798-
executeNext(pendingExecutions == null ? null : pendingExecutions.finishAndGetNext(), originalThread);
799-
}
800-
801-
void executeNext(Runnable runnable, Thread originalThread) {
802-
if (runnable != null) {
803-
assert throttleConcurrentRequests;
804-
if (originalThread == Thread.currentThread()) {
805-
fork(runnable);
806-
} else {
807-
runnable.run();
808-
}
809-
}
810-
}
811-
812763
private static final class PendingExecutions {
813-
private final int permits;
814-
private int permitsTaken = 0;
815-
private final ArrayDeque<Runnable> queue = new ArrayDeque<>();
764+
private final Semaphore semaphore;
765+
private final LinkedTransferQueue<Consumer<Releasable>> queue = new LinkedTransferQueue<>();
816766

817767
PendingExecutions(int permits) {
818768
assert permits > 0 : "not enough permits: " + permits;
819-
this.permits = permits;
769+
semaphore = new Semaphore(permits);
820770
}
821771

822-
Runnable finishAndGetNext() {
823-
synchronized (this) {
824-
permitsTaken--;
825-
assert permitsTaken >= 0 : "illegal taken permits: " + permitsTaken;
772+
void submit(Consumer<Releasable> task) {
773+
if (semaphore.tryAcquire()) {
774+
executeAndRelease(task);
775+
} else {
776+
queue.add(task);
777+
if (semaphore.tryAcquire()) {
778+
task = pollNextTaskOrReleasePermit();
779+
if (task != null) {
780+
executeAndRelease(task);
781+
}
782+
}
826783
}
827-
return tryQueue(null);
784+
828785
}
829786

830-
void tryRun(Runnable runnable) {
831-
Runnable r = tryQueue(runnable);
832-
if (r != null) {
833-
r.run();
787+
private void executeAndRelease(Consumer<Releasable> task) {
788+
while (task != null) {
789+
final SubscribableListener<Void> onDone = new SubscribableListener<>();
790+
task.accept(() -> onDone.onResponse(null));
791+
if (onDone.isDone()) {
792+
// keep going on the current thread, no need to fork
793+
task = pollNextTaskOrReleasePermit();
794+
} else {
795+
onDone.addListener(new ActionListener<>() {
796+
@Override
797+
public void onResponse(Void unused) {
798+
final Consumer<Releasable> nextTask = pollNextTaskOrReleasePermit();
799+
if (nextTask != null) {
800+
executeAndRelease(nextTask);
801+
}
802+
}
803+
804+
@Override
805+
public void onFailure(Exception e) {
806+
assert false : e;
807+
}
808+
});
809+
return;
810+
}
834811
}
835812
}
836813

837-
private synchronized Runnable tryQueue(Runnable runnable) {
838-
Runnable toExecute = null;
839-
if (permitsTaken < permits) {
840-
permitsTaken++;
841-
toExecute = runnable;
842-
if (toExecute == null) { // only poll if we don't have anything to execute
843-
toExecute = queue.poll();
844-
}
845-
if (toExecute == null) {
846-
permitsTaken--;
847-
}
848-
} else if (runnable != null) {
849-
queue.add(runnable);
814+
private Consumer<Releasable> pollNextTaskOrReleasePermit() {
815+
var task = queue.poll();
816+
if (task == null) {
817+
semaphore.release();
850818
}
851-
return toExecute;
819+
return task;
852820
}
853821
}
854822
}

0 commit comments

Comments
 (0)