Skip to content

Commit e5c8da6

Browse files
committed
Fix estimated memory usage for a model with zero allocations.
1 parent 8e26d18 commit e5c8da6

File tree

2 files changed

+79
-0
lines changed

2 files changed

+79
-0
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,9 @@ public String getDeploymentId() {
623623
* @return the estimated memory (in bytes) required for the model deployment to run
624624
*/
625625
public long estimateMemoryUsageBytes() {
626+
if (numberOfAllocations == 0) {
627+
return 0;
628+
}
626629
// We already take into account 2x the model bytes. If the cache size is larger than the model bytes, then
627630
// we need to take it into account when returning the estimate.
628631
if (cacheSize != null && cacheSize.getBytes() > modelBytes) {

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlAutoscalingResourceTrackerTests.java

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.elasticsearch.xpack.core.ml.action.OpenJobAction;
2323
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
2424
import org.elasticsearch.xpack.core.ml.autoscaling.MlAutoscalingStats;
25+
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
2526
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState;
2627
import org.elasticsearch.xpack.core.ml.inference.assignment.Priority;
2728
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo;
@@ -1800,6 +1801,81 @@ public void testGetMemoryAndProcessorsScaleDownNotPreventedByDummyEntityAsMemory
18001801
);
18011802
}
18021803

1804+
public void testGetMemoryAndProcessorsScaleDownForModelWithZeroAllocations() throws InterruptedException {
1805+
long memory = 1000000000;
1806+
Map<String, String> nodeAttr = Map.of(
1807+
MachineLearning.MACHINE_MEMORY_NODE_ATTR,
1808+
Long.toString(memory),
1809+
MachineLearning.MAX_JVM_SIZE_NODE_ATTR,
1810+
"400000000",
1811+
MachineLearning.ML_CONFIG_VERSION_NODE_ATTR,
1812+
"7.2.0",
1813+
MachineLearning.ALLOCATED_PROCESSORS_NODE_ATTR,
1814+
"2.0"
1815+
);
1816+
1817+
MlAutoscalingContext mlAutoscalingContext = new MlAutoscalingContext(
1818+
List.of(),
1819+
List.of(),
1820+
List.of(),
1821+
Map.of(
1822+
"model-with-zero-allocations",
1823+
TrainedModelAssignment.Builder.empty(
1824+
new StartTrainedModelDeploymentAction.TaskParams(
1825+
"model-with-zero-allocations",
1826+
"model-with-zero-allocations-deployment",
1827+
400,
1828+
0,
1829+
2,
1830+
100,
1831+
null,
1832+
Priority.NORMAL,
1833+
0L,
1834+
0L
1835+
),
1836+
new AdaptiveAllocationsSettings(true, 0, 4)
1837+
).build()
1838+
),
1839+
List.of(
1840+
DiscoveryNodeUtils.builder("ml-node-1")
1841+
.name("ml-node-name-1")
1842+
.address(new TransportAddress(InetAddress.getLoopbackAddress(), 9300))
1843+
.attributes(nodeAttr)
1844+
.roles(Set.of(DiscoveryNodeRole.ML_ROLE))
1845+
.build()
1846+
),
1847+
PersistentTasksCustomMetadata.builder().build()
1848+
);
1849+
MlMemoryTracker mockTracker = mock(MlMemoryTracker.class);
1850+
1851+
this.<MlAutoscalingStats>assertAsync(
1852+
listener -> MlAutoscalingResourceTracker.getMemoryAndProcessors(
1853+
mlAutoscalingContext,
1854+
mockTracker,
1855+
Map.of("ml-node-1", memory),
1856+
600000000,
1857+
2,
1858+
MachineLearning.DEFAULT_MAX_OPEN_JOBS_PER_NODE,
1859+
MlDummyAutoscalingEntity.of(0, 0),
1860+
1,
1861+
listener
1862+
),
1863+
stats -> {
1864+
assertEquals(memory, stats.currentPerNodeMemoryBytes());
1865+
assertEquals(0, stats.currentTotalModelMemoryBytes());
1866+
assertEquals(0, stats.currentTotalProcessorsInUse());
1867+
assertEquals(1, stats.currentTotalNodes());
1868+
assertEquals(0, stats.wantedMinNodes());
1869+
assertEquals(0, stats.wantedExtraPerNodeNodeProcessors());
1870+
assertEquals(0, stats.wantedExtraProcessors());
1871+
assertEquals(0, stats.wantedExtraModelMemoryBytes());
1872+
assertEquals(0, stats.wantedExtraPerNodeMemoryBytes());
1873+
assertEquals(memory, stats.unwantedNodeMemoryBytesToRemove());
1874+
assertEquals(MachineLearning.NATIVE_EXECUTABLE_CODE_OVERHEAD.getBytes(), stats.currentPerNodeMemoryOverheadBytes());
1875+
}
1876+
);
1877+
}
1878+
18031879
private <T> void assertAsync(Consumer<ActionListener<T>> function, Consumer<T> furtherTests) throws InterruptedException {
18041880
CountDownLatch latch = new CountDownLatch(1);
18051881
AtomicBoolean listenerCalled = new AtomicBoolean(false);

0 commit comments

Comments
 (0)