Skip to content

Commit 0a04210

Browse files
Adding test for deleting default inference endpoint via rest call
1 parent b9e20b6 commit 0a04210

File tree

2 files changed

+141
-0
lines changed

2 files changed

+141
-0
lines changed

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@
1010
import org.elasticsearch.ElasticsearchStatusException;
1111
import org.elasticsearch.TransportVersion;
1212
import org.elasticsearch.action.ActionListener;
13+
import org.elasticsearch.action.support.PlainActionFuture;
1314
import org.elasticsearch.client.internal.Client;
1415
import org.elasticsearch.cluster.service.ClusterService;
1516
import org.elasticsearch.common.io.stream.StreamOutput;
1617
import org.elasticsearch.common.settings.Settings;
18+
import org.elasticsearch.core.TimeValue;
1719
import org.elasticsearch.index.IndexNotFoundException;
1820
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
1921
import org.elasticsearch.inference.InferenceService;
@@ -51,7 +53,9 @@
5153
import java.util.HashMap;
5254
import java.util.List;
5355
import java.util.Map;
56+
import java.util.Set;
5457
import java.util.concurrent.CountDownLatch;
58+
import java.util.concurrent.TimeUnit;
5559
import java.util.concurrent.atomic.AtomicReference;
5660
import java.util.function.Consumer;
5761
import java.util.function.Function;
@@ -70,6 +74,7 @@
7074
import static org.mockito.Mockito.mock;
7175

