Skip to content

Commit c25d116

Browse files
Add model auto redeploy feature (#852) (#859)
* Add model auto reload feature * Add model auto redeploy feature * Add model auto redeploy feature * Add model auto redeploy feature * Add model auto redeploy feature * Add model auto redeploy feature * Add model auto redeploy feature * Add model auto redeploy feature * Add model auto redeploy feature * Add model auto redeploy feature * Addressed comments and reverted the getEligibleNodes logic * Change method regarding clear model retry times in model index and addressed naming convention comments * Fix import * issue * Rebase upstream code and fix UT failure issue --------- (cherry picked from commit 1e6b6c4) Signed-off-by: Zan Niu <[email protected]> Signed-off-by: Yaliang Wu <[email protected]> Co-authored-by: zane-neo <[email protected]>
1 parent 3e33f0d commit c25d116

34 files changed

+1628
-196
lines changed

common/src/main/java/org/opensearch/ml/common/CommonValue.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ public class CommonValue {
110110
+ MLModel.MODEL_CONTENT_HASH_VALUE_FIELD
111111
+ "\" : {\"type\": \"keyword\"},\n"
112112
+ " \""
113+
+ MLModel.AUTO_REDEPLOY_RETRY_TIMES_FIELD
114+
+ "\" : {\"type\": \"integer\"},\n"
115+
+ " \""
113116
+ MLModel.CREATED_TIME_FIELD
114117
+ "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n"
115118
+ " \""

common/src/main/java/org/opensearch/ml/common/MLModel.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ public class MLModel implements ToXContentObject {
4646
public static final String MODEL_CONTENT_SIZE_IN_BYTES_FIELD = "model_content_size_in_bytes";
4747
//SHA256 hash value of model content.
4848
public static final String MODEL_CONTENT_HASH_VALUE_FIELD = "model_content_hash_value";
49+
4950
public static final String MODEL_CONFIG_FIELD = "model_config";
5051
public static final String CREATED_TIME_FIELD = "created_time";
5152
public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time";
@@ -60,6 +61,8 @@ public class MLModel implements ToXContentObject {
6061
public static final String LAST_UNDEPLOYED_TIME_FIELD = "last_undeployed_time";
6162

6263
public static final String MODEL_ID_FIELD = "model_id";
64+
// auto redploy retry times for this model.
65+
public static final String AUTO_REDEPLOY_RETRY_TIMES_FIELD = "auto_redeploy_retry_times";
6366
public static final String CHUNK_NUMBER_FIELD = "chunk_number";
6467
public static final String TOTAL_CHUNKS_FIELD = "total_chunks";
6568
public static final String PLANNING_WORKER_NODE_COUNT_FIELD = "planning_worker_node_count";
@@ -88,6 +91,8 @@ public class MLModel implements ToXContentObject {
8891

8992
@Setter
9093
private String modelId; // model chunk doc only
94+
95+
private Integer autoRedeployRetryTimes;
9196
private Integer chunkNumber; // model chunk doc only
9297
private Integer totalChunks; // model chunk doc only
9398
private Integer planningWorkerNodeCount; // plan to deploy model to how many nodes
@@ -112,6 +117,7 @@ public MLModel(String name,
112117
Instant lastRegisteredTime,
113118
Instant lastDeployedTime,
114119
Instant lastUndeployedTime,
120+
Integer autoRedeployRetryTimes,
115121
String modelId, Integer chunkNumber,
116122
Integer totalChunks,
117123
Integer planningWorkerNodeCount,
@@ -135,6 +141,7 @@ public MLModel(String name,
135141
this.lastDeployedTime = lastDeployedTime;
136142
this.lastUndeployedTime = lastUndeployedTime;
137143
this.modelId = modelId;
144+
this.autoRedeployRetryTimes = autoRedeployRetryTimes;
138145
this.chunkNumber = chunkNumber;
139146
this.totalChunks = totalChunks;
140147
this.planningWorkerNodeCount = planningWorkerNodeCount;
@@ -172,6 +179,7 @@ public MLModel(StreamInput input) throws IOException{
172179
lastDeployedTime = input.readOptionalInstant();
173180
lastUndeployedTime = input.readOptionalInstant();
174181
modelId = input.readOptionalString();
182+
autoRedeployRetryTimes = input.readOptionalInt();
175183
chunkNumber = input.readOptionalInt();
176184
totalChunks = input.readOptionalInt();
177185
planningWorkerNodeCount = input.readOptionalInt();
@@ -219,6 +227,7 @@ public void writeTo(StreamOutput out) throws IOException {
219227
out.writeOptionalInstant(lastDeployedTime);
220228
out.writeOptionalInstant(lastUndeployedTime);
221229
out.writeOptionalString(modelId);
230+
out.writeOptionalInt(autoRedeployRetryTimes);
222231
out.writeOptionalInt(chunkNumber);
223232
out.writeOptionalInt(totalChunks);
224233
out.writeOptionalInt(planningWorkerNodeCount);
@@ -281,6 +290,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
281290
if (modelId != null) {
282291
builder.field(MODEL_ID_FIELD, modelId);
283292
}
293+
if (autoRedeployRetryTimes != null) {
294+
builder.field(AUTO_REDEPLOY_RETRY_TIMES_FIELD, autoRedeployRetryTimes);
295+
}
284296
if (chunkNumber != null) {
285297
builder.field(CHUNK_NUMBER_FIELD, chunkNumber);
286298
}
@@ -327,6 +339,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
327339
Instant lastDeployedTime = null;
328340
Instant lastUndeployedTime = null;
329341
String modelId = null;
342+
Integer autoRedeployRetryTimes = null;
330343
Integer chunkNumber = null;
331344
Integer totalChunks = null;
332345
Integer planningWorkerNodeCount = null;
@@ -370,6 +383,9 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
370383
case MODEL_ID_FIELD:
371384
modelId = parser.text();
372385
break;
386+
case AUTO_REDEPLOY_RETRY_TIMES_FIELD:
387+
autoRedeployRetryTimes = parser.intValue();
388+
break;
373389
case DESCRIPTION_FIELD:
374390
description = parser.text();
375391
break;
@@ -454,6 +470,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
454470
.lastDeployedTime(lastDeployedTime == null? lastLoadedTime : lastDeployedTime)
455471
.lastUndeployedTime(lastUndeployedTime == null? lastUnloadedTime : lastUndeployedTime)
456472
.modelId(modelId)
473+
.autoRedeployRetryTimes(autoRedeployRetryTimes)
457474
.chunkNumber(chunkNumber)
458475
.totalChunks(totalChunks)
459476
.planningWorkerNodeCount(planningWorkerNodeCount)

common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelInput.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ public class MLDeployModelInput implements Writeable {
2121
private String modelContentHash;
2222
private Integer nodeCount;
2323
private String coordinatingNodeId;
24+
private Boolean isDeployToAllNodes;
2425
private MLTask mlTask;
2526

2627
public MLDeployModelInput(StreamInput in) throws IOException {
@@ -29,16 +30,18 @@ public MLDeployModelInput(StreamInput in) throws IOException {
2930
this.modelContentHash = in.readOptionalString();
3031
this.nodeCount = in.readInt();
3132
this.coordinatingNodeId = in.readString();
33+
this.isDeployToAllNodes = in.readOptionalBoolean();
3234
this.mlTask = new MLTask(in);
3335
}
3436

3537
@Builder
36-
public MLDeployModelInput(String modelId, String taskId, String modelContentHash, Integer nodeCount, String coordinatingNodeId, MLTask mlTask) {
38+
public MLDeployModelInput(String modelId, String taskId, String modelContentHash, Integer nodeCount, String coordinatingNodeId, Boolean isDeployToAllNodes, MLTask mlTask) {
3739
this.modelId = modelId;
3840
this.taskId = taskId;
3941
this.modelContentHash = modelContentHash;
4042
this.nodeCount = nodeCount;
4143
this.coordinatingNodeId = coordinatingNodeId;
44+
this.isDeployToAllNodes = isDeployToAllNodes;
4245
this.mlTask = mlTask;
4346
}
4447

@@ -52,6 +55,7 @@ public void writeTo(StreamOutput out) throws IOException {
5255
out.writeOptionalString(modelContentHash);
5356
out.writeInt(nodeCount);
5457
out.writeString(coordinatingNodeId);
58+
out.writeOptionalBoolean(isDeployToAllNodes);
5559
mlTask.writeTo(out);
5660
}
5761

common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpInput.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,19 +31,25 @@ public class MLSyncUpInput implements Writeable {
3131
// sync running deploy model tasks
3232
private boolean syncRunningDeployModelTasks;
3333

34+
// deployToAll flag for models, when deploy/undeploy a model, this will passed to each node to update cache value to make sure
35+
// profile API has consistent data with model index.
36+
private Map<String, Boolean> deployToAllNodes;
37+
3438
@Builder
3539
public MLSyncUpInput(boolean getDeployedModels,
3640
Map<String, String[]> addedWorkerNodes,
3741
Map<String, String[]> removedWorkerNodes,
3842
Map<String, Set<String>> modelRoutingTable,
3943
Map<String, Set<String>> runningDeployModelTasks,
44+
Map<String, Boolean> deployToAllNodes,
4045
boolean clearRoutingTable,
4146
boolean syncRunningDeployModelTasks) {
4247
this.getDeployedModels = getDeployedModels;
4348
this.addedWorkerNodes = addedWorkerNodes;
4449
this.removedWorkerNodes = removedWorkerNodes;
4550
this.modelRoutingTable = modelRoutingTable;
4651
this.runningDeployModelTasks = runningDeployModelTasks;
52+
this.deployToAllNodes = deployToAllNodes;
4753
this.clearRoutingTable = clearRoutingTable;
4854
this.syncRunningDeployModelTasks = syncRunningDeployModelTasks;
4955
}
@@ -64,6 +70,9 @@ public MLSyncUpInput(StreamInput in) throws IOException {
6470
if (in.readBoolean()) {
6571
runningDeployModelTasks = in.readMap(StreamInput::readString, s -> s.readSet(StreamInput::readString));
6672
}
73+
if (in.readBoolean()) {
74+
deployToAllNodes = in.readMap(StreamInput::readString, StreamInput::readOptionalBoolean);
75+
}
6776
this.clearRoutingTable = in.readBoolean();
6877
this.syncRunningDeployModelTasks = in.readBoolean();
6978
}
@@ -95,6 +104,12 @@ public void writeTo(StreamOutput out) throws IOException {
95104
} else {
96105
out.writeBoolean(false);
97106
}
107+
if (deployToAllNodes != null && deployToAllNodes.size() > 0) {
108+
out.writeBoolean(true);
109+
out.writeMap(deployToAllNodes, StreamOutput::writeString, StreamOutput::writeOptionalBoolean);
110+
} else {
111+
out.writeBoolean(false);
112+
}
98113
out.writeBoolean(clearRoutingTable);
99114
out.writeBoolean(syncRunningDeployModelTasks);
100115
}

common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodeResponse.java

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,20 @@
2020
@Getter
2121
public class MLUndeployModelNodeResponse extends BaseNodeResponse implements ToXContentFragment {
2222

23+
// Model undeploy status, if there's model running and successfully undeployed, the status is undeployed, if model is not
24+
// running on current node, status is not_found
2325
private Map<String, String> modelUndeployStatus;
24-
private Map<String, Integer> modelWorkerNodeCounts;
26+
// This is to record before undeploy the model, which nodes are working nodes.
27+
private Map<String, String[]> modelWorkerNodeBeforeRemoval;
2528

2629
public MLUndeployModelNodeResponse(DiscoveryNode node,
2730
Map<String, String> modelUndeployStatus,
28-
Map<String, Integer> modelWorkerNodeCounts) {
31+
Map<String, String[]> modelWorkerNodeBeforeRemoval
32+
) {
2933
super(node);
3034
this.modelUndeployStatus = modelUndeployStatus;
31-
this.modelWorkerNodeCounts = modelWorkerNodeCounts;
35+
//
36+
this.modelWorkerNodeBeforeRemoval = modelWorkerNodeBeforeRemoval;
3237
}
3338

3439
public MLUndeployModelNodeResponse(StreamInput in) throws IOException {
@@ -37,7 +42,7 @@ public MLUndeployModelNodeResponse(StreamInput in) throws IOException {
3742
this.modelUndeployStatus = in.readMap(s -> s.readString(), s-> s.readString());
3843
}
3944
if (in.readBoolean()) {
40-
this.modelWorkerNodeCounts = in.readMap(s -> s.readString(), s-> s.readInt());
45+
this.modelWorkerNodeBeforeRemoval = in.readMap(s -> s.readString(), s-> s.readOptionalStringArray());
4146
}
4247
}
4348

@@ -55,9 +60,9 @@ public void writeTo(StreamOutput out) throws IOException {
5560
} else {
5661
out.writeBoolean(false);
5762
}
58-
if (modelWorkerNodeCounts != null) {
63+
if (modelWorkerNodeBeforeRemoval != null) {
5964
out.writeBoolean(true);
60-
out.writeMap(modelWorkerNodeCounts, StreamOutput::writeString, StreamOutput::writeInt);
65+
out.writeMap(modelWorkerNodeBeforeRemoval, StreamOutput::writeString, StreamOutput::writeOptionalStringArray);
6166
} else {
6267
out.writeBoolean(false);
6368
}

common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodesRequestTest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ public void setUp() throws Exception {
8484
@Test
8585
public void testConstructorSerialization1() throws IOException {
8686
String [] nodeIds = {"id1", "id2", "id3"};
87-
MLDeployModelInput deployModelInput = new MLDeployModelInput("modelId", "taskId", "modelContentHash", 3, "coordinatingNodeId", mlTask);
87+
MLDeployModelInput deployModelInput = new MLDeployModelInput("modelId", "taskId", "modelContentHash", 3, "coordinatingNodeId", true, mlTask);
8888
MLDeployModelNodeRequest MLDeployModelNodeRequest = new MLDeployModelNodeRequest(
8989
new MLDeployModelNodesRequest(nodeIds, deployModelInput)
9090
);
@@ -104,7 +104,7 @@ public void testConstructorSerialization1() throws IOException {
104104
@Test
105105
public void testConstructorSerialization2() throws IOException {
106106
DiscoveryNode [] nodeIds = {localNode1, localNode2, localNode3};
107-
MLDeployModelInput deployModelInput = new MLDeployModelInput("modelId", "taskId", "modelContentHash", 3, "coordinatingNodeId", mlTask);
107+
MLDeployModelInput deployModelInput = new MLDeployModelInput("modelId", "taskId", "modelContentHash", 3, "coordinatingNodeId", true, mlTask);
108108
MLDeployModelNodeRequest MLDeployModelNodeRequest = new MLDeployModelNodeRequest(
109109
new MLDeployModelNodesRequest(nodeIds, deployModelInput)
110110
);
@@ -140,7 +140,7 @@ public void testConstructorSerialization3() throws IOException {
140140
@Test
141141
public void testConstructorFromInputStream() throws IOException {
142142
String [] nodeIds = {"id1", "id2", "id3"};
143-
MLDeployModelInput deployModelInput = new MLDeployModelInput("modelId", "taskId", "modelContentHash", 3, "coordinatingNodeId", mlTask);
143+
MLDeployModelInput deployModelInput = new MLDeployModelInput("modelId", "taskId", "modelContentHash", 3, "coordinatingNodeId", true, mlTask);
144144
MLDeployModelNodeRequest MLDeployModelNodeRequest = new MLDeployModelNodeRequest(
145145
new MLDeployModelNodesRequest(nodeIds, deployModelInput)
146146
);

common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodeResponseTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ public class MLUndeployModelNodeResponseTest {
2525
@Mock
2626
private DiscoveryNode localNode;
2727

28-
private Map<String, Integer> modelWorkerNodeCounts;
28+
private Map<String, String[]> modelWorkerNodeCounts;
2929

3030
@Before
3131
public void setUp() throws Exception {
@@ -38,7 +38,7 @@ public void setUp() throws Exception {
3838
Version.CURRENT
3939
);
4040
modelWorkerNodeCounts = new HashMap<>();
41-
modelWorkerNodeCounts.put("modelId1", 1);
41+
modelWorkerNodeCounts.put("modelId1", new String[]{"node"});
4242
}
4343

4444
@Test

common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesResponseTest.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,14 +77,14 @@ public void testToXContent() throws IOException {
7777

7878
Map<String, String> modelToUndeployStatus1 = new HashMap<>();
7979
modelToUndeployStatus1.put("modelId1", "response");
80-
Map<String, Integer> modelWorkerNodeCounts1 = new HashMap<>();
81-
modelWorkerNodeCounts1.put("modelId1", 1);
80+
Map<String, String[]> modelWorkerNodeCounts1 = new HashMap<>();
81+
modelWorkerNodeCounts1.put("modelId1", new String[]{"mockNode1"});
8282
nodes.add(new MLUndeployModelNodeResponse(node1, modelToUndeployStatus1, modelWorkerNodeCounts1));
8383

8484
Map<String, String> modelToUndeployStatus2 = new HashMap<>();
8585
modelToUndeployStatus2.put("modelId2", "response");
86-
Map<String, Integer> modelWorkerNodeCounts2 = new HashMap<>();
87-
modelWorkerNodeCounts2.put("modelId2", 2);
86+
Map<String, String[]> modelWorkerNodeCounts2 = new HashMap<>();
87+
modelWorkerNodeCounts2.put("modelId2", new String[]{"mockNode2"});
8888
nodes.add(new MLUndeployModelNodeResponse(node2, modelToUndeployStatus2, modelWorkerNodeCounts2));
8989

9090
List<FailedNodeException> failures = new ArrayList<>();

plugin/build.gradle

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,8 @@ List<String> jacocoExclusions = [
277277
'org.opensearch.ml.action.profile.MLProfileTransportAction',
278278
'org.opensearch.ml.action.models.DeleteModelTransportAction.1',
279279
'org.opensearch.ml.rest.RestMLPredictionAction',
280-
'org.opensearch.ml.breaker.DiskCircuitBreaker'
280+
'org.opensearch.ml.breaker.DiskCircuitBreaker',
281+
'org.opensearch.ml.autoredeploy.MLModelAutoReDeployer.SearchRequestBuilderFactory'
281282
]
282283

283284
jacocoTestCoverageVerification {

plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ void updateModelDeployStatusAndTriggerOnNodesAction(
248248
mlModel.getModelContentHash(),
249249
eligibleNodes.size(),
250250
localNodeId,
251+
deployToAllNodes,
251252
mlTask
252253
);
253254
MLDeployModelNodesRequest deployModelRequest = new MLDeployModelNodesRequest(

0 commit comments

Comments
 (0)