Skip to content

Commit ef435c9

Browse files
Add TTL to un-deploy model automatically (#2365) (#2374)
* add ttl to un-deploy model automatically Signed-off-by: Xun Zhang <[email protected]> * undeploy only for models that expired in all nodes Signed-off-by: Xun Zhang <[email protected]> * add bwc version for model ttl Signed-off-by: Xun Zhang <[email protected]> * only use minutes in ttl Signed-off-by: Xun Zhang <[email protected]> * move MINIMAL_SUPPORTED_VERSION_FOR_MODEL_TTL to MLDeploySetting Signed-off-by: Xun Zhang <[email protected]> --------- Signed-off-by: Xun Zhang <[email protected]> (cherry picked from commit e380395) Co-authored-by: Xun Zhang <[email protected]>
1 parent 7e76fa9 commit ef435c9

File tree

13 files changed

+153
-13
lines changed

13 files changed

+153
-13
lines changed

common/src/main/java/org/opensearch/ml/common/model/MLDeploySetting.java

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import lombok.Builder;
99
import lombok.Getter;
1010
import lombok.Setter;
11+
import org.opensearch.Version;
1112
import org.opensearch.core.common.io.stream.Writeable;
1213
import org.opensearch.core.common.io.stream.StreamInput;
1314
import org.opensearch.core.common.io.stream.StreamOutput;
@@ -24,25 +25,42 @@
2425
@Getter
2526
public class MLDeploySetting implements ToXContentObject, Writeable {
2627
public static final String IS_AUTO_DEPLOY_ENABLED_FIELD = "is_auto_deploy_enabled";
28+
public static final String MODEL_TTL_MINUTES_FIELD = "model_ttl_minutes";
29+
private static final long DEFAULT_TTL_MINUTES = -1;
30+
public static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_TTL = Version.V_2_14_0;
2731

2832
private Boolean isAutoDeployEnabled;
33+
private Long modelTTLInMinutes; // in minutes
2934

3035
@Builder(toBuilder = true)
31-
public MLDeploySetting(Boolean isAutoDeployEnabled) {
36+
public MLDeploySetting(Boolean isAutoDeployEnabled, Long modelTTLInMinutes) {
3237
this.isAutoDeployEnabled = isAutoDeployEnabled;
38+
this.modelTTLInMinutes = modelTTLInMinutes;
39+
if (modelTTLInMinutes == null) {
40+
this.modelTTLInMinutes = DEFAULT_TTL_MINUTES;
41+
}
3342
}
3443

3544
public MLDeploySetting(StreamInput in) throws IOException {
3645
this.isAutoDeployEnabled = in.readOptionalBoolean();
46+
Version streamInputVersion = in.getVersion();
47+
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_MODEL_TTL)) {
48+
this.modelTTLInMinutes = in.readOptionalLong();
49+
}
3750
}
3851

3952
@Override
4053
public void writeTo(StreamOutput out) throws IOException {
54+
Version streamOutputVersion = out.getVersion();
4155
out.writeOptionalBoolean(isAutoDeployEnabled);
56+
if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_MODEL_TTL)) {
57+
out.writeOptionalLong(modelTTLInMinutes);
58+
}
4259
}
4360

4461
public static MLDeploySetting parse(XContentParser parser) throws IOException {
4562
Boolean isAutoDeployEnabled = null;
63+
Long modelTTLMinutes = null;
4664
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
4765
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
4866
String fieldName = parser.currentName();
@@ -51,12 +69,14 @@ public static MLDeploySetting parse(XContentParser parser) throws IOException {
5169
case IS_AUTO_DEPLOY_ENABLED_FIELD:
5270
isAutoDeployEnabled = parser.booleanValue();
5371
break;
72+
case MODEL_TTL_MINUTES_FIELD:
73+
modelTTLMinutes = parser.longValue();
5474
default:
5575
parser.skipChildren();
5676
break;
5777
}
5878
}
59-
return new MLDeploySetting(isAutoDeployEnabled);
79+
return new MLDeploySetting(isAutoDeployEnabled, modelTTLMinutes);
6080
}
6181

