Skip to content

Commit feafb3a

Browse files
authored
[ML] Track inference deployments (#131442)
Record duration and errors when Inference Endpoints deploy Trained Models. The new metric is `es.inference.trained_model.deployment.time`. Refactored `InferenceStats` into server so it can be used in `InferenceServiceExtension` and passed to InferenceServices rather than remain at the Transport layer.
1 parent a6d6d3e commit feafb3a

File tree

17 files changed

+232
-111
lines changed

17 files changed

+232
-111
lines changed

docs/changelog/131442.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 131442
2+
summary: Track inference deployments
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/module-info.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,4 +484,5 @@
484484
exports org.elasticsearch.index.codec.perfield;
485485
exports org.elasticsearch.index.codec.vectors to org.elasticsearch.test.knn;
486486
exports org.elasticsearch.index.codec.vectors.es818 to org.elasticsearch.test.knn;
487+
exports org.elasticsearch.inference.telemetry;
487488
}

server/src/main/java/org/elasticsearch/inference/InferenceServiceExtension.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.elasticsearch.client.internal.Client;
1313
import org.elasticsearch.cluster.service.ClusterService;
1414
import org.elasticsearch.common.settings.Settings;
15+
import org.elasticsearch.inference.telemetry.InferenceStats;
1516
import org.elasticsearch.threadpool.ThreadPool;
1617

