Skip to content

Commit f89caa9

Browse files
Enable force inference endpoint deleting for invalid models and after stopping model deployment fails
1 parent 080a280 commit f89caa9

File tree

2 files changed

+249
-3
lines changed

2 files changed

+249
-3
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.elasticsearch.common.Strings;
2323
import org.elasticsearch.common.util.concurrent.EsExecutors;
2424
import org.elasticsearch.inference.InferenceServiceRegistry;
25+
import org.elasticsearch.inference.Model;
2526
import org.elasticsearch.inference.UnparsedModel;
2627
import org.elasticsearch.injection.guice.Inject;
2728
import org.elasticsearch.rest.RestStatus;
@@ -124,10 +125,38 @@ private void doExecuteForked(
124125
}
125126

126127
var service = serviceRegistry.getService(unparsedModel.service());
128+
Model model;
127129
if (service.isPresent()) {
128-
var model = service.get()
129-
.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
130-
service.get().stop(model, listener);
130+
try {
131+
model = service.get()
132+
.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
133+
} catch (Exception e) {
134+
if (request.isForceDelete()) {
135+
listener.onResponse(true);
136+
return;
137+
} else {
138+
listener.onFailure(
139+
new ElasticsearchStatusException(
140+
Strings.format(
141+
"Failed to parse model configuration for inference endpoint [%s]",
142+
request.getInferenceEndpointId()
143+
),
144+
RestStatus.INTERNAL_SERVER_ERROR,
145+
e
146+
)
147+
);
148+
return;
149+
}
150+
}
151+
service.get().stop(model, listener.delegateResponse((l, e) -> {
152+
if (request.isForceDelete()) {
153+
l.onResponse(true);
154+
} else {
155+
l.onFailure(e);
156+
}
157+
}));
158+
} else if (request.isForceDelete()) {
159+
listener.onResponse(true);
131160
} else {
132161
listener.onFailure(
133162
new ElasticsearchStatusException(

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointActionTests.java

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616
import org.elasticsearch.core.TimeValue;
1717
import org.elasticsearch.inference.InferenceService;
1818
import org.elasticsearch.inference.InferenceServiceRegistry;
19+
import org.elasticsearch.inference.Model;
1920
import org.elasticsearch.inference.TaskType;
2021
import org.elasticsearch.inference.UnparsedModel;
22+
import org.elasticsearch.rest.RestStatus;
2123
import org.elasticsearch.tasks.Task;
2224
import org.elasticsearch.test.ESTestCase;
2325
import org.elasticsearch.threadpool.ThreadPool;
@@ -31,11 +33,17 @@
3133
import java.util.Optional;
3234

3335
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
36+
import static org.hamcrest.Matchers.containsString;
3437
import static org.hamcrest.Matchers.is;
3538
import static org.mockito.ArgumentMatchers.any;
3639
import static org.mockito.ArgumentMatchers.anyString;
40+
import static org.mockito.ArgumentMatchers.eq;
3741
import static org.mockito.Mockito.doAnswer;
42+
import static org.mockito.Mockito.doReturn;
43+
import static org.mockito.Mockito.doThrow;
3844
import static org.mockito.Mockito.mock;
45+
import static org.mockito.Mockito.verify;
46+
import static org.mockito.Mockito.verifyNoMoreInteractions;
3947
import static org.mockito.Mockito.when;
4048

4149
public class TransportDeleteInferenceEndpointActionTests extends ESTestCase {
@@ -128,4 +136,213 @@ public void testDeletesDefaultEndpoint_WhenForceIsTrue() {
128136

129137
assertTrue(response.isAcknowledged());
130138
}
139+
140+
public void testFailsToDeleteUnparsableEndpoint_WhenForceIsFalse() {
141+
var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
142+
var serviceName = randomAlphanumericOfLength(10);
143+
var taskType = randomFrom(TaskType.values());
144+
var mockService = mock(InferenceService.class);
145+
mockUnparsableModel(inferenceEndpointId, serviceName, taskType, mockService);
146+
when(mockModelRegistry.containsDefaultConfigId(inferenceEndpointId)).thenReturn(false);
147+
148+
var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
149+
action.masterOperation(
150+
mock(Task.class),
151+
new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, false, false),
152+
ClusterState.EMPTY_STATE,
153+
listener
154+
);
155+
156+
var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
157+
assertThat(exception.getMessage(), containsString("Failed to parse model configuration for inference endpoint"));
158+
159+
verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
160+
verify(mockInferenceServiceRegistry).getService(eq(serviceName));
161+
verify(mockModelRegistry).containsDefaultConfigId(eq(inferenceEndpointId));
162+
verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
163+
verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService);
164+
}
165+
166+
public void testDeletesUnparsableEndpoint_WhenForceIsTrue() {
167+
var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
168+
var serviceName = randomAlphanumericOfLength(10);
169+
var taskType = randomFrom(TaskType.values());
170+
var mockService = mock(InferenceService.class);
171+
mockUnparsableModel(inferenceEndpointId, serviceName, taskType, mockService);
172+
doAnswer(invocationOnMock -> {
173+
ActionListener<Boolean> listener = invocationOnMock.getArgument(1);
174+
listener.onResponse(true);
175+
return Void.TYPE;
176+
}).when(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any());
177+
178+
var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
179+
180+
action.masterOperation(
181+
mock(Task.class),
182+
new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, true, false),
183+
ClusterState.EMPTY_STATE,
184+
listener
185+
);
186+
187+
var response = listener.actionGet(TIMEOUT);
188+
assertTrue(response.isAcknowledged());
189+
190+
verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
191+
verify(mockInferenceServiceRegistry).getService(eq(serviceName));
192+
verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
193+
verify(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any());
194+
verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService);
195+
}
196+
197+
private void mockUnparsableModel(String inferenceEndpointId, String serviceName, TaskType taskType, InferenceService mockService) {
198+
doAnswer(invocationOnMock -> {
199+
ActionListener<UnparsedModel> listener = invocationOnMock.getArgument(1);
200+
listener.onResponse(new UnparsedModel(inferenceEndpointId, taskType, serviceName, Map.of(), Map.of()));
201+
return Void.TYPE;
202+
}).when(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
203+
doThrow(new ElasticsearchStatusException(randomAlphanumericOfLength(10), RestStatus.INTERNAL_SERVER_ERROR)).when(mockService)
204+
.parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
205+
when(mockInferenceServiceRegistry.getService(serviceName)).thenReturn(Optional.of(mockService));
206+
}
207+
208+
public void testDeletesEndpointWithNoService_WhenForceIsTrue() {
209+
var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
210+
var serviceName = randomAlphanumericOfLength(10);
211+
var taskType = randomFrom(TaskType.values());
212+
mockNoService(inferenceEndpointId, serviceName, taskType);
213+
doAnswer(invocationOnMock -> {
214+
ActionListener<Boolean> listener = invocationOnMock.getArgument(1);
215+
listener.onResponse(true);
216+
return Void.TYPE;
217+
}).when(mockModelRegistry).deleteModel(anyString(), any());
218+
219+
var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
220+
221+
action.masterOperation(
222+
mock(Task.class),
223+
new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, true, false),
224+
ClusterState.EMPTY_STATE,
225+
listener
226+
);
227+
228+
var response = listener.actionGet(TIMEOUT);
229+
assertTrue(response.isAcknowledged());
230+
verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
231+
verify(mockInferenceServiceRegistry).getService(eq(serviceName));
232+
verify(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any());
233+
verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry);
234+
}
235+
236+
public void testFailsToDeleteEndpointWithNoService_WhenForceIsFalse() {
237+
var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
238+
var serviceName = randomAlphanumericOfLength(10);
239+
var taskType = randomFrom(TaskType.values());
240+
mockNoService(inferenceEndpointId, serviceName, taskType);
241+
when(mockModelRegistry.containsDefaultConfigId(inferenceEndpointId)).thenReturn(false);
242+
243+
var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
244+
245+
action.masterOperation(
246+
mock(Task.class),
247+
new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, false, false),
248+
ClusterState.EMPTY_STATE,
249+
listener
250+
);
251+
252+
var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
253+
assertThat(exception.getMessage(), containsString("No service found for this inference endpoint"));
254+
verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
255+
verify(mockInferenceServiceRegistry).getService(eq(serviceName));
256+
verify(mockModelRegistry).containsDefaultConfigId(eq(inferenceEndpointId));
257+
verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry);
258+
}
259+
260+
private void mockNoService(String inferenceEndpointId, String serviceName, TaskType taskType) {
261+
doAnswer(invocationOnMock -> {
262+
ActionListener<UnparsedModel> listener = invocationOnMock.getArgument(1);
263+
listener.onResponse(new UnparsedModel(inferenceEndpointId, taskType, serviceName, Map.of(), Map.of()));
264+
return Void.TYPE;
265+
}).when(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
266+
when(mockInferenceServiceRegistry.getService(serviceName)).thenReturn(Optional.empty());
267+
}
268+
269+
public void testFailsToDeleteEndpointIfModelDeploymentStopFails_WhenForceIsFalse() {
270+
var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
271+
var serviceName = randomAlphanumericOfLength(10);
272+
var taskType = randomFrom(TaskType.values());
273+
var mockService = mock(InferenceService.class);
274+
var mockModel = mock(Model.class);
275+
mockStopDeploymentFails(inferenceEndpointId, serviceName, taskType, mockService, mockModel);
276+
when(mockModelRegistry.containsDefaultConfigId(inferenceEndpointId)).thenReturn(false);
277+
278+
var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
279+
action.masterOperation(
280+
mock(Task.class),
281+
new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, false, false),
282+
ClusterState.EMPTY_STATE,
283+
listener
284+
);
285+
286+
var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
287+
assertThat(exception.getMessage(), containsString("Failed to stop model deployment"));
288+
verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
289+
verify(mockInferenceServiceRegistry).getService(eq(serviceName));
290+
verify(mockModelRegistry).containsDefaultConfigId(eq(inferenceEndpointId));
291+
verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
292+
verify(mockService).stop(eq(mockModel), any());
293+
verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService, mockModel);
294+
}
295+
296+
public void testDeletesEndpointIfModelDeploymentStopFails_WhenForceIsTrue() {
297+
var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
298+
var serviceName = randomAlphanumericOfLength(10);
299+
var taskType = randomFrom(TaskType.values());
300+
var mockService = mock(InferenceService.class);
301+
var mockModel = mock(Model.class);
302+
mockStopDeploymentFails(inferenceEndpointId, serviceName, taskType, mockService, mockModel);
303+
doAnswer(invocationOnMock -> {
304+
ActionListener<Boolean> listener = invocationOnMock.getArgument(1);
305+
listener.onResponse(true);
306+
return Void.TYPE;
307+
}).when(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any());
308+
309+
var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
310+
action.masterOperation(
311+
mock(Task.class),
312+
new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, true, false),
313+
ClusterState.EMPTY_STATE,
314+
listener
315+
);
316+
317+
var response = listener.actionGet(TIMEOUT);
318+
assertTrue(response.isAcknowledged());
319+
verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
320+
verify(mockInferenceServiceRegistry).getService(eq(serviceName));
321+
verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
322+
verify(mockService).stop(eq(mockModel), any());
323+
verify(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any());
324+
verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService, mockModel);
325+
}
326+
327+
private void mockStopDeploymentFails(
328+
String inferenceEndpointId,
329+
String serviceName,
330+
TaskType taskType,
331+
InferenceService mockService,
332+
Model mockModel
333+
) {
334+
doAnswer(invocationOnMock -> {
335+
ActionListener<UnparsedModel> listener = invocationOnMock.getArgument(1);
336+
listener.onResponse(new UnparsedModel(inferenceEndpointId, taskType, serviceName, Map.of(), Map.of()));
337+
return Void.TYPE;
338+
}).when(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
339+
when(mockInferenceServiceRegistry.getService(serviceName)).thenReturn(Optional.of(mockService));
340+
doReturn(mockModel).when(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
341+
doAnswer(invocationOnMock -> {
342+
ActionListener<Boolean> listener = invocationOnMock.getArgument(1);
343+
listener.onFailure(new ElasticsearchStatusException("Failed to stop model deployment", RestStatus.INTERNAL_SERVER_ERROR));
344+
return Void.TYPE;
345+
}).when(mockService).stop(eq(mockModel), any());
346+
}
347+
131348
}

0 commit comments

Comments
 (0)