Skip to content

Commit c40e3e5

Browse files
committed
Check before stop
1 parent 916cd05 commit c40e3e5

File tree

3 files changed

+155
-15
lines changed

3 files changed

+155
-15
lines changed

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ public void testAttachToDeployment() throws IOException {
7575
var deploymentStats = stats.get(0).get("deployment_stats");
7676
assertNotNull(stats.toString(), deploymentStats);
7777

78-
stopMlNodeDeployment(deploymentId);
78+
forceStopMlNodeDeployment(deploymentId);
7979
}
8080

8181
public void testAttachWithModelId() throws IOException {
@@ -146,7 +146,7 @@ public void testAttachWithModelId() throws IOException {
146146
)
147147
);
148148

149-
stopMlNodeDeployment(deploymentId);
149+
forceStopMlNodeDeployment(deploymentId);
150150
}
151151

152152
public void testModelIdDoesNotMatch() throws IOException {
@@ -229,6 +229,29 @@ public void testNumAllocationsIsUpdated() throws IOException {
229229
);
230230
}
231231

232+
public void testStoppingDeploymentAttachedToInferenceEndpoint() throws IOException {
233+
var modelId = "try_stop_attach_to_deployment";
234+
var deploymentId = "test_stop_attach_to_deployment";
235+
236+
CustomElandModelIT.createMlNodeTextExpansionModel(modelId, client());
237+
var response = startMlNodeDeploymemnt(modelId, deploymentId);
238+
assertStatusOkOrCreated(response);
239+
240+
var inferenceId = "test_stop_inference_on_existing_deployment";
241+
putModel(inferenceId, endpointConfig(deploymentId), TaskType.SPARSE_EMBEDDING);
242+
243+
var stopShouldNotSucceed = expectThrows(ResponseException.class, () -> stopMlNodeDeployment(deploymentId));
244+
assertThat(
245+
stopShouldNotSucceed.getMessage(),
246+
containsString(
247+
Strings.format("Cannot stop deployment [%s] as it is used by inference endpoint [%s]", deploymentId, inferenceId)
248+
)
249+
);
250+
251+
// Force stop will stop the deployment
252+
forceStopMlNodeDeployment(deploymentId);
253+
}
254+
232255
private String endpointConfig(String deploymentId) {
233256
return Strings.format("""
234257
{
@@ -292,6 +315,12 @@ private Response updateMlNodeDeploymemnt(String deploymentId, int numAllocations
292315
}
293316

294317
protected void stopMlNodeDeployment(String deploymentId) throws IOException {
318+
String endpoint = "/_ml/trained_models/" + deploymentId + "/deployment/_stop";
319+
Request request = new Request("POST", endpoint);
320+
client().performRequest(request);
321+
}
322+
323+
protected void forceStopMlNodeDeployment(String deploymentId) throws IOException {
295324
String endpoint = "/_ml/trained_models/" + deploymentId + "/deployment/_stop";
296325
Request request = new Request("POST", endpoint);
297326
request.addParameter("force", "true");

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CustomElandModelIT.java

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.inference;
99

1010
import org.elasticsearch.client.Request;
11+
import org.elasticsearch.client.ResponseException;
1112
import org.elasticsearch.client.RestClient;
1213
import org.elasticsearch.core.Strings;
1314
import org.elasticsearch.inference.TaskType;
@@ -18,6 +19,8 @@
1819
import java.util.List;
1920
import java.util.stream.Collectors;
2021

22+
import static org.hamcrest.Matchers.containsString;
23+
2124
public class CustomElandModelIT extends InferenceBaseRestTest {
2225

2326
// The model definition is taken from org.elasticsearch.xpack.ml.integration.TextExpansionQueryIT
@@ -92,6 +95,47 @@ public void testSparse() throws IOException {
9295
assertNotNull(results.get("sparse_embedding"));
9396
}
9497

98+
public void testCannotStopDeployment() throws IOException {
99+
String modelId = "custom-model-that-cannot-be-stopped";
100+
101+
createTextExpansionModel(modelId, client());
102+
putModelDefinition(modelId, BASE_64_ENCODED_MODEL, RAW_MODEL_SIZE, client());
103+
putVocabulary(
104+
List.of("these", "are", "my", "words", "the", "washing", "machine", "is", "leaking", "octopus", "comforter", "smells"),
105+
modelId,
106+
client()
107+
);
108+
109+
var inferenceConfig = """
110+
{
111+
"service": "elasticsearch",
112+
"service_settings": {
113+
"model_id": "custom-model-that-cannot-be-stopped",
114+
"num_allocations": 1,
115+
"num_threads": 1
116+
}
117+
}
118+
""";
119+
120+
var inferenceId = "sparse-inf";
121+
putModel(inferenceId, inferenceConfig, TaskType.SPARSE_EMBEDDING);
122+
infer(inferenceId, List.of("washing", "machine"));
123+
124+
// Stopping the deployment using the ML trained models API should fail
125+
// because the deployment was created by the inference endpoint API
126+
String stopEndpoint = org.elasticsearch.common.Strings.format("_ml/trained_models/%s/deployment/_stop?error_trace", inferenceId);
127+
Request stopRequest = new Request("POST", stopEndpoint);
128+
var e = expectThrows(ResponseException.class, () -> client().performRequest(stopRequest));
129+
assertThat(
130+
e.getMessage(),
131+
containsString("Cannot stop deployment [sparse-inf] as it was created by inference endpoint [sparse-inf]")
132+
);
133+
134+
// Force stop works
135+
String forceStopEndpoint = org.elasticsearch.common.Strings.format("_ml/trained_models/%s/deployment/_stop?force", inferenceId);
136+
assertStatusOkOrCreated(client().performRequest(new Request("POST", forceStopEndpoint)));
137+
}
138+
95139
static void createTextExpansionModel(String modelId, RestClient client) throws IOException {
96140
// with_special_tokens: false for this test with limited vocab
97141
Request request = new Request("PUT", "/_ml/trained_models/" + modelId);

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java

Lines changed: 80 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,24 @@
1717
import org.elasticsearch.action.TaskOperationFailure;
1818
import org.elasticsearch.action.support.ActionFilters;
1919
import org.elasticsearch.action.support.tasks.TransportTasksAction;
20+
import org.elasticsearch.client.internal.Client;
2021
import org.elasticsearch.cluster.ClusterState;
2122
import org.elasticsearch.cluster.node.DiscoveryNode;
2223
import org.elasticsearch.cluster.node.DiscoveryNodes;
2324
import org.elasticsearch.cluster.service.ClusterService;
2425
import org.elasticsearch.common.util.concurrent.EsExecutors;
26+
import org.elasticsearch.common.xcontent.XContentHelper;
2527
import org.elasticsearch.discovery.MasterNotDiscoveredException;
28+
import org.elasticsearch.inference.TaskType;
2629
import org.elasticsearch.ingest.IngestMetadata;
27-
import org.elasticsearch.ingest.IngestService;
2830
import org.elasticsearch.injection.guice.Inject;
2931
import org.elasticsearch.rest.RestStatus;
3032
import org.elasticsearch.tasks.CancellableTask;
3133
import org.elasticsearch.tasks.Task;
3234
import org.elasticsearch.transport.TransportResponseHandler;
3335
import org.elasticsearch.transport.TransportService;
36+
import org.elasticsearch.xcontent.XContentType;
37+
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
3438
import org.elasticsearch.xpack.core.ml.action.StopTrainedModelDeploymentAction;
3539
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
3640
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata;
@@ -63,7 +67,7 @@ public class TransportStopTrainedModelDeploymentAction extends TransportTasksAct
6367

6468
private static final Logger logger = LogManager.getLogger(TransportStopTrainedModelDeploymentAction.class);
6569

66-
private final IngestService ingestService;
70+
private final Client client;
6771
private final TrainedModelAssignmentClusterService trainedModelAssignmentClusterService;
6872
private final InferenceAuditor auditor;
6973

@@ -72,7 +76,7 @@ public TransportStopTrainedModelDeploymentAction(
7276
ClusterService clusterService,
7377
TransportService transportService,
7478
ActionFilters actionFilters,
75-
IngestService ingestService,
79+
Client client,
7680
TrainedModelAssignmentClusterService trainedModelAssignmentClusterService,
7781
InferenceAuditor auditor
7882
) {
@@ -85,7 +89,7 @@ public TransportStopTrainedModelDeploymentAction(
8589
StopTrainedModelDeploymentAction.Response::new,
8690
EsExecutors.DIRECT_EXECUTOR_SERVICE
8791
);
88-
this.ingestService = ingestService;
92+
this.client = client;
8993
this.trainedModelAssignmentClusterService = trainedModelAssignmentClusterService;
9094
this.auditor = Objects.requireNonNull(auditor);
9195
}
@@ -154,21 +158,84 @@ protected void doExecute(
154158

155159
// NOTE, should only run on Master node
156160
assert clusterService.localNode().isMasterNode();
161+
162+
if (request.isForce() == false) {
163+
checkIfUsedByInferenceEndpoint(
164+
request.getId(),
165+
ActionListener.wrap(canStop -> stopDeployment(task, request, maybeAssignment.get(), listener), listener::onFailure)
166+
);
167+
} else {
168+
stopDeployment(task, request, maybeAssignment.get(), listener);
169+
}
170+
}
171+
172+
private void stopDeployment(
173+
Task task,
174+
StopTrainedModelDeploymentAction.Request request,
175+
TrainedModelAssignment assignment,
176+
ActionListener<StopTrainedModelDeploymentAction.Response> listener
177+
) {
157178
trainedModelAssignmentClusterService.setModelAssignmentToStopping(
158179
request.getId(),
159-
ActionListener.wrap(
160-
setToStopping -> normalUndeploy(task, request.getId(), maybeAssignment.get(), request, listener),
161-
failure -> {
162-
if (ExceptionsHelper.unwrapCause(failure) instanceof ResourceNotFoundException) {
163-
listener.onResponse(new StopTrainedModelDeploymentAction.Response(true));
164-
return;
165-
}
166-
listener.onFailure(failure);
180+
ActionListener.wrap(setToStopping -> normalUndeploy(task, request.getId(), assignment, request, listener), failure -> {
181+
if (ExceptionsHelper.unwrapCause(failure) instanceof ResourceNotFoundException) {
182+
listener.onResponse(new StopTrainedModelDeploymentAction.Response(true));
183+
return;
167184
}
168-
)
185+
listener.onFailure(failure);
186+
})
169187
);
170188
}
171189

190+
private void checkIfUsedByInferenceEndpoint(String deploymentId, ActionListener<Boolean> listener) {
191+
192+
GetInferenceModelAction.Request getAllEndpoints = new GetInferenceModelAction.Request("*", TaskType.ANY);
193+
client.execute(GetInferenceModelAction.INSTANCE, getAllEndpoints, listener.delegateFailureAndWrap((l, response) -> {
194+
// filter by the ml node services
195+
var mlNodeEndpoints = response.getEndpoints()
196+
.stream()
197+
.filter(model -> model.getService().equals("elasticsearch") || model.getService().equals("elser"))
198+
.toList();
199+
200+
var endpointOwnsDeployment = mlNodeEndpoints.stream()
201+
.filter(model -> model.getInferenceEntityId().equals(deploymentId))
202+
.findFirst();
203+
if (endpointOwnsDeployment.isPresent()) {
204+
l.onFailure(
205+
new ElasticsearchStatusException(
206+
"Cannot stop deployment [{}] as it was created by inference endpoint [{}]",
207+
RestStatus.CONFLICT,
208+
deploymentId,
209+
endpointOwnsDeployment.get().getInferenceEntityId()
210+
)
211+
);
212+
return;
213+
}
214+
215+
// The inference endpoint may have been created by attaching to an existing deployment.
216+
for (var endpoint : mlNodeEndpoints) {
217+
var serviceSettingsXContent = XContentHelper.toXContent(endpoint.getServiceSettings(), XContentType.JSON, false);
218+
var settingsMap = XContentHelper.convertToMap(serviceSettingsXContent, false, XContentType.JSON).v2();
219+
// Endpoints with the deployment_id setting are attached to an existing deployment.
220+
var deploymentIdFromSettings = (String) settingsMap.get("deployment_id");
221+
if (deploymentIdFromSettings != null && deploymentIdFromSettings.equals(deploymentId)) {
222+
// The endpoint was created to use this deployment
223+
l.onFailure(
224+
new ElasticsearchStatusException(
225+
"Cannot stop deployment [{}] as it is used by inference endpoint [{}]",
226+
RestStatus.CONFLICT,
227+
deploymentId,
228+
endpoint.getInferenceEntityId()
229+
)
230+
);
231+
return;
232+
}
233+
}
234+
235+
l.onResponse(true);
236+
}));
237+
}
238+
172239
private void redirectToMasterNode(
173240
DiscoveryNode masterNode,
174241
StopTrainedModelDeploymentAction.Request request,

0 commit comments

Comments
 (0)