diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java index 536e75e405baa..81086c4488661 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java @@ -39,7 +39,6 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; -import org.elasticsearch.xpack.inference.common.DelegatingProcessor; import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator; import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; import org.elasticsearch.xpack.inference.registry.ModelRegistry; @@ -297,8 +296,7 @@ private void inferOnServiceWithMetrics( ); inferenceResults.publisher().subscribe(taskProcessor); - var instrumentedStream = new PublisherWithMetrics(timer, model, request, localNodeId); - taskProcessor.subscribe(instrumentedStream); + var instrumentedStream = publisherWithMetrics(timer, model, request, localNodeId, taskProcessor); var streamErrorHandler = streamErrorHandler(instrumentedStream); @@ -313,7 +311,52 @@ private void inferOnServiceWithMetrics( })); } - protected Flow.Publisher streamErrorHandler(Flow.Processor upstream) { + private Flow.Publisher publisherWithMetrics( + InferenceTimer timer, + Model model, + Request request, + String localNodeId, + Flow.Processor upstream + ) { + return downstream -> { + upstream.subscribe(new Flow.Subscriber<>() { + @Override + public void onSubscribe(Flow.Subscription subscription) { + downstream.onSubscribe(new Flow.Subscription() { + @Override + public void request(long n) { + subscription.request(n); + } + + @Override + public void cancel() { + recordRequestDurationMetrics(model, timer, request, localNodeId, null); + subscription.cancel(); + } + }); + } + + @Override + public void onNext(T item) { + downstream.onNext(item); + } + + @Override + public void onError(Throwable throwable) { + recordRequestDurationMetrics(model, timer, request, localNodeId, throwable); + downstream.onError(throwable); + } + + @Override + public void onComplete() { + recordRequestDurationMetrics(model, timer, request, localNodeId, null); + downstream.onComplete(); + } + }); + }; + } + + protected Flow.Publisher streamErrorHandler(Flow.Publisher upstream) { return upstream; } @@ -386,44 +429,6 @@ private static ElasticsearchStatusException requestModelTaskTypeMismatchExceptio ); } - private class PublisherWithMetrics extends DelegatingProcessor { - - private final InferenceTimer timer; - private final Model model; - private final Request request; - private final String localNodeId; - - private PublisherWithMetrics(InferenceTimer timer, Model model, Request request, String localNodeId) { - this.timer = timer; - this.model = model; - this.request = request; - this.localNodeId = localNodeId; - } - - @Override - protected void next(InferenceServiceResults.Result item) { - downstream().onNext(item); - } - - @Override - public void onError(Throwable throwable) { - recordRequestDurationMetrics(model, timer, request, localNodeId, throwable); - super.onError(throwable); - } - - @Override - protected void onCancel() { - recordRequestDurationMetrics(model, timer, request, localNodeId, null); - super.onCancel(); - } - - @Override - public void onComplete() { - recordRequestDurationMetrics(model, timer, request, localNodeId, null); - super.onComplete(); - } - } - private record NodeRoutingDecision(boolean currentNodeShouldHandleRequest, DiscoveryNode targetNode) { static NodeRoutingDecision handleLocally() { return new NodeRoutingDecision(true, null); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java index 4c8f03fae9184..0d14149d7ab75 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java @@ -102,7 +102,7 @@ protected void doExecute(Task task, UnifiedCompletionAction.Request request, Act * as {@link UnifiedChatCompletionException}. */ @Override - protected Flow.Publisher streamErrorHandler(Flow.Processor upstream) { + protected Flow.Publisher streamErrorHandler(Flow.Publisher upstream) { return downstream -> { upstream.subscribe(new Flow.Subscriber<>() { @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java index 9e94f214219cb..caea68bd861d0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java @@ -291,7 +291,7 @@ public void testMetricsAfterInferSuccess() { } public void testMetricsAfterStreamInferSuccess() { - mockStreamResponse(Flow.Subscriber::onComplete); + mockStreamResponse(Flow.Subscriber::onComplete).subscribe(mock()); verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { assertThat(attributes.get("service"), is(serviceId)); assertThat(attributes.get("task_type"), is(taskType.toString())); @@ -306,10 +306,7 @@ public void testMetricsAfterStreamInferSuccess() { public void testMetricsAfterStreamInferFailure() { var expectedException = new IllegalStateException("hello"); var expectedError = expectedException.getClass().getSimpleName(); - mockStreamResponse(subscriber -> { - subscriber.subscribe(mock()); - subscriber.onError(expectedException); - }); + mockStreamResponse(subscriber -> subscriber.onError(expectedException)).subscribe(mock()); verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { assertThat(attributes.get("service"), is(serviceId)); assertThat(attributes.get("task_type"), is(taskType.toString())); @@ -388,7 +385,7 @@ public void onFailure(Exception e) {} assertThat(threadContext.getHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER), is(productUseCase)); } - protected Flow.Publisher mockStreamResponse(Consumer> action) { + protected Flow.Publisher mockStreamResponse(Consumer> action) { mockService(true, Set.of(), listener -> { Flow.Processor taskProcessor = mock(); doAnswer(innerAns -> {