17
17
import org .elasticsearch .action .TaskOperationFailure ;
18
18
import org .elasticsearch .action .support .ActionFilters ;
19
19
import org .elasticsearch .action .support .tasks .TransportTasksAction ;
20
+ import org .elasticsearch .client .internal .Client ;
21
+ import org .elasticsearch .client .internal .OriginSettingClient ;
20
22
import org .elasticsearch .cluster .ClusterState ;
21
23
import org .elasticsearch .cluster .node .DiscoveryNode ;
22
24
import org .elasticsearch .cluster .node .DiscoveryNodes ;
23
25
import org .elasticsearch .cluster .service .ClusterService ;
24
26
import org .elasticsearch .common .util .concurrent .EsExecutors ;
27
+ import org .elasticsearch .common .xcontent .XContentHelper ;
25
28
import org .elasticsearch .discovery .MasterNotDiscoveredException ;
29
+ import org .elasticsearch .inference .TaskType ;
26
30
import org .elasticsearch .ingest .IngestMetadata ;
27
- import org .elasticsearch .ingest .IngestService ;
28
31
import org .elasticsearch .injection .guice .Inject ;
29
32
import org .elasticsearch .rest .RestStatus ;
30
33
import org .elasticsearch .tasks .CancellableTask ;
31
34
import org .elasticsearch .tasks .Task ;
32
35
import org .elasticsearch .transport .TransportResponseHandler ;
33
36
import org .elasticsearch .transport .TransportService ;
37
+ import org .elasticsearch .xcontent .XContentType ;
38
+ import org .elasticsearch .xpack .core .inference .action .GetInferenceModelAction ;
34
39
import org .elasticsearch .xpack .core .ml .action .StopTrainedModelDeploymentAction ;
35
40
import org .elasticsearch .xpack .core .ml .inference .assignment .TrainedModelAssignment ;
36
41
import org .elasticsearch .xpack .core .ml .inference .assignment .TrainedModelAssignmentMetadata ;
47
52
import java .util .Set ;
48
53
49
54
import static org .elasticsearch .core .Strings .format ;
55
+ import static org .elasticsearch .xpack .core .ClientHelper .ML_ORIGIN ;
50
56
import static org .elasticsearch .xpack .ml .action .TransportDeleteTrainedModelAction .getModelAliases ;
51
57
52
58
/**
@@ -63,7 +69,7 @@ public class TransportStopTrainedModelDeploymentAction extends TransportTasksAct
63
69
64
70
private static final Logger logger = LogManager .getLogger (TransportStopTrainedModelDeploymentAction .class );
65
71
66
- private final IngestService ingestService ;
72
+ private final OriginSettingClient client ;
67
73
private final TrainedModelAssignmentClusterService trainedModelAssignmentClusterService ;
68
74
private final InferenceAuditor auditor ;
69
75
@@ -72,7 +78,7 @@ public TransportStopTrainedModelDeploymentAction(
72
78
ClusterService clusterService ,
73
79
TransportService transportService ,
74
80
ActionFilters actionFilters ,
75
- IngestService ingestService ,
81
+ Client client ,
76
82
TrainedModelAssignmentClusterService trainedModelAssignmentClusterService ,
77
83
InferenceAuditor auditor
78
84
) {
@@ -85,7 +91,7 @@ public TransportStopTrainedModelDeploymentAction(
85
91
StopTrainedModelDeploymentAction .Response ::new ,
86
92
EsExecutors .DIRECT_EXECUTOR_SERVICE
87
93
);
88
- this .ingestService = ingestService ;
94
+ this .client = new OriginSettingClient ( client , ML_ORIGIN ) ;
89
95
this .trainedModelAssignmentClusterService = trainedModelAssignmentClusterService ;
90
96
this .auditor = Objects .requireNonNull (auditor );
91
97
}
@@ -154,21 +160,84 @@ protected void doExecute(
154
160
155
161
// NOTE, should only run on Master node
156
162
assert clusterService .localNode ().isMasterNode ();
163
+
164
+ if (request .isForce () == false ) {
165
+ checkIfUsedByInferenceEndpoint (
166
+ request .getId (),
167
+ ActionListener .wrap (canStop -> stopDeployment (task , request , maybeAssignment .get (), listener ), listener ::onFailure )
168
+ );
169
+ } else {
170
+ stopDeployment (task , request , maybeAssignment .get (), listener );
171
+ }
172
+ }
173
+
174
+ private void stopDeployment (
175
+ Task task ,
176
+ StopTrainedModelDeploymentAction .Request request ,
177
+ TrainedModelAssignment assignment ,
178
+ ActionListener <StopTrainedModelDeploymentAction .Response > listener
179
+ ) {
157
180
trainedModelAssignmentClusterService .setModelAssignmentToStopping (
158
181
request .getId (),
159
- ActionListener .wrap (
160
- setToStopping -> normalUndeploy (task , request .getId (), maybeAssignment .get (), request , listener ),
161
- failure -> {
162
- if (ExceptionsHelper .unwrapCause (failure ) instanceof ResourceNotFoundException ) {
163
- listener .onResponse (new StopTrainedModelDeploymentAction .Response (true ));
164
- return ;
165
- }
166
- listener .onFailure (failure );
182
+ ActionListener .wrap (setToStopping -> normalUndeploy (task , request .getId (), assignment , request , listener ), failure -> {
183
+ if (ExceptionsHelper .unwrapCause (failure ) instanceof ResourceNotFoundException ) {
184
+ listener .onResponse (new StopTrainedModelDeploymentAction .Response (true ));
185
+ return ;
167
186
}
168
- )
187
+ listener .onFailure (failure );
188
+ })
169
189
);
170
190
}
171
191
192
+ private void checkIfUsedByInferenceEndpoint (String deploymentId , ActionListener <Boolean > listener ) {
193
+
194
+ GetInferenceModelAction .Request getAllEndpoints = new GetInferenceModelAction .Request ("*" , TaskType .ANY );
195
+ client .execute (GetInferenceModelAction .INSTANCE , getAllEndpoints , listener .delegateFailureAndWrap ((l , response ) -> {
196
+ // filter by the ml node services
197
+ var mlNodeEndpoints = response .getEndpoints ()
198
+ .stream ()
199
+ .filter (model -> model .getService ().equals ("elasticsearch" ) || model .getService ().equals ("elser" ))
200
+ .toList ();
201
+
202
+ var endpointOwnsDeployment = mlNodeEndpoints .stream ()
203
+ .filter (model -> model .getInferenceEntityId ().equals (deploymentId ))
204
+ .findFirst ();
205
+ if (endpointOwnsDeployment .isPresent ()) {
206
+ l .onFailure (
207
+ new ElasticsearchStatusException (
208
+ "Cannot stop deployment [{}] as it was created by inference endpoint [{}]" ,
209
+ RestStatus .CONFLICT ,
210
+ deploymentId ,
211
+ endpointOwnsDeployment .get ().getInferenceEntityId ()
212
+ )
213
+ );
214
+ return ;
215
+ }
216
+
217
+ // The inference endpoint may have been created by attaching to an existing deployment.
218
+ for (var endpoint : mlNodeEndpoints ) {
219
+ var serviceSettingsXContent = XContentHelper .toXContent (endpoint .getServiceSettings (), XContentType .JSON , false );
220
+ var settingsMap = XContentHelper .convertToMap (serviceSettingsXContent , false , XContentType .JSON ).v2 ();
221
+ // Endpoints with the deployment_id setting are attached to an existing deployment.
222
+ var deploymentIdFromSettings = (String ) settingsMap .get ("deployment_id" );
223
+ if (deploymentIdFromSettings != null && deploymentIdFromSettings .equals (deploymentId )) {
224
+ // The endpoint was created to use this deployment
225
+ l .onFailure (
226
+ new ElasticsearchStatusException (
227
+ "Cannot stop deployment [{}] as it is used by inference endpoint [{}]" ,
228
+ RestStatus .CONFLICT ,
229
+ deploymentId ,
230
+ endpoint .getInferenceEntityId ()
231
+ )
232
+ );
233
+ return ;
234
+ }
235
+ }
236
+
237
+ l .onResponse (true );
238
+ }));
239
+ }
240
+
172
241
private void redirectToMasterNode (
173
242
DiscoveryNode masterNode ,
174
243
StopTrainedModelDeploymentAction .Request request ,
0 commit comments