Skip to content

Commit 1e30c6d

Browse files
Creating a new action
1 parent bd59543 commit 1e30c6d

File tree

4 files changed

+457
-1
lines changed

4 files changed

+457
-1
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.elasticsearch.common.xcontent.ChunkedToXContent;
2121
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
2222
import org.elasticsearch.common.xcontent.ChunkedToXContentObject;
23+
import org.elasticsearch.core.Nullable;
2324
import org.elasticsearch.core.TimeValue;
2425
import org.elasticsearch.inference.InferenceResults;
2526
import org.elasticsearch.inference.InferenceServiceResults;
@@ -94,6 +95,7 @@ public static Builder parseRequest(String inferenceEntityId, TaskType taskType,
9495
private final TimeValue inferenceTimeout;
9596
private final boolean stream;
9697
private final boolean isUnifiedCompletionMode;
98+
private final UnifiedCompletionRequest unifiedCompletionRequest;
9799

98100
public Request(
99101
TaskType taskType,
@@ -105,6 +107,32 @@ public Request(
105107
TimeValue inferenceTimeout,
106108
boolean stream,
107109
boolean isUnifiedCompletionsMode
110+
) {
111+
this(
112+
taskType,
113+
inferenceEntityId,
114+
query,
115+
input,
116+
taskSettings,
117+
inputType,
118+
inferenceTimeout,
119+
stream,
120+
isUnifiedCompletionsMode,
121+
null
122+
);
123+
}
124+
125+
public Request(
126+
TaskType taskType,
127+
String inferenceEntityId,
128+
String query,
129+
List<String> input,
130+
Map<String, Object> taskSettings,
131+
InputType inputType,
132+
TimeValue inferenceTimeout,
133+
boolean stream,
134+
boolean isUnifiedCompletionsMode,
135+
@Nullable UnifiedCompletionRequest unifiedCompletionRequest
108136
) {
109137
this.taskType = taskType;
110138
this.inferenceEntityId = inferenceEntityId;
@@ -115,6 +143,7 @@ public Request(
115143
this.inferenceTimeout = inferenceTimeout;
116144
this.stream = stream;
117145
this.isUnifiedCompletionMode = isUnifiedCompletionsMode;
146+
this.unifiedCompletionRequest = unifiedCompletionRequest;
118147
}
119148

120149
public Request(StreamInput in) throws IOException {
@@ -143,8 +172,10 @@ public Request(StreamInput in) throws IOException {
143172

144173
if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_UNIFIED_COMPLETIONS_API)) {
145174
this.isUnifiedCompletionMode = in.readBoolean();
175+
this.unifiedCompletionRequest = in.readOptionalWriteable(UnifiedCompletionRequest::new);
146176
} else {
147177
this.isUnifiedCompletionMode = false;
178+
this.unifiedCompletionRequest = null;
148179
}
149180

150181
// streaming is not supported yet for transport traffic
@@ -244,6 +275,7 @@ public void writeTo(StreamOutput out) throws IOException {
244275

245276
if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_UNIFIED_COMPLETIONS_API)) {
246277
out.writeBoolean(isUnifiedCompletionMode);
278+
out.writeOptionalWriteable(unifiedCompletionRequest);
247279
}
248280
}
249281

@@ -300,6 +332,7 @@ public static class Builder {
300332
private TimeValue timeout = DEFAULT_TIMEOUT;
301333
private boolean stream = false;
302334
private boolean unifiedCompletionMode = false;
335+
private UnifiedCompletionRequest unifiedCompletionRequest;
303336

304337
private Builder() {}
305338

@@ -352,6 +385,11 @@ public Builder setUnifiedCompletionMode(boolean unified) {
352385
return this;
353386
}
354387

388+
public Builder setUnifiedCompletionRequest(UnifiedCompletionRequest unifiedCompletionRequest) {
389+
this.unifiedCompletionRequest = unifiedCompletionRequest;
390+
return this;
391+
}
392+
355393
public Request build() {
356394
return new Request(
357395
taskType,
@@ -362,7 +400,8 @@ public Request build() {
362400
inputType,
363401
timeout,
364402
stream,
365-
unifiedCompletionMode
403+
unifiedCompletionMode,
404+
unifiedCompletionRequest
366405
);
367406
}
368407
}
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.inference.action;
9+
10+
import org.elasticsearch.action.ActionRequest;
11+
import org.elasticsearch.action.ActionRequestValidationException;
12+
import org.elasticsearch.action.ActionType;
13+
import org.elasticsearch.common.io.stream.StreamInput;
14+
import org.elasticsearch.common.io.stream.StreamOutput;
15+
import org.elasticsearch.inference.TaskType;
16+
import org.elasticsearch.xcontent.XContentParser;
17+
18+
import java.io.IOException;
19+
import java.util.Objects;
20+
21+
public class UnifiedCompletionAction extends ActionType<InferenceAction.Response> {
22+
public static final UnifiedCompletionAction INSTANCE = new UnifiedCompletionAction();
23+
public static final String NAME = "cluster:monitor/xpack/inference/unified";
24+
25+
public UnifiedCompletionAction() {
26+
super(NAME);
27+
}
28+
29+
public static class Request extends ActionRequest {
30+
public static Request parseRequest(String inferenceEntityId, TaskType taskType, XContentParser parser) throws IOException {
31+
var unifiedRequest = UnifiedCompletionRequest.PARSER.apply(parser, null);
32+
return new Request(inferenceEntityId, taskType, unifiedRequest);
33+
}
34+
35+
private final String inferenceEntityId;
36+
private final TaskType taskType;
37+
private final UnifiedCompletionRequest unifiedCompletionRequest;
38+
39+
public Request(String inferenceEntityId, TaskType taskType, UnifiedCompletionRequest unifiedCompletionRequest) {
40+
this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId);
41+
this.taskType = Objects.requireNonNull(taskType);
42+
this.unifiedCompletionRequest = Objects.requireNonNull(unifiedCompletionRequest);
43+
}
44+
45+
public Request(StreamInput in) throws IOException {
46+
super(in);
47+
this.inferenceEntityId = in.readString();
48+
this.taskType = TaskType.fromStream(in);
49+
this.unifiedCompletionRequest = new UnifiedCompletionRequest(in);
50+
}
51+
52+
public TaskType getTaskType() {
53+
return taskType;
54+
}
55+
56+
public String getInferenceEntityId() {
57+
return inferenceEntityId;
58+
}
59+
60+
public UnifiedCompletionRequest getUnifiedCompletionRequest() {
61+
return unifiedCompletionRequest;
62+
}
63+
64+
public boolean isStreaming() {
65+
return Objects.requireNonNullElse(unifiedCompletionRequest.stream(), false);
66+
}
67+
68+
@Override
69+
public ActionRequestValidationException validate() {
70+
if (unifiedCompletionRequest == null || unifiedCompletionRequest.messages() == null) {
71+
var e = new ActionRequestValidationException();
72+
e.addValidationError("Field [messages] cannot be null");
73+
return e;
74+
}
75+
76+
if (unifiedCompletionRequest.messages().isEmpty()) {
77+
var e = new ActionRequestValidationException();
78+
e.addValidationError("Field [messages] cannot be an empty array");
79+
return e;
80+
}
81+
82+
if (taskType != TaskType.COMPLETION) {
83+
var e = new ActionRequestValidationException();
84+
e.addValidationError("Field [taskType] must be [completion]");
85+
return e;
86+
}
87+
88+
return null;
89+
}
90+
91+
@Override
92+
public void writeTo(StreamOutput out) throws IOException {
93+
super.writeTo(out);
94+
out.writeString(inferenceEntityId);
95+
taskType.writeTo(out);
96+
unifiedCompletionRequest.writeTo(out);
97+
}
98+
99+
@Override
100+
public boolean equals(Object o) {
101+
if (o == null || getClass() != o.getClass()) return false;
102+
Request request = (Request) o;
103+
return Objects.equals(inferenceEntityId, request.inferenceEntityId)
104+
&& taskType == request.taskType
105+
&& Objects.equals(unifiedCompletionRequest, request.unifiedCompletionRequest);
106+
}
107+
108+
@Override
109+
public int hashCode() {
110+
return Objects.hash(inferenceEntityId, taskType, unifiedCompletionRequest);
111+
}
112+
}
113+
114+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.action;
9+
10+
import org.elasticsearch.action.ActionRequest;
11+
import org.elasticsearch.inference.TaskType;
12+
13+
public abstract class BaseInferenceActionRequest extends ActionRequest {
14+
public abstract boolean isStreaming();
15+
16+
public abstract TaskType getTaskType();
17+
18+
public abstract String getInferenceEntityId();
19+
}

0 commit comments

Comments
 (0)