Skip to content

Commit d6cc223

Browse files
author
Max Hniebergall
committed
separate out unified request and combine inputs
1 parent 9cb401c commit d6cc223

File tree

5 files changed

+106
-134
lines changed

5 files changed

+106
-134
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CompletionInputs.java

Lines changed: 0 additions & 32 deletions
This file was deleted.
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.request;
9+
10+
import org.elasticsearch.core.Nullable;
11+
import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionRequest;
12+
13+
import java.util.List;
14+
15+
public record UnifiedRequest(
16+
List<UnifiedCompletionRequest.Message> messages,
17+
@Nullable String model,
18+
@Nullable Long maxCompletionTokens,
19+
@Nullable Integer n,
20+
@Nullable UnifiedCompletionRequest.Stop stop,
21+
@Nullable Float temperature,
22+
@Nullable UnifiedCompletionRequest.ToolChoice toolChoice,
23+
@Nullable List<UnifiedCompletionRequest.Tool> tool,
24+
@Nullable Float topP,
25+
@Nullable String user,
26+
boolean stream
27+
) {}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java

Lines changed: 66 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -7,132 +7,99 @@
77

88
package org.elasticsearch.xpack.inference.external.request.openai;
99

10-
import org.elasticsearch.common.Strings;
11-
import org.elasticsearch.core.Nullable;
1210
import org.elasticsearch.xcontent.ToXContentObject;
1311
import org.elasticsearch.xcontent.XContentBuilder;
1412
import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionRequest;
1513
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
14+
import org.elasticsearch.xpack.inference.external.request.UnifiedRequest;
1615

1716
import java.io.IOException;
1817
import java.util.List;
19-
import java.util.Objects;
2018

2119
public class OpenAiUnifiedChatCompletionRequestEntity implements ToXContentObject {
2220

23-
private static final String MESSAGES_FIELD = "messages";
24-
private static final String MODEL_FIELD = "model";
25-
21+
public static final String NAME_FIELD = "name";
22+
public static final String TOOL_CALL_ID_FIELD = "tool_call_id";
23+
public static final String TOOL_CALLS_FIELD = "tool_calls";
24+
public static final String ID_FIELD = "id";
25+
public static final String FUNCTION_FIELD = "function";
26+
public static final String ARGUMENTS_FIELD = "arguments";
27+
public static final String DESCRIPTION_FIELD = "description";
28+
public static final String PARAMETERS_FIELD = "parameters";
29+
public static final String STRICT_FIELD = "strict";
30+
public static final String TOP_P_FIELD = "top_p";
31+
public static final String USER_FIELD = "user";
32+
public static final String STREAM_FIELD = "stream";
2633
private static final String NUMBER_OF_RETURNED_CHOICES_FIELD = "n";
27-
34+
private static final String MODEL_FIELD = "model";
35+
public static final String MESSAGES_FIELD = "messages";
2836
private static final String ROLE_FIELD = "role";
29-
private static final String USER_FIELD = "user";
3037
private static final String CONTENT_FIELD = "content";
31-
private static final String STREAM_FIELD = "stream";
3238
private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens";
3339
private static final String STOP_FIELD = "stop";
3440
private static final String TEMPERATURE_FIELD = "temperature";
3541
private static final String TOOL_CHOICE_FIELD = "tool_choice";
3642
private static final String TOOL_FIELD = "tool";
37-
private static final String TOP_P_FIELD = "top_p";
43+
private static final String TEXT_FIELD = "text";
44+
private static final String TYPE_FIELD = "type";
3845

39-
private final String user;
46+
private final UnifiedRequest unifiedRequest;
4047

41-
public boolean isStream() {
42-
return stream;
48+
public OpenAiUnifiedChatCompletionRequestEntity(UnifiedRequest unifiedRequest) {
49+
this.unifiedRequest = unifiedRequest;
4350
}
4451

45-
private final boolean stream;
46-
private final Long maxCompletionTokens;
47-
private final Integer n;
48-
private final UnifiedCompletionRequest.Stop stop;
49-
private final Float temperature;
50-
private final UnifiedCompletionRequest.ToolChoice toolChoice;
51-
private final List<UnifiedCompletionRequest.Tool> tool;
52-
private final Float topP;
53-
private final List<UnifiedCompletionRequest.Message> messages;
54-
private final String model;
55-
5652
public OpenAiUnifiedChatCompletionRequestEntity(DocumentsOnlyInput input) {
57-
this(convertDocumentsOnlyInputToMessages(input), null, null, null, null, null, null, null, null, null);
53+
this(new UnifiedRequest(convertDocumentsOnlyInputToMessages(input), null, null, null, null, null, null, null, null, null, true));
5854
}
5955

6056
private static List<UnifiedCompletionRequest.Message> convertDocumentsOnlyInputToMessages(DocumentsOnlyInput input) {
6157
return input.getInputs()
6258
.stream()
63-
.map(doc -> new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString(doc), "user", null, null, null))
59+
.map(doc -> new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString(doc), USER_FIELD, null, null, null))
6460
.toList();
6561
}
6662

