Skip to content

Commit 39e2c27

Browse files
Working response from openai
1 parent 69ba46d commit 39e2c27

15 files changed

+1147
-1988
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/EISUnifiedChatCompletionResponseHandler.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ public EISUnifiedChatCompletionResponseHandler(String requestType, ResponseParse
2323
super(requestType, parseFunction);
2424
}
2525

26+
@Override
27+
public boolean canHandleStreamingResponses() {
28+
return true;
29+
}
30+
2631
@Override
2732
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
2833
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EISUnifiedCompletionRequestManager.java

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import org.elasticsearch.xpack.inference.external.request.elastic.EISUnifiedChatCompletionRequest;
1919
import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity;
2020
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
21+
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
2122

2223
import java.util.Objects;
2324
import java.util.function.Supplier;
@@ -28,15 +29,29 @@ public class EISUnifiedCompletionRequestManager extends ElasticInferenceServiceR
2829

2930
private static final ResponseHandler HANDLER = createCompletionHandler();
3031

31-
public static EISUnifiedCompletionRequestManager of(ElasticInferenceServiceCompletionModel model, ThreadPool threadPool) {
32-
return new EISUnifiedCompletionRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool));
32+
public static EISUnifiedCompletionRequestManager of(
33+
ElasticInferenceServiceCompletionModel model,
34+
ThreadPool threadPool,
35+
TraceContext traceContext
36+
) {
37+
return new EISUnifiedCompletionRequestManager(
38+
Objects.requireNonNull(model),
39+
Objects.requireNonNull(threadPool),
40+
Objects.requireNonNull(traceContext)
41+
);
3342
}
3443

3544
private final ElasticInferenceServiceCompletionModel model;
45+
private final TraceContext traceContext;
3646

