|
9 | 9 |
|
10 | 10 | import org.apache.lucene.util.SetOnce; |
11 | 11 | import org.elasticsearch.action.ActionListener; |
| 12 | +import org.elasticsearch.client.internal.node.NodeClient; |
| 13 | +import org.elasticsearch.inference.TaskType; |
12 | 14 | import org.elasticsearch.rest.RestChannel; |
| 15 | +import org.elasticsearch.rest.RestRequest; |
13 | 16 | import org.elasticsearch.rest.Scope; |
14 | 17 | import org.elasticsearch.rest.ServerlessScope; |
15 | 18 | import org.elasticsearch.threadpool.ThreadPool; |
16 | 19 | import org.elasticsearch.xpack.core.inference.action.InferenceAction; |
| 20 | +import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; |
17 | 21 |
|
| 22 | +import java.io.IOException; |
18 | 23 | import java.util.List; |
19 | 24 | import java.util.Objects; |
20 | 25 |
|
@@ -50,4 +55,32 @@ protected InferenceAction.Request prepareInferenceRequest(InferenceAction.Reques |
50 | 55 | protected ActionListener<InferenceAction.Response> listener(RestChannel channel) { |
51 | 56 | return new ServerSentEventsRestActionListener(channel, threadPool); |
52 | 57 | } |
| 58 | + |
| 59 | + @Override |
| 60 | + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { |
| 61 | + var params = parseParams(restRequest); |
| 62 | + var inferTimeout = parseTimeout(restRequest); |
| 63 | + |
| 64 | + if (params.taskType() == TaskType.CHAT_COMPLETION) { |
| 65 | + UnifiedCompletionAction.Request request; |
| 66 | + try (var parser = restRequest.contentParser()) { |
| 67 | + request = UnifiedCompletionAction.Request.parseRequest(params.inferenceEntityId(), params.taskType(), inferTimeout, parser); |
| 68 | + } |
| 69 | + |
| 70 | + return channel -> client.execute( |
| 71 | + UnifiedCompletionAction.INSTANCE, |
| 72 | + request, |
| 73 | + new ServerSentEventsRestActionListener(channel, threadPool) |
| 74 | + ); |
| 75 | + } else { |
| 76 | + InferenceAction.Request.Builder requestBuilder; |
| 77 | + try (var parser = restRequest.contentParser()) { |
| 78 | + requestBuilder = InferenceAction.Request.parseRequest(params.inferenceEntityId(), params.taskType(), parser); |
| 79 | + } |
| 80 | + |
| 81 | + requestBuilder.setInferenceTimeout(inferTimeout); |
| 82 | + var request = prepareInferenceRequest(requestBuilder); |
| 83 | + return channel -> client.execute(InferenceAction.INSTANCE, request, listener(channel)); |
| 84 | + } |
| 85 | + } |
53 | 86 | } |
0 commit comments