Skip to content

Commit bb15b04

Browse files
authored
[ML] Check for model deployment in inference endpoints before stopping (#129325) (#129909)
(cherry picked from commit 816caf7)
1 parent d9eb665 commit bb15b04

File tree

4 files changed

+163
-15
lines changed

4 files changed

+163
-15
lines changed

docs/changelog/129325.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 129325
2+
summary: Check for model deployment in inference endpoints before stopping
3+
area: Machine Learning
4+
type: bug
5+
issues:
6+
- 128549

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ public void testAttachToDeployment() throws IOException {
7575
var deploymentStats = stats.get(0).get("deployment_stats");
7676
assertNotNull(stats.toString(), deploymentStats);
7777

78-
stopMlNodeDeployment(deploymentId);
78+
forceStopMlNodeDeployment(deploymentId);
7979
}
8080

8181
public void testAttachWithModelId() throws IOException {
@@ -146,7 +146,7 @@ public void testAttachWithModelId() throws IOException {
146146
)
147147
);
148148

149-
stopMlNodeDeployment(deploymentId);
149+
forceStopMlNodeDeployment(deploymentId);
150150
}
151151

152152
public void testModelIdDoesNotMatch() throws IOException {
@@ -229,6 +229,29 @@ public void testNumAllocationsIsUpdated() throws IOException {
229229
);
230230
}
231231

