Skip to content

Commit e171cc1

Browse files
jonathan-buttnerelasticsearchmachine
andauthored
[ML] Gracefully shutdown model deployment when node is removed from assignment routing (#134673) (#135439)
* Initial fix for graceful shutdown for unreferenced nodes * Update docs/changelog/134673.yaml * [CI] Auto commit changes from spotless * Fixing cluster change test and flaky tests * [CI] Auto commit changes from spotless * Addressing feedback --------- Co-authored-by: elasticsearchmachine <[email protected]>
1 parent 8fec5a5 commit e171cc1

File tree

3 files changed

+74
-16
lines changed

3 files changed

+74
-16
lines changed

docs/changelog/134673.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 134673
2+
summary: Gracefully shutdown model deployment when node is removed from assignment
3+
routing
4+
area: Machine Learning
5+
type: bug
6+
issues: []

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ private void gracefullyStopDeployment(String deploymentId, String currentNode) {
526526
if (task == null) {
527527
logger.debug(
528528
() -> format(
529-
"[%s] Unable to gracefully stop deployment for shutting down node %s because task does not exit",
529+
"[%s] Unable to gracefully stop deployment for shutting down node %s because task does not exist",
530530
deploymentId,
531531
currentNode
532532
)
@@ -550,7 +550,7 @@ private void gracefullyStopDeployment(String deploymentId, String currentNode) {
550550
routingStateListener
551551
);
552552

553-
stopDeploymentAfterCompletingPendingWorkAsync(task, notifyDeploymentOfStopped);
553+
stopDeploymentAfterCompletingPendingWorkAsync(task, NODE_IS_SHUTTING_DOWN, notifyDeploymentOfStopped);
554554
}
555555

556556
private ActionListener<Void> updateRoutingStateToStoppedListener(
@@ -576,11 +576,18 @@ private void stopUnreferencedDeployment(String deploymentId, String currentNode)
576576
// This model is not routed to the current node at all
577577
TrainedModelDeploymentTask task = deploymentIdToTask.remove(deploymentId);
578578
if (task == null) {
579+
logger.debug(
580+
() -> format(
581+
"[%s] Unable to stop unreferenced deployment for node %s because task does not exist",
582+
deploymentId,
583+
currentNode
584+
)
585+
);
579586
return;
580587
}
581588

582589
logger.debug(() -> format("[%s] Stopping unreferenced deployment for node %s", deploymentId, currentNode));
583-
stopDeploymentAsync(
590+
stopDeploymentAfterCompletingPendingWorkAsync(
584591
task,
585592
NODE_NO_LONGER_REFERENCED,
586593
ActionListener.wrap(
@@ -617,8 +624,12 @@ private void stopDeploymentHelper(
617624
});
618625
}
619626

620-
private void stopDeploymentAfterCompletingPendingWorkAsync(TrainedModelDeploymentTask task, ActionListener<Void> listener) {
621-
stopDeploymentHelper(task, NODE_IS_SHUTTING_DOWN, deploymentManager::stopAfterCompletingPendingWork, listener);
627+
private void stopDeploymentAfterCompletingPendingWorkAsync(
628+
TrainedModelDeploymentTask task,
629+
String reason,
630+
ActionListener<Void> listener
631+
) {
632+
stopDeploymentHelper(task, reason, deploymentManager::stopAfterCompletingPendingWork, listener);
622633
}
623634

624635
private void updateNumberOfAllocations(TrainedModelAssignmentMetadata assignments) {

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -380,8 +380,7 @@ public void testClusterChangedWithResetMode() throws InterruptedException {
380380
verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService);
381381
}
382382

383-
public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNode_CallsStopAfterCompletingPendingWork()
384-
throws InterruptedException {
383+
public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNode_CallsStopAfterCompletingPendingWork() throws Exception {
385384
final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService();
386385
final DiscoveryNodes nodes = DiscoveryNodes.builder().localNodeId(NODE_ID).add(DiscoveryNodeUtils.create(NODE_ID, NODE_ID)).build();
387386
String modelOne = "model-1";
@@ -430,17 +429,19 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNode_CallsStop
430429
fail("Failed waiting for the stop process call to complete");
431430
}
432431

433-
verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(stopParamsCapture.capture());
434-
assertThat(stopParamsCapture.getValue().getModelId(), equalTo(modelOne));
435-
assertThat(stopParamsCapture.getValue().getDeploymentId(), equalTo(deploymentOne));
432+
assertBusy(() -> {
433+
verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(stopParamsCapture.capture());
434+
assertThat(stopParamsCapture.getValue().getModelId(), equalTo(modelOne));
435+
assertThat(stopParamsCapture.getValue().getDeploymentId(), equalTo(deploymentOne));
436+
});
436437
verify(trainedModelAssignmentService, times(1)).updateModelAssignmentState(
437438
any(UpdateTrainedModelAssignmentRoutingInfoAction.Request.class),
438439
any()
439440
);
440441
verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService);
441442
}
442443

443-
public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNode_ButOtherAllocationIsNotReady_DoesNotCallStop() {
444+
public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNode_ButOtherAllocationIsNotReady_DoesNotCallStop() {
444445
final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService();
445446
String node2 = "test-node-2";
446447
final DiscoveryNodes nodes = DiscoveryNodes.builder()
@@ -488,7 +489,7 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNode_ButOtherA
488489
verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService);
489490
}
490491

491-
public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNodeButAlreadyRemoved_DoesNotCallStop() {
492+
public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNodeButAlreadyRemoved_DoesNotCallStop() {
492493
final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService();
493494
final DiscoveryNodes nodes = DiscoveryNodes.builder().localNodeId(NODE_ID).add(DiscoveryNodeUtils.create(NODE_ID, NODE_ID)).build();
494495
String modelOne = "model-1";
@@ -529,7 +530,7 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNodeButAlready
529530
verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService);
530531
}
531532

532-
public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNodeWithStartingState_DoesNotStopTheDeployment() {
533+
public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNodeWithStartingState_DoesNotStopTheDeployment() {
533534
final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService();
534535
final DiscoveryNodes nodes = DiscoveryNodes.builder().localNodeId(NODE_ID).add(DiscoveryNodeUtils.create(NODE_ID, NODE_ID)).build();
535536
String modelOne = "model-1";
@@ -571,7 +572,46 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNodeWithStarti
571572
verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService);
572573
}
573574

574-
public void testClusterChanged_WhenAssigmentIsStopping_DoesNotAddModelToBeLoaded() throws InterruptedException {
575+
public void testClusterChanged_WhenNodeDoesNotExistInAssignmentRoutingTable_DoesGracefullyStopTheDeployment() throws Exception {
576+
final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService();
577+
final DiscoveryNodes nodes = DiscoveryNodes.builder().localNodeId(NODE_ID).add(DiscoveryNodeUtils.create(NODE_ID, NODE_ID)).build();
578+
String modelOne = "model-1";
579+
String deploymentOne = "deployment-1";
580+
581+
var taskParams = newParams(deploymentOne, modelOne);
582+
583+
ClusterChangedEvent event = new ClusterChangedEvent(
584+
"testClusterChanged",
585+
ClusterState.builder(new ClusterName("testClusterChanged"))
586+
.nodes(nodes)
587+
.metadata(
588+
Metadata.builder()
589+
.putCustom(
590+
TrainedModelAssignmentMetadata.NAME,
591+
TrainedModelAssignmentMetadata.Builder.empty()
592+
.addNewAssignment(deploymentOne, TrainedModelAssignment.Builder.empty(taskParams, null))
593+
.build()
594+
)
595+
.build()
596+
)
597+
.build(),
598+
ClusterState.EMPTY_STATE
599+
);
600+
601+
trainedModelAssignmentNodeService.prepareModelToLoad(taskParams);
602+
trainedModelAssignmentNodeService.clusterChanged(event);
603+
604+
assertBusy(() -> verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(any()));
605+
// This still shouldn't trigger a cluster state update because the routing entry wasn't in the table so we won't add a new routing
606+
// entry for stopping
607+
verify(trainedModelAssignmentService, never()).updateModelAssignmentState(
608+
any(UpdateTrainedModelAssignmentRoutingInfoAction.Request.class),
609+
any()
610+
);
611+
verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService);
612+
}
613+
614+
public void testClusterChanged_WhenAssignmentIsStopping_DoesNotAddModelToBeLoaded() throws InterruptedException {
575615
final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService();
576616
final DiscoveryNodes nodes = DiscoveryNodes.builder().localNodeId(NODE_ID).add(DiscoveryNodeUtils.create(NODE_ID, NODE_ID)).build();
577617
String modelOne = "model-1";
@@ -603,7 +643,6 @@ public void testClusterChanged_WhenAssigmentIsStopping_DoesNotAddModelToBeLoaded
603643
ClusterState.EMPTY_STATE
604644
);
605645

606-
// trainedModelAssignmentNodeService.prepareModelToLoad(taskParams);
607646
trainedModelAssignmentNodeService.clusterChanged(event);
608647
loadQueuedModels(trainedModelAssignmentNodeService);
609648

@@ -724,7 +763,9 @@ public void testClusterChanged() throws Exception {
724763

725764
assertBusy(() -> {
726765
ArgumentCaptor<TrainedModelDeploymentTask> stoppedTaskCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class);
727-
verify(deploymentManager, times(1)).stopDeployment(stoppedTaskCapture.capture());
766+
// deployment-2 was originally started on node NODE_ID but in the latest cluster event it is no longer on that node so we will
767+
// gracefully stop it
768+
verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(stoppedTaskCapture.capture());
728769
assertThat(stoppedTaskCapture.getAllValues().get(0).getDeploymentId(), equalTo(deploymentTwo));
729770
});
730771
ArgumentCaptor<TrainedModelDeploymentTask> startTaskCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class);

0 commit comments

Comments
 (0)