Skip to content

Commit f393dba

Browse files
authored
Fix memory usage estimation for ELSER models (elastic#131630)
* Pass model ID instead of deployment ID to memory estimator * Update docs/changelog/131630.yaml
1 parent cc68b17 commit f393dba

File tree

11 files changed

+329
-82
lines changed

11 files changed

+329
-82
lines changed

docs/changelog/131630.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 131630
2+
summary: Fix memory usage estimation for ELSER models
3+
area: Machine Learning
4+
type: bug
5+
issues: []

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ private AssignmentPlan computePlanForNormalPriorityModels(
169169
.collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().getTargetAllocations()));
170170
return new AssignmentPlan.Deployment(
171171
assignment.getDeploymentId(),
172+
assignment.getModelId(),
172173
assignment.getTaskParams().getModelBytes(),
173174
assignment.getTaskParams().getNumberOfAllocations(),
174175
assignment.getTaskParams().getThreadsPerAllocation(),
@@ -185,6 +186,7 @@ private AssignmentPlan computePlanForNormalPriorityModels(
185186
planDeployments.add(
186187
new AssignmentPlan.Deployment(
187188
taskParams.getDeploymentId(),
189+
taskParams.getModelId(),
188190
taskParams.getModelBytes(),
189191
taskParams.getNumberOfAllocations(),
190192
taskParams.getThreadsPerAllocation(),
@@ -225,6 +227,7 @@ private AssignmentPlan computePlanForLowPriorityModels(Set<String> assignableNod
225227
.map(
226228
assignment -> new AssignmentPlan.Deployment(
227229
assignment.getDeploymentId(),
230+
assignment.getModelId(),
228231
assignment.getTaskParams().getModelBytes(),
229232
assignment.getTaskParams().getNumberOfAllocations(),
230233
assignment.getTaskParams().getThreadsPerAllocation(),
@@ -242,6 +245,7 @@ private AssignmentPlan computePlanForLowPriorityModels(Set<String> assignableNod
242245
planDeployments.add(
243246
new AssignmentPlan.Deployment(
244247
taskParams.getDeploymentId(),
248+
taskParams.getModelId(),
245249
taskParams.getModelBytes(),
246250
taskParams.getNumberOfAllocations(),
247251
taskParams.getThreadsPerAllocation(),

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AbstractPreserveAllocations.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ Deployment modifyModelPreservingPreviousAssignments(Deployment m) {
5555

5656
return new Deployment(
5757
m.deploymentId(),
58+
m.modelId(),
5859
m.memoryBytes(),
5960
m.allocations() - calculatePreservedAllocations(m),
6061
m.threadsPerAllocation(),

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ public class AssignmentPlan implements Comparable<AssignmentPlan> {
4747
*/
4848
public record Deployment(
4949
String deploymentId,
50+
String modelId,
5051
long memoryBytes,
5152
int allocations,
5253
int threadsPerAllocation,
@@ -59,6 +60,7 @@ public record Deployment(
5960
) {
6061
public Deployment(
6162
String deploymentId,
63+
String modelId,
6264
long modelBytes,
6365
int allocations,
6466
int threadsPerAllocation,
@@ -70,6 +72,7 @@ public Deployment(
7072
) {
7173
this(
7274
deploymentId,
75+
modelId,
7376
modelBytes,
7477
allocations,
7578
threadsPerAllocation,
@@ -96,7 +99,7 @@ boolean hasEverBeenAllocated() {
9699

97100
public long estimateMemoryUsageBytes(int allocations) {
98101
return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(
99-
deploymentId,
102+
modelId,
100103
memoryBytes,
101104
perDeploymentMemoryBytes,
102105
perAllocationMemoryBytes,
@@ -106,24 +109,23 @@ public long estimateMemoryUsageBytes(int allocations) {
106109

107110
long estimateAdditionalMemoryUsageBytes(int allocationsOld, int allocationsNew) {
108111
return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(
109-
deploymentId,
112+
modelId,
110113
memoryBytes,
111114
perDeploymentMemoryBytes,
112115
perAllocationMemoryBytes,
113116
allocationsNew
114117
) - StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(
115-
deploymentId,
118+
modelId,
116119
memoryBytes,
117120
perDeploymentMemoryBytes,
118121
perAllocationMemoryBytes,
119122
allocationsOld
120123
);
121-
122124
}
123125

124126
long minimumMemoryRequiredBytes() {
125127
return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(
126-
deploymentId,
128+
modelId,
127129
memoryBytes,
128130
perDeploymentMemoryBytes,
129131
perAllocationMemoryBytes,

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ private AssignmentPlan solveAllocatingAtLeastOnceModelsThatWerePreviouslyAllocat
115115
.map(
116116
m -> new AssignmentPlan.Deployment(
117117
m.deploymentId(),
118+
m.modelId(),
118119
m.memoryBytes(),
119120
1,
120121
m.threadsPerAllocation(),
@@ -148,6 +149,7 @@ private AssignmentPlan solveAllocatingAtLeastOnceModelsThatWerePreviouslyAllocat
148149
: Map.of();
149150
return new AssignmentPlan.Deployment(
150151
m.deploymentId(),
152+
m.modelId(),
151153
m.memoryBytes(),
152154
m.allocations(),
153155
m.threadsPerAllocation(),

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlanner.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ private AssignmentPlan computeZonePlan(
127127
d -> new AssignmentPlan.Deployment(
128128
// replace each deployment with a new deployment
129129
d.deploymentId(),
130+
d.modelId(),
130131
d.memoryBytes(),
131132
deploymentIdToTargetAllocationsPerZone.get(d.deploymentId()),
132133
d.threadsPerAllocation(),
@@ -163,6 +164,7 @@ private AssignmentPlan computePlanAcrossAllNodes(List<AssignmentPlan> plans) {
163164
.map(
164165
d -> new AssignmentPlan.Deployment(
165166
d.deploymentId(),
167+
d.modelId(),
166168
d.memoryBytes(),
167169
d.allocations(),
168170
d.threadsPerAllocation(),

0 commit comments

Comments
 (0)