Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.function.Consumer;
import java.util.function.BiConsumer;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ASSIGNMENT_TASK_ACTION;
Expand Down Expand Up @@ -274,20 +274,27 @@ public void gracefullyStopDeploymentAndNotify(
public void stopDeploymentAndNotify(TrainedModelDeploymentTask task, String reason, ActionListener<AcknowledgedResponse> listener) {
logger.debug(() -> format("[%s] Forcefully stopping deployment due to reason %s", task.getDeploymentId(), reason));

stopAndNotifyHelper(task, reason, listener, deploymentManager::stopDeployment);
stopAndNotifyHelper(task, reason, listener, (t, l) -> {
deploymentManager.stopDeployment(t);
l.onResponse(AcknowledgedResponse.TRUE);
});
}

private void stopAndNotifyHelper(
TrainedModelDeploymentTask task,
String reason,
ActionListener<AcknowledgedResponse> listener,
Consumer<TrainedModelDeploymentTask> stopDeploymentFunc
BiConsumer<TrainedModelDeploymentTask, ActionListener<AcknowledgedResponse>> stopDeploymentFunc
) {
// Removing the entry from the map to avoid the possibility of a node shutdown triggering a concurrent graceful stopping of the
// process while we are attempting to forcefully stop the native process
// The graceful stopping will only occur if there is an entry in the map
deploymentIdToTask.remove(task.getDeploymentId());
ActionListener<Void> notifyDeploymentOfStopped = updateRoutingStateToStoppedListener(task.getDeploymentId(), reason, listener);
ActionListener<AcknowledgedResponse> notifyDeploymentOfStopped = updateRoutingStateToStoppedListener(
task.getDeploymentId(),
reason,
listener
);

updateStoredState(
task.getDeploymentId(),
Expand Down Expand Up @@ -541,7 +548,7 @@ private void gracefullyStopDeployment(String deploymentId, String currentNode) {
)
);

ActionListener<Void> notifyDeploymentOfStopped = updateRoutingStateToStoppedListener(
ActionListener<AcknowledgedResponse> notifyDeploymentOfStopped = updateRoutingStateToStoppedListener(
task.getDeploymentId(),
NODE_IS_SHUTTING_DOWN,
routingStateListener
Expand All @@ -550,7 +557,7 @@ private void gracefullyStopDeployment(String deploymentId, String currentNode) {
stopDeploymentAfterCompletingPendingWorkAsync(task, NODE_IS_SHUTTING_DOWN, notifyDeploymentOfStopped);
}

private ActionListener<Void> updateRoutingStateToStoppedListener(
private ActionListener<AcknowledgedResponse> updateRoutingStateToStoppedListener(
String deploymentId,
String reason,
ActionListener<AcknowledgedResponse> listener
Expand Down Expand Up @@ -594,27 +601,30 @@ private void stopUnreferencedDeployment(String deploymentId, String currentNode)
);
}

private void stopDeploymentAsync(TrainedModelDeploymentTask task, String reason, ActionListener<Void> listener) {
stopDeploymentHelper(task, reason, deploymentManager::stopDeployment, listener);
private void stopDeploymentAsync(TrainedModelDeploymentTask task, String reason, ActionListener<AcknowledgedResponse> listener) {
stopDeploymentHelper(task, reason, (t, l) -> {
deploymentManager.stopDeployment(t);
l.onResponse(AcknowledgedResponse.TRUE);
}, listener);
}

private void stopDeploymentHelper(
TrainedModelDeploymentTask task,
String reason,
Consumer<TrainedModelDeploymentTask> stopDeploymentFunc,
ActionListener<Void> listener
BiConsumer<TrainedModelDeploymentTask, ActionListener<AcknowledgedResponse>> stopDeploymentFunc,
ActionListener<AcknowledgedResponse> listener
) {
if (stopped) {
listener.onResponse(AcknowledgedResponse.FALSE);
return;
}
task.markAsStopped(reason);

threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> {
try {
stopDeploymentFunc.accept(task);
taskManager.unregister(task);
deploymentIdToTask.remove(task.getDeploymentId());
listener.onResponse(null);
stopDeploymentFunc.accept(task, listener);
} catch (Exception e) {
listener.onFailure(e);
}
Expand All @@ -624,7 +634,7 @@ private void stopDeploymentHelper(
private void stopDeploymentAfterCompletingPendingWorkAsync(
TrainedModelDeploymentTask task,
String reason,
ActionListener<Void> listener
ActionListener<AcknowledgedResponse> listener
) {
stopDeploymentHelper(task, reason, deploymentManager::stopAfterCompletingPendingWork, listener);
}
Expand Down Expand Up @@ -769,6 +779,7 @@ private void handleLoadSuccess(ActionListener<Boolean> retryListener, TrainedMod

private void updateStoredState(String deploymentId, RoutingInfoUpdate update, ActionListener<AcknowledgedResponse> listener) {
if (stopped) {
listener.onResponse(AcknowledgedResponse.FALSE);
return;
}
trainedModelAssignmentService.updateModelAssignmentState(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,15 +331,15 @@ public void stopDeployment(TrainedModelDeploymentTask task) {
}
}

public void stopAfterCompletingPendingWork(TrainedModelDeploymentTask task) {
public void stopAfterCompletingPendingWork(TrainedModelDeploymentTask task, ActionListener<AcknowledgedResponse> listener) {
ProcessContext processContext = processContextByAllocation.remove(task.getId());
if (processContext != null) {
logger.info(
"[{}] Stopping deployment after completing pending tasks, reason [{}]",
task.getDeploymentId(),
task.stoppedReason().orElse("unknown")
);
processContext.stopProcessAfterCompletingPendingWork();
processContext.stopProcessAfterCompletingPendingWork(listener);
} else {
logger.warn("[{}] No process context to stop gracefully", task.getDeploymentId());
}
Expand Down Expand Up @@ -569,7 +569,7 @@ private Consumer<String> onProcessCrashHandleRestarts(AtomicInteger startsCount,

processContextByAllocation.remove(task.getId());
isStopped = true;
resultProcessor.stop();
resultProcessor.signalIntentToStop();
stateStreamer.cancel();

if (startsCount.get() <= NUM_RESTART_ATTEMPTS) {
Expand Down Expand Up @@ -648,7 +648,7 @@ synchronized void forcefullyStopProcess() {

private void prepareInternalStateForShutdown() {
isStopped = true;
resultProcessor.stop();
resultProcessor.signalIntentToStop();
stateStreamer.cancel();
}

Expand All @@ -669,43 +669,33 @@ private void closeNlpTaskProcessor() {
}
}

private synchronized void stopProcessAfterCompletingPendingWork() {
private synchronized void stopProcessAfterCompletingPendingWork(ActionListener<AcknowledgedResponse> listener) {
logger.debug(() -> format("[%s] Stopping process after completing its pending work", task.getDeploymentId()));
prepareInternalStateForShutdown();
signalAndWaitForWorkerTermination();
stopProcessGracefully();
closeNlpTaskProcessor();
}

private void signalAndWaitForWorkerTermination() {
try {
awaitTerminationAfterCompletingWork();
} catch (TimeoutException e) {
logger.warn(format("[%s] Timed out waiting for process worker to complete, forcing a shutdown", task.getDeploymentId()), e);
// The process failed to stop in the time period allotted, so we'll mark it for shut down
priorityProcessWorker.shutdown();
priorityProcessWorker.notifyQueueRunnables();
}
}

private void awaitTerminationAfterCompletingWork() throws TimeoutException {
try {
priorityProcessWorker.shutdown();

if (priorityProcessWorker.awaitTermination(COMPLETION_TIMEOUT.getMinutes(), TimeUnit.MINUTES) == false) {
throw new TimeoutException(
Strings.format("Timed out waiting for process worker to complete for process %s", PROCESS_NAME)
);
} else {
priorityProcessWorker.notifyQueueRunnables();
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
logger.info(Strings.format("[%s] Interrupted waiting for process worker to complete", PROCESS_NAME));
}
}

private void stopProcessGracefully() {
// Waiting for the process worker to finish the pending work could
// take a long time. Best not to block the thread so register
// a function with the process worker that is called when the
// work is finished. Then proceed to closing the native process
// and wait for all results to be processed, the second part can be
// done synchronously as it is not expected to take long.
// The ShutdownTracker will handle this.

// Shutdown tracker will stop the process work and start a race with
// a timeout condition.
new ShutdownTracker(() -> {
// Stopping the process worker timed out, kill the process
logger.warn(format("[%s] Timed out waiting for process worker to complete, forcing a shutdown", task.getDeploymentId()));
forcefullyStopProcess();
}, () -> {
// process worker stopped within allotted time, close process
closeProcessAndWaitForResultProcessor();
closeNlpTaskProcessor();
}, threadPool, priorityProcessWorker, listener);

}

private void closeProcessAndWaitForResultProcessor() {
try {
closeProcessIfPresent();
resultProcessor.awaitCompletion(COMPLETION_TIMEOUT.getMinutes(), TimeUnit.MINUTES);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.ml.inference.deployment;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.threadpool.Scheduler;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.pytorch.PriorityProcessWorkerExecutorService;

import java.util.concurrent.atomic.AtomicBoolean;

public class ShutdownTracker {

private final ActionListener<AcknowledgedResponse> everythingStoppedListener;

private final Scheduler.Cancellable timeoutHandler;
private final Runnable onWorkerQueueCompletedCallback;
private final Runnable onTimeoutCallback;
private final Object monitor = new Object();
private final AtomicBoolean timedOutOrCompleted = new AtomicBoolean();

private static final TimeValue COMPLETION_TIMEOUT = TimeValue.timeValueMinutes(5);

public ShutdownTracker(
Runnable onTimeoutCallback,
Runnable onWorkerQueueCompletedCallback,
ThreadPool threadPool,
PriorityProcessWorkerExecutorService workerQueue,
ActionListener<AcknowledgedResponse> everythingStoppedListener
) {
this.onTimeoutCallback = onTimeoutCallback;
this.onWorkerQueueCompletedCallback = onWorkerQueueCompletedCallback;
this.everythingStoppedListener = ActionListener.notifyOnce(everythingStoppedListener);

// initiate the worker shutdown and add this as a callback when completed
workerQueue.shutdownWithCallback(this::workerQueueCompleted);
// start the race with the timeout and the worker completing
this.timeoutHandler = threadPool.schedule(
this::onTimeout,
COMPLETION_TIMEOUT,
threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME)
);
}

private void onTimeout() {
synchronized (monitor) { // TODO remove the lock as the atomic should be sufficient
if (timedOutOrCompleted.compareAndSet(false, true) == false) {
// already completed
return;
}
onTimeoutCallback.run();
everythingStoppedListener.onResponse(AcknowledgedResponse.FALSE);
}
}

private void workerQueueCompleted() {
synchronized (monitor) { // TODO remove the lock as the atomic should be sufficient
if (timedOutOrCompleted.compareAndSet(false, true) == false) {
// already completed
return;
}
timeoutHandler.cancel();
onWorkerQueueCompletedCallback.run();
everythingStoppedListener.onResponse(AcknowledgedResponse.TRUE);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ public synchronized void updateStats(PyTorchResult result) {
}
}

public void stop() {
public void signalIntentToStop() {
isStopping = true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ public abstract class AbstractProcessWorkerExecutorService<T extends Runnable> e
private final AtomicReference<Exception> error = new AtomicReference<>();
private final AtomicBoolean running = new AtomicBoolean(true);
private final AtomicBoolean shouldShutdownAfterCompletingWork = new AtomicBoolean(false);
private final AtomicReference<Runnable> onCompletion = new AtomicReference<>();

/**
* @param contextHolder the thread context holder
Expand Down Expand Up @@ -78,6 +79,11 @@ public void shutdown() {
shouldShutdownAfterCompletingWork.set(true);
}

public void shutdownWithCallback(Runnable onCompletion) {
this.onCompletion.set(onCompletion);
shutdown();
}

/**
* Some of the tasks in the returned list of {@link Runnable}s could have run. Some tasks may have run while the queue was being copied.
*
Expand Down Expand Up @@ -124,6 +130,10 @@ public void start() {
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} finally {
Runnable onComplete = onCompletion.get();
if (onComplete != null) {
onComplete.run();
}
awaitTermination.countDown();
}
}
Expand Down
Loading