88package org .elasticsearch .xpack .inference .action ;
99
1010import org .elasticsearch .ElasticsearchStatusException ;
11+ import org .elasticsearch .action .ActionListener ;
1112import org .elasticsearch .action .support .ActionFilters ;
1213import org .elasticsearch .action .support .PlainActionFuture ;
13- import org .elasticsearch .client .internal .Client ;
1414import org .elasticsearch .cluster .ClusterState ;
1515import org .elasticsearch .cluster .service .ClusterService ;
1616import org .elasticsearch .core .TimeValue ;
1717import org .elasticsearch .inference .InferenceService ;
1818import org .elasticsearch .inference .InferenceServiceRegistry ;
19- import org .elasticsearch .inference .MinimalServiceSettings ;
2019import org .elasticsearch .inference .TaskType ;
20+ import org .elasticsearch .inference .UnparsedModel ;
2121import org .elasticsearch .tasks .Task ;
2222import org .elasticsearch .test .ESTestCase ;
2323import org .elasticsearch .threadpool .ThreadPool ;
2727import org .junit .After ;
2828import org .junit .Before ;
2929
30+ import java .util .Map ;
31+ import java .util .Optional ;
32+
3033import static org .elasticsearch .xpack .inference .Utils .inferenceUtilityPool ;
3134import static org .hamcrest .Matchers .is ;
35+ import static org .mockito .ArgumentMatchers .any ;
36+ import static org .mockito .ArgumentMatchers .anyString ;
37+ import static org .mockito .Mockito .doAnswer ;
3238import static org .mockito .Mockito .mock ;
39+ import static org .mockito .Mockito .when ;
3340
3441public class TransportDeleteInferenceEndpointActionTests extends ESTestCase {
3542
3643 private static final TimeValue TIMEOUT = TimeValue .timeValueSeconds (30 );
3744
3845 private TransportDeleteInferenceEndpointAction action ;
3946 private ThreadPool threadPool ;
40- private ModelRegistry modelRegistry ;
47+ private ModelRegistry mockModelRegistry ;
48+ private InferenceServiceRegistry mockInferenceServiceRegistry ;
4149
4250 @ Before
4351 public void setUp () throws Exception {
4452 super .setUp ();
45- modelRegistry = new ModelRegistry (mock (Client .class ));
4653 threadPool = createThreadPool (inferenceUtilityPool ());
54+ mockModelRegistry = mock (ModelRegistry .class );
55+ mockInferenceServiceRegistry = mock (InferenceServiceRegistry .class );
4756 action = new TransportDeleteInferenceEndpointAction (
4857 mock (TransportService .class ),
4958 mock (ClusterService .class ),
5059 threadPool ,
5160 mock (ActionFilters .class ),
52- modelRegistry ,
53- mock ( InferenceServiceRegistry . class )
61+ mockModelRegistry ,
62+ mockInferenceServiceRegistry
5463 );
5564 }
5665
@@ -60,24 +69,63 @@ public void tearDown() throws Exception {
6069 terminate (threadPool );
6170 }
6271
63- public void testFailsToDelete_ADefaultEndpoint () {
64- modelRegistry .addDefaultIds (
65- new InferenceService .DefaultConfigId ("model-id" , MinimalServiceSettings .chatCompletion (), mock (InferenceService .class ))
66- );
72+ public void testFailsToDelete_ADefaultEndpoint_WithoutPassingForceQueryParameter () {
73+ doAnswer (invocationOnMock -> {
74+ ActionListener <UnparsedModel > listener = invocationOnMock .getArgument (1 );
75+ listener .onResponse (new UnparsedModel ("model_id" , TaskType .COMPLETION , "service" , Map .of (), Map .of ()));
76+ return Void .TYPE ;
77+ }).when (mockModelRegistry ).getModel (anyString (), any ());
78+ when (mockModelRegistry .containsDefaultConfigId (anyString ())).thenReturn (true );
6779
6880 var listener = new PlainActionFuture <DeleteInferenceEndpointAction .Response >();
6981
7082 action .masterOperation (
7183 mock (Task .class ),
72- new DeleteInferenceEndpointAction .Request ("model-id" , TaskType .CHAT_COMPLETION , true , false ),
73- mock ( ClusterState .class ) ,
84+ new DeleteInferenceEndpointAction .Request ("model-id" , TaskType .COMPLETION , false , false ),
85+ ClusterState .EMPTY_STATE ,
7486 listener
7587 );
7688
7789 var exception = expectThrows (ElasticsearchStatusException .class , () -> listener .actionGet (TIMEOUT ));
7890 assertThat (
7991 exception .getMessage (),
80- is ("[model-id] is a reserved inference endpoint. " + "Cannot delete a reserved inference endpoint." )
92+ is ("[model-id] is a reserved inference endpoint. Use the force=true query parameter to delete the inference endpoint." )
8193 );
8294 }
95+
96+ public void testDeletesDefaultEndpoint_WhenForceIsTrue () {
97+ doAnswer (invocationOnMock -> {
98+ ActionListener <UnparsedModel > listener = invocationOnMock .getArgument (1 );
99+ listener .onResponse (new UnparsedModel ("model_id" , TaskType .COMPLETION , "service" , Map .of (), Map .of ()));
100+ return Void .TYPE ;
101+ }).when (mockModelRegistry ).getModel (anyString (), any ());
102+ when (mockModelRegistry .containsDefaultConfigId (anyString ())).thenReturn (true );
103+ doAnswer (invocationOnMock -> {
104+ ActionListener <Boolean > listener = invocationOnMock .getArgument (1 );
105+ listener .onResponse (true );
106+ return Void .TYPE ;
107+ }).when (mockModelRegistry ).deleteModel (anyString (), any ());
108+
109+ var mockService = mock (InferenceService .class );
110+ doAnswer (invocationOnMock -> {
111+ ActionListener <Boolean > listener = invocationOnMock .getArgument (1 );
112+ listener .onResponse (true );
113+ return Void .TYPE ;
114+ }).when (mockService ).stop (any (), any ());
115+
116+ when (mockInferenceServiceRegistry .getService (anyString ())).thenReturn (Optional .of (mockService ));
117+
118+ var listener = new PlainActionFuture <DeleteInferenceEndpointAction .Response >();
119+
120+ action .masterOperation (
121+ mock (Task .class ),
122+ new DeleteInferenceEndpointAction .Request ("model-id" , TaskType .COMPLETION , true , false ),
123+ ClusterState .EMPTY_STATE ,
124+ listener
125+ );
126+
127+ var response = listener .actionGet (TIMEOUT );
128+
129+ assertTrue (response .isAcknowledged ());
130+ }
83131}
0 commit comments