67-
public OpenAiUnifiedChatCompletionRequestEntity(
68-
List<UnifiedCompletionRequest.Message> messages,
69-
@Nullable String model,
70-
@Nullable Long maxCompletionTokens,
71-
@Nullable Integer n,
72-
@Nullable UnifiedCompletionRequest.Stop stop,
73-
@Nullable Float temperature,
74-
@Nullable UnifiedCompletionRequest.ToolChoice toolChoice,
75-
@Nullable List<UnifiedCompletionRequest.Tool> tool,
76-
@Nullable Float topP,
77-
@Nullable String user
78-
) {
79-
Objects.requireNonNull(messages);
80-
Objects.requireNonNull(model);
81-
82-
this.user = user;
83-
this.stream = true; // always stream in unified API
84-
this.maxCompletionTokens = maxCompletionTokens;
85-
this.n = n;
86-
this.stop = stop;
87-
this.temperature = temperature;
88-
this.toolChoice = toolChoice;
89-
this.tool = tool;
90-
this.topP = topP;
91-
this.messages = messages;
92-
this.model = model;
93-
94-
}
95-
9663
@Override
9764
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
9865
builder.startObject();
9966
builder.startArray(MESSAGES_FIELD);
10067
{
101-
for (UnifiedCompletionRequest.Message message : messages) {
68+
for (UnifiedCompletionRequest.Message message : unifiedRequest.messages()) {
10269
builder.startObject();
10370
{
10471
switch (message.content()) {
10572
case UnifiedCompletionRequest.ContentString contentString -> builder.field(CONTENT_FIELD, contentString.content());
10673
case UnifiedCompletionRequest.ContentObjects contentObjects -> {
10774
for (UnifiedCompletionRequest.ContentObject contentObject : contentObjects.contentObjects()) {
10875
builder.startObject(CONTENT_FIELD);
109-
builder.field("text", contentObject.text());
110-
builder.field("type", contentObject.type());
76+
builder.field(TEXT_FIELD, contentObject.text());
77+
builder.field(TYPE_FIELD, contentObject.type());
11178
builder.endObject();
11279
}
11380
}
11481
}
11582

11683
builder.field(ROLE_FIELD, message.role());
11784
if (message.name() != null) {
118-
builder.field("name", message.name());
85+
builder.field(NAME_FIELD, message.name());
11986
}
12087
if (message.toolCallId() != null) {
121-
builder.field("tool_call_id", message.toolCallId());
88+
builder.field(TOOL_CALL_ID_FIELD, message.toolCallId());
12289
}
12390
if (message.toolCalls() != null) {
124-
builder.startArray("tool_calls");
91+
builder.startArray(TOOL_CALLS_FIELD);
12592
for (UnifiedCompletionRequest.ToolCall toolCall : message.toolCalls()) {
12693
builder.startObject();
12794
{
128-
builder.field("id", toolCall.id());
129-
builder.startObject("function");
95+
builder.field(ID_FIELD, toolCall.id());
96+
builder.startObject(FUNCTION_FIELD);
13097
{
131-
builder.field("arguments", toolCall.function().arguments());
132-
builder.field("name", toolCall.function().name());
98+
builder.field(ARGUMENTS_FIELD, toolCall.function().arguments());
99+
builder.field(NAME_FIELD, toolCall.function().name());
133100
}
134101
builder.endObject();
135-
builder.field("type", toolCall.type());
102+
builder.field(TYPE_FIELD, toolCall.type());
136103
}
137104
builder.endObject();
138105
}
@@ -144,65 +111,69 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
144111
}
145112
builder.endArray();
146113

147-
if (model != null) {
148-
builder.field(MODEL_FIELD, model);
114+
if (unifiedRequest.model() != null) {
115+
builder.field(MODEL_FIELD, unifiedRequest.model());
149116
}
150-
if (maxCompletionTokens != null) {
151-
builder.field(MAX_COMPLETION_TOKENS_FIELD, maxCompletionTokens);
117+
if (unifiedRequest.maxCompletionTokens() != null) {
118+
builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedRequest.maxCompletionTokens());
152119
}
153-
if (n != null) {
154-
builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, n);
120+
if (unifiedRequest.n() != null) {
121+
builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, unifiedRequest.n());
155122
}
156-
if (stop != null) {
157-
switch (stop) {
123+
if (unifiedRequest.stop() != null) {
124+
switch (unifiedRequest.stop()) {
158125
case UnifiedCompletionRequest.StopString stopString -> builder.field(STOP_FIELD, stopString.value());
159126
case UnifiedCompletionRequest.StopValues stopValues -> builder.field(STOP_FIELD, stopValues.values());
160127
}
161128
}
162-
if (temperature != null) {
163-
builder.field(TEMPERATURE_FIELD, temperature);
129+
if (unifiedRequest.temperature() != null) {
130+
builder.field(TEMPERATURE_FIELD, unifiedRequest.temperature());
164131
}
165-
if (toolChoice != null) {
166-
if (toolChoice instanceof UnifiedCompletionRequest.ToolChoiceString) {
167-
builder.field(TOOL_CHOICE_FIELD, ((UnifiedCompletionRequest.ToolChoiceString) toolChoice).value());
168-
} else if (toolChoice instanceof UnifiedCompletionRequest.ToolChoiceObject) {
132+
if (unifiedRequest.toolChoice() != null) {
133+
if (unifiedRequest.toolChoice() instanceof UnifiedCompletionRequest.ToolChoiceString) {
134+
builder.field(TOOL_CHOICE_FIELD, ((UnifiedCompletionRequest.ToolChoiceString) unifiedRequest.toolChoice()).value());
135+
} else if (unifiedRequest.toolChoice() instanceof UnifiedCompletionRequest.ToolChoiceObject) {
169136
builder.startObject(TOOL_CHOICE_FIELD);
170137
{
171-
builder.field("type", ((UnifiedCompletionRequest.ToolChoiceObject) toolChoice).type());
172-
builder.startObject("function");
138+
builder.field(TYPE_FIELD, ((UnifiedCompletionRequest.ToolChoiceObject) unifiedRequest.toolChoice()).type());
139+
builder.startObject(FUNCTION_FIELD);
173140
{
174-
builder.field("name", ((UnifiedCompletionRequest.ToolChoiceObject) toolChoice).function().name());
141+
builder.field(
142+
NAME_FIELD,
143+
((UnifiedCompletionRequest.ToolChoiceObject) unifiedRequest.toolChoice()).function().name()
144+
);
175145
}
176146
builder.endObject();
177147
}
178148
builder.endObject();
179149
}
180150
}
181-
if (tool != null) {
151+
if (unifiedRequest.tool() != null) {
182152
builder.startArray(TOOL_FIELD);
183-
for (UnifiedCompletionRequest.Tool t : tool) {
153+
for (UnifiedCompletionRequest.Tool t : unifiedRequest.tool()) {
184154
builder.startObject();
185155
{
186-
builder.field("type", t.type());
187-
builder.startObject("function");
156+
builder.field(TYPE_FIELD, t.type());
157+
builder.startObject(FUNCTION_FIELD);
188158
{
189-
builder.field("description", t.function().description());
190-
builder.field("name", t.function().name());
191-
builder.field("parameters", t.function().parameters());
192-
builder.field("strict", t.function().strict());
159+
builder.field(DESCRIPTION_FIELD, t.function().description());
160+
builder.field(NAME_FIELD, t.function().name());
161+
builder.field(PARAMETERS_FIELD, t.function().parameters());
162+
builder.field(STRICT_FIELD, t.function().strict());
193163
}
194164
builder.endObject();
195165
}
196166
builder.endObject();
197167
}
198168
builder.endArray();
199169
}
200-
if (topP != null) {
201-
builder.field(TOP_P_FIELD, topP);
170+
if (unifiedRequest.topP() != null) {
171+
builder.field(TOP_P_FIELD, unifiedRequest.topP());
202172
}
203-
if (Strings.isNullOrEmpty(user) == false) {
204-
builder.field(USER_FIELD, user);
173+
if (unifiedRequest.user() != null && unifiedRequest.user().isEmpty() == false) {
174+
builder.field(USER_FIELD, unifiedRequest.user());
205175
}
176+
builder.field(STREAM_FIELD, unifiedRequest.stream());
206177
builder.endObject();
207178
return builder;
208179
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,13 @@
2121
import org.elasticsearch.inference.Model;
2222
import org.elasticsearch.inference.TaskType;
2323
import org.elasticsearch.rest.RestStatus;
24-
import org.elasticsearch.xpack.inference.external.http.sender.CompletionInputs;
2524
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
2625
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
2726
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
2827
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
2928
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
29+
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
30+
import org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequestEntity;
3031

3132
import java.io.IOException;
3233
import java.util.EnumSet;
@@ -73,7 +74,7 @@ public void infer(
7374
private static InferenceInputs createInput(Model model, List<String> input, @Nullable String query, boolean stream) {
7475
return switch (model.getTaskType()) {
7576
// TODO implement parameters
76-
case COMPLETION -> new CompletionInputs(null);
77+
case COMPLETION -> new UnifiedChatInput(null);
7778
case RERANK -> new QueryAndDocsInputs(query, input, stream);
7879
case TEXT_EMBEDDING -> new DocumentsOnlyInput(input, stream);
7980
default -> throw new ElasticsearchStatusException(
@@ -84,9 +85,14 @@ private static InferenceInputs createInput(Model model, List<String> input, @Nul
8485
}
8586

8687
@Override
87-
public void completionInfer(Model model, Object parameters, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
88+
public void completionInfer(
89+
Model model,
90+
OpenAiUnifiedChatCompletionRequestEntity parameters,
91+
TimeValue timeout,
92+
ActionListener<InferenceServiceResults> listener
93+
) {
8894
init();
89-
doUnifiedCompletionInfer(model, new CompletionInputs(parameters), timeout, listener);
95+
doUnifiedCompletionInfer(model, new UnifiedChatInput(parameters), timeout, listener);
9096
}
9197

9298
@Override
@@ -116,7 +122,7 @@ protected abstract void doInfer(
116122

117123
protected abstract void doUnifiedCompletionInfer(
118124
Model model,
119-
CompletionInputs inputs,
125+
UnifiedChatInput inputs,
120126
TimeValue timeout,
121127
ActionListener<InferenceServiceResults> listener
122128
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@
3535
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
3636
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
3737
import org.elasticsearch.xpack.inference.external.action.openai.OpenAiActionCreator;
38-
import org.elasticsearch.xpack.inference.external.http.sender.CompletionInputs;
3938
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
4039
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
4140
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
4241
import org.elasticsearch.xpack.inference.external.http.sender.OpenAiCompletionRequestManager;
42+
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
4343
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
4444
import org.elasticsearch.xpack.inference.services.SenderService;
4545
import org.elasticsearch.xpack.inference.services.ServiceComponents;
@@ -266,7 +266,7 @@ public void doInfer(
266266
@Override
267267
public void doUnifiedCompletionInfer(
268268
Model model,
269-
CompletionInputs inputs,
269+
UnifiedChatInput inputs,
270270
TimeValue timeout,
271271
ActionListener<InferenceServiceResults> listener
272272
) {

0 commit comments

Comments
 (0)