Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.function.Consumer;
import java.util.function.BiConsumer;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ASSIGNMENT_TASK_ACTION;
Expand Down Expand Up @@ -274,20 +274,27 @@ public void gracefullyStopDeploymentAndNotify(
public void stopDeploymentAndNotify(TrainedModelDeploymentTask task, String reason, ActionListener<AcknowledgedResponse> listener) {
logger.debug(() -> format("[%s] Forcefully stopping deployment due to reason %s", task.getDeploymentId(), reason));

stopAndNotifyHelper(task, reason, listener, deploymentManager::stopDeployment);
stopAndNotifyHelper(task, reason, listener, (t, l) -> {
deploymentManager.stopDeployment(t);
l.onResponse(AcknowledgedResponse.TRUE);
});
}

private void stopAndNotifyHelper(
TrainedModelDeploymentTask task,
String reason,
ActionListener<AcknowledgedResponse> listener,
Consumer<TrainedModelDeploymentTask> stopDeploymentFunc
BiConsumer<TrainedModelDeploymentTask, ActionListener<AcknowledgedResponse>> stopDeploymentFunc
) {
// Removing the entry from the map to avoid the possibility of a node shutdown triggering a concurrent graceful stopping of the
// process while we are attempting to forcefully stop the native process
// The graceful stopping will only occur if there is an entry in the map
deploymentIdToTask.remove(task.getDeploymentId());
ActionListener<Void> notifyDeploymentOfStopped = updateRoutingStateToStoppedListener(task.getDeploymentId(), reason, listener);
ActionListener<AcknowledgedResponse> notifyDeploymentOfStopped = updateRoutingStateToStoppedListener(
task.getDeploymentId(),
reason,
listener
);

updateStoredState(
task.getDeploymentId(),
Expand Down Expand Up @@ -541,7 +548,7 @@ private void gracefullyStopDeployment(String deploymentId, String currentNode) {
)
);

ActionListener<Void> notifyDeploymentOfStopped = updateRoutingStateToStoppedListener(
ActionListener<AcknowledgedResponse> notifyDeploymentOfStopped = updateRoutingStateToStoppedListener(
task.getDeploymentId(),
NODE_IS_SHUTTING_DOWN,
routingStateListener
Expand All @@ -550,7 +557,7 @@ private void gracefullyStopDeployment(String deploymentId, String currentNode) {
stopDeploymentAfterCompletingPendingWorkAsync(task, NODE_IS_SHUTTING_DOWN, notifyDeploymentOfStopped);
}

private ActionListener<Void> updateRoutingStateToStoppedListener(
private ActionListener<AcknowledgedResponse> updateRoutingStateToStoppedListener(
String deploymentId,
String reason,
ActionListener<AcknowledgedResponse> listener
Expand Down Expand Up @@ -594,27 +601,30 @@ private void stopUnreferencedDeployment(String deploymentId, String currentNode)
);
}

private void stopDeploymentAsync(TrainedModelDeploymentTask task, String reason, ActionListener<Void> listener) {
stopDeploymentHelper(task, reason, deploymentManager::stopDeployment, listener);
private void stopDeploymentAsync(TrainedModelDeploymentTask task, String reason, ActionListener<AcknowledgedResponse> listener) {
stopDeploymentHelper(task, reason, (t, l) -> {
deploymentManager.stopDeployment(t);
l.onResponse(AcknowledgedResponse.TRUE);
}, listener);
}

private void stopDeploymentHelper(
TrainedModelDeploymentTask task,
String reason,
Consumer<TrainedModelDeploymentTask> stopDeploymentFunc,
ActionListener<Void> listener
BiConsumer<TrainedModelDeploymentTask, ActionListener<AcknowledgedResponse>> stopDeploymentFunc,
ActionListener<AcknowledgedResponse> listener
) {
if (stopped) {
listener.onResponse(AcknowledgedResponse.FALSE);
return;
}
task.markAsStopped(reason);

threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> {
try {
stopDeploymentFunc.accept(task);
taskManager.unregister(task);
deploymentIdToTask.remove(task.getDeploymentId());
listener.onResponse(null);
stopDeploymentFunc.accept(task, listener);
} catch (Exception e) {
listener.onFailure(e);
}
Expand All @@ -624,7 +634,7 @@ private void stopDeploymentHelper(
private void stopDeploymentAfterCompletingPendingWorkAsync(
TrainedModelDeploymentTask task,
String reason,
ActionListener<Void> listener
ActionListener<AcknowledgedResponse> listener
) {
stopDeploymentHelper(task, reason, deploymentManager::stopAfterCompletingPendingWork, listener);
}
Expand Down Expand Up @@ -769,6 +779,7 @@ private void handleLoadSuccess(ActionListener<Boolean> retryListener, TrainedMod

private void updateStoredState(String deploymentId, RoutingInfoUpdate update, ActionListener<AcknowledgedResponse> listener) {
if (stopped) {
listener.onResponse(AcknowledgedResponse.FALSE);
return;
}
trainedModelAssignmentService.updateModelAssignmentState(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.TransportSearchAction;
import org.elasticsearch.action.support.ListenerTimeouts;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.Strings;
Expand Down Expand Up @@ -79,6 +80,7 @@ public class DeploymentManager {
private static final Logger logger = LogManager.getLogger(DeploymentManager.class);
private static final AtomicLong requestIdCounter = new AtomicLong(1);
public static final int NUM_RESTART_ATTEMPTS = 3;
private static final TimeValue WORKER_QUEUE_COMPLETION_TIMEOUT = TimeValue.timeValueMinutes(5);

private final Client client;
private final NamedXContentRegistry xContentRegistry;
Expand Down Expand Up @@ -331,15 +333,15 @@ public void stopDeployment(TrainedModelDeploymentTask task) {
}
}

public void stopAfterCompletingPendingWork(TrainedModelDeploymentTask task) {
public void stopAfterCompletingPendingWork(TrainedModelDeploymentTask task, ActionListener<AcknowledgedResponse> listener) {
ProcessContext processContext = processContextByAllocation.remove(task.getId());
if (processContext != null) {
logger.info(
"[{}] Stopping deployment after completing pending tasks, reason [{}]",
task.getDeploymentId(),
task.stoppedReason().orElse("unknown")
);
processContext.stopProcessAfterCompletingPendingWork();
processContext.stopProcessAfterCompletingPendingWork(listener);
} else {
logger.warn("[{}] No process context to stop gracefully", task.getDeploymentId());
}
Expand Down Expand Up @@ -569,7 +571,7 @@ private Consumer<String> onProcessCrashHandleRestarts(AtomicInteger startsCount,

processContextByAllocation.remove(task.getId());
isStopped = true;
resultProcessor.stop();
resultProcessor.signalIntentToStop();
stateStreamer.cancel();

if (startsCount.get() <= NUM_RESTART_ATTEMPTS) {
Expand Down Expand Up @@ -648,7 +650,7 @@ synchronized void forcefullyStopProcess() {

private void prepareInternalStateForShutdown() {
isStopped = true;
resultProcessor.stop();
resultProcessor.signalIntentToStop();
stateStreamer.cancel();
}

Expand All @@ -669,43 +671,46 @@ private void closeNlpTaskProcessor() {
}
}

private synchronized void stopProcessAfterCompletingPendingWork() {
private synchronized void stopProcessAfterCompletingPendingWork(ActionListener<AcknowledgedResponse> listener) {
logger.debug(() -> format("[%s] Stopping process after completing its pending work", task.getDeploymentId()));
prepareInternalStateForShutdown();
signalAndWaitForWorkerTermination();
stopProcessGracefully();
closeNlpTaskProcessor();
}

private void signalAndWaitForWorkerTermination() {
try {
awaitTerminationAfterCompletingWork();
} catch (TimeoutException e) {
logger.warn(format("[%s] Timed out waiting for process worker to complete, forcing a shutdown", task.getDeploymentId()), e);
// The process failed to stop in the time period allotted, so we'll mark it for shut down
priorityProcessWorker.shutdown();
priorityProcessWorker.notifyQueueRunnables();
}
}

private void awaitTerminationAfterCompletingWork() throws TimeoutException {
try {
priorityProcessWorker.shutdown();
// Waiting for the process worker to finish the pending work could
// take a long time. To avoid blocking the calling thread register
// a function with the process worker queue that is called when the
// worker queue is finished. Then proceed to closing the native process
// and wait for all results to be processed, the second part can be
// done synchronously as it is not expected to take long.

// This listener closes the native process and waits for the results
// after the worker queue has finished
var closeProcessListener = listener.delegateResponse((l, r) -> {
// process worker stopped within allotted time, close process
closeProcessAndWaitForResultProcessor();
closeNlpTaskProcessor();
l.onResponse(AcknowledgedResponse.TRUE);
});

if (priorityProcessWorker.awaitTermination(COMPLETION_TIMEOUT.getMinutes(), TimeUnit.MINUTES) == false) {
throw new TimeoutException(
Strings.format("Timed out waiting for process worker to complete for process %s", PROCESS_NAME)
// Timeout listener waits
var listenWithTimeout = ListenerTimeouts.wrapWithTimeout(
threadPool,
WORKER_QUEUE_COMPLETION_TIMEOUT,
threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME),
closeProcessListener,
(l) -> {
// Stopping the process worker timed out, kill the process
logger.warn(
format("[%s] Timed out waiting for process worker to complete, forcing a shutdown", task.getDeploymentId())
);
} else {
priorityProcessWorker.notifyQueueRunnables();
forcefullyStopProcess();
l.onResponse(AcknowledgedResponse.FALSE);
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
logger.info(Strings.format("[%s] Interrupted waiting for process worker to complete", PROCESS_NAME));
}
);

priorityProcessWorker.shutdownWithCallback(() -> listenWithTimeout.onResponse(AcknowledgedResponse.TRUE));
}

private void stopProcessGracefully() {
private void closeProcessAndWaitForResultProcessor() {
try {
closeProcessIfPresent();
resultProcessor.awaitCompletion(COMPLETION_TIMEOUT.getMinutes(), TimeUnit.MINUTES);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ public synchronized void updateStats(PyTorchResult result) {
}
}

public void stop() {
public void signalIntentToStop() {
isStopping = true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ public abstract class AbstractProcessWorkerExecutorService<T extends Runnable> e
private final AtomicReference<Exception> error = new AtomicReference<>();
private final AtomicBoolean running = new AtomicBoolean(true);
private final AtomicBoolean shouldShutdownAfterCompletingWork = new AtomicBoolean(false);
private final AtomicReference<Runnable> onCompletion = new AtomicReference<>();

/**
* @param contextHolder the thread context holder
Expand Down Expand Up @@ -78,6 +79,11 @@ public void shutdown() {
shouldShutdownAfterCompletingWork.set(true);
}

public void shutdownWithCallback(Runnable onCompletion) {
this.onCompletion.set(onCompletion);
shutdown();
}

/**
* Some of the tasks in the returned list of {@link Runnable}s could have run. Some tasks may have run while the queue was being copied.
*
Expand Down Expand Up @@ -124,6 +130,10 @@ public void start() {
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} finally {
Runnable onComplete = onCompletion.get();
if (onComplete != null) {
onComplete.run();
}
awaitTermination.countDown();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,11 @@ public void testLoadQueuedModelsWhenTaskIsStopped() throws Exception {
UpdateTrainedModelAssignmentRoutingInfoAction.Request.class
);
verify(deploymentManager, times(1)).startDeployment(startTaskCapture.capture(), any());
assertBusy(() -> verify(trainedModelAssignmentService, times(3)).updateModelAssignmentState(requestCapture.capture(), any()));
assertBusy(
() -> verify(trainedModelAssignmentService, times(3)).updateModelAssignmentState(requestCapture.capture(), any()),
3,
TimeUnit.SECONDS
);

boolean seenStopping = false;
for (int i = 0; i < 3; i++) {
Expand Down Expand Up @@ -397,6 +401,13 @@ public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNode_CallsSto
return null;
}).when(trainedModelAssignmentService).updateModelAssignmentState(any(), any());

doAnswer(invocationOnMock -> {
@SuppressWarnings({ "unchecked", "rawtypes" })
ActionListener<AcknowledgedResponse> listener = (ActionListener) invocationOnMock.getArguments()[1];
listener.onResponse(AcknowledgedResponse.TRUE);
return null;
}).when(deploymentManager).stopAfterCompletingPendingWork(any(), any());

var taskParams = newParams(deploymentOne, modelOne);

ClusterChangedEvent event = new ClusterChangedEvent(
Expand Down Expand Up @@ -430,7 +441,7 @@ public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNode_CallsSto
}

assertBusy(() -> {
verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(stopParamsCapture.capture());
verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(stopParamsCapture.capture(), any());
assertThat(stopParamsCapture.getValue().getModelId(), equalTo(modelOne));
assertThat(stopParamsCapture.getValue().getDeploymentId(), equalTo(deploymentOne));
});
Expand Down Expand Up @@ -481,7 +492,7 @@ public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNode_ButOther
trainedModelAssignmentNodeService.prepareModelToLoad(taskParams);
trainedModelAssignmentNodeService.clusterChanged(event);

verify(deploymentManager, never()).stopAfterCompletingPendingWork(any());
verify(deploymentManager, never()).stopAfterCompletingPendingWork(any(), any());
verify(trainedModelAssignmentService, never()).updateModelAssignmentState(
any(UpdateTrainedModelAssignmentRoutingInfoAction.Request.class),
any()
Expand Down Expand Up @@ -522,7 +533,7 @@ public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNodeButAlread

trainedModelAssignmentNodeService.clusterChanged(event);

verify(deploymentManager, never()).stopAfterCompletingPendingWork(any());
verify(deploymentManager, never()).stopAfterCompletingPendingWork(any(), any());
verify(trainedModelAssignmentService, never()).updateModelAssignmentState(
any(UpdateTrainedModelAssignmentRoutingInfoAction.Request.class),
any()
Expand Down Expand Up @@ -564,7 +575,7 @@ public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNodeWithStart
trainedModelAssignmentNodeService.prepareModelToLoad(taskParams);
trainedModelAssignmentNodeService.clusterChanged(event);

verify(deploymentManager, never()).stopAfterCompletingPendingWork(any());
verify(deploymentManager, never()).stopAfterCompletingPendingWork(any(), any());
verify(trainedModelAssignmentService, never()).updateModelAssignmentState(
any(UpdateTrainedModelAssignmentRoutingInfoAction.Request.class),
any()
Expand Down Expand Up @@ -601,7 +612,7 @@ public void testClusterChanged_WhenNodeDoesNotExistInAssignmentRoutingTable_Does
trainedModelAssignmentNodeService.prepareModelToLoad(taskParams);
trainedModelAssignmentNodeService.clusterChanged(event);

assertBusy(() -> verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(any()));
assertBusy(() -> verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(any(), any()));
// This still shouldn't trigger a cluster state update because the routing entry wasn't in the table so we won't add a new routing
// entry for stopping
verify(trainedModelAssignmentService, never()).updateModelAssignmentState(
Expand Down Expand Up @@ -765,7 +776,7 @@ public void testClusterChanged() throws Exception {
ArgumentCaptor<TrainedModelDeploymentTask> stoppedTaskCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class);
// deployment-2 was originally started on node NODE_ID but in the latest cluster event it is no longer on that node so we will
// gracefully stop it
verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(stoppedTaskCapture.capture());
verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(stoppedTaskCapture.capture(), any());
assertThat(stoppedTaskCapture.getAllValues().get(0).getDeploymentId(), equalTo(deploymentTwo));
});
ArgumentCaptor<TrainedModelDeploymentTask> startTaskCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class);
Expand Down