Skip to content

Commit 92ccd73

Browse files
authored
[ML] Include inference process's RSS memory stat in /_ml/trained_models/_stats output (elastic#142312)
* Measure rss memory returned from ml-cpp * Change memory stat to average * Add bytes unit to variable names * Fix BWC test * Regenerate transport version * Fix unit test and apply spotless * Fix serverless build
1 parent 9852d24 commit 92ccd73

File tree

16 files changed

+343
-56
lines changed

16 files changed

+343
-56
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
9295000
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
dfs_search_timed_out,9294000
1+
assignment_stats_memory_stat,9295000

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStats.java

Lines changed: 110 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.xpack.core.ml.inference.assignment;
99

10+
import org.elasticsearch.TransportVersion;
1011
import org.elasticsearch.cluster.node.DiscoveryNode;
1112
import org.elasticsearch.common.Strings;
1213
import org.elasticsearch.common.io.stream.StreamInput;
@@ -27,6 +28,8 @@
2728

2829
public class AssignmentStats implements ToXContentObject, Writeable {
2930

31+
public static final TransportVersion MEMORY_STAT_TRANSPORT_VERSION = TransportVersion.fromName("assignment_stats_memory_stat");
32+
3033
public static class NodeStats implements ToXContentObject, Writeable {
3134
private final DiscoveryNode node;
3235
private final Long inferenceCount;
@@ -46,6 +49,7 @@ public static class NodeStats implements ToXContentObject, Writeable {
4649
private final long throughputLastPeriod;
4750
private final Double avgInferenceTimeLastPeriod;
4851
private final Long cacheHitCountLastPeriod;
52+
private final Long avgInferenceProcessMemoryRssBytes;
4953

5054
public static AssignmentStats.NodeStats forStartedState(
5155
DiscoveryNode node,
@@ -88,6 +92,49 @@ public static AssignmentStats.NodeStats forStartedState(
8892
);
8993
}
9094

95+
public static AssignmentStats.NodeStats forStartedState(
96+
DiscoveryNode node,
97+
long inferenceCount,
98+
Double avgInferenceTime,
99+
Double avgInferenceTimeExcludingCacheHit,
100+
int pendingCount,
101+
int errorCount,
102+
long cacheHitCount,
103+
int rejectedExecutionCount,
104+
int timeoutCount,
105+
Instant lastAccess,
106+
Instant startTime,
107+
Integer threadsPerAllocation,
108+
Integer numberOfAllocations,
109+
long peakThroughput,
110+
long throughputLastPeriod,
111+
Double avgInferenceTimeLastPeriod,
112+
long cacheHitCountLastPeriod,
113+
Long avgInferenceProcessMemoryRssBytes
114+
) {
115+
return new AssignmentStats.NodeStats(
116+
node,
117+
inferenceCount,
118+
avgInferenceTime,
119+
avgInferenceTimeExcludingCacheHit,
120+
lastAccess,
121+
pendingCount,
122+
errorCount,
123+
cacheHitCount,
124+
rejectedExecutionCount,
125+
timeoutCount,
126+
new RoutingStateAndReason(RoutingState.STARTED, null),
127+
Objects.requireNonNull(startTime),
128+
threadsPerAllocation,
129+
numberOfAllocations,
130+
peakThroughput,
131+
throughputLastPeriod,
132+
avgInferenceTimeLastPeriod,
133+
cacheHitCountLastPeriod,
134+
avgInferenceProcessMemoryRssBytes
135+
);
136+
}
137+
91138
public static AssignmentStats.NodeStats forNotStartedState(DiscoveryNode node, RoutingState state, String reason) {
92139
return new AssignmentStats.NodeStats(
93140
node,
@@ -107,6 +154,7 @@ public static AssignmentStats.NodeStats forNotStartedState(DiscoveryNode node, R
107154
0L,
108155
0L,
109156
null,
157+
null,
110158
null
111159
);
112160
}
@@ -129,7 +177,8 @@ public NodeStats(
129177
long peakThroughput,
130178
long throughputLastPeriod,
131179
Double avgInferenceTimeLastPeriod,
132-
Long cacheHitCountLastPeriod
180+
Long cacheHitCountLastPeriod,
181+
Long avgInferenceProcessMemoryRssBytes
133182
) {
134183
this.node = node;
135184
this.inferenceCount = inferenceCount;
@@ -149,11 +198,55 @@ public NodeStats(
149198
this.throughputLastPeriod = throughputLastPeriod;
150199
this.avgInferenceTimeLastPeriod = avgInferenceTimeLastPeriod;
151200
this.cacheHitCountLastPeriod = cacheHitCountLastPeriod;
201+
this.avgInferenceProcessMemoryRssBytes = avgInferenceProcessMemoryRssBytes;
152202

153203
// if lastAccess time is null there have been no inferences
154204
assert this.lastAccess != null || (inferenceCount == null || inferenceCount == 0);
155205
}
156206

207+
public NodeStats(
208+
DiscoveryNode node,
209+
Long inferenceCount,
210+
Double avgInferenceTime,
211+
Double avgInferenceTimeExcludingCacheHit,
212+
@Nullable Instant lastAccess,
213+
Integer pendingCount,
214+
int errorCount,
215+
Long cacheHitCount,
216+
int rejectedExecutionCount,
217+
int timeoutCount,
218+
RoutingStateAndReason routingState,
219+
@Nullable Instant startTime,
220+
@Nullable Integer threadsPerAllocation,
221+
@Nullable Integer numberOfAllocations,
222+
long peakThroughput,
223+
long throughputLastPeriod,
224+
Double avgInferenceTimeLastPeriod,
225+
Long cacheHitCountLastPeriod
226+
) {
227+
this(
228+
node,
229+
inferenceCount,
230+
avgInferenceTime,
231+
avgInferenceTimeExcludingCacheHit,
232+
lastAccess,
233+
pendingCount,
234+
errorCount,
235+
cacheHitCount,
236+
rejectedExecutionCount,
237+
timeoutCount,
238+
routingState,
239+
startTime,
240+
threadsPerAllocation,
241+
numberOfAllocations,
242+
peakThroughput,
243+
throughputLastPeriod,
244+
avgInferenceTimeLastPeriod,
245+
cacheHitCountLastPeriod,
246+
null
247+
);
248+
}
249+
157250
public NodeStats(StreamInput in) throws IOException {
158251
this.node = in.readOptionalWriteable(DiscoveryNode::new);
159252
this.inferenceCount = in.readOptionalLong();
@@ -173,6 +266,11 @@ public NodeStats(StreamInput in) throws IOException {
173266
this.cacheHitCount = in.readOptionalVLong();
174267
this.cacheHitCountLastPeriod = in.readOptionalVLong();
175268
this.avgInferenceTimeExcludingCacheHit = in.readOptionalDouble();
269+
if (in.getTransportVersion().supports(MEMORY_STAT_TRANSPORT_VERSION)) {
270+
this.avgInferenceProcessMemoryRssBytes = in.readOptionalVLong();
271+
} else {
272+
this.avgInferenceProcessMemoryRssBytes = null;
273+
}
176274

177275
}
178276

@@ -260,14 +358,17 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
260358
if (inferenceCount != null) {
261359
builder.field("inference_count", inferenceCount);
262360
}
263-
// avoid reporting the average time as 0 if count < 1
361+
// avoid reporting averages as 0 if count < 1
264362
if (inferenceCount != null && inferenceCount > 0) {
265363
if (avgInferenceTime != null) {
266364
builder.field("average_inference_time_ms", avgInferenceTime);
267365
}
268366
if (avgInferenceTimeExcludingCacheHit != null) {
269367
builder.field("average_inference_time_ms_excluding_cache_hits", avgInferenceTimeExcludingCacheHit);
270368
}
369+
if (avgInferenceProcessMemoryRssBytes != null) {
370+
builder.field("average_inference_process_memory_rss_bytes", avgInferenceProcessMemoryRssBytes);
371+
}
271372
}
272373
if (cacheHitCount != null) {
273374
builder.field("inference_cache_hit_count", cacheHitCount);
@@ -329,6 +430,9 @@ public void writeTo(StreamOutput out) throws IOException {
329430
out.writeOptionalVLong(cacheHitCount);
330431
out.writeOptionalVLong(cacheHitCountLastPeriod);
331432
out.writeOptionalDouble(avgInferenceTimeExcludingCacheHit);
433+
if (out.getTransportVersion().supports(MEMORY_STAT_TRANSPORT_VERSION)) {
434+
out.writeOptionalVLong(avgInferenceProcessMemoryRssBytes);
435+
}
332436
}
333437

334438
@Override
@@ -353,7 +457,8 @@ public boolean equals(Object o) {
353457
&& Objects.equals(peakThroughput, that.peakThroughput)
354458
&& Objects.equals(throughputLastPeriod, that.throughputLastPeriod)
355459
&& Objects.equals(avgInferenceTimeLastPeriod, that.avgInferenceTimeLastPeriod)
356-
&& Objects.equals(cacheHitCountLastPeriod, that.cacheHitCountLastPeriod);
460+
&& Objects.equals(cacheHitCountLastPeriod, that.cacheHitCountLastPeriod)
461+
&& Objects.equals(avgInferenceProcessMemoryRssBytes, that.avgInferenceProcessMemoryRssBytes);
357462
}
358463

359464
@Override
@@ -376,7 +481,8 @@ public int hashCode() {
376481
peakThroughput,
377482
throughputLastPeriod,
378483
avgInferenceTimeLastPeriod,
379-
cacheHitCountLastPeriod
484+
cacheHitCountLastPeriod,
485+
avgInferenceProcessMemoryRssBytes
380486
);
381487
}
382488
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsActionResponseTests.java

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.elasticsearch.xpack.core.action.util.QueryPage;
1414
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
1515
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction.Response;
16+
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
1617
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStatsTests;
1718
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStatsTests;
1819
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelSizeStatsTests;
@@ -24,6 +25,7 @@
2425
import java.util.stream.Stream;
2526

2627
import static org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction.Response.RESULTS_FIELD;
28+
import static org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats.MEMORY_STAT_TRANSPORT_VERSION;
2729

2830
public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSerializationTestCase<Response> {
2931

@@ -93,6 +95,67 @@ protected Writeable.Reader<Response> instanceReader() {
9395

9496
@Override
9597
protected Response mutateInstanceForVersion(Response instance, TransportVersion version) {
98+
if (version.supports(MEMORY_STAT_TRANSPORT_VERSION) == false) {
99+
return new Response(
100+
new QueryPage<>(
101+
instance.getResources()
102+
.results()
103+
.stream()
104+
.map(
105+
stats -> new Response.TrainedModelStats(
106+
stats.getModelId(),
107+
stats.getModelSizeStats(),
108+
stats.getIngestStats(),
109+
stats.getPipelineCount(),
110+
stats.getInferenceStats(),
111+
stats.getDeploymentStats() == null
112+
? null
113+
: new AssignmentStats(
114+
stats.getDeploymentStats().getDeploymentId(),
115+
stats.getDeploymentStats().getModelId(),
116+
stats.getDeploymentStats().getThreadsPerAllocation(),
117+
stats.getDeploymentStats().getNumberOfAllocations(),
118+
stats.getDeploymentStats().getAdaptiveAllocationsSettings(),
119+
stats.getDeploymentStats().getQueueCapacity(),
120+
stats.getDeploymentStats().getCacheSize(),
121+
stats.getDeploymentStats().getStartTime(),
122+
stats.getDeploymentStats()
123+
.getNodeStats()
124+
.stream()
125+
.map(
126+
nodeStats -> new AssignmentStats.NodeStats(
127+
nodeStats.getNode(),
128+
nodeStats.getInferenceCount().orElse(null),
129+
nodeStats.getAvgInferenceTime().orElse(null),
130+
nodeStats.getAvgInferenceTimeExcludingCacheHit().orElse(null),
131+
nodeStats.getLastAccess(),
132+
nodeStats.getPendingCount(),
133+
nodeStats.getErrorCount(),
134+
nodeStats.getCacheHitCount().orElse(null),
135+
nodeStats.getRejectedExecutionCount(),
136+
nodeStats.getTimeoutCount(),
137+
nodeStats.getRoutingState(),
138+
nodeStats.getStartTime(),
139+
nodeStats.getThreadsPerAllocation(),
140+
nodeStats.getNumberOfAllocations(),
141+
nodeStats.getPeakThroughput(),
142+
nodeStats.getThroughputLastPeriod(),
143+
nodeStats.getAvgInferenceTimeLastPeriod(),
144+
nodeStats.getCacheHitCountLastPeriod().orElse(null),
145+
null // avgInferenceProcessMemoryRssBytes is null for old versions
146+
)
147+
)
148+
.toList(),
149+
stats.getDeploymentStats().getPriority()
150+
)
151+
)
152+
)
153+
.toList(),
154+
instance.getResources().count(),
155+
RESULTS_FIELD
156+
)
157+
);
158+
}
96159
return instance;
97160
}
98161
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStatsTests.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ public static AssignmentStats.NodeStats randomNodeStats(DiscoveryNode node) {
9191
randomIntBetween(0, 100),
9292
randomIntBetween(0, 100),
9393
avgInferenceTimeLastPeriod,
94+
randomLongBetween(0, 100),
9495
randomLongBetween(0, 100)
9596
);
9697
}
@@ -125,7 +126,8 @@ public void testGetOverallInferenceStats() {
125126
randomNonNegativeLong(),
126127
randomNonNegativeLong(),
127128
null,
128-
1L
129+
1L,
130+
randomNonNegativeLong()
129131
),
130132
AssignmentStats.NodeStats.forStartedState(
131133
DiscoveryNodeUtils.create("node_started_2"),
@@ -144,7 +146,8 @@ public void testGetOverallInferenceStats() {
144146
randomNonNegativeLong(),
145147
randomNonNegativeLong(),
146148
null,
147-
1L
149+
1L,
150+
randomNonNegativeLong()
148151
),
149152
AssignmentStats.NodeStats.forNotStartedState(
150153
DiscoveryNodeUtils.create("node_not_started_3"),

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsAction.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,8 @@ protected void taskOperation(
332332
presentValue.peakThroughput(),
333333
presentValue.throughputLastPeriod(),
334334
presentValue.avgInferenceTimeLastPeriod(),
335-
presentValue.cacheHitCountLastPeriod()
335+
presentValue.cacheHitCountLastPeriod(),
336+
presentValue.avgInferenceProcessMemoryRssBytes()
336337
)
337338
);
338339
} else {

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,8 @@ public Optional<ModelStats> getStats(TrainedModelDeploymentTask task) {
130130
stats.peakThroughput(),
131131
recentStats.requestsProcessed(),
132132
recentStats.avgInferenceTime(),
133-
recentStats.cacheHitCount()
133+
recentStats.cacheHitCount(),
134+
Math.round(stats.inferenceProcessMemoryRssBytesStats().getAverage())
134135
);
135136
});
136137
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/ModelStats.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,6 @@ public record ModelStats(
2525
long peakThroughput,
2626
long throughputLastPeriod,
2727
Double avgInferenceTimeLastPeriod,
28-
long cacheHitCountLastPeriod
28+
long cacheHitCountLastPeriod,
29+
Long avgInferenceProcessMemoryRssBytes
2930
) {}

0 commit comments

Comments
 (0)