Skip to content

Commit 6bfea79

Browse files
[ML] Allowing deletion of default endpoints while using force=true (elastic#124781) (elastic#124879)
* Allowing deletion of default endpoints and add warning header * Moving to force logic (cherry picked from commit cbfc100) # Conflicts: # x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointActionTests.java
1 parent b9ab327 commit 6bfea79

File tree

2 files changed

+78
-28
lines changed

2 files changed

+78
-28
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
import org.elasticsearch.cluster.ClusterState;
1919
import org.elasticsearch.cluster.block.ClusterBlockException;
2020
import org.elasticsearch.cluster.block.ClusterBlockLevel;
21-
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
2221
import org.elasticsearch.cluster.service.ClusterService;
22+
import org.elasticsearch.common.Strings;
2323
import org.elasticsearch.common.util.concurrent.EsExecutors;
2424
import org.elasticsearch.inference.InferenceServiceRegistry;
2525
import org.elasticsearch.inference.UnparsedModel;
@@ -53,7 +53,6 @@ public TransportDeleteInferenceEndpointAction(
5353
ClusterService clusterService,
5454
ThreadPool threadPool,
5555
ActionFilters actionFilters,
56-
IndexNameExpressionResolver indexNameExpressionResolver,
5756
ModelRegistry modelRegistry,
5857
InferenceServiceRegistry serviceRegistry
5958
) {
@@ -88,17 +87,6 @@ private void doExecuteForked(
8887
ClusterState state,
8988
ActionListener<DeleteInferenceEndpointAction.Response> masterListener
9089
) {
91-
if (modelRegistry.containsDefaultConfigId(request.getInferenceEndpointId())) {
92-
masterListener.onFailure(
93-
new ElasticsearchStatusException(
94-
"[{}] is a reserved inference endpoint. Cannot delete a reserved inference endpoint.",
95-
RestStatus.BAD_REQUEST,
96-
request.getInferenceEndpointId()
97-
)
98-
);
99-
return;
100-
}
101-
10290
SubscribableListener.<UnparsedModel>newForked(modelConfigListener -> {
10391
// Get the model from the registry
10492

@@ -120,6 +108,18 @@ private void doExecuteForked(
120108
if (errorString != null) {
121109
listener.onFailure(new ElasticsearchStatusException(errorString, RestStatus.CONFLICT));
122110
return;
111+
} else if (isInferenceIdReserved(request.getInferenceEndpointId())) {
112+
listener.onFailure(
113+
new ElasticsearchStatusException(
114+
Strings.format(
115+
"[%s] is a reserved inference endpoint. Use the force=true query parameter "
116+
+ "to delete the inference endpoint.",
117+
request.getInferenceEndpointId()
118+
),
119+
RestStatus.BAD_REQUEST
120+
)
121+
);
122+
return;
123123
}
124124
}
125125

@@ -188,6 +188,10 @@ private static String endpointIsReferencedInPipelinesOrIndexes(final ClusterStat
188188
return null;
189189
}
190190

191+
private boolean isInferenceIdReserved(String inferenceEndpointId) {
192+
return modelRegistry.containsDefaultConfigId(inferenceEndpointId);
193+
}
194+
191195
private static String buildErrorString(String inferenceEndpointId, Set<String> pipelines, Set<String> indexes) {
192196
StringBuilder errorString = new StringBuilder();
193197

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointActionTests.java

Lines changed: 61 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,16 @@
88
package org.elasticsearch.xpack.inference.action;
99

1010
import org.elasticsearch.ElasticsearchStatusException;
11+
import org.elasticsearch.action.ActionListener;
1112
import org.elasticsearch.action.support.ActionFilters;
1213
import org.elasticsearch.action.support.PlainActionFuture;
13-
import org.elasticsearch.client.internal.Client;
1414
import org.elasticsearch.cluster.ClusterState;
15-
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
1615
import org.elasticsearch.cluster.service.ClusterService;
1716
import org.elasticsearch.core.TimeValue;
1817
import org.elasticsearch.inference.InferenceService;
1918
import org.elasticsearch.inference.InferenceServiceRegistry;
20-
import org.elasticsearch.inference.MinimalServiceSettings;
2119
import org.elasticsearch.inference.TaskType;
20+
import org.elasticsearch.inference.UnparsedModel;
2221
import org.elasticsearch.tasks.Task;
2322
import org.elasticsearch.test.ESTestCase;
2423
import org.elasticsearch.threadpool.ThreadPool;
@@ -28,31 +27,39 @@
2827
import org.junit.After;
2928
import org.junit.Before;
3029

30+
import java.util.Map;
31+
import java.util.Optional;
32+
3133
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
3234
import static org.hamcrest.Matchers.is;
35+
import static org.mockito.ArgumentMatchers.any;
36+
import static org.mockito.ArgumentMatchers.anyString;
37+
import static org.mockito.Mockito.doAnswer;
3338
import static org.mockito.Mockito.mock;
39+
import static org.mockito.Mockito.when;
3440

3541
public class TransportDeleteInferenceEndpointActionTests extends ESTestCase {
3642

3743
private static final TimeValue TIMEOUT = TimeValue.timeValueSeconds(30);
3844

3945
private TransportDeleteInferenceEndpointAction action;
4046
private ThreadPool threadPool;
41-
private ModelRegistry modelRegistry;
47+
private ModelRegistry mockModelRegistry;
48+
private InferenceServiceRegistry mockInferenceServiceRegistry;
4249

4350
@Before
4451
public void setUp() throws Exception {
4552
super.setUp();
46-
modelRegistry = new ModelRegistry(mock(Client.class));
4753
threadPool = createThreadPool(inferenceUtilityPool());
54+
mockModelRegistry = mock(ModelRegistry.class);
55+
mockInferenceServiceRegistry = mock(InferenceServiceRegistry.class);
4856
action = new TransportDeleteInferenceEndpointAction(
4957
mock(TransportService.class),
5058
mock(ClusterService.class),
5159
threadPool,
5260
mock(ActionFilters.class),
53-
mock(IndexNameExpressionResolver.class),
54-
modelRegistry,
55-
mock(InferenceServiceRegistry.class)
61+
mockModelRegistry,
62+
mockInferenceServiceRegistry
5663
);
5764
}
5865

@@ -62,24 +69,63 @@ public void tearDown() throws Exception {
6269
terminate(threadPool);
6370
}
6471

65-
public void testFailsToDelete_ADefaultEndpoint() {
66-
modelRegistry.addDefaultIds(
67-
new InferenceService.DefaultConfigId("model-id", MinimalServiceSettings.chatCompletion(), mock(InferenceService.class))
68-
);
72+
public void testFailsToDelete_ADefaultEndpoint_WithoutPassingForceQueryParameter() {
73+
doAnswer(invocationOnMock -> {
74+
ActionListener<UnparsedModel> listener = invocationOnMock.getArgument(1);
75+
listener.onResponse(new UnparsedModel("model_id", TaskType.COMPLETION, "service", Map.of(), Map.of()));
76+
return Void.TYPE;
77+
}).when(mockModelRegistry).getModel(anyString(), any());
78+
when(mockModelRegistry.containsDefaultConfigId(anyString())).thenReturn(true);
6979

7080
var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
7181

7282
action.masterOperation(
7383
mock(Task.class),
74-
new DeleteInferenceEndpointAction.Request("model-id", TaskType.CHAT_COMPLETION, true, false),
75-
mock(ClusterState.class),
84+
new DeleteInferenceEndpointAction.Request("model-id", TaskType.COMPLETION, false, false),
85+
ClusterState.EMPTY_STATE,
7686
listener
7787
);
7888

7989
var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
8090
assertThat(
8191
exception.getMessage(),
82-
is("[model-id] is a reserved inference endpoint. " + "Cannot delete a reserved inference endpoint.")
92+
is("[model-id] is a reserved inference endpoint. Use the force=true query parameter to delete the inference endpoint.")
8393
);
8494
}
95+
96+
public void testDeletesDefaultEndpoint_WhenForceIsTrue() {
97+
doAnswer(invocationOnMock -> {
98+
ActionListener<UnparsedModel> listener = invocationOnMock.getArgument(1);
99+
listener.onResponse(new UnparsedModel("model_id", TaskType.COMPLETION, "service", Map.of(), Map.of()));
100+
return Void.TYPE;
101+
}).when(mockModelRegistry).getModel(anyString(), any());
102+
when(mockModelRegistry.containsDefaultConfigId(anyString())).thenReturn(true);
103+
doAnswer(invocationOnMock -> {
104+
ActionListener<Boolean> listener = invocationOnMock.getArgument(1);
105+
listener.onResponse(true);
106+
return Void.TYPE;
107+
}).when(mockModelRegistry).deleteModel(anyString(), any());
108+
109+
var mockService = mock(InferenceService.class);
110+
doAnswer(invocationOnMock -> {
111+
ActionListener<Boolean> listener = invocationOnMock.getArgument(1);
112+
listener.onResponse(true);
113+
return Void.TYPE;
114+
}).when(mockService).stop(any(), any());
115+
116+
when(mockInferenceServiceRegistry.getService(anyString())).thenReturn(Optional.of(mockService));
117+
118+
var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
119+
120+
action.masterOperation(
121+
mock(Task.class),
122+
new DeleteInferenceEndpointAction.Request("model-id", TaskType.COMPLETION, true, false),
123+
ClusterState.EMPTY_STATE,
124+
listener
125+
);
126+
127+
var response = listener.actionGet(TIMEOUT);
128+
129+
assertTrue(response.isAcknowledged());
130+
}
85131
}

0 commit comments

Comments
 (0)