Skip to content
Merged
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.integration;

import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.MinimalServiceSettings;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.reindex.ReindexPlugin;
import org.elasticsearch.test.ESSingleNodeTestCase;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationHandler;
import org.junit.After;
import org.junit.Before;

import java.util.Collection;
import java.util.EnumSet;
import java.util.List;
import java.util.concurrent.TimeUnit;

import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.hamcrest.CoreMatchers.is;
import static org.mockito.Mockito.mock;

public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase {
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);

private ModelRegistry modelRegistry;
private final MockWebServer webServer = new MockWebServer();
private ThreadPool threadPool;
private String gatewayUrl;

@Before
public void createComponents() throws Exception {
threadPool = createThreadPool(inferenceUtilityPool());
webServer.start();
gatewayUrl = getUrl(webServer);
modelRegistry = new ModelRegistry(client());
}

@After
public void shutdown() {
terminate(threadPool);
webServer.close();
}

@Override
protected boolean resetNodeAfterTest() {
return true;
}

@Override
protected Collection<Class<? extends Plugin>> getPlugins() {
return pluginList(ReindexPlugin.class);
}

public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCorrect() throws Exception {
String responseJson = """
{
"models": [
{
"model_name": "rainbow-sprinkles",
"task_types": ["chat"]
}
]
}
""";

webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));

try (var service = createElasticInferenceService()) {
service.waitForAuthorizationToComplete(TIMEOUT);
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
assertThat(
service.defaultConfigIds(),
is(
List.of(
new InferenceService.DefaultConfigId(".rainbow-sprinkles-elastic", MinimalServiceSettings.chatCompletion(), service)
)
)
);
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));

PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
service.defaultConfigs(listener);
assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
}
}

public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty() throws Exception {
{
String responseJson = """
{
"models": [
{
"model_name": "rainbow-sprinkles",
"task_types": ["chat"]
}
]
}
""";

webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));

try (var service = createElasticInferenceService()) {
service.waitForAuthorizationToComplete(TIMEOUT);
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
assertThat(
service.defaultConfigIds(),
is(
List.of(
new InferenceService.DefaultConfigId(
".rainbow-sprinkles-elastic",
MinimalServiceSettings.chatCompletion(),
service
)
)
)
);
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));

PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
service.defaultConfigs(listener);
assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));

var getModelListener = new PlainActionFuture<UnparsedModel>();
// persists the default endpoints
modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener);

var inferenceEntity = getModelListener.actionGet(TIMEOUT);
assertThat(inferenceEntity.inferenceEntityId(), is(".rainbow-sprinkles-elastic"));
assertThat(inferenceEntity.taskType(), is(TaskType.CHAT_COMPLETION));
}
}
{
String noAuthorizationResponseJson = """
{
"models": []
}
""";

webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson));

try (var service = createElasticInferenceService()) {
service.waitForAuthorizationToComplete(TIMEOUT);
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
assertTrue(service.defaultConfigIds().isEmpty());
assertThat(service.supportedTaskTypes(), is(EnumSet.noneOf(TaskType.class)));

var getModelListener = new PlainActionFuture<UnparsedModel>();
modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener);

var exception = expectThrows(ResourceNotFoundException.class, () -> getModelListener.actionGet(TIMEOUT));
assertThat(exception.getMessage(), is("Inference endpoint not found [.rainbow-sprinkles-elastic]"));
}
}
}

public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnAuthForIt() throws Exception {
{
String responseJson = """
{
"models": [
{
"model_name": "rainbow-sprinkles",
"task_types": ["chat"]
},
{
"model_name": "elser-v2",
"task_types": ["embed/text/sparse"]
}
]
}
""";

webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));

try (var service = createElasticInferenceService()) {
service.waitForAuthorizationToComplete(TIMEOUT);
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
assertThat(
service.defaultConfigIds(),
is(
List.of(
new InferenceService.DefaultConfigId(
".rainbow-sprinkles-elastic",
MinimalServiceSettings.chatCompletion(),
service
)
)
)
);
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING)));

PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
service.defaultConfigs(listener);
assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));

var getModelListener = new PlainActionFuture<UnparsedModel>();
// persists the default endpoints
modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener);

var inferenceEntity = getModelListener.actionGet(TIMEOUT);
assertThat(inferenceEntity.inferenceEntityId(), is(".rainbow-sprinkles-elastic"));
assertThat(inferenceEntity.taskType(), is(TaskType.CHAT_COMPLETION));
}
}
{
String noAuthorizationResponseJson = """
{
"models": [
{
"model_name": "elser-v2",
"task_types": ["embed/text/sparse"]
}
]
}
""";

webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson));

