88package org .elasticsearch .xpack .inference .action ;
99
1010import org .elasticsearch .ElasticsearchStatusException ;
11+ import org .elasticsearch .action .ActionListener ;
1112import org .elasticsearch .action .support .ActionFilters ;
1213import org .elasticsearch .action .support .PlainActionFuture ;
13- import org .elasticsearch .client .internal .Client ;
1414import org .elasticsearch .cluster .ClusterState ;
15- import org .elasticsearch .cluster .metadata .IndexNameExpressionResolver ;
1615import org .elasticsearch .cluster .service .ClusterService ;
1716import org .elasticsearch .core .TimeValue ;
1817import org .elasticsearch .inference .InferenceService ;
1918import org .elasticsearch .inference .InferenceServiceRegistry ;
20- import org .elasticsearch .inference .MinimalServiceSettings ;
2119import org .elasticsearch .inference .TaskType ;
20+ import org .elasticsearch .inference .UnparsedModel ;
2221import org .elasticsearch .tasks .Task ;
2322import org .elasticsearch .test .ESTestCase ;
2423import org .elasticsearch .threadpool .ThreadPool ;
2827import org .junit .After ;
2928import org .junit .Before ;
3029
30+ import java .util .Map ;
31+ import java .util .Optional ;
32+
3133import static org .elasticsearch .xpack .inference .Utils .inferenceUtilityPool ;
3234import static org .hamcrest .Matchers .is ;
35+ import static org .mockito .ArgumentMatchers .any ;
36+ import static org .mockito .ArgumentMatchers .anyString ;
37+ import static org .mockito .Mockito .doAnswer ;
3338import static org .mockito .Mockito .mock ;
39+ import static org .mockito .Mockito .when ;
3440
3541public class TransportDeleteInferenceEndpointActionTests extends ESTestCase {
3642
3743 private static final TimeValue TIMEOUT = TimeValue .timeValueSeconds (30 );
3844
3945 private TransportDeleteInferenceEndpointAction action ;
4046 private ThreadPool threadPool ;
41- private ModelRegistry modelRegistry ;
47+ private ModelRegistry mockModelRegistry ;
48+ private InferenceServiceRegistry mockInferenceServiceRegistry ;
4249
4350 @ Before
4451 public void setUp () throws Exception {
4552 super .setUp ();
46- modelRegistry = new ModelRegistry (mock (Client .class ));
4753 threadPool = createThreadPool (inferenceUtilityPool ());
54+ mockModelRegistry = mock (ModelRegistry .class );
55+ mockInferenceServiceRegistry = mock (InferenceServiceRegistry .class );
4856 action = new TransportDeleteInferenceEndpointAction (
4957 mock (TransportService .class ),
5058 mock (ClusterService .class ),
5159 threadPool ,
5260 mock (ActionFilters .class ),
53- mock (IndexNameExpressionResolver .class ),
54- modelRegistry ,
55- mock (InferenceServiceRegistry .class )
61+ mockModelRegistry ,
62+ mockInferenceServiceRegistry
5663 );
5764 }
5865
@@ -62,24 +69,63 @@ public void tearDown() throws Exception {
6269 terminate (threadPool );
6370 }
6471
65- public void testFailsToDelete_ADefaultEndpoint () {
66- modelRegistry .addDefaultIds (
67- new InferenceService .DefaultConfigId ("model-id" , MinimalServiceSettings .chatCompletion (), mock (InferenceService .class ))
68- );
72+ public void testFailsToDelete_ADefaultEndpoint_WithoutPassingForceQueryParameter () {
73+ doAnswer (invocationOnMock -> {
74+ ActionListener <UnparsedModel > listener = invocationOnMock .getArgument (1 );
75+ listener .onResponse (new UnparsedModel ("model_id" , TaskType .COMPLETION , "service" , Map .of (), Map .of ()));
76+ return Void .TYPE ;
77+ }).when (mockModelRegistry ).getModel (anyString (), any ());
78+ when (mockModelRegistry .containsDefaultConfigId (anyString ())).thenReturn (true );
6979
7080 var listener = new PlainActionFuture <DeleteInferenceEndpointAction .Response >();
7181
7282 action .masterOperation (
7383 mock (Task .class ),
74- new DeleteInferenceEndpointAction .Request ("model-id" , TaskType .CHAT_COMPLETION , true , false ),
75- mock ( ClusterState .class ) ,
84+ new DeleteInferenceEndpointAction .Request ("model-id" , TaskType .COMPLETION , false , false ),
85+ ClusterState .EMPTY_STATE ,
7686 listener
7787 );
7888
7989 var exception = expectThrows (ElasticsearchStatusException .class , () -> listener .actionGet (TIMEOUT ));
8090 assertThat (
8191 exception .getMessage (),
82- is ("[model-id] is a reserved inference endpoint. " + "Cannot delete a reserved inference endpoint." )
92+ is ("[model-id] is a reserved inference endpoint. Use the force=true query parameter to delete the inference endpoint." )
8393 );
8494 }
95+
96+ public void testDeletesDefaultEndpoint_WhenForceIsTrue () {
97+ doAnswer (invocationOnMock -> {
98+ ActionListener <UnparsedModel > listener = invocationOnMock .getArgument (1 );
99+ listener .onResponse (new UnparsedModel ("model_id" , TaskType .COMPLETION , "service" , Map .of (), Map .of ()));
100+ return Void .TYPE ;
101+ }).when (mockModelRegistry ).getModel (anyString (), any ());
102+ when (mockModelRegistry .containsDefaultConfigId (anyString ())).thenReturn (true );
103+ doAnswer (invocationOnMock -> {
104+ ActionListener <Boolean > listener = invocationOnMock .getArgument (1 );
105+ listener .onResponse (true );
106+ return Void .TYPE ;
107+ }).when (mockModelRegistry ).deleteModel (anyString (), any ());
108+
109+ var mockService = mock (InferenceService .class );
110+ doAnswer (invocationOnMock -> {
111+ ActionListener <Boolean > listener = invocationOnMock .getArgument (1 );
112+ listener .onResponse (true );
113+ return Void .TYPE ;
114+ }).when (mockService ).stop (any (), any ());
115+
116+ when (mockInferenceServiceRegistry .getService (anyString ())).thenReturn (Optional .of (mockService ));
117+
118+ var listener = new PlainActionFuture <DeleteInferenceEndpointAction .Response >();
119+
120+ action .masterOperation (
121+ mock (Task .class ),
122+ new DeleteInferenceEndpointAction .Request ("model-id" , TaskType .COMPLETION , true , false ),
123+ ClusterState .EMPTY_STATE ,
124+ listener
125+ );
126+
127+ var response = listener .actionGet (TIMEOUT );
128+
129+ assertTrue (response .isAcknowledged ());
130+ }
85131}
0 commit comments