diff --git a/docs/changelog/129140.yaml b/docs/changelog/129140.yaml new file mode 100644 index 0000000000000..e7ee59122c34f --- /dev/null +++ b/docs/changelog/129140.yaml @@ -0,0 +1,5 @@ +pr: 129140 +summary: Increment inference stats counter for shard bulk inference calls +area: Machine Learning +type: enhancement +issues: [] diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 915a4d3f7af9b..2709d9de19c5c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -344,22 +344,24 @@ public Collection createComponents(PluginServices services) { } inferenceServiceRegistry.set(serviceRegistry); + var meterRegistry = services.telemetryProvider().getMeterRegistry(); + var inferenceStats = InferenceStats.create(meterRegistry); + var inferenceStatsBinding = new PluginComponentBinding<>(InferenceStats.class, inferenceStats); + var actionFilter = new ShardBulkInferenceActionFilter( services.clusterService(), serviceRegistry, modelRegistry.get(), getLicenseState(), - services.indexingPressure() + services.indexingPressure(), + inferenceStats ); shardBulkInferenceActionFilter.set(actionFilter); - var meterRegistry = services.telemetryProvider().getMeterRegistry(); - var inferenceStats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry)); - components.add(serviceRegistry); components.add(modelRegistry.get()); components.add(httpClientManager); - components.add(inferenceStats); + components.add(inferenceStatsBinding); // Only add InferenceServiceNodeLocalRateLimitCalculator (which is a ClusterStateListener) for cluster aware rate limiting, // if the rate limiting feature flags are enabled, otherwise provide noop implementation diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index a4ab8663e8664..082ece347208a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -63,6 +63,7 @@ import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.mapper.SemanticTextUtils; import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import java.io.IOException; import java.util.ArrayList; @@ -78,6 +79,8 @@ import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunks; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunksLegacy; +import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes; +import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes; /** * A {@link MappedActionFilter} that intercepts {@link BulkShardRequest} to apply inference on fields specified @@ -112,6 +115,7 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter { private final ModelRegistry modelRegistry; private final XPackLicenseState licenseState; private final IndexingPressure indexingPressure; + private final InferenceStats inferenceStats; private volatile long batchSizeInBytes; public ShardBulkInferenceActionFilter( @@ -119,13 +123,15 @@ public ShardBulkInferenceActionFilter( InferenceServiceRegistry inferenceServiceRegistry, ModelRegistry modelRegistry, XPackLicenseState licenseState, - IndexingPressure indexingPressure + IndexingPressure indexingPressure, + InferenceStats inferenceStats ) { this.clusterService = clusterService; this.inferenceServiceRegistry = inferenceServiceRegistry; this.modelRegistry = modelRegistry; this.licenseState = licenseState; this.indexingPressure = indexingPressure; + this.inferenceStats = inferenceStats; this.batchSizeInBytes = INDICES_INFERENCE_BATCH_SIZE.get(clusterService.getSettings()).getBytes(); clusterService.getClusterSettings().addSettingsUpdateConsumer(INDICES_INFERENCE_BATCH_SIZE, this::setBatchSize); } @@ -386,10 +392,12 @@ public void onFailure(Exception exc) { public void onResponse(List results) { try (onFinish) { var requestsIterator = requests.iterator(); + int success = 0; for (ChunkedInference result : results) { var request = requestsIterator.next(); var acc = inferenceResults.get(request.bulkItemIndex); if (result instanceof ChunkedInferenceError error) { + recordRequestCountMetrics(inferenceProvider.model, 1, error.exception()); acc.addFailure( new InferenceException( "Exception when running inference id [{}] on field [{}]", @@ -399,6 +407,7 @@ public void onResponse(List results) { ) ); } else { + success++; acc.addOrUpdateResponse( new FieldInferenceResponse( request.field(), @@ -412,12 +421,16 @@ public void onResponse(List results) { ); } } + if (success > 0) { + recordRequestCountMetrics(inferenceProvider.model, success, null); + } } } @Override public void onFailure(Exception exc) { try (onFinish) { + recordRequestCountMetrics(inferenceProvider.model, requests.size(), exc); for (FieldInferenceRequest request : requests) { addInferenceResponseFailure( request.bulkItemIndex, @@ -444,6 +457,14 @@ public void onFailure(Exception exc) { ); } + private void recordRequestCountMetrics(Model model, int incrementBy, Throwable throwable) { + Map requestCountAttributes = new HashMap<>(); + requestCountAttributes.putAll(modelAttributes(model)); + requestCountAttributes.putAll(responseAttributes(throwable)); + requestCountAttributes.put("inference_source", "semantic_text_bulk"); + inferenceStats.requestCount().incrementBy(incrementBy, requestCountAttributes); + } + /** * Adds all inference requests associated with their respective inference IDs to the given {@code requestsMap} * for the specified {@code item}. diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index f592774b7a356..a7cb0234aee59 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -66,6 +66,7 @@ import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.model.TestModel; import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import org.junit.After; import org.junit.Before; import org.mockito.stubbing.Answer; @@ -80,6 +81,7 @@ import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import static org.elasticsearch.index.IndexingPressure.MAX_COORDINATING_BYTES; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; @@ -103,9 +105,11 @@ import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.assertArg; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.longThat; import static org.mockito.Mockito.any; +import static org.mockito.Mockito.atMost; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -127,7 +131,9 @@ public ShardBulkInferenceActionFilterTests(boolean useLegacyFormat) { @ParametersFactory public static Iterable parameters() throws Exception { - return List.of(new Object[] { true }, new Object[] { false }); + List lst = new ArrayList<>(); + lst.add(new Object[] { true }); + return lst; } @Before @@ -142,7 +148,15 @@ public void tearDownThreadPool() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testFilterNoop() throws Exception { - ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), NOOP_INDEXING_PRESSURE, useLegacyFormat, true); + final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + ShardBulkInferenceActionFilter filter = createFilter( + threadPool, + Map.of(), + NOOP_INDEXING_PRESSURE, + useLegacyFormat, + true, + inferenceStats + ); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { @@ -167,8 +181,16 @@ public void testFilterNoop() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testLicenseInvalidForInference() throws InterruptedException { + final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); StaticModel model = StaticModel.createRandomInstance(); - ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), NOOP_INDEXING_PRESSURE, useLegacyFormat, false); + ShardBulkInferenceActionFilter filter = createFilter( + threadPool, + Map.of(), + NOOP_INDEXING_PRESSURE, + useLegacyFormat, + false, + inferenceStats + ); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { @@ -205,13 +227,15 @@ public void testLicenseInvalidForInference() throws InterruptedException { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testInferenceNotFound() throws Exception { + final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); StaticModel model = StaticModel.createRandomInstance(); ShardBulkInferenceActionFilter filter = createFilter( threadPool, Map.of(model.getInferenceEntityId(), model), NOOP_INDEXING_PRESSURE, useLegacyFormat, - true + true, + inferenceStats ); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { @@ -251,14 +275,15 @@ public void testInferenceNotFound() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testItemFailures() throws Exception { + final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); StaticModel model = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); - ShardBulkInferenceActionFilter filter = createFilter( threadPool, Map.of(model.getInferenceEntityId(), model), NOOP_INDEXING_PRESSURE, useLegacyFormat, - true + true, + inferenceStats ); model.putResult("I am a failure", new ChunkedInferenceError(new IllegalArgumentException("boom"))); model.putResult("I am a success", randomChunkedInferenceEmbedding(model, List.of("I am a success"))); @@ -316,10 +341,30 @@ public void testItemFailures() throws Exception { request.setInferenceFieldMap(inferenceFieldMap); filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + + AtomicInteger success = new AtomicInteger(0); + AtomicInteger failed = new AtomicInteger(0); + verify(inferenceStats.requestCount(), atMost(3)).incrementBy(anyLong(), assertArg(attributes -> { + var statusCode = attributes.get("status_code"); + if (statusCode == null) { + failed.incrementAndGet(); + assertThat(attributes.get("error.type"), is("IllegalArgumentException")); + } else { + success.incrementAndGet(); + assertThat(statusCode, is(200)); + } + assertThat(attributes.get("task_type"), is(model.getTaskType().toString())); + assertThat(attributes.get("model_id"), is(model.getServiceSettings().modelId())); + assertThat(attributes.get("service"), is(model.getConfigurations().getService())); + assertThat(attributes.get("inference_source"), is("semantic_text_bulk")); + })); + assertThat(success.get(), equalTo(1)); + assertThat(failed.get(), equalTo(2)); } @SuppressWarnings({ "unchecked", "rawtypes" }) public void testExplicitNull() throws Exception { + final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); StaticModel model = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); model.putResult("I am a failure", new ChunkedInferenceError(new IllegalArgumentException("boom"))); model.putResult("I am a success", randomChunkedInferenceEmbedding(model, List.of("I am a success"))); @@ -329,7 +374,8 @@ public void testExplicitNull() throws Exception { Map.of(model.getInferenceEntityId(), model), NOOP_INDEXING_PRESSURE, useLegacyFormat, - true + true, + inferenceStats ); CountDownLatch chainExecuted = new CountDownLatch(1); @@ -394,13 +440,15 @@ public void testExplicitNull() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testHandleEmptyInput() throws Exception { + final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); StaticModel model = StaticModel.createRandomInstance(); ShardBulkInferenceActionFilter filter = createFilter( threadPool, Map.of(model.getInferenceEntityId(), model), NOOP_INDEXING_PRESSURE, useLegacyFormat, - true + true, + inferenceStats ); CountDownLatch chainExecuted = new CountDownLatch(1); @@ -447,6 +495,7 @@ public void testHandleEmptyInput() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testManyRandomDocs() throws Exception { + final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); Map inferenceModelMap = new HashMap<>(); int numModels = randomIntBetween(1, 3); for (int i = 0; i < numModels; i++) { @@ -471,7 +520,14 @@ public void testManyRandomDocs() throws Exception { modifiedRequests[id] = res[1]; } - ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap, NOOP_INDEXING_PRESSURE, useLegacyFormat, true); + ShardBulkInferenceActionFilter filter = createFilter( + threadPool, + inferenceModelMap, + NOOP_INDEXING_PRESSURE, + useLegacyFormat, + true, + inferenceStats + ); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { @@ -503,6 +559,7 @@ public void testManyRandomDocs() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testIndexingPressure() throws Exception { + final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure(Settings.EMPTY); final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); final StaticModel denseModel = StaticModel.createRandomInstance(TaskType.TEXT_EMBEDDING); @@ -511,7 +568,8 @@ public void testIndexingPressure() throws Exception { Map.of(sparseModel.getInferenceEntityId(), sparseModel, denseModel.getInferenceEntityId(), denseModel), indexingPressure, useLegacyFormat, - true + true, + inferenceStats ); XContentBuilder doc0Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "a test value"); @@ -619,6 +677,7 @@ public void testIndexingPressure() throws Exception { @SuppressWarnings("unchecked") public void testIndexingPressureTripsOnInferenceRequestGeneration() throws Exception { + final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure( Settings.builder().put(MAX_COORDINATING_BYTES.getKey(), "1b").build() ); @@ -628,7 +687,8 @@ public void testIndexingPressureTripsOnInferenceRequestGeneration() throws Excep Map.of(sparseModel.getInferenceEntityId(), sparseModel), indexingPressure, useLegacyFormat, - true + true, + inferenceStats ); XContentBuilder doc1Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "bar"); @@ -702,6 +762,7 @@ public void testIndexingPressureTripsOnInferenceResponseHandling() throws Except Settings.builder().put(MAX_COORDINATING_BYTES.getKey(), (bytesUsed(doc1Source) + 1) + "b").build() ); + final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); sparseModel.putResult("bar", randomChunkedInferenceEmbedding(sparseModel, List.of("bar"))); @@ -710,7 +771,8 @@ public void testIndexingPressureTripsOnInferenceResponseHandling() throws Except Map.of(sparseModel.getInferenceEntityId(), sparseModel), indexingPressure, useLegacyFormat, - true + true, + inferenceStats ); CountDownLatch chainExecuted = new CountDownLatch(1); @@ -813,12 +875,14 @@ public void testIndexingPressurePartialFailure() throws Exception { .build() ); + final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); final ShardBulkInferenceActionFilter filter = createFilter( threadPool, Map.of(sparseModel.getInferenceEntityId(), sparseModel), indexingPressure, useLegacyFormat, - true + true, + inferenceStats ); CountDownLatch chainExecuted = new CountDownLatch(1); @@ -893,7 +957,8 @@ private static ShardBulkInferenceActionFilter createFilter( Map modelMap, IndexingPressure indexingPressure, boolean useLegacyFormat, - boolean isLicenseValidForInference + boolean isLicenseValidForInference, + InferenceStats inferenceStats ) { ModelRegistry modelRegistry = mock(ModelRegistry.class); Answer unparsedModelAnswer = invocationOnMock -> { @@ -970,7 +1035,8 @@ private static ShardBulkInferenceActionFilter createFilter( inferenceServiceRegistry, modelRegistry, licenseState, - indexingPressure + indexingPressure, + inferenceStats ); }