Skip to content

Commit c5beb06

Browse files
Add memory consumption estimation for models in profile API. (#853) (#856)
Signed-off-by: Jing Zhang <[email protected]> (cherry picked from commit dd2799a) Co-authored-by: Jing Zhang <[email protected]>
1 parent 3ddec6e commit c5beb06

File tree

6 files changed

+110
-2
lines changed

6 files changed

+110
-2
lines changed

plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ public class MLModelCache {
3636
private final Set<String> workerNodes;
3737
private final Queue<Double> modelInferenceDurationQueue;
3838
private final Queue<Double> predictRequestDurationQueue;
39+
private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Long memSizeEstimationCPU;
40+
private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Long memSizeEstimationGPU;
3941

4042
public MLModelCache() {
4143
targetWorkerNodes = ConcurrentHashMap.newKeySet();
@@ -90,6 +92,8 @@ public void clear() {
9092
if (predictor != null) {
9193
predictor.close();
9294
}
95+
memSizeEstimationCPU = 0L;
96+
memSizeEstimationGPU = 0L;
9397
if (executor != null) {
9498
executor.close();
9599
}

plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.opensearch.common.settings.Settings;
2222
import org.opensearch.ml.common.FunctionName;
2323
import org.opensearch.ml.common.exception.MLLimitExceededException;
24+
import org.opensearch.ml.common.model.MLModelFormat;
2425
import org.opensearch.ml.common.model.MLModelState;
2526
import org.opensearch.ml.engine.MLExecutable;
2627
import org.opensearch.ml.engine.Predictable;
@@ -66,6 +67,59 @@ public synchronized void setModelState(String modelId, MLModelState state) {
6667
getExistingModelCache(modelId).setModelState(state);
6768
}
6869

70+
/**
71+
* Set memory size estimation CPU/GPU
72+
* @param modelId model id
73+
* @param format model format like onnx
74+
* @param size memory size
75+
*/
76+
public synchronized void setMemSizeEstimation(String modelId, MLModelFormat format, Long size) {
77+
Long memSize = getMemSizeEstimation(format, size);
78+
log.debug("Updating memSizeEstimation of Model {} to {}", modelId, memSize);
79+
getExistingModelCache(modelId).setMemSizeEstimationCPU(memSize);
80+
getExistingModelCache(modelId).setMemSizeEstimationGPU(memSize);
81+
}
82+
83+
private Long getMemSizeEstimation(MLModelFormat format, Long size) {
84+
Double scale = 1.0;
85+
switch (format) {
86+
case ONNX:
87+
scale = 1.5;
88+
break;
89+
case TORCH_SCRIPT:
90+
scale = 1.2;
91+
break;
92+
}
93+
Long memSize = Double.valueOf(scale * size).longValue();
94+
return memSize;
95+
}
96+
97+
/**
98+
* Get CPU memory estimation.
99+
* @param modelId model id
100+
* @return Long
101+
*/
102+
public Long getMemEstCPU(String modelId) {
103+
MLModelCache modelCache = modelCaches.get(modelId);
104+
if (modelCache == null) {
105+
return null;
106+
}
107+
return modelCache.getMemSizeEstimationCPU();
108+
}
109+
110+
/**
111+
* Get GPU memory estimation.
112+
* @param modelId model id
113+
* @return Long
114+
*/
115+
public Long getMemEstGPU(String modelId) {
116+
MLModelCache modelCache = modelCaches.get(modelId);
117+
if (modelCache == null) {
118+
return null;
119+
}
120+
return modelCache.getMemSizeEstimationGPU();
121+
}
122+
69123
/**
70124
* Check if model deployed on node.
71125
* @param modelId model id
@@ -293,6 +347,8 @@ public MLModelProfile getModelProfile(String modelId) {
293347
}
294348
builder.modelInferenceStats(modelCache.getInferenceStats(true));
295349
builder.predictRequestStats(modelCache.getInferenceStats(false));
350+
builder.memSizeEstimationCPU(modelCache.getMemSizeEstimationCPU());
351+
builder.memSizeEstimationGPU(modelCache.getMemSizeEstimationGPU());
296352
return builder.build();
297353
}
298354

plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,7 @@ public void deployModel(
568568
modelCacheHelper.setPredictor(modelId, predictable);
569569
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).increment();
570570
modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED);
571+
modelCacheHelper.setMemSizeEstimation(modelId, mlModel.getModelFormat(), mlModel.getModelContentSizeInBytes());
571572
listener.onResponse("successful");
572573
} catch (Exception e) {
573574
log.error("Failed to add predictor to cache", e);

plugin/src/main/java/org/opensearch/ml/profile/MLModelProfile.java

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ public class MLModelProfile implements ToXContentFragment, Writeable {
2828
private final String[] workerNodes;
2929
private final MLPredictRequestStats modelInferenceStats;
3030
private final MLPredictRequestStats predictRequestStats;
31+
private final Long memSizeEstimationCPU;
32+
private final Long memSizeEstimationGPU;
3133

3234
@Builder
3335
public MLModelProfile(
@@ -36,14 +38,18 @@ public MLModelProfile(
3638
String[] targetWorkerNodes,
3739
String[] workerNodes,
3840
MLPredictRequestStats modelInferenceStats,
39-
MLPredictRequestStats predictRequestStats
41+
MLPredictRequestStats predictRequestStats,
42+
Long memSizeEstimationCPU,
43+
Long memSizeEstimationGPU
4044
) {
4145
this.modelState = modelState;
4246
this.predictor = predictor;
4347
this.targetWorkerNodes = targetWorkerNodes;
4448
this.workerNodes = workerNodes;
4549
this.modelInferenceStats = modelInferenceStats;
4650
this.predictRequestStats = predictRequestStats;
51+
this.memSizeEstimationCPU = memSizeEstimationCPU;
52+
this.memSizeEstimationGPU = memSizeEstimationGPU;
4753
}
4854

4955
@Override
@@ -67,6 +73,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
6773
if (predictRequestStats != null) {
6874
builder.field("predict_request_stats", predictRequestStats);
6975
}
76+
if (memSizeEstimationCPU != null) {
77+
builder.field("mem_size_estimation_cpu", memSizeEstimationCPU);
78+
}
79+
if (memSizeEstimationGPU != null) {
80+
builder.field("mem_size_estimation_gpu", memSizeEstimationGPU);
81+
}
7082
builder.endObject();
7183
return builder;
7284
}
@@ -90,6 +102,8 @@ public MLModelProfile(StreamInput in) throws IOException {
90102
} else {
91103
this.predictRequestStats = null;
92104
}
105+
this.memSizeEstimationCPU = in.readOptionalLong();
106+
this.memSizeEstimationGPU = in.readOptionalLong();
93107
}
94108

95109
@Override
@@ -115,5 +129,7 @@ public void writeTo(StreamOutput out) throws IOException {
115129
} else {
116130
out.writeBoolean(false);
117131
}
132+
out.writeOptionalLong(memSizeEstimationCPU);
133+
out.writeOptionalLong(memSizeEstimationGPU);
118134
}
119135
}

plugin/src/main/java/org/opensearch/ml/rest/RestMLProfileAction.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,9 @@ private Map<String, MLProfileModelResponse> buildModelCentricResult(List<MLProfi
161161
null,
162162
null,
163163
entry.getValue().getModelInferenceStats(),
164-
entry.getValue().getPredictRequestStats()
164+
entry.getValue().getPredictRequestStats(),
165+
entry.getValue().getMemSizeEstimationCPU(),
166+
entry.getValue().getMemSizeEstimationGPU()
165167
);
166168
mlProfileModelResponse.getMlModelProfileMap().putAll(ImmutableMap.of(nodeId, modelProfile));
167169
}

plugin/src/test/java/org/opensearch/ml/model/MLModelCacheHelperTests.java

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.opensearch.common.settings.Settings;
2828
import org.opensearch.ml.common.FunctionName;
2929
import org.opensearch.ml.common.exception.MLLimitExceededException;
30+
import org.opensearch.ml.common.model.MLModelFormat;
3031
import org.opensearch.ml.common.model.MLModelState;
3132
import org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingModel;
3233
import org.opensearch.ml.profile.MLModelProfile;
@@ -78,6 +79,34 @@ public void testModelState() {
7879
assertEquals(FunctionName.TEXT_EMBEDDING, cacheHelper.getFunctionName(modelId));
7980
}
8081

82+
public void testMemSizeEstimationCPU() {
83+
cacheHelper.initModelState(modelId, MLModelState.DEPLOYING, FunctionName.TEXT_EMBEDDING, targetWorkerNodes);
84+
assertTrue(cacheHelper.getMemEstCPU(modelId) == null);
85+
cacheHelper.setMemSizeEstimation(modelId, MLModelFormat.TORCH_SCRIPT, 1000L);
86+
assertTrue(cacheHelper.getMemEstCPU(modelId) == 1200L);
87+
}
88+
89+
public void testMemSizeEstimationCPUONNX() {
90+
cacheHelper.initModelState(modelId, MLModelState.DEPLOYING, FunctionName.TEXT_EMBEDDING, targetWorkerNodes);
91+
assertTrue(cacheHelper.getMemEstCPU(modelId) == null);
92+
cacheHelper.setMemSizeEstimation(modelId, MLModelFormat.ONNX, 1000L);
93+
assertTrue(cacheHelper.getMemEstCPU(modelId) == 1500L);
94+
}
95+
96+
public void testMemSizeEstimationGPU() {
97+
cacheHelper.initModelState(modelId, MLModelState.DEPLOYING, FunctionName.TEXT_EMBEDDING, targetWorkerNodes);
98+
assertTrue(cacheHelper.getMemEstGPU(modelId) == null);
99+
cacheHelper.setMemSizeEstimation(modelId, MLModelFormat.TORCH_SCRIPT, 1000L);
100+
assertTrue(cacheHelper.getMemEstGPU(modelId) == 1200L);
101+
}
102+
103+
public void testMemSizeEstimationGPUONNX() {
104+
cacheHelper.initModelState(modelId, MLModelState.DEPLOYING, FunctionName.TEXT_EMBEDDING, targetWorkerNodes);
105+
assertTrue(cacheHelper.getMemEstGPU(modelId) == null);
106+
cacheHelper.setMemSizeEstimation(modelId, MLModelFormat.ONNX, 1000L);
107+
assertTrue(cacheHelper.getMemEstGPU(modelId) == 1500L);
108+
}
109+
81110
public void testModelState_DuplicateError() {
82111
expectedEx.expect(MLLimitExceededException.class);
83112
expectedEx.expectMessage("Duplicate deploy model task");

0 commit comments

Comments
 (0)