Skip to content

Commit 3e2154d

Browse files
author
Hendrik Muhs
authored
[ML] set model memory usage for ELSER to 2004MB #96101
This change hardcodes the pretrained/fixed elser model to 2004MB until we found a better way to set/calculate model memory usage.
1 parent 4f4614a commit 3e2154d

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,13 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
6161
*/
6262
private static final ByteSizeValue MEMORY_OVERHEAD = ByteSizeValue.ofMb(240);
6363

64+
/**
65+
* The ELSER model turned out to use more memory then what we usually estimate.
66+
* We overwrite the estimate with this static value for ELSER V1 for now. Soon to be
67+
* replaced with a better estimate provided by the model.
68+
*/
69+
private static final ByteSizeValue ELSER_1_MEMORY_USAGE = ByteSizeValue.ofMb(2004);
70+
6471
public StartTrainedModelDeploymentAction() {
6572
super(NAME, CreateTrainedModelAssignmentAction.Response::new);
6673
}
@@ -514,9 +521,10 @@ public long estimateMemoryUsageBytes() {
514521
// We already take into account 2x the model bytes. If the cache size is larger than the model bytes, then
515522
// we need to take it into account when returning the estimate.
516523
if (cacheSize != null && cacheSize.getBytes() > modelBytes) {
517-
return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(modelBytes) + (cacheSize.getBytes() - modelBytes);
524+
return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(modelId, modelBytes) + (cacheSize.getBytes()
525+
- modelBytes);
518526
}
519-
return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(modelBytes);
527+
return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(modelId, modelBytes);
520528
}
521529

522530
public Version getMinimalSupportedVersion() {
@@ -641,8 +649,12 @@ static boolean match(Task task, String expectedId) {
641649
}
642650
}
643651

644-
public static long estimateMemoryUsageBytes(long totalDefinitionLength) {
652+
public static long estimateMemoryUsageBytes(String modelId, long totalDefinitionLength) {
645653
// While loading the model in the process we need twice the model size.
646-
return MEMORY_OVERHEAD.getBytes() + 2 * totalDefinitionLength;
654+
return isElserModel(modelId) ? ELSER_1_MEMORY_USAGE.getBytes() : MEMORY_OVERHEAD.getBytes() + 2 * totalDefinitionLength;
655+
}
656+
657+
private static boolean isElserModel(String modelId) {
658+
return modelId.startsWith(".elser_model_1");
647659
}
648660
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ private void modelSizeStats(
290290
new TrainedModelSizeStats(
291291
totalDefinitionLength,
292292
totalDefinitionLength > 0L
293-
? StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(totalDefinitionLength)
293+
? StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(model.getModelId(), totalDefinitionLength)
294294
: 0L
295295
)
296296
);

0 commit comments

Comments
 (0)