Skip to content

Commit 2846942

Browse files
author
Max Hniebergall
committed
Add outbound request writing (WIP)
1 parent 1e30c6d commit 2846942

File tree

6 files changed

+284
-84
lines changed

6 files changed

+284
-84
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequest.java

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import org.elasticsearch.xcontent.ConstructingObjectParser;
1717
import org.elasticsearch.xcontent.ObjectParser;
1818
import org.elasticsearch.xcontent.ParseField;
19+
import org.elasticsearch.xcontent.ToXContent;
20+
import org.elasticsearch.xcontent.XContentBuilder;
1921
import org.elasticsearch.xcontent.XContentParseException;
2022
import org.elasticsearch.xcontent.XContentParser;
2123

@@ -40,6 +42,10 @@ public record UnifiedCompletionRequest(
4042
@Nullable String user
4143
) implements Writeable {
4244

45+
public sealed interface Content extends NamedWriteable permits ContentObjects, ContentString {
46+
void toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException;
47+
}
48+
4349
@SuppressWarnings("unchecked")
4450
static final ConstructingObjectParser<UnifiedCompletionRequest, Void> PARSER = new ConstructingObjectParser<>(
4551
InferenceAction.NAME,
@@ -158,8 +164,6 @@ public void writeTo(StreamOutput out) throws IOException {
158164
}
159165
}
160166

161-
public sealed interface Content extends NamedWriteable permits ContentObjects, ContentString {}
162-
163167
public record ContentObjects(List<ContentObject> contentObjects) implements Content, Writeable {
164168

165169
public static final String NAME = "content_objects";
@@ -173,6 +177,17 @@ public void writeTo(StreamOutput out) throws IOException {
173177
out.writeCollection(contentObjects);
174178
}
175179

180+
@Override
181+
public void toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
182+
builder.startArray();
183+
for (ContentObject contentObject : contentObjects) {
184+
builder.startObject();
185+
contentObject.toXContentObject(builder, params);
186+
builder.endObject();
187+
}
188+
builder.endArray();
189+
}
190+
176191
@Override
177192
public String getWriteableName() {
178193
return NAME;
@@ -199,6 +214,14 @@ public void writeTo(StreamOutput out) throws IOException {
199214
out.writeString(text);
200215
out.writeString(type);
201216
}
217+
218+
public XContentBuilder toXContentObject(XContentBuilder builder, ToXContent.Params params) throws IOException {
219+
builder.startObject();
220+
builder.field("text", text);
221+
builder.field("type", type);
222+
builder.endObject();
223+
return builder;
224+
}
202225
}
203226

204227
public record ContentString(String content) implements Content, NamedWriteable {
@@ -222,6 +245,10 @@ public void writeTo(StreamOutput out) throws IOException {
222245
public String getWriteableName() {
223246
return NAME;
224247
}
248+
249+
public void toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
250+
builder.value(content);
251+
}
225252
}
226253

227254
public record ToolCall(String id, FunctionField function, String type) implements Writeable {
@@ -437,7 +464,7 @@ public void writeTo(StreamOutput out) throws IOException {
437464
public record FunctionField(
438465
@Nullable String description,
439466
String name,
440-
@Nullable Map<String, Object> parameters,
467+
@Nullable Map<String, Object> parameters, // TODO can we parse this as a string?
441468
@Nullable Boolean strict
442469
) implements Writeable {
443470

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
1616
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
1717
import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseHandler;
18-
import org.elasticsearch.xpack.inference.external.request.openai.OpenAiChatCompletionRequest;
18+
import org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequest;
1919
import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity;
2020
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel;
2121

@@ -35,7 +35,7 @@ public static OpenAiCompletionRequestManager of(OpenAiChatCompletionModel model,
3535
private final OpenAiChatCompletionModel model;
3636

3737
private OpenAiCompletionRequestManager(OpenAiChatCompletionModel model, ThreadPool threadPool) {
38-
super(threadPool, model, OpenAiChatCompletionRequest::buildDefaultUri);
38+
super(threadPool, model, OpenAiUnifiedChatCompletionRequest::buildDefaultUri);
3939
this.model = Objects.requireNonNull(model);
4040
}
4141

@@ -46,10 +46,11 @@ public void execute(
4646
Supplier<Boolean> hasRequestCompletedFunction,
4747
ActionListener<InferenceServiceResults> listener
4848
) {
49-
var docsOnly = DocumentsOnlyInput.of(inferenceInputs);
50-
var docsInput = docsOnly.getInputs();
51-
var stream = docsOnly.stream();
52-
OpenAiChatCompletionRequest request = new OpenAiChatCompletionRequest(docsInput, model, stream);
49+
50+
OpenAiUnifiedChatCompletionRequest request = new OpenAiUnifiedChatCompletionRequest(
51+
UnifiedChatInput.of(inferenceInputs).getRequestEntity(),
52+
model
53+
);
5354

5455
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
5556
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiUnifiedCompletionRequestManager.java

Lines changed: 0 additions & 61 deletions
This file was deleted.
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.http.sender;
9+
10+
import org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequestEntity;
11+
12+
import java.util.Objects;
13+
14+
public class UnifiedChatInput extends InferenceInputs {
15+
16+
public static UnifiedChatInput of(InferenceInputs inferenceInputs) {
17+
18+
if (inferenceInputs instanceof DocumentsOnlyInput docsOnly) {
19+
return new UnifiedChatInput(new OpenAiUnifiedChatCompletionRequestEntity(docsOnly));
20+
} else if (inferenceInputs instanceof UnifiedChatInput == false) {
21+
throw createUnsupportedTypeException(inferenceInputs);
22+
}
23+
24+
return (UnifiedChatInput) inferenceInputs;
25+
}
26+
27+
public OpenAiUnifiedChatCompletionRequestEntity getRequestEntity() {
28+
return requestEntity;
29+
}
30+
31+
private final OpenAiUnifiedChatCompletionRequestEntity requestEntity;
32+
33+
public UnifiedChatInput(OpenAiUnifiedChatCompletionRequestEntity requestEntity) {
34+
this.requestEntity = Objects.requireNonNull(requestEntity);
35+
}
36+
37+
public boolean stream() {
38+
return requestEntity.isStream();
39+
}
40+
}
Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,35 +21,28 @@
2121
import java.net.URI;
2222
import java.net.URISyntaxException;
2323
import java.nio.charset.StandardCharsets;
24-
import java.util.List;
2524
import java.util.Objects;
2625

2726
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
2827
import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.createOrgHeader;
2928

30-
public class OpenAiChatCompletionRequest implements OpenAiRequest {
29+
public class OpenAiUnifiedChatCompletionRequest implements OpenAiRequest {
3130

3231
private final OpenAiAccount account;
33-
private final List<String> input;
32+
private final OpenAiUnifiedChatCompletionRequestEntity requestEntity;
3433
private final OpenAiChatCompletionModel model;
35-
private final boolean stream;
3634

37-
public OpenAiChatCompletionRequest(List<String> input, OpenAiChatCompletionModel model, boolean stream) {
38-
this.account = OpenAiAccount.of(model, OpenAiChatCompletionRequest::buildDefaultUri);
39-
this.input = Objects.requireNonNull(input);
35+
public OpenAiUnifiedChatCompletionRequest(OpenAiUnifiedChatCompletionRequestEntity requestEntity, OpenAiChatCompletionModel model) {
36+
this.account = OpenAiAccount.of(model, OpenAiUnifiedChatCompletionRequest::buildDefaultUri);
37+
this.requestEntity = Objects.requireNonNull(requestEntity);
4038
this.model = Objects.requireNonNull(model);
41-
this.stream = stream;
4239
}
4340

4441
@Override
4542
public HttpRequest createHttpRequest() {
4643
HttpPost httpPost = new HttpPost(account.uri());
4744

48-
ByteArrayEntity byteEntity = new ByteArrayEntity(
49-
Strings.toString(
50-
new OpenAiChatCompletionRequestEntity(input, model.getServiceSettings().modelId(), model.getTaskSettings().user(), stream)
51-
).getBytes(StandardCharsets.UTF_8)
52-
);
45+
ByteArrayEntity byteEntity = new ByteArrayEntity(Strings.toString(requestEntity).getBytes(StandardCharsets.UTF_8));
5346
httpPost.setEntity(byteEntity);
5447

5548
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
@@ -87,7 +80,7 @@ public String getInferenceEntityId() {
8780

8881
@Override
8982
public boolean isStreaming() {
90-
return stream;
83+
return requestEntity.isStream();
9184
}
9285

9386
public static URI buildDefaultUri() throws URISyntaxException {

0 commit comments

Comments
 (0)