232+
public void testStoppingDeploymentAttachedToInferenceEndpoint() throws IOException {
233+
var modelId = "try_stop_attach_to_deployment";
234+
var deploymentId = "test_stop_attach_to_deployment";
235+
236+
CustomElandModelIT.createMlNodeTextExpansionModel(modelId, client());
237+
var response = startMlNodeDeploymemnt(modelId, deploymentId);
238+
assertStatusOkOrCreated(response);
239+
240+
var inferenceId = "test_stop_inference_on_existing_deployment";
241+
putModel(inferenceId, endpointConfig(deploymentId), TaskType.SPARSE_EMBEDDING);
242+
243+
var stopShouldNotSucceed = expectThrows(ResponseException.class, () -> stopMlNodeDeployment(deploymentId));
244+
assertThat(
245+
stopShouldNotSucceed.getMessage(),
246+
containsString(
247+
Strings.format("Cannot stop deployment [%s] as it is used by inference endpoint [%s]", deploymentId, inferenceId)
248+
)
249+
);
250+
251+
// Force stop will stop the deployment
252+
forceStopMlNodeDeployment(deploymentId);
253+
}
254+
232255
private String endpointConfig(String deploymentId) {
233256
return Strings.format("""
234257
{
@@ -292,6 +315,12 @@ private Response updateMlNodeDeploymemnt(String deploymentId, int numAllocations
292315
}
293316

294317
protected void stopMlNodeDeployment(String deploymentId) throws IOException {
318+
String endpoint = "/_ml/trained_models/" + deploymentId + "/deployment/_stop";
319+
Request request = new Request("POST", endpoint);
320+
client().performRequest(request);
321+
}
322+
323+
protected void forceStopMlNodeDeployment(String deploymentId) throws IOException {
295324
String endpoint = "/_ml/trained_models/" + deploymentId + "/deployment/_stop";
296325
Request request = new Request("POST", endpoint);
297326
request.addParameter("force", "true");

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CustomElandModelIT.java

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.inference;
99

1010
import org.elasticsearch.client.Request;
11+
import org.elasticsearch.client.ResponseException;
1112
import org.elasticsearch.client.RestClient;
1213
import org.elasticsearch.core.Strings;
1314
import org.elasticsearch.inference.TaskType;
@@ -18,6 +19,8 @@
1819
import java.util.List;
1920
import java.util.stream.Collectors;
2021

22+
import static org.hamcrest.Matchers.containsString;
23+
2124
public class CustomElandModelIT extends InferenceBaseRestTest {
2225

2326
// The model definition is taken from org.elasticsearch.xpack.ml.integration.TextExpansionQueryIT
@@ -92,6 +95,47 @@ public void testSparse() throws IOException {
9295
assertNotNull(results.get("sparse_embedding"));
9396
}
9497

98+
public void testCannotStopDeployment() throws IOException {
99+
String modelId = "custom-model-that-cannot-be-stopped";
100+
101+
createTextExpansionModel(modelId, client());
102+
putModelDefinition(modelId, BASE_64_ENCODED_MODEL, RAW_MODEL_SIZE, client());
103+
putVocabulary(
104+
List.of("these", "are", "my", "words", "the", "washing", "machine", "is", "leaking", "octopus", "comforter", "smells"),
105+
modelId,
106+
client()
107+
);
108+
109+
var inferenceConfig = """
110+
{
111+
"service": "elasticsearch",
112+
"service_settings": {
113+
"model_id": "custom-model-that-cannot-be-stopped",
114+
"num_allocations": 1,
115+
"num_threads": 1
116+
}
117+
}
118+
""";
119+
120+
var inferenceId = "sparse-inf";
121+
putModel(inferenceId, inferenceConfig, TaskType.SPARSE_EMBEDDING);
122+
infer(inferenceId, List.of("washing", "machine"));
123+
124+
// Stopping the deployment using the ML trained models API should fail
125+
// because the deployment was created by the inference endpoint API
126+
String stopEndpoint = org.elasticsearch.common.Strings.format("_ml/trained_models/%s/deployment/_stop?error_trace", inferenceId);
127+
Request stopRequest = new Request("POST", stopEndpoint);
128+
var e = expectThrows(ResponseException.class, () -> client().performRequest(stopRequest));
129+
assertThat(
130+
e.getMessage(),
131+
containsString("Cannot stop deployment [sparse-inf] as it was created by inference endpoint [sparse-inf]")
132+
);
133+
134+
// Force stop works
135+
String forceStopEndpoint = org.elasticsearch.common.Strings.format("_ml/trained_models/%s/deployment/_stop?force", inferenceId);
136+
assertStatusOkOrCreated(client().performRequest(new Request("POST", forceStopEndpoint)));
137+
}
138+
95139
static void createTextExpansionModel(String modelId, RestClient client) throws IOException {
96140
// with_special_tokens: false for this test with limited vocab
97141
Request request = new Request("PUT", "/_ml/trained_models/" + modelId);

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java

Lines changed: 82 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,25 @@
1717
import org.elasticsearch.action.TaskOperationFailure;
1818
import org.elasticsearch.action.support.ActionFilters;
1919
import org.elasticsearch.action.support.tasks.TransportTasksAction;
20+
import org.elasticsearch.client.internal.Client;
21+
import org.elasticsearch.client.internal.OriginSettingClient;
2022
import org.elasticsearch.cluster.ClusterState;
2123
import org.elasticsearch.cluster.node.DiscoveryNode;
2224
import org.elasticsearch.cluster.node.DiscoveryNodes;
2325
import org.elasticsearch.cluster.service.ClusterService;
2426
import org.elasticsearch.common.util.concurrent.EsExecutors;
27+
import org.elasticsearch.common.xcontent.XContentHelper;
2528
import org.elasticsearch.discovery.MasterNotDiscoveredException;
29+
import org.elasticsearch.inference.TaskType;
2630
import org.elasticsearch.ingest.IngestMetadata;
27-
import org.elasticsearch.ingest.IngestService;
2831
import org.elasticsearch.injection.guice.Inject;
2932
import org.elasticsearch.rest.RestStatus;
3033
import org.elasticsearch.tasks.CancellableTask;
3134
import org.elasticsearch.tasks.Task;
3235
import org.elasticsearch.transport.TransportResponseHandler;
3336
import org.elasticsearch.transport.TransportService;
37+
import org.elasticsearch.xcontent.XContentType;
38+
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
3439
import org.elasticsearch.xpack.core.ml.action.StopTrainedModelDeploymentAction;
3540
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
3641
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata;
@@ -47,6 +52,7 @@
4752
import java.util.Set;
4853

4954
import static org.elasticsearch.core.Strings.format;
55+
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
5056
import static org.elasticsearch.xpack.ml.action.TransportDeleteTrainedModelAction.getModelAliases;
5157

5258
/**
@@ -63,7 +69,7 @@ public class TransportStopTrainedModelDeploymentAction extends TransportTasksAct
6369

6470
private static final Logger logger = LogManager.getLogger(TransportStopTrainedModelDeploymentAction.class);
6571

66-
private final IngestService ingestService;
72+
private final OriginSettingClient client;
6773
private final TrainedModelAssignmentClusterService trainedModelAssignmentClusterService;
6874
private final InferenceAuditor auditor;
6975

@@ -72,7 +78,7 @@ public TransportStopTrainedModelDeploymentAction(
7278
ClusterService clusterService,
7379
TransportService transportService,
7480
ActionFilters actionFilters,
75-
IngestService ingestService,
81+
Client client,
7682
TrainedModelAssignmentClusterService trainedModelAssignmentClusterService,
7783
InferenceAuditor auditor
7884
) {
@@ -85,7 +91,7 @@ public TransportStopTrainedModelDeploymentAction(
8591
StopTrainedModelDeploymentAction.Response::new,
8692
EsExecutors.DIRECT_EXECUTOR_SERVICE
8793
);
88-
this.ingestService = ingestService;
94+
this.client = new OriginSettingClient(client, ML_ORIGIN);
8995
this.trainedModelAssignmentClusterService = trainedModelAssignmentClusterService;
9096
this.auditor = Objects.requireNonNull(auditor);
9197
}
@@ -154,21 +160,84 @@ protected void doExecute(
154160

155161
// NOTE, should only run on Master node
156162
assert clusterService.localNode().isMasterNode();
163+
164+
if (request.isForce() == false) {
165+
checkIfUsedByInferenceEndpoint(
166+
request.getId(),
167+
ActionListener.wrap(canStop -> stopDeployment(task, request, maybeAssignment.get(), listener), listener::onFailure)
168+
);
169+
} else {
170+
stopDeployment(task, request, maybeAssignment.get(), listener);
171+
}
172+
}
173+
174+
private void stopDeployment(
175+
Task task,
176+
StopTrainedModelDeploymentAction.Request request,
177+
TrainedModelAssignment assignment,
178+
ActionListener<StopTrainedModelDeploymentAction.Response> listener
179+
) {
157180
trainedModelAssignmentClusterService.setModelAssignmentToStopping(
158181
request.getId(),
159-
ActionListener.wrap(
160-
setToStopping -> normalUndeploy(task, request.getId(), maybeAssignment.get(), request, listener),
161-
failure -> {
162-
if (ExceptionsHelper.unwrapCause(failure) instanceof ResourceNotFoundException) {
163-
listener.onResponse(new StopTrainedModelDeploymentAction.Response(true));
164-
return;
165-
}
166-
listener.onFailure(failure);
182+
ActionListener.wrap(setToStopping -> normalUndeploy(task, request.getId(), assignment, request, listener), failure -> {
183+
if (ExceptionsHelper.unwrapCause(failure) instanceof ResourceNotFoundException) {
184+
listener.onResponse(new StopTrainedModelDeploymentAction.Response(true));
185+
return;
167186
}
168-
)
187+
listener.onFailure(failure);
188+
})
169189
);
170190
}
171191

192+
private void checkIfUsedByInferenceEndpoint(String deploymentId, ActionListener<Boolean> listener) {
193+
194+
GetInferenceModelAction.Request getAllEndpoints = new GetInferenceModelAction.Request("*", TaskType.ANY);
195+
client.execute(GetInferenceModelAction.INSTANCE, getAllEndpoints, listener.delegateFailureAndWrap((l, response) -> {
196+
// filter by the ml node services
197+
var mlNodeEndpoints = response.getEndpoints()
198+
.stream()
199+
.filter(model -> model.getService().equals("elasticsearch") || model.getService().equals("elser"))
200+
.toList();
201+
202+
var endpointOwnsDeployment = mlNodeEndpoints.stream()
203+
.filter(model -> model.getInferenceEntityId().equals(deploymentId))
204+
.findFirst();
205+
if (endpointOwnsDeployment.isPresent()) {
206+
l.onFailure(
207+
new ElasticsearchStatusException(
208+
"Cannot stop deployment [{}] as it was created by inference endpoint [{}]",
209+
RestStatus.CONFLICT,
210+
deploymentId,
211+
endpointOwnsDeployment.get().getInferenceEntityId()
212+
)
213+
);
214+
return;
215+
}
216+
217+
// The inference endpoint may have been created by attaching to an existing deployment.
218+
for (var endpoint : mlNodeEndpoints) {
219+
var serviceSettingsXContent = XContentHelper.toXContent(endpoint.getServiceSettings(), XContentType.JSON, false);
220+
var settingsMap = XContentHelper.convertToMap(serviceSettingsXContent, false, XContentType.JSON).v2();
221+
// Endpoints with the deployment_id setting are attached to an existing deployment.
222+
var deploymentIdFromSettings = (String) settingsMap.get("deployment_id");
223+
if (deploymentIdFromSettings != null && deploymentIdFromSettings.equals(deploymentId)) {
224+
// The endpoint was created to use this deployment
225+
l.onFailure(
226+
new ElasticsearchStatusException(
227+
"Cannot stop deployment [{}] as it is used by inference endpoint [{}]",
228+
RestStatus.CONFLICT,
229+
deploymentId,
230+
endpoint.getInferenceEntityId()
231+
)
232+
);
233+
return;
234+
}
235+
}
236+
237+
l.onResponse(true);
238+
}));
239+
}
240+
172241
private void redirectToMasterNode(
173242
DiscoveryNode masterNode,
174243
StopTrainedModelDeploymentAction.Request request,

0 commit comments

Comments
 (0)