6282
@Override
@@ -65,6 +85,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
6585
if (isAutoDeployEnabled != null) {
6686
builder.field(IS_AUTO_DEPLOY_ENABLED_FIELD, isAutoDeployEnabled);
6787
}
88+
if (modelTTLInMinutes != null) {
89+
builder.field(MODEL_TTL_MINUTES_FIELD, modelTTLInMinutes);
90+
}
6891
builder.endObject();
6992
return builder;
7093
}

common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpInput.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import lombok.Builder;
99
import lombok.Data;
10+
import org.opensearch.Version;
1011
import org.opensearch.core.common.io.stream.StreamInput;
1112
import org.opensearch.core.common.io.stream.StreamOutput;
1213
import org.opensearch.core.common.io.stream.Writeable;

common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponse.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77

88
import lombok.Getter;
99
import lombok.extern.log4j.Log4j2;
10+
import org.opensearch.Version;
1011
import org.opensearch.action.support.nodes.BaseNodeResponse;
1112
import org.opensearch.cluster.node.DiscoveryNode;
1213
import org.opensearch.core.common.io.stream.StreamInput;
1314
import org.opensearch.core.common.io.stream.StreamOutput;
15+
import org.opensearch.ml.common.model.MLDeploySetting;
1416

1517
import java.io.IOException;
1618

