Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/130940.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 130940
summary: Block trained model updates from inference
area: Machine Learning
type: enhancement
issues:
- 129999
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,11 @@ public final class Messages {
public static final String MODEL_ID_DOES_NOT_MATCH_EXISTING_MODEL_IDS_BUT_MUST_FOR_IN_CLUSTER_SERVICE =
"Requested model ID [{}] does not have a matching trained model and thus cannot be updated.";
public static final String INFERENCE_ENTITY_NON_EXISTANT_NO_UPDATE = "The inference endpoint [{}] does not exist and cannot be updated";
public static final String INFERENCE_REFERENCE_CANNOT_UPDATE_ANOTHER_ENDPOINT =
"Cannot update inference [{}] for model deployment [{}] as it was created by another inference endpoint. "
+ "This model can only be updated using inference endpoint id [{}].";
public static final String INFERENCE_CAN_ONLY_UPDATE_MODELS_IT_CREATED =
"Cannot update inference [{}] for model deployment [{}]. This model deployment must be updated through the trained models API.";

private Messages() {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,24 +51,6 @@ public void testAttachToDeployment() throws IOException {
var results = infer(inferenceId, List.of("washing machine"));
assertNotNull(results.get("sparse_embedding"));

var updatedNumAllocations = randomIntBetween(1, 10);
var updatedEndpointConfig = updateEndpoint(inferenceId, updatedEndpointConfig(updatedNumAllocations), TaskType.SPARSE_EMBEDDING);
assertThat(
updatedEndpointConfig.get("service_settings"),
is(
Map.of(
"num_allocations",
updatedNumAllocations,
"num_threads",
1,
"model_id",
"attach_to_deployment",
"deployment_id",
"existing_deployment"
)
)
);

deleteModel(inferenceId);
// assert deployment not stopped
var stats = (List<Map<String, Object>>) getTrainedModelStats(modelId).get("trained_model_stats");
Expand Down Expand Up @@ -128,24 +110,6 @@ public void testAttachWithModelId() throws IOException {
var results = infer(inferenceId, List.of("washing machine"));
assertNotNull(results.get("sparse_embedding"));

var updatedNumAllocations = randomIntBetween(1, 10);
var updatedEndpointConfig = updateEndpoint(inferenceId, updatedEndpointConfig(updatedNumAllocations), TaskType.SPARSE_EMBEDDING);
assertThat(
updatedEndpointConfig.get("service_settings"),
is(
Map.of(
"num_allocations",
updatedNumAllocations,
"num_threads",
1,
"model_id",
"attach_with_model_id",
"deployment_id",
"existing_deployment_with_model_id"
)
)
);

forceStopMlNodeDeployment(deploymentId);
}

Expand Down Expand Up @@ -180,6 +144,30 @@ public void testDeploymentDoesNotExist() {
assertThat(e.getMessage(), containsString("Cannot find deployment [missing_deployment]"));
}

public void testCreateInferenceUsingSameDeploymentId() throws IOException {
var modelId = "conflicting_ids";
var deploymentId = modelId;
var inferenceId = modelId;

CustomElandModelIT.createMlNodeTextExpansionModel(modelId, client());
var response = startMlNodeDeploymemnt(modelId, deploymentId);
assertStatusOkOrCreated(response);

var responseException = assertThrows(
ResponseException.class,
() -> putModel(inferenceId, endpointConfig(deploymentId), TaskType.SPARSE_EMBEDDING)
);
assertThat(
responseException.getMessage(),
containsString(
"Inference endpoint IDs must be unique. "
+ "Requested inference endpoint ID [conflicting_ids] matches existing trained model ID(s) but must not."
)
);

forceStopMlNodeDeployment(deploymentId);
}

public void testNumAllocationsIsUpdated() throws IOException {
var modelId = "update_num_allocations";
var deploymentId = modelId;
Expand Down Expand Up @@ -208,7 +196,16 @@ public void testNumAllocationsIsUpdated() throws IOException {
)
);

assertStatusOkOrCreated(updateMlNodeDeploymemnt(deploymentId, 2));
var responseException = assertThrows(ResponseException.class, () -> updateInference(inferenceId, TaskType.SPARSE_EMBEDDING, 2));
assertThat(
responseException.getMessage(),
containsString(
"Cannot update inference [test_num_allocations_updated] for model deployment [update_num_allocations]. "
+ "This model deployment must be updated through the trained models API."
)
);

updateMlNodeDeploymemnt(deploymentId, 2);

var updatedServiceSettings = getModel(inferenceId).get("service_settings");
assertThat(
Expand All @@ -227,6 +224,92 @@ public void testNumAllocationsIsUpdated() throws IOException {
)
)
);

forceStopMlNodeDeployment(deploymentId);
}

public void testUpdateWhenInferenceEndpointCreatesDeployment() throws IOException {
var modelId = "update_num_allocations_from_created_endpoint";
var inferenceId = "test_created_endpoint_from_model";
var deploymentId = inferenceId;

CustomElandModelIT.createMlNodeTextExpansionModel(modelId, client());

var putModel = putModel(inferenceId, Strings.format("""
{
"service": "elasticsearch",
"service_settings": {
"num_allocations": %s,
"num_threads": %s,
"model_id": "%s"
}
}
""", 1, 1, modelId), TaskType.SPARSE_EMBEDDING);
var serviceSettings = putModel.get("service_settings");
assertThat(putModel.toString(), serviceSettings, is(Map.of("num_allocations", 1, "num_threads", 1, "model_id", modelId)));

updateInference(inferenceId, TaskType.SPARSE_EMBEDDING, 2);

var responseException = assertThrows(ResponseException.class, () -> updateMlNodeDeploymemnt(deploymentId, 2));
assertThat(
responseException.getMessage(),
containsString(
"Cannot update deployment [test_created_endpoint_from_model] as it was created by inference endpoint "
+ "[test_created_endpoint_from_model]. This model deployment must be updated through the inference API."
)
);

var updatedServiceSettings = getModel(inferenceId).get("service_settings");
assertThat(
updatedServiceSettings.toString(),
updatedServiceSettings,
is(Map.of("num_allocations", 2, "num_threads", 1, "model_id", modelId))
);

forceStopMlNodeDeployment(deploymentId);
}

public void testCannotUpdateAnotherInferenceEndpointsCreatedDeployment() throws IOException {
var modelId = "model_deployment_for_endpoint";
var inferenceId = "first_endpoint_for_model_deployment";
var deploymentId = inferenceId;

CustomElandModelIT.createMlNodeTextExpansionModel(modelId, client());

putModel(inferenceId, Strings.format("""
{
"service": "elasticsearch",
"service_settings": {
"num_allocations": %s,
"num_threads": %s,
"model_id": "%s"
}
}
""", 1, 1, modelId), TaskType.SPARSE_EMBEDDING);

var secondInferenceId = "second_endpoint_for_model_deployment";
var putModel = putModel(secondInferenceId, endpointConfig(deploymentId), TaskType.SPARSE_EMBEDDING);
var serviceSettings = putModel.get("service_settings");
assertThat(
putModel.toString(),
serviceSettings,
is(Map.of("num_allocations", 1, "num_threads", 1, "model_id", modelId, "deployment_id", deploymentId))
);

var responseException = assertThrows(
ResponseException.class,
() -> updateInference(secondInferenceId, TaskType.SPARSE_EMBEDDING, 2)
);
assertThat(
responseException.getMessage(),
containsString(
"Cannot update inference [second_endpoint_for_model_deployment] for model deployment "
+ "[first_endpoint_for_model_deployment] as it was created by another inference endpoint. "
+ "This model can only be updated using inference endpoint id [first_endpoint_for_model_deployment]."
)
);

forceStopMlNodeDeployment(deploymentId);
}

public void testStoppingDeploymentAttachedToInferenceEndpoint() throws IOException {
Expand Down Expand Up @@ -300,6 +383,22 @@ private Response startMlNodeDeploymemnt(String modelId, String deploymentId) thr
return client().performRequest(request);
}

private Response updateInference(String deploymentId, TaskType taskType, int numAllocations) throws IOException {
String endPoint = Strings.format("/_inference/%s/%s/_update", taskType, deploymentId);

var body = Strings.format("""
{
"service_settings": {
"num_allocations": %d
}
}
""", numAllocations);

Request request = new Request("PUT", endPoint);
request.setJsonEntity(body);
return client().performRequest(request);
}

private Response updateMlNodeDeploymemnt(String deploymentId, int numAllocations) throws IOException {
String endPoint = "/_ml/trained_models/" + deploymentId + "/deployment/_update";

Expand All @@ -314,6 +413,16 @@ private Response updateMlNodeDeploymemnt(String deploymentId, int numAllocations
return client().performRequest(request);
}

private Map<String, Object> updateMlNodeDeploymemnt(String deploymentId, String body) throws IOException {
String endPoint = "/_ml/trained_models/" + deploymentId + "/deployment/_update";

Request request = new Request("POST", endPoint);
request.setJsonEntity(body);
var response = client().performRequest(request);
assertStatusOkOrCreated(response);
return entityAsMap(response);
}

protected void stopMlNodeDeployment(String deploymentId) throws IOException {
String endpoint = "/_ml/trained_models/" + deploymentId + "/deployment/_stop";
Request request = new Request("POST", endpoint);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;

Expand Down Expand Up @@ -224,13 +223,10 @@ private Model combineExistingModelWithNewSettings(
if (settingsToUpdate.serviceSettings() != null && existingSecretSettings != null) {
newSecretSettings = existingSecretSettings.newSecretSettings(settingsToUpdate.serviceSettings());
}
if (settingsToUpdate.serviceSettings() != null && settingsToUpdate.serviceSettings().containsKey(NUM_ALLOCATIONS)) {
// In cluster services can only have their num_allocations updated, so this is a special case
if (settingsToUpdate.serviceSettings() != null) {
// In cluster services can have their deployment settings updated, so this is a special case
if (newServiceSettings instanceof ElasticsearchInternalServiceSettings elasticServiceSettings) {
newServiceSettings = new ElasticsearchInternalServiceSettings(
elasticServiceSettings,
(Integer) settingsToUpdate.serviceSettings().get(NUM_ALLOCATIONS)
);
newServiceSettings = elasticServiceSettings.updateServiceSettings(settingsToUpdate.serviceSettings());
}
}
if (settingsToUpdate.taskSettings() != null && existingTaskSettings != null) {
Expand All @@ -257,26 +253,57 @@ private void updateInClusterEndpoint(
Model newModel,
Model existingParsedModel,
ActionListener<Boolean> listener
) throws IOException {
) {
// The model we are trying to update must have a trained model associated with it if it is an in-cluster deployment
var deploymentId = getDeploymentIdForInClusterEndpoint(existingParsedModel);
throwIfTrainedModelDoesntExist(request.getInferenceEntityId(), deploymentId);
var inferenceEntityId = request.getInferenceEntityId();
throwIfTrainedModelDoesntExist(inferenceEntityId, deploymentId);

if (inferenceEntityId.equals(deploymentId) == false) {
modelRegistry.getModel(deploymentId, ActionListener.wrap(unparsedModel -> {
// if this deployment was created by another inference endpoint, then it must be updated using that inference endpoint
listener.onFailure(
new ElasticsearchStatusException(
Messages.INFERENCE_REFERENCE_CANNOT_UPDATE_ANOTHER_ENDPOINT,
RestStatus.CONFLICT,
inferenceEntityId,
deploymentId,
unparsedModel.inferenceEntityId()
)
);
}, e -> {
if (e instanceof ResourceNotFoundException) {
// if this deployment was created by the trained models API, then it must be updated by the trained models API
listener.onFailure(
new ElasticsearchStatusException(
Messages.INFERENCE_CAN_ONLY_UPDATE_MODELS_IT_CREATED,
RestStatus.CONFLICT,
inferenceEntityId,
deploymentId
)
);
}
listener.onFailure(e);
}));
}

Map<String, Object> serviceSettings = request.getContentAsSettings().serviceSettings();
if (serviceSettings != null && serviceSettings.get(NUM_ALLOCATIONS) instanceof Integer numAllocations) {
if (newModel.getServiceSettings() instanceof ElasticsearchInternalServiceSettings elasticServiceSettings) {

UpdateTrainedModelDeploymentAction.Request updateRequest = new UpdateTrainedModelDeploymentAction.Request(deploymentId);
updateRequest.setNumberOfAllocations(numAllocations);
var updateRequest = new UpdateTrainedModelDeploymentAction.Request(deploymentId);
updateRequest.setNumberOfAllocations(elasticServiceSettings.getNumAllocations());
updateRequest.setAdaptiveAllocationsSettings(elasticServiceSettings.getAdaptiveAllocationsSettings());
updateRequest.setIsInternal(true);

var delegate = listener.<CreateTrainedModelAssignmentAction.Response>delegateFailure((l2, response) -> {
modelRegistry.updateModelTransaction(newModel, existingParsedModel, l2);
});

logger.info(
"Updating trained model deployment [{}] for inference entity [{}] with [{}] num_allocations",
"Updating trained model deployment [{}] for inference entity [{}] with [{}] num_allocations and adaptive allocations [{}]",
deploymentId,
request.getInferenceEntityId(),
numAllocations
elasticServiceSettings.getNumAllocations(),
elasticServiceSettings.getAdaptiveAllocationsSettings()
);
client.execute(UpdateTrainedModelDeploymentAction.INSTANCE, updateRequest, delegate);

Expand Down Expand Up @@ -317,7 +344,6 @@ private void throwIfTrainedModelDoesntExist(String inferenceEntityId, String dep
throw ExceptionsHelper.entityNotFoundException(
Messages.MODEL_ID_DOES_NOT_MATCH_EXISTING_MODEL_IDS_BUT_MUST_FOR_IN_CLUSTER_SERVICE,
inferenceEntityId

);
}
}
Expand Down
Loading
Loading