Skip to content

Commit 6d2c3ef

Browse files
authored
[ML] Don't block thread while waiting for work to finish on graceful shutdown (#135350)
A model deployment that is gracefully shutdown will wait until the queued up work is done (or timeout) before terminating the inference process. This change avoids blocking a thread while waiting
1 parent 73f40a4 commit 6d2c3ef

File tree

5 files changed

+91
-54
lines changed

5 files changed

+91
-54
lines changed

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
import java.util.Set;
6363
import java.util.concurrent.ConcurrentHashMap;
6464
import java.util.concurrent.ConcurrentLinkedDeque;
65-
import java.util.function.Consumer;
65+
import java.util.function.BiConsumer;
6666

6767
import static org.elasticsearch.core.Strings.format;
6868
import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ASSIGNMENT_TASK_ACTION;
@@ -274,20 +274,27 @@ public void gracefullyStopDeploymentAndNotify(
274274
public void stopDeploymentAndNotify(TrainedModelDeploymentTask task, String reason, ActionListener<AcknowledgedResponse> listener) {
275275
logger.debug(() -> format("[%s] Forcefully stopping deployment due to reason %s", task.getDeploymentId(), reason));
276276

277-
stopAndNotifyHelper(task, reason, listener, deploymentManager::stopDeployment);
277+
stopAndNotifyHelper(task, reason, listener, (t, l) -> {
278+
deploymentManager.stopDeployment(t);
279+
l.onResponse(AcknowledgedResponse.TRUE);
280+
});
278281
}
279282

280283
private void stopAndNotifyHelper(
281284
TrainedModelDeploymentTask task,
282285
String reason,
283286
ActionListener<AcknowledgedResponse> listener,
284-
Consumer<TrainedModelDeploymentTask> stopDeploymentFunc
287+
BiConsumer<TrainedModelDeploymentTask, ActionListener<AcknowledgedResponse>> stopDeploymentFunc
285288
) {
286289
// Removing the entry from the map to avoid the possibility of a node shutdown triggering a concurrent graceful stopping of the
287290
// process while we are attempting to forcefully stop the native process
288291
// The graceful stopping will only occur if there is an entry in the map
289292
deploymentIdToTask.remove(task.getDeploymentId());
290-
ActionListener<Void> notifyDeploymentOfStopped = updateRoutingStateToStoppedListener(task.getDeploymentId(), reason, listener);
293+
ActionListener<AcknowledgedResponse> notifyDeploymentOfStopped = updateRoutingStateToStoppedListener(
294+
task.getDeploymentId(),
295+
reason,
296+
listener
297+
);
291298

292299
updateStoredState(
293300
task.getDeploymentId(),
@@ -541,7 +548,7 @@ private void gracefullyStopDeployment(String deploymentId, String currentNode) {
541548
)
542549
);
543550

544-
ActionListener<Void> notifyDeploymentOfStopped = updateRoutingStateToStoppedListener(
551+
ActionListener<AcknowledgedResponse> notifyDeploymentOfStopped = updateRoutingStateToStoppedListener(
545552
task.getDeploymentId(),
546553
NODE_IS_SHUTTING_DOWN,
547554
routingStateListener
@@ -550,7 +557,7 @@ private void gracefullyStopDeployment(String deploymentId, String currentNode) {
550557
stopDeploymentAfterCompletingPendingWorkAsync(task, NODE_IS_SHUTTING_DOWN, notifyDeploymentOfStopped);
551558
}
552559

553-
private ActionListener<Void> updateRoutingStateToStoppedListener(
560+
private ActionListener<AcknowledgedResponse> updateRoutingStateToStoppedListener(
554561
String deploymentId,
555562
String reason,
556563
ActionListener<AcknowledgedResponse> listener
@@ -594,27 +601,30 @@ private void stopUnreferencedDeployment(String deploymentId, String currentNode)
594601
);
595602
}
596603

597-
private void stopDeploymentAsync(TrainedModelDeploymentTask task, String reason, ActionListener<Void> listener) {
598-
stopDeploymentHelper(task, reason, deploymentManager::stopDeployment, listener);
604+
private void stopDeploymentAsync(TrainedModelDeploymentTask task, String reason, ActionListener<AcknowledgedResponse> listener) {
605+
stopDeploymentHelper(task, reason, (t, l) -> {
606+
deploymentManager.stopDeployment(t);
607+
l.onResponse(AcknowledgedResponse.TRUE);
608+
}, listener);
599609
}
600610

601611
private void stopDeploymentHelper(
602612
TrainedModelDeploymentTask task,
603613
String reason,
604-
Consumer<TrainedModelDeploymentTask> stopDeploymentFunc,
605-
ActionListener<Void> listener
614+
BiConsumer<TrainedModelDeploymentTask, ActionListener<AcknowledgedResponse>> stopDeploymentFunc,
615+
ActionListener<AcknowledgedResponse> listener
606616
) {
607617
if (stopped) {
618+
listener.onResponse(AcknowledgedResponse.FALSE);
608619
return;
609620
}
610621
task.markAsStopped(reason);
611622

612623
threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> {
613624
try {
614-
stopDeploymentFunc.accept(task);
615625
taskManager.unregister(task);
616626
deploymentIdToTask.remove(task.getDeploymentId());
617-
listener.onResponse(null);
627+
stopDeploymentFunc.accept(task, listener);
618628
} catch (Exception e) {
619629
listener.onFailure(e);
620630
}
@@ -624,7 +634,7 @@ private void stopDeploymentHelper(
624634
private void stopDeploymentAfterCompletingPendingWorkAsync(
625635
TrainedModelDeploymentTask task,
626636
String reason,
627-
ActionListener<Void> listener
637+
ActionListener<AcknowledgedResponse> listener
628638
) {
629639
stopDeploymentHelper(task, reason, deploymentManager::stopAfterCompletingPendingWork, listener);
630640
}
@@ -769,6 +779,7 @@ private void handleLoadSuccess(ActionListener<Boolean> retryListener, TrainedMod
769779

770780
private void updateStoredState(String deploymentId, RoutingInfoUpdate update, ActionListener<AcknowledgedResponse> listener) {
771781
if (stopped) {
782+
listener.onResponse(AcknowledgedResponse.FALSE);
772783
return;
773784
}
774785
trainedModelAssignmentService.updateModelAssignmentState(

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java

Lines changed: 38 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.elasticsearch.action.ActionListener;
1717
import org.elasticsearch.action.search.SearchRequest;
1818
import org.elasticsearch.action.search.TransportSearchAction;
19+
import org.elasticsearch.action.support.ListenerTimeouts;
1920
import org.elasticsearch.action.support.master.AcknowledgedResponse;
2021
import org.elasticsearch.client.internal.Client;
2122
import org.elasticsearch.common.Strings;
@@ -79,6 +80,7 @@ public class DeploymentManager {
7980
private static final Logger logger = LogManager.getLogger(DeploymentManager.class);
8081
private static final AtomicLong requestIdCounter = new AtomicLong(1);
8182
public static final int NUM_RESTART_ATTEMPTS = 3;
83+
private static final TimeValue WORKER_QUEUE_COMPLETION_TIMEOUT = TimeValue.timeValueMinutes(5);
8284

8385
private final Client client;
8486
private final NamedXContentRegistry xContentRegistry;
@@ -331,15 +333,15 @@ public void stopDeployment(TrainedModelDeploymentTask task) {
331333
}
332334
}
333335

334-
public void stopAfterCompletingPendingWork(TrainedModelDeploymentTask task) {
336+
public void stopAfterCompletingPendingWork(TrainedModelDeploymentTask task, ActionListener<AcknowledgedResponse> listener) {
335337
ProcessContext processContext = processContextByAllocation.remove(task.getId());
336338
if (processContext != null) {
337339
logger.info(
338340
"[{}] Stopping deployment after completing pending tasks, reason [{}]",
339341
task.getDeploymentId(),
340342
task.stoppedReason().orElse("unknown")
341343
);
342-
processContext.stopProcessAfterCompletingPendingWork();
344+
processContext.stopProcessAfterCompletingPendingWork(listener);
343345
} else {
344346
logger.warn("[{}] No process context to stop gracefully", task.getDeploymentId());
345347
}
@@ -569,7 +571,7 @@ private Consumer<String> onProcessCrashHandleRestarts(AtomicInteger startsCount,
569571

570572
processContextByAllocation.remove(task.getId());
571573
isStopped = true;
572-
resultProcessor.stop();
574+
resultProcessor.signalIntentToStop();
573575
stateStreamer.cancel();
574576

575577
if (startsCount.get() <= NUM_RESTART_ATTEMPTS) {
@@ -648,7 +650,7 @@ synchronized void forcefullyStopProcess() {
648650

649651
private void prepareInternalStateForShutdown() {
650652
isStopped = true;
651-
resultProcessor.stop();
653+
resultProcessor.signalIntentToStop();
652654
stateStreamer.cancel();
653655
}
654656

@@ -669,43 +671,46 @@ private void closeNlpTaskProcessor() {
669671
}
670672
}
671673

672-
private synchronized void stopProcessAfterCompletingPendingWork() {
674+
private synchronized void stopProcessAfterCompletingPendingWork(ActionListener<AcknowledgedResponse> listener) {
673675
logger.debug(() -> format("[%s] Stopping process after completing its pending work", task.getDeploymentId()));
674676
prepareInternalStateForShutdown();
675-
signalAndWaitForWorkerTermination();
676-
stopProcessGracefully();
677-
closeNlpTaskProcessor();
678-
}
679-
680-
private void signalAndWaitForWorkerTermination() {
681-
try {
682-
awaitTerminationAfterCompletingWork();
683-
} catch (TimeoutException e) {
684-
logger.warn(format("[%s] Timed out waiting for process worker to complete, forcing a shutdown", task.getDeploymentId()), e);
685-
// The process failed to stop in the time period allotted, so we'll mark it for shut down
686-
priorityProcessWorker.shutdown();
687-
priorityProcessWorker.notifyQueueRunnables();
688-
}
689-
}
690677

691-
private void awaitTerminationAfterCompletingWork() throws TimeoutException {
692-
try {
693-
priorityProcessWorker.shutdown();
678+
// Waiting for the process worker to finish the pending work could
679+
// take a long time. To avoid blocking the calling thread register
680+
// a function with the process worker queue that is called when the
681+
// worker queue is finished. Then proceed to closing the native process
682+
// and wait for all results to be processed, the second part can be
683+
// done synchronously as it is not expected to take long.
684+
685+
// This listener closes the native process and waits for the results
686+
// after the worker queue has finished
687+
var closeProcessListener = listener.delegateFailureAndWrap((l, r) -> {
688+
// process worker stopped within allotted time, close process
689+
closeProcessAndWaitForResultProcessor();
690+
closeNlpTaskProcessor();
691+
l.onResponse(AcknowledgedResponse.TRUE);
692+
});
694693

695-
if (priorityProcessWorker.awaitTermination(COMPLETION_TIMEOUT.getMinutes(), TimeUnit.MINUTES) == false) {
696-
throw new TimeoutException(
697-
Strings.format("Timed out waiting for process worker to complete for process %s", PROCESS_NAME)
694+
// Timeout listener waits
695+
var listenWithTimeout = ListenerTimeouts.wrapWithTimeout(
696+
threadPool,
697+
WORKER_QUEUE_COMPLETION_TIMEOUT,
698+
threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME),
699+
closeProcessListener,
700+
(l) -> {
701+
// Stopping the process worker timed out, kill the process
702+
logger.warn(
703+
format("[%s] Timed out waiting for process worker to complete, forcing a shutdown", task.getDeploymentId())
698704
);
699-
} else {
700-
priorityProcessWorker.notifyQueueRunnables();
705+
forcefullyStopProcess();
706+
l.onResponse(AcknowledgedResponse.FALSE);
701707
}
702-
} catch (InterruptedException e) {
703-
Thread.currentThread().interrupt();
704-
logger.info(Strings.format("[%s] Interrupted waiting for process worker to complete", PROCESS_NAME));
705-
}
708+
);
709+
710+
priorityProcessWorker.shutdownWithCallback(() -> listenWithTimeout.onResponse(AcknowledgedResponse.TRUE));
706711
}
707712

708-
private void stopProcessGracefully() {
713+
private void closeProcessAndWaitForResultProcessor() {
709714
try {
710715
closeProcessIfPresent();
711716
resultProcessor.awaitCompletion(COMPLETION_TIMEOUT.getMinutes(), TimeUnit.MINUTES);

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ public synchronized void updateStats(PyTorchResult result) {
313313
}
314314
}
315315

316-
public void stop() {
316+
public void signalIntentToStop() {
317317
isStopping = true;
318318
}
319319

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/AbstractProcessWorkerExecutorService.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ public abstract class AbstractProcessWorkerExecutorService<T extends Runnable> e
4444
private final AtomicReference<Exception> error = new AtomicReference<>();
4545
private final AtomicBoolean running = new AtomicBoolean(true);
4646
private final AtomicBoolean shouldShutdownAfterCompletingWork = new AtomicBoolean(false);
47+
private final AtomicReference<Runnable> onCompletion = new AtomicReference<>();
4748

4849
/**
4950
* @param contextHolder the thread context holder
@@ -78,6 +79,11 @@ public void shutdown() {
7879
shouldShutdownAfterCompletingWork.set(true);
7980
}
8081

82+
public void shutdownWithCallback(Runnable onCompletion) {
83+
this.onCompletion.set(onCompletion);
84+
shutdown();
85+
}
86+
8187
/**
8288
* Some of the tasks in the returned list of {@link Runnable}s could have run. Some tasks may have run while the queue was being copied.
8389
*
@@ -124,6 +130,10 @@ public void start() {
124130
} catch (InterruptedException e) {
125131
Thread.currentThread().interrupt();
126132
} finally {
133+
Runnable onComplete = onCompletion.get();
134+
if (onComplete != null) {
135+
onComplete.run();
136+
}
127137
awaitTermination.countDown();
128138
}
129139
}

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,11 @@ public void testLoadQueuedModelsWhenTaskIsStopped() throws Exception {
265265
UpdateTrainedModelAssignmentRoutingInfoAction.Request.class
266266
);
267267
verify(deploymentManager, times(1)).startDeployment(startTaskCapture.capture(), any());
268-
assertBusy(() -> verify(trainedModelAssignmentService, times(3)).updateModelAssignmentState(requestCapture.capture(), any()));
268+
assertBusy(
269+
() -> verify(trainedModelAssignmentService, times(3)).updateModelAssignmentState(requestCapture.capture(), any()),
270+
3,
271+
TimeUnit.SECONDS
272+
);
269273

270274
boolean seenStopping = false;
271275
for (int i = 0; i < 3; i++) {
@@ -397,6 +401,13 @@ public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNode_CallsSto
397401
return null;
398402
}).when(trainedModelAssignmentService).updateModelAssignmentState(any(), any());
399403

404+
doAnswer(invocationOnMock -> {
405+
@SuppressWarnings({ "unchecked", "rawtypes" })
406+
ActionListener<AcknowledgedResponse> listener = (ActionListener) invocationOnMock.getArguments()[1];
407+
listener.onResponse(AcknowledgedResponse.TRUE);
408+
return null;
409+
}).when(deploymentManager).stopAfterCompletingPendingWork(any(), any());
410+
400411
var taskParams = newParams(deploymentOne, modelOne);
401412

402413
ClusterChangedEvent event = new ClusterChangedEvent(
@@ -430,7 +441,7 @@ public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNode_CallsSto
430441
}
431442

432443
assertBusy(() -> {
433-
verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(stopParamsCapture.capture());
444+
verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(stopParamsCapture.capture(), any());
434445
assertThat(stopParamsCapture.getValue().getModelId(), equalTo(modelOne));
435446
assertThat(stopParamsCapture.getValue().getDeploymentId(), equalTo(deploymentOne));
436447
});
@@ -481,7 +492,7 @@ public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNode_ButOther
481492
trainedModelAssignmentNodeService.prepareModelToLoad(taskParams);
482493
trainedModelAssignmentNodeService.clusterChanged(event);
483494

484-
verify(deploymentManager, never()).stopAfterCompletingPendingWork(any());
495+
verify(deploymentManager, never()).stopAfterCompletingPendingWork(any(), any());
485496
verify(trainedModelAssignmentService, never()).updateModelAssignmentState(
486497
any(UpdateTrainedModelAssignmentRoutingInfoAction.Request.class),
487498
any()
@@ -522,7 +533,7 @@ public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNodeButAlread
522533

523534
trainedModelAssignmentNodeService.clusterChanged(event);
524535

525-
verify(deploymentManager, never()).stopAfterCompletingPendingWork(any());
536+
verify(deploymentManager, never()).stopAfterCompletingPendingWork(any(), any());
526537
verify(trainedModelAssignmentService, never()).updateModelAssignmentState(
527538
any(UpdateTrainedModelAssignmentRoutingInfoAction.Request.class),
528539
any()
@@ -564,7 +575,7 @@ public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNodeWithStart
564575
trainedModelAssignmentNodeService.prepareModelToLoad(taskParams);
565576
trainedModelAssignmentNodeService.clusterChanged(event);
566577

567-
verify(deploymentManager, never()).stopAfterCompletingPendingWork(any());
578+
verify(deploymentManager, never()).stopAfterCompletingPendingWork(any(), any());
568579
verify(trainedModelAssignmentService, never()).updateModelAssignmentState(
569580
any(UpdateTrainedModelAssignmentRoutingInfoAction.Request.class),
570581
any()
@@ -601,7 +612,7 @@ public void testClusterChanged_WhenNodeDoesNotExistInAssignmentRoutingTable_Does
601612
trainedModelAssignmentNodeService.prepareModelToLoad(taskParams);
602613
trainedModelAssignmentNodeService.clusterChanged(event);
603614

604-
assertBusy(() -> verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(any()));
615+
assertBusy(() -> verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(any(), any()));
605616
// This still shouldn't trigger a cluster state update because the routing entry wasn't in the table so we won't add a new routing
606617
// entry for stopping
607618
verify(trainedModelAssignmentService, never()).updateModelAssignmentState(
@@ -765,7 +776,7 @@ public void testClusterChanged() throws Exception {
765776
ArgumentCaptor<TrainedModelDeploymentTask> stoppedTaskCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class);
766777
// deployment-2 was originally started on node NODE_ID but in the latest cluster event it is no longer on that node so we will
767778
// gracefully stop it
768-
verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(stoppedTaskCapture.capture());
779+
verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(stoppedTaskCapture.capture(), any());
769780
assertThat(stoppedTaskCapture.getAllValues().get(0).getDeploymentId(), equalTo(deploymentTwo));
770781
});
771782
ArgumentCaptor<TrainedModelDeploymentTask> startTaskCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class);

0 commit comments

Comments
 (0)