1718
import java.util.List;
@@ -23,7 +24,13 @@ public interface InferenceServiceExtension {
2324

2425
List<Factory> getInferenceServiceFactories();
2526

26-
record InferenceServiceFactoryContext(Client client, ThreadPool threadPool, ClusterService clusterService, Settings settings) {}
27+
record InferenceServiceFactoryContext(
28+
Client client,
29+
ThreadPool threadPool,
30+
ClusterService clusterService,
31+
Settings settings,
32+
InferenceStats inferenceStats
33+
) {}
2734

2835
interface Factory {
2936
/**
Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
/*
22
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3-
* or more contributor license agreements. Licensed under the Elastic License
4-
* 2.0; you may not use this file except in compliance with the Elastic License
5-
* 2.0.
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
68
*/
79

8-
package org.elasticsearch.xpack.inference.telemetry;
10+
package org.elasticsearch.inference.telemetry;
911

1012
import org.elasticsearch.ElasticsearchStatusException;
1113
import org.elasticsearch.core.Nullable;
@@ -14,17 +16,17 @@
1416
import org.elasticsearch.telemetry.metric.LongCounter;
1517
import org.elasticsearch.telemetry.metric.LongHistogram;
1618
import org.elasticsearch.telemetry.metric.MeterRegistry;
17-
import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest;
1819

1920
import java.util.HashMap;
2021
import java.util.Map;
2122
import java.util.Objects;
2223

23-
public record InferenceStats(LongCounter requestCount, LongHistogram inferenceDuration) {
24+
public record InferenceStats(LongCounter requestCount, LongHistogram inferenceDuration, LongHistogram deploymentDuration) {
2425

2526
public InferenceStats {
2627
Objects.requireNonNull(requestCount);
2728
Objects.requireNonNull(inferenceDuration);
29+
Objects.requireNonNull(deploymentDuration);
2830
}
2931

3032
public static InferenceStats create(MeterRegistry meterRegistry) {
@@ -38,6 +40,11 @@ public static InferenceStats create(MeterRegistry meterRegistry) {
3840
"es.inference.requests.time",
3941
"Inference API request counts for a particular service, task type, model ID",
4042
"ms"
43+
),
44+
meterRegistry.registerLongHistogram(
45+
"es.inference.trained_model.deployment.time",
46+
"Inference API time spent waiting for Trained Model Deployments",
47+
"ms"
4148
)
4249
);
4350
}
@@ -54,8 +61,8 @@ public static Map<String, Object> modelAttributes(Model model) {
5461
return modelAttributesMap;
5562
}
5663

57-
public static Map<String, Object> routingAttributes(BaseInferenceActionRequest request, String nodeIdHandlingRequest) {
58-
return Map.of("rerouted", request.hasBeenRerouted(), "node_id", nodeIdHandlingRequest);
64+
public static Map<String, Object> routingAttributes(boolean hasBeenRerouted, String nodeIdHandlingRequest) {
65+
return Map.of("rerouted", hasBeenRerouted, "node_id", nodeIdHandlingRequest);
5966
}
6067

6168
public static Map<String, Object> modelAttributes(UnparsedModel model) {
@@ -73,4 +80,11 @@ public static Map<String, Object> responseAttributes(@Nullable Throwable throwab
7380

7481
return Map.of("error.type", throwable.getClass().getSimpleName());
7582
}
83+
84+
public static Map<String, Object> modelAndResponseAttributes(Model model, @Nullable Throwable throwable) {
85+
var metricAttributes = new HashMap<String, Object>();
86+
metricAttributes.putAll(modelAttributes(model));
87+
metricAttributes.putAll(responseAttributes(throwable));
88+
return metricAttributes;
89+
}
7690
}
Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
/*
22
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3-
* or more contributor license agreements. Licensed under the Elastic License
4-
* 2.0; you may not use this file except in compliance with the Elastic License
5-
* 2.0.
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
68
*/
79

8-
package org.elasticsearch.xpack.inference.telemetry;
10+
package org.elasticsearch.inference.telemetry;
911

1012
import org.elasticsearch.ElasticsearchStatusException;
1113
import org.elasticsearch.inference.Model;
@@ -22,9 +24,9 @@
2224
import java.util.HashMap;
2325
import java.util.Map;
2426

25-
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.create;
26-
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes;
27-
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes;
27+
import static org.elasticsearch.inference.telemetry.InferenceStats.create;
28+
import static org.elasticsearch.inference.telemetry.InferenceStats.modelAttributes;
29+
import static org.elasticsearch.inference.telemetry.InferenceStats.responseAttributes;
2830
import static org.hamcrest.Matchers.is;
2931
import static org.hamcrest.Matchers.nullValue;
3032
import static org.mockito.ArgumentMatchers.assertArg;
@@ -35,9 +37,13 @@
3537

3638
public class InferenceStatsTests extends ESTestCase {
3739

40+
public static InferenceStats mockInferenceStats() {
41+
return new InferenceStats(mock(), mock(), mock());
42+
}
43+
3844
public void testRecordWithModel() {
3945
var longCounter = mock(LongCounter.class);
40-
var stats = new InferenceStats(longCounter, mock());
46+
var stats = new InferenceStats(longCounter, mock(), mock());
4147

4248
stats.requestCount().incrementBy(1, modelAttributes(model("service", TaskType.ANY, "modelId")));
4349

@@ -49,7 +55,7 @@ public void testRecordWithModel() {
4955

5056
public void testRecordWithoutModel() {
5157
var longCounter = mock(LongCounter.class);
52-
var stats = new InferenceStats(longCounter, mock());
58+
var stats = new InferenceStats(longCounter, mock(), mock());
5359

5460
stats.requestCount().incrementBy(1, modelAttributes(model("service", TaskType.ANY, null)));
5561

@@ -63,7 +69,7 @@ public void testCreation() {
6369
public void testRecordDurationWithoutError() {
6470
var expectedLong = randomLong();
6571
var histogramCounter = mock(LongHistogram.class);
66-
var stats = new InferenceStats(mock(), histogramCounter);
72+
var stats = new InferenceStats(mock(), histogramCounter, mock());
6773

6874
Map<String, Object> metricAttributes = new HashMap<>();
6975
metricAttributes.putAll(modelAttributes(model("service", TaskType.ANY, "modelId")));
@@ -88,7 +94,7 @@ public void testRecordDurationWithoutError() {
8894
public void testRecordDurationWithElasticsearchStatusException() {
8995
var expectedLong = randomLong();
9096
var histogramCounter = mock(LongHistogram.class);
91-
var stats = new InferenceStats(mock(), histogramCounter);
97+
var stats = new InferenceStats(mock(), histogramCounter, mock());
9298
var statusCode = RestStatus.BAD_REQUEST;
9399
var exception = new ElasticsearchStatusException("hello", statusCode);
94100
var expectedError = String.valueOf(statusCode.getStatus());
@@ -116,7 +122,7 @@ public void testRecordDurationWithElasticsearchStatusException() {
116122
public void testRecordDurationWithOtherException() {
117123
var expectedLong = randomLong();
118124
var histogramCounter = mock(LongHistogram.class);
119-
var stats = new InferenceStats(mock(), histogramCounter);
125+
var stats = new InferenceStats(mock(), histogramCounter, mock());
120126
var exception = new IllegalStateException("ahh");
121127
var expectedError = exception.getClass().getSimpleName();
122128

@@ -138,7 +144,7 @@ public void testRecordDurationWithOtherException() {
138144
public void testRecordDurationWithUnparsedModelAndElasticsearchStatusException() {
139145
var expectedLong = randomLong();
140146
var histogramCounter = mock(LongHistogram.class);
141-
var stats = new InferenceStats(mock(), histogramCounter);
147+
var stats = new InferenceStats(mock(), histogramCounter, mock());
142148
var statusCode = RestStatus.BAD_REQUEST;
143149
var exception = new ElasticsearchStatusException("hello", statusCode);
144150
var expectedError = String.valueOf(statusCode.getStatus());
@@ -163,7 +169,7 @@ public void testRecordDurationWithUnparsedModelAndElasticsearchStatusException()
163169
public void testRecordDurationWithUnparsedModelAndOtherException() {
164170
var expectedLong = randomLong();
165171
var histogramCounter = mock(LongHistogram.class);
166-
var stats = new InferenceStats(mock(), histogramCounter);
172+
var stats = new InferenceStats(mock(), histogramCounter, mock());
167173
var exception = new IllegalStateException("ahh");
168174
var expectedError = exception.getClass().getSimpleName();
169175

@@ -187,7 +193,7 @@ public void testRecordDurationWithUnparsedModelAndOtherException() {
187193
public void testRecordDurationWithUnknownModelAndElasticsearchStatusException() {
188194
var expectedLong = randomLong();
189195
var histogramCounter = mock(LongHistogram.class);
190-
var stats = new InferenceStats(mock(), histogramCounter);
196+
var stats = new InferenceStats(mock(), histogramCounter, mock());
191197
var statusCode = RestStatus.BAD_REQUEST;
192198
var exception = new ElasticsearchStatusException("hello", statusCode);
193199
var expectedError = String.valueOf(statusCode.getStatus());
@@ -206,7 +212,7 @@ public void testRecordDurationWithUnknownModelAndElasticsearchStatusException()
206212
public void testRecordDurationWithUnknownModelAndOtherException() {
207213
var expectedLong = randomLong();
208214
var histogramCounter = mock(LongHistogram.class);
209-
var stats = new InferenceStats(mock(), histogramCounter);
215+
var stats = new InferenceStats(mock(), histogramCounter, mock());
210216
var exception = new IllegalStateException("ahh");
211217
var expectedError = exception.getClass().getSimpleName();
212218

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.elasticsearch.inference.TaskSettings;
3232
import org.elasticsearch.inference.TaskType;
3333
import org.elasticsearch.inference.UnparsedModel;
34+
import org.elasticsearch.inference.telemetry.InferenceStatsTests;
3435
import org.elasticsearch.plugins.Plugin;
3536
import org.elasticsearch.reindex.ReindexPlugin;
3637
import org.elasticsearch.test.ESSingleNodeTestCase;
@@ -129,7 +130,8 @@ public void testGetModel() throws Exception {
129130
mock(Client.class),
130131
mock(ThreadPool.class),
131132
mock(ClusterService.class),
132-
Settings.EMPTY
133+
Settings.EMPTY,
134+
InferenceStatsTests.mockInferenceStats()
133135
)
134136
);
135137
ElasticsearchInternalModel roundTripModel = (ElasticsearchInternalModel) elserService.parsePersistedConfigWithSecrets(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.elasticsearch.indices.SystemIndexDescriptor;
3131
import org.elasticsearch.inference.InferenceServiceExtension;
3232
import org.elasticsearch.inference.InferenceServiceRegistry;
33+
import org.elasticsearch.inference.telemetry.InferenceStats;
3334
import org.elasticsearch.license.License;
3435
import org.elasticsearch.license.LicensedFeature;
3536
import org.elasticsearch.license.XPackLicenseState;
@@ -140,7 +141,6 @@
140141
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModelBuilder;
141142
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
142143
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIService;
143-
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
144144

145145
import java.util.ArrayList;
146146
import java.util.Collection;
@@ -328,11 +328,16 @@ public Collection<?> createComponents(PluginServices services) {
328328
)
329329
);
330330

331+
var meterRegistry = services.telemetryProvider().getMeterRegistry();
332+
var inferenceStats = InferenceStats.create(meterRegistry);
333+
var inferenceStatsBinding = new PluginComponentBinding<>(InferenceStats.class, inferenceStats);
334+
331335
var factoryContext = new InferenceServiceExtension.InferenceServiceFactoryContext(
332336
services.client(),
333337
services.threadPool(),
334338
services.clusterService(),
335-
settings
339+
settings,
340+
inferenceStats
336341
);
337342

338343
// This must be done after the HttpRequestSenderFactory is created so that the services can get the
@@ -344,10 +349,6 @@ public Collection<?> createComponents(PluginServices services) {
344349
}
345350
inferenceServiceRegistry.set(serviceRegistry);
346351

347-
var meterRegistry = services.telemetryProvider().getMeterRegistry();
348-
var inferenceStats = InferenceStats.create(meterRegistry);
349-
var inferenceStatsBinding = new PluginComponentBinding<>(InferenceStats.class, inferenceStats);
350-
351352
var actionFilter = new ShardBulkInferenceActionFilter(
352353
services.clusterService(),
353354
serviceRegistry,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.elasticsearch.inference.Model;
2727
import org.elasticsearch.inference.TaskType;
2828
import org.elasticsearch.inference.UnparsedModel;
29+
import org.elasticsearch.inference.telemetry.InferenceStats;
2930
import org.elasticsearch.license.LicenseUtils;
3031
import org.elasticsearch.license.XPackLicenseState;
3132
import org.elasticsearch.rest.RestStatus;
@@ -42,7 +43,6 @@
4243
import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator;
4344
import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
4445
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
45-
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
4646
import org.elasticsearch.xpack.inference.telemetry.InferenceTimer;
4747

4848
import java.io.IOException;
@@ -57,10 +57,11 @@
5757

5858
import static org.elasticsearch.ExceptionsHelper.unwrapCause;
5959
import static org.elasticsearch.core.Strings.format;
60+
import static org.elasticsearch.inference.telemetry.InferenceStats.modelAndResponseAttributes;
61+
import static org.elasticsearch.inference.telemetry.InferenceStats.modelAttributes;
62+
import static org.elasticsearch.inference.telemetry.InferenceStats.responseAttributes;
63+
import static org.elasticsearch.inference.telemetry.InferenceStats.routingAttributes;
6064
import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;
61-
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes;
62-
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes;
63-
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.routingAttributes;
6465

6566
/**
6667
* Base class for transport actions that handle inference requests.
@@ -274,15 +275,11 @@ public InferenceAction.Response read(StreamInput in) throws IOException {
274275
}
275276

276277
private void recordRequestDurationMetrics(UnparsedModel model, InferenceTimer timer, @Nullable Throwable t) {
277-
try {
278-
Map<String, Object> metricAttributes = new HashMap<>();
279-
metricAttributes.putAll(modelAttributes(model));
280-
metricAttributes.putAll(responseAttributes(unwrapCause(t)));
281-
282-
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), metricAttributes);
283-
} catch (Exception e) {
284-
log.atDebug().withThrowable(e).log("Failed to record metrics with an unparsed model, dropping metrics");
285-
}
278+
Map<String, Object> metricAttributes = new HashMap<>();
279+
metricAttributes.putAll(modelAttributes(model));
280+
metricAttributes.putAll(responseAttributes(unwrapCause(t)));
281+
282+
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), metricAttributes);
286283
}
287284

288285
private void inferOnServiceWithMetrics(
@@ -369,7 +366,7 @@ protected <T> Flow.Publisher<T> streamErrorHandler(Flow.Publisher<T> upstream) {
369366
private void recordRequestCountMetrics(Model model, Request request, String localNodeId) {
370367
Map<String, Object> requestCountAttributes = new HashMap<>();
371368
requestCountAttributes.putAll(modelAttributes(model));
372-
requestCountAttributes.putAll(routingAttributes(request, localNodeId));
369+
requestCountAttributes.putAll(routingAttributes(request.hasBeenRerouted(), localNodeId));
373370

374371
inferenceStats.requestCount().incrementBy(1, requestCountAttributes);
375372
}
@@ -381,16 +378,11 @@ private void recordRequestDurationMetrics(
381378
String localNodeId,
382379
@Nullable Throwable t
383380
) {
384-
try {
385-
Map<String, Object> metricAttributes = new HashMap<>();
386-
metricAttributes.putAll(modelAttributes(model));
387-
metricAttributes.putAll(routingAttributes(request, localNodeId));
388-
metricAttributes.putAll(responseAttributes(unwrapCause(t)));
389-
390-
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), metricAttributes);
391-
} catch (Exception e) {
392-
log.atDebug().withThrowable(e).log("Failed to record metrics with a parsed model, dropping metrics");
393-
}
381+
Map<String, Object> metricAttributes = new HashMap<>();
382+
metricAttributes.putAll(modelAndResponseAttributes(model, unwrapCause(t)));
383+
metricAttributes.putAll(routingAttributes(request.hasBeenRerouted(), localNodeId));
384+
385+
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), metricAttributes);
394386
}
395387

396388
private void inferOnService(Model model, Request request, InferenceService service, ActionListener<InferenceServiceResults> listener) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.elasticsearch.inference.InferenceServiceResults;
1717
import org.elasticsearch.inference.Model;
1818
import org.elasticsearch.inference.UnparsedModel;
19+
import org.elasticsearch.inference.telemetry.InferenceStats;
1920
import org.elasticsearch.injection.guice.Inject;
2021
import org.elasticsearch.license.XPackLicenseState;
2122
import org.elasticsearch.threadpool.ThreadPool;
@@ -24,7 +25,6 @@
2425
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
2526
import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
2627
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
27-
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
2828

2929
public class TransportInferenceAction extends BaseTransportInferenceAction<InferenceAction.Request> {
3030

0 commit comments

Comments
 (0)