Skip to content

Commit 69ba46d

Browse files
author
Max Hniebergall
committed
initial implementation of request and response handling, manager, and entity
1 parent 467747f commit 69ba46d

File tree

7 files changed

+1317
-13
lines changed

7 files changed

+1317
-13
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.elastic;
9+
10+
import org.elasticsearch.inference.InferenceServiceResults;
11+
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
12+
import org.elasticsearch.xpack.inference.external.http.HttpResult;
13+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
14+
import org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedStreamingProcessor;
15+
import org.elasticsearch.xpack.inference.external.request.Request;
16+
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
17+
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor;
18+
19+
import java.util.concurrent.Flow;
20+
21+
public class EISUnifiedChatCompletionResponseHandler extends ElasticInferenceServiceResponseHandler {
22+
public EISUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
23+
super(requestType, parseFunction);
24+
}
25+
26+
@Override
27+
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
28+
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
29+
var openAiProcessor = new OpenAiUnifiedStreamingProcessor(); // EIS uses the unified API spec
30+
31+
flow.subscribe(serverSentEventProcessor);
32+
serverSentEventProcessor.subscribe(openAiProcessor);
33+
return new StreamingUnifiedChatCompletionResults(openAiProcessor);
34+
}
35+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.http.sender;
9+
10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.action.ActionListener;
13+
import org.elasticsearch.inference.InferenceServiceResults;
14+
import org.elasticsearch.threadpool.ThreadPool;
15+
import org.elasticsearch.xpack.inference.external.elastic.EISUnifiedChatCompletionResponseHandler;
16+
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
17+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
18+
import org.elasticsearch.xpack.inference.external.request.elastic.EISUnifiedChatCompletionRequest;
19+
import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity;
20+
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
21+
22+
import java.util.Objects;
23+
import java.util.function.Supplier;
24+
25+
public class EISUnifiedCompletionRequestManager extends ElasticInferenceServiceRequestManager {
26+
27+
private static final Logger logger = LogManager.getLogger(EISUnifiedCompletionRequestManager.class);
28+
29+
private static final ResponseHandler HANDLER = createCompletionHandler();
30+
31+
public static EISUnifiedCompletionRequestManager of(ElasticInferenceServiceCompletionModel model, ThreadPool threadPool) {
32+
return new EISUnifiedCompletionRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool));
33+
}
34+
35+
private final ElasticInferenceServiceCompletionModel model;
36+
37+
private EISUnifiedCompletionRequestManager(ElasticInferenceServiceCompletionModel model, ThreadPool threadPool) {
38+
super(threadPool, model);
39+
this.model = Objects.requireNonNull(model);
40+
}
41+
42+
@Override
43+
public void execute(
44+
InferenceInputs inferenceInputs,
45+
RequestSender requestSender,
46+
Supplier<Boolean> hasRequestCompletedFunction,
47+
ActionListener<InferenceServiceResults> listener
48+
) {
49+
50+
EISUnifiedChatCompletionRequest request = new EISUnifiedChatCompletionRequest(
51+
inferenceInputs.castTo(UnifiedChatInput.class),
52+
model,
53+
null // TODO
54+
);
55+
56+
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
57+
}
58+
59+
private static ResponseHandler createCompletionHandler() {
60+
return new EISUnifiedChatCompletionResponseHandler("eis completion", OpenAiChatCompletionResponseEntity::fromResponse);
61+
}
62+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.request.elastic;
9+
10+
import org.apache.http.HttpHeaders;
11+
import org.apache.http.client.methods.HttpPost;
12+
import org.apache.http.entity.ByteArrayEntity;
13+
import org.apache.http.message.BasicHeader;
14+
import org.elasticsearch.common.Strings;
15+
import org.elasticsearch.tasks.Task;
16+
import org.elasticsearch.xcontent.XContentType;
17+
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
18+
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
19+
import org.elasticsearch.xpack.inference.external.request.Request;
20+
import org.elasticsearch.xpack.inference.external.request.openai.OpenAiRequest;
21+
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
22+
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
23+
24+
import java.net.URI;
25+
import java.nio.charset.StandardCharsets;
26+
import java.util.Objects;
27+
28+
public class EISUnifiedChatCompletionRequest implements OpenAiRequest {
29+
30+
private final ElasticInferenceServiceCompletionModel model;
31+
private final UnifiedChatInput unifiedChatInput;
32+
private final URI uri;
33+
private final TraceContext traceContext;
34+
35+
public EISUnifiedChatCompletionRequest(
36+
UnifiedChatInput unifiedChatInput,
37+
ElasticInferenceServiceCompletionModel model,
38+
TraceContext traceContext
39+
) {
40+
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
41+
this.model = Objects.requireNonNull(model);
42+
this.uri = model.uri();
43+
this.traceContext = traceContext;
44+
45+
}
46+
47+
@Override
48+
public HttpRequest createHttpRequest() {
49+
var httpPost = new HttpPost(uri);
50+
var requestEntity = Strings.toString(new EISUnifiedChatCompletionRequestEntity(unifiedChatInput));
51+
52+
ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
53+
httpPost.setEntity(byteEntity);
54+
55+
if (traceContext != null) {
56+
propagateTraceContext(httpPost);
57+
}
58+
59+
httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));
60+
61+
return new HttpRequest(httpPost, getInferenceEntityId());
62+
}
63+
64+
@Override
65+
public URI getURI() {
66+
return uri;
67+
}
68+
69+
@Override
70+
public Request truncate() {
71+
// No truncation for OpenAI chat completions
72+
return this;
73+
}
74+
75+
@Override
76+
public boolean[] getTruncationInfo() {
77+
// No truncation for OpenAI chat completions
78+
return null;
79+
}
80+
81+
@Override
82+
public String getInferenceEntityId() {
83+
return model.getInferenceEntityId();
84+
}
85+
86+
@Override
87+
public boolean isStreaming() {
88+
return true;
89+
}
90+
91+
public TraceContext getTraceContext() {
92+
return traceContext;
93+
}
94+
95+
private void propagateTraceContext(HttpPost httpPost) {
96+
var traceParent = traceContext.traceParent();
97+
var traceState = traceContext.traceState();
98+
99+
if (traceParent != null) {
100+
httpPost.setHeader(Task.TRACE_PARENT_HTTP_HEADER, traceParent);
101+
}
102+
103+
if (traceState != null) {
104+
httpPost.setHeader(Task.TRACE_STATE, traceState);
105+
}
106+
}
107+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.request.elastic;
9+
10+
import org.elasticsearch.inference.UnifiedCompletionRequest;
11+
import org.elasticsearch.xcontent.ToXContentObject;
12+
import org.elasticsearch.xcontent.XContentBuilder;
13+
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
14+
15+
import java.io.IOException;
16+
import java.util.Objects;
17+
18+
public class EISUnifiedChatCompletionRequestEntity implements ToXContentObject {
19+
20+
public static final String NAME_FIELD = "name";
21+
public static final String TOOL_CALL_ID_FIELD = "tool_call_id";
22+
public static final String TOOL_CALLS_FIELD = "tool_calls";
23+
public static final String ID_FIELD = "id";
24+
public static final String FUNCTION_FIELD = "function";
25+
public static final String ARGUMENTS_FIELD = "arguments";
26+
public static final String DESCRIPTION_FIELD = "description";
27+
public static final String PARAMETERS_FIELD = "parameters";
28+
public static final String STRICT_FIELD = "strict";
29+
public static final String TOP_P_FIELD = "top_p";
30+
public static final String USER_FIELD = "user";
31+
public static final String STREAM_FIELD = "stream";
32+
private static final String NUMBER_OF_RETURNED_CHOICES_FIELD = "n";
33+
private static final String MODEL_FIELD = "model";
34+
public static final String MESSAGES_FIELD = "messages";
35+
private static final String ROLE_FIELD = "role";
36+
private static final String CONTENT_FIELD = "content";
37+
private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens";
38+
private static final String STOP_FIELD = "stop";
39+
private static final String TEMPERATURE_FIELD = "temperature";
40+
private static final String TOOL_CHOICE_FIELD = "tool_choice";
41+
private static final String TOOL_FIELD = "tools";
42+
private static final String TEXT_FIELD = "text";
43+
private static final String TYPE_FIELD = "type";
44+
private static final String STREAM_OPTIONS_FIELD = "stream_options";
45+
private static final String INCLUDE_USAGE_FIELD = "include_usage";
46+
47+
private final UnifiedCompletionRequest unifiedRequest;
48+
private final boolean stream;
49+
50+
public EISUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput) {
51+
Objects.requireNonNull(unifiedChatInput);
52+
53+
this.unifiedRequest = unifiedChatInput.getRequest();
54+
this.stream = unifiedChatInput.stream();
55+
}
56+
57+
@Override
58+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
59+
builder.startObject();
60+
builder.startArray(MESSAGES_FIELD);
61+
{
62+
for (UnifiedCompletionRequest.Message message : unifiedRequest.messages()) {
63+
builder.startObject();
64+
{
65+
switch (message.content()) {
66+
case UnifiedCompletionRequest.ContentString contentString -> builder.field(CONTENT_FIELD, contentString.content());
67+
case UnifiedCompletionRequest.ContentObjects contentObjects -> {
68+
builder.startArray(CONTENT_FIELD);
69+
for (UnifiedCompletionRequest.ContentObject contentObject : contentObjects.contentObjects()) {
70+
builder.startObject();
71+
builder.field(TEXT_FIELD, contentObject.text());
72+
builder.field(TYPE_FIELD, contentObject.type());
73+
builder.endObject();
74+
}
75+
builder.endArray();
76+
}
77+
}
78+
79+
builder.field(ROLE_FIELD, message.role());
80+
if (message.name() != null) {
81+
builder.field(NAME_FIELD, message.name());
82+
}
83+
if (message.toolCallId() != null) {
84+
builder.field(TOOL_CALL_ID_FIELD, message.toolCallId());
85+
}
86+
if (message.toolCalls() != null) {
87+
builder.startArray(TOOL_CALLS_FIELD);
88+
for (UnifiedCompletionRequest.ToolCall toolCall : message.toolCalls()) {
89+
builder.startObject();
90+
{
91+
builder.field(ID_FIELD, toolCall.id());
92+
builder.startObject(FUNCTION_FIELD);
93+
{
94+
builder.field(ARGUMENTS_FIELD, toolCall.function().arguments());
95+
builder.field(NAME_FIELD, toolCall.function().name());
96+
}
97+
builder.endObject();
98+
builder.field(TYPE_FIELD, toolCall.type());
99+
}
100+
builder.endObject();
101+
}
102+
builder.endArray();
103+
}
104+
}
105+
builder.endObject();
106+
}
107+
}
108+
builder.endArray();
109+
110+
if (unifiedRequest.maxCompletionTokens() != null) {
111+
builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedRequest.maxCompletionTokens());
112+
}
113+
114+
builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1);
115+
116+
if (unifiedRequest.stop() != null && unifiedRequest.stop().isEmpty() == false) {
117+
builder.field(STOP_FIELD, unifiedRequest.stop());
118+
}
119+
if (unifiedRequest.temperature() != null) {
120+
builder.field(TEMPERATURE_FIELD, unifiedRequest.temperature());
121+
}
122+
if (unifiedRequest.toolChoice() != null) {
123+
if (unifiedRequest.toolChoice() instanceof UnifiedCompletionRequest.ToolChoiceString) {
124+
builder.field(TOOL_CHOICE_FIELD, ((UnifiedCompletionRequest.ToolChoiceString) unifiedRequest.toolChoice()).value());
125+
} else if (unifiedRequest.toolChoice() instanceof UnifiedCompletionRequest.ToolChoiceObject) {
126+
builder.startObject(TOOL_CHOICE_FIELD);
127+
{
128+
builder.field(TYPE_FIELD, ((UnifiedCompletionRequest.ToolChoiceObject) unifiedRequest.toolChoice()).type());
129+
builder.startObject(FUNCTION_FIELD);
130+
{
131+
builder.field(
132+
NAME_FIELD,
133+
((UnifiedCompletionRequest.ToolChoiceObject) unifiedRequest.toolChoice()).function().name()
134+
);
135+
}
136+
builder.endObject();
137+
}
138+
builder.endObject();
139+
}
140+
}
141+
if (unifiedRequest.tools() != null && unifiedRequest.tools().isEmpty() == false) {
142+
builder.startArray(TOOL_FIELD);
143+
for (UnifiedCompletionRequest.Tool t : unifiedRequest.tools()) {
144+
builder.startObject();
145+
{
146+
builder.field(TYPE_FIELD, t.type());
147+
builder.startObject(FUNCTION_FIELD);
148+
{
149+
builder.field(DESCRIPTION_FIELD, t.function().description());
150+
builder.field(NAME_FIELD, t.function().name());
151+
builder.field(PARAMETERS_FIELD, t.function().parameters());
152+
if (t.function().strict() != null) {
153+
builder.field(STRICT_FIELD, t.function().strict());
154+
}
155+
}
156+
builder.endObject();
157+
}
158+
builder.endObject();
159+
}
160+
builder.endArray();
161+
}
162+
if (unifiedRequest.topP() != null) {
163+
builder.field(TOP_P_FIELD, unifiedRequest.topP());
164+
}
165+
166+
builder.field(STREAM_FIELD, stream);
167+
if (stream) {
168+
builder.startObject(STREAM_OPTIONS_FIELD);
169+
builder.field(INCLUDE_USAGE_FIELD, true);
170+
builder.endObject();
171+
}
172+
builder.endObject();
173+
174+
return builder;
175+
}
176+
}

0 commit comments

Comments
 (0)