Skip to content

Commit 2854063

Browse files
jonathan-buttnerelasticsearchmachine
andcommitted
[ML] Gracefully shutdown model deployment when node is removed from assignment routing (elastic#134673)
* Initial fix for graceful shutdown for unreferenced nodes * Update docs/changelog/134673.yaml * [CI] Auto commit changes from spotless * Fixing cluster change test and flaky tests * [CI] Auto commit changes from spotless * Addressing feedback --------- Co-authored-by: elasticsearchmachine <[email protected]>
1 parent 6e2e910 commit 2854063

File tree

3 files changed

+74
-16
lines changed

3 files changed

+74
-16
lines changed

docs/changelog/134673.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 134673
2+
summary: Gracefully shutdown model deployment when node is removed from assignment
3+
routing
4+
area: Machine Learning
5+
type: bug
6+
issues: []

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,7 @@ private void gracefullyStopDeployment(String deploymentId, String currentNode) {
523523
if (task == null) {
524524
logger.debug(
525525
() -> format(
526-
"[%s] Unable to gracefully stop deployment for shutting down node %s because task does not exit",
526+
"[%s] Unable to gracefully stop deployment for shutting down node %s because task does not exist",
527527
deploymentId,
528528
currentNode
529529
)
@@ -547,7 +547,7 @@ private void gracefullyStopDeployment(String deploymentId, String currentNode) {
547547
routingStateListener
548548
);
549549

550-
stopDeploymentAfterCompletingPendingWorkAsync(task, notifyDeploymentOfStopped);
550+
stopDeploymentAfterCompletingPendingWorkAsync(task, NODE_IS_SHUTTING_DOWN, notifyDeploymentOfStopped);
551551
}
552552

553553
private ActionListener<Void> updateRoutingStateToStoppedListener(
@@ -573,11 +573,18 @@ private void stopUnreferencedDeployment(String deploymentId, String currentNode)
573573
// This model is not routed to the current node at all
574574
TrainedModelDeploymentTask task = deploymentIdToTask.remove(deploymentId);
575575
if (task == null) {
576+
logger.debug(
577+
() -> format(
578+
"[%s] Unable to stop unreferenced deployment for node %s because task does not exist",
579+
deploymentId,
580+
currentNode
581+
)
582+
);
576583
return;
577584
}
578585

579586
logger.debug(() -> format("[%s] Stopping unreferenced deployment for node %s", deploymentId, currentNode));
580-
stopDeploymentAsync(
587+
stopDeploymentAfterCompletingPendingWorkAsync(
581588
task,
582589
NODE_NO_LONGER_REFERENCED,
583590
ActionListener.wrap(
@@ -614,8 +621,12 @@ private void stopDeploymentHelper(
614621
});
615622
}
616623

617-
private void stopDeploymentAfterCompletingPendingWorkAsync(TrainedModelDeploymentTask task, ActionListener<Void> listener) {
618-
stopDeploymentHelper(task, NODE_IS_SHUTTING_DOWN, deploymentManager::stopAfterCompletingPendingWork, listener);
624+
private void stopDeploymentAfterCompletingPendingWorkAsync(
625+
TrainedModelDeploymentTask task,
626+
String reason,
627+
ActionListener<Void> listener
628+
) {
629+
stopDeploymentHelper(task, reason, deploymentManager::stopAfterCompletingPendingWork, listener);
619630
}
620631

621632
private void updateNumberOfAllocations(TrainedModelAssignmentMetadata assignments) {

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -380,8 +380,7 @@ public void testClusterChangedWithResetMode() throws InterruptedException {
380380
verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService);
381381
}
382382

383-
public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNode_CallsStopAfterCompletingPendingWork()
384-
throws InterruptedException {
383+
public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNode_CallsStopAfterCompletingPendingWork() throws Exception {
385384
final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService();
386385
final DiscoveryNodes nodes = DiscoveryNodes.builder().localNodeId(NODE_ID).add(DiscoveryNodeUtils.create(NODE_ID, NODE_ID)).build();
387386
String modelOne = "model-1";
@@ -430,17 +429,19 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNode_CallsStop
430429
fail("Failed waiting for the stop process call to complete");
431430
}
432431

433-
verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(stopParamsCapture.capture());
434-
assertThat(stopParamsCapture.getValue().getModelId(), equalTo(modelOne));
435-
assertThat(stopParamsCapture.getValue().getDeploymentId(), equalTo(deploymentOne));
432+
assertBusy(() -> {
433+
verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(stopParamsCapture.capture());
434+
assertThat(stopParamsCapture.getValue().getModelId(), equalTo(modelOne));
435+
assertThat(stopParamsCapture.getValue().getDeploymentId(), equalTo(deploymentOne));
436+
});
436437
verify(trainedModelAssignmentService, times(1)).updateModelAssignmentState(
437438
any(UpdateTrainedModelAssignmentRoutingInfoAction.Request.class),
438439
any()
439440
);
440441
verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService);
441442
}
442443

443-
public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNode_ButOtherAllocationIsNotReady_DoesNotCallStop() {
444+
public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNode_ButOtherAllocationIsNotReady_DoesNotCallStop() {
444445
final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService();
445446
String node2 = "test-node-2";
446447
final DiscoveryNodes nodes = DiscoveryNodes.builder()
@@ -488,7 +489,7 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNode_ButOtherA
488489
verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService);
489490
}
490491

491-
public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNodeButAlreadyRemoved_DoesNotCallStop() {
492+
public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNodeButAlreadyRemoved_DoesNotCallStop() {
492493
final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService();
493494
final DiscoveryNodes nodes = DiscoveryNodes.builder().localNodeId(NODE_ID).add(DiscoveryNodeUtils.create(NODE_ID, NODE_ID)).build();
494495
String modelOne = "model-1";
@@ -529,7 +530,7 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNodeButAlready
529530
verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService);
530531
}
531532

532-
public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNodeWithStartingState_DoesNotStopTheDeployment() {
533+
public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNodeWithStartingState_DoesNotStopTheDeployment() {
533534
final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService();
534535
final DiscoveryNodes nodes = DiscoveryNodes.builder().localNodeId(NODE_ID).add(DiscoveryNodeUtils.create(NODE_ID, NODE_ID)).build();
535536
String modelOne = "model-1";
@@ -571,7 +572,46 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNodeWithStarti
571572
verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService);
572573
}
573574

574-
public void testClusterChanged_WhenAssigmentIsStopping_DoesNotAddModelToBeLoaded() throws InterruptedException {
575+
public void testClusterChanged_WhenNodeDoesNotExistInAssignmentRoutingTable_DoesGracefullyStopTheDeployment() throws Exception {
576+
final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService();
577+
final DiscoveryNodes nodes = DiscoveryNodes.builder().localNodeId(NODE_ID).add(DiscoveryNodeUtils.create(NODE_ID, NODE_ID)).build();
578+
String modelOne = "model-1";
579+
String deploymentOne = "deployment-1";
580+
581+
var taskParams = newParams(deploymentOne, modelOne);
582+
583+
ClusterChangedEvent event = new ClusterChangedEvent(
584+
"testClusterChanged",
585+
ClusterState.builder(new ClusterName("testClusterChanged"))
586+
.nodes(nodes)
587+
.metadata(
588+
Metadata.builder()
589+
.putCustom(
590+
TrainedModelAssignmentMetadata.NAME,
591+
TrainedModelAssignmentMetadata.Builder.empty()
592+
.addNewAssignment(deploymentOne, TrainedModelAssignment.Builder.empty(taskParams, null))
593+
.build()
594+
)
595+
.build()
596+
)
597+
.build(),
598+
ClusterState.EMPTY_STATE
599+
);
600+
601+
trainedModelAssignmentNodeService.prepareModelToLoad(taskParams);
602+
trainedModelAssignmentNodeService.clusterChanged(event);
603+
604+
assertBusy(() -> verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(any()));
605+
// This still shouldn't trigger a cluster state update because the routing entry wasn't in the table so we won't add a new routing
606+
// entry for stopping
607+
verify(trainedModelAssignmentService, never()).updateModelAssignmentState(
608+
any(UpdateTrainedModelAssignmentRoutingInfoAction.Request.class),
609+
any()
610+
);
611+
verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService);
612+
}
613+
614+
public void testClusterChanged_WhenAssignmentIsStopping_DoesNotAddModelToBeLoaded() throws InterruptedException {
575615
final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService();
576616
final DiscoveryNodes nodes = DiscoveryNodes.builder().localNodeId(NODE_ID).add(DiscoveryNodeUtils.create(NODE_ID, NODE_ID)).build();
577617
String modelOne = "model-1";
@@ -603,7 +643,6 @@ public void testClusterChanged_WhenAssigmentIsStopping_DoesNotAddModelToBeLoaded
603643
ClusterState.EMPTY_STATE
604644
);
605645

606-
// trainedModelAssignmentNodeService.prepareModelToLoad(taskParams);
607646
trainedModelAssignmentNodeService.clusterChanged(event);
608647
loadQueuedModels(trainedModelAssignmentNodeService);
609648

@@ -724,7 +763,9 @@ public void testClusterChanged() throws Exception {
724763

725764
assertBusy(() -> {
726765
ArgumentCaptor<TrainedModelDeploymentTask> stoppedTaskCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class);
727-
verify(deploymentManager, times(1)).stopDeployment(stoppedTaskCapture.capture());
766+
// deployment-2 was originally started on node NODE_ID but in the latest cluster event it is no longer on that node so we will
767+
// gracefully stop it
768+
verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(stoppedTaskCapture.capture());
728769
assertThat(stoppedTaskCapture.getAllValues().get(0).getDeploymentId(), equalTo(deploymentTwo));
729770
});
730771
ArgumentCaptor<TrainedModelDeploymentTask> startTaskCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class);

0 commit comments

Comments
 (0)