Skip to content

Commit 633ea4d

Browse files
[8.x] [Inference API] Introduce Update API to change some aspects of existing inference endpoints (elastic#114457) (elastic#114734)
* [Inference API] Introduce Update API to change some aspects of existing inference endpoints (elastic#114457) (cherry picked from commit 6b714e2) * Fix syntax error caused by old JDK?
1 parent 98209e4 commit 633ea4d

File tree

68 files changed

+1745
-102
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+1745
-102
lines changed

docs/changelog/114457.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 114457
2+
summary: "[Inference API] Introduce Update API to change some aspects of existing\
3+
\ inference endpoints"
4+
area: Machine Learning
5+
type: enhancement
6+
issues: []

server/src/main/java/org/elasticsearch/inference/EmptySecretSettings.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.elasticsearch.xcontent.XContentBuilder;
1717

1818
import java.io.IOException;
19+
import java.util.Map;
1920

2021
/**
2122
* This class defines an empty secret settings object. This is useful for services that do not have any secret settings.
@@ -48,4 +49,9 @@ public TransportVersion getMinimalSupportedVersion() {
4849

4950
@Override
5051
public void writeTo(StreamOutput out) throws IOException {}
52+
53+
@Override
54+
public SecretSettings newSecretSettings(Map<String, Object> newSecrets) {
55+
return INSTANCE;
56+
}
5157
}

server/src/main/java/org/elasticsearch/inference/EmptyTaskSettings.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.elasticsearch.xcontent.XContentBuilder;
1717

1818
import java.io.IOException;
19+
import java.util.Map;
1920

2021
/**
2122
* This class defines an empty task settings object. This is useful for services that do not have any task settings.
@@ -53,4 +54,9 @@ public TransportVersion getMinimalSupportedVersion() {
5354

5455
@Override
5556
public void writeTo(StreamOutput out) throws IOException {}
57+
58+
@Override
59+
public TaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
60+
return INSTANCE;
61+
}
5662
}

server/src/main/java/org/elasticsearch/inference/SecretSettings.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
1313
import org.elasticsearch.xcontent.ToXContentObject;
1414

15+
import java.util.Map;
16+
1517
public interface SecretSettings extends ToXContentObject, VersionedNamedWriteable {
1618

19+
SecretSettings newSecretSettings(Map<String, Object> newSecrets);
1720
}

server/src/main/java/org/elasticsearch/inference/TaskSettings.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
1313
import org.elasticsearch.xcontent.ToXContentObject;
1414

15+
import java.util.Map;
16+
1517
public interface TaskSettings extends ToXContentObject, VersionedNamedWriteable {
18+
1619
boolean isEmpty();
20+
21+
TaskSettings updatedTaskSettings(Map<String, Object> newSettings);
1722
}
Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.inference.action;
9+
10+
import org.elasticsearch.ElasticsearchStatusException;
11+
import org.elasticsearch.action.ActionRequestValidationException;
12+
import org.elasticsearch.action.ActionResponse;
13+
import org.elasticsearch.action.ActionType;
14+
import org.elasticsearch.action.support.master.AcknowledgedRequest;
15+
import org.elasticsearch.common.bytes.BytesReference;
16+
import org.elasticsearch.common.io.stream.StreamInput;
17+
import org.elasticsearch.common.io.stream.StreamOutput;
18+
import org.elasticsearch.common.xcontent.XContentHelper;
19+
import org.elasticsearch.core.Nullable;
20+
import org.elasticsearch.core.TimeValue;
21+
import org.elasticsearch.inference.ModelConfigurations;
22+
import org.elasticsearch.inference.TaskType;
23+
import org.elasticsearch.rest.RestStatus;
24+
import org.elasticsearch.xcontent.ToXContentObject;
25+
import org.elasticsearch.xcontent.XContentBuilder;
26+
import org.elasticsearch.xcontent.XContentType;
27+
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
28+
import org.elasticsearch.xpack.core.ml.utils.MlStrings;
29+
30+
import java.io.IOException;
31+
import java.util.Collections;
32+
import java.util.HashMap;
33+
import java.util.Map;
34+
import java.util.Objects;
35+
36+
import static org.elasticsearch.inference.ModelConfigurations.SERVICE_SETTINGS;
37+
import static org.elasticsearch.inference.ModelConfigurations.TASK_SETTINGS;
38+
39+
public class UpdateInferenceModelAction extends ActionType<UpdateInferenceModelAction.Response> {
40+
41+
public static final UpdateInferenceModelAction INSTANCE = new UpdateInferenceModelAction();
42+
public static final String NAME = "cluster:admin/xpack/inference/update";
43+
44+
public UpdateInferenceModelAction() {
45+
super(NAME);
46+
}
47+
48+
public record Settings(
49+
@Nullable Map<String, Object> serviceSettings,
50+
@Nullable Map<String, Object> taskSettings,
51+
@Nullable TaskType taskType
52+
) {}
53+
54+
public static class Request extends AcknowledgedRequest<Request> {
55+
56+
private final String inferenceEntityId;
57+
private final BytesReference content;
58+
private final XContentType contentType;
59+
private final TaskType taskType;
60+
private Settings settings;
61+
62+
public Request(String inferenceEntityId, BytesReference content, XContentType contentType, TaskType taskType, TimeValue timeout) {
63+
super(timeout, DEFAULT_ACK_TIMEOUT);
64+
this.inferenceEntityId = inferenceEntityId;
65+
this.content = content;
66+
this.contentType = contentType;
67+
this.taskType = taskType;
68+
}
69+
70+
public Request(StreamInput in) throws IOException {
71+
super(in);
72+
this.inferenceEntityId = in.readString();
73+
this.content = in.readBytesReference();
74+
this.taskType = TaskType.fromStream(in);
75+
this.contentType = in.readEnum(XContentType.class);
76+
}
77+
78+
public String getInferenceEntityId() {
79+
return inferenceEntityId;
80+
}
81+
82+
public TaskType getTaskType() {
83+
return taskType;
84+
}
85+
86+
/**
87+
* The body of the request.
88+
* For in-cluster models, this is expected to contain some of the following:
89+
* "number_of_allocations": `an integer`
90+
*
91+
* For third-party services, this is expected to contain:
92+
* "service_settings": {
93+
* "api_key": `a string` // service settings can only contain an api key
94+
* }
95+
* "task_settings": { a map of settings }
96+
*
97+
*/
98+
public BytesReference getContent() {
99+
return content;
100+
}
101+
102+
/**
103+
* The body of the request as a map.
104+
* The map is validated such that only allowed fields are present.
105+
* If any fields in the body are not on the allow list, this function will throw an exception.
106+
*/
107+
public Settings getContentAsSettings() {
108+
if (settings == null) { // settings is deterministic on content, so we only need to compute it once
109+
Map<String, Object> unvalidatedMap = XContentHelper.convertToMap(content, false, contentType).v2();
110+
Map<String, Object> serviceSettings = new HashMap<>();
111+
Map<String, Object> taskSettings = new HashMap<>();
112+
TaskType taskType = null;
113+
114+
if (unvalidatedMap.isEmpty()) {
115+
throw new ElasticsearchStatusException("Request body is empty", RestStatus.BAD_REQUEST);
116+
}
117+
118+
if (unvalidatedMap.containsKey("task_type")) {
119+
if (unvalidatedMap.get("task_type") instanceof String taskTypeString) {
120+
taskType = TaskType.fromStringOrStatusException(taskTypeString);
121+
} else {
122+
throw new ElasticsearchStatusException(
123+
"Failed to parse [task_type] in update request [{}]",
124+
RestStatus.INTERNAL_SERVER_ERROR,
125+
unvalidatedMap.toString()
126+
);
127+
}
128+
unvalidatedMap.remove("task_type");
129+
}
130+
131+
if (unvalidatedMap.containsKey(SERVICE_SETTINGS)) {
132+
if (unvalidatedMap.get(SERVICE_SETTINGS) instanceof Map<?, ?> tempMap) {
133+
for (Map.Entry<?, ?> entry : (tempMap).entrySet()) {
134+
if (entry.getKey() instanceof String key) {
135+
serviceSettings.put(key, entry.getValue());
136+
} else {
137+
throw new ElasticsearchStatusException(
138+
"Failed to parse update request [{}]",
139+
RestStatus.INTERNAL_SERVER_ERROR,
140+
unvalidatedMap.toString()
141+
);
142+
}
143+
}
144+
unvalidatedMap.remove(SERVICE_SETTINGS);
145+
} else {
146+
throw new ElasticsearchStatusException(
147+
"Unable to parse service settings in the request [{}]",
148+
RestStatus.BAD_REQUEST,
149+
unvalidatedMap.toString()
150+
);
151+
}
152+
}
153+
154+
if (unvalidatedMap.containsKey(TASK_SETTINGS)) {
155+
if (unvalidatedMap.get(TASK_SETTINGS) instanceof Map<?, ?> tempMap) {
156+
for (Map.Entry<?, ?> entry : (tempMap).entrySet()) {
157+
if (entry.getKey() instanceof String key) {
158+
taskSettings.put(key, entry.getValue());
159+
} else {
160+
throw new ElasticsearchStatusException(
161+
"Failed to parse update request [{}]",
162+
RestStatus.INTERNAL_SERVER_ERROR,
163+
unvalidatedMap.toString()
164+
);
165+
}
166+
}
167+
unvalidatedMap.remove(TASK_SETTINGS);
168+
} else {
169+
throw new ElasticsearchStatusException(
170+
"Unable to parse task settings in the request [{}]",
171+
RestStatus.BAD_REQUEST,
172+
unvalidatedMap.toString()
173+
);
174+
}
175+
}
176+
177+
if (unvalidatedMap.isEmpty() == false) {
178+
throw new ElasticsearchStatusException(
179+
"Request contained fields which cannot be updated, remove these fields and try again [{}]",
180+
RestStatus.BAD_REQUEST,
181+
unvalidatedMap.toString()
182+
);
183+
}
184+
185+
this.settings = new Settings(
186+
serviceSettings.isEmpty() == false ? Collections.unmodifiableMap(serviceSettings) : null,
187+
taskSettings.isEmpty() == false ? Collections.unmodifiableMap(taskSettings) : null,
188+
taskType
189+
);
190+
}
191+
return this.settings;
192+
}
193+
194+
public XContentType getContentType() {
195+
return contentType;
196+
}
197+
198+
@Override
199+
public void writeTo(StreamOutput out) throws IOException {
200+
super.writeTo(out);
201+
out.writeString(inferenceEntityId);
202+
taskType.writeTo(out);
203+
out.writeBytesReference(content);
204+
XContentHelper.writeTo(out, contentType);
205+
}
206+
207+
@Override
208+
public ActionRequestValidationException validate() {
209+
ActionRequestValidationException validationException = new ActionRequestValidationException();
210+
if (MlStrings.isValidId(this.inferenceEntityId) == false) {
211+
validationException.addValidationError(Messages.getMessage(Messages.INVALID_ID, "inference_id", this.inferenceEntityId));
212+
}
213+
214+
if (validationException.validationErrors().isEmpty() == false) {
215+
return validationException;
216+
} else {
217+
return null;
218+
}
219+
}
220+
221+
@Override
222+
public boolean equals(Object o) {
223+
if (this == o) return true;
224+
if (o == null || getClass() != o.getClass()) return false;
225+
Request request = (Request) o;
226+
return Objects.equals(inferenceEntityId, request.inferenceEntityId)
227+
&& Objects.equals(content, request.content)
228+
&& contentType == request.contentType
229+
&& taskType == request.taskType;
230+
}
231+
232+
@Override
233+
public int hashCode() {
234+
return Objects.hash(inferenceEntityId, content, contentType, taskType);
235+
}
236+
}
237+
238+
public static class Response extends ActionResponse implements ToXContentObject {
239+
240+
private final ModelConfigurations model;
241+
242+
public Response(ModelConfigurations model) {
243+
this.model = model;
244+
}
245+
246+
public Response(StreamInput in) throws IOException {
247+
super(in);
248+
model = new ModelConfigurations(in);
249+
}
250+
251+
public ModelConfigurations getModel() {
252+
return model;
253+
}
254+
255+
@Override
256+
public void writeTo(StreamOutput out) throws IOException {
257+
model.writeTo(out);
258+
}
259+
260+
@Override
261+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
262+
return model.toFilteredXContent(builder, params);
263+
}
264+
265+
@Override
266+
public boolean equals(Object o) {
267+
if (this == o) return true;
268+
if (o == null || getClass() != o.getClass()) return false;
269+
Response response = (Response) o;
270+
return Objects.equals(model, response.model);
271+
}
272+
273+
@Override
274+
public int hashCode() {
275+
return Objects.hash(model);
276+
}
277+
}
278+
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,9 @@ public final class Messages {
281281
public static final String FIELD_CANNOT_BE_NULL = "Field [{0}] cannot be null";
282282
public static final String MODEL_ID_MATCHES_EXISTING_MODEL_IDS_BUT_MUST_NOT =
283283
"Model IDs must be unique. Requested model ID [{}] matches existing model IDs but must not.";
284+
public static final String MODEL_ID_DOES_NOT_MATCH_EXISTING_MODEL_IDS_BUT_MUST_FOR_IN_CLUSTER_SERVICE =
285+
"Requested model ID [{}] does not have a matching trained model and thus cannot be updated.";
286+
public static final String INFERENCE_ENTITY_NON_EXISTANT_NO_UPDATE = "The inference endpoint [{}] does not exist and cannot be updated";
284287

285288
private Messages() {}
286289

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ public static ElasticsearchStatusException badRequestException(String msg, Objec
9898
return new ElasticsearchStatusException(msg, RestStatus.BAD_REQUEST, args);
9999
}
100100

101+
public static ElasticsearchStatusException entityNotFoundException(String msg, Object... args) {
102+
return new ElasticsearchStatusException(msg, RestStatus.NOT_FOUND, args);
103+
}
104+
101105
public static ElasticsearchStatusException taskOperationFailureToStatusException(TaskOperationFailure failure) {
102106
return new ElasticsearchStatusException(failure.getCause().getMessage(), failure.getStatus(), failure.getCause());
103107
}

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,21 @@ static String mockSparseServiceModelConfig(@Nullable TaskType taskTypeInBody) {
8181
""", taskType);
8282
}
8383

84+
static String updateConfig(@Nullable TaskType taskTypeInBody, String apiKey, int temperature) {
85+
var taskType = taskTypeInBody == null ? "" : "\"task_type\": \"" + taskTypeInBody + "\",";
86+
return Strings.format("""
87+
{
88+
%s
89+
"service_settings": {
90+
"api_key": "%s"
91+
},
92+
"task_settings": {
93+
"temperature": %d
94+
}
95+
}
96+
""", taskType, apiKey, temperature);
97+
}
98+
8499
static String mockCompletionServiceModelConfig(@Nullable TaskType taskTypeInBody) {
85100
var taskType = taskTypeInBody == null ? "" : "\"task_type\": \"" + taskTypeInBody + "\",";
86101
return Strings.format("""
@@ -196,6 +211,11 @@ protected Map<String, Object> putModel(String modelId, String modelConfig, TaskT
196211
return putRequest(endpoint, modelConfig);
197212
}
198213

214+
protected Map<String, Object> updateEndpoint(String inferenceID, String modelConfig, TaskType taskType) throws IOException {
215+
String endpoint = Strings.format("_inference/%s/%s/_update", taskType, inferenceID);
216+
return putRequest(endpoint, modelConfig);
217+
}
218+
199219
protected Map<String, Object> putPipeline(String pipelineId, String modelId) throws IOException {
200220
String endpoint = Strings.format("_ingest/pipeline/%s", pipelineId);
201221
String body = """

0 commit comments

Comments
 (0)