From c40e3e51fa9720ae536ef12545a19c8be7cadf52 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Thu, 12 Jun 2025 10:32:46 +0100 Subject: [PATCH 1/3] Check before stop --- .../inference/CreateFromDeploymentIT.java | 33 ++++++- .../xpack/inference/CustomElandModelIT.java | 44 +++++++++ ...sportStopTrainedModelDeploymentAction.java | 93 ++++++++++++++++--- 3 files changed, 155 insertions(+), 15 deletions(-) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java index 47f34fa486daf..5571729626fb9 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java @@ -75,7 +75,7 @@ public void testAttachToDeployment() throws IOException { var deploymentStats = stats.get(0).get("deployment_stats"); assertNotNull(stats.toString(), deploymentStats); - stopMlNodeDeployment(deploymentId); + forceStopMlNodeDeployment(deploymentId); } public void testAttachWithModelId() throws IOException { @@ -146,7 +146,7 @@ public void testAttachWithModelId() throws IOException { ) ); - stopMlNodeDeployment(deploymentId); + forceStopMlNodeDeployment(deploymentId); } public void testModelIdDoesNotMatch() throws IOException { @@ -229,6 +229,29 @@ public void testNumAllocationsIsUpdated() throws IOException { ); } + public void testStoppingDeploymentAttachedToInferenceEndpoint() throws IOException { + var modelId = "try_stop_attach_to_deployment"; + var deploymentId = "test_stop_attach_to_deployment"; + + CustomElandModelIT.createMlNodeTextExpansionModel(modelId, client()); + var response = startMlNodeDeploymemnt(modelId, deploymentId); + assertStatusOkOrCreated(response); + + var inferenceId = "test_stop_inference_on_existing_deployment"; + putModel(inferenceId, endpointConfig(deploymentId), TaskType.SPARSE_EMBEDDING); + + var stopShouldNotSucceed = expectThrows(ResponseException.class, () -> stopMlNodeDeployment(deploymentId)); + assertThat( + stopShouldNotSucceed.getMessage(), + containsString( + Strings.format("Cannot stop deployment [%s] as it is used by inference endpoint [%s]", deploymentId, inferenceId) + ) + ); + + // Force stop will stop the deployment + forceStopMlNodeDeployment(deploymentId); + } + private String endpointConfig(String deploymentId) { return Strings.format(""" { @@ -292,6 +315,12 @@ private Response updateMlNodeDeploymemnt(String deploymentId, int numAllocations } protected void stopMlNodeDeployment(String deploymentId) throws IOException { + String endpoint = "/_ml/trained_models/" + deploymentId + "/deployment/_stop"; + Request request = new Request("POST", endpoint); + client().performRequest(request); + } + + protected void forceStopMlNodeDeployment(String deploymentId) throws IOException { String endpoint = "/_ml/trained_models/" + deploymentId + "/deployment/_stop"; Request request = new Request("POST", endpoint); request.addParameter("force", "true"); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CustomElandModelIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CustomElandModelIT.java index e6d959bafea3f..a8e409ff1fd0f 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CustomElandModelIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CustomElandModelIT.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference; import org.elasticsearch.client.Request; +import org.elasticsearch.client.ResponseException; import org.elasticsearch.client.RestClient; import org.elasticsearch.core.Strings; import org.elasticsearch.inference.TaskType; @@ -18,6 +19,8 @@ import java.util.List; import java.util.stream.Collectors; +import static org.hamcrest.Matchers.containsString; + public class CustomElandModelIT extends InferenceBaseRestTest { // The model definition is taken from org.elasticsearch.xpack.ml.integration.TextExpansionQueryIT @@ -92,6 +95,47 @@ public void testSparse() throws IOException { assertNotNull(results.get("sparse_embedding")); } + public void testCannotStopDeployment() throws IOException { + String modelId = "custom-model-that-cannot-be-stopped"; + + createTextExpansionModel(modelId, client()); + putModelDefinition(modelId, BASE_64_ENCODED_MODEL, RAW_MODEL_SIZE, client()); + putVocabulary( + List.of("these", "are", "my", "words", "the", "washing", "machine", "is", "leaking", "octopus", "comforter", "smells"), + modelId, + client() + ); + + var inferenceConfig = """ + { + "service": "elasticsearch", + "service_settings": { + "model_id": "custom-model-that-cannot-be-stopped", + "num_allocations": 1, + "num_threads": 1 + } + } + """; + + var inferenceId = "sparse-inf"; + putModel(inferenceId, inferenceConfig, TaskType.SPARSE_EMBEDDING); + infer(inferenceId, List.of("washing", "machine")); + + // Stopping the deployment using the ML trained models API should fail + // because the deployment was created by the inference endpoint API + String stopEndpoint = org.elasticsearch.common.Strings.format("_ml/trained_models/%s/deployment/_stop?error_trace", inferenceId); + Request stopRequest = new Request("POST", stopEndpoint); + var e = expectThrows(ResponseException.class, () -> client().performRequest(stopRequest)); + assertThat( + e.getMessage(), + containsString("Cannot stop deployment [sparse-inf] as it was created by inference endpoint [sparse-inf]") + ); + + // Force stop works + String forceStopEndpoint = org.elasticsearch.common.Strings.format("_ml/trained_models/%s/deployment/_stop?force", inferenceId); + assertStatusOkOrCreated(client().performRequest(new Request("POST", forceStopEndpoint))); + } + static void createTextExpansionModel(String modelId, RestClient client) throws IOException { // with_special_tokens: false for this test with limited vocab Request request = new Request("PUT", "/_ml/trained_models/" + modelId); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java index 483295de89ceb..612809950254a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java @@ -17,20 +17,24 @@ import org.elasticsearch.action.TaskOperationFailure; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.tasks.TransportTasksAction; +import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.discovery.MasterNotDiscoveredException; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.ingest.IngestMetadata; -import org.elasticsearch.ingest.IngestService; import org.elasticsearch.injection.guice.Inject; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportResponseHandler; import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; import org.elasticsearch.xpack.core.ml.action.StopTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata; @@ -63,7 +67,7 @@ public class TransportStopTrainedModelDeploymentAction extends TransportTasksAct private static final Logger logger = LogManager.getLogger(TransportStopTrainedModelDeploymentAction.class); - private final IngestService ingestService; + private final Client client; private final TrainedModelAssignmentClusterService trainedModelAssignmentClusterService; private final InferenceAuditor auditor; @@ -72,7 +76,7 @@ public TransportStopTrainedModelDeploymentAction( ClusterService clusterService, TransportService transportService, ActionFilters actionFilters, - IngestService ingestService, + Client client, TrainedModelAssignmentClusterService trainedModelAssignmentClusterService, InferenceAuditor auditor ) { @@ -85,7 +89,7 @@ public TransportStopTrainedModelDeploymentAction( StopTrainedModelDeploymentAction.Response::new, EsExecutors.DIRECT_EXECUTOR_SERVICE ); - this.ingestService = ingestService; + this.client = client; this.trainedModelAssignmentClusterService = trainedModelAssignmentClusterService; this.auditor = Objects.requireNonNull(auditor); } @@ -154,21 +158,84 @@ protected void doExecute( // NOTE, should only run on Master node assert clusterService.localNode().isMasterNode(); + + if (request.isForce() == false) { + checkIfUsedByInferenceEndpoint( + request.getId(), + ActionListener.wrap(canStop -> stopDeployment(task, request, maybeAssignment.get(), listener), listener::onFailure) + ); + } else { + stopDeployment(task, request, maybeAssignment.get(), listener); + } + } + + private void stopDeployment( + Task task, + StopTrainedModelDeploymentAction.Request request, + TrainedModelAssignment assignment, + ActionListener listener + ) { trainedModelAssignmentClusterService.setModelAssignmentToStopping( request.getId(), - ActionListener.wrap( - setToStopping -> normalUndeploy(task, request.getId(), maybeAssignment.get(), request, listener), - failure -> { - if (ExceptionsHelper.unwrapCause(failure) instanceof ResourceNotFoundException) { - listener.onResponse(new StopTrainedModelDeploymentAction.Response(true)); - return; - } - listener.onFailure(failure); + ActionListener.wrap(setToStopping -> normalUndeploy(task, request.getId(), assignment, request, listener), failure -> { + if (ExceptionsHelper.unwrapCause(failure) instanceof ResourceNotFoundException) { + listener.onResponse(new StopTrainedModelDeploymentAction.Response(true)); + return; } - ) + listener.onFailure(failure); + }) ); } + private void checkIfUsedByInferenceEndpoint(String deploymentId, ActionListener listener) { + + GetInferenceModelAction.Request getAllEndpoints = new GetInferenceModelAction.Request("*", TaskType.ANY); + client.execute(GetInferenceModelAction.INSTANCE, getAllEndpoints, listener.delegateFailureAndWrap((l, response) -> { + // filter by the ml node services + var mlNodeEndpoints = response.getEndpoints() + .stream() + .filter(model -> model.getService().equals("elasticsearch") || model.getService().equals("elser")) + .toList(); + + var endpointOwnsDeployment = mlNodeEndpoints.stream() + .filter(model -> model.getInferenceEntityId().equals(deploymentId)) + .findFirst(); + if (endpointOwnsDeployment.isPresent()) { + l.onFailure( + new ElasticsearchStatusException( + "Cannot stop deployment [{}] as it was created by inference endpoint [{}]", + RestStatus.CONFLICT, + deploymentId, + endpointOwnsDeployment.get().getInferenceEntityId() + ) + ); + return; + } + + // The inference endpoint may have been created by attaching to an existing deployment. + for (var endpoint : mlNodeEndpoints) { + var serviceSettingsXContent = XContentHelper.toXContent(endpoint.getServiceSettings(), XContentType.JSON, false); + var settingsMap = XContentHelper.convertToMap(serviceSettingsXContent, false, XContentType.JSON).v2(); + // Endpoints with the deployment_id setting are attached to an existing deployment. + var deploymentIdFromSettings = (String) settingsMap.get("deployment_id"); + if (deploymentIdFromSettings != null && deploymentIdFromSettings.equals(deploymentId)) { + // The endpoint was created to use this deployment + l.onFailure( + new ElasticsearchStatusException( + "Cannot stop deployment [{}] as it is used by inference endpoint [{}]", + RestStatus.CONFLICT, + deploymentId, + endpoint.getInferenceEntityId() + ) + ); + return; + } + } + + l.onResponse(true); + })); + } + private void redirectToMasterNode( DiscoveryNode masterNode, StopTrainedModelDeploymentAction.Request request, From f87ac853bda09c422c7637df9a6563e7a5ef152b Mon Sep 17 00:00:00 2001 From: David Kyle Date: Thu, 12 Jun 2025 11:00:15 +0100 Subject: [PATCH 2/3] Update docs/changelog/129325.yaml --- docs/changelog/129325.yaml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 docs/changelog/129325.yaml diff --git a/docs/changelog/129325.yaml b/docs/changelog/129325.yaml new file mode 100644 index 0000000000000..cbbb309dceee0 --- /dev/null +++ b/docs/changelog/129325.yaml @@ -0,0 +1,6 @@ +pr: 129325 +summary: Check for model deployment in inference endpoints before stopping +area: Machine Learning +type: bug +issues: + - 128549 From 255019deb7ed00cc78c9dbd77ebff0aaa6bb139f Mon Sep 17 00:00:00 2001 From: David Kyle Date: Thu, 12 Jun 2025 12:07:15 +0100 Subject: [PATCH 3/3] use internal ml user --- .../action/TransportStopTrainedModelDeploymentAction.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java index 612809950254a..188bdbfecc55e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java @@ -18,6 +18,7 @@ import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.tasks.TransportTasksAction; import org.elasticsearch.client.internal.Client; +import org.elasticsearch.client.internal.OriginSettingClient; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodes; @@ -51,6 +52,7 @@ import java.util.Set; import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.ml.action.TransportDeleteTrainedModelAction.getModelAliases; /** @@ -67,7 +69,7 @@ public class TransportStopTrainedModelDeploymentAction extends TransportTasksAct private static final Logger logger = LogManager.getLogger(TransportStopTrainedModelDeploymentAction.class); - private final Client client; + private final OriginSettingClient client; private final TrainedModelAssignmentClusterService trainedModelAssignmentClusterService; private final InferenceAuditor auditor; @@ -89,7 +91,7 @@ public TransportStopTrainedModelDeploymentAction( StopTrainedModelDeploymentAction.Response::new, EsExecutors.DIRECT_EXECUTOR_SERVICE ); - this.client = client; + this.client = new OriginSettingClient(client, ML_ORIGIN); this.trainedModelAssignmentClusterService = trainedModelAssignmentClusterService; this.auditor = Objects.requireNonNull(auditor); }