Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/131442.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 131442
summary: Track inference deployments
area: Machine Learning
type: enhancement
issues: []
1 change: 1 addition & 0 deletions server/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -484,4 +484,5 @@
exports org.elasticsearch.index.codec.perfield;
exports org.elasticsearch.index.codec.vectors to org.elasticsearch.test.knn;
exports org.elasticsearch.index.codec.vectors.es818 to org.elasticsearch.test.knn;
exports org.elasticsearch.inference.telemetry;
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.inference.telemetry.InferenceStats;
import org.elasticsearch.threadpool.ThreadPool;

import java.util.List;
Expand All @@ -23,7 +24,13 @@ public interface InferenceServiceExtension {

List<Factory> getInferenceServiceFactories();

record InferenceServiceFactoryContext(Client client, ThreadPool threadPool, ClusterService clusterService, Settings settings) {}
record InferenceServiceFactoryContext(
Client client,
ThreadPool threadPool,
ClusterService clusterService,
Settings settings,
InferenceStats inferenceStats
) {}

interface Factory {
/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.xpack.inference.telemetry;
package org.elasticsearch.inference.telemetry;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.core.Nullable;
Expand All @@ -14,17 +16,17 @@
import org.elasticsearch.telemetry.metric.LongCounter;
import org.elasticsearch.telemetry.metric.LongHistogram;
import org.elasticsearch.telemetry.metric.MeterRegistry;
import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest;

import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

public record InferenceStats(LongCounter requestCount, LongHistogram inferenceDuration) {
public record InferenceStats(LongCounter requestCount, LongHistogram inferenceDuration, LongHistogram deploymentDuration) {

public InferenceStats {
Objects.requireNonNull(requestCount);
Objects.requireNonNull(inferenceDuration);
Objects.requireNonNull(deploymentDuration);
}

public static InferenceStats create(MeterRegistry meterRegistry) {
Expand All @@ -38,6 +40,11 @@ public static InferenceStats create(MeterRegistry meterRegistry) {
"es.inference.requests.time",
"Inference API request counts for a particular service, task type, model ID",
"ms"
),
meterRegistry.registerLongHistogram(
"es.inference.trained_model.deployment.time",
"Inference API time spent waiting for Trained Model Deployments",
"ms"
)
);
}
Expand All @@ -54,8 +61,8 @@ public static Map<String, Object> modelAttributes(Model model) {
return modelAttributesMap;
}

public static Map<String, Object> routingAttributes(BaseInferenceActionRequest request, String nodeIdHandlingRequest) {
return Map.of("rerouted", request.hasBeenRerouted(), "node_id", nodeIdHandlingRequest);
public static Map<String, Object> routingAttributes(boolean hasBeenRerouted, String nodeIdHandlingRequest) {
return Map.of("rerouted", hasBeenRerouted, "node_id", nodeIdHandlingRequest);
}

public static Map<String, Object> modelAttributes(UnparsedModel model) {
Expand All @@ -73,4 +80,11 @@ public static Map<String, Object> responseAttributes(@Nullable Throwable throwab

return Map.of("error.type", throwable.getClass().getSimpleName());
}

public static Map<String, Object> modelAndResponseAttributes(Model model, @Nullable Throwable throwable) {
var metricAttributes = new HashMap<String, Object>();
metricAttributes.putAll(modelAttributes(model));
metricAttributes.putAll(responseAttributes(throwable));
return metricAttributes;
}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.xpack.inference.telemetry;
package org.elasticsearch.inference.telemetry;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.inference.Model;
Expand All @@ -22,9 +24,9 @@
import java.util.HashMap;
import java.util.Map;

import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.create;
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes;
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes;
import static org.elasticsearch.inference.telemetry.InferenceStats.create;
import static org.elasticsearch.inference.telemetry.InferenceStats.modelAttributes;
import static org.elasticsearch.inference.telemetry.InferenceStats.responseAttributes;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.ArgumentMatchers.assertArg;
Expand All @@ -35,9 +37,13 @@

public class InferenceStatsTests extends ESTestCase {

public static InferenceStats mockInferenceStats() {
return new InferenceStats(mock(), mock(), mock());
}

public void testRecordWithModel() {
var longCounter = mock(LongCounter.class);
var stats = new InferenceStats(longCounter, mock());
var stats = new InferenceStats(longCounter, mock(), mock());

stats.requestCount().incrementBy(1, modelAttributes(model("service", TaskType.ANY, "modelId")));

Expand All @@ -49,7 +55,7 @@ public void testRecordWithModel() {

public void testRecordWithoutModel() {
var longCounter = mock(LongCounter.class);
var stats = new InferenceStats(longCounter, mock());
var stats = new InferenceStats(longCounter, mock(), mock());

stats.requestCount().incrementBy(1, modelAttributes(model("service", TaskType.ANY, null)));

Expand All @@ -63,7 +69,7 @@ public void testCreation() {
public void testRecordDurationWithoutError() {
var expectedLong = randomLong();
var histogramCounter = mock(LongHistogram.class);
var stats = new InferenceStats(mock(), histogramCounter);
var stats = new InferenceStats(mock(), histogramCounter, mock());

Map<String, Object> metricAttributes = new HashMap<>();
metricAttributes.putAll(modelAttributes(model("service", TaskType.ANY, "modelId")));
Expand All @@ -88,7 +94,7 @@ public void testRecordDurationWithoutError() {
public void testRecordDurationWithElasticsearchStatusException() {
var expectedLong = randomLong();
var histogramCounter = mock(LongHistogram.class);
var stats = new InferenceStats(mock(), histogramCounter);
var stats = new InferenceStats(mock(), histogramCounter, mock());
var statusCode = RestStatus.BAD_REQUEST;
var exception = new ElasticsearchStatusException("hello", statusCode);
var expectedError = String.valueOf(statusCode.getStatus());
Expand Down Expand Up @@ -116,7 +122,7 @@ public void testRecordDurationWithElasticsearchStatusException() {
public void testRecordDurationWithOtherException() {
var expectedLong = randomLong();
var histogramCounter = mock(LongHistogram.class);
var stats = new InferenceStats(mock(), histogramCounter);
var stats = new InferenceStats(mock(), histogramCounter, mock());
var exception = new IllegalStateException("ahh");
var expectedError = exception.getClass().getSimpleName();

Expand All @@ -138,7 +144,7 @@ public void testRecordDurationWithOtherException() {
public void testRecordDurationWithUnparsedModelAndElasticsearchStatusException() {
var expectedLong = randomLong();
var histogramCounter = mock(LongHistogram.class);
var stats = new InferenceStats(mock(), histogramCounter);
var stats = new InferenceStats(mock(), histogramCounter, mock());
var statusCode = RestStatus.BAD_REQUEST;
var exception = new ElasticsearchStatusException("hello", statusCode);
var expectedError = String.valueOf(statusCode.getStatus());
Expand All @@ -163,7 +169,7 @@ public void testRecordDurationWithUnparsedModelAndElasticsearchStatusException()
public void testRecordDurationWithUnparsedModelAndOtherException() {
var expectedLong = randomLong();
var histogramCounter = mock(LongHistogram.class);
var stats = new InferenceStats(mock(), histogramCounter);
var stats = new InferenceStats(mock(), histogramCounter, mock());
var exception = new IllegalStateException("ahh");
var expectedError = exception.getClass().getSimpleName();

Expand All @@ -187,7 +193,7 @@ public void testRecordDurationWithUnparsedModelAndOtherException() {
public void testRecordDurationWithUnknownModelAndElasticsearchStatusException() {
var expectedLong = randomLong();
var histogramCounter = mock(LongHistogram.class);
var stats = new InferenceStats(mock(), histogramCounter);
var stats = new InferenceStats(mock(), histogramCounter, mock());
var statusCode = RestStatus.BAD_REQUEST;
var exception = new ElasticsearchStatusException("hello", statusCode);
var expectedError = String.valueOf(statusCode.getStatus());
Expand All @@ -206,7 +212,7 @@ public void testRecordDurationWithUnknownModelAndElasticsearchStatusException()
public void testRecordDurationWithUnknownModelAndOtherException() {
var expectedLong = randomLong();
var histogramCounter = mock(LongHistogram.class);
var stats = new InferenceStats(mock(), histogramCounter);
var stats = new InferenceStats(mock(), histogramCounter, mock());
var exception = new IllegalStateException("ahh");
var expectedError = exception.getClass().getSimpleName();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.inference.telemetry.InferenceStatsTests;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.reindex.ReindexPlugin;
import org.elasticsearch.test.ESSingleNodeTestCase;
Expand Down Expand Up @@ -129,7 +130,8 @@ public void testGetModel() throws Exception {
mock(Client.class),
mock(ThreadPool.class),
mock(ClusterService.class),
Settings.EMPTY
Settings.EMPTY,
InferenceStatsTests.mockInferenceStats()
)
);
ElasticsearchInternalModel roundTripModel = (ElasticsearchInternalModel) elserService.parsePersistedConfigWithSecrets(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.elasticsearch.indices.SystemIndexDescriptor;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.inference.telemetry.InferenceStats;
import org.elasticsearch.license.License;
import org.elasticsearch.license.LicensedFeature;
import org.elasticsearch.license.XPackLicenseState;
Expand Down Expand Up @@ -140,7 +141,6 @@
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModelBuilder;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIService;
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;

import java.util.ArrayList;
import java.util.Collection;
Expand Down Expand Up @@ -328,11 +328,16 @@ public Collection<?> createComponents(PluginServices services) {
)
);

var meterRegistry = services.telemetryProvider().getMeterRegistry();
var inferenceStats = InferenceStats.create(meterRegistry);
var inferenceStatsBinding = new PluginComponentBinding<>(InferenceStats.class, inferenceStats);

var factoryContext = new InferenceServiceExtension.InferenceServiceFactoryContext(
services.client(),
services.threadPool(),
services.clusterService(),
settings
settings,
inferenceStats
);

// This must be done after the HttpRequestSenderFactory is created so that the services can get the
Expand All @@ -344,10 +349,6 @@ public Collection<?> createComponents(PluginServices services) {
}
inferenceServiceRegistry.set(serviceRegistry);

var meterRegistry = services.telemetryProvider().getMeterRegistry();
var inferenceStats = InferenceStats.create(meterRegistry);
var inferenceStatsBinding = new PluginComponentBinding<>(InferenceStats.class, inferenceStats);

var actionFilter = new ShardBulkInferenceActionFilter(
services.clusterService(),
serviceRegistry,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.inference.telemetry.InferenceStats;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestStatus;
Expand All @@ -42,7 +43,6 @@
import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator;
import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
import org.elasticsearch.xpack.inference.telemetry.InferenceTimer;

import java.io.IOException;
Expand All @@ -57,10 +57,11 @@

import static org.elasticsearch.ExceptionsHelper.unwrapCause;
import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.inference.telemetry.InferenceStats.modelAndResponseAttributes;
import static org.elasticsearch.inference.telemetry.InferenceStats.modelAttributes;
import static org.elasticsearch.inference.telemetry.InferenceStats.responseAttributes;
import static org.elasticsearch.inference.telemetry.InferenceStats.routingAttributes;
import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes;
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes;
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.routingAttributes;

/**
* Base class for transport actions that handle inference requests.
Expand Down Expand Up @@ -274,15 +275,11 @@ public InferenceAction.Response read(StreamInput in) throws IOException {
}

private void recordRequestDurationMetrics(UnparsedModel model, InferenceTimer timer, @Nullable Throwable t) {
try {
Map<String, Object> metricAttributes = new HashMap<>();
metricAttributes.putAll(modelAttributes(model));
metricAttributes.putAll(responseAttributes(unwrapCause(t)));

inferenceStats.inferenceDuration().record(timer.elapsedMillis(), metricAttributes);
} catch (Exception e) {
log.atDebug().withThrowable(e).log("Failed to record metrics with an unparsed model, dropping metrics");
}
Map<String, Object> metricAttributes = new HashMap<>();
metricAttributes.putAll(modelAttributes(model));
metricAttributes.putAll(responseAttributes(unwrapCause(t)));

inferenceStats.inferenceDuration().record(timer.elapsedMillis(), metricAttributes);
}

private void inferOnServiceWithMetrics(
Expand Down Expand Up @@ -369,7 +366,7 @@ protected <T> Flow.Publisher<T> streamErrorHandler(Flow.Publisher<T> upstream) {
private void recordRequestCountMetrics(Model model, Request request, String localNodeId) {
Map<String, Object> requestCountAttributes = new HashMap<>();
requestCountAttributes.putAll(modelAttributes(model));
requestCountAttributes.putAll(routingAttributes(request, localNodeId));
requestCountAttributes.putAll(routingAttributes(request.hasBeenRerouted(), localNodeId));

inferenceStats.requestCount().incrementBy(1, requestCountAttributes);
}
Expand All @@ -381,16 +378,11 @@ private void recordRequestDurationMetrics(
String localNodeId,
@Nullable Throwable t
) {
try {
Map<String, Object> metricAttributes = new HashMap<>();
metricAttributes.putAll(modelAttributes(model));
metricAttributes.putAll(routingAttributes(request, localNodeId));
metricAttributes.putAll(responseAttributes(unwrapCause(t)));

inferenceStats.inferenceDuration().record(timer.elapsedMillis(), metricAttributes);
} catch (Exception e) {
log.atDebug().withThrowable(e).log("Failed to record metrics with a parsed model, dropping metrics");
}
Map<String, Object> metricAttributes = new HashMap<>();
metricAttributes.putAll(modelAndResponseAttributes(model, unwrapCause(t)));
metricAttributes.putAll(routingAttributes(request.hasBeenRerouted(), localNodeId));

inferenceStats.inferenceDuration().record(timer.elapsedMillis(), metricAttributes);
}

private void inferOnService(Model model, Request request, InferenceService service, ActionListener<InferenceServiceResults> listener) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.inference.telemetry.InferenceStats;
import org.elasticsearch.injection.guice.Inject;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.threadpool.ThreadPool;
Expand All @@ -24,7 +25,6 @@
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;

public class TransportInferenceAction extends BaseTransportInferenceAction<InferenceAction.Request> {

Expand Down
Loading
Loading