diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java index 846fb7a530283..39d8cce6c8e70 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java @@ -62,7 +62,7 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedDeque; -import java.util.function.Consumer; +import java.util.function.BiConsumer; import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ASSIGNMENT_TASK_ACTION; @@ -274,20 +274,27 @@ public void gracefullyStopDeploymentAndNotify( public void stopDeploymentAndNotify(TrainedModelDeploymentTask task, String reason, ActionListener listener) { logger.debug(() -> format("[%s] Forcefully stopping deployment due to reason %s", task.getDeploymentId(), reason)); - stopAndNotifyHelper(task, reason, listener, deploymentManager::stopDeployment); + stopAndNotifyHelper(task, reason, listener, (t, l) -> { + deploymentManager.stopDeployment(t); + l.onResponse(AcknowledgedResponse.TRUE); + }); } private void stopAndNotifyHelper( TrainedModelDeploymentTask task, String reason, ActionListener listener, - Consumer stopDeploymentFunc + BiConsumer> stopDeploymentFunc ) { // Removing the entry from the map to avoid the possibility of a node shutdown triggering a concurrent graceful stopping of the // process while we are attempting to forcefully stop the native process // The graceful stopping will only occur if there is an entry in the map deploymentIdToTask.remove(task.getDeploymentId()); - ActionListener notifyDeploymentOfStopped = updateRoutingStateToStoppedListener(task.getDeploymentId(), reason, listener); + ActionListener notifyDeploymentOfStopped = updateRoutingStateToStoppedListener( + task.getDeploymentId(), + reason, + listener + ); updateStoredState( task.getDeploymentId(), @@ -541,7 +548,7 @@ private void gracefullyStopDeployment(String deploymentId, String currentNode) { ) ); - ActionListener notifyDeploymentOfStopped = updateRoutingStateToStoppedListener( + ActionListener notifyDeploymentOfStopped = updateRoutingStateToStoppedListener( task.getDeploymentId(), NODE_IS_SHUTTING_DOWN, routingStateListener @@ -550,7 +557,7 @@ private void gracefullyStopDeployment(String deploymentId, String currentNode) { stopDeploymentAfterCompletingPendingWorkAsync(task, NODE_IS_SHUTTING_DOWN, notifyDeploymentOfStopped); } - private ActionListener updateRoutingStateToStoppedListener( + private ActionListener updateRoutingStateToStoppedListener( String deploymentId, String reason, ActionListener listener @@ -594,27 +601,30 @@ private void stopUnreferencedDeployment(String deploymentId, String currentNode) ); } - private void stopDeploymentAsync(TrainedModelDeploymentTask task, String reason, ActionListener listener) { - stopDeploymentHelper(task, reason, deploymentManager::stopDeployment, listener); + private void stopDeploymentAsync(TrainedModelDeploymentTask task, String reason, ActionListener listener) { + stopDeploymentHelper(task, reason, (t, l) -> { + deploymentManager.stopDeployment(t); + l.onResponse(AcknowledgedResponse.TRUE); + }, listener); } private void stopDeploymentHelper( TrainedModelDeploymentTask task, String reason, - Consumer stopDeploymentFunc, - ActionListener listener + BiConsumer> stopDeploymentFunc, + ActionListener listener ) { if (stopped) { + listener.onResponse(AcknowledgedResponse.FALSE); return; } task.markAsStopped(reason); threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> { try { - stopDeploymentFunc.accept(task); taskManager.unregister(task); deploymentIdToTask.remove(task.getDeploymentId()); - listener.onResponse(null); + stopDeploymentFunc.accept(task, listener); } catch (Exception e) { listener.onFailure(e); } @@ -624,7 +634,7 @@ private void stopDeploymentHelper( private void stopDeploymentAfterCompletingPendingWorkAsync( TrainedModelDeploymentTask task, String reason, - ActionListener listener + ActionListener listener ) { stopDeploymentHelper(task, reason, deploymentManager::stopAfterCompletingPendingWork, listener); } @@ -769,6 +779,7 @@ private void handleLoadSuccess(ActionListener retryListener, TrainedMod private void updateStoredState(String deploymentId, RoutingInfoUpdate update, ActionListener listener) { if (stopped) { + listener.onResponse(AcknowledgedResponse.FALSE); return; } trainedModelAssignmentService.updateModelAssignmentState( diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java index c6f1ebcc10780..124f78cfad41d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java @@ -16,6 +16,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.TransportSearchAction; +import org.elasticsearch.action.support.ListenerTimeouts; import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.Strings; @@ -79,6 +80,7 @@ public class DeploymentManager { private static final Logger logger = LogManager.getLogger(DeploymentManager.class); private static final AtomicLong requestIdCounter = new AtomicLong(1); public static final int NUM_RESTART_ATTEMPTS = 3; + private static final TimeValue WORKER_QUEUE_COMPLETION_TIMEOUT = TimeValue.timeValueMinutes(5); private final Client client; private final NamedXContentRegistry xContentRegistry; @@ -331,7 +333,7 @@ public void stopDeployment(TrainedModelDeploymentTask task) { } } - public void stopAfterCompletingPendingWork(TrainedModelDeploymentTask task) { + public void stopAfterCompletingPendingWork(TrainedModelDeploymentTask task, ActionListener listener) { ProcessContext processContext = processContextByAllocation.remove(task.getId()); if (processContext != null) { logger.info( @@ -339,7 +341,7 @@ public void stopAfterCompletingPendingWork(TrainedModelDeploymentTask task) { task.getDeploymentId(), task.stoppedReason().orElse("unknown") ); - processContext.stopProcessAfterCompletingPendingWork(); + processContext.stopProcessAfterCompletingPendingWork(listener); } else { logger.warn("[{}] No process context to stop gracefully", task.getDeploymentId()); } @@ -569,7 +571,7 @@ private Consumer onProcessCrashHandleRestarts(AtomicInteger startsCount, processContextByAllocation.remove(task.getId()); isStopped = true; - resultProcessor.stop(); + resultProcessor.signalIntentToStop(); stateStreamer.cancel(); if (startsCount.get() <= NUM_RESTART_ATTEMPTS) { @@ -648,7 +650,7 @@ synchronized void forcefullyStopProcess() { private void prepareInternalStateForShutdown() { isStopped = true; - resultProcessor.stop(); + resultProcessor.signalIntentToStop(); stateStreamer.cancel(); } @@ -669,43 +671,46 @@ private void closeNlpTaskProcessor() { } } - private synchronized void stopProcessAfterCompletingPendingWork() { + private synchronized void stopProcessAfterCompletingPendingWork(ActionListener listener) { logger.debug(() -> format("[%s] Stopping process after completing its pending work", task.getDeploymentId())); prepareInternalStateForShutdown(); - signalAndWaitForWorkerTermination(); - stopProcessGracefully(); - closeNlpTaskProcessor(); - } - - private void signalAndWaitForWorkerTermination() { - try { - awaitTerminationAfterCompletingWork(); - } catch (TimeoutException e) { - logger.warn(format("[%s] Timed out waiting for process worker to complete, forcing a shutdown", task.getDeploymentId()), e); - // The process failed to stop in the time period allotted, so we'll mark it for shut down - priorityProcessWorker.shutdown(); - priorityProcessWorker.notifyQueueRunnables(); - } - } - private void awaitTerminationAfterCompletingWork() throws TimeoutException { - try { - priorityProcessWorker.shutdown(); + // Waiting for the process worker to finish the pending work could + // take a long time. To avoid blocking the calling thread register + // a function with the process worker queue that is called when the + // worker queue is finished. Then proceed to closing the native process + // and wait for all results to be processed, the second part can be + // done synchronously as it is not expected to take long. + + // This listener closes the native process and waits for the results + // after the worker queue has finished + var closeProcessListener = listener.delegateFailureAndWrap((l, r) -> { + // process worker stopped within allotted time, close process + closeProcessAndWaitForResultProcessor(); + closeNlpTaskProcessor(); + l.onResponse(AcknowledgedResponse.TRUE); + }); - if (priorityProcessWorker.awaitTermination(COMPLETION_TIMEOUT.getMinutes(), TimeUnit.MINUTES) == false) { - throw new TimeoutException( - Strings.format("Timed out waiting for process worker to complete for process %s", PROCESS_NAME) + // Timeout listener waits + var listenWithTimeout = ListenerTimeouts.wrapWithTimeout( + threadPool, + WORKER_QUEUE_COMPLETION_TIMEOUT, + threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME), + closeProcessListener, + (l) -> { + // Stopping the process worker timed out, kill the process + logger.warn( + format("[%s] Timed out waiting for process worker to complete, forcing a shutdown", task.getDeploymentId()) ); - } else { - priorityProcessWorker.notifyQueueRunnables(); + forcefullyStopProcess(); + l.onResponse(AcknowledgedResponse.FALSE); } - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - logger.info(Strings.format("[%s] Interrupted waiting for process worker to complete", PROCESS_NAME)); - } + ); + + priorityProcessWorker.shutdownWithCallback(() -> listenWithTimeout.onResponse(AcknowledgedResponse.TRUE)); } - private void stopProcessGracefully() { + private void closeProcessAndWaitForResultProcessor() { try { closeProcessIfPresent(); resultProcessor.awaitCompletion(COMPLETION_TIMEOUT.getMinutes(), TimeUnit.MINUTES); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java index 68389e6ca7165..dd622a1fe4ce6 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java @@ -313,7 +313,7 @@ public synchronized void updateStats(PyTorchResult result) { } } - public void stop() { + public void signalIntentToStop() { isStopping = true; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/AbstractProcessWorkerExecutorService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/AbstractProcessWorkerExecutorService.java index debe6586e453e..66a39bde0fe6a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/AbstractProcessWorkerExecutorService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/AbstractProcessWorkerExecutorService.java @@ -44,6 +44,7 @@ public abstract class AbstractProcessWorkerExecutorService e private final AtomicReference error = new AtomicReference<>(); private final AtomicBoolean running = new AtomicBoolean(true); private final AtomicBoolean shouldShutdownAfterCompletingWork = new AtomicBoolean(false); + private final AtomicReference onCompletion = new AtomicReference<>(); /** * @param contextHolder the thread context holder @@ -78,6 +79,11 @@ public void shutdown() { shouldShutdownAfterCompletingWork.set(true); } + public void shutdownWithCallback(Runnable onCompletion) { + this.onCompletion.set(onCompletion); + shutdown(); + } + /** * Some of the tasks in the returned list of {@link Runnable}s could have run. Some tasks may have run while the queue was being copied. * @@ -124,6 +130,10 @@ public void start() { } catch (InterruptedException e) { Thread.currentThread().interrupt(); } finally { + Runnable onComplete = onCompletion.get(); + if (onComplete != null) { + onComplete.run(); + } awaitTermination.countDown(); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java index 7aa4faa6459dc..af15d9d1c6acc 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java @@ -265,7 +265,11 @@ public void testLoadQueuedModelsWhenTaskIsStopped() throws Exception { UpdateTrainedModelAssignmentRoutingInfoAction.Request.class ); verify(deploymentManager, times(1)).startDeployment(startTaskCapture.capture(), any()); - assertBusy(() -> verify(trainedModelAssignmentService, times(3)).updateModelAssignmentState(requestCapture.capture(), any())); + assertBusy( + () -> verify(trainedModelAssignmentService, times(3)).updateModelAssignmentState(requestCapture.capture(), any()), + 3, + TimeUnit.SECONDS + ); boolean seenStopping = false; for (int i = 0; i < 3; i++) { @@ -397,6 +401,13 @@ public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNode_CallsSto return null; }).when(trainedModelAssignmentService).updateModelAssignmentState(any(), any()); + doAnswer(invocationOnMock -> { + @SuppressWarnings({ "unchecked", "rawtypes" }) + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; + listener.onResponse(AcknowledgedResponse.TRUE); + return null; + }).when(deploymentManager).stopAfterCompletingPendingWork(any(), any()); + var taskParams = newParams(deploymentOne, modelOne); ClusterChangedEvent event = new ClusterChangedEvent( @@ -430,7 +441,7 @@ public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNode_CallsSto } assertBusy(() -> { - verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(stopParamsCapture.capture()); + verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(stopParamsCapture.capture(), any()); assertThat(stopParamsCapture.getValue().getModelId(), equalTo(modelOne)); assertThat(stopParamsCapture.getValue().getDeploymentId(), equalTo(deploymentOne)); }); @@ -481,7 +492,7 @@ public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNode_ButOther trainedModelAssignmentNodeService.prepareModelToLoad(taskParams); trainedModelAssignmentNodeService.clusterChanged(event); - verify(deploymentManager, never()).stopAfterCompletingPendingWork(any()); + verify(deploymentManager, never()).stopAfterCompletingPendingWork(any(), any()); verify(trainedModelAssignmentService, never()).updateModelAssignmentState( any(UpdateTrainedModelAssignmentRoutingInfoAction.Request.class), any() @@ -522,7 +533,7 @@ public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNodeButAlread trainedModelAssignmentNodeService.clusterChanged(event); - verify(deploymentManager, never()).stopAfterCompletingPendingWork(any()); + verify(deploymentManager, never()).stopAfterCompletingPendingWork(any(), any()); verify(trainedModelAssignmentService, never()).updateModelAssignmentState( any(UpdateTrainedModelAssignmentRoutingInfoAction.Request.class), any() @@ -564,7 +575,7 @@ public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNodeWithStart trainedModelAssignmentNodeService.prepareModelToLoad(taskParams); trainedModelAssignmentNodeService.clusterChanged(event); - verify(deploymentManager, never()).stopAfterCompletingPendingWork(any()); + verify(deploymentManager, never()).stopAfterCompletingPendingWork(any(), any()); verify(trainedModelAssignmentService, never()).updateModelAssignmentState( any(UpdateTrainedModelAssignmentRoutingInfoAction.Request.class), any() @@ -601,7 +612,7 @@ public void testClusterChanged_WhenNodeDoesNotExistInAssignmentRoutingTable_Does trainedModelAssignmentNodeService.prepareModelToLoad(taskParams); trainedModelAssignmentNodeService.clusterChanged(event); - assertBusy(() -> verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(any())); + assertBusy(() -> verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(any(), any())); // This still shouldn't trigger a cluster state update because the routing entry wasn't in the table so we won't add a new routing // entry for stopping verify(trainedModelAssignmentService, never()).updateModelAssignmentState( @@ -765,7 +776,7 @@ public void testClusterChanged() throws Exception { ArgumentCaptor stoppedTaskCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class); // deployment-2 was originally started on node NODE_ID but in the latest cluster event it is no longer on that node so we will // gracefully stop it - verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(stoppedTaskCapture.capture()); + verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(stoppedTaskCapture.capture(), any()); assertThat(stoppedTaskCapture.getAllValues().get(0).getDeploymentId(), equalTo(deploymentTwo)); }); ArgumentCaptor startTaskCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class);