Skip to content
Merged
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.integration;

import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.MinimalServiceSettings;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.reindex.ReindexPlugin;
import org.elasticsearch.test.ESSingleNodeTestCase;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationHandler;
import org.junit.After;
import org.junit.Before;

import java.util.Collection;
import java.util.EnumSet;
import java.util.List;
import java.util.concurrent.TimeUnit;

import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.hamcrest.CoreMatchers.is;
import static org.mockito.Mockito.mock;

public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase {
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);

private ModelRegistry modelRegistry;
private final MockWebServer webServer = new MockWebServer();
private ThreadPool threadPool;
private String gatewayUrl;

@Before
public void createComponents() throws Exception {
threadPool = createThreadPool(inferenceUtilityPool());
webServer.start();
gatewayUrl = getUrl(webServer);
modelRegistry = new ModelRegistry(client());
}

@After
public void shutdown() {
terminate(threadPool);
webServer.close();
}

@Override
protected boolean resetNodeAfterTest() {
return true;
}

@Override
protected Collection<Class<? extends Plugin>> getPlugins() {
return pluginList(ReindexPlugin.class);
}

public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCorrect() throws Exception {
String responseJson = """
{
"models": [
{
"model_name": "rainbow-sprinkles",
"task_types": ["chat"]
}
]
}
""";

webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));

try (var service = createElasticInferenceService()) {
service.waitForAuthorizationToComplete(TIMEOUT);
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
assertThat(
service.defaultConfigIds(),
is(
List.of(
new InferenceService.DefaultConfigId(".rainbow-sprinkles-elastic", MinimalServiceSettings.chatCompletion(), service)
)
)
);
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));

PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
service.defaultConfigs(listener);
assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
}
}

public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty() throws Exception {
{
String responseJson = """
{
"models": [
{
"model_name": "rainbow-sprinkles",
"task_types": ["chat"]
}
]
}
""";

webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));

try (var service = createElasticInferenceService()) {
service.waitForAuthorizationToComplete(TIMEOUT);
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
assertThat(
service.defaultConfigIds(),
is(
List.of(
new InferenceService.DefaultConfigId(
".rainbow-sprinkles-elastic",
MinimalServiceSettings.chatCompletion(),
service
)
)
)
);
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));

PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
service.defaultConfigs(listener);
assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));

var getModelListener = new PlainActionFuture<UnparsedModel>();
// persists the default endpoints
modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener);

var inferenceEntity = getModelListener.actionGet(TIMEOUT);
assertThat(inferenceEntity.inferenceEntityId(), is(".rainbow-sprinkles-elastic"));
assertThat(inferenceEntity.taskType(), is(TaskType.CHAT_COMPLETION));
}
}
{
String noAuthorizationResponseJson = """
{
"models": []
}
""";

webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson));

try (var service = createElasticInferenceService()) {
service.waitForAuthorizationToComplete(TIMEOUT);
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
assertTrue(service.defaultConfigIds().isEmpty());
assertThat(service.supportedTaskTypes(), is(EnumSet.noneOf(TaskType.class)));

var getModelListener = new PlainActionFuture<UnparsedModel>();
modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener);

var exception = expectThrows(ResourceNotFoundException.class, () -> getModelListener.actionGet(TIMEOUT));
assertThat(exception.getMessage(), is("Inference endpoint not found [.rainbow-sprinkles-elastic]"));
}
}
}

public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnAuthForIt() throws Exception {
{
String responseJson = """
{
"models": [
{
"model_name": "rainbow-sprinkles",
"task_types": ["chat"]
},
{
"model_name": "elser-v2",
"task_types": ["embed/text/sparse"]
}
]
}
""";

webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));

try (var service = createElasticInferenceService()) {
service.waitForAuthorizationToComplete(TIMEOUT);
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
assertThat(
service.defaultConfigIds(),
is(
List.of(
new InferenceService.DefaultConfigId(
".rainbow-sprinkles-elastic",
MinimalServiceSettings.chatCompletion(),
service
)
)
)
);
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING)));

PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
service.defaultConfigs(listener);
assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));

var getModelListener = new PlainActionFuture<UnparsedModel>();
// persists the default endpoints
modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener);

var inferenceEntity = getModelListener.actionGet(TIMEOUT);
assertThat(inferenceEntity.inferenceEntityId(), is(".rainbow-sprinkles-elastic"));
assertThat(inferenceEntity.taskType(), is(TaskType.CHAT_COMPLETION));
}
}
{
String noAuthorizationResponseJson = """
{
"models": [
{
"model_name": "elser-v2",
"task_types": ["embed/text/sparse"]
}
]
}
""";

webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson));

try (var service = createElasticInferenceService()) {
service.waitForAuthorizationToComplete(TIMEOUT);
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
assertTrue(service.defaultConfigIds().isEmpty());
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING)));

var getModelListener = new PlainActionFuture<UnparsedModel>();
modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener);

var exception = expectThrows(ResourceNotFoundException.class, () -> getModelListener.actionGet(TIMEOUT));
assertThat(exception.getMessage(), is("Inference endpoint not found [.rainbow-sprinkles-elastic]"));
}
}
}

private ElasticInferenceService createElasticInferenceService() {
var httpManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, httpManager);

return new ElasticInferenceService(
senderFactory,
createWithEmptySettings(threadPool),
new ElasticInferenceServiceComponents(gatewayUrl),
modelRegistry,
new ElasticInferenceServiceAuthorizationHandler(gatewayUrl, threadPool)
);
}
}
Loading