Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator;
import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
Expand Down Expand Up @@ -297,8 +296,7 @@ private void inferOnServiceWithMetrics(
);
inferenceResults.publisher().subscribe(taskProcessor);

var instrumentedStream = new PublisherWithMetrics(timer, model, request, localNodeId);
taskProcessor.subscribe(instrumentedStream);
var instrumentedStream = publisherWithMetrics(timer, model, request, localNodeId, taskProcessor);

var streamErrorHandler = streamErrorHandler(instrumentedStream);

Expand All @@ -313,7 +311,52 @@ private void inferOnServiceWithMetrics(
}));
}

protected <T> Flow.Publisher<T> streamErrorHandler(Flow.Processor<T, T> upstream) {
private <T> Flow.Publisher<T> publisherWithMetrics(
InferenceTimer timer,
Model model,
Request request,
String localNodeId,
Flow.Processor<T, T> upstream
) {
return downstream -> {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Publisher has a single method subscribe, so when our downstream calls subscribe we will forward it to the upstream.subscribe, tying the upstream and downstream together with this function

upstream.subscribe(new Flow.Subscriber<>() {
@Override
public void onSubscribe(Flow.Subscription subscription) {
downstream.onSubscribe(new Flow.Subscription() {
@Override
public void request(long n) {
subscription.request(n);
}

@Override
public void cancel() {
recordRequestDurationMetrics(model, timer, request, localNodeId, null);
subscription.cancel();
}
});
}

@Override
public void onNext(T item) {
downstream.onNext(item);
}

@Override
public void onError(Throwable throwable) {
recordRequestDurationMetrics(model, timer, request, localNodeId, throwable);
downstream.onError(throwable);
}

@Override
public void onComplete() {
recordRequestDurationMetrics(model, timer, request, localNodeId, null);
downstream.onComplete();
}
});
};
}

protected <T> Flow.Publisher<T> streamErrorHandler(Flow.Publisher<T> upstream) {
return upstream;
}

Expand Down Expand Up @@ -386,44 +429,6 @@ private static ElasticsearchStatusException requestModelTaskTypeMismatchExceptio
);
}

private class PublisherWithMetrics extends DelegatingProcessor<InferenceServiceResults.Result, InferenceServiceResults.Result> {

private final InferenceTimer timer;
private final Model model;
private final Request request;
private final String localNodeId;

private PublisherWithMetrics(InferenceTimer timer, Model model, Request request, String localNodeId) {
this.timer = timer;
this.model = model;
this.request = request;
this.localNodeId = localNodeId;
}

@Override
protected void next(InferenceServiceResults.Result item) {
downstream().onNext(item);
}

@Override
public void onError(Throwable throwable) {
recordRequestDurationMetrics(model, timer, request, localNodeId, throwable);
super.onError(throwable);
}

@Override
protected void onCancel() {
recordRequestDurationMetrics(model, timer, request, localNodeId, null);
super.onCancel();
}

@Override
public void onComplete() {
recordRequestDurationMetrics(model, timer, request, localNodeId, null);
super.onComplete();
}
}

private record NodeRoutingDecision(boolean currentNodeShouldHandleRequest, DiscoveryNode targetNode) {
static NodeRoutingDecision handleLocally() {
return new NodeRoutingDecision(true, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ protected void doExecute(Task task, UnifiedCompletionAction.Request request, Act
* as {@link UnifiedChatCompletionException}.
*/
@Override
protected <T> Flow.Publisher<T> streamErrorHandler(Flow.Processor<T, T> upstream) {
protected <T> Flow.Publisher<T> streamErrorHandler(Flow.Publisher<T> upstream) {
return downstream -> {
upstream.subscribe(new Flow.Subscriber<>() {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ public void testMetricsAfterInferSuccess() {
}

public void testMetricsAfterStreamInferSuccess() {
mockStreamResponse(Flow.Subscriber::onComplete);
mockStreamResponse(Flow.Subscriber::onComplete).subscribe(mock());
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
assertThat(attributes.get("service"), is(serviceId));
assertThat(attributes.get("task_type"), is(taskType.toString()));
Expand All @@ -306,10 +306,7 @@ public void testMetricsAfterStreamInferSuccess() {
public void testMetricsAfterStreamInferFailure() {
var expectedException = new IllegalStateException("hello");
var expectedError = expectedException.getClass().getSimpleName();
mockStreamResponse(subscriber -> {
subscriber.subscribe(mock());
subscriber.onError(expectedException);
});
mockStreamResponse(subscriber -> subscriber.onError(expectedException)).subscribe(mock());
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
assertThat(attributes.get("service"), is(serviceId));
assertThat(attributes.get("task_type"), is(taskType.toString()));
Expand Down Expand Up @@ -388,7 +385,7 @@ public void onFailure(Exception e) {}
assertThat(threadContext.getHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER), is(productUseCase));
}

protected Flow.Publisher<InferenceServiceResults.Result> mockStreamResponse(Consumer<Flow.Processor<?, ?>> action) {
protected Flow.Publisher<InferenceServiceResults.Result> mockStreamResponse(Consumer<Flow.Subscriber<?>> action) {
mockService(true, Set.of(), listener -> {
Flow.Processor<ChunkedToXContent, ChunkedToXContent> taskProcessor = mock();
doAnswer(innerAns -> {
Expand Down