Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/130940.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 130940
summary: Block trained model updates from inference
area: Machine Learning
type: enhancement
issues:
- 129999
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,12 @@ public final class Messages {
public static final String MODEL_ID_DOES_NOT_MATCH_EXISTING_MODEL_IDS_BUT_MUST_FOR_IN_CLUSTER_SERVICE =
"Requested model ID [{}] does not have a matching trained model and thus cannot be updated.";
public static final String INFERENCE_ENTITY_NON_EXISTANT_NO_UPDATE = "The inference endpoint [{}] does not exist and cannot be updated";
public static final String INFERENCE_REFERENCE_CANNOT_UPDATE_ANOTHER_ENDPOINT =
"Cannot update inference endpoint [{}] for model deployment [{}] as it was created by another inference endpoint. "
+ "The model can only be updated using inference endpoint id [{}].";
public static final String INFERENCE_CAN_ONLY_UPDATE_MODELS_IT_CREATED =
"Cannot update inference endpoint [{}] using model deployment [{}]. "
+ "The model deployment must be updated through the trained models API.";

private Messages() {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,24 +51,6 @@ public void testAttachToDeployment() throws IOException {
var results = infer(inferenceId, List.of("washing machine"));
assertNotNull(results.get("sparse_embedding"));

var updatedNumAllocations = randomIntBetween(1, 10);
var updatedEndpointConfig = updateEndpoint(inferenceId, updatedEndpointConfig(updatedNumAllocations), TaskType.SPARSE_EMBEDDING);
assertThat(
updatedEndpointConfig.get("service_settings"),
is(
Map.of(
"num_allocations",
updatedNumAllocations,
"num_threads",
1,
"model_id",
"attach_to_deployment",
"deployment_id",
"existing_deployment"
)
)
);

deleteModel(inferenceId);
// assert deployment not stopped
var stats = (List<Map<String, Object>>) getTrainedModelStats(modelId).get("trained_model_stats");
Expand Down Expand Up @@ -128,24 +110,6 @@ public void testAttachWithModelId() throws IOException {
var results = infer(inferenceId, List.of("washing machine"));
assertNotNull(results.get("sparse_embedding"));

var updatedNumAllocations = randomIntBetween(1, 10);
var updatedEndpointConfig = updateEndpoint(inferenceId, updatedEndpointConfig(updatedNumAllocations), TaskType.SPARSE_EMBEDDING);
assertThat(
updatedEndpointConfig.get("service_settings"),
is(
Map.of(
"num_allocations",
updatedNumAllocations,
"num_threads",
1,
"model_id",
"attach_with_model_id",
"deployment_id",
"existing_deployment_with_model_id"
)
)
);

forceStopMlNodeDeployment(deploymentId);
}

Expand Down Expand Up @@ -180,6 +144,30 @@ public void testDeploymentDoesNotExist() {
assertThat(e.getMessage(), containsString("Cannot find deployment [missing_deployment]"));
}

public void testCreateInferenceUsingSameDeploymentId() throws IOException {
var modelId = "conflicting_ids";
var deploymentId = modelId;
var inferenceId = modelId;

CustomElandModelIT.createMlNodeTextExpansionModel(modelId, client());
var response = startMlNodeDeploymemnt(modelId, deploymentId);
assertStatusOkOrCreated(response);

var responseException = assertThrows(
ResponseException.class,
() -> putModel(inferenceId, endpointConfig(deploymentId), TaskType.SPARSE_EMBEDDING)
);
assertThat(
responseException.getMessage(),
containsString(
"Inference endpoint IDs must be unique. "
+ "Requested inference endpoint ID [conflicting_ids] matches existing trained model ID(s) but must not."
)
);

forceStopMlNodeDeployment(deploymentId);
}

public void testNumAllocationsIsUpdated() throws IOException {
var modelId = "update_num_allocations";
var deploymentId = modelId;
Expand Down Expand Up @@ -208,7 +196,16 @@ public void testNumAllocationsIsUpdated() throws IOException {
)
);

assertStatusOkOrCreated(updateMlNodeDeploymemnt(deploymentId, 2));
var responseException = assertThrows(ResponseException.class, () -> updateInference(inferenceId, TaskType.SPARSE_EMBEDDING, 2));
assertThat(
responseException.getMessage(),
containsString(
"Cannot update inference endpoint [test_num_allocations_updated] using model deployment [update_num_allocations]. "
+ "The model deployment must be updated through the trained models API."
)
);

updateMlNodeDeploymemnt(deploymentId, 2);

var updatedServiceSettings = getModel(inferenceId).get("service_settings");
assertThat(
Expand All @@ -227,6 +224,92 @@ public void testNumAllocationsIsUpdated() throws IOException {
)
)
);

forceStopMlNodeDeployment(deploymentId);
}

public void testUpdateWhenInferenceEndpointCreatesDeployment() throws IOException {
var modelId = "update_num_allocations_from_created_endpoint";
var inferenceId = "test_created_endpoint_from_model";
var deploymentId = inferenceId;

CustomElandModelIT.createMlNodeTextExpansionModel(modelId, client());

var putModel = putModel(inferenceId, Strings.format("""
{
"service": "elasticsearch",
"service_settings": {
"num_allocations": %s,
"num_threads": %s,
"model_id": "%s"
}
}
""", 1, 1, modelId), TaskType.SPARSE_EMBEDDING);
var serviceSettings = putModel.get("service_settings");
assertThat(putModel.toString(), serviceSettings, is(Map.of("num_allocations", 1, "num_threads", 1, "model_id", modelId)));

updateInference(inferenceId, TaskType.SPARSE_EMBEDDING, 2);

var responseException = assertThrows(ResponseException.class, () -> updateMlNodeDeploymemnt(deploymentId, 2));
assertThat(
responseException.getMessage(),
containsString(
"Cannot update deployment [test_created_endpoint_from_model] as it was created by inference endpoint "
+ "[test_created_endpoint_from_model]. This model deployment must be updated through the inference API."
)
);

var updatedServiceSettings = getModel(inferenceId).get("service_settings");
assertThat(
updatedServiceSettings.toString(),
updatedServiceSettings,
is(Map.of("num_allocations", 2, "num_threads", 1, "model_id", modelId))
);

forceStopMlNodeDeployment(deploymentId);
}

public void testCannotUpdateAnotherInferenceEndpointsCreatedDeployment() throws IOException {
var modelId = "model_deployment_for_endpoint";
var inferenceId = "first_endpoint_for_model_deployment";
var deploymentId = inferenceId;

CustomElandModelIT.createMlNodeTextExpansionModel(modelId, client());

putModel(inferenceId, Strings.format("""
{
"service": "elasticsearch",
"service_settings": {
"num_allocations": %s,
"num_threads": %s,
"model_id": "%s"
}
}
""", 1, 1, modelId), TaskType.SPARSE_EMBEDDING);

var secondInferenceId = "second_endpoint_for_model_deployment";
var putModel = putModel(secondInferenceId, endpointConfig(deploymentId), TaskType.SPARSE_EMBEDDING);
var serviceSettings = putModel.get("service_settings");
assertThat(
putModel.toString(),
serviceSettings,
is(Map.of("num_allocations", 1, "num_threads", 1, "model_id", modelId, "deployment_id", deploymentId))
);

var responseException = assertThrows(
ResponseException.class,
() -> updateInference(secondInferenceId, TaskType.SPARSE_EMBEDDING, 2)
);
assertThat(
responseException.getMessage(),
containsString(
"Cannot update inference endpoint [second_endpoint_for_model_deployment] for model deployment "
+ "[first_endpoint_for_model_deployment] as it was created by another inference endpoint. "
+ "The model can only be updated using inference endpoint id [first_endpoint_for_model_deployment]."
)
);

forceStopMlNodeDeployment(deploymentId);
}

public void testStoppingDeploymentAttachedToInferenceEndpoint() throws IOException {
Expand Down Expand Up @@ -300,6 +383,22 @@ private Response startMlNodeDeploymemnt(String modelId, String deploymentId) thr
return client().performRequest(request);
}

private Response updateInference(String deploymentId, TaskType taskType, int numAllocations) throws IOException {
String endPoint = Strings.format("/_inference/%s/%s/_update", taskType, deploymentId);

var body = Strings.format("""
{
"service_settings": {
"num_allocations": %d
}
}
""", numAllocations);

Request request = new Request("PUT", endPoint);
request.setJsonEntity(body);
return client().performRequest(request);
}

private Response updateMlNodeDeploymemnt(String deploymentId, int numAllocations) throws IOException {
String endPoint = "/_ml/trained_models/" + deploymentId + "/deployment/_update";

Expand All @@ -314,6 +413,16 @@ private Response updateMlNodeDeploymemnt(String deploymentId, int numAllocations
return client().performRequest(request);
}

private Map<String, Object> updateMlNodeDeploymemnt(String deploymentId, String body) throws IOException {
String endPoint = "/_ml/trained_models/" + deploymentId + "/deployment/_update";

Request request = new Request("POST", endPoint);
request.setJsonEntity(body);
var response = client().performRequest(request);
assertStatusOkOrCreated(response);
return entityAsMap(response);
}

protected void stopMlNodeDeployment(String deploymentId) throws IOException {
String endpoint = "/_ml/trained_models/" + deploymentId + "/deployment/_stop";
Request request = new Request("POST", endpoint);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;

Expand Down Expand Up @@ -224,13 +223,10 @@ private Model combineExistingModelWithNewSettings(
if (settingsToUpdate.serviceSettings() != null && existingSecretSettings != null) {
newSecretSettings = existingSecretSettings.newSecretSettings(settingsToUpdate.serviceSettings());
}
if (settingsToUpdate.serviceSettings() != null && settingsToUpdate.serviceSettings().containsKey(NUM_ALLOCATIONS)) {
// In cluster services can only have their num_allocations updated, so this is a special case
if (settingsToUpdate.serviceSettings() != null) {
// In cluster services can have their deployment settings updated, so this is a special case
if (newServiceSettings instanceof ElasticsearchInternalServiceSettings elasticServiceSettings) {
newServiceSettings = new ElasticsearchInternalServiceSettings(
elasticServiceSettings,
(Integer) settingsToUpdate.serviceSettings().get(NUM_ALLOCATIONS)
);
newServiceSettings = elasticServiceSettings.updateServiceSettings(settingsToUpdate.serviceSettings());
}
}
if (settingsToUpdate.taskSettings() != null && existingTaskSettings != null) {
Expand All @@ -257,26 +253,59 @@ private void updateInClusterEndpoint(
Model newModel,
Model existingParsedModel,
ActionListener<Boolean> listener
) throws IOException {
) {
// The model we are trying to update must have a trained model associated with it if it is an in-cluster deployment
var deploymentId = getDeploymentIdForInClusterEndpoint(existingParsedModel);
throwIfTrainedModelDoesntExist(request.getInferenceEntityId(), deploymentId);
var inferenceEntityId = request.getInferenceEntityId();
throwIfTrainedModelDoesntExist(inferenceEntityId, deploymentId);

Map<String, Object> serviceSettings = request.getContentAsSettings().serviceSettings();
if (serviceSettings != null && serviceSettings.get(NUM_ALLOCATIONS) instanceof Integer numAllocations) {
if (inferenceEntityId.equals(deploymentId) == false) {
modelRegistry.getModel(deploymentId, ActionListener.wrap(unparsedModel -> {
// if this deployment was created by another inference endpoint, then it must be updated using that inference endpoint
listener.onFailure(
new ElasticsearchStatusException(
Messages.INFERENCE_REFERENCE_CANNOT_UPDATE_ANOTHER_ENDPOINT,
RestStatus.CONFLICT,
inferenceEntityId,
deploymentId,
unparsedModel.inferenceEntityId()
)
);
}, e -> {
if (e instanceof ResourceNotFoundException) {
// if this deployment was created by the trained models API, then it must be updated by the trained models API
listener.onFailure(
new ElasticsearchStatusException(
Messages.INFERENCE_CAN_ONLY_UPDATE_MODELS_IT_CREATED,
RestStatus.CONFLICT,
inferenceEntityId,
deploymentId
)
);
return;
}
listener.onFailure(e);
}));
return;
}

if (newModel.getServiceSettings() instanceof ElasticsearchInternalServiceSettings elasticServiceSettings) {

UpdateTrainedModelDeploymentAction.Request updateRequest = new UpdateTrainedModelDeploymentAction.Request(deploymentId);
updateRequest.setNumberOfAllocations(numAllocations);
var updateRequest = new UpdateTrainedModelDeploymentAction.Request(deploymentId);
updateRequest.setNumberOfAllocations(elasticServiceSettings.getNumAllocations());
updateRequest.setAdaptiveAllocationsSettings(elasticServiceSettings.getAdaptiveAllocationsSettings());
updateRequest.setIsInternal(true);

var delegate = listener.<CreateTrainedModelAssignmentAction.Response>delegateFailure((l2, response) -> {
modelRegistry.updateModelTransaction(newModel, existingParsedModel, l2);
});

logger.info(
"Updating trained model deployment [{}] for inference entity [{}] with [{}] num_allocations",
"Updating trained model deployment [{}] for inference entity [{}] with [{}] num_allocations and adaptive allocations [{}]",
deploymentId,
request.getInferenceEntityId(),
numAllocations
elasticServiceSettings.getNumAllocations(),
elasticServiceSettings.getAdaptiveAllocationsSettings()
);
client.execute(UpdateTrainedModelDeploymentAction.INSTANCE, updateRequest, delegate);

Expand Down Expand Up @@ -317,7 +346,6 @@ private void throwIfTrainedModelDoesntExist(String inferenceEntityId, String dep
throw ExceptionsHelper.entityNotFoundException(
Messages.MODEL_ID_DOES_NOT_MATCH_EXISTING_MODEL_IDS_BUT_MUST_FOR_IN_CLUSTER_SERVICE,
inferenceEntityId

);
}
}
Expand Down
Loading
Loading