Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator;
import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
Expand Down Expand Up @@ -289,8 +288,7 @@ private void inferOnServiceWithMetrics(
);
inferenceResults.publisher().subscribe(taskProcessor);

var instrumentedStream = new PublisherWithMetrics(timer, model);
taskProcessor.subscribe(instrumentedStream);
var instrumentedStream = publisherWithMetrics(timer, model, taskProcessor);

var streamErrorHandler = streamErrorHandler(instrumentedStream);

Expand All @@ -305,7 +303,46 @@ private void inferOnServiceWithMetrics(
}));
}

protected <T> Flow.Publisher<T> streamErrorHandler(Flow.Processor<T, T> upstream) {
private <T> Flow.Publisher<T> publisherWithMetrics(InferenceTimer timer, Model model, Flow.Processor<T, T> upstream) {
return downstream -> {
upstream.subscribe(new Flow.Subscriber<>() {
@Override
public void onSubscribe(Flow.Subscription subscription) {
downstream.onSubscribe(new Flow.Subscription() {
@Override
public void request(long n) {
subscription.request(n);
}

@Override
public void cancel() {
recordMetrics(model, timer, null);
subscription.cancel();
}
});
}

@Override
public void onNext(T item) {
downstream.onNext(item);
}

@Override
public void onError(Throwable throwable) {
recordMetrics(model, timer, throwable);
downstream.onError(throwable);
}

@Override
public void onComplete() {
recordMetrics(model, timer, null);
downstream.onComplete();
}
});
};
}

protected <T> Flow.Publisher<T> streamErrorHandler(Flow.Publisher<T> upstream) {
return upstream;
}

Expand Down Expand Up @@ -359,40 +396,6 @@ private static ElasticsearchStatusException requestModelTaskTypeMismatchExceptio
);
}

private class PublisherWithMetrics extends DelegatingProcessor<InferenceServiceResults.Result, InferenceServiceResults.Result> {

private final InferenceTimer timer;
private final Model model;

private PublisherWithMetrics(InferenceTimer timer, Model model) {
this.timer = timer;
this.model = model;
}

@Override
protected void next(InferenceServiceResults.Result item) {
downstream().onNext(item);
}

@Override
public void onError(Throwable throwable) {
recordMetrics(model, timer, throwable);
super.onError(throwable);
}

@Override
protected void onCancel() {
recordMetrics(model, timer, null);
super.onCancel();
}

@Override
public void onComplete() {
recordMetrics(model, timer, null);
super.onComplete();
}
}

private record NodeRoutingDecision(boolean currentNodeShouldHandleRequest, DiscoveryNode targetNode) {
static NodeRoutingDecision handleLocally() {
return new NodeRoutingDecision(true, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ protected void doExecute(Task task, UnifiedCompletionAction.Request request, Act
* as {@link UnifiedChatCompletionException}.
*/
@Override
protected <T> Flow.Publisher<T> streamErrorHandler(Flow.Processor<T, T> upstream) {
protected <T> Flow.Publisher<T> streamErrorHandler(Flow.Publisher<T> upstream) {
return downstream -> {
upstream.subscribe(new Flow.Subscriber<>() {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ public void testMetricsAfterInferSuccess() {
}

public void testMetricsAfterStreamInferSuccess() {
mockStreamResponse(Flow.Subscriber::onComplete);
mockStreamResponse(Flow.Subscriber::onComplete).subscribe(mock());
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
assertThat(attributes.get("service"), is(serviceId));
assertThat(attributes.get("task_type"), is(taskType.toString()));
Expand All @@ -290,10 +290,7 @@ public void testMetricsAfterStreamInferSuccess() {
public void testMetricsAfterStreamInferFailure() {
var expectedException = new IllegalStateException("hello");
var expectedError = expectedException.getClass().getSimpleName();
mockStreamResponse(subscriber -> {
subscriber.subscribe(mock());
subscriber.onError(expectedException);
});
mockStreamResponse(subscriber -> subscriber.onError(expectedException)).subscribe(mock());
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
assertThat(attributes.get("service"), is(serviceId));
assertThat(attributes.get("task_type"), is(taskType.toString()));
Expand Down Expand Up @@ -368,7 +365,7 @@ public void onFailure(Exception e) {}
assertThat(threadContext.getHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER), is(productUseCase));
}

protected Flow.Publisher<InferenceServiceResults.Result> mockStreamResponse(Consumer<Flow.Processor<?, ?>> action) {
protected Flow.Publisher<InferenceServiceResults.Result> mockStreamResponse(Consumer<Flow.Subscriber<?>> action) {
mockService(true, Set.of(), listener -> {
Flow.Processor<ChunkedToXContent, ChunkedToXContent> taskProcessor = mock();
doAnswer(innerAns -> {
Expand Down