37-
private EISUnifiedCompletionRequestManager(ElasticInferenceServiceCompletionModel model, ThreadPool threadPool) {
47+
private EISUnifiedCompletionRequestManager(
48+
ElasticInferenceServiceCompletionModel model,
49+
ThreadPool threadPool,
50+
TraceContext traceContext
51+
) {
3852
super(threadPool, model);
39-
this.model = Objects.requireNonNull(model);
53+
this.model = model;
54+
this.traceContext = traceContext;
4055
}
4156

4257
@Override
@@ -50,7 +65,7 @@ public void execute(
5065
EISUnifiedChatCompletionRequest request = new EISUnifiedChatCompletionRequest(
5166
inferenceInputs.castTo(UnifiedChatInput.class),
5267
model,
53-
null // TODO
68+
traceContext
5469
);
5570

5671
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequest.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
import java.nio.charset.StandardCharsets;
2626
import java.util.Objects;
2727

28+
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
29+
2830
public class EISUnifiedChatCompletionRequest implements OpenAiRequest {
2931

3032
private final ElasticInferenceServiceCompletionModel model;
@@ -47,7 +49,10 @@ public EISUnifiedChatCompletionRequest(
4749
@Override
4850
public HttpRequest createHttpRequest() {
4951
var httpPost = new HttpPost(uri);
50-
var requestEntity = Strings.toString(new EISUnifiedChatCompletionRequestEntity(unifiedChatInput));
52+
var requestEntity = Strings.toString(
53+
// TODO remove the modelId() call if not used
54+
new EISUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId())
55+
);
5156

5257
ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
5358
httpPost.setEntity(byteEntity);
@@ -57,6 +62,8 @@ public HttpRequest createHttpRequest() {
5762
}
5863

5964
httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));
65+
// TODO remove EIS doesn't use an API key
66+
httpPost.setHeader(createAuthBearerHeader(model.getSecretSettings().apiKey()));
6067

6168
return new HttpRequest(httpPost, getInferenceEntityId());
6269
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequestEntity.java

Lines changed: 10 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -7,168 +7,32 @@
77

88
package org.elasticsearch.xpack.inference.external.request.elastic;
99

10-
import org.elasticsearch.inference.UnifiedCompletionRequest;
1110
import org.elasticsearch.xcontent.ToXContentObject;
1211
import org.elasticsearch.xcontent.XContentBuilder;
1312
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
13+
import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity;
1414

1515
import java.io.IOException;
1616
import java.util.Objects;
1717

1818
public class EISUnifiedChatCompletionRequestEntity implements ToXContentObject {
19-
20-
public static final String NAME_FIELD = "name";
21-
public static final String TOOL_CALL_ID_FIELD = "tool_call_id";
22-
public static final String TOOL_CALLS_FIELD = "tool_calls";
23-
public static final String ID_FIELD = "id";
24-
public static final String FUNCTION_FIELD = "function";
25-
public static final String ARGUMENTS_FIELD = "arguments";
26-
public static final String DESCRIPTION_FIELD = "description";
27-
public static final String PARAMETERS_FIELD = "parameters";
28-
public static final String STRICT_FIELD = "strict";
29-
public static final String TOP_P_FIELD = "top_p";
30-
public static final String USER_FIELD = "user";
31-
public static final String STREAM_FIELD = "stream";
32-
private static final String NUMBER_OF_RETURNED_CHOICES_FIELD = "n";
19+
// TODO remove this if EIS doesn't use it
3320
private static final String MODEL_FIELD = "model";
34-
public static final String MESSAGES_FIELD = "messages";
35-
private static final String ROLE_FIELD = "role";
36-
private static final String CONTENT_FIELD = "content";
37-
private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens";
38-
private static final String STOP_FIELD = "stop";
39-
private static final String TEMPERATURE_FIELD = "temperature";
40-
private static final String TOOL_CHOICE_FIELD = "tool_choice";
41-
private static final String TOOL_FIELD = "tools";
42-
private static final String TEXT_FIELD = "text";
43-
private static final String TYPE_FIELD = "type";
44-
private static final String STREAM_OPTIONS_FIELD = "stream_options";
45-
private static final String INCLUDE_USAGE_FIELD = "include_usage";
46-
47-
private final UnifiedCompletionRequest unifiedRequest;
48-
private final boolean stream;
4921

50-
public EISUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput) {
51-
Objects.requireNonNull(unifiedChatInput);
22+
private final UnifiedChatCompletionRequestEntity unifiedRequestEntity;
23+
private final String modelId;
5224

53-
this.unifiedRequest = unifiedChatInput.getRequest();
54-
this.stream = unifiedChatInput.stream();
25+
public EISUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, String modelId) {
26+
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(Objects.requireNonNull(unifiedChatInput));
27+
this.modelId = Objects.requireNonNull(modelId);
5528
}
5629

5730
@Override
5831
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
5932
builder.startObject();
60-
builder.startArray(MESSAGES_FIELD);
61-
{
62-
for (UnifiedCompletionRequest.Message message : unifiedRequest.messages()) {
63-
builder.startObject();
64-
{
65-
switch (message.content()) {
66-
case UnifiedCompletionRequest.ContentString contentString -> builder.field(CONTENT_FIELD, contentString.content());
67-
case UnifiedCompletionRequest.ContentObjects contentObjects -> {
68-
builder.startArray(CONTENT_FIELD);
69-
for (UnifiedCompletionRequest.ContentObject contentObject : contentObjects.contentObjects()) {
70-
builder.startObject();
71-
builder.field(TEXT_FIELD, contentObject.text());
72-
builder.field(TYPE_FIELD, contentObject.type());
73-
builder.endObject();
74-
}
75-
builder.endArray();
76-
}
77-
}
78-
79-
builder.field(ROLE_FIELD, message.role());
80-
if (message.name() != null) {
81-
builder.field(NAME_FIELD, message.name());
82-
}
83-
if (message.toolCallId() != null) {
84-
builder.field(TOOL_CALL_ID_FIELD, message.toolCallId());
85-
}
86-
if (message.toolCalls() != null) {
87-
builder.startArray(TOOL_CALLS_FIELD);
88-
for (UnifiedCompletionRequest.ToolCall toolCall : message.toolCalls()) {
89-
builder.startObject();
90-
{
91-
builder.field(ID_FIELD, toolCall.id());
92-
builder.startObject(FUNCTION_FIELD);
93-
{
94-
builder.field(ARGUMENTS_FIELD, toolCall.function().arguments());
95-
builder.field(NAME_FIELD, toolCall.function().name());
96-
}
97-
builder.endObject();
98-
builder.field(TYPE_FIELD, toolCall.type());
99-
}
100-
builder.endObject();
101-
}
102-
builder.endArray();
103-
}
104-
}
105-
builder.endObject();
106-
}
107-
}
108-
builder.endArray();
109-
110-
if (unifiedRequest.maxCompletionTokens() != null) {
111-
builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedRequest.maxCompletionTokens());
112-
}
113-
114-
builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1);
115-
116-
if (unifiedRequest.stop() != null && unifiedRequest.stop().isEmpty() == false) {
117-
builder.field(STOP_FIELD, unifiedRequest.stop());
118-
}
119-
if (unifiedRequest.temperature() != null) {
120-
builder.field(TEMPERATURE_FIELD, unifiedRequest.temperature());
121-
}
122-
if (unifiedRequest.toolChoice() != null) {
123-
if (unifiedRequest.toolChoice() instanceof UnifiedCompletionRequest.ToolChoiceString) {
124-
builder.field(TOOL_CHOICE_FIELD, ((UnifiedCompletionRequest.ToolChoiceString) unifiedRequest.toolChoice()).value());
125-
} else if (unifiedRequest.toolChoice() instanceof UnifiedCompletionRequest.ToolChoiceObject) {
126-
builder.startObject(TOOL_CHOICE_FIELD);
127-
{
128-
builder.field(TYPE_FIELD, ((UnifiedCompletionRequest.ToolChoiceObject) unifiedRequest.toolChoice()).type());
129-
builder.startObject(FUNCTION_FIELD);
130-
{
131-
builder.field(
132-
NAME_FIELD,
133-
((UnifiedCompletionRequest.ToolChoiceObject) unifiedRequest.toolChoice()).function().name()
134-
);
135-
}
136-
builder.endObject();
137-
}
138-
builder.endObject();
139-
}
140-
}
141-
if (unifiedRequest.tools() != null && unifiedRequest.tools().isEmpty() == false) {
142-
builder.startArray(TOOL_FIELD);
143-
for (UnifiedCompletionRequest.Tool t : unifiedRequest.tools()) {
144-
builder.startObject();
145-
{
146-
builder.field(TYPE_FIELD, t.type());
147-
builder.startObject(FUNCTION_FIELD);
148-
{
149-
builder.field(DESCRIPTION_FIELD, t.function().description());
150-
builder.field(NAME_FIELD, t.function().name());
151-
builder.field(PARAMETERS_FIELD, t.function().parameters());
152-
if (t.function().strict() != null) {
153-
builder.field(STRICT_FIELD, t.function().strict());
154-
}
155-
}
156-
builder.endObject();
157-
}
158-
builder.endObject();
159-
}
160-
builder.endArray();
161-
}
162-
if (unifiedRequest.topP() != null) {
163-
builder.field(TOP_P_FIELD, unifiedRequest.topP());
164-
}
165-
166-
builder.field(STREAM_FIELD, stream);
167-
if (stream) {
168-
builder.startObject(STREAM_OPTIONS_FIELD);
169-
builder.field(INCLUDE_USAGE_FIELD, true);
170-
builder.endObject();
171-
}
33+
unifiedRequestEntity.toXContent(builder, params);
34+
// TODO remove this if EIS doesn't use it
35+
builder.field(MODEL_FIELD, modelId);
17236
builder.endObject();
17337

17438
return builder;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,7 @@ public HttpRequest createHttpRequest() {
4444
HttpPost httpPost = new HttpPost(account.uri());
4545

4646
ByteArrayEntity byteEntity = new ByteArrayEntity(
47-
Strings.toString(
48-
new OpenAiUnifiedChatCompletionRequestEntity(
49-
unifiedChatInput,
50-
new OpenAiUnifiedChatCompletionRequestEntity.ModelFields(
51-
model.getServiceSettings().modelId(),
52-
model.getTaskSettings().user()
53-
)
54-
)
55-
).getBytes(StandardCharsets.UTF_8)
47+
Strings.toString(new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model)).getBytes(StandardCharsets.UTF_8)
5648
);
5749
httpPost.setEntity(byteEntity);
5850

0 commit comments

Comments
 (0)