Skip to content

Commit d878fbd

Browse files
authored
[backport 2.x] Fix model still deployed after calling undeploy API (#2510) (#2531)
* Fix model still deployed after calling undeploy API (#2510) * Fix model still deployed after calling undeploy API Signed-off-by: Sicheng Song <[email protected]> * Add UT coverage Signed-off-by: Sicheng Song <[email protected]> * Fix style Signed-off-by: Sicheng Song <[email protected]> * Add UT coverage Signed-off-by: Sicheng Song <[email protected]> * Add UT coverage Signed-off-by: Sicheng Song <[email protected]> --------- Signed-off-by: Sicheng Song <[email protected]> * Fix IT Signed-off-by: Sicheng Song <[email protected]> --------- Signed-off-by: Sicheng Song <[email protected]>
1 parent 40c5edc commit d878fbd

File tree

3 files changed

+438
-122
lines changed

3 files changed

+438
-122
lines changed

plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java

Lines changed: 108 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
import org.opensearch.common.util.concurrent.ThreadContext;
3030
import org.opensearch.core.action.ActionListener;
3131
import org.opensearch.core.common.io.stream.StreamInput;
32-
import org.opensearch.core.xcontent.NamedXContentRegistry;
3332
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
3433
import org.opensearch.ml.common.FunctionName;
3534
import org.opensearch.ml.common.MLModel;
@@ -42,10 +41,10 @@
4241
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodeResponse;
4342
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesRequest;
4443
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse;
45-
import org.opensearch.ml.helper.ModelAccessControlHelper;
4644
import org.opensearch.ml.model.MLModelManager;
4745
import org.opensearch.ml.stats.MLNodeLevelStat;
4846
import org.opensearch.ml.stats.MLStats;
47+
import org.opensearch.tasks.Task;
4948
import org.opensearch.threadpool.ThreadPool;
5049
import org.opensearch.transport.TransportService;
5150

@@ -59,11 +58,8 @@ public class TransportUndeployModelAction extends
5958
private final MLModelManager mlModelManager;
6059
private final ClusterService clusterService;
6160
private final Client client;
62-
private DiscoveryNodeHelper nodeFilter;
61+
private final DiscoveryNodeHelper nodeFilter;
6362
private final MLStats mlStats;
64-
private NamedXContentRegistry xContentRegistry;
65-
66-
private ModelAccessControlHelper modelAccessControlHelper;
6763

6864
@Inject
6965
public TransportUndeployModelAction(
@@ -74,9 +70,7 @@ public TransportUndeployModelAction(
7470
ThreadPool threadPool,
7571
Client client,
7672
DiscoveryNodeHelper nodeFilter,
77-
MLStats mlStats,
78-
NamedXContentRegistry xContentRegistry,
79-
ModelAccessControlHelper modelAccessControlHelper
73+
MLStats mlStats
8074
) {
8175
super(
8276
MLUndeployModelAction.NAME,
@@ -90,107 +84,128 @@ public TransportUndeployModelAction(
9084
MLUndeployModelNodeResponse.class
9185
);
9286
this.mlModelManager = mlModelManager;
87+
9388
this.clusterService = clusterService;
9489
this.client = client;
9590
this.nodeFilter = nodeFilter;
9691
this.mlStats = mlStats;
97-
this.xContentRegistry = xContentRegistry;
98-
this.modelAccessControlHelper = modelAccessControlHelper;
9992
}
10093

10194
@Override
102-
protected MLUndeployModelNodesResponse newResponse(
103-
MLUndeployModelNodesRequest nodesRequest,
104-
List<MLUndeployModelNodeResponse> responses,
105-
List<FailedNodeException> failures
95+
protected void doExecute(Task task, MLUndeployModelNodesRequest request, ActionListener<MLUndeployModelNodesResponse> listener) {
96+
ActionListener<MLUndeployModelNodesResponse> wrappedListener = ActionListener.wrap(undeployModelNodesResponse -> {
97+
processUndeployModelResponseAndUpdate(undeployModelNodesResponse, listener);
98+
}, listener::onFailure);
99+
super.doExecute(task, request, wrappedListener);
100+
}
101+
102+
void processUndeployModelResponseAndUpdate(
103+
MLUndeployModelNodesResponse undeployModelNodesResponse,
104+
ActionListener<MLUndeployModelNodesResponse> listener
106105
) {
107-
if (responses != null) {
108-
Map<String, List<String>> actualRemovedNodesMap = new HashMap<>();
109-
Map<String, String[]> modelWorkNodesBeforeRemoval = new HashMap<>();
110-
responses.forEach(r -> {
111-
Map<String, String[]> nodeCounts = r.getModelWorkerNodeBeforeRemoval();
112-
113-
if (nodeCounts != null) {
114-
for (Map.Entry<String, String[]> entry : nodeCounts.entrySet()) {
115-
// when undeploy a undeployed model, the entry.getvalue() is null
116-
if (entry.getValue() != null
117-
&& (!modelWorkNodesBeforeRemoval.containsKey(entry.getKey())
118-
|| modelWorkNodesBeforeRemoval.get(entry.getKey()).length < entry.getValue().length)) {
119-
modelWorkNodesBeforeRemoval.put(entry.getKey(), entry.getValue());
120-
}
106+
List<MLUndeployModelNodeResponse> responses = undeployModelNodesResponse.getNodes();
107+
if (responses == null || responses.isEmpty()) {
108+
listener.onResponse(undeployModelNodesResponse);
109+
return;
110+
}
111+
112+
Map<String, List<String>> actualRemovedNodesMap = new HashMap<>();
113+
Map<String, String[]> modelWorkNodesBeforeRemoval = new HashMap<>();
114+
responses.forEach(r -> {
115+
Map<String, String[]> nodeCounts = r.getModelWorkerNodeBeforeRemoval();
116+
117+
if (nodeCounts != null) {
118+
for (Map.Entry<String, String[]> entry : nodeCounts.entrySet()) {
119+
// when undeploy an undeployed model, the entry.getvalue() is null
120+
if (entry.getValue() != null
121+
&& (!modelWorkNodesBeforeRemoval.containsKey(entry.getKey())
122+
|| modelWorkNodesBeforeRemoval.get(entry.getKey()).length < entry.getValue().length)) {
123+
modelWorkNodesBeforeRemoval.put(entry.getKey(), entry.getValue());
121124
}
122125
}
126+
}
123127

124-
Map<String, String> modelUndeployStatus = r.getModelUndeployStatus();
125-
for (Map.Entry<String, String> entry : modelUndeployStatus.entrySet()) {
126-
String status = entry.getValue();
127-
if (UNDEPLOYED.equals(status)) {
128-
String modelId = entry.getKey();
129-
if (!actualRemovedNodesMap.containsKey(modelId)) {
130-
actualRemovedNodesMap.put(modelId, new ArrayList<>());
131-
}
132-
actualRemovedNodesMap.get(modelId).add(r.getNode().getId());
128+
Map<String, String> modelUndeployStatus = r.getModelUndeployStatus();
129+
for (Map.Entry<String, String> entry : modelUndeployStatus.entrySet()) {
130+
String status = entry.getValue();
131+
if (UNDEPLOYED.equals(status)) {
132+
String modelId = entry.getKey();
133+
if (!actualRemovedNodesMap.containsKey(modelId)) {
134+
actualRemovedNodesMap.put(modelId, new ArrayList<>());
133135
}
136+
actualRemovedNodesMap.get(modelId).add(r.getNode().getId());
134137
}
135-
});
136-
137-
MLSyncUpInput syncUpInput = MLSyncUpInput
138-
.builder()
139-
.removedWorkerNodes(covertRemoveNodesMapForSyncUp(actualRemovedNodesMap))
140-
.build();
141-
142-
MLSyncUpNodesRequest syncUpRequest = new MLSyncUpNodesRequest(nodeFilter.getAllNodes(), syncUpInput);
143-
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
144-
if (actualRemovedNodesMap.size() > 0) {
145-
BulkRequest bulkRequest = new BulkRequest();
146-
Map<String, Boolean> deployToAllNodes = new HashMap<>();
147-
for (String modelId : actualRemovedNodesMap.keySet()) {
148-
UpdateRequest updateRequest = new UpdateRequest();
149-
List<String> removedNodes = actualRemovedNodesMap.get(modelId);
150-
int removedNodeCount = removedNodes.size();
151-
/**
152-
* If allow custom deploy is false, user can only undeploy all nodes and status is undeployed.
153-
* If allow custom deploy is true, user can undeploy all nodes and status is undeployed,
154-
* or undeploy partial nodes, and status is deployed, this case means user created a new deployment plan, and
155-
* we need to update both planning worker nodes (count) and current worker nodes (count)
156-
* and deployToAllNodes value in model index.
157-
*/
158-
Map<String, Object> updateDocument = new HashMap<>();
159-
if (modelWorkNodesBeforeRemoval.get(modelId).length == removedNodeCount) { // undeploy all nodes.
160-
updateDocument.put(MLModel.PLANNING_WORKER_NODES_FIELD, ImmutableList.of());
161-
updateDocument.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, 0);
162-
updateDocument.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, 0);
163-
updateDocument.put(MLModel.MODEL_STATE_FIELD, MLModelState.UNDEPLOYED);
164-
} else { // undeploy partial nodes.
165-
// TODO (to fix) when undeploy partial nodes, the original model status could be partially_deployed,
166-
// and the user could be undeploying not running model nodes, and we should update model status to deployed.
167-
updateDocument.put(MLModel.DEPLOY_TO_ALL_NODES_FIELD, false);
168-
List<String> newPlanningWorkerNodes = Arrays
169-
.stream(modelWorkNodesBeforeRemoval.get(modelId))
170-
.filter(x -> !removedNodes.contains(x))
171-
.collect(Collectors.toList());
172-
updateDocument.put(MLModel.PLANNING_WORKER_NODES_FIELD, newPlanningWorkerNodes);
173-
updateDocument.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, newPlanningWorkerNodes.size());
174-
updateDocument.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, newPlanningWorkerNodes.size());
175-
deployToAllNodes.put(modelId, false);
176-
}
177-
updateRequest.index(ML_MODEL_INDEX).id(modelId).doc(updateDocument);
178-
bulkRequest.add(updateRequest).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
138+
}
139+
});
140+
141+
MLSyncUpInput syncUpInput = MLSyncUpInput
142+
.builder()
143+
.removedWorkerNodes(covertRemoveNodesMapForSyncUp(actualRemovedNodesMap))
144+
.build();
145+
146+
MLSyncUpNodesRequest syncUpRequest = new MLSyncUpNodesRequest(nodeFilter.getAllNodes(), syncUpInput);
147+
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
148+
if (actualRemovedNodesMap.size() > 0) {
149+
BulkRequest bulkRequest = new BulkRequest();
150+
Map<String, Boolean> deployToAllNodes = new HashMap<>();
151+
for (String modelId : actualRemovedNodesMap.keySet()) {
152+
UpdateRequest updateRequest = new UpdateRequest();
153+
List<String> removedNodes = actualRemovedNodesMap.get(modelId);
154+
int removedNodeCount = removedNodes.size();
155+
/**
156+
* If allow custom deploy is false, user can only undeploy all nodes and status is undeployed.
157+
* If allow custom deploy is true, user can undeploy all nodes and status is undeployed,
158+
* or undeploy partial nodes, and status is deployed, this case means user created a new deployment plan, and
159+
* we need to update both planning worker nodes (count) and current worker nodes (count)
160+
* and deployToAllNodes value in model index.
161+
*/
162+
Map<String, Object> updateDocument = new HashMap<>();
163+
if (modelWorkNodesBeforeRemoval.get(modelId).length == removedNodeCount) { // undeploy all nodes.
164+
updateDocument.put(MLModel.PLANNING_WORKER_NODES_FIELD, ImmutableList.of());
165+
updateDocument.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, 0);
166+
updateDocument.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, 0);
167+
updateDocument.put(MLModel.MODEL_STATE_FIELD, MLModelState.UNDEPLOYED);
168+
} else { // undeploy partial nodes.
169+
// TODO (to fix) when undeploy partial nodes, the original model status could be partially_deployed,
170+
// and the user could be undeploying not running model nodes, and we should update model status to deployed.
171+
updateDocument.put(MLModel.DEPLOY_TO_ALL_NODES_FIELD, false);
172+
List<String> newPlanningWorkerNodes = Arrays
173+
.stream(modelWorkNodesBeforeRemoval.get(modelId))
174+
.filter(x -> !removedNodes.contains(x))
175+
.collect(Collectors.toList());
176+
updateDocument.put(MLModel.PLANNING_WORKER_NODES_FIELD, newPlanningWorkerNodes);
177+
updateDocument.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, newPlanningWorkerNodes.size());
178+
updateDocument.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, newPlanningWorkerNodes.size());
179+
deployToAllNodes.put(modelId, false);
179180
}
180-
syncUpInput.setDeployToAllNodes(deployToAllNodes);
181-
ActionListener<BulkResponse> actionListener = ActionListener.wrap(r -> {
182-
log
183-
.debug(
184-
"updated model state as undeployed for : {}",
185-
Arrays.toString(actualRemovedNodesMap.keySet().toArray(new String[0]))
186-
);
187-
}, e -> { log.error("Failed to update model state as undeployed", e); });
188-
client.bulk(bulkRequest, ActionListener.runAfter(actionListener, () -> { syncUpUndeployedModels(syncUpRequest); }));
189-
} else {
190-
syncUpUndeployedModels(syncUpRequest);
181+
updateRequest.index(ML_MODEL_INDEX).id(modelId).doc(updateDocument);
182+
bulkRequest.add(updateRequest).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
191183
}
184+
syncUpInput.setDeployToAllNodes(deployToAllNodes);
185+
ActionListener<BulkResponse> actionListener = ActionListener.wrap(r -> {
186+
log
187+
.debug(
188+
"updated model state as undeployed for : {}",
189+
Arrays.toString(actualRemovedNodesMap.keySet().toArray(new String[0]))
190+
);
191+
}, e -> { log.error("Failed to update model state as undeployed", e); });
192+
client.bulk(bulkRequest, ActionListener.runAfter(actionListener, () -> {
193+
syncUpUndeployedModels(syncUpRequest);
194+
listener.onResponse(undeployModelNodesResponse);
195+
}));
196+
} else {
197+
syncUpUndeployedModels(syncUpRequest);
198+
listener.onResponse(undeployModelNodesResponse);
192199
}
193200
}
201+
}
202+
203+
@Override
204+
protected MLUndeployModelNodesResponse newResponse(
205+
MLUndeployModelNodesRequest nodesRequest,
206+
List<MLUndeployModelNodeResponse> responses,
207+
List<FailedNodeException> failures
208+
) {
194209
return new MLUndeployModelNodesResponse(clusterService.getClusterName(), responses, failures);
195210
}
196211

0 commit comments

Comments
 (0)