Skip to content

Commit 7d10123

Browse files
Incrementing reference count for body content and fixing tests
1 parent e99a781 commit 7d10123

File tree

4 files changed

+28
-27
lines changed

4 files changed

+28
-27
lines changed

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,8 +360,7 @@ protected Deque<ServerSentEvent> unifiedCompletionInferOnMockService(
360360
List<String> input,
361361
@Nullable Consumer<Response> responseConsumerCallback
362362
) throws Exception {
363-
var route = randomBoolean() ? "_stream" : "_unified"; // TODO remove unified route
364-
var endpoint = Strings.format("_inference/%s/%s/%s", taskType, modelId, route);
363+
var endpoint = Strings.format("_inference/%s/%s/_stream", taskType, modelId);
365364
return callAsyncUnified(endpoint, input, "user", responseConsumerCallback);
366365
}
367366

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
2525
import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy;
2626
import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;
27+
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
2728
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
2829

2930
import java.io.IOException;
@@ -77,27 +78,33 @@ protected void doExecute(Task task, InferenceActionProxy.Request request, Action
7778
}
7879
}
7980

80-
private void sendUnifiedCompletionRequest(InferenceActionProxy.Request request, ActionListener<InferenceAction.Response> listener)
81-
throws IOException {
81+
private void sendUnifiedCompletionRequest(InferenceActionProxy.Request request, ActionListener<InferenceAction.Response> listener) {
82+
// format any validation exceptions from the rest -> transport path as UnifiedChatCompletionException
83+
var unifiedErrorFormatListener = listener.delegateResponse((l, e) -> l.onFailure(UnifiedChatCompletionException.fromThrowable(e)));
8284

83-
if (request.isStreaming() == false) {
84-
throw new ElasticsearchStatusException(
85-
"The [chat_completion] task type only supports streaming, please try again with the _stream API",
86-
RestStatus.BAD_REQUEST
87-
);
88-
}
85+
try {
86+
if (request.isStreaming() == false) {
87+
unifiedErrorFormatListener.onFailure(new ElasticsearchStatusException(
88+
"The [chat_completion] task type only supports streaming, please try again with the _stream API",
89+
RestStatus.BAD_REQUEST
90+
));
91+
return;
92+
}
8993

90-
UnifiedCompletionAction.Request unifiedRequest;
91-
try (var parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, request.getContent(), request.getContentType())) {
92-
unifiedRequest = UnifiedCompletionAction.Request.parseRequest(
93-
request.getInferenceEntityId(),
94-
request.getTaskType(),
95-
request.getTimeout(),
96-
parser
97-
);
98-
}
94+
UnifiedCompletionAction.Request unifiedRequest;
95+
try (var parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, request.getContent(), request.getContentType())) {
96+
unifiedRequest = UnifiedCompletionAction.Request.parseRequest(
97+
request.getInferenceEntityId(),
98+
request.getTaskType(),
99+
request.getTimeout(),
100+
parser
101+
);
102+
}
99103

100-
executeAsyncWithOrigin(client, INFERENCE_ORIGIN, UnifiedCompletionAction.INSTANCE, unifiedRequest, listener);
104+
executeAsyncWithOrigin(client, INFERENCE_ORIGIN, UnifiedCompletionAction.INSTANCE, unifiedRequest, unifiedErrorFormatListener);
105+
} catch (Exception e) {
106+
unifiedErrorFormatListener.onFailure(e);
107+
}
101108
}
102109

103110
private void sendInferenceActionRequest(InferenceActionProxy.Request request, ActionListener<InferenceAction.Response> listener)

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
5454
shouldStream()
5555
);
5656

57-
return channel -> client.execute(InferenceActionProxy.INSTANCE, request, listener(channel));
57+
return channel -> client.execute(InferenceActionProxy.INSTANCE, request, ActionListener.withRef(listener(channel), content));
5858
}
5959

6060
protected abstract boolean shouldStream();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import org.elasticsearch.rest.ServerlessScope;
1515
import org.elasticsearch.threadpool.ThreadPool;
1616
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
17-
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
1817

1918
import java.util.List;
2019
import java.util.Objects;
@@ -44,11 +43,7 @@ public List<Route> routes() {
4443

4544
@Override
4645
protected ActionListener<InferenceAction.Response> listener(RestChannel channel) {
47-
// TODO confirm with Pat that this will work
48-
return new ServerSentEventsRestActionListener(channel, threadPool).delegateResponse((l, e) -> {
49-
// format any validation exceptions from the rest -> transport path as UnifiedChatCompletionException
50-
l.onFailure(UnifiedChatCompletionException.fromThrowable(e));
51-
});
46+
return new ServerSentEventsRestActionListener(channel, threadPool);
5247
}
5348

5449
@Override

0 commit comments

Comments
 (0)