Skip to content

Commit 92b01c8

Browse files
[ML] Allowing deletion of default endpoints while using force=true (elastic#124781) (elastic#124877)
* Allowing deletion of default endpoints and add warning header * Moving to force logic
1 parent 10f1f47 commit 92b01c8

File tree

2 files changed

+78
-24
lines changed

2 files changed

+78
-24
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import org.elasticsearch.cluster.block.ClusterBlockException;
2020
import org.elasticsearch.cluster.block.ClusterBlockLevel;
2121
import org.elasticsearch.cluster.service.ClusterService;
22+
import org.elasticsearch.common.Strings;
2223
import org.elasticsearch.common.util.concurrent.EsExecutors;
2324
import org.elasticsearch.inference.InferenceServiceRegistry;
2425
import org.elasticsearch.inference.UnparsedModel;
@@ -86,17 +87,6 @@ private void doExecuteForked(
8687
ClusterState state,
8788
ActionListener<DeleteInferenceEndpointAction.Response> masterListener
8889
) {
89-
if (modelRegistry.containsDefaultConfigId(request.getInferenceEndpointId())) {
90-
masterListener.onFailure(
91-
new ElasticsearchStatusException(
92-
"[{}] is a reserved inference endpoint. Cannot delete a reserved inference endpoint.",
93-
RestStatus.BAD_REQUEST,
94-
request.getInferenceEndpointId()
95-
)
96-
);
97-
return;
98-
}
99-
10090
SubscribableListener.<UnparsedModel>newForked(modelConfigListener -> {
10191
// Get the model from the registry
10292

@@ -118,6 +108,18 @@ private void doExecuteForked(
118108
if (errorString != null) {
119109
listener.onFailure(new ElasticsearchStatusException(errorString, RestStatus.CONFLICT));
120110
return;
111+
} else if (isInferenceIdReserved(request.getInferenceEndpointId())) {
112+
listener.onFailure(
113+
new ElasticsearchStatusException(
114+
Strings.format(
115+
"[%s] is a reserved inference endpoint. Use the force=true query parameter "
116+
+ "to delete the inference endpoint.",
117+
request.getInferenceEndpointId()
118+
),
119+
RestStatus.BAD_REQUEST
120+
)
121+
);
122+
return;
121123
}
122124
}
123125

@@ -186,6 +188,10 @@ private static String endpointIsReferencedInPipelinesOrIndexes(final ClusterStat
186188
return null;
187189
}
188190

191+
private boolean isInferenceIdReserved(String inferenceEndpointId) {
192+
return modelRegistry.containsDefaultConfigId(inferenceEndpointId);
193+
}
194+
189195
private static String buildErrorString(String inferenceEndpointId, Set<String> pipelines, Set<String> indexes) {
190196
StringBuilder errorString = new StringBuilder();
191197

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointActionTests.java

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,16 @@
88
package org.elasticsearch.xpack.inference.action;
99

1010
import org.elasticsearch.ElasticsearchStatusException;
11+
import org.elasticsearch.action.ActionListener;
1112
import org.elasticsearch.action.support.ActionFilters;
1213
import org.elasticsearch.action.support.PlainActionFuture;
13-
import org.elasticsearch.client.internal.Client;
1414
import org.elasticsearch.cluster.ClusterState;
1515
import org.elasticsearch.cluster.service.ClusterService;
1616
import org.elasticsearch.core.TimeValue;
1717
import org.elasticsearch.inference.InferenceService;
1818
import org.elasticsearch.inference.InferenceServiceRegistry;
19-
import org.elasticsearch.inference.MinimalServiceSettings;
2019
import org.elasticsearch.inference.TaskType;
20+
import org.elasticsearch.inference.UnparsedModel;
2121
import org.elasticsearch.tasks.Task;
2222
import org.elasticsearch.test.ESTestCase;
2323
import org.elasticsearch.threadpool.ThreadPool;
@@ -27,30 +27,39 @@
2727
import org.junit.After;
2828
import org.junit.Before;
2929

30+
import java.util.Map;
31+
import java.util.Optional;
32+
3033
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
3134
import static org.hamcrest.Matchers.is;
35+
import static org.mockito.ArgumentMatchers.any;
36+
import static org.mockito.ArgumentMatchers.anyString;
37+
import static org.mockito.Mockito.doAnswer;
3238
import static org.mockito.Mockito.mock;
39+
import static org.mockito.Mockito.when;
3340

3441
public class TransportDeleteInferenceEndpointActionTests extends ESTestCase {
3542

3643
private static final TimeValue TIMEOUT = TimeValue.timeValueSeconds(30);
3744

3845
private TransportDeleteInferenceEndpointAction action;
3946
private ThreadPool threadPool;
40-
private ModelRegistry modelRegistry;
47+
private ModelRegistry mockModelRegistry;
48+
private InferenceServiceRegistry mockInferenceServiceRegistry;
4149

4250
@Before
4351
public void setUp() throws Exception {
4452
super.setUp();
45-
modelRegistry = new ModelRegistry(mock(Client.class));
4653
threadPool = createThreadPool(inferenceUtilityPool());
54+
mockModelRegistry = mock(ModelRegistry.class);
55+
mockInferenceServiceRegistry = mock(InferenceServiceRegistry.class);
4756
action = new TransportDeleteInferenceEndpointAction(
4857
mock(TransportService.class),
4958
mock(ClusterService.class),
5059
threadPool,
5160
mock(ActionFilters.class),
52-
modelRegistry,
53-
mock(InferenceServiceRegistry.class)
61+
mockModelRegistry,
62+
mockInferenceServiceRegistry
5463
);
5564
}
5665

@@ -60,24 +69,63 @@ public void tearDown() throws Exception {
6069
terminate(threadPool);
6170
}
6271

63-
public void testFailsToDelete_ADefaultEndpoint() {
64-
modelRegistry.addDefaultIds(
65-
new InferenceService.DefaultConfigId("model-id", MinimalServiceSettings.chatCompletion(), mock(InferenceService.class))
66-
);
72+
public void testFailsToDelete_ADefaultEndpoint_WithoutPassingForceQueryParameter() {
73+
doAnswer(invocationOnMock -> {
74+
ActionListener<UnparsedModel> listener = invocationOnMock.getArgument(1);
75+
listener.onResponse(new UnparsedModel("model_id", TaskType.COMPLETION, "service", Map.of(), Map.of()));
76+
return Void.TYPE;
77+
}).when(mockModelRegistry).getModel(anyString(), any());
78+
when(mockModelRegistry.containsDefaultConfigId(anyString())).thenReturn(true);
6779

6880
var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
6981

7082
action.masterOperation(
7183
mock(Task.class),
72-
new DeleteInferenceEndpointAction.Request("model-id", TaskType.CHAT_COMPLETION, true, false),
73-
mock(ClusterState.class),
84+
new DeleteInferenceEndpointAction.Request("model-id", TaskType.COMPLETION, false, false),
85+
ClusterState.EMPTY_STATE,
7486
listener
7587
);
7688

7789
var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
7890
assertThat(
7991
exception.getMessage(),
80-
is("[model-id] is a reserved inference endpoint. " + "Cannot delete a reserved inference endpoint.")
92+
is("[model-id] is a reserved inference endpoint. Use the force=true query parameter to delete the inference endpoint.")
8193
);
8294
}
95+
96+
public void testDeletesDefaultEndpoint_WhenForceIsTrue() {
97+
doAnswer(invocationOnMock -> {
98+
ActionListener<UnparsedModel> listener = invocationOnMock.getArgument(1);
99+
listener.onResponse(new UnparsedModel("model_id", TaskType.COMPLETION, "service", Map.of(), Map.of()));
100+
return Void.TYPE;
101+
}).when(mockModelRegistry).getModel(anyString(), any());
102+
when(mockModelRegistry.containsDefaultConfigId(anyString())).thenReturn(true);
103+
doAnswer(invocationOnMock -> {
104+
ActionListener<Boolean> listener = invocationOnMock.getArgument(1);
105+
listener.onResponse(true);
106+
return Void.TYPE;
107+
}).when(mockModelRegistry).deleteModel(anyString(), any());
108+
109+
var mockService = mock(InferenceService.class);
110+
doAnswer(invocationOnMock -> {
111+
ActionListener<Boolean> listener = invocationOnMock.getArgument(1);
112+
listener.onResponse(true);
113+
return Void.TYPE;
114+
}).when(mockService).stop(any(), any());
115+
116+
when(mockInferenceServiceRegistry.getService(anyString())).thenReturn(Optional.of(mockService));
117+
118+
var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
119+
120+
action.masterOperation(
121+
mock(Task.class),
122+
new DeleteInferenceEndpointAction.Request("model-id", TaskType.COMPLETION, true, false),
123+
ClusterState.EMPTY_STATE,
124+
listener
125+
);
126+
127+
var response = listener.actionGet(TIMEOUT);
128+
129+
assertTrue(response.isAcknowledged());
130+
}
83131
}

0 commit comments

Comments
 (0)