Skip to content

Commit 6da8f92

Browse files
prwhelanelasticsearchmachinedavidkyle
authored
[ML] Block trained model updates from inference (elastic#130940)
When the Trained Model has been deployed through the Inference Endpoint API, it can only be updated using the Inference Endpoint API. When the Trained Model has been deployed and then attached to an Inference Endpoint, it can only be updated using the Trained Model API. Fix elastic#129999 Co-authored-by: elasticsearchmachine <[email protected]> Co-authored-by: David Kyle <[email protected]>
1 parent e2bb47c commit 6da8f92

File tree

7 files changed

+395
-62
lines changed

7 files changed

+395
-62
lines changed

docs/changelog/130940.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 130940
2+
summary: Block trained model updates from inference
3+
area: Machine Learning
4+
type: enhancement
5+
issues:
6+
- 129999

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,12 @@ public final class Messages {
286286
public static final String MODEL_ID_DOES_NOT_MATCH_EXISTING_MODEL_IDS_BUT_MUST_FOR_IN_CLUSTER_SERVICE =
287287
"Requested model ID [{}] does not have a matching trained model and thus cannot be updated.";
288288
public static final String INFERENCE_ENTITY_NON_EXISTANT_NO_UPDATE = "The inference endpoint [{}] does not exist and cannot be updated";
289+
public static final String INFERENCE_REFERENCE_CANNOT_UPDATE_ANOTHER_ENDPOINT =
290+
"Cannot update inference endpoint [{}] for model deployment [{}] as it was created by another inference endpoint. "
291+
+ "The model can only be updated using inference endpoint id [{}].";
292+
public static final String INFERENCE_CAN_ONLY_UPDATE_MODELS_IT_CREATED =
293+
"Cannot update inference endpoint [{}] using model deployment [{}]. "
294+
+ "The model deployment must be updated through the trained models API.";
289295

290296
private Messages() {}
291297

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java

Lines changed: 146 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -51,24 +51,6 @@ public void testAttachToDeployment() throws IOException {
5151
var results = infer(inferenceId, List.of("washing machine"));
5252
assertNotNull(results.get("sparse_embedding"));
5353

54-
var updatedNumAllocations = randomIntBetween(1, 10);
55-
var updatedEndpointConfig = updateEndpoint(inferenceId, updatedEndpointConfig(updatedNumAllocations), TaskType.SPARSE_EMBEDDING);
56-
assertThat(
57-
updatedEndpointConfig.get("service_settings"),
58-
is(
59-
Map.of(
60-
"num_allocations",
61-
updatedNumAllocations,
62-
"num_threads",
63-
1,
64-
"model_id",
65-
"attach_to_deployment",
66-
"deployment_id",
67-
"existing_deployment"
68-
)
69-
)
70-
);
71-
7254
deleteModel(inferenceId);
7355
// assert deployment not stopped
7456
var stats = (List<Map<String, Object>>) getTrainedModelStats(modelId).get("trained_model_stats");
@@ -128,24 +110,6 @@ public void testAttachWithModelId() throws IOException {
128110
var results = infer(inferenceId, List.of("washing machine"));
129111
assertNotNull(results.get("sparse_embedding"));
130112

131-
var updatedNumAllocations = randomIntBetween(1, 10);
132-
var updatedEndpointConfig = updateEndpoint(inferenceId, updatedEndpointConfig(updatedNumAllocations), TaskType.SPARSE_EMBEDDING);
133-
assertThat(
134-
updatedEndpointConfig.get("service_settings"),
135-
is(
136-
Map.of(
137-
"num_allocations",
138-
updatedNumAllocations,
139-
"num_threads",
140-
1,
141-
"model_id",
142-
"attach_with_model_id",
143-
"deployment_id",
144-
"existing_deployment_with_model_id"
145-
)
146-
)
147-
);
148-
149113
forceStopMlNodeDeployment(deploymentId);
150114
}
151115

@@ -180,6 +144,30 @@ public void testDeploymentDoesNotExist() {
180144
assertThat(e.getMessage(), containsString("Cannot find deployment [missing_deployment]"));
181145
}
182146

147+
public void testCreateInferenceUsingSameDeploymentId() throws IOException {
148+
var modelId = "conflicting_ids";
149+
var deploymentId = modelId;
150+
var inferenceId = modelId;
151+
152+
CustomElandModelIT.createMlNodeTextExpansionModel(modelId, client());
153+
var response = startMlNodeDeploymemnt(modelId, deploymentId);
154+
assertStatusOkOrCreated(response);
155+
156+
var responseException = assertThrows(
157+
ResponseException.class,
158+
() -> putModel(inferenceId, endpointConfig(deploymentId), TaskType.SPARSE_EMBEDDING)
159+
);
160+
assertThat(
161+
responseException.getMessage(),
162+
containsString(
163+
"Inference endpoint IDs must be unique. "
164+
+ "Requested inference endpoint ID [conflicting_ids] matches existing trained model ID(s) but must not."
165+
)
166+
);
167+
168+
forceStopMlNodeDeployment(deploymentId);
169+
}
170+
183171
public void testNumAllocationsIsUpdated() throws IOException {
184172
var modelId = "update_num_allocations";
185173
var deploymentId = modelId;
@@ -208,7 +196,16 @@ public void testNumAllocationsIsUpdated() throws IOException {
208196
)
209197
);
210198

211-
assertStatusOkOrCreated(updateMlNodeDeploymemnt(deploymentId, 2));
199+
var responseException = assertThrows(ResponseException.class, () -> updateInference(inferenceId, TaskType.SPARSE_EMBEDDING, 2));
200+
assertThat(
201+
responseException.getMessage(),
202+
containsString(
203+
"Cannot update inference endpoint [test_num_allocations_updated] using model deployment [update_num_allocations]. "
204+
+ "The model deployment must be updated through the trained models API."
205+
)
206+
);
207+
208+
updateMlNodeDeploymemnt(deploymentId, 2);
212209

213210
var updatedServiceSettings = getModel(inferenceId).get("service_settings");
214211
assertThat(
@@ -227,6 +224,92 @@ public void testNumAllocationsIsUpdated() throws IOException {
227224
)
228225
)
229226
);
227+
228+
forceStopMlNodeDeployment(deploymentId);
229+
}
230+
231+
public void testUpdateWhenInferenceEndpointCreatesDeployment() throws IOException {
232+
var modelId = "update_num_allocations_from_created_endpoint";
233+
var inferenceId = "test_created_endpoint_from_model";
234+
var deploymentId = inferenceId;
235+
236+
CustomElandModelIT.createMlNodeTextExpansionModel(modelId, client());
237+
238+
var putModel = putModel(inferenceId, Strings.format("""
239+
{
240+
"service": "elasticsearch",
241+
"service_settings": {
242+
"num_allocations": %s,
243+
"num_threads": %s,
244+
"model_id": "%s"
245+
}
246+
}
247+
""", 1, 1, modelId), TaskType.SPARSE_EMBEDDING);
248+
var serviceSettings = putModel.get("service_settings");
249+
assertThat(putModel.toString(), serviceSettings, is(Map.of("num_allocations", 1, "num_threads", 1, "model_id", modelId)));
250+
251+
updateInference(inferenceId, TaskType.SPARSE_EMBEDDING, 2);
252+
253+
var responseException = assertThrows(ResponseException.class, () -> updateMlNodeDeploymemnt(deploymentId, 2));
254+
assertThat(
255+
responseException.getMessage(),
256+
containsString(
257+
"Cannot update deployment [test_created_endpoint_from_model] as it was created by inference endpoint "
258+
+ "[test_created_endpoint_from_model]. This model deployment must be updated through the inference API."
259+
)
260+
);
261+
262+
var updatedServiceSettings = getModel(inferenceId).get("service_settings");
263+
assertThat(
264+
updatedServiceSettings.toString(),
265+
updatedServiceSettings,
266+
is(Map.of("num_allocations", 2, "num_threads", 1, "model_id", modelId))
267+
);
268+
269+
forceStopMlNodeDeployment(deploymentId);
270+
}
271+
272+
public void testCannotUpdateAnotherInferenceEndpointsCreatedDeployment() throws IOException {
273+
var modelId = "model_deployment_for_endpoint";
274+
var inferenceId = "first_endpoint_for_model_deployment";
275+
var deploymentId = inferenceId;
276+
277+
CustomElandModelIT.createMlNodeTextExpansionModel(modelId, client());
278+
279+
putModel(inferenceId, Strings.format("""
280+
{
281+
"service": "elasticsearch",
282+
"service_settings": {
283+
"num_allocations": %s,
284+
"num_threads": %s,
285+
"model_id": "%s"
286+
}
287+
}
288+
""", 1, 1, modelId), TaskType.SPARSE_EMBEDDING);
289+
290+
var secondInferenceId = "second_endpoint_for_model_deployment";
291+
var putModel = putModel(secondInferenceId, endpointConfig(deploymentId), TaskType.SPARSE_EMBEDDING);
292+
var serviceSettings = putModel.get("service_settings");
293+
assertThat(
294+
putModel.toString(),
295+
serviceSettings,
296+
is(Map.of("num_allocations", 1, "num_threads", 1, "model_id", modelId, "deployment_id", deploymentId))
297+
);
298+
299+
var responseException = assertThrows(
300+
ResponseException.class,
301+
() -> updateInference(secondInferenceId, TaskType.SPARSE_EMBEDDING, 2)
302+
);
303+
assertThat(
304+
responseException.getMessage(),
305+
containsString(
306+
"Cannot update inference endpoint [second_endpoint_for_model_deployment] for model deployment "
307+
+ "[first_endpoint_for_model_deployment] as it was created by another inference endpoint. "
308+
+ "The model can only be updated using inference endpoint id [first_endpoint_for_model_deployment]."
309+
)
310+
);
311+
312+
forceStopMlNodeDeployment(deploymentId);
230313
}
231314

232315
public void testStoppingDeploymentAttachedToInferenceEndpoint() throws IOException {
@@ -300,6 +383,22 @@ private Response startMlNodeDeploymemnt(String modelId, String deploymentId) thr
300383
return client().performRequest(request);
301384
}
302385

386+
private Response updateInference(String deploymentId, TaskType taskType, int numAllocations) throws IOException {
387+
String endPoint = Strings.format("/_inference/%s/%s/_update", taskType, deploymentId);
388+
389+
var body = Strings.format("""
390+
{
391+
"service_settings": {
392+
"num_allocations": %d
393+
}
394+
}
395+
""", numAllocations);
396+
397+
Request request = new Request("PUT", endPoint);
398+
request.setJsonEntity(body);
399+
return client().performRequest(request);
400+
}
401+
303402
private Response updateMlNodeDeploymemnt(String deploymentId, int numAllocations) throws IOException {
304403
String endPoint = "/_ml/trained_models/" + deploymentId + "/deployment/_update";
305404

@@ -314,6 +413,16 @@ private Response updateMlNodeDeploymemnt(String deploymentId, int numAllocations
314413
return client().performRequest(request);
315414
}
316415

416+
private Map<String, Object> updateMlNodeDeploymemnt(String deploymentId, String body) throws IOException {
417+
String endPoint = "/_ml/trained_models/" + deploymentId + "/deployment/_update";
418+
419+
Request request = new Request("POST", endPoint);
420+
request.setJsonEntity(body);
421+
var response = client().performRequest(request);
422+
assertStatusOkOrCreated(response);
423+
return entityAsMap(response);
424+
}
425+
317426
protected void stopMlNodeDeployment(String deploymentId) throws IOException {
318427
String endpoint = "/_ml/trained_models/" + deploymentId + "/deployment/_stop";
319428
Request request = new Request("POST", endpoint);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUpdateInferenceModelAction.java

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
import java.io.IOException;
5959
import java.util.HashMap;
6060
import java.util.List;
61-
import java.util.Map;
6261
import java.util.Optional;
6362
import java.util.concurrent.atomic.AtomicReference;
6463

@@ -224,13 +223,10 @@ private Model combineExistingModelWithNewSettings(
224223
if (settingsToUpdate.serviceSettings() != null && existingSecretSettings != null) {
225224
newSecretSettings = existingSecretSettings.newSecretSettings(settingsToUpdate.serviceSettings());
226225
}
227-
if (settingsToUpdate.serviceSettings() != null && settingsToUpdate.serviceSettings().containsKey(NUM_ALLOCATIONS)) {
228-
// In cluster services can only have their num_allocations updated, so this is a special case
226+
if (settingsToUpdate.serviceSettings() != null) {
227+
// In cluster services can have their deployment settings updated, so this is a special case
229228
if (newServiceSettings instanceof ElasticsearchInternalServiceSettings elasticServiceSettings) {
230-
newServiceSettings = new ElasticsearchInternalServiceSettings(
231-
elasticServiceSettings,
232-
(Integer) settingsToUpdate.serviceSettings().get(NUM_ALLOCATIONS)
233-
);
229+
newServiceSettings = elasticServiceSettings.updateServiceSettings(settingsToUpdate.serviceSettings());
234230
}
235231
}
236232
if (settingsToUpdate.taskSettings() != null && existingTaskSettings != null) {
@@ -257,26 +253,59 @@ private void updateInClusterEndpoint(
257253
Model newModel,
258254
Model existingParsedModel,
259255
ActionListener<Boolean> listener
260-
) throws IOException {
256+
) {
261257
// The model we are trying to update must have a trained model associated with it if it is an in-cluster deployment
262258
var deploymentId = getDeploymentIdForInClusterEndpoint(existingParsedModel);
263-
throwIfTrainedModelDoesntExist(request.getInferenceEntityId(), deploymentId);
259+
var inferenceEntityId = request.getInferenceEntityId();
260+
throwIfTrainedModelDoesntExist(inferenceEntityId, deploymentId);
264261

265-
Map<String, Object> serviceSettings = request.getContentAsSettings().serviceSettings();
266-
if (serviceSettings != null && serviceSettings.get(NUM_ALLOCATIONS) instanceof Integer numAllocations) {
262+
if (inferenceEntityId.equals(deploymentId) == false) {
263+
modelRegistry.getModel(deploymentId, ActionListener.wrap(unparsedModel -> {
264+
// if this deployment was created by another inference endpoint, then it must be updated using that inference endpoint
265+
listener.onFailure(
266+
new ElasticsearchStatusException(
267+
Messages.INFERENCE_REFERENCE_CANNOT_UPDATE_ANOTHER_ENDPOINT,
268+
RestStatus.CONFLICT,
269+
inferenceEntityId,
270+
deploymentId,
271+
unparsedModel.inferenceEntityId()
272+
)
273+
);
274+
}, e -> {
275+
if (e instanceof ResourceNotFoundException) {
276+
// if this deployment was created by the trained models API, then it must be updated by the trained models API
277+
listener.onFailure(
278+
new ElasticsearchStatusException(
279+
Messages.INFERENCE_CAN_ONLY_UPDATE_MODELS_IT_CREATED,
280+
RestStatus.CONFLICT,
281+
inferenceEntityId,
282+
deploymentId
283+
)
284+
);
285+
return;
286+
}
287+
listener.onFailure(e);
288+
}));
289+
return;
290+
}
291+
292+
if (newModel.getServiceSettings() instanceof ElasticsearchInternalServiceSettings elasticServiceSettings) {
267293

268-
UpdateTrainedModelDeploymentAction.Request updateRequest = new UpdateTrainedModelDeploymentAction.Request(deploymentId);
269-
updateRequest.setNumberOfAllocations(numAllocations);
294+
var updateRequest = new UpdateTrainedModelDeploymentAction.Request(deploymentId);
295+
updateRequest.setNumberOfAllocations(elasticServiceSettings.getNumAllocations());
296+
updateRequest.setAdaptiveAllocationsSettings(elasticServiceSettings.getAdaptiveAllocationsSettings());
297+
updateRequest.setIsInternal(true);
270298

271299
var delegate = listener.<CreateTrainedModelAssignmentAction.Response>delegateFailure((l2, response) -> {
272300
modelRegistry.updateModelTransaction(newModel, existingParsedModel, l2);
273301
});
274302

275303
logger.info(
276-
"Updating trained model deployment [{}] for inference entity [{}] with [{}] num_allocations",
304+
"Updating trained model deployment [{}] for inference entity [{}] with [{}] num_allocations and adaptive allocations [{}]",
277305
deploymentId,
278306
request.getInferenceEntityId(),
279-
numAllocations
307+
elasticServiceSettings.getNumAllocations(),
308+
elasticServiceSettings.getAdaptiveAllocationsSettings()
280309
);
281310
client.execute(UpdateTrainedModelDeploymentAction.INSTANCE, updateRequest, delegate);
282311

@@ -317,7 +346,6 @@ private void throwIfTrainedModelDoesntExist(String inferenceEntityId, String dep
317346
throw ExceptionsHelper.entityNotFoundException(
318347
Messages.MODEL_ID_DOES_NOT_MATCH_EXISTING_MODEL_IDS_BUT_MUST_FOR_IN_CLUSTER_SERVICE,
319348
inferenceEntityId
320-
321349
);
322350
}
323351
}

0 commit comments

Comments
 (0)