Skip to content

Commit b8db0ee

Browse files
authored
[ML] Refactor stream metrics (#125092)
Remove the use of DelegatingProcessor and replace it with an inline processor.
1 parent 397c9c5 commit b8db0ee

File tree

3 files changed

+51
-49
lines changed

3 files changed

+51
-49
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java

Lines changed: 47 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
4040
import org.elasticsearch.xpack.inference.InferencePlugin;
4141
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
42-
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
4342
import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator;
4443
import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
4544
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
@@ -297,8 +296,7 @@ private void inferOnServiceWithMetrics(
297296
);
298297
inferenceResults.publisher().subscribe(taskProcessor);
299298

300-
var instrumentedStream = new PublisherWithMetrics(timer, model, request, localNodeId);
301-
taskProcessor.subscribe(instrumentedStream);
299+
var instrumentedStream = publisherWithMetrics(timer, model, request, localNodeId, taskProcessor);
302300

303301
var streamErrorHandler = streamErrorHandler(instrumentedStream);
304302

@@ -313,7 +311,52 @@ private void inferOnServiceWithMetrics(
313311
}));
314312
}
315313

316-
protected <T> Flow.Publisher<T> streamErrorHandler(Flow.Processor<T, T> upstream) {
314+
private <T> Flow.Publisher<T> publisherWithMetrics(
315+
InferenceTimer timer,
316+
Model model,
317+
Request request,
318+
String localNodeId,
319+
Flow.Processor<T, T> upstream
320+
) {
321+
return downstream -> {
322+
upstream.subscribe(new Flow.Subscriber<>() {
323+
@Override
324+
public void onSubscribe(Flow.Subscription subscription) {
325+
downstream.onSubscribe(new Flow.Subscription() {
326+
@Override
327+
public void request(long n) {
328+
subscription.request(n);
329+
}
330+
331+
@Override
332+
public void cancel() {
333+
recordRequestDurationMetrics(model, timer, request, localNodeId, null);
334+
subscription.cancel();
335+
}
336+
});
337+
}
338+
339+
@Override
340+
public void onNext(T item) {
341+
downstream.onNext(item);
342+
}
343+
344+
@Override
345+
public void onError(Throwable throwable) {
346+
recordRequestDurationMetrics(model, timer, request, localNodeId, throwable);
347+
downstream.onError(throwable);
348+
}
349+
350+
@Override
351+
public void onComplete() {
352+
recordRequestDurationMetrics(model, timer, request, localNodeId, null);
353+
downstream.onComplete();
354+
}
355+
});
356+
};
357+
}
358+
359+
protected <T> Flow.Publisher<T> streamErrorHandler(Flow.Publisher<T> upstream) {
317360
return upstream;
318361
}
319362

@@ -386,44 +429,6 @@ private static ElasticsearchStatusException requestModelTaskTypeMismatchExceptio
386429
);
387430
}
388431

389-
private class PublisherWithMetrics extends DelegatingProcessor<InferenceServiceResults.Result, InferenceServiceResults.Result> {
390-
391-
private final InferenceTimer timer;
392-
private final Model model;
393-
private final Request request;
394-
private final String localNodeId;
395-
396-
private PublisherWithMetrics(InferenceTimer timer, Model model, Request request, String localNodeId) {
397-
this.timer = timer;
398-
this.model = model;
399-
this.request = request;
400-
this.localNodeId = localNodeId;
401-
}
402-
403-
@Override
404-
protected void next(InferenceServiceResults.Result item) {
405-
downstream().onNext(item);
406-
}
407-
408-
@Override
409-
public void onError(Throwable throwable) {
410-
recordRequestDurationMetrics(model, timer, request, localNodeId, throwable);
411-
super.onError(throwable);
412-
}
413-
414-
@Override
415-
protected void onCancel() {
416-
recordRequestDurationMetrics(model, timer, request, localNodeId, null);
417-
super.onCancel();
418-
}
419-
420-
@Override
421-
public void onComplete() {
422-
recordRequestDurationMetrics(model, timer, request, localNodeId, null);
423-
super.onComplete();
424-
}
425-
}
426-
427432
private record NodeRoutingDecision(boolean currentNodeShouldHandleRequest, DiscoveryNode targetNode) {
428433
static NodeRoutingDecision handleLocally() {
429434
return new NodeRoutingDecision(true, null);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ protected void doExecute(Task task, UnifiedCompletionAction.Request request, Act
102102
* as {@link UnifiedChatCompletionException}.
103103
*/
104104
@Override
105-
protected <T> Flow.Publisher<T> streamErrorHandler(Flow.Processor<T, T> upstream) {
105+
protected <T> Flow.Publisher<T> streamErrorHandler(Flow.Publisher<T> upstream) {
106106
return downstream -> {
107107
upstream.subscribe(new Flow.Subscriber<>() {
108108
@Override

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ public void testMetricsAfterInferSuccess() {
291291
}
292292

293293
public void testMetricsAfterStreamInferSuccess() {
294-
mockStreamResponse(Flow.Subscriber::onComplete);
294+
mockStreamResponse(Flow.Subscriber::onComplete).subscribe(mock());
295295
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
296296
assertThat(attributes.get("service"), is(serviceId));
297297
assertThat(attributes.get("task_type"), is(taskType.toString()));
@@ -306,10 +306,7 @@ public void testMetricsAfterStreamInferSuccess() {
306306
public void testMetricsAfterStreamInferFailure() {
307307
var expectedException = new IllegalStateException("hello");
308308
var expectedError = expectedException.getClass().getSimpleName();
309-
mockStreamResponse(subscriber -> {
310-
subscriber.subscribe(mock());
311-
subscriber.onError(expectedException);
312-
});
309+
mockStreamResponse(subscriber -> subscriber.onError(expectedException)).subscribe(mock());
313310
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
314311
assertThat(attributes.get("service"), is(serviceId));
315312
assertThat(attributes.get("task_type"), is(taskType.toString()));
@@ -388,7 +385,7 @@ public void onFailure(Exception e) {}
388385
assertThat(threadContext.getHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER), is(productUseCase));
389386
}
390387

391-
protected Flow.Publisher<InferenceServiceResults.Result> mockStreamResponse(Consumer<Flow.Processor<?, ?>> action) {
388+
protected Flow.Publisher<InferenceServiceResults.Result> mockStreamResponse(Consumer<Flow.Subscriber<?>> action) {
392389
mockService(true, Set.of(), listener -> {
393390
Flow.Processor<ChunkedToXContent, ChunkedToXContent> taskProcessor = mock();
394391
doAnswer(innerAns -> {

0 commit comments

Comments
 (0)