Skip to content

Commit b4c5e16

Browse files
committed
[ML] Refactor stream metrics (elastic#125092)
Remove the use of DelegatingProcessor and replace it with an inline processor.
1 parent ab8b1f0 commit b4c5e16

File tree

3 files changed

+45
-45
lines changed

3 files changed

+45
-45
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java

Lines changed: 41 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
4040
import org.elasticsearch.xpack.inference.InferencePlugin;
4141
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
42-
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
4342
import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator;
4443
import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
4544
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
@@ -289,8 +288,7 @@ private void inferOnServiceWithMetrics(
289288
);
290289
inferenceResults.publisher().subscribe(taskProcessor);
291290

292-
var instrumentedStream = new PublisherWithMetrics(timer, model);
293-
taskProcessor.subscribe(instrumentedStream);
291+
var instrumentedStream = publisherWithMetrics(timer, model, taskProcessor);
294292

295293
var streamErrorHandler = streamErrorHandler(instrumentedStream);
296294

@@ -305,7 +303,46 @@ private void inferOnServiceWithMetrics(
305303
}));
306304
}
307305

308-
protected <T> Flow.Publisher<T> streamErrorHandler(Flow.Processor<T, T> upstream) {
306+
private <T> Flow.Publisher<T> publisherWithMetrics(InferenceTimer timer, Model model, Flow.Processor<T, T> upstream) {
307+
return downstream -> {
308+
upstream.subscribe(new Flow.Subscriber<>() {
309+
@Override
310+
public void onSubscribe(Flow.Subscription subscription) {
311+
downstream.onSubscribe(new Flow.Subscription() {
312+
@Override
313+
public void request(long n) {
314+
subscription.request(n);
315+
}
316+
317+
@Override
318+
public void cancel() {
319+
recordMetrics(model, timer, null);
320+
subscription.cancel();
321+
}
322+
});
323+
}
324+
325+
@Override
326+
public void onNext(T item) {
327+
downstream.onNext(item);
328+
}
329+
330+
@Override
331+
public void onError(Throwable throwable) {
332+
recordMetrics(model, timer, throwable);
333+
downstream.onError(throwable);
334+
}
335+
336+
@Override
337+
public void onComplete() {
338+
recordMetrics(model, timer, null);
339+
downstream.onComplete();
340+
}
341+
});
342+
};
343+
}
344+
345+
protected <T> Flow.Publisher<T> streamErrorHandler(Flow.Publisher<T> upstream) {
309346
return upstream;
310347
}
311348

@@ -359,40 +396,6 @@ private static ElasticsearchStatusException requestModelTaskTypeMismatchExceptio
359396
);
360397
}
361398

362-
private class PublisherWithMetrics extends DelegatingProcessor<InferenceServiceResults.Result, InferenceServiceResults.Result> {
363-
364-
private final InferenceTimer timer;
365-
private final Model model;
366-
367-
private PublisherWithMetrics(InferenceTimer timer, Model model) {
368-
this.timer = timer;
369-
this.model = model;
370-
}
371-
372-
@Override
373-
protected void next(InferenceServiceResults.Result item) {
374-
downstream().onNext(item);
375-
}
376-
377-
@Override
378-
public void onError(Throwable throwable) {
379-
recordMetrics(model, timer, throwable);
380-
super.onError(throwable);
381-
}
382-
383-
@Override
384-
protected void onCancel() {
385-
recordMetrics(model, timer, null);
386-
super.onCancel();
387-
}
388-
389-
@Override
390-
public void onComplete() {
391-
recordMetrics(model, timer, null);
392-
super.onComplete();
393-
}
394-
}
395-
396399
private record NodeRoutingDecision(boolean currentNodeShouldHandleRequest, DiscoveryNode targetNode) {
397400
static NodeRoutingDecision handleLocally() {
398401
return new NodeRoutingDecision(true, null);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ protected void doExecute(Task task, UnifiedCompletionAction.Request request, Act
102102
* as {@link UnifiedChatCompletionException}.
103103
*/
104104
@Override
105-
protected <T> Flow.Publisher<T> streamErrorHandler(Flow.Processor<T, T> upstream) {
105+
protected <T> Flow.Publisher<T> streamErrorHandler(Flow.Publisher<T> upstream) {
106106
return downstream -> {
107107
upstream.subscribe(new Flow.Subscriber<>() {
108108
@Override

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ public void testMetricsAfterInferSuccess() {
277277
}
278278

279279
public void testMetricsAfterStreamInferSuccess() {
280-
mockStreamResponse(Flow.Subscriber::onComplete);
280+
mockStreamResponse(Flow.Subscriber::onComplete).subscribe(mock());
281281
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
282282
assertThat(attributes.get("service"), is(serviceId));
283283
assertThat(attributes.get("task_type"), is(taskType.toString()));
@@ -290,10 +290,7 @@ public void testMetricsAfterStreamInferSuccess() {
290290
public void testMetricsAfterStreamInferFailure() {
291291
var expectedException = new IllegalStateException("hello");
292292
var expectedError = expectedException.getClass().getSimpleName();
293-
mockStreamResponse(subscriber -> {
294-
subscriber.subscribe(mock());
295-
subscriber.onError(expectedException);
296-
});
293+
mockStreamResponse(subscriber -> subscriber.onError(expectedException)).subscribe(mock());
297294
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
298295
assertThat(attributes.get("service"), is(serviceId));
299296
assertThat(attributes.get("task_type"), is(taskType.toString()));
@@ -368,7 +365,7 @@ public void onFailure(Exception e) {}
368365
assertThat(threadContext.getHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER), is(productUseCase));
369366
}
370367

371-
protected Flow.Publisher<InferenceServiceResults.Result> mockStreamResponse(Consumer<Flow.Processor<?, ?>> action) {
368+
protected Flow.Publisher<InferenceServiceResults.Result> mockStreamResponse(Consumer<Flow.Subscriber<?>> action) {
372369
mockService(true, Set.of(), listener -> {
373370
Flow.Processor<ChunkedToXContent, ChunkedToXContent> taskProcessor = mock();
374371
doAnswer(innerAns -> {

0 commit comments

Comments
 (0)