Skip to content

Commit 6aecbf5

Browse files
authored
[ML] Add _stream path for chat_completions task (elastic#121006) (elastic#121102)
1 parent 86de107 commit 6aecbf5

File tree

3 files changed

+36
-1
lines changed

3 files changed

+36
-1
lines changed

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,8 @@ protected Deque<ServerSentEvent> unifiedCompletionInferOnMockService(
355355
List<String> input,
356356
@Nullable Consumer<Response> responseConsumerCallback
357357
) throws Exception {
358-
var endpoint = Strings.format("_inference/%s/%s/_unified", taskType, modelId);
358+
var route = randomBoolean() ? "_stream" : "_unified"; // TODO remove unified route
359+
var endpoint = Strings.format("_inference/%s/%s/%s", taskType, modelId, route);
359360
return callAsyncUnified(endpoint, input, "user", responseConsumerCallback);
360361
}
361362

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ public final class Paths {
3131
+ INFERENCE_ID
3232
+ "}/_stream";
3333

34+
// TODO remove the _unified path
3435
public static final String UNIFIED_SUFFIX = "_unified";
3536
static final String UNIFIED_INFERENCE_ID_PATH = "_inference/{" + TASK_TYPE_OR_INFERENCE_ID + "}/" + UNIFIED_SUFFIX;
3637
static final String UNIFIED_TASK_TYPE_INFERENCE_ID_PATH = "_inference/{"

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,17 @@
99

1010
import org.apache.lucene.util.SetOnce;
1111
import org.elasticsearch.action.ActionListener;
12+
import org.elasticsearch.client.internal.node.NodeClient;
13+
import org.elasticsearch.inference.TaskType;
1214
import org.elasticsearch.rest.RestChannel;
15+
import org.elasticsearch.rest.RestRequest;
1316
import org.elasticsearch.rest.Scope;
1417
import org.elasticsearch.rest.ServerlessScope;
1518
import org.elasticsearch.threadpool.ThreadPool;
1619
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
20+
import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;
1721

22+
import java.io.IOException;
1823
import java.util.List;
1924
import java.util.Objects;
2025

@@ -50,4 +55,32 @@ protected InferenceAction.Request prepareInferenceRequest(InferenceAction.Reques
5055
protected ActionListener<InferenceAction.Response> listener(RestChannel channel) {
5156
return new ServerSentEventsRestActionListener(channel, threadPool);
5257
}
58+
59+
@Override
60+
protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
61+
var params = parseParams(restRequest);
62+
var inferTimeout = parseTimeout(restRequest);
63+
64+
if (params.taskType() == TaskType.CHAT_COMPLETION) {
65+
UnifiedCompletionAction.Request request;
66+
try (var parser = restRequest.contentParser()) {
67+
request = UnifiedCompletionAction.Request.parseRequest(params.inferenceEntityId(), params.taskType(), inferTimeout, parser);
68+
}
69+
70+
return channel -> client.execute(
71+
UnifiedCompletionAction.INSTANCE,
72+
request,
73+
new ServerSentEventsRestActionListener(channel, threadPool)
74+
);
75+
} else {
76+
InferenceAction.Request.Builder requestBuilder;
77+
try (var parser = restRequest.contentParser()) {
78+
requestBuilder = InferenceAction.Request.parseRequest(params.inferenceEntityId(), params.taskType(), parser);
79+
}
80+
81+
requestBuilder.setInferenceTimeout(inferTimeout);
82+
var request = prepareInferenceRequest(requestBuilder);
83+
return channel -> client.execute(InferenceAction.INSTANCE, request, listener(channel));
84+
}
85+
}
5386
}

0 commit comments

Comments
 (0)