Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/134673.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 134673
summary: Gracefully shutdown model deployment when node is removed from assignment
routing
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ private void gracefullyStopDeployment(String deploymentId, String currentNode) {
if (task == null) {
logger.debug(
() -> format(
"[%s] Unable to gracefully stop deployment for shutting down node %s because task does not exit",
"[%s] Unable to gracefully stop deployment for shutting down node %s because task does not exist",
deploymentId,
currentNode
)
Expand All @@ -547,7 +547,7 @@ private void gracefullyStopDeployment(String deploymentId, String currentNode) {
routingStateListener
);

stopDeploymentAfterCompletingPendingWorkAsync(task, notifyDeploymentOfStopped);
stopDeploymentAfterCompletingPendingWorkAsync(task, NODE_IS_SHUTTING_DOWN, notifyDeploymentOfStopped);
}

private ActionListener<Void> updateRoutingStateToStoppedListener(
Expand All @@ -573,11 +573,18 @@ private void stopUnreferencedDeployment(String deploymentId, String currentNode)
// This model is not routed to the current node at all
TrainedModelDeploymentTask task = deploymentIdToTask.remove(deploymentId);
if (task == null) {
logger.debug(
() -> format(
"[%s] Unable to stop unreferenced deployment for node %s because task does not exist",
deploymentId,
currentNode
)
);
return;
}

logger.debug(() -> format("[%s] Stopping unreferenced deployment for node %s", deploymentId, currentNode));
stopDeploymentAsync(
stopDeploymentAfterCompletingPendingWorkAsync(
task,
NODE_NO_LONGER_REFERENCED,
ActionListener.wrap(
Expand Down Expand Up @@ -614,8 +621,12 @@ private void stopDeploymentHelper(
});
}

private void stopDeploymentAfterCompletingPendingWorkAsync(TrainedModelDeploymentTask task, ActionListener<Void> listener) {
stopDeploymentHelper(task, NODE_IS_SHUTTING_DOWN, deploymentManager::stopAfterCompletingPendingWork, listener);
private void stopDeploymentAfterCompletingPendingWorkAsync(
TrainedModelDeploymentTask task,
String reason,
ActionListener<Void> listener
) {
stopDeploymentHelper(task, reason, deploymentManager::stopAfterCompletingPendingWork, listener);
}

private void updateNumberOfAllocations(TrainedModelAssignmentMetadata assignments) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,7 @@ public void testClusterChangedWithResetMode() throws InterruptedException {
verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService);
}

public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNode_CallsStopAfterCompletingPendingWork()
throws InterruptedException {
public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNode_CallsStopAfterCompletingPendingWork() throws Exception {
final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService();
final DiscoveryNodes nodes = DiscoveryNodes.builder().localNodeId(NODE_ID).add(DiscoveryNodeUtils.create(NODE_ID, NODE_ID)).build();
String modelOne = "model-1";
Expand Down Expand Up @@ -430,17 +429,19 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNode_CallsStop
fail("Failed waiting for the stop process call to complete");
}

verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(stopParamsCapture.capture());
assertThat(stopParamsCapture.getValue().getModelId(), equalTo(modelOne));
assertThat(stopParamsCapture.getValue().getDeploymentId(), equalTo(deploymentOne));
assertBusy(() -> {
verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(stopParamsCapture.capture());
assertThat(stopParamsCapture.getValue().getModelId(), equalTo(modelOne));
assertThat(stopParamsCapture.getValue().getDeploymentId(), equalTo(deploymentOne));
});
verify(trainedModelAssignmentService, times(1)).updateModelAssignmentState(
any(UpdateTrainedModelAssignmentRoutingInfoAction.Request.class),
any()
);
verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService);
}

public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNode_ButOtherAllocationIsNotReady_DoesNotCallStop() {
public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNode_ButOtherAllocationIsNotReady_DoesNotCallStop() {
final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService();
String node2 = "test-node-2";
final DiscoveryNodes nodes = DiscoveryNodes.builder()
Expand Down Expand Up @@ -488,7 +489,7 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNode_ButOtherA
verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService);
}

public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNodeButAlreadyRemoved_DoesNotCallStop() {
public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNodeButAlreadyRemoved_DoesNotCallStop() {
final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService();
final DiscoveryNodes nodes = DiscoveryNodes.builder().localNodeId(NODE_ID).add(DiscoveryNodeUtils.create(NODE_ID, NODE_ID)).build();
String modelOne = "model-1";
Expand Down Expand Up @@ -529,7 +530,7 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNodeButAlready
verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService);
}

public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNodeWithStartingState_DoesNotStopTheDeployment() {
public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNodeWithStartingState_DoesNotStopTheDeployment() {
final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService();
final DiscoveryNodes nodes = DiscoveryNodes.builder().localNodeId(NODE_ID).add(DiscoveryNodeUtils.create(NODE_ID, NODE_ID)).build();
String modelOne = "model-1";
Expand Down Expand Up @@ -571,7 +572,46 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNodeWithStarti
verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService);
}

public void testClusterChanged_WhenAssigmentIsStopping_DoesNotAddModelToBeLoaded() throws InterruptedException {
public void testClusterChanged_WhenNodeDoesNotExistInAssignmentRoutingTable_DoesGracefullyStopTheDeployment() throws Exception {
final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService();
final DiscoveryNodes nodes = DiscoveryNodes.builder().localNodeId(NODE_ID).add(DiscoveryNodeUtils.create(NODE_ID, NODE_ID)).build();
String modelOne = "model-1";
String deploymentOne = "deployment-1";

var taskParams = newParams(deploymentOne, modelOne);

ClusterChangedEvent event = new ClusterChangedEvent(
"testClusterChanged",
ClusterState.builder(new ClusterName("testClusterChanged"))
.nodes(nodes)
.metadata(
Metadata.builder()
.putCustom(
TrainedModelAssignmentMetadata.NAME,
TrainedModelAssignmentMetadata.Builder.empty()
.addNewAssignment(deploymentOne, TrainedModelAssignment.Builder.empty(taskParams, null))
.build()
)
.build()
)
.build(),
ClusterState.EMPTY_STATE
);

trainedModelAssignmentNodeService.prepareModelToLoad(taskParams);
trainedModelAssignmentNodeService.clusterChanged(event);

assertBusy(() -> verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(any()));
// This still shouldn't trigger a cluster state update because the routing entry wasn't in the table so we won't add a new routing
// entry for stopping
verify(trainedModelAssignmentService, never()).updateModelAssignmentState(
any(UpdateTrainedModelAssignmentRoutingInfoAction.Request.class),
any()
);
verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService);
}

public void testClusterChanged_WhenAssignmentIsStopping_DoesNotAddModelToBeLoaded() throws InterruptedException {
final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService();
final DiscoveryNodes nodes = DiscoveryNodes.builder().localNodeId(NODE_ID).add(DiscoveryNodeUtils.create(NODE_ID, NODE_ID)).build();
String modelOne = "model-1";
Expand Down Expand Up @@ -603,7 +643,6 @@ public void testClusterChanged_WhenAssigmentIsStopping_DoesNotAddModelToBeLoaded
ClusterState.EMPTY_STATE
);

// trainedModelAssignmentNodeService.prepareModelToLoad(taskParams);
trainedModelAssignmentNodeService.clusterChanged(event);
loadQueuedModels(trainedModelAssignmentNodeService);

Expand Down Expand Up @@ -724,7 +763,9 @@ public void testClusterChanged() throws Exception {

assertBusy(() -> {
ArgumentCaptor<TrainedModelDeploymentTask> stoppedTaskCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class);
verify(deploymentManager, times(1)).stopDeployment(stoppedTaskCapture.capture());
// deployment-2 was originally started on node NODE_ID but in the latest cluster event it is no longer on that node so we will
// gracefully stop it
verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(stoppedTaskCapture.capture());
assertThat(stoppedTaskCapture.getAllValues().get(0).getDeploymentId(), equalTo(deploymentTwo));
});
ArgumentCaptor<TrainedModelDeploymentTask> startTaskCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class);
Expand Down