diff --git a/docs/reference/rest-api/usage.asciidoc b/docs/reference/rest-api/usage.asciidoc index a54dbe21b46c6..4a8895807f2fa 100644 --- a/docs/reference/rest-api/usage.asciidoc +++ b/docs/reference/rest-api/usage.asciidoc @@ -195,7 +195,13 @@ GET /_xpack/usage } } }, - "node_count" : 1 + "node_count" : 1, + "memory": { + anomaly_detectors_memory_bytes: 0, + data_frame_analytics_memory_bytes: 0, + pytorch_inference_memory_bytes: 0, + total_used_memory_bytes: 0 + } }, "inference": { "available" : true, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningFeatureSetUsage.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningFeatureSetUsage.java index 98c31dd9106d0..60484675ec90b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningFeatureSetUsage.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningFeatureSetUsage.java @@ -31,11 +31,13 @@ public class MachineLearningFeatureSetUsage extends XPackFeatureSet.Usage { public static final String NODE_COUNT = "node_count"; public static final String DATA_FRAME_ANALYTICS_JOBS_FIELD = "data_frame_analytics_jobs"; public static final String INFERENCE_FIELD = "inference"; + public static final String MEMORY_FIELD = "memory"; private final Map jobsUsage; private final Map datafeedsUsage; private final Map analyticsUsage; private final Map inferenceUsage; + private final Map memoryUsage; private final int nodeCount; public MachineLearningFeatureSetUsage( @@ -45,6 +47,7 @@ public MachineLearningFeatureSetUsage( Map datafeedsUsage, Map analyticsUsage, Map inferenceUsage, + Map memoryUsage, int nodeCount ) { super(XPackField.MACHINE_LEARNING, available, enabled); @@ -52,6 +55,7 @@ public MachineLearningFeatureSetUsage( this.datafeedsUsage = Objects.requireNonNull(datafeedsUsage); this.analyticsUsage = Objects.requireNonNull(analyticsUsage); this.inferenceUsage = Objects.requireNonNull(inferenceUsage); + this.memoryUsage = Objects.requireNonNull(memoryUsage); this.nodeCount = nodeCount; } @@ -62,6 +66,11 @@ public MachineLearningFeatureSetUsage(StreamInput in) throws IOException { this.analyticsUsage = in.readGenericMap(); this.inferenceUsage = in.readGenericMap(); this.nodeCount = in.readInt(); + if (in.getTransportVersion().onOrAfter(TransportVersions.ML_TELEMETRY_MEMORY_ADDED)) { + this.memoryUsage = in.readGenericMap(); + } else { + this.memoryUsage = Map.of(); + } } @Override @@ -77,6 +86,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeGenericMap(analyticsUsage); out.writeGenericMap(inferenceUsage); out.writeInt(nodeCount); + if (out.getTransportVersion().onOrAfter(TransportVersions.ML_TELEMETRY_MEMORY_ADDED)) { + out.writeGenericMap(memoryUsage); + } } @Override @@ -86,9 +98,51 @@ protected void innerXContent(XContentBuilder builder, Params params) throws IOEx builder.field(DATAFEEDS_FIELD, datafeedsUsage); builder.field(DATA_FRAME_ANALYTICS_JOBS_FIELD, analyticsUsage); builder.field(INFERENCE_FIELD, inferenceUsage); + builder.field(MEMORY_FIELD, memoryUsage); if (nodeCount >= 0) { builder.field(NODE_COUNT, nodeCount); } } + public Map getJobsUsage() { + return jobsUsage; + } + + public Map getDatafeedsUsage() { + return datafeedsUsage; + } + + public Map getAnalyticsUsage() { + return analyticsUsage; + } + + public Map getInferenceUsage() { + return inferenceUsage; + } + + public Map getMemoryUsage() { + return memoryUsage; + } + + public int getNodeCount() { + return nodeCount; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + MachineLearningFeatureSetUsage that = (MachineLearningFeatureSetUsage) o; + return nodeCount == that.nodeCount + && Objects.equals(jobsUsage, that.jobsUsage) + && Objects.equals(datafeedsUsage, that.datafeedsUsage) + && Objects.equals(analyticsUsage, that.analyticsUsage) + && Objects.equals(inferenceUsage, that.inferenceUsage) + && Objects.equals(memoryUsage, that.memoryUsage); + } + + @Override + public int hashCode() { + return Objects.hash(jobsUsage, datafeedsUsage, analyticsUsage, inferenceUsage, memoryUsage, nodeCount); + } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/MachineLearningFeatureSetUsageTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/MachineLearningFeatureSetUsageTests.java new file mode 100644 index 0000000000000..87d658c6f983c --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/MachineLearningFeatureSetUsageTests.java @@ -0,0 +1,75 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Tuple; + +import java.io.IOException; +import java.util.Collections; + +public class MachineLearningFeatureSetUsageTests extends AbstractBWCWireSerializationTestCase { + @Override + protected Writeable.Reader instanceReader() { + return MachineLearningFeatureSetUsage::new; + } + + @Override + protected MachineLearningFeatureSetUsage createTestInstance() { + boolean enabled = randomBoolean(); + + if (enabled == false) { + return new MachineLearningFeatureSetUsage( + randomBoolean(), + enabled, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap(), + 0 + ); + } else { + return new MachineLearningFeatureSetUsage( + randomBoolean(), + enabled, + randomMap(0, 4, () -> new Tuple<>(randomAlphaOfLength(4), randomAlphaOfLength(4))), + randomMap(0, 4, () -> new Tuple<>(randomAlphaOfLength(4), randomAlphaOfLength(4))), + randomMap(0, 4, () -> new Tuple<>(randomAlphaOfLength(4), randomAlphaOfLength(4))), + randomMap(0, 4, () -> new Tuple<>(randomAlphaOfLength(4), randomAlphaOfLength(4))), + randomMap(0, 4, () -> new Tuple<>(randomAlphaOfLength(4), randomAlphaOfLength(4))), + randomIntBetween(1, 10) + ); + } + } + + @Override + protected MachineLearningFeatureSetUsage mutateInstance(MachineLearningFeatureSetUsage instance) throws IOException { + return null; + } + + @Override + protected MachineLearningFeatureSetUsage mutateInstanceForVersion(MachineLearningFeatureSetUsage instance, TransportVersion version) { + if (version.before(TransportVersions.ML_TELEMETRY_MEMORY_ADDED)) { + return new MachineLearningFeatureSetUsage( + instance.available(), + instance.enabled(), + instance.getJobsUsage(), + instance.getDatafeedsUsage(), + instance.getAnalyticsUsage(), + instance.getInferenceUsage(), + Collections.emptyMap(), + instance.getNodeCount() + ); + } + + return instance; + } +} diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java index 34ef0baecccc5..4e92cad1026a3 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java @@ -1120,6 +1120,28 @@ public void testStartMultipleLowPriorityDeployments() throws Exception { } } + @SuppressWarnings("unchecked") + public void testDeploymentThreadsIncludedInUsage() throws IOException { + String modelId = "deployment_threads_in_usage"; + createPassThroughModel(modelId); + putModelDefinition(modelId); + putVocabulary(List.of("these", "are", "my", "words"), modelId); + startDeployment(modelId); + + Request request = new Request("GET", "/_xpack/usage"); + var usage = entityAsMap(client().performRequest(request).getEntity()); + + var ml = (Map) usage.get("ml"); + assertNotNull(usage.toString(), ml); + var inference = (Map) ml.get("inference"); + var deployments = (Map) inference.get("deployments"); + var deploymentStats = (List>) deployments.get("stats_by_model"); + for (var stat : deploymentStats) { + assertThat(stat.toString(), (Integer) stat.get("num_threads"), greaterThanOrEqualTo(1)); + assertThat(stat.toString(), (Integer) stat.get("num_allocations"), greaterThanOrEqualTo(1)); + } + } + private void putModelDefinition(String modelId) throws IOException { putModelDefinition(modelId, BASE_64_ENCODED_MODEL, RAW_MODEL_SIZE); } diff --git a/x-pack/plugin/ml/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlUsageIT.java b/x-pack/plugin/ml/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlUsageIT.java new file mode 100644 index 0000000000000..05a307c2dfad3 --- /dev/null +++ b/x-pack/plugin/ml/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlUsageIT.java @@ -0,0 +1,35 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.integration; + +import org.elasticsearch.client.Request; +import org.elasticsearch.test.rest.ESRestTestCase; + +import java.io.IOException; +import java.util.Map; + +import static org.hamcrest.Matchers.greaterThanOrEqualTo; + +// Test the phone home/telemetry data +public class MlUsageIT extends ESRestTestCase { + + @SuppressWarnings("unchecked") + public void testMLUsage() throws IOException { + Request request = new Request("GET", "/_xpack/usage"); + var usage = entityAsMap(client().performRequest(request).getEntity()); + + var ml = (Map) usage.get("ml"); + assertNotNull(usage.toString(), ml); + var memoryUsage = (Map) ml.get("memory"); + assertNotNull(ml.toString(), memoryUsage); + assertThat(memoryUsage.toString(), (Integer) memoryUsage.get("anomaly_detectors_memory_bytes"), greaterThanOrEqualTo(0)); + assertThat(memoryUsage.toString(), (Integer) memoryUsage.get("data_frame_analytics_memory_bytes"), greaterThanOrEqualTo(0)); + assertThat(memoryUsage.toString(), (Integer) memoryUsage.get("pytorch_inference_memory_bytes"), greaterThanOrEqualTo(0)); + assertThat(memoryUsage.toString(), (Integer) memoryUsage.get("total_used_memory_bytes"), greaterThanOrEqualTo(0)); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearningUsageTransportAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearningUsageTransportAction.java index 583965e76e542..40e3fbb661db1 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearningUsageTransportAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearningUsageTransportAction.java @@ -39,6 +39,7 @@ import org.elasticsearch.xpack.core.ml.action.GetJobsStatsAction; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction; +import org.elasticsearch.xpack.core.ml.action.MlMemoryAction; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; @@ -65,6 +66,7 @@ import java.util.Map; import java.util.Objects; import java.util.TreeMap; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import java.util.stream.Collectors; @@ -72,16 +74,20 @@ public class MachineLearningUsageTransportAction extends XPackUsageFeatureTransportAction { - private static class ModelStats { + private static class DeploymentStats { private final String modelId; private final String taskType; private final StatsAccumulator inferenceCounts = new StatsAccumulator(); private Instant lastAccess; + private final int numThreads; + private final int numAllocations; - ModelStats(String modelId, String taskType) { + DeploymentStats(String modelId, String taskType, int numThreads, int numAllocations) { this.modelId = modelId; this.taskType = taskType; + this.numThreads = numThreads; + this.numAllocations = numAllocations; } void update(AssignmentStats.NodeStats stats) { @@ -95,6 +101,8 @@ Map asMap() { Map result = new HashMap<>(); result.put("model_id", modelId); result.put("task_type", taskType); + result.put("num_allocations", numAllocations); + result.put("num_threads", numThreads); result.put("inference_counts", inferenceCounts.asMap()); if (lastAccess != null) { result.put("last_access", lastAccess.toString()); @@ -158,6 +166,7 @@ protected void masterOperation( Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap(), + Collections.emptyMap(), 0 ); listener.onResponse(new XPackUsageFeatureResponse(usage)); @@ -167,11 +176,14 @@ protected void masterOperation( Map jobsUsage = new LinkedHashMap<>(); Map datafeedsUsage = new LinkedHashMap<>(); Map analyticsUsage = new LinkedHashMap<>(); + AtomicReference> inferenceUsage = new AtomicReference<>(Map.of()); + int nodeCount = mlNodeCount(state); - // Step 5. return final ML usage - ActionListener> inferenceUsageListener = ActionListener.wrap( - inferenceUsage -> listener.onResponse( + // Step 6. return final ML usage + ActionListener memoryUsageListener = ActionListener.wrap(memoryResponse -> { + var memoryUsage = extractMemoryUsage(memoryResponse); + listener.onResponse( new XPackUsageFeatureResponse( new MachineLearningFeatureSetUsage( MachineLearningField.ML_API_FEATURE.checkWithoutTracking(licenseState), @@ -179,28 +191,38 @@ protected void masterOperation( jobsUsage, datafeedsUsage, analyticsUsage, - inferenceUsage, + inferenceUsage.get(), + memoryUsage, nodeCount ) ) - ), - e -> { - logger.warn("Failed to get inference usage to include in ML usage", e); - listener.onResponse( - new XPackUsageFeatureResponse( - new MachineLearningFeatureSetUsage( - MachineLearningField.ML_API_FEATURE.checkWithoutTracking(licenseState), - enabled, - jobsUsage, - datafeedsUsage, - analyticsUsage, - Collections.emptyMap(), - nodeCount - ) + ); + }, e -> { + logger.warn("Failed to get memory usage to include in ML usage", e); + listener.onResponse( + new XPackUsageFeatureResponse( + new MachineLearningFeatureSetUsage( + MachineLearningField.ML_API_FEATURE.checkWithoutTracking(licenseState), + enabled, + jobsUsage, + datafeedsUsage, + analyticsUsage, + inferenceUsage.get(), + Collections.emptyMap(), + nodeCount ) - ); - } - ); + ) + ); + }); + + // Step 5. Get + ActionListener> inferenceUsageListener = ActionListener.wrap(inference -> { + inferenceUsage.set(inference); + client.execute(MlMemoryAction.INSTANCE, new MlMemoryAction.Request("_all"), memoryUsageListener); + }, e -> { + logger.warn("Failed to get inference usage to include in ML usage", e); + client.execute(MlMemoryAction.INSTANCE, new MlMemoryAction.Request("_all"), memoryUsageListener); + }); // Step 4. Extract usage from data frame analytics configs and then get inference usage ActionListener dataframeAnalyticsListener = ActionListener.wrap(response -> { @@ -464,7 +486,7 @@ private static void addDeploymentStats( int deploymentsCount = 0; double avgTimeSum = 0.0; StatsAccumulator nodeDistribution = new StatsAccumulator(); - Map statsByModel = new TreeMap<>(); + Map statsByModel = new TreeMap<>(); for (var stats : statsResponse.getResources().results()) { AssignmentStats deploymentStats = stats.getDeploymentStats(); if (deploymentStats == null) { @@ -478,7 +500,15 @@ private static void addDeploymentStats( String modelId = deploymentStats.getModelId(); String taskType = taskTypes.get(deploymentStats.getModelId()); String mapKey = modelId + ":" + taskType; - ModelStats modelStats = statsByModel.computeIfAbsent(mapKey, key -> new ModelStats(modelId, taskType)); + DeploymentStats modelStats = statsByModel.computeIfAbsent( + mapKey, + key -> new DeploymentStats( + modelId, + taskType, + deploymentStats.getThreadsPerAllocation(), + deploymentStats.getNumberOfAllocations() + ) + ); for (var nodeStats : deploymentStats.getNodeStats()) { long nodeInferenceCount = nodeStats.getInferenceCount().orElse(0L); avgTimeSum += nodeStats.getAvgInferenceTime().orElse(0.0) * nodeInferenceCount; @@ -499,7 +529,7 @@ private static void addDeploymentStats( "inference_counts", nodeDistribution.asMap(), "stats_by_model", - statsByModel.values().stream().map(ModelStats::asMap).collect(Collectors.toList()) + statsByModel.values().stream().map(DeploymentStats::asMap).collect(Collectors.toList()) ) ); } @@ -590,6 +620,21 @@ private static void addInferenceIngestUsage(GetTrainedModelsStatsAction.Response inferenceUsage.put("ingest_processors", Collections.singletonMap(MachineLearningFeatureSetUsage.ALL, ingestUsage)); } + private static Map extractMemoryUsage(MlMemoryAction.Response memoryResponse) { + var adMem = memoryResponse.getNodes().stream().mapToLong(mem -> mem.getMlAnomalyDetectors().getBytes()).sum(); + var dfaMem = memoryResponse.getNodes().stream().mapToLong(mem -> mem.getMlDataFrameAnalytics().getBytes()).sum(); + var pytorchMem = memoryResponse.getNodes().stream().mapToLong(mem -> mem.getMlNativeInference().getBytes()).sum(); + var nativeOverheadMem = memoryResponse.getNodes().stream().mapToLong(mem -> mem.getMlNativeCodeOverhead().getBytes()).sum(); + long totalUsedMem = adMem + dfaMem + pytorchMem + nativeOverheadMem; + + var memoryUsage = new LinkedHashMap(); + memoryUsage.put("anomaly_detectors_memory_bytes", adMem); + memoryUsage.put("data_frame_analytics_memory_bytes", dfaMem); + memoryUsage.put("pytorch_inference_memory_bytes", pytorchMem); + memoryUsage.put("total_used_memory_bytes", totalUsedMem); + return memoryUsage; + } + private static Map getMinMaxSumAsLongsFromStats(StatsAccumulator stats) { Map asMap = Maps.newMapWithExpectedSize(3); asMap.put("sum", Double.valueOf(stats.getTotal()).longValue()); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningInfoTransportActionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningInfoTransportActionTests.java index e5575abfeb020..4fdb7d2e5e46c 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningInfoTransportActionTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningInfoTransportActionTests.java @@ -10,9 +10,11 @@ import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.client.internal.Client; +import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodeRole; import org.elasticsearch.cluster.node.DiscoveryNodeUtils; import org.elasticsearch.cluster.node.DiscoveryNodes; @@ -46,6 +48,7 @@ import org.elasticsearch.xpack.core.ml.action.GetJobsStatsAction; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction; +import org.elasticsearch.xpack.core.ml.action.MlMemoryAction; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; @@ -134,6 +137,27 @@ public void init() { new QueryPage<>(Collections.emptyList(), 0, GetTrainedModelsStatsAction.Response.RESULTS_FIELD) ) ); + givenMlMemory( + new MlMemoryAction.Response( + new ClusterName("cluster_foo"), + List.of( + new MlMemoryAction.Response.MlMemoryStats( + mock(DiscoveryNode.class), + ByteSizeValue.ofBytes(100L), + ByteSizeValue.ofBytes(1L), + ByteSizeValue.ofBytes(1L), + ByteSizeValue.ofBytes(1L), + ByteSizeValue.ofBytes(20L), + ByteSizeValue.ofBytes(30L), + ByteSizeValue.ofBytes(40L), + ByteSizeValue.ofBytes(1L), + ByteSizeValue.ofBytes(1L), + ByteSizeValue.ofBytes(1L) + ) + ), + List.of() + ) + ); } @After @@ -343,6 +367,8 @@ public void testUsage() throws Exception { assertThat(source.getValue("inference.deployments.inference_counts.avg"), equalTo(4.0)); assertThat(source.getValue("inference.deployments.stats_by_model.0.model_id"), equalTo("model_3")); assertThat(source.getValue("inference.deployments.stats_by_model.0.task_type"), equalTo("ner")); + assertThat(source.getValue("inference.deployments.stats_by_model.0.num_allocations"), equalTo(8)); + assertThat(source.getValue("inference.deployments.stats_by_model.0.num_threads"), equalTo(1)); assertThat(source.getValue("inference.deployments.stats_by_model.0.last_access"), equalTo(lastAccess(3).toString())); assertThat(source.getValue("inference.deployments.stats_by_model.0.inference_counts.total"), equalTo(3.0)); assertThat(source.getValue("inference.deployments.stats_by_model.0.inference_counts.min"), equalTo(3.0)); @@ -350,6 +376,8 @@ public void testUsage() throws Exception { assertThat(source.getValue("inference.deployments.stats_by_model.0.inference_counts.avg"), equalTo(3.0)); assertThat(source.getValue("inference.deployments.stats_by_model.1.model_id"), equalTo("model_4")); assertThat(source.getValue("inference.deployments.stats_by_model.1.task_type"), equalTo("text_expansion")); + assertThat(source.getValue("inference.deployments.stats_by_model.1.num_allocations"), equalTo(2)); + assertThat(source.getValue("inference.deployments.stats_by_model.1.num_threads"), equalTo(2)); assertThat(source.getValue("inference.deployments.stats_by_model.1.last_access"), equalTo(lastAccess(44).toString())); assertThat(source.getValue("inference.deployments.stats_by_model.1.inference_counts.total"), equalTo(9.0)); assertThat(source.getValue("inference.deployments.stats_by_model.1.inference_counts.min"), equalTo(4.0)); @@ -360,6 +388,11 @@ public void testUsage() throws Exception { assertThat(source.getValue("inference.deployments.model_sizes_bytes.max"), equalTo(1000.0)); assertThat(source.getValue("inference.deployments.model_sizes_bytes.avg"), equalTo(650.0)); assertThat(source.getValue("inference.deployments.time_ms.avg"), closeTo(44.0, 1e-10)); + + assertThat(source.getValue("memory.anomaly_detectors_memory_bytes"), equalTo(20)); + assertThat(source.getValue("memory.data_frame_analytics_memory_bytes"), equalTo(30)); + assertThat(source.getValue("memory.pytorch_inference_memory_bytes"), equalTo(40)); + assertThat(source.getValue("memory.total_used_memory_bytes"), equalTo(91)); } } @@ -566,6 +599,8 @@ public void testUsageWithOrphanedTask() throws Exception { Job closed1 = buildJob("closed1", Arrays.asList(buildMinDetector("foo"), buildMinDetector("bar"), buildMinDetector("foobar"))); GetJobsStatsAction.Response.JobStats closed1JobStats = buildJobStats("closed1", JobState.CLOSED, 300L, 0); givenJobs(Arrays.asList(opened1, closed1), Arrays.asList(opened1JobStats, opened2JobStats, closed1JobStats)); + MlMemoryAction.Response memory = new MlMemoryAction.Response(new ClusterName("foo"), List.of(), List.of()); + givenMlMemory(memory); var usageAction = newUsageAction(settings.build(), true, true, true); PlainActionFuture future = new PlainActionFuture<>(); @@ -590,6 +625,11 @@ public void testUsageWithOrphanedTask() throws Exception { assertThat(source.getValue("jobs._all.model_size.avg"), equalTo(200.0)); assertThat(source.getValue("jobs._all.created_by.a_cool_module"), equalTo(1)); assertThat(source.getValue("jobs._all.created_by.unknown"), equalTo(1)); + + assertThat(source.getValue("memory.anomaly_detectors_memory_bytes"), equalTo(0)); + assertThat(source.getValue("memory.data_frame_analytics_memory_bytes"), equalTo(0)); + assertThat(source.getValue("memory.pytorch_inference_memory_bytes"), equalTo(0)); + assertThat(source.getValue("memory.total_used_memory_bytes"), equalTo(0)); } public void testUsageDisabledML() throws Exception { @@ -802,6 +842,15 @@ private void givenTrainedModelStats(GetTrainedModelsStatsAction.Response trained }).when(client).execute(same(GetTrainedModelsStatsAction.INSTANCE), any(), any()); } + private void givenMlMemory(MlMemoryAction.Response memoryUsage) { + doAnswer(invocationOnMock -> { + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(memoryUsage); + return Void.TYPE; + }).when(client).execute(same(MlMemoryAction.INSTANCE), any(), any()); + } + private static Detector buildMinDetector(String fieldName) { Detector.Builder detectorBuilder = new Detector.Builder(); detectorBuilder.setFunction("min"); @@ -1004,8 +1053,8 @@ private Map setupComplexMocks() { new AssignmentStats( "deployment_3", "model_3", - null, - null, + 1, + 8, null, null, null, @@ -1111,6 +1160,29 @@ private Map setupComplexMocks() { ) ) ); + + givenMlMemory( + new MlMemoryAction.Response( + new ClusterName("cluster_foo"), + List.of( + new MlMemoryAction.Response.MlMemoryStats( + mock(DiscoveryNode.class), + ByteSizeValue.ofBytes(100L), + ByteSizeValue.ofBytes(1L), + ByteSizeValue.ofBytes(1L), + ByteSizeValue.ofBytes(1L), + ByteSizeValue.ofBytes(20L), + ByteSizeValue.ofBytes(30L), + ByteSizeValue.ofBytes(40L), + ByteSizeValue.ofBytes(1L), + ByteSizeValue.ofBytes(1L), + ByteSizeValue.ofBytes(1L) + ) + ), + List.of() + ) + ); + return expectedDfaCountByAnalysis; }