Skip to content

Commit bf817d0

Browse files
Reworking unified inputs
1 parent 7986c81 commit bf817d0

File tree

12 files changed

+84
-153
lines changed

12 files changed

+84
-153
lines changed

server/src/main/java/org/elasticsearch/inference/InferenceService.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,14 +115,13 @@ void infer(
115115
* Perform completion inference on the model using the unified schema.
116116
*
117117
* @param model The model
118-
* @param parameters Parameters for the request
118+
* @param request Parameters for the request
119119
* @param timeout The timeout for the request
120120
* @param listener Inference result listener
121121
*/
122-
void completionInfer(
122+
void unifiedCompletionInfer(
123123
Model model,
124-
// TODO create the class for this object
125-
Object parameters,
124+
UnifiedCompletionRequest request,
126125
TimeValue timeout,
127126
ActionListener<InferenceServiceResults> listener
128127
);
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
* 2.0.
66
*/
77

8-
package org.elasticsearch.xpack.core.inference.action;
8+
package org.elasticsearch.inference;
99

1010
import org.elasticsearch.common.io.stream.NamedWriteable;
1111
import org.elasticsearch.common.io.stream.StreamInput;
@@ -44,8 +44,8 @@ public record UnifiedCompletionRequest(
4444
public sealed interface Content extends NamedWriteable permits ContentObjects, ContentString {}
4545

4646
@SuppressWarnings("unchecked")
47-
static final ConstructingObjectParser<UnifiedCompletionRequest, Void> PARSER = new ConstructingObjectParser<>(
48-
InferenceAction.NAME,
47+
public static final ConstructingObjectParser<UnifiedCompletionRequest, Void> PARSER = new ConstructingObjectParser<>(
48+
UnifiedCompletionRequest.class.getSimpleName(),
4949
args -> new UnifiedCompletionRequest(
5050
(List<Message>) args[0],
5151
(String) args[1],

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java

Lines changed: 4 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import org.elasticsearch.common.xcontent.ChunkedToXContent;
2020
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
2121
import org.elasticsearch.common.xcontent.ChunkedToXContentObject;
22-
import org.elasticsearch.core.Nullable;
2322
import org.elasticsearch.core.TimeValue;
2423
import org.elasticsearch.inference.InferenceResults;
2524
import org.elasticsearch.inference.InferenceServiceResults;
@@ -93,8 +92,6 @@ public static Builder parseRequest(String inferenceEntityId, TaskType taskType,
9392
private final InputType inputType;
9493
private final TimeValue inferenceTimeout;
9594
private final boolean stream;
96-
private final boolean isUnifiedCompletionMode;
97-
private final UnifiedCompletionRequest unifiedCompletionRequest;
9895

9996
public Request(
10097
TaskType taskType,
@@ -104,34 +101,7 @@ public Request(
104101
Map<String, Object> taskSettings,
105102
InputType inputType,
106103
TimeValue inferenceTimeout,
107-
boolean stream,
108-
boolean isUnifiedCompletionsMode
109-
) {
110-
this(
111-
taskType,
112-
inferenceEntityId,
113-
query,
114-
input,
115-
taskSettings,
116-
inputType,
117-
inferenceTimeout,
118-
stream,
119-
isUnifiedCompletionsMode,
120-
null
121-
);
122-
}
123-
124-
public Request(
125-
TaskType taskType,
126-
String inferenceEntityId,
127-
String query,
128-
List<String> input,
129-
Map<String, Object> taskSettings,
130-
InputType inputType,
131-
TimeValue inferenceTimeout,
132-
boolean stream,
133-
boolean isUnifiedCompletionsMode,
134-
@Nullable UnifiedCompletionRequest unifiedCompletionRequest
104+
boolean stream
135105
) {
136106
this.taskType = taskType;
137107
this.inferenceEntityId = inferenceEntityId;
@@ -141,8 +111,6 @@ public Request(
141111
this.inputType = inputType;
142112
this.inferenceTimeout = inferenceTimeout;
143113
this.stream = stream;
144-
this.isUnifiedCompletionMode = isUnifiedCompletionsMode;
145-
this.unifiedCompletionRequest = unifiedCompletionRequest;
146114
}
147115

148116
public Request(StreamInput in) throws IOException {
@@ -169,14 +137,6 @@ public Request(StreamInput in) throws IOException {
169137
this.inferenceTimeout = DEFAULT_TIMEOUT;
170138
}
171139

172-
if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_UNIFIED_COMPLETIONS_API)) {
173-
this.isUnifiedCompletionMode = in.readBoolean();
174-
this.unifiedCompletionRequest = in.readOptionalWriteable(UnifiedCompletionRequest::new);
175-
} else {
176-
this.isUnifiedCompletionMode = false;
177-
this.unifiedCompletionRequest = null;
178-
}
179-
180140
// streaming is not supported yet for transport traffic
181141
this.stream = false;
182142
}
@@ -213,10 +173,6 @@ public boolean isStreaming() {
213173
return stream;
214174
}
215175

216-
public boolean isUnifiedCompletionMode() {
217-
return isUnifiedCompletionMode;
218-
}
219-
220176
@Override
221177
public ActionRequestValidationException validate() {
222178
if (input == null) {
@@ -242,10 +198,6 @@ public ActionRequestValidationException validate() {
242198
e.addValidationError(format("Field [query] cannot be empty for task type [%s]", TaskType.RERANK));
243199
return e;
244200
}
245-
} else if (query != null) {
246-
var e = new ActionRequestValidationException();
247-
e.addValidationError(format("Task type [%s] does not support field [query]", TaskType.RERANK));
248-
return e;
249201
}
250202

251203
return null;
@@ -271,11 +223,6 @@ public void writeTo(StreamOutput out) throws IOException {
271223
out.writeOptionalString(query);
272224
out.writeTimeValue(inferenceTimeout);
273225
}
274-
275-
if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_UNIFIED_COMPLETIONS_API)) {
276-
out.writeBoolean(isUnifiedCompletionMode);
277-
out.writeOptionalWriteable(unifiedCompletionRequest);
278-
}
279226
}
280227

281228
// default for easier testing
@@ -302,22 +249,12 @@ public boolean equals(Object o) {
302249
&& Objects.equals(taskSettings, request.taskSettings)
303250
&& Objects.equals(inputType, request.inputType)
304251
&& Objects.equals(query, request.query)
305-
&& Objects.equals(inferenceTimeout, request.inferenceTimeout)
306-
&& Objects.equals(isUnifiedCompletionMode, request.isUnifiedCompletionMode);
252+
&& Objects.equals(inferenceTimeout, request.inferenceTimeout);
307253
}
308254

309255
@Override
310256
public int hashCode() {
311-
return Objects.hash(
312-
taskType,
313-
inferenceEntityId,
314-
input,
315-
taskSettings,
316-
inputType,
317-
query,
318-
inferenceTimeout,
319-
isUnifiedCompletionMode
320-
);
257+
return Objects.hash(taskType, inferenceEntityId, input, taskSettings, inputType, query, inferenceTimeout);
321258
}
322259

323260
public static class Builder {
@@ -330,8 +267,6 @@ public static class Builder {
330267
private String query;
331268
private TimeValue timeout = DEFAULT_TIMEOUT;
332269
private boolean stream = false;
333-
private boolean unifiedCompletionMode = false;
334-
private UnifiedCompletionRequest unifiedCompletionRequest;
335270

336271
private Builder() {}
337272

@@ -379,29 +314,8 @@ public Builder setStream(boolean stream) {
379314
return this;
380315
}
381316

382-
public Builder setUnifiedCompletionMode(boolean unified) {
383-
this.unifiedCompletionMode = unified;
384-
return this;
385-
}
386-
387-
public Builder setUnifiedCompletionRequest(UnifiedCompletionRequest unifiedCompletionRequest) {
388-
this.unifiedCompletionRequest = unifiedCompletionRequest;
389-
return this;
390-
}
391-
392317
public Request build() {
393-
return new Request(
394-
taskType,
395-
inferenceEntityId,
396-
query,
397-
input,
398-
taskSettings,
399-
inputType,
400-
timeout,
401-
stream,
402-
unifiedCompletionMode,
403-
unifiedCompletionRequest
404-
);
318+
return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout, stream);
405319
}
406320
}
407321

@@ -420,8 +334,6 @@ public String toString() {
420334
+ this.getInputType()
421335
+ ", timeout="
422336
+ this.getInferenceTimeout()
423-
+ ", isUnifiedCompletionsMode="
424-
+ this.isUnifiedCompletionMode()
425337
+ ")";
426338
}
427339
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.elasticsearch.common.io.stream.StreamOutput;
1414
import org.elasticsearch.core.TimeValue;
1515
import org.elasticsearch.inference.TaskType;
16+
import org.elasticsearch.inference.UnifiedCompletionRequest;
1617
import org.elasticsearch.xcontent.XContentParser;
1718

1819
import java.io.IOException;
@@ -27,7 +28,8 @@ public UnifiedCompletionAction() {
2728
}
2829

2930
public static class Request extends BaseInferenceActionRequest {
30-
public static Request parseRequest(String inferenceEntityId, TaskType taskType, TimeValue timeout, XContentParser parser) throws IOException {
31+
public static Request parseRequest(String inferenceEntityId, TaskType taskType, TimeValue timeout, XContentParser parser)
32+
throws IOException {
3133
var unifiedRequest = UnifiedCompletionRequest.PARSER.apply(parser, null);
3234
return new Request(inferenceEntityId, taskType, unifiedRequest, timeout);
3335
}
@@ -110,8 +112,8 @@ public boolean equals(Object o) {
110112
Request request = (Request) o;
111113
return Objects.equals(inferenceEntityId, request.inferenceEntityId)
112114
&& taskType == request.taskType
113-
&& Objects.equals(unifiedCompletionRequest, request.unifiedCompletionRequest) &&
114-
Objects.equals(timeout, request.timeout);
115+
&& Objects.equals(unifiedCompletionRequest, request.unifiedCompletionRequest)
116+
&& Objects.equals(timeout, request.timeout);
115117
}
116118

117119
@Override

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.elasticsearch.TransportVersion;
1111
import org.elasticsearch.common.io.stream.Writeable;
12+
import org.elasticsearch.inference.UnifiedCompletionRequest;
1213
import org.elasticsearch.xcontent.json.JsonXContent;
1314
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
1415

@@ -99,7 +100,6 @@ public void testParseAllFields() throws IOException {
99100
100L,
100101
1,
101102
new UnifiedCompletionRequest.StopValues(List.of("stop")),
102-
true,
103103
0.1F,
104104
new UnifiedCompletionRequest.ToolChoiceObject(
105105
"function",
@@ -168,7 +168,6 @@ public void testParsing() throws IOException {
168168
null,
169169
new UnifiedCompletionRequest.StopString("none"),
170170
null,
171-
null,
172171
new UnifiedCompletionRequest.ToolChoiceString("auto"),
173172
List.of(
174173
new UnifiedCompletionRequest.Tool(
@@ -196,7 +195,6 @@ public static UnifiedCompletionRequest randomUnifiedCompletionRequest() {
196195
randomNullOrLong(),
197196
randomNullOrInt(),
198197
randomNullOrStop(),
199-
randomOptionalBoolean(),
200198
randomNullOrFloat(),
201199
randomNullOrToolChoice(),
202200
randomList(5, UnifiedCompletionRequestTests::randomTool),

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,6 @@ protected void doInference(
7171
InferenceService service,
7272
ActionListener<InferenceServiceResults> listener
7373
) {
74-
service.completionInfer(model, request.getUnifiedCompletionRequest(), null, listener);
74+
service.unifiedCompletionInfer(model, request.getUnifiedCompletionRequest(), null, listener);
7575
}
7676
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@ public static IllegalArgumentException createUnsupportedTypeException(InferenceI
1414
return new IllegalArgumentException(Strings.format("Unsupported inference inputs type: [%s]", inferenceInputs.getClass()));
1515
}
1616

17-
public static <T> T abc(InferenceInputs inputs, Class<T> clazz) {
18-
if (inputs.getClass().isInstance(clazz) == false) {
19-
throw createUnsupportedTypeException(inputs);
17+
public <T> T castTo(Class<T> clazz) {
18+
if (this.getClass().isInstance(clazz) == false) {
19+
throw createUnsupportedTypeException(this);
2020
}
2121

22-
return clazz.cast(inputs);
22+
return clazz.cast(this);
2323
}
2424
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,9 @@ public void execute(
4747
ActionListener<InferenceServiceResults> listener
4848
) {
4949

50-
OpenAiUnifiedChatCompletionRequest request = new OpenAiUnifiedChatCompletionRequest(
51-
UnifiedChatInput.of(inferenceInputs).getRequestEntity(),
52-
model
53-
);
50+
// TODO check and see if this works
51+
// OpenAiUnifiedChatCompletionRequest request = new OpenAiUnifiedChatCompletionRequest(UnifiedChatInput.of(inferenceInputs), model);
52+
OpenAiUnifiedChatCompletionRequest request = new OpenAiUnifiedChatCompletionRequest(inferenceInputs.castTo(UnifiedChatInput.class), model);
5453

5554
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
5655
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,34 +7,60 @@
77

88
package org.elasticsearch.xpack.inference.external.http.sender;
99

10-
import org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequestEntity;
10+
import org.elasticsearch.inference.UnifiedCompletionRequest;
1111

12+
import java.util.List;
1213
import java.util.Objects;
1314

15+
import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD;
16+
1417
public class UnifiedChatInput extends InferenceInputs {
1518

1619
public static UnifiedChatInput of(InferenceInputs inferenceInputs) {
17-
18-
if (inferenceInputs instanceof DocumentsOnlyInput docsOnly) {
19-
return new UnifiedChatInput(new OpenAiUnifiedChatCompletionRequestEntity(docsOnly));
20-
} else if (inferenceInputs instanceof UnifiedChatInput == false) {
20+
if (inferenceInputs instanceof UnifiedChatInput == false) {
2121
throw createUnsupportedTypeException(inferenceInputs);
2222
}
2323

2424
return (UnifiedChatInput) inferenceInputs;
2525
}
2626

27-
public OpenAiUnifiedChatCompletionRequestEntity getRequestEntity() {
28-
return requestEntity;
27+
public static UnifiedChatInput of(List<String> input, boolean stream) {
28+
var unifiedRequest = new UnifiedCompletionRequest(
29+
convertToMessages(input),
30+
null,
31+
null,
32+
null,
33+
null,
34+
null,
35+
null,
36+
null,
37+
null,
38+
// TODO we need to get the user field from task settings if it is there
39+
null
40+
);
41+
42+
return new UnifiedChatInput(unifiedRequest, stream);
2943
}
3044

31-
private final OpenAiUnifiedChatCompletionRequestEntity requestEntity;
45+
private static List<UnifiedCompletionRequest.Message> convertToMessages(List<String> inputs) {
46+
return inputs.stream()
47+
.map(doc -> new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString(doc), USER_FIELD, null, null, null))
48+
.toList();
49+
}
50+
51+
private final UnifiedCompletionRequest request;
52+
private final boolean stream;
53+
54+
public UnifiedChatInput(UnifiedCompletionRequest request, boolean stream) {
55+
this.request = Objects.requireNonNull(request);
56+
this.stream = stream;
57+
}
3258

33-
public UnifiedChatInput(OpenAiUnifiedChatCompletionRequestEntity requestEntity) {
34-
this.requestEntity = Objects.requireNonNull(requestEntity);
59+
public UnifiedCompletionRequest getRequest() {
60+
return request;
3561
}
3662

3763
public boolean stream() {
38-
return requestEntity.isStream();
64+
return stream;
3965
}
4066
}

0 commit comments

Comments
 (0)