try (var service = createElasticInferenceService()) {
service.waitForAuthorizationToComplete(TIMEOUT);
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
assertTrue(service.defaultConfigIds().isEmpty());
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING)));

var getModelListener = new PlainActionFuture<UnparsedModel>();
modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener);

var exception = expectThrows(ResourceNotFoundException.class, () -> getModelListener.actionGet(TIMEOUT));
assertThat(exception.getMessage(), is("Inference endpoint not found [.rainbow-sprinkles-elastic]"));
}
}
}

private ElasticInferenceService createElasticInferenceService() {
var httpManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, httpManager);

return new ElasticInferenceService(
senderFactory,
createWithEmptySettings(threadPool),
new ElasticInferenceServiceComponents(gatewayUrl),
modelRegistry,
new ElasticInferenceServiceAuthorizationHandler(gatewayUrl, threadPool)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.IndexNotFoundException;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.inference.InferenceService;
Expand Down Expand Up @@ -51,7 +53,9 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Function;
Expand All @@ -70,6 +74,7 @@
import static org.mockito.Mockito.mock;

public class ModelRegistryIT extends ESSingleNodeTestCase {
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);

private ModelRegistry modelRegistry;

Expand Down Expand Up @@ -195,6 +200,57 @@ public void testDeleteModel() throws Exception {
assertThat(exceptionHolder.get().getMessage(), containsString("Inference endpoint not found [model1]"));
}

public void testNonExistentDeleteModel_DoesNotThrowAnException() {
var listener = new PlainActionFuture<Boolean>();

modelRegistry.deleteModel("non-existent-model", listener);
assertTrue(listener.actionGet(TIMEOUT));
}

public void testRemoveDefaultConfigs_DoesNotThrowAnException_WhenSearchingForNonExistentInferenceEndpointIds() {
var listener = new PlainActionFuture<Boolean>();

modelRegistry.deleteModels(Set.of("non-existent-model", "abc"), listener);
assertTrue(listener.actionGet(TIMEOUT));
}

public void testRemoveDefaultConfigs_RemovesModelsFromPersistentStorage_AndInMemoryCache() {
var service = mock(InferenceService.class);

var defaultConfigs = new ArrayList<Model>();
var defaultIds = new ArrayList<InferenceService.DefaultConfigId>();
for (var id : new String[] { "model1", "model2", "model3" }) {
var modelSettings = ModelRegistryTests.randomMinimalServiceSettings();
defaultConfigs.add(createModel(id, modelSettings.taskType(), "name"));
defaultIds.add(new InferenceService.DefaultConfigId(id, modelSettings, service));
}

doAnswer(invocation -> {
@SuppressWarnings("unchecked")
var listener = (ActionListener<List<Model>>) invocation.getArguments()[0];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's another API that does the implicit typecasting magic, if you prefer:

ActionListener<List<Model>> listener = invocation.getArgument(0);

listener.onResponse(defaultConfigs);
return Void.TYPE;
}).when(service).defaultConfigs(any());

defaultIds.forEach(modelRegistry::addDefaultIds);

var getModelsListener = new PlainActionFuture<List<UnparsedModel>>();
modelRegistry.getAllModels(true, getModelsListener);
var unparsedModels = getModelsListener.actionGet(TIMEOUT);
assertThat(unparsedModels.size(), is(3));

var removeModelsListener = new PlainActionFuture<Boolean>();

modelRegistry.removeDefaultConfigs(Set.of("model1", "model2", "model3"), removeModelsListener);
assertTrue(removeModelsListener.actionGet(TIMEOUT));

var getModelsAfterDeleteListener = new PlainActionFuture<List<UnparsedModel>>();
// the models should have been removed from the in memory cache, if not they they will be persisted again by this call
modelRegistry.getAllModels(true, getModelsAfterDeleteListener);
var unparsedModelsAfterDelete = getModelsAfterDeleteListener.actionGet(TIMEOUT);
assertThat(unparsedModelsAfterDelete.size(), is(0));
}

public void testGetModelsByTaskType() throws InterruptedException {
var service = "foo";
var sparseAndTextEmbeddingModels = new ArrayList<Model>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,17 @@ private void doExecuteForked(
ClusterState state,
ActionListener<DeleteInferenceEndpointAction.Response> masterListener
) {
if (modelRegistry.containsDefaultConfigId(request.getInferenceEndpointId())) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change will prevent REST calls from deleting default inference endpoints.

masterListener.onFailure(
new ElasticsearchStatusException(
"[{}] is a reserved inference endpoint. Cannot delete a reserved inference endpoint.",
RestStatus.BAD_REQUEST,
request.getInferenceEndpointId()
)
);
return;
}

SubscribableListener.<UnparsedModel>newForked(modelConfigListener -> {
// Get the model from the registry

Expand Down
Loading