Skip to content

Commit 7986c81

Browse files
Merge branch 'ml-inference-unified-api-elastic' of github.com:elastic/elasticsearch into ml-inference-unified-api-elastic
2 parents 1e0eb20 + d6cc223 commit 7986c81

File tree

10 files changed

+284
-123
lines changed

10 files changed

+284
-123
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequest.java

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import org.elasticsearch.xcontent.ConstructingObjectParser;
1717
import org.elasticsearch.xcontent.ObjectParser;
1818
import org.elasticsearch.xcontent.ParseField;
19+
import org.elasticsearch.xcontent.ToXContent;
20+
import org.elasticsearch.xcontent.XContentBuilder;
1921
import org.elasticsearch.xcontent.XContentParseException;
2022
import org.elasticsearch.xcontent.XContentParser;
2123

@@ -39,6 +41,8 @@ public record UnifiedCompletionRequest(
3941
@Nullable String user
4042
) implements Writeable {
4143

44+
public sealed interface Content extends NamedWriteable permits ContentObjects, ContentString {}
45+
4246
@SuppressWarnings("unchecked")
4347
static final ConstructingObjectParser<UnifiedCompletionRequest, Void> PARSER = new ConstructingObjectParser<>(
4448
InferenceAction.NAME,
@@ -153,8 +157,6 @@ public void writeTo(StreamOutput out) throws IOException {
153157
}
154158
}
155159

156-
public sealed interface Content extends NamedWriteable permits ContentObjects, ContentString {}
157-
158160
public record ContentObjects(List<ContentObject> contentObjects) implements Content, Writeable {
159161

160162
public static final String NAME = "content_objects";
@@ -194,6 +196,7 @@ public void writeTo(StreamOutput out) throws IOException {
194196
out.writeString(text);
195197
out.writeString(type);
196198
}
199+
197200
}
198201

199202
public record ContentString(String content) implements Content, NamedWriteable {
@@ -217,6 +220,10 @@ public void writeTo(StreamOutput out) throws IOException {
217220
public String getWriteableName() {
218221
return NAME;
219222
}
223+
224+
public void toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
225+
builder.value(content);
226+
}
220227
}
221228

222229
public record ToolCall(String id, FunctionField function, String type) implements Writeable {
@@ -432,7 +439,7 @@ public void writeTo(StreamOutput out) throws IOException {
432439
public record FunctionField(
433440
@Nullable String description,
434441
String name,
435-
@Nullable Map<String, Object> parameters,
442+
@Nullable Map<String, Object> parameters, // TODO can we parse this as a string?
436443
@Nullable Boolean strict
437444
) implements Writeable {
438445

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CompletionInputs.java

Lines changed: 0 additions & 32 deletions
This file was deleted.

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
1616
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
1717
import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseHandler;
18-
import org.elasticsearch.xpack.inference.external.request.openai.OpenAiChatCompletionRequest;
18+
import org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequest;
1919
import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity;
2020
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel;
2121

@@ -35,7 +35,7 @@ public static OpenAiCompletionRequestManager of(OpenAiChatCompletionModel model,
3535
private final OpenAiChatCompletionModel model;
3636

3737
private OpenAiCompletionRequestManager(OpenAiChatCompletionModel model, ThreadPool threadPool) {
38-
super(threadPool, model, OpenAiChatCompletionRequest::buildDefaultUri);
38+
super(threadPool, model, OpenAiUnifiedChatCompletionRequest::buildDefaultUri);
3939
this.model = Objects.requireNonNull(model);
4040
}
4141

@@ -46,10 +46,11 @@ public void execute(
4646
Supplier<Boolean> hasRequestCompletedFunction,
4747
ActionListener<InferenceServiceResults> listener
4848
) {
49-
var docsOnly = DocumentsOnlyInput.of(inferenceInputs);
50-
var docsInput = docsOnly.getInputs();
51-
var stream = docsOnly.stream();
52-
OpenAiChatCompletionRequest request = new OpenAiChatCompletionRequest(docsInput, model, stream);
49+
50+
OpenAiUnifiedChatCompletionRequest request = new OpenAiUnifiedChatCompletionRequest(
51+
UnifiedChatInput.of(inferenceInputs).getRequestEntity(),
52+
model
53+
);
5354

5455
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
5556
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiUnifiedCompletionRequestManager.java

Lines changed: 0 additions & 61 deletions
This file was deleted.
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.http.sender;
9+
10+
import org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequestEntity;
11+
12+
import java.util.Objects;
13+
14+
public class UnifiedChatInput extends InferenceInputs {
15+
16+
public static UnifiedChatInput of(InferenceInputs inferenceInputs) {
17+
18+
if (inferenceInputs instanceof DocumentsOnlyInput docsOnly) {
19+
return new UnifiedChatInput(new OpenAiUnifiedChatCompletionRequestEntity(docsOnly));
20+
} else if (inferenceInputs instanceof UnifiedChatInput == false) {
21+
throw createUnsupportedTypeException(inferenceInputs);
22+
}
23+
24+
return (UnifiedChatInput) inferenceInputs;
25+
}
26+
27+
public OpenAiUnifiedChatCompletionRequestEntity getRequestEntity() {
28+
return requestEntity;
29+
}
30+
31+
private final OpenAiUnifiedChatCompletionRequestEntity requestEntity;
32+
33+
public UnifiedChatInput(OpenAiUnifiedChatCompletionRequestEntity requestEntity) {
34+
this.requestEntity = Objects.requireNonNull(requestEntity);
35+
}
36+
37+
public boolean stream() {
38+
return requestEntity.isStream();
39+
}
40+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.request;
9+
10+
import org.elasticsearch.core.Nullable;
11+
import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionRequest;
12+
13+
import java.util.List;
14+
15+
public record UnifiedRequest(
16+
List<UnifiedCompletionRequest.Message> messages,
17+
@Nullable String model,
18+
@Nullable Long maxCompletionTokens,
19+
@Nullable Integer n,
20+
@Nullable UnifiedCompletionRequest.Stop stop,
21+
@Nullable Float temperature,
22+
@Nullable UnifiedCompletionRequest.ToolChoice toolChoice,
23+
@Nullable List<UnifiedCompletionRequest.Tool> tool,
24+
@Nullable Float topP,
25+
@Nullable String user,
26+
boolean stream
27+
) {}
Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,35 +21,28 @@
2121
import java.net.URI;
2222
import java.net.URISyntaxException;
2323
import java.nio.charset.StandardCharsets;
24-
import java.util.List;
2524
import java.util.Objects;
2625

2726
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
2827
import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.createOrgHeader;
2928

30-
public class OpenAiChatCompletionRequest implements OpenAiRequest {
29+
public class OpenAiUnifiedChatCompletionRequest implements OpenAiRequest {
3130

3231
private final OpenAiAccount account;
33-
private final List<String> input;
32+
private final OpenAiUnifiedChatCompletionRequestEntity requestEntity;
3433
private final OpenAiChatCompletionModel model;
35-
private final boolean stream;
3634

37-
public OpenAiChatCompletionRequest(List<String> input, OpenAiChatCompletionModel model, boolean stream) {
38-
this.account = OpenAiAccount.of(model, OpenAiChatCompletionRequest::buildDefaultUri);
39-
this.input = Objects.requireNonNull(input);
35+
public OpenAiUnifiedChatCompletionRequest(OpenAiUnifiedChatCompletionRequestEntity requestEntity, OpenAiChatCompletionModel model) {
36+
this.account = OpenAiAccount.of(model, OpenAiUnifiedChatCompletionRequest::buildDefaultUri);
37+
this.requestEntity = Objects.requireNonNull(requestEntity);
4038
this.model = Objects.requireNonNull(model);
41-
this.stream = stream;
4239
}
4340

4441
@Override
4542
public HttpRequest createHttpRequest() {
4643
HttpPost httpPost = new HttpPost(account.uri());
4744

48-
ByteArrayEntity byteEntity = new ByteArrayEntity(
49-
Strings.toString(
50-
new OpenAiChatCompletionRequestEntity(input, model.getServiceSettings().modelId(), model.getTaskSettings().user(), stream)
51-
).getBytes(StandardCharsets.UTF_8)
52-
);
45+
ByteArrayEntity byteEntity = new ByteArrayEntity(Strings.toString(requestEntity).getBytes(StandardCharsets.UTF_8));
5346
httpPost.setEntity(byteEntity);
5447

5548
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
@@ -87,7 +80,7 @@ public String getInferenceEntityId() {
8780

8881
@Override
8982
public boolean isStreaming() {
90-
return stream;
83+
return requestEntity.isStream();
9184
}
9285

9386
public static URI buildDefaultUri() throws URISyntaxException {

0 commit comments

Comments
 (0)