-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Limit concurrent node requests #122850
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Limit concurrent node requests #122850
Changes from 15 commits
eec2039
d4850fc
ae96763
15d897f
666f588
34badf7
8c55e8a
b2a66a2
11bd0f0
53f7d60
1bf2715
3296cf0
f0cc4ec
5ce4095
a5084ab
50da087
6809e58
1458611
02ddb1e
4058ef1
e8ad22a
a70aff4
ed54538
7f47bde
f6868a1
f8a54f6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -98,15 +98,25 @@ void startComputeOnDataNodes( | |||
| Runnable runOnTaskFailure, | ||||
| ActionListener<ComputeResponse> outListener | ||||
| ) { | ||||
| final boolean allowPartialResults = configuration.allowPartialResults(); | ||||
| DataNodeRequestSender sender = new DataNodeRequestSender(transportService, esqlExecutor, parentTask, allowPartialResults) { | ||||
| new DataNodeRequestSender( | ||||
| transportService, | ||||
| esqlExecutor, | ||||
| parentTask, | ||||
| configuration.allowPartialResults(), | ||||
| configuration.pragmas().maxConcurrentNodesPerCluster() | ||||
| ) { | ||||
| @Override | ||||
| protected void sendRequest( | ||||
| DiscoveryNode node, | ||||
| List<ShardId> shardIds, | ||||
| Map<Index, AliasFilter> aliasFilters, | ||||
| NodeListener nodeListener | ||||
| ) { | ||||
| if (exchangeSource.isFinished()) { | ||||
| nodeListener.onSkip(true); | ||||
| return; | ||||
| } | ||||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This part prevents us from sending a query to remaining data nodes if we collected enough results
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: There's one thing here: We'll "skip" it with onSkip(), but the Sender will still continue processing all shards. From what I see, it will continue calling this after every node finishes. Should we instead pass something to the sender so it stops calling sendRequest()? I don't think it matters, computationally speaking, but it fells like we're doing "too much" when we could shortcircuit instead (?)
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like we won't send more requests because we do: So we'll only count the number of shards we skip and that's it. I think.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Correct, this was added today in: a5084ab
The total skipped count consists of ones we skipped already ( Line 96 in 50da087
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since these computations should not be expensive, I wonder if we should skip only here, not shortcutting in other places. The reason is that we might need to be more careful not to shortcut in other places when allow_partial_results=true.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good. It would also simplify the change. We can always add it back later if we see it is needed. |
||||
|
|
||||
| final AtomicLong pagesFetched = new AtomicLong(); | ||||
| var listener = ActionListener.wrap(nodeListener::onResponse, e -> nodeListener.onFailure(e, pagesFetched.get() > 0)); | ||||
| final Transport.Connection connection; | ||||
|
|
@@ -129,7 +139,7 @@ protected void sendRequest( | |||
| listener.delegateFailureAndWrap((l, unused) -> { | ||||
| final Runnable onGroupFailure; | ||||
| final CancellableTask groupTask; | ||||
| if (allowPartialResults) { | ||||
| if (configuration.allowPartialResults()) { | ||||
| try { | ||||
| groupTask = computeService.createGroupTask( | ||||
| parentTask, | ||||
|
|
@@ -152,7 +162,7 @@ protected void sendRequest( | |||
| final var remoteSink = exchangeService.newRemoteSink(groupTask, childSessionId, transportService, connection); | ||||
| exchangeSource.addRemoteSink( | ||||
| remoteSink, | ||||
| allowPartialResults == false, | ||||
| configuration.allowPartialResults() == false, | ||||
| pagesFetched::incrementAndGet, | ||||
| queryPragmas.concurrentExchangeClients(), | ||||
| computeListener.acquireAvoid() | ||||
|
|
@@ -184,8 +194,7 @@ protected void sendRequest( | |||
| }) | ||||
| ); | ||||
| } | ||||
| }; | ||||
| sender.startComputeOnDataNodes( | ||||
| }.startComputeOnDataNodes( | ||||
| clusterAlias, | ||||
| concreteIndices, | ||||
| originalIndices, | ||||
|
|
||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -57,18 +57,27 @@ abstract class DataNodeRequestSender { | |
| private final Executor esqlExecutor; | ||
| private final CancellableTask rootTask; | ||
| private final boolean allowPartialResults; | ||
| private final Semaphore concurrentRequests; | ||
| private final ReentrantLock sendingLock = new ReentrantLock(); | ||
| private final Queue<ShardId> pendingShardIds = ConcurrentCollections.newQueue(); | ||
| private final Map<DiscoveryNode, Semaphore> nodePermits = new HashMap<>(); | ||
| private final Map<ShardId, ShardFailure> shardFailures = ConcurrentCollections.newConcurrentMap(); | ||
| private final AtomicBoolean changed = new AtomicBoolean(); | ||
| private boolean reportedFailure = false; // guarded by sendingLock | ||
| private volatile boolean skipRemaining = false; | ||
|
|
||
| DataNodeRequestSender(TransportService transportService, Executor esqlExecutor, CancellableTask rootTask, boolean allowPartialResults) { | ||
| DataNodeRequestSender( | ||
| TransportService transportService, | ||
| Executor esqlExecutor, | ||
| CancellableTask rootTask, | ||
| boolean allowPartialResults, | ||
| int concurrentRequests | ||
| ) { | ||
| this.transportService = transportService; | ||
| this.esqlExecutor = esqlExecutor; | ||
| this.rootTask = rootTask; | ||
| this.allowPartialResults = allowPartialResults; | ||
| this.concurrentRequests = concurrentRequests > 0 ? new Semaphore(concurrentRequests) : null; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we initialize the Semaphore for the -1 case with
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I made a special case not to keep it at all if we have no limit (most of the cases). But I can do that as well. |
||
| } | ||
|
|
||
| final void startComputeOnDataNodes( | ||
|
|
@@ -126,7 +135,7 @@ private void trySendingRequestsForPendingShards(TargetShards targetShards, Compu | |
| || (allowPartialResults == false && shardFailures.values().stream().anyMatch(shardFailure -> shardFailure.fatal))) { | ||
| reportedFailure = true; | ||
| reportFailures(computeListener); | ||
| } else { | ||
| } else if (skipRemaining == false) { | ||
| for (NodeRequest request : selectNodeRequests(targetShards)) { | ||
| sendOneNodeRequest(targetShards, computeListener, request); | ||
| } | ||
|
|
@@ -159,6 +168,9 @@ private void sendOneNodeRequest(TargetShards targetShards, ComputeListener compu | |
| sendRequest(request.node, request.shardIds, request.aliasFilters, new NodeListener() { | ||
| void onAfter(List<DriverProfile> profiles) { | ||
| nodePermits.get(request.node).release(); | ||
| if (concurrentRequests != null) { | ||
| concurrentRequests.release(); | ||
| } | ||
| trySendingRequestsForPendingShards(targetShards, computeListener); | ||
| listener.onResponse(profiles); | ||
| } | ||
|
|
@@ -187,6 +199,14 @@ public void onFailure(Exception e, boolean receivedData) { | |
| } | ||
| onAfter(List.of()); | ||
| } | ||
|
|
||
| @Override | ||
| public void onSkip(boolean skipRemaining) { | ||
| if (skipRemaining) { | ||
| DataNodeRequestSender.this.skipRemaining = true; | ||
| } | ||
| onAfter(List.of()); | ||
|
||
| } | ||
| }); | ||
| } | ||
|
|
||
|
|
@@ -196,6 +216,8 @@ interface NodeListener { | |
| void onResponse(DataNodeComputeResponse response); | ||
|
|
||
| void onFailure(Exception e, boolean receivedData); | ||
|
|
||
| void onSkip(boolean skipRemaining); | ||
| } | ||
|
|
||
| private static Exception unwrapFailure(Exception e) { | ||
|
|
@@ -256,6 +278,7 @@ private List<NodeRequest> selectNodeRequests(TargetShards targetShards) { | |
| assert sendingLock.isHeldByCurrentThread(); | ||
| final Map<DiscoveryNode, List<ShardId>> nodeToShardIds = new HashMap<>(); | ||
| final Iterator<ShardId> shardsIt = pendingShardIds.iterator(); | ||
|
|
||
| while (shardsIt.hasNext()) { | ||
| ShardId shardId = shardsIt.next(); | ||
| ShardFailure failure = shardFailures.get(shardId); | ||
|
|
@@ -265,31 +288,45 @@ private List<NodeRequest> selectNodeRequests(TargetShards targetShards) { | |
| } | ||
| TargetShard shard = targetShards.getShard(shardId); | ||
| Iterator<DiscoveryNode> nodesIt = shard.remainingNodes.iterator(); | ||
| DiscoveryNode selectedNode = null; | ||
| while (nodesIt.hasNext()) { | ||
| DiscoveryNode node = nodesIt.next(); | ||
| if (nodeToShardIds.containsKey(node) || nodePermits.get(node).tryAcquire()) { | ||
| List<ShardId> pendingRequest = nodeToShardIds.get(node); | ||
| if (pendingRequest != null) { | ||
| pendingRequest.add(shard.shardId); | ||
| nodesIt.remove(); | ||
| shardsIt.remove(); | ||
| selectedNode = node; | ||
| break; | ||
| } | ||
| } | ||
| if (selectedNode != null) { | ||
| nodeToShardIds.computeIfAbsent(selectedNode, unused -> new ArrayList<>()).add(shard.shardId); | ||
|
|
||
| if (concurrentRequests == null || concurrentRequests.tryAcquire()) { | ||
| if (nodePermits.get(node).tryAcquire()) { | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if instead we want to check all pending requests here before attempting a new one? |
||
| pendingRequest = new ArrayList<>(); | ||
| pendingRequest.add(shard.shardId); | ||
| nodeToShardIds.put(node, pendingRequest); | ||
|
|
||
| nodesIt.remove(); | ||
| shardsIt.remove(); | ||
|
|
||
| break; | ||
| } else if (concurrentRequests != null) { | ||
| concurrentRequests.release(); | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| final List<NodeRequest> nodeRequests = new ArrayList<>(nodeToShardIds.size()); | ||
| for (var e : nodeToShardIds.entrySet()) { | ||
| List<ShardId> shardIds = e.getValue(); | ||
| for (var entry : nodeToShardIds.entrySet()) { | ||
| var node = entry.getKey(); | ||
| var shardIds = entry.getValue(); | ||
| Map<Index, AliasFilter> aliasFilters = new HashMap<>(); | ||
| for (ShardId shardId : shardIds) { | ||
| var aliasFilter = targetShards.getShard(shardId).aliasFilter; | ||
| if (aliasFilter != null) { | ||
| aliasFilters.put(shardId.getIndex(), aliasFilter); | ||
| } | ||
| } | ||
| nodeRequests.add(new NodeRequest(e.getKey(), shardIds, aliasFilters)); | ||
| nodeRequests.add(new NodeRequest(node, shardIds, aliasFilters)); | ||
| } | ||
| return nodeRequests; | ||
| } | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.