diff --git a/docs/changelog/131442.yaml b/docs/changelog/131442.yaml new file mode 100644 index 0000000000000..23d00cd7d028d --- /dev/null +++ b/docs/changelog/131442.yaml @@ -0,0 +1,5 @@ +pr: 131442 +summary: Track inference deployments +area: Machine Learning +type: enhancement +issues: [] diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index a6f6f622070f1..90cd3c669a52c 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -484,4 +484,5 @@ exports org.elasticsearch.index.codec.perfield; exports org.elasticsearch.index.codec.vectors to org.elasticsearch.test.knn; exports org.elasticsearch.index.codec.vectors.es818 to org.elasticsearch.test.knn; + exports org.elasticsearch.inference.telemetry; } diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceExtension.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceExtension.java index 3274bf571d10a..a6857b82a747f 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceServiceExtension.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceServiceExtension.java @@ -12,6 +12,7 @@ import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.threadpool.ThreadPool; import java.util.List; @@ -23,7 +24,13 @@ public interface InferenceServiceExtension { List getInferenceServiceFactories(); - record InferenceServiceFactoryContext(Client client, ThreadPool threadPool, ClusterService clusterService, Settings settings) {} + record InferenceServiceFactoryContext( + Client client, + ThreadPool threadPool, + ClusterService clusterService, + Settings settings, + InferenceStats inferenceStats + ) {} interface Factory { /** diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceStats.java b/server/src/main/java/org/elasticsearch/inference/telemetry/InferenceStats.java similarity index 65% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceStats.java rename to server/src/main/java/org/elasticsearch/inference/telemetry/InferenceStats.java index 17c91b81233fb..e73b1ad9c5ff6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceStats.java +++ b/server/src/main/java/org/elasticsearch/inference/telemetry/InferenceStats.java @@ -1,11 +1,13 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". */ -package org.elasticsearch.xpack.inference.telemetry; +package org.elasticsearch.inference.telemetry; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.core.Nullable; @@ -14,17 +16,17 @@ import org.elasticsearch.telemetry.metric.LongCounter; import org.elasticsearch.telemetry.metric.LongHistogram; import org.elasticsearch.telemetry.metric.MeterRegistry; -import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest; import java.util.HashMap; import java.util.Map; import java.util.Objects; -public record InferenceStats(LongCounter requestCount, LongHistogram inferenceDuration) { +public record InferenceStats(LongCounter requestCount, LongHistogram inferenceDuration, LongHistogram deploymentDuration) { public InferenceStats { Objects.requireNonNull(requestCount); Objects.requireNonNull(inferenceDuration); + Objects.requireNonNull(deploymentDuration); } public static InferenceStats create(MeterRegistry meterRegistry) { @@ -38,6 +40,11 @@ public static InferenceStats create(MeterRegistry meterRegistry) { "es.inference.requests.time", "Inference API request counts for a particular service, task type, model ID", "ms" + ), + meterRegistry.registerLongHistogram( + "es.inference.trained_model.deployment.time", + "Inference API time spent waiting for Trained Model Deployments", + "ms" ) ); } @@ -54,8 +61,8 @@ public static Map modelAttributes(Model model) { return modelAttributesMap; } - public static Map routingAttributes(BaseInferenceActionRequest request, String nodeIdHandlingRequest) { - return Map.of("rerouted", request.hasBeenRerouted(), "node_id", nodeIdHandlingRequest); + public static Map routingAttributes(boolean hasBeenRerouted, String nodeIdHandlingRequest) { + return Map.of("rerouted", hasBeenRerouted, "node_id", nodeIdHandlingRequest); } public static Map modelAttributes(UnparsedModel model) { @@ -73,4 +80,11 @@ public static Map responseAttributes(@Nullable Throwable throwab return Map.of("error.type", throwable.getClass().getSimpleName()); } + + public static Map modelAndResponseAttributes(Model model, @Nullable Throwable throwable) { + var metricAttributes = new HashMap(); + metricAttributes.putAll(modelAttributes(model)); + metricAttributes.putAll(responseAttributes(throwable)); + return metricAttributes; + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/telemetry/InferenceStatsTests.java b/server/src/test/java/org/elasticsearch/inference/telemetry/InferenceStatsTests.java similarity index 87% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/telemetry/InferenceStatsTests.java rename to server/src/test/java/org/elasticsearch/inference/telemetry/InferenceStatsTests.java index f3800f91d9a54..0d71165823e89 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/telemetry/InferenceStatsTests.java +++ b/server/src/test/java/org/elasticsearch/inference/telemetry/InferenceStatsTests.java @@ -1,11 +1,13 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". */ -package org.elasticsearch.xpack.inference.telemetry; +package org.elasticsearch.inference.telemetry; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.inference.Model; @@ -22,9 +24,9 @@ import java.util.HashMap; import java.util.Map; -import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.create; -import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes; -import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes; +import static org.elasticsearch.inference.telemetry.InferenceStats.create; +import static org.elasticsearch.inference.telemetry.InferenceStats.modelAttributes; +import static org.elasticsearch.inference.telemetry.InferenceStats.responseAttributes; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.nullValue; import static org.mockito.ArgumentMatchers.assertArg; @@ -35,9 +37,13 @@ public class InferenceStatsTests extends ESTestCase { + public static InferenceStats mockInferenceStats() { + return new InferenceStats(mock(), mock(), mock()); + } + public void testRecordWithModel() { var longCounter = mock(LongCounter.class); - var stats = new InferenceStats(longCounter, mock()); + var stats = new InferenceStats(longCounter, mock(), mock()); stats.requestCount().incrementBy(1, modelAttributes(model("service", TaskType.ANY, "modelId"))); @@ -49,7 +55,7 @@ public void testRecordWithModel() { public void testRecordWithoutModel() { var longCounter = mock(LongCounter.class); - var stats = new InferenceStats(longCounter, mock()); + var stats = new InferenceStats(longCounter, mock(), mock()); stats.requestCount().incrementBy(1, modelAttributes(model("service", TaskType.ANY, null))); @@ -63,7 +69,7 @@ public void testCreation() { public void testRecordDurationWithoutError() { var expectedLong = randomLong(); var histogramCounter = mock(LongHistogram.class); - var stats = new InferenceStats(mock(), histogramCounter); + var stats = new InferenceStats(mock(), histogramCounter, mock()); Map metricAttributes = new HashMap<>(); metricAttributes.putAll(modelAttributes(model("service", TaskType.ANY, "modelId"))); @@ -88,7 +94,7 @@ public void testRecordDurationWithoutError() { public void testRecordDurationWithElasticsearchStatusException() { var expectedLong = randomLong(); var histogramCounter = mock(LongHistogram.class); - var stats = new InferenceStats(mock(), histogramCounter); + var stats = new InferenceStats(mock(), histogramCounter, mock()); var statusCode = RestStatus.BAD_REQUEST; var exception = new ElasticsearchStatusException("hello", statusCode); var expectedError = String.valueOf(statusCode.getStatus()); @@ -116,7 +122,7 @@ public void testRecordDurationWithElasticsearchStatusException() { public void testRecordDurationWithOtherException() { var expectedLong = randomLong(); var histogramCounter = mock(LongHistogram.class); - var stats = new InferenceStats(mock(), histogramCounter); + var stats = new InferenceStats(mock(), histogramCounter, mock()); var exception = new IllegalStateException("ahh"); var expectedError = exception.getClass().getSimpleName(); @@ -138,7 +144,7 @@ public void testRecordDurationWithOtherException() { public void testRecordDurationWithUnparsedModelAndElasticsearchStatusException() { var expectedLong = randomLong(); var histogramCounter = mock(LongHistogram.class); - var stats = new InferenceStats(mock(), histogramCounter); + var stats = new InferenceStats(mock(), histogramCounter, mock()); var statusCode = RestStatus.BAD_REQUEST; var exception = new ElasticsearchStatusException("hello", statusCode); var expectedError = String.valueOf(statusCode.getStatus()); @@ -163,7 +169,7 @@ public void testRecordDurationWithUnparsedModelAndElasticsearchStatusException() public void testRecordDurationWithUnparsedModelAndOtherException() { var expectedLong = randomLong(); var histogramCounter = mock(LongHistogram.class); - var stats = new InferenceStats(mock(), histogramCounter); + var stats = new InferenceStats(mock(), histogramCounter, mock()); var exception = new IllegalStateException("ahh"); var expectedError = exception.getClass().getSimpleName(); @@ -187,7 +193,7 @@ public void testRecordDurationWithUnparsedModelAndOtherException() { public void testRecordDurationWithUnknownModelAndElasticsearchStatusException() { var expectedLong = randomLong(); var histogramCounter = mock(LongHistogram.class); - var stats = new InferenceStats(mock(), histogramCounter); + var stats = new InferenceStats(mock(), histogramCounter, mock()); var statusCode = RestStatus.BAD_REQUEST; var exception = new ElasticsearchStatusException("hello", statusCode); var expectedError = String.valueOf(statusCode.getStatus()); @@ -206,7 +212,7 @@ public void testRecordDurationWithUnknownModelAndElasticsearchStatusException() public void testRecordDurationWithUnknownModelAndOtherException() { var expectedLong = randomLong(); var histogramCounter = mock(LongHistogram.class); - var stats = new InferenceStats(mock(), histogramCounter); + var stats = new InferenceStats(mock(), histogramCounter, mock()); var exception = new IllegalStateException("ahh"); var expectedError = exception.getClass().getSimpleName(); diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index e56782bd00ef5..22aebee72df0c 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -31,6 +31,7 @@ import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.inference.telemetry.InferenceStatsTests; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.reindex.ReindexPlugin; import org.elasticsearch.test.ESSingleNodeTestCase; @@ -129,7 +130,8 @@ public void testGetModel() throws Exception { mock(Client.class), mock(ThreadPool.class), mock(ClusterService.class), - Settings.EMPTY + Settings.EMPTY, + InferenceStatsTests.mockInferenceStats() ) ); ElasticsearchInternalModel roundTripModel = (ElasticsearchInternalModel) elserService.parsePersistedConfigWithSecrets( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index b729857c91f81..a35d64ab84c7f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -30,6 +30,7 @@ import org.elasticsearch.indices.SystemIndexDescriptor; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.license.License; import org.elasticsearch.license.LicensedFeature; import org.elasticsearch.license.XPackLicenseState; @@ -140,7 +141,6 @@ import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModelBuilder; import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas; import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIService; -import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import java.util.ArrayList; import java.util.Collection; @@ -328,11 +328,16 @@ public Collection createComponents(PluginServices services) { ) ); + var meterRegistry = services.telemetryProvider().getMeterRegistry(); + var inferenceStats = InferenceStats.create(meterRegistry); + var inferenceStatsBinding = new PluginComponentBinding<>(InferenceStats.class, inferenceStats); + var factoryContext = new InferenceServiceExtension.InferenceServiceFactoryContext( services.client(), services.threadPool(), services.clusterService(), - settings + settings, + inferenceStats ); // This must be done after the HttpRequestSenderFactory is created so that the services can get the @@ -344,10 +349,6 @@ public Collection createComponents(PluginServices services) { } inferenceServiceRegistry.set(serviceRegistry); - var meterRegistry = services.telemetryProvider().getMeterRegistry(); - var inferenceStats = InferenceStats.create(meterRegistry); - var inferenceStatsBinding = new PluginComponentBinding<>(InferenceStats.class, inferenceStats); - var actionFilter = new ShardBulkInferenceActionFilter( services.clusterService(), serviceRegistry, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java index dec6d0d928b97..269e0f27fd461 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java @@ -26,6 +26,7 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.RestStatus; @@ -42,7 +43,6 @@ import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator; import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import org.elasticsearch.xpack.inference.telemetry.InferenceTimer; import java.io.IOException; @@ -57,10 +57,11 @@ import static org.elasticsearch.ExceptionsHelper.unwrapCause; import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.inference.telemetry.InferenceStats.modelAndResponseAttributes; +import static org.elasticsearch.inference.telemetry.InferenceStats.modelAttributes; +import static org.elasticsearch.inference.telemetry.InferenceStats.responseAttributes; +import static org.elasticsearch.inference.telemetry.InferenceStats.routingAttributes; import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE; -import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes; -import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes; -import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.routingAttributes; /** * Base class for transport actions that handle inference requests. @@ -274,15 +275,11 @@ public InferenceAction.Response read(StreamInput in) throws IOException { } private void recordRequestDurationMetrics(UnparsedModel model, InferenceTimer timer, @Nullable Throwable t) { - try { - Map metricAttributes = new HashMap<>(); - metricAttributes.putAll(modelAttributes(model)); - metricAttributes.putAll(responseAttributes(unwrapCause(t))); - - inferenceStats.inferenceDuration().record(timer.elapsedMillis(), metricAttributes); - } catch (Exception e) { - log.atDebug().withThrowable(e).log("Failed to record metrics with an unparsed model, dropping metrics"); - } + Map metricAttributes = new HashMap<>(); + metricAttributes.putAll(modelAttributes(model)); + metricAttributes.putAll(responseAttributes(unwrapCause(t))); + + inferenceStats.inferenceDuration().record(timer.elapsedMillis(), metricAttributes); } private void inferOnServiceWithMetrics( @@ -369,7 +366,7 @@ protected Flow.Publisher streamErrorHandler(Flow.Publisher upstream) { private void recordRequestCountMetrics(Model model, Request request, String localNodeId) { Map requestCountAttributes = new HashMap<>(); requestCountAttributes.putAll(modelAttributes(model)); - requestCountAttributes.putAll(routingAttributes(request, localNodeId)); + requestCountAttributes.putAll(routingAttributes(request.hasBeenRerouted(), localNodeId)); inferenceStats.requestCount().incrementBy(1, requestCountAttributes); } @@ -381,16 +378,11 @@ private void recordRequestDurationMetrics( String localNodeId, @Nullable Throwable t ) { - try { - Map metricAttributes = new HashMap<>(); - metricAttributes.putAll(modelAttributes(model)); - metricAttributes.putAll(routingAttributes(request, localNodeId)); - metricAttributes.putAll(responseAttributes(unwrapCause(t))); - - inferenceStats.inferenceDuration().record(timer.elapsedMillis(), metricAttributes); - } catch (Exception e) { - log.atDebug().withThrowable(e).log("Failed to record metrics with a parsed model, dropping metrics"); - } + Map metricAttributes = new HashMap<>(); + metricAttributes.putAll(modelAndResponseAttributes(model, unwrapCause(t))); + metricAttributes.putAll(routingAttributes(request.hasBeenRerouted(), localNodeId)); + + inferenceStats.inferenceDuration().record(timer.elapsedMillis(), metricAttributes); } private void inferOnService(Model model, Request request, InferenceService service, ActionListener listener) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index 7d24b7766baa3..f14d679ba7d26 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -16,6 +16,7 @@ import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.injection.guice.Inject; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.threadpool.ThreadPool; @@ -24,7 +25,6 @@ import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.telemetry.InferenceStats; public class TransportInferenceAction extends BaseTransportInferenceAction { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java index bfa8141d312cf..d0eef677ca1d3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java @@ -17,6 +17,7 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.injection.guice.Inject; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.RestStatus; @@ -29,7 +30,6 @@ import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import java.util.concurrent.Flow; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index ca55434454422..ecf73ed004194 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -46,6 +46,7 @@ import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.RestStatus; @@ -63,7 +64,6 @@ import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.mapper.SemanticTextUtils; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import java.io.IOException; import java.util.ArrayList; @@ -76,11 +76,10 @@ import java.util.Map; import java.util.stream.Collectors; +import static org.elasticsearch.inference.telemetry.InferenceStats.modelAndResponseAttributes; import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunks; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunksLegacy; -import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes; -import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes; /** * A {@link MappedActionFilter} that intercepts {@link BulkShardRequest} to apply inference on fields specified @@ -459,8 +458,7 @@ public void onFailure(Exception exc) { private void recordRequestCountMetrics(Model model, int incrementBy, Throwable throwable) { Map requestCountAttributes = new HashMap<>(); - requestCountAttributes.putAll(modelAttributes(model)); - requestCountAttributes.putAll(responseAttributes(throwable)); + requestCountAttributes.putAll(modelAndResponseAttributes(model, throwable)); requestCountAttributes.put("inference_source", "semantic_text_bulk"); inferenceStats.requestCount().incrementBy(incrementBy, requestCountAttributes); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java index ee4221157388e..4aaf3c2db2e61 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java @@ -11,7 +11,6 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.ElasticsearchTimeoutException; -import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.SubscribableListener; @@ -23,6 +22,7 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.MachineLearningField; @@ -38,13 +38,16 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; import org.elasticsearch.xpack.core.ml.utils.MlPlatformArchitecturesUtil; import org.elasticsearch.xpack.inference.InferencePlugin; +import org.elasticsearch.xpack.inference.telemetry.InferenceTimer; import java.io.IOException; import java.util.List; import java.util.concurrent.ExecutorService; import java.util.function.Consumer; +import static org.elasticsearch.ExceptionsHelper.unwrapCause; import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.inference.telemetry.InferenceStats.modelAndResponseAttributes; import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; @@ -55,6 +58,7 @@ public abstract class BaseElasticsearchInternalService implements InferenceServi protected final ExecutorService inferenceExecutor; protected final Consumer> preferredModelVariantFn; private final ClusterService clusterService; + private final InferenceStats inferenceStats; public enum PreferredModelVariant { LINUX_X86_OPTIMIZED, @@ -69,10 +73,11 @@ public BaseElasticsearchInternalService(InferenceServiceExtension.InferenceServi this.inferenceExecutor = context.threadPool().executor(InferencePlugin.UTILITY_THREAD_POOL_NAME); this.preferredModelVariantFn = this::preferredVariantFromPlatformArchitecture; this.clusterService = context.clusterService(); + this.inferenceStats = context.inferenceStats(); } // For testing. - // platformArchFn enables similating different architectures + // platformArchFn enables simulating different architectures // without extensive mocking on the client to simulate the nodes info response. // TODO make package private once the elser service is moved to the Elasticsearch // service package. @@ -85,6 +90,7 @@ public BaseElasticsearchInternalService( this.inferenceExecutor = context.threadPool().executor(InferencePlugin.UTILITY_THREAD_POOL_NAME); this.preferredModelVariantFn = preferredModelVariantFn; this.clusterService = context.clusterService(); + this.inferenceStats = context.inferenceStats(); } @Override @@ -103,6 +109,7 @@ public void start(Model model, TimeValue timeout, ActionListener finalL return; } + var timer = InferenceTimer.start(); // instead of a subscribably listener, use some wait to wait for the first one. var subscribableListener = SubscribableListener.newForked( forkedListener -> { isBuiltinModelPut(model, forkedListener); } @@ -118,21 +125,25 @@ public void start(Model model, TimeValue timeout, ActionListener finalL client.execute(StartTrainedModelDeploymentAction.INSTANCE, startRequest, responseListener); }); subscribableListener.addTimeout(timeout, threadPool, inferenceExecutor); - subscribableListener.addListener(finalListener.delegateResponse((l, e) -> { + subscribableListener.addListener(ActionListener.wrap(started -> { + inferenceStats.deploymentDuration().record(timer.elapsedMillis(), modelAndResponseAttributes(model, null)); + finalListener.onResponse(started); + }, e -> { if (e instanceof ElasticsearchTimeoutException) { - l.onFailure( - new ModelDeploymentTimeoutException( - format( - "Timed out after [%s] waiting for trained model deployment for inference endpoint [%s] to start. " - + "The inference endpoint can not be used to perform inference until the deployment has started. " - + "Use the trained model stats API to track the state of the deployment.", - timeout, - model.getInferenceEntityId() - ) + var timeoutException = new ModelDeploymentTimeoutException( + format( + "Timed out after [%s] waiting for trained model deployment for inference endpoint [%s] to start. " + + "The inference endpoint can not be used to perform inference until the deployment has started. " + + "Use the trained model stats API to track the state of the deployment.", + timeout, + model.getInferenceEntityId() ) ); + inferenceStats.deploymentDuration().record(timer.elapsedMillis(), modelAndResponseAttributes(model, timeoutException)); + finalListener.onFailure(timeoutException); } else { - l.onFailure(e); + inferenceStats.deploymentDuration().record(timer.elapsedMillis(), modelAndResponseAttributes(model, unwrapCause(e))); + finalListener.onFailure(e); } })); @@ -323,7 +334,7 @@ protected void maybeStartDeployment( InferModelAction.Request request, ActionListener listener ) { - if (isDefaultId(model.getInferenceEntityId()) && ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) { + if (isDefaultId(model.getInferenceEntityId()) && unwrapCause(e) instanceof ResourceNotFoundException) { this.start(model, request.getInferenceTimeout(), listener.delegateFailureAndWrap((l, started) -> { client.execute(InferModelAction.INSTANCE, request, listener); })); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java index 70499c7987965..812cd1e3c6d7f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java @@ -21,6 +21,8 @@ import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.inference.telemetry.InferenceStats; +import org.elasticsearch.inference.telemetry.InferenceStatsTests; import org.elasticsearch.license.MockLicenseState; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; @@ -33,7 +35,6 @@ import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import org.junit.Before; import org.mockito.ArgumentCaptor; @@ -88,7 +89,7 @@ public void setUp() throws Exception { licenseState = mock(); modelRegistry = mock(); serviceRegistry = mock(); - inferenceStats = new InferenceStats(mock(), mock()); + inferenceStats = InferenceStatsTests.mockInferenceStats(); streamingTaskManager = mock(); action = createAction( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java index 4d986cf0a837f..547078d93acc4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.license.MockLicenseState; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportException; @@ -22,7 +23,6 @@ import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; import org.elasticsearch.xpack.inference.common.RateLimitAssignment; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import java.util.List; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java index f26d0675487a5..9e6f4a6260936 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.license.MockLicenseState; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.threadpool.ThreadPool; @@ -20,7 +21,6 @@ import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import java.util.Optional; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 08334de00543d..e96fda569aa12 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -49,6 +49,8 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.inference.telemetry.InferenceStats; +import org.elasticsearch.inference.telemetry.InferenceStatsTests; import org.elasticsearch.license.MockLicenseState; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; @@ -66,7 +68,6 @@ import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.model.TestModel; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import org.junit.After; import org.junit.Before; import org.mockito.stubbing.Answer; @@ -148,7 +149,7 @@ public void tearDownThreadPool() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testFilterNoop() throws Exception { - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = InferenceStatsTests.mockInferenceStats(); ShardBulkInferenceActionFilter filter = createFilter( threadPool, Map.of(), @@ -181,7 +182,7 @@ public void testFilterNoop() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testLicenseInvalidForInference() throws InterruptedException { - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = InferenceStatsTests.mockInferenceStats(); StaticModel model = StaticModel.createRandomInstance(); ShardBulkInferenceActionFilter filter = createFilter( threadPool, @@ -227,7 +228,7 @@ public void testLicenseInvalidForInference() throws InterruptedException { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testInferenceNotFound() throws Exception { - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = InferenceStatsTests.mockInferenceStats(); StaticModel model = StaticModel.createRandomInstance(); ShardBulkInferenceActionFilter filter = createFilter( threadPool, @@ -275,7 +276,7 @@ public void testInferenceNotFound() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testItemFailures() throws Exception { - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = InferenceStatsTests.mockInferenceStats(); StaticModel model = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); ShardBulkInferenceActionFilter filter = createFilter( threadPool, @@ -364,7 +365,7 @@ public void testItemFailures() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testExplicitNull() throws Exception { - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = InferenceStatsTests.mockInferenceStats(); StaticModel model = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); model.putResult("I am a failure", new ChunkedInferenceError(new IllegalArgumentException("boom"))); model.putResult("I am a success", randomChunkedInferenceEmbedding(model, List.of("I am a success"))); @@ -440,7 +441,7 @@ public void testExplicitNull() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testHandleEmptyInput() throws Exception { - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = InferenceStatsTests.mockInferenceStats(); StaticModel model = StaticModel.createRandomInstance(); ShardBulkInferenceActionFilter filter = createFilter( threadPool, @@ -495,7 +496,7 @@ public void testHandleEmptyInput() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testManyRandomDocs() throws Exception { - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = InferenceStatsTests.mockInferenceStats(); Map inferenceModelMap = new HashMap<>(); int numModels = randomIntBetween(1, 3); for (int i = 0; i < numModels; i++) { @@ -559,7 +560,7 @@ public void testManyRandomDocs() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testIndexingPressure() throws Exception { - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = InferenceStatsTests.mockInferenceStats(); final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure(Settings.EMPTY); final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); final StaticModel denseModel = StaticModel.createRandomInstance(TaskType.TEXT_EMBEDDING); @@ -677,7 +678,7 @@ public void testIndexingPressure() throws Exception { @SuppressWarnings("unchecked") public void testIndexingPressureTripsOnInferenceRequestGeneration() throws Exception { - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = InferenceStatsTests.mockInferenceStats(); final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure( Settings.builder().put(MAX_COORDINATING_BYTES.getKey(), "1b").build() ); @@ -765,7 +766,7 @@ public void testIndexingPressureTripsOnInferenceResponseHandling() throws Except Settings.builder().put(MAX_COORDINATING_BYTES.getKey(), (length(doc1Source) + 1) + "b").build() ); - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = InferenceStatsTests.mockInferenceStats(); final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); sparseModel.putResult("bar", randomChunkedInferenceEmbedding(sparseModel, List.of("bar"))); @@ -881,7 +882,7 @@ public void testIndexingPressurePartialFailure() throws Exception { .build() ); - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = InferenceStatsTests.mockInferenceStats(); final ShardBulkInferenceActionFilter filter = createFilter( threadPool, Map.of(sparseModel.getInferenceEntityId(), sparseModel), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index d2c22cdcf6f57..6c01145701d92 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -11,6 +11,7 @@ import org.apache.logging.log4j.Level; import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.ElasticsearchTimeoutException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.LatchedActionListener; import org.elasticsearch.action.support.ActionTestUtils; @@ -37,6 +38,8 @@ import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.telemetry.InferenceStats; +import org.elasticsearch.inference.telemetry.InferenceStatsTests; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; @@ -57,10 +60,12 @@ import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.inference.ModelDeploymentTimeoutException; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats; +import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentTests; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResultsTests; @@ -98,6 +103,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; @@ -113,6 +119,7 @@ import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.assertArg; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.same; @@ -124,12 +131,16 @@ public class ElasticsearchInternalServiceTests extends ESTestCase { - String randomInferenceEntityId = randomAlphaOfLength(10); + private String randomInferenceEntityId; + private InferenceStats inferenceStats; private static ThreadPool threadPool; @Before - public void setUpThreadPool() { + public void setUp() throws Exception { + super.setUp(); + randomInferenceEntityId = randomAlphaOfLength(10); + inferenceStats = InferenceStatsTests.mockInferenceStats(); threadPool = createThreadPool(InferencePlugin.inferenceUtilityExecutor(Settings.EMPTY)); } @@ -1813,7 +1824,8 @@ public void testUpdateWithoutMlEnabled() throws IOException, InterruptedExceptio mock(), threadPool, cs, - Settings.builder().put("xpack.ml.enabled", false).build() + Settings.builder().put("xpack.ml.enabled", false).build(), + inferenceStats ); try (var service = new ElasticsearchInternalService(context)) { var models = List.of(mock(Model.class)); @@ -1855,7 +1867,8 @@ public void testUpdateWithMlEnabled() throws IOException, InterruptedException { client, threadPool, cs, - Settings.builder().put("xpack.ml.enabled", true).build() + Settings.builder().put("xpack.ml.enabled", true).build(), + inferenceStats ); try (var service = new ElasticsearchInternalService(context)) { List models = List.of(model); @@ -1869,7 +1882,82 @@ public void testUpdateWithMlEnabled() throws IOException, InterruptedException { } public void testStart_OnFailure_WhenTimeoutOccurs() throws IOException { - var model = new ElserInternalModel( + var model = mockModel(); + + var client = mockClientForStart( + listener -> listener.onFailure(new ElasticsearchStatusException("failed", RestStatus.GATEWAY_TIMEOUT)) + ); + + try (var service = createService(client)) { + var actionListener = new PlainActionFuture(); + service.start(model, TimeValue.timeValueSeconds(30), actionListener); + var exception = expectThrows( + ElasticsearchStatusException.class, + () -> actionListener.actionGet(TimeValue.timeValueSeconds(30)) + ); + + assertThat(exception.getMessage(), is("failed")); + verify(inferenceStats.deploymentDuration()).record(anyLong(), assertArg(attributes -> { + assertNotNull(attributes); + assertThat(attributes.get("error.type"), is("504")); + })); + } + } + + public void testStart_OnFailure_WhenDeploymentTimeoutOccurs() throws IOException { + var model = mockModel(); + + var client = mockClientForStart( + listener -> listener.onFailure(new ElasticsearchTimeoutException("failed", RestStatus.GATEWAY_TIMEOUT)) + ); + + try (var service = createService(client)) { + var actionListener = new PlainActionFuture(); + service.start(model, TimeValue.timeValueSeconds(30), actionListener); + var exception = expectThrows( + ModelDeploymentTimeoutException.class, + () -> actionListener.actionGet(TimeValue.timeValueSeconds(30)) + ); + + assertThat( + exception.getMessage(), + is( + "Timed out after [30s] waiting for trained model deployment for inference endpoint [inference_id] to start. " + + "The inference endpoint can not be used to perform inference until the deployment has started. " + + "Use the trained model stats API to track the state of the deployment." + ) + ); + verify(inferenceStats.deploymentDuration()).record(anyLong(), assertArg(attributes -> { + assertNotNull(attributes); + assertThat(attributes.get("error.type"), is("408")); + })); + } + } + + public void testStart() throws IOException { + var model = mockModel(); + + var client = mockClientForStart(listener -> { + var response = mock(CreateTrainedModelAssignmentAction.Response.class); + when(response.getTrainedModelAssignment()).thenReturn(TrainedModelAssignmentTests.randomInstance()); + listener.onResponse(response); + }); + + try (var service = createService(client)) { + var actionListener = new PlainActionFuture(); + service.start(model, TimeValue.timeValueSeconds(30), actionListener); + assertTrue(actionListener.actionGet(TimeValue.timeValueSeconds(30))); + + verify(inferenceStats.deploymentDuration()).record(anyLong(), assertArg(attributes -> { + assertNotNull(attributes); + assertNull(attributes.get("error.type")); + assertThat(attributes.get("status_code"), is(200)); + })); + } + } + + private ElserInternalModel mockModel() { + return new ElserInternalModel( "inference_id", TaskType.SPARSE_EMBEDDING, "elasticsearch", @@ -1879,7 +1967,9 @@ public void testStart_OnFailure_WhenTimeoutOccurs() throws IOException { new ElserMlNodeTaskSettings(), null ); + } + private Client mockClientForStart(Consumer> startModelListener) { var client = mock(Client.class); when(client.threadPool()).thenReturn(threadPool); @@ -1895,27 +1985,18 @@ public void testStart_OnFailure_WhenTimeoutOccurs() throws IOException { doAnswer(invocationOnMock -> { ActionListener listener = invocationOnMock.getArgument(2); - listener.onFailure(new ElasticsearchStatusException("failed", RestStatus.GATEWAY_TIMEOUT)); + startModelListener.accept(listener); return Void.TYPE; }).when(client).execute(eq(StartTrainedModelDeploymentAction.INSTANCE), any(), any()); - try (var service = createService(client)) { - var actionListener = new PlainActionFuture(); - service.start(model, TimeValue.timeValueSeconds(30), actionListener); - var exception = expectThrows( - ElasticsearchStatusException.class, - () -> actionListener.actionGet(TimeValue.timeValueSeconds(30)) - ); - - assertThat(exception.getMessage(), is("failed")); - } + return client; } private ElasticsearchInternalService createService(Client client) { var cs = mock(ClusterService.class); var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES)); when(cs.getClusterSettings()).thenReturn(cSettings); - var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client, threadPool, cs, Settings.EMPTY); + var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client, threadPool, cs, Settings.EMPTY, inferenceStats); return new ElasticsearchInternalService(context); } @@ -1924,7 +2005,8 @@ private ElasticsearchInternalService createService(Client client, BaseElasticsea client, threadPool, mock(ClusterService.class), - Settings.EMPTY + Settings.EMPTY, + inferenceStats ); return new ElasticsearchInternalService(context, l -> l.onResponse(modelVariant)); }