diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java new file mode 100644 index 0000000000000..201f1250427a8 --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java @@ -0,0 +1,271 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.integration; + +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.MinimalServiceSettings; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.reindex.ReindexPlugin; +import org.elasticsearch.test.ESSingleNodeTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationHandler; +import org.junit.After; +import org.junit.Before; + +import java.util.Collection; +import java.util.EnumSet; +import java.util.List; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.hamcrest.CoreMatchers.is; +import static org.mockito.Mockito.mock; + +public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + + private ModelRegistry modelRegistry; + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private String gatewayUrl; + + @Before + public void createComponents() throws Exception { + threadPool = createThreadPool(inferenceUtilityPool()); + webServer.start(); + gatewayUrl = getUrl(webServer); + modelRegistry = new ModelRegistry(client()); + } + + @After + public void shutdown() { + terminate(threadPool); + webServer.close(); + } + + @Override + protected boolean resetNodeAfterTest() { + return true; + } + + @Override + protected Collection> getPlugins() { + return pluginList(ReindexPlugin.class); + } + + public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCorrect() throws Exception { + String responseJson = """ + { + "models": [ + { + "model_name": "rainbow-sprinkles", + "task_types": ["chat"] + } + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + try (var service = createElasticInferenceService()) { + service.waitForAuthorizationToComplete(TIMEOUT); + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY))); + assertThat( + service.defaultConfigIds(), + is( + List.of( + new InferenceService.DefaultConfigId(".rainbow-sprinkles-elastic", MinimalServiceSettings.chatCompletion(), service) + ) + ) + ); + assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); + + PlainActionFuture> listener = new PlainActionFuture<>(); + service.defaultConfigs(listener); + assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); + } + } + + public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty() throws Exception { + { + String responseJson = """ + { + "models": [ + { + "model_name": "rainbow-sprinkles", + "task_types": ["chat"] + } + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + try (var service = createElasticInferenceService()) { + service.waitForAuthorizationToComplete(TIMEOUT); + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY))); + assertThat( + service.defaultConfigIds(), + is( + List.of( + new InferenceService.DefaultConfigId( + ".rainbow-sprinkles-elastic", + MinimalServiceSettings.chatCompletion(), + service + ) + ) + ) + ); + assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); + + PlainActionFuture> listener = new PlainActionFuture<>(); + service.defaultConfigs(listener); + assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); + + var getModelListener = new PlainActionFuture(); + // persists the default endpoints + modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener); + + var inferenceEntity = getModelListener.actionGet(TIMEOUT); + assertThat(inferenceEntity.inferenceEntityId(), is(".rainbow-sprinkles-elastic")); + assertThat(inferenceEntity.taskType(), is(TaskType.CHAT_COMPLETION)); + } + } + { + String noAuthorizationResponseJson = """ + { + "models": [] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson)); + + try (var service = createElasticInferenceService()) { + service.waitForAuthorizationToComplete(TIMEOUT); + assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class))); + assertTrue(service.defaultConfigIds().isEmpty()); + assertThat(service.supportedTaskTypes(), is(EnumSet.noneOf(TaskType.class))); + + var getModelListener = new PlainActionFuture(); + modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener); + + var exception = expectThrows(ResourceNotFoundException.class, () -> getModelListener.actionGet(TIMEOUT)); + assertThat(exception.getMessage(), is("Inference endpoint not found [.rainbow-sprinkles-elastic]")); + } + } + } + + public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnAuthForIt() throws Exception { + { + String responseJson = """ + { + "models": [ + { + "model_name": "rainbow-sprinkles", + "task_types": ["chat"] + }, + { + "model_name": "elser-v2", + "task_types": ["embed/text/sparse"] + } + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + try (var service = createElasticInferenceService()) { + service.waitForAuthorizationToComplete(TIMEOUT); + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY))); + assertThat( + service.defaultConfigIds(), + is( + List.of( + new InferenceService.DefaultConfigId( + ".rainbow-sprinkles-elastic", + MinimalServiceSettings.chatCompletion(), + service + ) + ) + ) + ); + assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING))); + + PlainActionFuture> listener = new PlainActionFuture<>(); + service.defaultConfigs(listener); + assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); + + var getModelListener = new PlainActionFuture(); + // persists the default endpoints + modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener); + + var inferenceEntity = getModelListener.actionGet(TIMEOUT); + assertThat(inferenceEntity.inferenceEntityId(), is(".rainbow-sprinkles-elastic")); + assertThat(inferenceEntity.taskType(), is(TaskType.CHAT_COMPLETION)); + } + } + { + String noAuthorizationResponseJson = """ + { + "models": [ + { + "model_name": "elser-v2", + "task_types": ["embed/text/sparse"] + } + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson)); + + try (var service = createElasticInferenceService()) { + service.waitForAuthorizationToComplete(TIMEOUT); + assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class))); + assertTrue(service.defaultConfigIds().isEmpty()); + assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING))); + + var getModelListener = new PlainActionFuture(); + modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener); + + var exception = expectThrows(ResourceNotFoundException.class, () -> getModelListener.actionGet(TIMEOUT)); + assertThat(exception.getMessage(), is("Inference endpoint not found [.rainbow-sprinkles-elastic]")); + } + } + } + + private ElasticInferenceService createElasticInferenceService() { + var httpManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, httpManager); + + return new ElasticInferenceService( + senderFactory, + createWithEmptySettings(threadPool), + new ElasticInferenceServiceComponents(gatewayUrl), + modelRegistry, + new ElasticInferenceServiceAuthorizationHandler(gatewayUrl, threadPool) + ); + } +} diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index dfdca6226efd3..8f6e9f8cb5f21 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -10,10 +10,12 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.InferenceService; @@ -51,7 +53,9 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Function; @@ -70,6 +74,7 @@ import static org.mockito.Mockito.mock; public class ModelRegistryIT extends ESSingleNodeTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); private ModelRegistry modelRegistry; @@ -195,6 +200,56 @@ public void testDeleteModel() throws Exception { assertThat(exceptionHolder.get().getMessage(), containsString("Inference endpoint not found [model1]")); } + public void testNonExistentDeleteModel_DoesNotThrowAnException() { + var listener = new PlainActionFuture(); + + modelRegistry.deleteModel("non-existent-model", listener); + assertTrue(listener.actionGet(TIMEOUT)); + } + + public void testRemoveDefaultConfigs_DoesNotThrowAnException_WhenSearchingForNonExistentInferenceEndpointIds() { + var listener = new PlainActionFuture(); + + modelRegistry.deleteModels(Set.of("non-existent-model", "abc"), listener); + assertTrue(listener.actionGet(TIMEOUT)); + } + + public void testRemoveDefaultConfigs_RemovesModelsFromPersistentStorage_AndInMemoryCache() { + var service = mock(InferenceService.class); + + var defaultConfigs = new ArrayList(); + var defaultIds = new ArrayList(); + for (var id : new String[] { "model1", "model2", "model3" }) { + var modelSettings = ModelRegistryTests.randomMinimalServiceSettings(); + defaultConfigs.add(createModel(id, modelSettings.taskType(), "name")); + defaultIds.add(new InferenceService.DefaultConfigId(id, modelSettings, service)); + } + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(0); + listener.onResponse(defaultConfigs); + return Void.TYPE; + }).when(service).defaultConfigs(any()); + + defaultIds.forEach(modelRegistry::addDefaultIds); + + var getModelsListener = new PlainActionFuture>(); + modelRegistry.getAllModels(true, getModelsListener); + var unparsedModels = getModelsListener.actionGet(TIMEOUT); + assertThat(unparsedModels.size(), is(3)); + + var removeModelsListener = new PlainActionFuture(); + + modelRegistry.removeDefaultConfigs(Set.of("model1", "model2", "model3"), removeModelsListener); + assertTrue(removeModelsListener.actionGet(TIMEOUT)); + + var getModelsAfterDeleteListener = new PlainActionFuture>(); + // the models should have been removed from the in memory cache, if not they they will be persisted again by this call + modelRegistry.getAllModels(true, getModelsAfterDeleteListener); + var unparsedModelsAfterDelete = getModelsAfterDeleteListener.actionGet(TIMEOUT); + assertThat(unparsedModelsAfterDelete.size(), is(0)); + } + public void testGetModelsByTaskType() throws InterruptedException { var service = "foo"; var sparseAndTextEmbeddingModels = new ArrayList(); @@ -315,8 +370,7 @@ public void testGetAllModels_WithDefaults() throws Exception { } doAnswer(invocation -> { - @SuppressWarnings("unchecked") - var listener = (ActionListener>) invocation.getArguments()[0]; + ActionListener> listener = invocation.getArgument(0); listener.onResponse(defaultConfigs); return Void.TYPE; }).when(service).defaultConfigs(any()); @@ -381,8 +435,7 @@ public void testGetAllModels_OnlyDefaults() throws Exception { } doAnswer(invocation -> { - @SuppressWarnings("unchecked") - var listener = (ActionListener>) invocation.getArguments()[0]; + ActionListener> listener = invocation.getArgument(0); listener.onResponse(defaultConfigs); return Void.TYPE; }).when(service).defaultConfigs(any()); @@ -424,8 +477,7 @@ public void testGetAllModels_withDoNotPersist() throws Exception { } doAnswer(invocation -> { - @SuppressWarnings("unchecked") - var listener = (ActionListener>) invocation.getArguments()[0]; + ActionListener> listener = invocation.getArgument(0); listener.onResponse(defaultConfigs); return Void.TYPE; }).when(service).defaultConfigs(any()); @@ -466,8 +518,7 @@ public void testGet_WithDefaults() throws InterruptedException { ); doAnswer(invocation -> { - @SuppressWarnings("unchecked") - var listener = (ActionListener>) invocation.getArguments()[0]; + ActionListener> listener = invocation.getArgument(0); listener.onResponse(defaultConfigs); return Void.TYPE; }).when(service).defaultConfigs(any()); @@ -520,8 +571,7 @@ public void testGetByTaskType_WithDefaults() throws Exception { defaultIds.add(new InferenceService.DefaultConfigId("default-chat", MinimalServiceSettings.completion(), service)); doAnswer(invocation -> { - @SuppressWarnings("unchecked") - var listener = (ActionListener>) invocation.getArguments()[0]; + ActionListener> listener = invocation.getArgument(0); listener.onResponse(List.of(defaultSparse, defaultChat, defaultText)); return Void.TYPE; }).when(service).defaultConfigs(any()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java index 0e441e78fb986..63601315cf45e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java @@ -88,6 +88,17 @@ private void doExecuteForked( ClusterState state, ActionListener masterListener ) { + if (modelRegistry.containsDefaultConfigId(request.getInferenceEndpointId())) { + masterListener.onFailure( + new ElasticsearchStatusException( + "[{}] is a reserved inference endpoint. Cannot delete a reserved inference endpoint.", + RestStatus.BAD_REQUEST, + request.getInferenceEndpointId() + ) + ); + return; + } + SubscribableListener.newForked(modelConfigListener -> { // Get the model from the registry diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index a9642a685aec9..2bcb130ddccbd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -28,6 +28,7 @@ import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.client.internal.Client; import org.elasticsearch.client.internal.OriginSettingClient; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.index.engine.VersionConflictEngineException; @@ -61,6 +62,7 @@ import java.util.Collections; import java.util.Comparator; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; @@ -111,7 +113,7 @@ public static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap) public ModelRegistry(Client client) { this.client = new OriginSettingClient(client, ClientHelper.INFERENCE_ORIGIN); - defaultConfigIds = new HashMap<>(); + defaultConfigIds = new ConcurrentHashMap<>(); } /** @@ -644,11 +646,32 @@ private static BulkItemResponse.Failure getFirstBulkFailure(BulkResponse bulkRes return null; } + public synchronized void removeDefaultConfigs(Set inferenceEntityIds, ActionListener listener) { + if (inferenceEntityIds.isEmpty()) { + listener.onResponse(true); + return; + } + + defaultConfigIds.keySet().removeAll(inferenceEntityIds); + deleteModels(inferenceEntityIds, listener); + } + public void deleteModel(String inferenceEntityId, ActionListener listener) { - if (preventDeletionLock.contains(inferenceEntityId)) { + deleteModels(Set.of(inferenceEntityId), listener); + } + + public void deleteModels(Set inferenceEntityIds, ActionListener listener) { + var lockedInferenceIds = new HashSet<>(inferenceEntityIds); + lockedInferenceIds.retainAll(preventDeletionLock); + + if (lockedInferenceIds.isEmpty() == false) { listener.onFailure( new ElasticsearchStatusException( - "Model is currently being updated, you may delete the model once the update completes", + Strings.format( + "The inference endpoint(s) %s are currently being updated, please wait until after they are " + + "finished updating to delete.", + lockedInferenceIds + ), RestStatus.CONFLICT ) ); @@ -657,7 +680,7 @@ public void deleteModel(String inferenceEntityId, ActionListener listen DeleteByQueryRequest request = new DeleteByQueryRequest().setAbortOnVersionConflict(false); request.indices(InferenceIndex.INDEX_PATTERN, InferenceSecretsIndex.INDEX_PATTERN); - request.setQuery(documentIdQuery(inferenceEntityId)); + request.setQuery(documentIdsQuery(inferenceEntityIds)); request.setRefresh(true); client.execute(DeleteByQueryAction.INSTANCE, request, listener.delegateFailureAndWrap((l, r) -> l.onResponse(Boolean.TRUE))); @@ -695,6 +718,11 @@ private QueryBuilder documentIdQuery(String inferenceEntityId) { return QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(Model.documentId(inferenceEntityId))); } + private QueryBuilder documentIdsQuery(Set inferenceEntityIds) { + var documentIdsArray = inferenceEntityIds.stream().map(Model::documentId).toArray(String[]::new); + return QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(documentIdsArray)); + } + static Optional idMatchedDefault( String inferenceId, List defaultConfigIds diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index f010c2f85a063..799f2c3eaa905 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -68,8 +68,10 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException; +import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; @@ -158,10 +160,7 @@ static AuthorizedContent empty() { private void getAuthorization() { try { - ActionListener listener = ActionListener.wrap(result -> { - setAuthorizedContent(result); - authorizationCompletedLatch.countDown(); - }, e -> { + ActionListener listener = ActionListener.wrap(this::setAuthorizedContent, e -> { // we don't need to do anything if there was a failure, everything is disabled by default authorizationCompletedLatch.countDown(); }); @@ -177,18 +176,30 @@ private synchronized void setAuthorizedContent(ElasticInferenceServiceAuthorizat var authorizedTaskTypesAndModels = auth.newLimitedToTaskTypes(EnumSet.copyOf(IMPLEMENTED_TASK_TYPES)); // recalculate which default config ids and models are authorized now - var authorizedDefaultConfigIds = getAuthorizedDefaultConfigIds(auth); - var authorizedDefaultModelObjects = getAuthorizedDefaultModelsObjects(auth); + var authorizedDefaultModelIds = getAuthorizedDefaultModelIds(auth); + + var authorizedDefaultConfigIds = getAuthorizedDefaultConfigIds(authorizedDefaultModelIds, auth); + var authorizedDefaultModelObjects = getAuthorizedDefaultModelsObjects(authorizedDefaultModelIds); authRef.set(new AuthorizedContent(authorizedTaskTypesAndModels, authorizedDefaultConfigIds, authorizedDefaultModelObjects)); configuration = new Configuration(authRef.get().taskTypesAndModels.getAuthorizedTaskTypes()); defaultConfigIds().forEach(modelRegistry::addDefaultIds); + handleRevokedDefaultConfigs(authorizedDefaultModelIds); } - private List getAuthorizedDefaultConfigIds(ElasticInferenceServiceAuthorization auth) { - var authorizedDefaultModelIds = getAuthorizedDefaultModelIds(auth); + private Set getAuthorizedDefaultModelIds(ElasticInferenceServiceAuthorization auth) { + var authorizedModels = auth.getAuthorizedModelIds(); + var authorizedDefaultModelIds = new HashSet<>(defaultModelsConfigs.keySet()); + authorizedDefaultModelIds.retainAll(authorizedModels); + return authorizedDefaultModelIds; + } + + private List getAuthorizedDefaultConfigIds( + Set authorizedDefaultModelIds, + ElasticInferenceServiceAuthorization auth + ) { var authorizedConfigIds = new ArrayList(); for (var id : authorizedDefaultModelIds) { var modelConfig = defaultModelsConfigs.get(id); @@ -210,17 +221,7 @@ private List getAuthorizedDefaultConfigIds(ElasticInferenceServ return authorizedConfigIds; } - private Set getAuthorizedDefaultModelIds(ElasticInferenceServiceAuthorization auth) { - var authorizedModels = auth.getAuthorizedModelIds(); - var authorizedDefaultModelIds = new HashSet<>(defaultModelsConfigs.keySet()); - authorizedDefaultModelIds.retainAll(authorizedModels); - - return authorizedDefaultModelIds; - } - - private List getAuthorizedDefaultModelsObjects(ElasticInferenceServiceAuthorization auth) { - var authorizedDefaultModelIds = getAuthorizedDefaultModelIds(auth); - + private List getAuthorizedDefaultModelsObjects(Set authorizedDefaultModelIds) { var authorizedModels = new ArrayList(); for (var id : authorizedDefaultModelIds) { var modelConfig = defaultModelsConfigs.get(id); @@ -232,8 +233,39 @@ private List getAuthorizedDefaultModelsObjects(ElasticInfere return authorizedModels; } - // Default for testing - void waitForAuthorizationToComplete(TimeValue waitTime) { + private void handleRevokedDefaultConfigs(Set authorizedDefaultModelIds) { + // if a model was initially returned in the authorization response but is absent, then we'll assume authorization was revoked + var unauthorizedDefaultModelIds = new HashSet<>(defaultModelsConfigs.keySet()); + unauthorizedDefaultModelIds.removeAll(authorizedDefaultModelIds); + + // get all the default inference endpoint ids for the unauthorized model ids + var unauthorizedDefaultInferenceEndpointIds = unauthorizedDefaultModelIds.stream() + .map(defaultModelsConfigs::get) // get all the model configs + .filter(Objects::nonNull) // limit to only non-null + .map(modelConfig -> modelConfig.model.getInferenceEntityId()) // get the inference ids + .collect(Collectors.toSet()); + + var deleteInferenceEndpointsListener = ActionListener.wrap(result -> { + logger.trace(Strings.format("Successfully revoked access to default inference endpoint IDs: %s", unauthorizedDefaultModelIds)); + authorizationCompletedLatch.countDown(); + }, e -> { + logger.warn( + Strings.format("Failed to revoke access to default inference endpoint IDs: %s, error: %s", unauthorizedDefaultModelIds, e) + ); + authorizationCompletedLatch.countDown(); + }); + + getServiceComponents().threadPool() + .executor(UTILITY_THREAD_POOL_NAME) + .execute(() -> modelRegistry.removeDefaultConfigs(unauthorizedDefaultInferenceEndpointIds, deleteInferenceEndpointsListener)); + } + + /** + * Waits the specified amount of time for the authorization call to complete. This is mainly to make testing easier. + * @param waitTime the max time to wait + * @throws IllegalStateException if the wait time is exceeded or the call receives an {@link InterruptedException} + */ + public void waitForAuthorizationToComplete(TimeValue waitTime) { try { if (authorizationCompletedLatch.await(waitTime.getSeconds(), TimeUnit.SECONDS) == false) { throw new IllegalStateException("The wait time has expired for authorization to complete."); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointActionTests.java new file mode 100644 index 0000000000000..a640e64c2022d --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointActionTests.java @@ -0,0 +1,85 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.MinimalServiceSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.inference.action.DeleteInferenceEndpointAction; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.junit.After; +import org.junit.Before; + +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class TransportDeleteInferenceEndpointActionTests extends ESTestCase { + + private static final TimeValue TIMEOUT = TimeValue.timeValueSeconds(30); + + private TransportDeleteInferenceEndpointAction action; + private ThreadPool threadPool; + private ModelRegistry modelRegistry; + + @Before + public void setUp() throws Exception { + super.setUp(); + modelRegistry = new ModelRegistry(mock(Client.class)); + threadPool = createThreadPool(inferenceUtilityPool()); + action = new TransportDeleteInferenceEndpointAction( + mock(TransportService.class), + mock(ClusterService.class), + threadPool, + mock(ActionFilters.class), + mock(IndexNameExpressionResolver.class), + modelRegistry, + mock(InferenceServiceRegistry.class) + ); + } + + @After + public void tearDown() throws Exception { + super.tearDown(); + terminate(threadPool); + } + + public void testFailsToDelete_ADefaultEndpoint() { + modelRegistry.addDefaultIds( + new InferenceService.DefaultConfigId("model-id", MinimalServiceSettings.chatCompletion(), mock(InferenceService.class)) + ); + + var listener = new PlainActionFuture(); + + action.masterOperation( + mock(Task.class), + new DeleteInferenceEndpointAction.Request("model-id", TaskType.CHAT_COMPLETION, true, false), + mock(ClusterState.class), + listener + ); + + var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + exception.getMessage(), + is("[model-id] is a reserved inference endpoint. " + "Cannot delete a reserved inference endpoint.") + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java index 162bcc8f09713..65e4d049ef58b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java @@ -28,6 +28,7 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.SearchResponseUtils; @@ -41,6 +42,7 @@ import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Map; +import java.util.Set; import java.util.concurrent.TimeUnit; import static org.elasticsearch.core.Strings.format; @@ -52,6 +54,8 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class ModelRegistryTests extends ESTestCase { @@ -295,6 +299,37 @@ public void testStoreModel_ThrowsException_WhenFailureIsNotAVersionConflict() { ); } + public void testRemoveDefaultConfigs_DoesNotCallClient_WhenPassedAnEmptySet() { + var client = mock(Client.class); + + var registry = new ModelRegistry(client); + var listener = new PlainActionFuture(); + + registry.removeDefaultConfigs(Set.of(), listener); + + assertTrue(listener.actionGet(TIMEOUT)); + verify(client, times(0)).execute(any(), any(), any()); + } + + public void testDeleteModels_Returns_ConflictException_WhenModelIsBeingAdded() { + var client = mockClient(); + + var registry = new ModelRegistry(client); + var model = TestModel.createRandomInstance(); + var newModel = TestModel.createRandomInstance(); + registry.updateModelTransaction(newModel, model, new PlainActionFuture<>()); + + var listener = new PlainActionFuture(); + + registry.deleteModels(Set.of(newModel.getInferenceEntityId()), listener); + var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + exception.getMessage(), + containsString("are currently being updated, please wait until after they are finished updating to delete.") + ); + assertThat(exception.status(), is(RestStatus.CONFLICT)); + } + public void testIdMatchedDefault() { var defaultConfigIds = new ArrayList(); defaultConfigIds.add( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 77ea3059a7b56..104e3ecd4ad35 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -366,6 +366,14 @@ public void testInfer_ThrowsErrorWhenModelIsNotAValidModel() throws IOException private ModelRegistry mockModelRegistry() { var client = mock(Client.class); when(client.threadPool()).thenReturn(threadPool); + + doAnswer(invocationOnMock -> { + @SuppressWarnings("unchecked") + var listener = (ActionListener) invocationOnMock.getArgument(2); + listener.onResponse(true); + + return Void.TYPE; + }).when(client).execute(any(), any(), any()); return new ModelRegistry(client); }