|
10 | 10 | import org.elasticsearch.ElasticsearchStatusException; |
11 | 11 | import org.elasticsearch.TransportVersion; |
12 | 12 | import org.elasticsearch.action.ActionListener; |
| 13 | +import org.elasticsearch.action.support.PlainActionFuture; |
13 | 14 | import org.elasticsearch.client.internal.Client; |
14 | 15 | import org.elasticsearch.cluster.service.ClusterService; |
15 | 16 | import org.elasticsearch.common.io.stream.StreamOutput; |
16 | 17 | import org.elasticsearch.common.settings.Settings; |
| 18 | +import org.elasticsearch.core.TimeValue; |
17 | 19 | import org.elasticsearch.index.IndexNotFoundException; |
18 | 20 | import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; |
19 | 21 | import org.elasticsearch.inference.InferenceService; |
|
51 | 53 | import java.util.HashMap; |
52 | 54 | import java.util.List; |
53 | 55 | import java.util.Map; |
| 56 | +import java.util.Set; |
54 | 57 | import java.util.concurrent.CountDownLatch; |
| 58 | +import java.util.concurrent.TimeUnit; |
55 | 59 | import java.util.concurrent.atomic.AtomicReference; |
56 | 60 | import java.util.function.Consumer; |
57 | 61 | import java.util.function.Function; |
|
70 | 74 | import static org.mockito.Mockito.mock; |
71 | 75 |
|
72 | 76 | public class ModelRegistryIT extends ESSingleNodeTestCase { |
| 77 | + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); |
73 | 78 |
|
74 | 79 | private ModelRegistry modelRegistry; |
75 | 80 |
|
@@ -195,6 +200,57 @@ public void testDeleteModel() throws Exception { |
195 | 200 | assertThat(exceptionHolder.get().getMessage(), containsString("Inference endpoint not found [model1]")); |
196 | 201 | } |
197 | 202 |
|
| 203 | + public void testNonExistentDeleteModel_DoesNotThrowAnException() { |
| 204 | + var listener = new PlainActionFuture<Boolean>(); |
| 205 | + |
| 206 | + modelRegistry.deleteModel("non-existent-model", listener); |
| 207 | + assertTrue(listener.actionGet(TIMEOUT)); |
| 208 | + } |
| 209 | + |
| 210 | + public void testRemoveDefaultConfigs_DoesNotThrowAnException_WhenSearchingForNonExistentInferenceEndpointIds() { |
| 211 | + var listener = new PlainActionFuture<Boolean>(); |
| 212 | + |
| 213 | + modelRegistry.deleteModels(Set.of("non-existent-model", "abc"), listener); |
| 214 | + assertTrue(listener.actionGet(TIMEOUT)); |
| 215 | + } |
| 216 | + |
| 217 | + public void testRemoveDefaultConfigs_RemovesModelsFromPersistentStorage_AndInMemoryCache() { |
| 218 | + var service = mock(InferenceService.class); |
| 219 | + |
| 220 | + var defaultConfigs = new ArrayList<Model>(); |
| 221 | + var defaultIds = new ArrayList<InferenceService.DefaultConfigId>(); |
| 222 | + for (var id : new String[] { "model1", "model2", "model3" }) { |
| 223 | + var modelSettings = ModelRegistryTests.randomMinimalServiceSettings(); |
| 224 | + defaultConfigs.add(createModel(id, modelSettings.taskType(), "name")); |
| 225 | + defaultIds.add(new InferenceService.DefaultConfigId(id, modelSettings, service)); |
| 226 | + } |
| 227 | + |
| 228 | + doAnswer(invocation -> { |
| 229 | + @SuppressWarnings("unchecked") |
| 230 | + var listener = (ActionListener<List<Model>>) invocation.getArguments()[0]; |
| 231 | + listener.onResponse(defaultConfigs); |
| 232 | + return Void.TYPE; |
| 233 | + }).when(service).defaultConfigs(any()); |
| 234 | + |
| 235 | + defaultIds.forEach(modelRegistry::addDefaultIds); |
| 236 | + |
| 237 | + var getModelsListener = new PlainActionFuture<List<UnparsedModel>>(); |
| 238 | + modelRegistry.getAllModels(true, getModelsListener); |
| 239 | + var unparsedModels = getModelsListener.actionGet(TIMEOUT); |
| 240 | + assertThat(unparsedModels.size(), is(3)); |
| 241 | + |
| 242 | + var removeModelsListener = new PlainActionFuture<Boolean>(); |
| 243 | + |
| 244 | + modelRegistry.removeDefaultConfigs(Set.of("model1", "model2", "model3"), removeModelsListener); |
| 245 | + assertTrue(removeModelsListener.actionGet(TIMEOUT)); |
| 246 | + |
| 247 | + var getModelsAfterDeleteListener = new PlainActionFuture<List<UnparsedModel>>(); |
| 248 | + // the models should have been removed from the in memory cache, if not they they will be persisted again by this call |
| 249 | + modelRegistry.getAllModels(true, getModelsAfterDeleteListener); |
| 250 | + var unparsedModelsAfterDelete = getModelsAfterDeleteListener.actionGet(TIMEOUT); |
| 251 | + assertThat(unparsedModelsAfterDelete.size(), is(0)); |
| 252 | + } |
| 253 | + |
198 | 254 | public void testGetModelsByTaskType() throws InterruptedException { |
199 | 255 | var service = "foo"; |
200 | 256 | var sparseAndTextEmbeddingModels = new ArrayList<Model>(); |
|
0 commit comments