@@ -22,22 +24,28 @@ public class MLSyncUpNodeResponse extends BaseNodeResponse {
2224
private String[] deployedModelIds;
2325
private String[] runningDeployModelIds; // model ids which have deploying model task running
2426
private String[] runningDeployModelTaskIds; // deploy model task ids which is running
27+
private String[] expiredModelIds;
2528

2629
public MLSyncUpNodeResponse(DiscoveryNode node, String modelStatus, String[] deployedModelIds, String[] runningDeployModelIds,
27-
String[] runningDeployModelTaskIds) {
30+
String[] runningDeployModelTaskIds, String[] expiredModelIds) {
2831
super(node);
2932
this.modelStatus = modelStatus;
3033
this.deployedModelIds = deployedModelIds;
3134
this.runningDeployModelIds = runningDeployModelIds;
3235
this.runningDeployModelTaskIds = runningDeployModelTaskIds;
36+
this.expiredModelIds = expiredModelIds;
3337
}
3438

3539
public MLSyncUpNodeResponse(StreamInput in) throws IOException {
3640
super(in);
41+
Version streamInputVersion = in.getVersion();
3742
this.modelStatus = in.readOptionalString();
3843
this.deployedModelIds = in.readOptionalStringArray();
3944
this.runningDeployModelIds = in.readOptionalStringArray();
4045
this.runningDeployModelTaskIds = in.readOptionalStringArray();
46+
if (streamInputVersion.onOrAfter(MLDeploySetting.MINIMAL_SUPPORTED_VERSION_FOR_MODEL_TTL)) {
47+
this.expiredModelIds = in.readOptionalStringArray();
48+
}
4149
}
4250

4351
public static MLSyncUpNodeResponse readStats(StreamInput in) throws IOException {
@@ -46,11 +54,14 @@ public static MLSyncUpNodeResponse readStats(StreamInput in) throws IOException
4654

4755
@Override
4856
public void writeTo(StreamOutput out) throws IOException {
57+
Version streamOutputVersion = out.getVersion();
4958
super.writeTo(out);
5059
out.writeOptionalString(modelStatus);
5160
out.writeOptionalStringArray(deployedModelIds);
5261
out.writeOptionalStringArray(runningDeployModelIds);
5362
out.writeOptionalStringArray(runningDeployModelTaskIds);
63+
if (streamOutputVersion.onOrAfter(MLDeploySetting.MINIMAL_SUPPORTED_VERSION_FOR_MODEL_TTL)) {
64+
out.writeOptionalStringArray(expiredModelIds);
65+
}
5466
}
55-
5667
}

common/src/test/java/org/opensearch/ml/common/model/MLDeployingSettingTests.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ public class MLDeployingSettingTests {
3636

3737
private MLDeploySetting deploySettingNull;
3838

39-
private final String expectedInputStr = "{\"is_auto_deploy_enabled\":true}";
39+
private final String expectedInputStr = "{\"is_auto_deploy_enabled\":true,\"model_ttl_minutes\":-1}";
4040

4141
@Rule
4242
public ExpectedException exceptionRule = ExpectedException.none();
@@ -66,7 +66,7 @@ public void testToXContent() throws Exception {
6666

6767
@Test
6868
public void testToXContentIncomplete() throws Exception {
69-
final String expectedIncompleteInputStr = "{}";
69+
final String expectedIncompleteInputStr = "{\"model_ttl_minutes\":-1}";
7070

7171
String jsonStr = serializationWithToXContent(deploySettingNull);
7272
assertEquals(expectedIncompleteInputStr, jsonStr);
@@ -109,12 +109,12 @@ public void parseWithIllegalArgumentInteger() throws Exception {
109109

110110
@Test
111111
public void parseWithIllegalField() throws Exception {
112-
final String expectedInputStrWithIllegalField = "{\"is_auto_deploy_enabled\":true," +
112+
final String expectedInputStrWithIllegalField = "{\"is_auto_deploy_enabled\":true," + "\"model_ttl_hours\":0," +
113113
"\"illegal_field\":\"This field need to be skipped.\"}";
114114

115115
testParseFromJsonString(expectedInputStrWithIllegalField, parsedInput -> {
116116
try {
117-
assertEquals(expectedInputStr, serializationWithToXContent(parsedInput));
117+
assertEquals("{\"is_auto_deploy_enabled\":true,\"model_ttl_minutes\":-1}", serializationWithToXContent(parsedInput));
118118
} catch (IOException e) {
119119
throw new RuntimeException(e);
120120
}

common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponseTest.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ public class MLSyncUpNodeResponseTest {
2424
private final String[] loadedModelIds = {"loadedModelIds"};
2525
private final String[] runningLoadModelTaskIds = {"runningLoadModelTaskIds"};
2626
private final String[] runningLoadModelIds = {"modelid1"};
27+
28+
private final String[] expiredModelIds = {"modelExpired"};
2729
@Before
2830
public void setUp() throws Exception {
2931
localNode = new DiscoveryNode(
@@ -38,7 +40,7 @@ public void setUp() throws Exception {
3840

3941
@Test
4042
public void testSerializationDeserialization() throws IOException {
41-
MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(localNode, modelStatus, loadedModelIds, runningLoadModelIds, runningLoadModelTaskIds);
43+
MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(localNode, modelStatus, loadedModelIds, runningLoadModelIds, runningLoadModelTaskIds, expiredModelIds);
4244
BytesStreamOutput output = new BytesStreamOutput();
4345
response.writeTo(output);
4446
MLSyncUpNodeResponse newResponse = new MLSyncUpNodeResponse(output.bytes().streamInput());
@@ -51,7 +53,7 @@ public void testSerializationDeserialization() throws IOException {
5153

5254
@Test
5355
public void testReadProfile() throws IOException {
54-
MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(localNode, modelStatus, loadedModelIds, runningLoadModelIds, runningLoadModelTaskIds);
56+
MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(localNode, modelStatus, loadedModelIds, runningLoadModelIds, runningLoadModelTaskIds, expiredModelIds);
5557
BytesStreamOutput output = new BytesStreamOutput();
5658
response.writeTo(output);
5759
MLSyncUpNodeResponse newResponse = MLSyncUpNodeResponse.readStats(output.bytes().streamInput());

plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ private void executePredict(
223223
long endTime = System.nanoTime();
224224
double durationInMs = (endTime - startTime) / 1e6;
225225
modelCacheHelper.addPredictRequestDuration(modelId, durationInMs);
226+
modelCacheHelper.refreshLastAccessTime(modelId);
226227
log.debug("completed predict request " + requestId + " for model " + modelId);
227228
})
228229
);

plugin/src/main/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeAction.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,13 @@ private MLSyncUpNodeResponse createSyncUpNodeResponse(MLSyncUpNodesRequest syncU
162162
String[] deployedModelIds = null;
163163
String[] runningDeployModelTaskIds = null;
164164
String[] runningDeployModelIds = null;
165+
String[] expiredModelIds = null;
165166
if (syncUpInput.isGetDeployedModels()) {
166167
deployedModelIds = mlModelManager.getLocalDeployedModels();
167168
List<String[]> localRunningDeployModel = mlTaskManager.getLocalRunningDeployModelTasks();
168169
runningDeployModelTaskIds = localRunningDeployModel.get(0);
169170
runningDeployModelIds = localRunningDeployModel.get(1);
171+
expiredModelIds = mlModelManager.getExpiredModels();
170172
}
171173

172174
if (syncUpInput.isClearRoutingTable()) {
@@ -186,7 +188,8 @@ private MLSyncUpNodeResponse createSyncUpNodeResponse(MLSyncUpNodesRequest syncU
186188
"ok",
187189
deployedModelIds,
188190
runningDeployModelIds,
189-
runningDeployModelTaskIds
191+
runningDeployModelTaskIds,
192+
expiredModelIds
190193
);
191194
}
192195

plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
import org.opensearch.ml.common.transport.sync.MLSyncUpInput;
4242
import org.opensearch.ml.common.transport.sync.MLSyncUpNodeResponse;
4343
import org.opensearch.ml.common.transport.sync.MLSyncUpNodesRequest;
44+
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelAction;
45+
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesRequest;
4446
import org.opensearch.ml.engine.encryptor.Encryptor;
4547
import org.opensearch.ml.engine.indices.MLIndicesHandler;
4648
import org.opensearch.search.SearchHit;
@@ -101,8 +103,17 @@ public void run() {
101103
Map<String, Set<String>> runningDeployModelTasks = new HashMap<>();
102104
// key is model id, value is set of worker node ids
103105
Map<String, Set<String>> deployingModels = new HashMap<>();
106+
// key is expired model_id, value is set of worker node ids
107+
Map<String, Set<String>> expiredModelToNodes = new HashMap<>();
104108
for (MLSyncUpNodeResponse response : responses) {
105109
String nodeId = response.getNode().getId();
110+
String[] expiredModelIds = response.getExpiredModelIds();
111+
if (expiredModelIds != null && expiredModelIds.length > 0) {
112+
Arrays
113+
.stream(expiredModelIds)
114+
.forEach(modelId -> { expiredModelToNodes.computeIfAbsent(modelId, it -> new HashSet<>()).add(nodeId); });
115+
}
116+
106117
String[] deployedModelIds = response.getDeployedModelIds();
107118
if (deployedModelIds != null && deployedModelIds.length > 0) {
108119
for (String modelId : deployedModelIds) {
@@ -126,6 +137,17 @@ public void run() {
126137
}
127138
}
128139
}
140+
141+
Set<String> modelsToUndeploy = new HashSet<>();
142+
for (String modelId : expiredModelToNodes.keySet()) {
143+
if (modelWorkerNodes.containsKey(modelId)
144+
&& expiredModelToNodes.get(modelId).size() == modelWorkerNodes.get(modelId).size()) {
145+
// this model has expired in all the nodes
146+
modelWorkerNodes.remove(modelId);
147+
modelsToUndeploy.add(modelId);
148+
}
149+
}
150+
129151
for (Map.Entry<String, Set<String>> entry : modelWorkerNodes.entrySet()) {
130152
String modelId = entry.getKey();
131153
log.debug("will sync model worker nodes for model: {}: {}", modelId, entry.getValue().toArray(new String[0]));
@@ -154,6 +176,8 @@ public void run() {
154176
log.error("Failed to sync model routing", ex);
155177
})
156178
);
179+
// Undeploy expired models
180+
undeployExpiredModels(modelsToUndeploy, modelWorkerNodes);
157181

158182
// refresh model status
159183
mlIndicesHandler
@@ -163,6 +187,20 @@ public void run() {
163187
}, e -> { log.error("Failed to sync model routing", e); }));
164188
}
165189

190+
private void undeployExpiredModels(Set<String> expiredModels, Map<String, Set<String>> modelWorkerNodes) {
191+
expiredModels.forEach(modelId -> {
192+
String[] targetNodeIds = modelWorkerNodes.keySet().toArray(new String[0]);
193+
194+
MLUndeployModelNodesRequest mlUndeployModelNodesRequest = new MLUndeployModelNodesRequest(
195+
targetNodeIds,
196+
new String[] { modelId }
197+
);
198+
client.execute(MLUndeployModelAction.INSTANCE, mlUndeployModelNodesRequest, ActionListener.wrap(r -> {
199+
log.debug("model {} is un_deployed", modelId);
200+
}, e -> { log.error("Failed to undeploy model {}", modelId, e); }));
201+
});
202+
}
203+
166204
@VisibleForTesting
167205
void initMLConfig() {
168206
if (mlConfigInited) {

plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
package org.opensearch.ml.model;
77

8+
import java.time.Instant;
89
import java.util.DoubleSummaryStatistics;
910
import java.util.List;
1011
import java.util.Map;
@@ -51,6 +52,7 @@ public class MLModelCache {
5152
// In rare case, this could be null, e.g. model info not synced up yet a predict request comes in.
5253
@Setter
5354
private Boolean deployToAllNodes;
55+
private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Instant lastAccessTime;
5456

5557
public MLModelCache() {
5658
targetWorkerNodes = ConcurrentHashMap.newKeySet();

plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MONITORING_REQUEST_COUNT;
99

10+
import java.time.Duration;
11+
import java.time.Instant;
1012
import java.util.HashSet;
1113
import java.util.List;
1214
import java.util.Map;
@@ -68,6 +70,7 @@ public synchronized void initModelState(
6870
modelCache.setFunctionName(functionName);
6971
modelCache.setTargetWorkerNodes(targetWorkerNodes);
7072
modelCache.setDeployToAllNodes(deployToAllNodes);
73+
modelCache.setLastAccessTime(Instant.now());
7174
modelCaches.put(modelId, modelCache);
7275
}
7376

@@ -87,6 +90,7 @@ public synchronized void initModelStateLocal(
8790
modelCache.setFunctionName(functionName);
8891
modelCache.setTargetWorkerNodes(targetWorkerNodes);
8992
modelCache.setDeployToAllNodes(false);
93+
modelCache.setLastAccessTime(Instant.now());
9094
modelCaches.put(modelId, modelCache);
9195
}
9296

@@ -341,6 +345,29 @@ public String[] getLocalDeployedModels() {
341345
.toArray(new String[0]);
342346
}
343347

348+
/**
349+
* Get expired models on node.
350+
*
351+
* @return array of expired model id
352+
*/
353+
public String[] getExpiredModels() {
354+
return modelCaches.entrySet().stream().filter(entry -> {
355+
MLModel mlModel = entry.getValue().getCachedModelInfo();
356+
if (mlModel.getDeploySetting() == null) {
357+
return false; // no TTL, never expire
358+
}
359+
Duration liveDuration = Duration.between(entry.getValue().getLastAccessTime(), Instant.now());
360+
Long ttlInMinutes = mlModel.getDeploySetting().getModelTTLInMinutes();
361+
if (ttlInMinutes < 0) {
362+
return false;
363+
}
364+
Duration ttl = Duration.ofMinutes(ttlInMinutes);
365+
boolean isModelExpired = liveDuration.getSeconds() >= ttl.getSeconds();
366+
return isModelExpired
367+
&& (mlModel.getModelState() == MLModelState.DEPLOYED || mlModel.getModelState() == MLModelState.PARTIALLY_DEPLOYED);
368+
}).map(entry -> entry.getKey()).collect(Collectors.toList()).toArray(new String[0]);
369+
}
370+
344371
/**
345372
* Check if model is running on node.
346373
*
@@ -403,6 +430,16 @@ public void setTargetWorkerNodes(String modelId, List<String> targetWorkerNodes)
403430
}
404431
}
405432

433+
/**
434+
* Set the last access time to Instant.now()
435+
*
436+
* @param modelId model id
437+
*/
438+
public void refreshLastAccessTime(String modelId) {
439+
MLModelCache modelCache = modelCaches.get(modelId);
440+
modelCache.setLastAccessTime(Instant.now());
441+
}
442+
406443
/**
407444
* Remove model.
408445
*

0 commit comments

Comments
 (0)