7276
public class ModelRegistryIT extends ESSingleNodeTestCase {
77+
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
7378

7479
private ModelRegistry modelRegistry;
7580

@@ -195,6 +200,57 @@ public void testDeleteModel() throws Exception {
195200
assertThat(exceptionHolder.get().getMessage(), containsString("Inference endpoint not found [model1]"));
196201
}
197202

203+
public void testNonExistentDeleteModel_DoesNotThrowAnException() {
204+
var listener = new PlainActionFuture<Boolean>();
205+
206+
modelRegistry.deleteModel("non-existent-model", listener);
207+
assertTrue(listener.actionGet(TIMEOUT));
208+
}
209+
210+
public void testRemoveDefaultConfigs_DoesNotThrowAnException_WhenSearchingForNonExistentInferenceEndpointIds() {
211+
var listener = new PlainActionFuture<Boolean>();
212+
213+
modelRegistry.deleteModels(Set.of("non-existent-model", "abc"), listener);
214+
assertTrue(listener.actionGet(TIMEOUT));
215+
}
216+
217+
public void testRemoveDefaultConfigs_RemovesModelsFromPersistentStorage_AndInMemoryCache() {
218+
var service = mock(InferenceService.class);
219+
220+
var defaultConfigs = new ArrayList<Model>();
221+
var defaultIds = new ArrayList<InferenceService.DefaultConfigId>();
222+
for (var id : new String[] { "model1", "model2", "model3" }) {
223+
var modelSettings = ModelRegistryTests.randomMinimalServiceSettings();
224+
defaultConfigs.add(createModel(id, modelSettings.taskType(), "name"));
225+
defaultIds.add(new InferenceService.DefaultConfigId(id, modelSettings, service));
226+
}
227+
228+
doAnswer(invocation -> {
229+
@SuppressWarnings("unchecked")
230+
var listener = (ActionListener<List<Model>>) invocation.getArguments()[0];
231+
listener.onResponse(defaultConfigs);
232+
return Void.TYPE;
233+
}).when(service).defaultConfigs(any());
234+
235+
defaultIds.forEach(modelRegistry::addDefaultIds);
236+
237+
var getModelsListener = new PlainActionFuture<List<UnparsedModel>>();
238+
modelRegistry.getAllModels(true, getModelsListener);
239+
var unparsedModels = getModelsListener.actionGet(TIMEOUT);
240+
assertThat(unparsedModels.size(), is(3));
241+
242+
var removeModelsListener = new PlainActionFuture<Boolean>();
243+
244+
modelRegistry.removeDefaultConfigs(Set.of("model1", "model2", "model3"), removeModelsListener);
245+
assertTrue(removeModelsListener.actionGet(TIMEOUT));
246+
247+
var getModelsAfterDeleteListener = new PlainActionFuture<List<UnparsedModel>>();
248+
// the models should have been removed from the in memory cache, if not they they will be persisted again by this call
249+
modelRegistry.getAllModels(true, getModelsAfterDeleteListener);
250+
var unparsedModelsAfterDelete = getModelsAfterDeleteListener.actionGet(TIMEOUT);
251+
assertThat(unparsedModelsAfterDelete.size(), is(0));
252+
}
253+
198254
public void testGetModelsByTaskType() throws InterruptedException {
199255
var service = "foo";
200256
var sparseAndTextEmbeddingModels = new ArrayList<Model>();
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.action;
9+
10+
import org.elasticsearch.ElasticsearchStatusException;
11+
import org.elasticsearch.action.support.ActionFilters;
12+
import org.elasticsearch.action.support.PlainActionFuture;
13+
import org.elasticsearch.client.internal.Client;
14+
import org.elasticsearch.cluster.ClusterState;
15+
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
16+
import org.elasticsearch.cluster.service.ClusterService;
17+
import org.elasticsearch.core.TimeValue;
18+
import org.elasticsearch.inference.InferenceService;
19+
import org.elasticsearch.inference.InferenceServiceRegistry;
20+
import org.elasticsearch.inference.MinimalServiceSettings;
21+
import org.elasticsearch.inference.TaskType;
22+
import org.elasticsearch.tasks.Task;
23+
import org.elasticsearch.test.ESTestCase;
24+
import org.elasticsearch.threadpool.ThreadPool;
25+
import org.elasticsearch.transport.TransportService;
26+
import org.elasticsearch.xpack.core.inference.action.DeleteInferenceEndpointAction;
27+
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
28+
import org.junit.After;
29+
import org.junit.Before;
30+
31+
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
32+
import static org.hamcrest.Matchers.is;
33+
import static org.mockito.Mockito.mock;
34+
35+
public class TransportDeleteInferenceEndpointActionTests extends ESTestCase {
36+
37+
private static final TimeValue TIMEOUT = TimeValue.timeValueSeconds(30);
38+
39+
private TransportDeleteInferenceEndpointAction action;
40+
private ThreadPool threadPool;
41+
private ModelRegistry modelRegistry;
42+
43+
@Before
44+
public void setUp() throws Exception {
45+
super.setUp();
46+
modelRegistry = new ModelRegistry(mock(Client.class));
47+
threadPool = createThreadPool(inferenceUtilityPool());
48+
action = new TransportDeleteInferenceEndpointAction(
49+
mock(TransportService.class),
50+
mock(ClusterService.class),
51+
threadPool,
52+
mock(ActionFilters.class),
53+
mock(IndexNameExpressionResolver.class),
54+
modelRegistry,
55+
mock(InferenceServiceRegistry.class)
56+
);
57+
}
58+
59+
@After
60+
public void tearDown() throws Exception {
61+
super.tearDown();
62+
terminate(threadPool);
63+
}
64+
65+
public void testFailsToDelete_ADefaultEndpoint() {
66+
modelRegistry.addDefaultIds(
67+
new InferenceService.DefaultConfigId("model-id", MinimalServiceSettings.chatCompletion(), mock(InferenceService.class))
68+
);
69+
70+
var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
71+
72+
action.masterOperation(
73+
mock(Task.class),
74+
new DeleteInferenceEndpointAction.Request("model-id", TaskType.CHAT_COMPLETION, true, false),
75+
mock(ClusterState.class),
76+
listener
77+
);
78+
79+
var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
80+
assertThat(
81+
exception.getMessage(),
82+
is("[model-id] is a reserved inference endpoint. " + "Cannot delete a reserved inference endpoint.")
83+
);
84+
}
85+
}

0 commit comments

Comments
 (0)