Skip to content

Commit 1aea049

Browse files
prwhelanelasticmachinedavidkyle
authored
[ML] Avoid ModelAssignment deadlock (#109684)
The model loading scheduled thread iterates through the model queue and deploys each model. Rather than block and wait on each deployment, the thread will attach a listener that will either iterate to the next model (if one is in the queue) or reschedule the thread. This change should not impact: 1. the iterative nature of the model deployment process - each model is still deployed one at a time, and no additional threads are consumed per model. 2. the 1s delay between model deployment tries - if a deployment fails but can be retried, the retry is added to the next batch of models that are consumed after the 1s scheduled delay. Co-authored-by: Elastic Machine <[email protected]> Co-authored-by: David Kyle <[email protected]>
1 parent 5d53c9a commit 1aea049

File tree

3 files changed

+150
-89
lines changed

3 files changed

+150
-89
lines changed

docs/changelog/109684.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 109684
2+
summary: Avoid `ModelAssignment` deadlock
3+
area: Machine Learning
4+
type: bug
5+
issues: []

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java

Lines changed: 71 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
import org.elasticsearch.ResourceNotFoundException;
1313
import org.elasticsearch.action.ActionListener;
1414
import org.elasticsearch.action.search.SearchPhaseExecutionException;
15-
import org.elasticsearch.action.support.PlainActionFuture;
16-
import org.elasticsearch.action.support.UnsafePlainActionFuture;
15+
import org.elasticsearch.action.support.SubscribableListener;
1716
import org.elasticsearch.action.support.master.AcknowledgedResponse;
1817
import org.elasticsearch.cluster.ClusterChangedEvent;
1918
import org.elasticsearch.cluster.ClusterState;
@@ -53,7 +52,6 @@
5352
import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask;
5453
import org.elasticsearch.xpack.ml.task.AbstractJobPersistentTasksExecutor;
5554

56-
import java.util.ArrayDeque;
5755
import java.util.ArrayList;
5856
import java.util.Collections;
5957
import java.util.Deque;
@@ -154,26 +152,38 @@ public void beforeStop() {
154152
this.expressionResolver = expressionResolver;
155153
}
156154

157-
public void start() {
155+
void start() {
158156
stopped = false;
159-
scheduledFuture = threadPool.scheduleWithFixedDelay(
160-
this::loadQueuedModels,
161-
MODEL_LOADING_CHECK_INTERVAL,
162-
threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME)
163-
);
157+
schedule(false);
164158
}
165159

166-
public void stop() {
160+
private void schedule(boolean runImmediately) {
161+
if (stopped) {
162+
// do not schedule when stopped
163+
return;
164+
}
165+
166+
var rescheduleListener = ActionListener.wrap(this::schedule, e -> this.schedule(false));
167+
Runnable loadQueuedModels = () -> loadQueuedModels(rescheduleListener);
168+
var executor = threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME);
169+
170+
if (runImmediately) {
171+
executor.execute(loadQueuedModels);
172+
} else {
173+
scheduledFuture = threadPool.schedule(loadQueuedModels, MODEL_LOADING_CHECK_INTERVAL, executor);
174+
}
175+
}
176+
177+
void stop() {
167178
stopped = true;
168179
ThreadPool.Cancellable cancellable = this.scheduledFuture;
169180
if (cancellable != null) {
170181
cancellable.cancel();
171182
}
172183
}
173184

174-
void loadQueuedModels() {
175-
TrainedModelDeploymentTask loadingTask;
176-
if (loadingModels.isEmpty()) {
185+
void loadQueuedModels(ActionListener<Boolean> rescheduleImmediately) {
186+
if (stopped) {
177187
return;
178188
}
179189
if (latestState != null) {
@@ -188,39 +198,49 @@ void loadQueuedModels() {
188198
);
189199
if (unassignedIndices.size() > 0) {
190200
logger.trace("not loading models as indices {} primary shards are unassigned", unassignedIndices);
201+
rescheduleImmediately.onResponse(false);
191202
return;
192203
}
193204
}
194-
logger.trace("attempting to load all currently queued models");
195-
// NOTE: As soon as this method exits, the timer for the scheduler starts ticking
196-
Deque<TrainedModelDeploymentTask> loadingToRetry = new ArrayDeque<>();
197-
while ((loadingTask = loadingModels.poll()) != null) {
198-
final String deploymentId = loadingTask.getDeploymentId();
199-
if (loadingTask.isStopped()) {
200-
if (logger.isTraceEnabled()) {
201-
String reason = loadingTask.stoppedReason().orElse("_unknown_");
202-
logger.trace("[{}] attempted to load stopped task with reason [{}]", deploymentId, reason);
203-
}
204-
continue;
205+
206+
var loadingTask = loadingModels.poll();
207+
if (loadingTask == null) {
208+
rescheduleImmediately.onResponse(false);
209+
return;
210+
}
211+
212+
loadModel(loadingTask, ActionListener.wrap(retry -> {
213+
if (retry != null && retry) {
214+
loadingModels.offer(loadingTask);
215+
// don't reschedule immediately if the next task is the one we just queued, instead wait a bit to retry
216+
rescheduleImmediately.onResponse(loadingModels.peek() != loadingTask);
217+
} else {
218+
rescheduleImmediately.onResponse(loadingModels.isEmpty() == false);
205219
}
206-
if (stopped) {
207-
return;
220+
}, e -> rescheduleImmediately.onResponse(loadingModels.isEmpty() == false)));
221+
}
222+
223+
void loadModel(TrainedModelDeploymentTask loadingTask, ActionListener<Boolean> retryListener) {
224+
if (loadingTask.isStopped()) {
225+
if (logger.isTraceEnabled()) {
226+
logger.trace(
227+
"[{}] attempted to load stopped task with reason [{}]",
228+
loadingTask.getDeploymentId(),
229+
loadingTask.stoppedReason().orElse("_unknown_")
230+
);
208231
}
209-
final PlainActionFuture<TrainedModelDeploymentTask> listener = new UnsafePlainActionFuture<>(
210-
MachineLearning.UTILITY_THREAD_POOL_NAME
211-
);
212-
try {
213-
deploymentManager.startDeployment(loadingTask, listener);
214-
// This needs to be synchronous here in the utility thread to keep queueing order
215-
TrainedModelDeploymentTask deployedTask = listener.actionGet();
216-
// kicks off asynchronous cluster state update
217-
handleLoadSuccess(deployedTask);
218-
} catch (Exception ex) {
232+
retryListener.onResponse(false);
233+
return;
234+
}
235+
SubscribableListener.<TrainedModelDeploymentTask>newForked(l -> deploymentManager.startDeployment(loadingTask, l))
236+
.andThen(threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME), threadPool.getThreadContext(), this::handleLoadSuccess)
237+
.addListener(retryListener.delegateResponse((retryL, ex) -> {
238+
var deploymentId = loadingTask.getDeploymentId();
219239
logger.warn(() -> "[" + deploymentId + "] Start deployment failed", ex);
220240
if (ExceptionsHelper.unwrapCause(ex) instanceof ResourceNotFoundException) {
221-
String modelId = loadingTask.getParams().getModelId();
241+
var modelId = loadingTask.getParams().getModelId();
222242
logger.debug(() -> "[" + deploymentId + "] Start deployment failed as model [" + modelId + "] was not found", ex);
223-
handleLoadFailure(loadingTask, ExceptionsHelper.missingTrainedModel(modelId, ex));
243+
handleLoadFailure(loadingTask, ExceptionsHelper.missingTrainedModel(modelId, ex), retryL);
224244
} else if (ExceptionsHelper.unwrapCause(ex) instanceof SearchPhaseExecutionException) {
225245
/*
226246
* This case will not catch the ElasticsearchException generated from the ChunkedTrainedModelRestorer in a scenario
@@ -232,13 +252,11 @@ void loadQueuedModels() {
232252
// A search phase execution failure should be retried, push task back to the queue
233253

234254
// This will cause the entire model to be reloaded (all the chunks)
235-
loadingToRetry.add(loadingTask);
255+
retryL.onResponse(true);
236256
} else {
237-
handleLoadFailure(loadingTask, ex);
257+
handleLoadFailure(loadingTask, ex, retryL);
238258
}
239-
}
240-
}
241-
loadingModels.addAll(loadingToRetry);
259+
}), threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME), threadPool.getThreadContext());
242260
}
243261

244262
public void gracefullyStopDeploymentAndNotify(
@@ -680,14 +698,14 @@ void prepareModelToLoad(StartTrainedModelDeploymentAction.TaskParams taskParams)
680698
);
681699
// threadsafe check to verify we are not loading/loaded the model
682700
if (deploymentIdToTask.putIfAbsent(taskParams.getDeploymentId(), task) == null) {
683-
loadingModels.add(task);
701+
loadingModels.offer(task);
684702
} else {
685703
// If there is already a task for the deployment, unregister the new task
686704
taskManager.unregister(task);
687705
}
688706
}
689707

690-
private void handleLoadSuccess(TrainedModelDeploymentTask task) {
708+
private void handleLoadSuccess(ActionListener<Boolean> retryListener, TrainedModelDeploymentTask task) {
691709
logger.debug(
692710
() -> "["
693711
+ task.getParams().getDeploymentId()
@@ -704,13 +722,16 @@ private void handleLoadSuccess(TrainedModelDeploymentTask task) {
704722
task.stoppedReason().orElse("_unknown_")
705723
)
706724
);
725+
retryListener.onResponse(false);
707726
return;
708727
}
709728

710729
updateStoredState(
711730
task.getDeploymentId(),
712731
RoutingInfoUpdate.updateStateAndReason(new RoutingStateAndReason(RoutingState.STARTED, "")),
713-
ActionListener.wrap(r -> logger.debug(() -> "[" + task.getDeploymentId() + "] model loaded and accepting routes"), e -> {
732+
ActionListener.runAfter(ActionListener.wrap(r -> {
733+
logger.debug(() -> "[" + task.getDeploymentId() + "] model loaded and accepting routes");
734+
}, e -> {
714735
// This means that either the assignment has been deleted, or this node's particular route has been removed
715736
if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
716737
logger.debug(
@@ -732,7 +753,7 @@ private void handleLoadSuccess(TrainedModelDeploymentTask task) {
732753
e
733754
);
734755
}
735-
})
756+
}), () -> retryListener.onResponse(false))
736757
);
737758
}
738759

@@ -752,7 +773,7 @@ private void updateStoredState(String deploymentId, RoutingInfoUpdate update, Ac
752773
);
753774
}
754775

755-
private void handleLoadFailure(TrainedModelDeploymentTask task, Exception ex) {
776+
private void handleLoadFailure(TrainedModelDeploymentTask task, Exception ex, ActionListener<Boolean> retryListener) {
756777
logger.error(() -> "[" + task.getDeploymentId() + "] model [" + task.getParams().getModelId() + "] failed to load", ex);
757778
if (task.isStopped()) {
758779
logger.debug(
@@ -769,14 +790,14 @@ private void handleLoadFailure(TrainedModelDeploymentTask task, Exception ex) {
769790
Runnable stopTask = () -> stopDeploymentAsync(
770791
task,
771792
"model failed to load; reason [" + ex.getMessage() + "]",
772-
ActionListener.noop()
793+
ActionListener.running(() -> retryListener.onResponse(false))
773794
);
774795
updateStoredState(
775796
task.getDeploymentId(),
776797
RoutingInfoUpdate.updateStateAndReason(
777798
new RoutingStateAndReason(RoutingState.FAILED, ExceptionsHelper.unwrapCause(ex).getMessage())
778799
),
779-
ActionListener.wrap(r -> stopTask.run(), e -> stopTask.run())
800+
ActionListener.running(stopTask)
780801
);
781802
}
782803

0 commit comments

Comments
 (0)