diff --git a/x-pack/plugin/build.gradle b/x-pack/plugin/build.gradle index ea715b0d5c921..f947a845a056d 100644 --- a/x-pack/plugin/build.gradle +++ b/x-pack/plugin/build.gradle @@ -155,6 +155,7 @@ tasks.named("yamlRestCompatTestTransform").configure({ task -> task.skipTest("esql/46_downsample/Query stats on downsampled index", "Extra function required to enable the field type") task.skipTest("esql/46_downsample/Render stats from downsampled index", "Extra function required to enable the field type") task.skipTest("esql/46_downsample/Sort from multiple indices one with aggregate metric double", "Extra function required to enable the field type") + task.skipTest("inference/inference_crud/Test get missing model", "Error message changed") }) tasks.named('yamlRestCompatTest').configure { diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/SemanticMatchTestCase.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/SemanticMatchTestCase.java index aada75f151d66..48d421d970bf5 100644 --- a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/SemanticMatchTestCase.java +++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/SemanticMatchTestCase.java @@ -67,7 +67,7 @@ public void testWithInferenceNotConfigured() { """; ResponseException re = expectThrows(ResponseException.class, () -> runEsqlQuery(query)); - assertThat(re.getMessage(), containsString("Inference endpoint not found")); + assertThat(re.getMessage(), containsString("Inference endpoint [inexistent] not found")); assertEquals(404, re.getResponse().getStatusLine().getStatusCode()); } diff --git a/x-pack/plugin/inference/build.gradle b/x-pack/plugin/inference/build.gradle index eb9372e675831..7afd9903090b3 100644 --- a/x-pack/plugin/inference/build.gradle +++ b/x-pack/plugin/inference/build.gradle @@ -9,6 +9,7 @@ apply plugin: 'elasticsearch.internal-es-plugin' apply plugin: 'elasticsearch.internal-cluster-test' apply plugin: 'elasticsearch.internal-yaml-rest-test' apply plugin: 'elasticsearch.internal-test-artifact' +apply plugin: 'elasticsearch.yaml-rest-compat-test' restResources { restApi { @@ -407,6 +408,13 @@ tasks.named('yamlRestTest') { usesDefaultDistribution("Uses the inference API") } +tasks.named("yamlRestCompatTestTransform").configure({ task -> + task.skipTest("inference/40_semantic_text_query/Query a field with an invalid inference ID", "Error message changed") + task.skipTest("inference/40_semantic_text_query/Query a field with an invalid search inference ID", "Error message changed") + task.skipTest("inference/70_text_similarity_rank_retriever/Text similarity reranking fails if the inference ID does not exist", "Error message changed") + task.skipTest("inference/70_text_similarity_rank_retriever/Text similarity reranking fails if the inference ID does not exist and result set is empty", "Error message changed") +}) + artifacts { restXpackTests(new File(projectDir, "src/yamlRestTest/resources/rest-api-spec/test")) } diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java index 09834e6a91210..f809444a2a73d 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java @@ -27,11 +27,6 @@ public class BaseMockEISAuthServerTest extends ESRestTestCase { protected static final MockElasticInferenceServiceAuthorizationServer mockEISServer = new MockElasticInferenceServiceAuthorizationServer(); - static { - // Ensure that the mock EIS server has an authorized response prior to the cluster starting - mockEISServer.enqueueAuthorizeAllModelsResponse(); - } - private static ElasticsearchCluster cluster = ElasticsearchCluster.local() .distribution(DistributionType.DEFAULT) .setting("xpack.license.self_generated.type", "trial") diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java index ecc3bcd508bb6..4d98bb2801c94 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java @@ -10,7 +10,6 @@ package org.elasticsearch.xpack.inference; import org.elasticsearch.inference.TaskType; -import org.junit.BeforeClass; import java.io.IOException; import java.util.List; @@ -22,24 +21,11 @@ import static org.hamcrest.Matchers.is; public class InferenceGetModelsWithElasticInferenceServiceIT extends BaseMockEISAuthServerTest { - - /** - * This is done before the class because I've run into issues where another class that extends {@link BaseMockEISAuthServerTest} - * results in an authorization response not being queued up for the new Elasticsearch Node in time. When the node starts up, it - * retrieves authorization. If the request isn't queued up when that happens the tests will fail. From my testing locally it seems - * like the base class's static functionality to queue a response is only done once and not for each subclass. - * - * My understanding is that the @Before will be run after the node starts up and wouldn't be sufficient to handle - * this scenario. That is why this needs to be @BeforeClass. - */ - @BeforeClass - public static void init() { - // Ensure the mock EIS server has an authorized response ready - mockEISServer.enqueueAuthorizeAllModelsResponse(); - } - public void testGetDefaultEndpoints() throws IOException { + mockEISServer.enqueueAuthorizeAllModelsResponse(); var allModels = getAllModels(); + + mockEISServer.enqueueAuthorizeAllModelsResponse(); var chatCompletionModels = getModels("_all", TaskType.CHAT_COMPLETION); assertThat(allModels, hasSize(7)); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index f86c92c02db48..95cd94cb4b6f2 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -13,7 +13,6 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.inference.TaskType; import org.junit.Before; -import org.junit.BeforeClass; import java.io.IOException; import java.util.List; @@ -32,21 +31,6 @@ public void setUp() throws Exception { mockEISServer.enqueueAuthorizeAllModelsResponse(); } - /** - * This is done before the class because I've run into issues where another class that extends {@link BaseMockEISAuthServerTest} - * results in an authorization response not being queued up for the new Elasticsearch Node in time. When the node starts up, it - * retrieves authorization. If the request isn't queued up when that happens the tests will fail. From my testing locally it seems - * like the base class's static functionality to queue a response is only done once and not for each subclass. - * - * My understanding is that the @Before will be run after the node starts up and wouldn't be sufficient to handle - * this scenario. That is why this needs to be @BeforeClass. - */ - @BeforeClass - public static void init() { - // Ensure the mock EIS server has an authorized response ready - mockEISServer.enqueueAuthorizeAllModelsResponse(); - } - public void testGetServicesWithoutTaskType() throws IOException { assertThat( allProviders(), diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesWithoutEisIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesWithoutEisIT.java new file mode 100644 index 0000000000000..fea121f161afd --- /dev/null +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesWithoutEisIT.java @@ -0,0 +1,121 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + * + * this file has been contributed to by a Generative AI + */ + +package org.elasticsearch.xpack.inference; + +import org.elasticsearch.client.Request; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.cluster.ElasticsearchCluster; +import org.elasticsearch.test.cluster.local.distribution.DistributionType; +import org.elasticsearch.test.rest.ESRestTestCase; +import org.junit.ClassRule; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.assertStatusOkOrCreated; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasItem; +import static org.hamcrest.Matchers.not; + +public class InferenceGetServicesWithoutEisIT extends ESRestTestCase { + + @ClassRule + public static ElasticsearchCluster cluster = ElasticsearchCluster.local() + .distribution(DistributionType.DEFAULT) + .setting("xpack.license.self_generated.type", "trial") + .setting("xpack.security.enabled", "true") + // This plugin is located in the inference/qa/test-service-plugin package, look for TestInferenceServicePlugin + .plugin("inference-service-test") + .user("x_pack_rest_user", "x-pack-test-password") + .build(); + + @Override + protected String getTestRestCluster() { + return cluster.getHttpAddresses(); + } + + @Override + protected Settings restClientSettings() { + String token = basicAuthHeaderValue("x_pack_rest_user", new SecureString("x-pack-test-password".toCharArray())); + return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build(); + } + + public void testGetServicesWithoutTaskType() throws IOException { + assertThat(allProviders(), not(hasItem("elastic"))); + } + + private List allProviders() throws IOException { + return providers(getAllServices()); + } + + @SuppressWarnings("unchecked") + private List providers(List services) { + return services.stream().map(service -> { + var serviceConfig = (Map) service; + return (String) serviceConfig.get("service"); + }).toList(); + } + + public void testGetServicesWithTextEmbeddingTaskType() throws IOException { + var providers = providersFor(TaskType.TEXT_EMBEDDING); + assertThat(providers.size(), not(equalTo(0))); + assertThat(providers, not(hasItem("elastic"))); + } + + private List providersFor(TaskType taskType) throws IOException { + return providers(getServices(taskType)); + } + + public void testGetServicesWithRerankTaskType() throws IOException { + var providers = providersFor(TaskType.RERANK); + assertThat(providers.size(), not(equalTo(0))); + assertThat(providersFor(TaskType.RERANK), not(hasItem("elastic"))); + } + + public void testGetServicesWithCompletionTaskType() throws IOException { + var providers = providersFor(TaskType.COMPLETION); + assertThat(providers.size(), not(equalTo(0))); + assertThat(providersFor(TaskType.COMPLETION), not(hasItem("elastic"))); + } + + public void testGetServicesWithChatCompletionTaskType() throws IOException { + var providers = providersFor(TaskType.CHAT_COMPLETION); + assertThat(providers.size(), not(equalTo(0))); + assertThat(providersFor(TaskType.CHAT_COMPLETION), not(hasItem("elastic"))); + } + + public void testGetServicesWithSparseEmbeddingTaskType() throws IOException { + var providers = providersFor(TaskType.SPARSE_EMBEDDING); + assertThat(providers.size(), not(equalTo(0))); + assertThat(providersFor(TaskType.SPARSE_EMBEDDING), not(hasItem("elastic"))); + } + + private List getAllServices() throws IOException { + var endpoint = Strings.format("_inference/_services"); + return getInternalAsList(endpoint); + } + + private List getServices(TaskType taskType) throws IOException { + var endpoint = Strings.format("_inference/_services/%s", taskType); + return getInternalAsList(endpoint); + } + + private List getInternalAsList(String endpoint) throws IOException { + var request = new Request("GET", endpoint); + var response = client().performRequest(request); + assertStatusOkOrCreated(response); + return entityAsList(response); + } +} diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceIndicesIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceIndicesIT.java index e59f0617851c3..bdf4f9318ce60 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceIndicesIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceIndicesIT.java @@ -167,7 +167,7 @@ public void testRetrievingInferenceEndpoint_ThrowsException_WhenIndexNodeIsNotAv var proxyResponse = sendInferenceProxyRequest(inferenceId); var exception = expectThrows(ElasticsearchException.class, () -> proxyResponse.actionGet(TEST_REQUEST_TIMEOUT)); - assertThat(exception.toString(), containsString("Failed to load inference endpoint with secrets [test-index-id-2]")); + assertThat(exception.toString(), containsString("Failed to load inference endpoint [test-index-id-2]")); var causeException = exception.getCause(); assertThat(causeException, instanceOf(SearchPhaseExecutionException.class)); @@ -196,7 +196,7 @@ public void testRetrievingInferenceEndpoint_ThrowsException_WhenSecretsIndexNode var proxyResponse = sendInferenceProxyRequest(inferenceId); var exception = expectThrows(ElasticsearchException.class, () -> proxyResponse.actionGet(TEST_REQUEST_TIMEOUT)); - assertThat(exception.toString(), containsString("Failed to load inference endpoint with secrets [test-secrets-index-id]")); + assertThat(exception.toString(), containsString("Failed to load inference endpoint [test-secrets-index-id]")); var causeException = exception.getCause(); diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java deleted file mode 100644 index 72109e43bb6ac..0000000000000 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java +++ /dev/null @@ -1,359 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.integration; - -import org.elasticsearch.ResourceNotFoundException; -import org.elasticsearch.action.support.PlainActionFuture; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.core.TimeValue; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.inference.InferenceService; -import org.elasticsearch.inference.MinimalServiceSettings; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.inference.UnparsedModel; -import org.elasticsearch.plugins.Plugin; -import org.elasticsearch.reindex.ReindexPlugin; -import org.elasticsearch.test.ESSingleNodeTestCase; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.test.http.MockResponse; -import org.elasticsearch.test.http.MockWebServer; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; -import org.elasticsearch.xpack.inference.external.http.HttpClientManager; -import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; -import org.elasticsearch.xpack.inference.logging.ThrottlerManager; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettingsTests; -import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; -import org.junit.After; -import org.junit.Before; - -import java.util.Collection; -import java.util.EnumSet; -import java.util.List; -import java.util.concurrent.TimeUnit; - -import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; -import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; -import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; -import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; -import static org.hamcrest.CoreMatchers.is; -import static org.hamcrest.Matchers.containsInAnyOrder; -import static org.mockito.Mockito.mock; - -@ESTestCase.WithoutEntitlements // due to dependency issue ES-12435 -public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase { - private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); - - private ModelRegistry modelRegistry; - private final MockWebServer webServer = new MockWebServer(); - private ThreadPool threadPool; - private String gatewayUrl; - - @Before - public void createComponents() throws Exception { - threadPool = createThreadPool(inferenceUtilityExecutors()); - webServer.start(); - gatewayUrl = getUrl(webServer); - modelRegistry = node().injector().getInstance(ModelRegistry.class); - } - - @After - public void shutdown() { - terminate(threadPool); - webServer.close(); - } - - @Override - protected boolean resetNodeAfterTest() { - return true; - } - - @Override - protected Collection> getPlugins() { - return pluginList(ReindexPlugin.class, LocalStateInferencePlugin.class); - } - - public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCorrect() throws Exception { - String responseJson = """ - { - "models": [ - { - "model_name": "rainbow-sprinkles", - "task_types": ["chat"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - try (var service = createElasticInferenceService()) { - ensureAuthorizationCallFinished(service); - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - assertThat( - service.defaultConfigIds(), - is( - List.of( - new InferenceService.DefaultConfigId( - ".rainbow-sprinkles-elastic", - MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME), - service - ) - ) - ) - ); - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - - PlainActionFuture> listener = new PlainActionFuture<>(); - service.defaultConfigs(listener); - assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); - } - } - - public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty() throws Exception { - { - String responseJson = """ - { - "models": [ - { - "model_name": "rainbow-sprinkles", - "task_types": ["chat"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - try (var service = createElasticInferenceService()) { - ensureAuthorizationCallFinished(service); - - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - assertThat( - service.defaultConfigIds(), - is( - List.of( - new InferenceService.DefaultConfigId( - ".rainbow-sprinkles-elastic", - MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME), - service - ) - ) - ) - ); - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - - PlainActionFuture> listener = new PlainActionFuture<>(); - service.defaultConfigs(listener); - assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); - - var getModelListener = new PlainActionFuture(); - // persists the default endpoints - modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener); - - var inferenceEntity = getModelListener.actionGet(TIMEOUT); - assertThat(inferenceEntity.inferenceEntityId(), is(".rainbow-sprinkles-elastic")); - assertThat(inferenceEntity.taskType(), is(TaskType.CHAT_COMPLETION)); - } - } - { - String noAuthorizationResponseJson = """ - { - "models": [] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson)); - - try (var service = createElasticInferenceService()) { - ensureAuthorizationCallFinished(service); - - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - assertTrue(service.defaultConfigIds().isEmpty()); - assertThat(service.supportedTaskTypes(), is(EnumSet.noneOf(TaskType.class))); - - var getModelListener = new PlainActionFuture(); - modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener); - - var exception = expectThrows(ResourceNotFoundException.class, () -> getModelListener.actionGet(TIMEOUT)); - assertThat(exception.getMessage(), is("Inference endpoint not found [.rainbow-sprinkles-elastic]")); - } - } - } - - public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnAuthForIt() throws Exception { - { - String responseJson = """ - { - "models": [ - { - "model_name": "elser_model_2", - "task_types": ["embed/text/sparse"] - }, - { - "model_name": "rainbow-sprinkles", - "task_types": ["chat"] - }, - { - "model_name": "multilingual-embed-v1", - "task_types": ["embed/text/dense"] - }, - { - "model_name": "rerank-v1", - "task_types": ["rerank/text/text-similarity"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - try (var service = createElasticInferenceService()) { - ensureAuthorizationCallFinished(service); - - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - assertThat( - service.defaultConfigIds(), - containsInAnyOrder( - new InferenceService.DefaultConfigId( - ".elser-2-elastic", - MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME), - service - ), - new InferenceService.DefaultConfigId( - ".rainbow-sprinkles-elastic", - MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME), - service - ), - new InferenceService.DefaultConfigId( - ".multilingual-embed-v1-elastic", - MinimalServiceSettings.textEmbedding( - ElasticInferenceService.NAME, - ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS, - ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(), - DenseVectorFieldMapper.ElementType.FLOAT - ), - service - ), - new InferenceService.DefaultConfigId( - ".rerank-v1-elastic", - MinimalServiceSettings.rerank(ElasticInferenceService.NAME), - service - ) - ) - ); - assertThat( - service.supportedTaskTypes(), - is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.RERANK, TaskType.TEXT_EMBEDDING)) - ); - - PlainActionFuture> listener = new PlainActionFuture<>(); - service.defaultConfigs(listener); - assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".elser-2-elastic")); - assertThat( - listener.actionGet(TIMEOUT).get(1).getConfigurations().getInferenceEntityId(), - is(".multilingual-embed-v1-elastic") - ); - assertThat(listener.actionGet(TIMEOUT).get(2).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); - assertThat(listener.actionGet(TIMEOUT).get(3).getConfigurations().getInferenceEntityId(), is(".rerank-v1-elastic")); - - var getModelListener = new PlainActionFuture(); - // persists the default endpoints - modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener); - - var inferenceEntity = getModelListener.actionGet(TIMEOUT); - assertThat(inferenceEntity.inferenceEntityId(), is(".rainbow-sprinkles-elastic")); - assertThat(inferenceEntity.taskType(), is(TaskType.CHAT_COMPLETION)); - } - } - { - String noAuthorizationResponseJson = """ - { - "models": [ - { - "model_name": "elser_model_2", - "task_types": ["embed/text/sparse"] - }, - { - "model_name": "rerank-v1", - "task_types": ["rerank/text/text-similarity"] - }, - { - "model_name": "multilingual-embed-v1", - "task_types": ["embed/text/dense"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson)); - - try (var service = createElasticInferenceService()) { - ensureAuthorizationCallFinished(service); - - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - assertThat( - service.defaultConfigIds(), - containsInAnyOrder( - new InferenceService.DefaultConfigId( - ".elser-2-elastic", - MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME), - service - ), - new InferenceService.DefaultConfigId( - ".multilingual-embed-v1-elastic", - MinimalServiceSettings.textEmbedding( - ElasticInferenceService.NAME, - ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS, - ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(), - DenseVectorFieldMapper.ElementType.FLOAT - ), - service - ), - new InferenceService.DefaultConfigId( - ".rerank-v1-elastic", - MinimalServiceSettings.rerank(ElasticInferenceService.NAME), - service - ) - ) - ); - assertThat( - service.supportedTaskTypes(), - is(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK)) - ); - - var getModelListener = new PlainActionFuture(); - modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener); - var exception = expectThrows(ResourceNotFoundException.class, () -> getModelListener.actionGet(TIMEOUT)); - assertThat(exception.getMessage(), is("Inference endpoint not found [.rainbow-sprinkles-elastic]")); - } - } - } - - private void ensureAuthorizationCallFinished(ElasticInferenceService service) { - service.onNodeStarted(); - service.waitForFirstAuthorizationToComplete(TIMEOUT); - } - - private ElasticInferenceService createElasticInferenceService() { - var httpManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, httpManager); - - return new ElasticInferenceService( - senderFactory, - createWithEmptySettings(threadPool), - ElasticInferenceServiceSettingsTests.create(gatewayUrl), - modelRegistry, - new ElasticInferenceServiceAuthorizationRequestHandler(gatewayUrl, threadPool), - mockClusterServiceEmpty() - ); - } -} diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisBaseIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisBaseIT.java new file mode 100644 index 0000000000000..6d320a00d695c --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisBaseIT.java @@ -0,0 +1,90 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.integration; + +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.reindex.ReindexPlugin; +import org.elasticsearch.test.ESSingleNodeTestCase; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.registry.ModelRegistryTests; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.integration.ModelRegistryIT.createModel; +import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings.ELASTIC_INFERENCE_SERVICE_URL; + +public abstract class ModelRegistryEisBaseIT extends ESSingleNodeTestCase { + protected static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + protected static final MockWebServer webServer = new MockWebServer(); + + protected ModelRegistry modelRegistry; + private String eisUrl; + + public ModelRegistryEisBaseIT() {} + + public ModelRegistryEisBaseIT(String eisUrl) { + this.eisUrl = eisUrl; + } + + @BeforeClass + public static void init() throws Exception { + // This must be called prior to retrieving the hostname and port from the mock server. + webServer.start(); + } + + @AfterClass + public static void shutdown() { + webServer.close(); + } + + @Before + public void createComponents() { + modelRegistry = node().injector().getInstance(ModelRegistry.class); + modelRegistry.clearDefaultIds(); + } + + @Override + protected Collection> getPlugins() { + return pluginList(ReindexPlugin.class, LocalStateInferencePlugin.class); + } + + @Override + protected Settings nodeSettings() { + return Settings.builder().put(super.nodeSettings()).put(ELASTIC_INFERENCE_SERVICE_URL.getKey(), getEisUrl()).build(); + } + + private String getEisUrl() { + return eisUrl != null ? eisUrl : getUrl(webServer); + } + + protected void initializeModels() { + var service = "foo"; + var sparseAndTextEmbeddingModels = new ArrayList(); + sparseAndTextEmbeddingModels.add(createModel("sparse-1", TaskType.SPARSE_EMBEDDING, service)); + sparseAndTextEmbeddingModels.add(createModel("sparse-2", TaskType.SPARSE_EMBEDDING, service)); + sparseAndTextEmbeddingModels.add(createModel("sparse-3", TaskType.SPARSE_EMBEDDING, service)); + sparseAndTextEmbeddingModels.add(createModel("embedding-1", TaskType.TEXT_EMBEDDING, service)); + sparseAndTextEmbeddingModels.add(createModel("embedding-2", TaskType.TEXT_EMBEDDING, service)); + + for (var model : sparseAndTextEmbeddingModels) { + ModelRegistryTests.assertStoreModel(modelRegistry, model); + } + } +} diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisGetModelIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisGetModelIT.java new file mode 100644 index 0000000000000..33786f29e8c7e --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisGetModelIT.java @@ -0,0 +1,510 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.integration; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings; + +import java.util.Arrays; +import java.util.Map; +import java.util.function.BiConsumer; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.containsString; + +/** + * Parameterized tests for {@link ModelRegistry#getModel} and {@link ModelRegistry#getModelWithSecrets}. + */ +public class ModelRegistryEisGetModelIT extends ModelRegistryEisBaseIT { + private final TestCase testCase; + + public ModelRegistryEisGetModelIT(TestCase testCase) { + super(); + this.testCase = testCase; + } + + public record TestCase( + String description, + BiConsumer> registryCall, + String responseJson, + @Nullable UnparsedModel expectedResult, + @Nullable String failureMessage, + @Nullable RestStatus failureStatus + ) {} + + private static class TestCaseBuilder { + private final String description; + private final BiConsumer> registryCall; + private final String responseJson; + private UnparsedModel expectedResult; + private String failureMessage; + private RestStatus failureStatus; + + TestCaseBuilder(String description, BiConsumer> registryCall, String responseJson) { + this.description = description; + this.registryCall = registryCall; + this.responseJson = responseJson; + } + + public TestCaseBuilder withSuccessfulResult(UnparsedModel expectedResult) { + this.expectedResult = expectedResult; + return this; + } + + public TestCaseBuilder withFailure(String failure, RestStatus status) { + this.failureMessage = failure; + this.failureStatus = status; + return this; + } + + public TestCase build() { + return new TestCase(description, registryCall, responseJson, expectedResult, failureMessage, failureStatus); + } + } + + @ParametersFactory + public static Iterable parameters() { + return Arrays.asList( + new TestCase[][] { + // getModel calls + { + new TestCaseBuilder( + "getModel retrieves eis chat completion preconfigured endpoint", + (modelRegistry, listener) -> modelRegistry.getModel( + ElasticInferenceServiceMinimalSettings.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, + listener + ), + """ + { + "models": [ + { + "model_name": "rainbow-sprinkles", + "task_types": ["chat"] + } + ] + } + """ + ).withSuccessfulResult( + new UnparsedModel( + ElasticInferenceServiceMinimalSettings.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, + TaskType.CHAT_COMPLETION, + ElasticInferenceService.NAME, + Map.of( + ModelConfigurations.SERVICE_SETTINGS, + Map.of(ServiceFields.MODEL_ID, ElasticInferenceServiceMinimalSettings.DEFAULT_CHAT_COMPLETION_MODEL_ID_V1) + ), + Map.of() + ) + ).build() }, + { + new TestCaseBuilder( + "getModel throws an exception when retrieving eis " + + "chat completion preconfigured endpoint and it isn't authorized", + (modelRegistry, listener) -> modelRegistry.getModel( + ElasticInferenceServiceMinimalSettings.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, + listener + ), + """ + { + "models": [ + ] + } + """ + ).withFailure( + "Unable to retrieve the preconfigured inference endpoint " + + "[.rainbow-sprinkles-elastic] from the Elastic Inference Service", + RestStatus.BAD_REQUEST + ).build() }, + { + new TestCaseBuilder( + "getModel retrieves eis elser preconfigured endpoint", + (modelRegistry, listener) -> modelRegistry.getModel( + ElasticInferenceServiceMinimalSettings.DEFAULT_ELSER_ENDPOINT_ID_V2, + listener + ), + """ + { + "models": [ + { + "model_name": "elser_model_2", + "task_types": ["embed/text/sparse"] + } + ] + } + """ + ).withSuccessfulResult( + new UnparsedModel( + ElasticInferenceServiceMinimalSettings.DEFAULT_ELSER_ENDPOINT_ID_V2, + TaskType.SPARSE_EMBEDDING, + ElasticInferenceService.NAME, + Map.of( + ModelConfigurations.SERVICE_SETTINGS, + Map.of(ServiceFields.MODEL_ID, ElasticInferenceServiceMinimalSettings.DEFAULT_ELSER_2_MODEL_ID) + ), + Map.of() + ) + ).build() }, + { + new TestCaseBuilder( + "getModel throws exception when retrieving eis elser preconfigured endpoint and not authorized", + (modelRegistry, listener) -> modelRegistry.getModel( + ElasticInferenceServiceMinimalSettings.DEFAULT_ELSER_ENDPOINT_ID_V2, + listener + ), + """ + { + "models": [ + ] + } + """ + ).withFailure( + "Unable to retrieve the preconfigured inference endpoint [.elser-2-elastic] from the Elastic Inference Service", + RestStatus.BAD_REQUEST + ).build() }, + { + new TestCaseBuilder( + "getModel retrieves eis multilingual embed preconfigured endpoint", + (modelRegistry, listener) -> modelRegistry.getModel( + ElasticInferenceServiceMinimalSettings.DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID, + listener + ), + """ + { + "models": [ + { + "model_name": "multilingual-embed-v1", + "task_types": ["embed/text/dense"] + } + ] + } + """ + ).withSuccessfulResult( + new UnparsedModel( + ElasticInferenceServiceMinimalSettings.DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID, + TaskType.TEXT_EMBEDDING, + ElasticInferenceService.NAME, + Map.of( + ModelConfigurations.SERVICE_SETTINGS, + Map.of( + ServiceFields.MODEL_ID, + ElasticInferenceServiceMinimalSettings.DEFAULT_MULTILINGUAL_EMBED_MODEL_ID, + ServiceFields.SIMILARITY, + SimilarityMeasure.COSINE.toString(), + ServiceFields.DIMENSIONS, + ElasticInferenceServiceMinimalSettings.DENSE_TEXT_EMBEDDINGS_DIMENSIONS, + ServiceFields.ELEMENT_TYPE, + DenseVectorFieldMapper.ElementType.FLOAT.toString() + ) + ), + Map.of() + ) + ).build() }, + { + new TestCaseBuilder( + "getModel throws exception when retrieving eis multilingual embed preconfigured endpoint and not authorized", + (modelRegistry, listener) -> modelRegistry.getModel( + ElasticInferenceServiceMinimalSettings.DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID, + listener + ), + """ + { + "models": [ + ] + } + """ + ).withFailure( + "Unable to retrieve the preconfigured inference endpoint " + + "[.multilingual-embed-v1-elastic] from the Elastic Inference Service", + RestStatus.BAD_REQUEST + ).build() }, + { + new TestCaseBuilder( + "getModel retrieves eis rerank preconfigured endpoint", + (modelRegistry, listener) -> modelRegistry.getModel( + ElasticInferenceServiceMinimalSettings.DEFAULT_RERANK_ENDPOINT_ID_V1, + listener + ), + """ + { + "models": [ + { + "model_name": "rerank-v1", + "task_types": ["rerank/text/text-similarity"] + } + ] + } + """ + ).withSuccessfulResult( + new UnparsedModel( + ElasticInferenceServiceMinimalSettings.DEFAULT_RERANK_ENDPOINT_ID_V1, + TaskType.RERANK, + ElasticInferenceService.NAME, + Map.of( + ModelConfigurations.SERVICE_SETTINGS, + Map.of(ServiceFields.MODEL_ID, ElasticInferenceServiceMinimalSettings.DEFAULT_RERANK_MODEL_ID_V1) + ), + Map.of() + ) + ).build() }, + { + new TestCaseBuilder( + "getModel throws exception when retrieving eis rerank preconfigured endpoint and not authorized", + (modelRegistry, listener) -> modelRegistry.getModel( + ElasticInferenceServiceMinimalSettings.DEFAULT_RERANK_ENDPOINT_ID_V1, + listener + ), + """ + { + "models": [ + ] + } + """ + ).withFailure( + "Unable to retrieve the preconfigured inference endpoint [.rerank-v1-elastic] from the Elastic Inference Service", + RestStatus.BAD_REQUEST + ).build() }, + // getModelWithSecrets calls + { + new TestCaseBuilder( + "getModelWithSecrets retrieves eis chat completion preconfigured endpoint", + (modelRegistry, listener) -> modelRegistry.getModelWithSecrets( + ElasticInferenceServiceMinimalSettings.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, + listener + ), + """ + { + "models": [ + { + "model_name": "rainbow-sprinkles", + "task_types": ["chat"] + } + ] + } + """ + ).withSuccessfulResult( + new UnparsedModel( + ElasticInferenceServiceMinimalSettings.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, + TaskType.CHAT_COMPLETION, + ElasticInferenceService.NAME, + Map.of( + ModelConfigurations.SERVICE_SETTINGS, + Map.of(ServiceFields.MODEL_ID, ElasticInferenceServiceMinimalSettings.DEFAULT_CHAT_COMPLETION_MODEL_ID_V1) + ), + Map.of() + ) + ).build() }, + { + new TestCaseBuilder( + "getModelWithSecrets throws an exception when retrieving eis " + + "chat completion preconfigured endpoint and it isn't authorized", + (modelRegistry, listener) -> modelRegistry.getModelWithSecrets( + ElasticInferenceServiceMinimalSettings.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, + listener + ), + """ + { + "models": [ + ] + } + """ + ).withFailure( + "Unable to retrieve the preconfigured inference endpoint " + + "[.rainbow-sprinkles-elastic] from the Elastic Inference Service", + RestStatus.BAD_REQUEST + ).build() }, + { + new TestCaseBuilder( + "getModelWithSecrets retrieves eis elser preconfigured endpoint", + (modelRegistry, listener) -> modelRegistry.getModelWithSecrets( + ElasticInferenceServiceMinimalSettings.DEFAULT_ELSER_ENDPOINT_ID_V2, + listener + ), + """ + { + "models": [ + { + "model_name": "elser_model_2", + "task_types": ["embed/text/sparse"] + } + ] + } + """ + ).withSuccessfulResult( + new UnparsedModel( + ElasticInferenceServiceMinimalSettings.DEFAULT_ELSER_ENDPOINT_ID_V2, + TaskType.SPARSE_EMBEDDING, + ElasticInferenceService.NAME, + Map.of( + ModelConfigurations.SERVICE_SETTINGS, + Map.of(ServiceFields.MODEL_ID, ElasticInferenceServiceMinimalSettings.DEFAULT_ELSER_2_MODEL_ID) + ), + Map.of() + ) + ).build() }, + { + new TestCaseBuilder( + "getModelWithSecrets throws exception when retrieving eis elser preconfigured endpoint and not authorized", + (modelRegistry, listener) -> modelRegistry.getModelWithSecrets( + ElasticInferenceServiceMinimalSettings.DEFAULT_ELSER_ENDPOINT_ID_V2, + listener + ), + """ + { + "models": [ + ] + } + """ + ).withFailure( + "Unable to retrieve the preconfigured inference endpoint [.elser-2-elastic] from the Elastic Inference Service", + RestStatus.BAD_REQUEST + ).build() }, + { + new TestCaseBuilder( + "getModelWithSecrets retrieves eis multilingual embed preconfigured endpoint", + (modelRegistry, listener) -> modelRegistry.getModelWithSecrets( + ElasticInferenceServiceMinimalSettings.DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID, + listener + ), + """ + { + "models": [ + { + "model_name": "multilingual-embed-v1", + "task_types": ["embed/text/dense"] + } + ] + } + """ + ).withSuccessfulResult( + new UnparsedModel( + ElasticInferenceServiceMinimalSettings.DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID, + TaskType.TEXT_EMBEDDING, + ElasticInferenceService.NAME, + Map.of( + ModelConfigurations.SERVICE_SETTINGS, + Map.of( + ServiceFields.MODEL_ID, + ElasticInferenceServiceMinimalSettings.DEFAULT_MULTILINGUAL_EMBED_MODEL_ID, + ServiceFields.SIMILARITY, + SimilarityMeasure.COSINE.toString(), + ServiceFields.DIMENSIONS, + ElasticInferenceServiceMinimalSettings.DENSE_TEXT_EMBEDDINGS_DIMENSIONS, + ServiceFields.ELEMENT_TYPE, + DenseVectorFieldMapper.ElementType.FLOAT.toString() + ) + ), + Map.of() + ) + ).build() }, + { + new TestCaseBuilder( + "getModelWithSecrets throws exception when retrieving eis " + + "multilingual embed preconfigured endpoint and not authorized", + (modelRegistry, listener) -> modelRegistry.getModelWithSecrets( + ElasticInferenceServiceMinimalSettings.DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID, + listener + ), + """ + { + "models": [ + ] + } + """ + ).withFailure( + "Unable to retrieve the preconfigured inference endpoint " + + "[.multilingual-embed-v1-elastic] from the Elastic Inference Service", + RestStatus.BAD_REQUEST + ).build() }, + { + new TestCaseBuilder( + "getModelWithSecrets retrieves eis rerank preconfigured endpoint", + (modelRegistry, listener) -> modelRegistry.getModelWithSecrets( + ElasticInferenceServiceMinimalSettings.DEFAULT_RERANK_ENDPOINT_ID_V1, + listener + ), + """ + { + "models": [ + { + "model_name": "rerank-v1", + "task_types": ["rerank/text/text-similarity"] + } + ] + } + """ + ).withSuccessfulResult( + new UnparsedModel( + ElasticInferenceServiceMinimalSettings.DEFAULT_RERANK_ENDPOINT_ID_V1, + TaskType.RERANK, + ElasticInferenceService.NAME, + Map.of( + ModelConfigurations.SERVICE_SETTINGS, + Map.of(ServiceFields.MODEL_ID, ElasticInferenceServiceMinimalSettings.DEFAULT_RERANK_MODEL_ID_V1) + ), + Map.of() + ) + ).build() }, + { + new TestCaseBuilder( + "getModelWithSecrets throws exception when retrieving eis rerank preconfigured endpoint and not authorized", + (modelRegistry, listener) -> modelRegistry.getModelWithSecrets( + ElasticInferenceServiceMinimalSettings.DEFAULT_RERANK_ENDPOINT_ID_V1, + listener + ), + """ + { + "models": [ + ] + } + """ + ).withFailure( + "Unable to retrieve the preconfigured inference endpoint [.rerank-v1-elastic] from the Elastic Inference Service", + RestStatus.BAD_REQUEST + ).build() } } + ); + } + + public void test() { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testCase.responseJson)); + + PlainActionFuture listener = new PlainActionFuture<>(); + testCase.registryCall.accept(modelRegistry, listener); + + if (testCase.expectedResult != null) { + assertSuccessfulTestCase(listener); + } else { + assertFailureTestCase(listener); + } + } + + private void assertSuccessfulTestCase(PlainActionFuture listener) { + var model = listener.actionGet(TIMEOUT); + assertThat(model, is(testCase.expectedResult)); + } + + private void assertFailureTestCase(PlainActionFuture listener) { + var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(exception.getMessage(), containsString(testCase.failureMessage)); + assertThat(exception.status(), is(testCase.failureStatus)); + } +} diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisIT.java new file mode 100644 index 0000000000000..2450d12064832 --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisIT.java @@ -0,0 +1,171 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.integration; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings; +import org.junit.Before; + +import java.util.List; +import java.util.stream.Stream; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class ModelRegistryEisIT extends ModelRegistryEisBaseIT { + + @Before + public void setupTest() { + initializeModels(); + } + + private static final String eisAuthorizedResponse = """ + { + "models": [ + { + "model_name": "rainbow-sprinkles", + "task_types": ["chat"] + }, + { + "model_name": "elser_model_2", + "task_types": ["embed/text/sparse"] + }, + { + "model_name": "multilingual-embed-v1", + "task_types": ["embed/text/dense"] + }, + { + "model_name": "rerank-v1", + "task_types": ["rerank/text/text-similarity"] + } + ] + } + """; + + private static final String eisUnauthorizedResponse = """ + { + "models": [ + ] + } + """; + + public void testGetModelsByTaskType() { + { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(eisAuthorizedResponse)); + PlainActionFuture> listener = new PlainActionFuture<>(); + modelRegistry.getModelsByTaskType(TaskType.COMPLETION, listener); + + assertThat(listener.actionGet(TIMEOUT), is(List.of())); + } + { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(eisAuthorizedResponse)); + PlainActionFuture> listener = new PlainActionFuture<>(); + modelRegistry.getModelsByTaskType(TaskType.SPARSE_EMBEDDING, listener); + + var results = listener.actionGet(TIMEOUT); + var expected = Stream.of("sparse-1", "sparse-2", "sparse-3", ".elser-2-elastic").toArray(String[]::new); + assertThat(results.size(), is(expected.length)); + assertThat(results.stream().map(UnparsedModel::inferenceEntityId).sorted().toList(), containsInAnyOrder(expected)); + } + { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(eisAuthorizedResponse)); + PlainActionFuture> listener = new PlainActionFuture<>(); + modelRegistry.getModelsByTaskType(TaskType.TEXT_EMBEDDING, listener); + + var results = listener.actionGet(TIMEOUT); + var expected = Stream.of("embedding-1", "embedding-2", ".multilingual-embed-v1-elastic").toArray(String[]::new); + assertThat(results.size(), is(expected.length)); + assertThat(results.stream().map(UnparsedModel::inferenceEntityId).sorted().toList(), containsInAnyOrder(expected)); + } + { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(eisAuthorizedResponse)); + PlainActionFuture> listener = new PlainActionFuture<>(); + modelRegistry.getModelsByTaskType(TaskType.CHAT_COMPLETION, listener); + + var results = listener.actionGet(TIMEOUT); + assertThat(results.size(), is(1)); + assertThat( + results.stream().map(UnparsedModel::inferenceEntityId).sorted().toList(), + containsInAnyOrder(".rainbow-sprinkles-elastic") + ); + } + { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(eisAuthorizedResponse)); + PlainActionFuture> listener = new PlainActionFuture<>(); + modelRegistry.getModelsByTaskType(TaskType.RERANK, listener); + + var results = listener.actionGet(TIMEOUT); + assertThat(results.size(), is(1)); + assertThat(results.stream().map(UnparsedModel::inferenceEntityId).sorted().toList(), containsInAnyOrder(".rerank-v1-elastic")); + } + } + + public void testGetAllModels() { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(eisAuthorizedResponse)); + PlainActionFuture> listener = new PlainActionFuture<>(); + modelRegistry.getAllModels(false, listener); + + var results = listener.actionGet(TIMEOUT); + var expected = Stream.of( + "sparse-1", + "sparse-2", + "sparse-3", + "embedding-1", + "embedding-2", + ".elser-2-elastic", + ".multilingual-embed-v1-elastic", + ".rainbow-sprinkles-elastic", + ".rerank-v1-elastic" + ).toArray(String[]::new); + assertThat(results.size(), is(expected.length)); + assertThat(results.stream().map(UnparsedModel::inferenceEntityId).sorted().toList(), containsInAnyOrder(expected)); + } + + public void testGetAllModelsNoEisResults() { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(eisUnauthorizedResponse)); + PlainActionFuture> listener = new PlainActionFuture<>(); + modelRegistry.getAllModels(false, listener); + + var results = listener.actionGet(TIMEOUT); + var expected = Stream.of("sparse-1", "sparse-2", "sparse-3", "embedding-1", "embedding-2").toArray(String[]::new); + assertThat(results.size(), is(expected.length)); + assertThat(results.stream().map(UnparsedModel::inferenceEntityId).sorted().toList(), containsInAnyOrder(expected)); + } + + public void testGetModel_WhenNotAuthorizedForEis() { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(eisUnauthorizedResponse)); + PlainActionFuture listener = new PlainActionFuture<>(); + modelRegistry.getModel(ElasticInferenceServiceMinimalSettings.DEFAULT_RERANK_ENDPOINT_ID_V1, listener); + + var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(exception.getMessage(), containsString("Unable to retrieve the preconfigured inference endpoint")); + assertThat( + exception.getCause().getMessage(), + containsString( + "No Elastic Inference Service preconfigured endpoint found for inference ID [.rerank-v1-elastic]. " + + "Either it does not exist, or you are not authorized to access it." + ) + ); + } + + public void testGetAllModelsEisReturnsFailureStatusCode() { + webServer.enqueue(new MockResponse().setResponseCode(500).setBody("{}")); + PlainActionFuture> listener = new PlainActionFuture<>(); + modelRegistry.getAllModels(false, listener); + + var results = listener.actionGet(TIMEOUT); + var expected = Stream.of("sparse-1", "sparse-2", "sparse-3", "embedding-1", "embedding-2").toArray(String[]::new); + assertThat(results.size(), is(expected.length)); + assertThat(results.stream().map(UnparsedModel::inferenceEntityId).sorted().toList(), containsInAnyOrder(expected)); + } +} diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisInvalidUrlIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisInvalidUrlIT.java new file mode 100644 index 0000000000000..0db54784851f2 --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryEisInvalidUrlIT.java @@ -0,0 +1,35 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.integration; + +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.inference.UnparsedModel; + +import java.util.List; +import java.util.stream.Stream; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.is; + +public class ModelRegistryEisInvalidUrlIT extends ModelRegistryEisBaseIT { + public ModelRegistryEisInvalidUrlIT() { + super(""); + } + + public void testGetAllModelsDoesNotReturnEisModels_WhenEisUrlIsEmpty() { + initializeModels(); + + PlainActionFuture> listener = new PlainActionFuture<>(); + modelRegistry.getAllModels(false, listener); + + var results = listener.actionGet(TIMEOUT); + var expected = Stream.of("sparse-1", "sparse-2", "sparse-3", "embedding-1", "embedding-2").toArray(String[]::new); + assertThat(results.size(), is(expected.length)); + assertThat(results.stream().map(UnparsedModel::inferenceEntityId).sorted().toList(), containsInAnyOrder(expected)); + } +} diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index 92eea9599ec5d..9aac2c44fc2e5 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -93,13 +93,13 @@ protected Collection> getPlugins() { return pluginList(ReindexPlugin.class, LocalStateInferencePlugin.class); } - public void testStoreModel() throws Exception { + public void testStoreModel() { String inferenceEntityId = "test-store-model"; Model model = buildElserModelConfig(inferenceEntityId, TaskType.SPARSE_EMBEDDING); ModelRegistryTests.assertStoreModel(modelRegistry, model); } - public void testStoreModelWithUnknownFields() throws Exception { + public void testStoreModelWithUnknownFields() { String inferenceEntityId = "test-store-model-unknown-field"; Model model = buildModelWithUnknownField(inferenceEntityId); ElasticsearchStatusException statusException = expectThrows( @@ -145,7 +145,7 @@ public void testGetModel() throws Exception { assertEquals(model, roundTripModel); } - public void testStoreModelFailsWhenModelExists() throws Exception { + public void testStoreModelFailsWhenModelExists() { String inferenceEntityId = "test-put-trained-model-config-exists"; Model model = buildElserModelConfig(inferenceEntityId, TaskType.SPARSE_EMBEDDING); ModelRegistryTests.assertStoreModel(modelRegistry, model); @@ -175,7 +175,7 @@ public void testDeleteModel() throws Exception { assertThat(exceptionHolder.get(), not(nullValue())); assertFalse(deleteResponseHolder.get()); - assertThat(exceptionHolder.get().getMessage(), containsString("Inference endpoint not found [model1]")); + assertThat(exceptionHolder.get().getMessage(), containsString("Inference endpoint [model1] not found")); } public void testNonExistentDeleteModel_DoesNotThrowAnException() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 60592c5dd1dbd..447f33308284d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -134,6 +134,7 @@ import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; +import org.elasticsearch.xpack.inference.services.elastic.authorization.PreconfiguredEndpointsRequestHandler; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioService; import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiService; @@ -288,9 +289,6 @@ public Collection createComponents(PluginServices services) { var amazonBedrockRequestSenderFactory = new AmazonBedrockRequestSender.Factory(serviceComponents.get(), services.clusterService()); amazonBedrockFactory.set(amazonBedrockRequestSenderFactory); - modelRegistry.set(new ModelRegistry(services.clusterService(), services.client())); - services.clusterService().addListener(modelRegistry.get()); - if (inferenceServiceExtensions == null) { inferenceServiceExtensions = new ArrayList<>(); } @@ -322,6 +320,11 @@ public Collection createComponents(PluginServices services) { services.threadPool() ); + var eisSender = elasicInferenceServiceFactory.get().createSender(); + var preconfigEndpointsHandler = new PreconfiguredEndpointsRequestHandler(authorizationHandler, eisSender); + modelRegistry.set(new ModelRegistry(services.clusterService(), services.client(), preconfigEndpointsHandler)); + services.clusterService().addListener(modelRegistry.get()); + var sageMakerSchemas = new SageMakerSchemas(); var sageMakerConfigurations = new LazyInitializable<>(new SageMakerConfiguration(sageMakerSchemas)); inferenceServices.add( @@ -385,7 +388,7 @@ public Collection createComponents(PluginServices services) { ); components.add(inferenceStatsBinding); components.add(authorizationHandler); - components.add(new PluginComponentBinding<>(Sender.class, elasicInferenceServiceFactory.get().createSender())); + components.add(new PluginComponentBinding<>(Sender.class, eisSender)); components.add( new InferenceEndpointRegistry( services.clusterService(), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java index 18c83df4067ed..9ebbeb187ad7b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java @@ -21,7 +21,6 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.injection.guice.Inject; import org.elasticsearch.tasks.Task; -import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.GetInferenceServicesAction; import org.elasticsearch.xpack.inference.external.http.sender.Sender; @@ -35,8 +34,6 @@ import java.util.Map; import java.util.stream.Collectors; -import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; - public class TransportGetInferenceServicesAction extends HandledTransportAction< GetInferenceServicesAction.Request, GetInferenceServicesAction.Response> { @@ -46,13 +43,11 @@ public class TransportGetInferenceServicesAction extends HandledTransportAction< private final InferenceServiceRegistry serviceRegistry; private final ElasticInferenceServiceAuthorizationRequestHandler eisAuthorizationRequestHandler; private final Sender eisSender; - private final ThreadPool threadPool; @Inject public TransportGetInferenceServicesAction( TransportService transportService, ActionFilters actionFilters, - ThreadPool threadPool, InferenceServiceRegistry serviceRegistry, ElasticInferenceServiceAuthorizationRequestHandler eisAuthorizationRequestHandler, Sender sender @@ -67,7 +62,6 @@ public TransportGetInferenceServicesAction( this.serviceRegistry = serviceRegistry; this.eisAuthorizationRequestHandler = eisAuthorizationRequestHandler; this.eisSender = sender; - this.threadPool = threadPool; } @Override @@ -123,8 +117,7 @@ private void getServiceConfigurationsForServicesAndEis( @Nullable TaskType requestedTaskType ) { SubscribableListener.newForked(authModelListener -> { - // Executing on a separate thread because there's a chance the authorization call needs to do some initialization for the Sender - threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(() -> getEisAuthorization(authModelListener, eisSender)); + getEisAuthorization(authModelListener, eisSender); }).>andThen((configurationListener, authorizationModel) -> { var serviceConfigs = getServiceConfigurationsForServices(availableServices); @@ -152,11 +145,16 @@ private void getServiceConfigurationsForServicesAndEis( private void getEisAuthorization(ActionListener listener, Sender sender) { var disabledServiceListener = listener.delegateResponse((delegate, e) -> { - logger.warn( - "Failed to retrieve authorization information from the " - + "Elastic Inference Service while determining service configurations. Marking service as disabled.", - e - ); + if (eisAuthorizationRequestHandler.isServiceConfigured()) { + logger.warn( + "Failed to retrieve authorization information from the " + + "Elastic Inference Service while determining service configurations. Marking service as disabled.", + e + ); + } else { + logger.debug("The Elastic Inference Service is not configured. Marking service as disabled.", e); + } + delegate.onResponse(ElasticInferenceServiceAuthorizationModel.newDisabledService()); }); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index 7cd1cf5999d11..cfe2057a7731b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -49,6 +49,7 @@ import org.elasticsearch.core.Tuple; import org.elasticsearch.gateway.GatewayService; import org.elasticsearch.index.engine.VersionConflictEngineException; +import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.reindex.BulkByScrollResponse; @@ -76,6 +77,8 @@ import org.elasticsearch.xpack.inference.InferenceIndex; import org.elasticsearch.xpack.inference.InferenceSecretsIndex; import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings; +import org.elasticsearch.xpack.inference.services.elastic.authorization.PreconfiguredEndpointsRequestHandler; import java.io.IOException; import java.util.ArrayList; @@ -87,14 +90,18 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; +import java.util.function.Predicate; +import java.util.function.Supplier; import java.util.stream.Collectors; import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings.EIS_PRECONFIGURED_ENDPOINTS; /** * A class responsible for persisting and reading inference endpoint configurations. @@ -147,10 +154,15 @@ public static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap) private final MasterServiceTaskQueue metadataTaskQueue; private final AtomicBoolean upgradeMetadataInProgress = new AtomicBoolean(false); private final Set preventDeletionLock = Collections.newSetFromMap(new ConcurrentHashMap<>()); + private final PreconfiguredEndpointsRequestHandler preconfiguredEndpointsRequestHandler; private volatile Metadata lastMetadata; - public ModelRegistry(ClusterService clusterService, Client client) { + public ModelRegistry( + ClusterService clusterService, + Client client, + PreconfiguredEndpointsRequestHandler preconfiguredEndpointsRequestHandler + ) { this.client = new OriginSettingClient(client, ClientHelper.INFERENCE_ORIGIN); this.defaultConfigIds = new ConcurrentHashMap<>(); var executor = new SimpleBatchedAckListenerTaskExecutor() { @@ -163,6 +175,7 @@ public Tuple executeTask(MetadataTask tas } }; this.metadataTaskQueue = clusterService.createTaskQueue("model_registry", Priority.NORMAL, executor); + this.preconfiguredEndpointsRequestHandler = Objects.requireNonNull(preconfiguredEndpointsRequestHandler); } /** @@ -175,14 +188,6 @@ public boolean containsDefaultConfigId(String inferenceEntityId) { return defaultConfigIds.containsKey(inferenceEntityId); } - /** - * Adds the default configuration information if it does not already exist internally. - * @param defaultConfigId the default endpoint information - */ - public synchronized void putDefaultIdIfAbsent(InferenceService.DefaultConfigId defaultConfigId) { - defaultConfigIds.putIfAbsent(defaultConfigId.inferenceId(), defaultConfigId); - } - /** * Set the default inference ids provided by the services * @param defaultConfigId The default endpoint information @@ -230,6 +235,14 @@ public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId throw new IllegalStateException("initial cluster state not set yet"); } } + + // this is a temporary solution until the model registry handles polling the EIS authorization endpoint + // to retrieve the preconfigured inference endpoints + var eisConfig = ElasticInferenceServiceMinimalSettings.getWithInferenceId(inferenceEntityId); + if (eisConfig != null) { + return eisConfig.minimalSettings(); + } + var config = defaultConfigIds.get(inferenceEntityId); if (config != null) { return config.settings(); @@ -249,37 +262,16 @@ public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId * @param listener Model listener */ public void getModelWithSecrets(String inferenceEntityId, ActionListener listener) { - ActionListener searchListener = ActionListener.wrap((searchResponse) -> { - // There should be a hit for the configurations - if (searchResponse.getHits().getHits().length == 0) { - var maybeDefault = defaultConfigIds.get(inferenceEntityId); - if (maybeDefault != null) { - getDefaultConfig(true, maybeDefault, listener); - } else { - listener.onFailure(inferenceNotFoundException(inferenceEntityId)); - } - return; - } - - listener.onResponse(unparsedModelFromMap(createModelConfigMap(searchResponse.getHits(), inferenceEntityId))); - }, (e) -> { - logger.warn(format("Failed to load inference endpoint with secrets [%s]", inferenceEntityId), e); - listener.onFailure( - new ElasticsearchException( - format("Failed to load inference endpoint with secrets [%s], error: [%s]", inferenceEntityId, e.getMessage()), - e - ) - ); - }); - - QueryBuilder queryBuilder = documentIdQuery(inferenceEntityId); - SearchRequest modelSearch = client.prepareSearch(InferenceIndex.INDEX_PATTERN, InferenceSecretsIndex.INDEX_PATTERN) - .setQuery(queryBuilder) - .setSize(2) - .setAllowPartialSearchResults(false) - .request(); - - client.search(modelSearch, searchListener); + getModelHelper( + inferenceEntityId, + client.prepareSearch(InferenceIndex.INDEX_PATTERN, InferenceSecretsIndex.INDEX_PATTERN) + .setQuery(documentIdQuery(inferenceEntityId)) + .setSize(2) + .setAllowPartialSearchResults(false) + .request(), + searchResponse -> unparsedModelFromMap(createModelConfigMap(searchResponse.getHits(), inferenceEntityId)), + listener + ); } /** @@ -289,24 +281,47 @@ public void getModelWithSecrets(String inferenceEntityId, ActionListener listener) { - ActionListener searchListener = ActionListener.wrap((searchResponse) -> { - // There should be a hit for the configurations - if (searchResponse.getHits().getHits().length == 0) { - var maybeDefault = defaultConfigIds.get(inferenceEntityId); - if (maybeDefault != null) { - getDefaultConfig(true, maybeDefault, listener); - } else { - listener.onFailure(inferenceNotFoundException(inferenceEntityId)); - } + getModelHelper( + inferenceEntityId, + client.prepareSearch(InferenceIndex.INDEX_PATTERN) + .setQuery(documentIdQuery(inferenceEntityId)) + .setSize(1) + .setTrackTotalHits(false) + .request(), + searchResponse -> { + var modelConfigs = parseHitsAsModelsWithoutSecrets(searchResponse.getHits()).stream() + .map(ModelRegistry::unparsedModelFromMap) + .toList(); + assert modelConfigs.size() == 1; + return modelConfigs.get(0); + }, + listener + ); + } + + private void getModelHelper( + String inferenceEntityId, + SearchRequest modelSearch, + Function unparsedModelCreator, + ActionListener listener + ) { + // If we know it's an EIS preconfigured endpoint, skip looking in the index because it could have an outdated version of the + // endpoint and go directly to EIS to retrieve it + if (ElasticInferenceServiceMinimalSettings.isEisPreconfiguredEndpoint(inferenceEntityId)) { + retrievePreconfiguredEndpointFromEisElseEisError(listener, inferenceEntityId); + return; + } + + var failureListener = listener.delegateResponse((delegate, e) -> { + // If the inference endpoint does not exist, we've already created a well-defined exception, so just return it + if (e instanceof ResourceNotFoundException) { + delegate.onFailure(e); return; } - var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList(); - assert modelConfigs.size() == 1; - listener.onResponse(modelConfigs.get(0)); - }, e -> { logger.warn(format("Failed to load inference endpoint [%s]", inferenceEntityId), e); - listener.onFailure( + + delegate.onFailure( new ElasticsearchException( format("Failed to load inference endpoint [%s], error: [%s]", inferenceEntityId, e.getMessage()), e @@ -314,18 +329,77 @@ public void getModel(String inferenceEntityId, ActionListener lis ); }); - QueryBuilder queryBuilder = documentIdQuery(inferenceEntityId); - SearchRequest modelSearch = client.prepareSearch(InferenceIndex.INDEX_PATTERN) - .setQuery(queryBuilder) - .setSize(1) - .setTrackTotalHits(false) - .request(); + ActionListener searchResponseListener = failureListener.delegateFailureAndWrap( + (delegate, searchResponse) -> searchForEndpointInDefaultAndEis( + inferenceEntityId, + searchResponse, + unparsedModelCreator, + delegate + ) + ); - client.search(modelSearch, searchListener); + client.search(modelSearch, searchResponseListener); + } + + private void searchForEndpointInDefaultAndEis( + String inferenceEntityId, + SearchResponse searchResponse, + Function unparsedModelCreator, + ActionListener listener + ) { + // We likely found the configuration, so parse it and return it + if (searchResponse.getHits().getHits().length != 0) { + listener.onResponse(unparsedModelCreator.apply(searchResponse)); + return; + } + + // we didn't find the configuration in the inference index, so check if it is a preconfigured endpoint + var maybeDefault = defaultConfigIds.get(inferenceEntityId); + if (maybeDefault != null) { + getDefaultConfig(true, maybeDefault, listener); + return; + } + + // check if the inference id is a preconfigured endpoint available from EIS + retrievePreconfiguredEndpointFromEisElseNotFound(listener, inferenceEntityId); + } + + private void retrievePreconfiguredEndpointFromEisElseEisError(ActionListener listener, String inferenceEntityId) { + var eisFailureListener = listener.delegateResponse( + (delegate, e) -> delegate.onFailure(eisBadRequestException(inferenceEntityId, e)) + ); + + retrieveEisPreconfiguredEndpoint(eisFailureListener, inferenceEntityId); + } + + private ElasticsearchStatusException eisBadRequestException(String inferenceEntityId, Exception exception) { + return new ElasticsearchStatusException( + "Unable to retrieve the preconfigured inference endpoint [{}] from the Elastic Inference Service", + RestStatus.BAD_REQUEST, + exception, + inferenceEntityId + ); + } + + private void retrievePreconfiguredEndpointFromEisElseNotFound(ActionListener listener, String inferenceEntityId) { + var eisFailureListener = listener.delegateResponse( + (delegate, e) -> delegate.onFailure(inferenceNotFoundException(inferenceEntityId)) + ); + + retrieveEisPreconfiguredEndpoint(eisFailureListener, inferenceEntityId); + } + + private void retrieveEisPreconfiguredEndpoint(ActionListener listener, String inferenceEntityId) { + var eisFailureListener = listener.delegateResponse((delegate, e) -> { + logger.debug("Failed to retrieve preconfigured endpoint from EIS", e); + delegate.onFailure(e); + }); + + preconfiguredEndpointsRequestHandler.getPreconfiguredEndpointAsUnparsedModel(inferenceEntityId, eisFailureListener); } private ResourceNotFoundException inferenceNotFoundException(String inferenceEntityId) { - return new ResourceNotFoundException("Inference endpoint not found [{}]", inferenceEntityId); + return new ResourceNotFoundException("Inference endpoint [{}] not found or you are not authorized to access it", inferenceEntityId); } /** @@ -335,22 +409,80 @@ private ResourceNotFoundException inferenceNotFoundException(String inferenceEnt * @param listener Models listener */ public void getModelsByTaskType(TaskType taskType, ActionListener> listener) { - ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { - var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList(); - var defaultConfigsForTaskType = taskTypeMatchedDefaults(taskType, defaultConfigIds.values()); - addAllDefaultConfigsIfMissing(true, modelConfigs, defaultConfigsForTaskType, delegate); - }); + getModelsHelper( + QueryBuilders.boolQuery().filter(QueryBuilders.termsQuery(TASK_TYPE_FIELD, taskType.toString())), + () -> taskTypeMatchedDefaults(taskType, defaultConfigIds.values()), + unparsedModel -> unparsedModel.taskType() == taskType, + true, + listener + ); + } - QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.termsQuery(TASK_TYPE_FIELD, taskType.toString())); + private void getModelsHelper( + BoolQueryBuilder boolQueryBuilder, + Supplier> defaultConfigIdsSupplier, + Predicate eisResponseFilter, + boolean persistDefaultEndpoints, + ActionListener> listener + ) { + ActionListener searchResponseListener = listener.delegateFailureAndWrap( + (delegate, searchResponse) -> includeDefaultAndEisEndpoints( + searchResponse, + defaultConfigIdsSupplier, + eisResponseFilter, + persistDefaultEndpoints, + delegate + ) + ); - SearchRequest modelSearch = client.prepareSearch(InferenceIndex.INDEX_PATTERN) + var eisEndpointIds = EIS_PRECONFIGURED_ENDPOINTS.stream().map(Model::documentId).toArray(String[]::new); + + // exclude the EIS preconfigured endpoints so we can query EIS directly for them + var queryBuilder = boolQueryBuilder.filter(QueryBuilders.boolQuery().mustNot(QueryBuilders.idsQuery().addIds(eisEndpointIds))); + + var modelSearch = client.prepareSearch(InferenceIndex.INDEX_PATTERN) .setQuery(queryBuilder) .setSize(10_000) .setTrackTotalHits(false) .addSort(MODEL_ID_FIELD, SortOrder.ASC) .request(); - client.search(modelSearch, searchListener); + client.search(modelSearch, searchResponseListener); + } + + private void includeDefaultAndEisEndpoints( + SearchResponse searchResponse, + Supplier> defaultConfigIdsSupplier, + Predicate eisResponseFilter, + boolean persistDefaultEndpoints, + ActionListener> listener + ) { + SubscribableListener.>newForked(missingDefaultConfigsAddedListener -> { + var modelConfigs = parseHitsAsModelsWithoutSecrets(searchResponse.getHits()).stream() + .map(ModelRegistry::unparsedModelFromMap) + .toList(); + addAllDefaultConfigsIfMissing( + persistDefaultEndpoints, + modelConfigs, + defaultConfigIdsSupplier.get(), + missingDefaultConfigsAddedListener + ); + }).>andThen((eisPreconfiguredEndpointsAddedListener, defaultModelsAndFromIndex) -> { + ActionListener> eisListener = ActionListener.wrap(allEisAuthorizedModels -> { + var filteredEisModels = allEisAuthorizedModels.stream().filter(eisResponseFilter).toList(); + + var allModels = new ArrayList<>(defaultModelsAndFromIndex); + allModels.addAll(filteredEisModels); + allModels.sort(Comparator.comparing(UnparsedModel::inferenceEntityId)); + + eisPreconfiguredEndpointsAddedListener.onResponse(allModels); + }, e -> { + logger.debug("Failed to retrieve preconfigured endpoint from EIS", e); + eisPreconfiguredEndpointsAddedListener.onResponse(defaultModelsAndFromIndex); + }); + + preconfiguredEndpointsRequestHandler.getAllPreconfiguredEndpointsAsUnparsedModels(eisListener); + }).addListener(listener); } /** @@ -366,24 +498,13 @@ public void getModelsByTaskType(TaskType taskType, ActionListener> listener) { - ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { - var foundConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList(); - addAllDefaultConfigsIfMissing(persistDefaultEndpoints, foundConfigs, defaultConfigIds.values(), delegate); - }); - - // In theory the index should only contain model config documents - // and a match all query would be sufficient. But just in case the - // index has been polluted return only docs with a task_type field - QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.existsQuery(TASK_TYPE_FIELD)); - - SearchRequest modelSearch = client.prepareSearch(InferenceIndex.INDEX_PATTERN) - .setQuery(queryBuilder) - .setSize(10_000) - .setTrackTotalHits(false) - .addSort(MODEL_ID_FIELD, SortOrder.ASC) - .request(); - - client.search(modelSearch, searchListener); + getModelsHelper( + QueryBuilders.boolQuery().filter(QueryBuilders.constantScoreQuery(QueryBuilders.existsQuery(TASK_TYPE_FIELD))), + () -> new ArrayList<>(defaultConfigIds.values()), + eisResponse -> true, + persistDefaultEndpoints, + listener + ); } private void addAllDefaultConfigsIfMissing( @@ -458,7 +579,7 @@ private void storeDefaultEndpoint(Model preconfigured, Runnable runAfter) { storeModel(preconfigured, false, ActionListener.runAfter(responseListener, runAfter), AcknowledgedRequest.DEFAULT_ACK_TIMEOUT); } - private ArrayList parseHitsAsModels(SearchHits hits) { + private ArrayList parseHitsAsModelsWithoutSecrets(SearchHits hits) { var modelConfigs = new ArrayList(); for (var hit : hits) { modelConfigs.add(new ModelConfigMap(hit.getSourceAsMap(), Map.of())); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 9fada2a66c95d..4f9a7d680de00 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -16,7 +16,6 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; @@ -26,7 +25,6 @@ import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; -import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; @@ -54,7 +52,6 @@ import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.elastic.action.ElasticInferenceServiceActionCreator; -import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationHandler; import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; @@ -84,6 +81,18 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.useChatCompletionUrlMessage; +import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings.CHAT_COMPLETION_V1_MINIMAL_SETTINGS; +import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1; +import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings.DEFAULT_CHAT_COMPLETION_MODEL_ID_V1; +import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings.DEFAULT_ELSER_2_MODEL_ID; +import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings.DEFAULT_ELSER_ENDPOINT_ID_V2; +import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings.DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID; +import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings.DEFAULT_MULTILINGUAL_EMBED_MODEL_ID; +import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings.DEFAULT_RERANK_ENDPOINT_ID_V1; +import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings.DEFAULT_RERANK_MODEL_ID_V1; +import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings.ELSER_V2_MINIMAL_SETTINGS; +import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings.MULTILINGUAL_EMBED_MINIMAL_SETTINGS; +import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings.RERANK_V1_MINIMAL_SETTINGS; public class ElasticInferenceService extends SenderService { @@ -95,12 +104,6 @@ public class ElasticInferenceService extends SenderService { // A batch size of 16 provides optimal throughput and stability, especially on lower-tier instance types. public static final Integer SPARSE_TEXT_EMBEDDING_MAX_BATCH_SIZE = 16; - private static final EnumSet IMPLEMENTED_TASK_TYPES = EnumSet.of( - TaskType.SPARSE_EMBEDDING, - TaskType.CHAT_COMPLETION, - TaskType.RERANK, - TaskType.TEXT_EMBEDDING - ); private static final String SERVICE_NAME = "Elastic"; // TODO: revisit this value once EIS supports dense models @@ -108,22 +111,6 @@ public class ElasticInferenceService extends SenderService { // This mirrors the memory constraints observed with sparse embeddings private static final Integer DENSE_TEXT_EMBEDDINGS_MAX_BATCH_SIZE = 16; - // rainbow-sprinkles - static final String DEFAULT_CHAT_COMPLETION_MODEL_ID_V1 = "rainbow-sprinkles"; - static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1); - - // elser-2 - static final String DEFAULT_ELSER_2_MODEL_ID = "elser_model_2"; - static final String DEFAULT_ELSER_ENDPOINT_ID_V2 = defaultEndpointId("elser-2"); - - // multilingual-text-embed - static final String DEFAULT_MULTILINGUAL_EMBED_MODEL_ID = "multilingual-embed-v1"; - static final String DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID = defaultEndpointId(DEFAULT_MULTILINGUAL_EMBED_MODEL_ID); - - // rerank-v1 - static final String DEFAULT_RERANK_MODEL_ID_V1 = "rerank-v1"; - static final String DEFAULT_RERANK_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_RERANK_MODEL_ID_V1); - /** * The task types that the {@link InferenceAction.Request} can accept. */ @@ -133,12 +120,7 @@ public class ElasticInferenceService extends SenderService { TaskType.TEXT_EMBEDDING ); - public static String defaultEndpointId(String modelId) { - return Strings.format(".%s-elastic", modelId); - } - private final ElasticInferenceServiceComponents elasticInferenceServiceComponents; - private final ElasticInferenceServiceAuthorizationHandler authorizationHandler; public ElasticInferenceService( HttpRequestSender.Factory factory, @@ -170,16 +152,6 @@ public ElasticInferenceService( this.elasticInferenceServiceComponents = new ElasticInferenceServiceComponents( elasticInferenceServiceSettings.getElasticInferenceServiceUrl() ); - authorizationHandler = new ElasticInferenceServiceAuthorizationHandler( - serviceComponents, - modelRegistry, - authorizationRequestHandler, - initDefaultEndpoints(elasticInferenceServiceComponents), - IMPLEMENTED_TASK_TYPES, - this, - getSender(), - elasticInferenceServiceSettings - ); } private static Map initDefaultEndpoints( @@ -197,7 +169,7 @@ private static Map initDefaultEndpoints( EmptySecretSettings.INSTANCE, elasticInferenceServiceComponents ), - MinimalServiceSettings.chatCompletion(NAME) + CHAT_COMPLETION_V1_MINIMAL_SETTINGS ), DEFAULT_ELSER_2_MODEL_ID, new DefaultModelConfig( @@ -211,7 +183,7 @@ private static Map initDefaultEndpoints( elasticInferenceServiceComponents, ChunkingSettingsBuilder.DEFAULT_SETTINGS ), - MinimalServiceSettings.sparseEmbedding(NAME) + ELSER_V2_MINIMAL_SETTINGS ), DEFAULT_MULTILINGUAL_EMBED_MODEL_ID, new DefaultModelConfig( @@ -230,12 +202,7 @@ private static Map initDefaultEndpoints( elasticInferenceServiceComponents, ChunkingSettingsBuilder.DEFAULT_SETTINGS ), - MinimalServiceSettings.textEmbedding( - NAME, - DENSE_TEXT_EMBEDDINGS_DIMENSIONS, - defaultDenseTextEmbeddingsSimilarity(), - DenseVectorFieldMapper.ElementType.FLOAT - ) + MULTILINGUAL_EMBED_MINIMAL_SETTINGS ), DEFAULT_RERANK_MODEL_ID_V1, new DefaultModelConfig( @@ -248,16 +215,11 @@ private static Map initDefaultEndpoints( EmptySecretSettings.INSTANCE, elasticInferenceServiceComponents ), - MinimalServiceSettings.rerank(NAME) + RERANK_V1_MINIMAL_SETTINGS ) ); } - @Override - public void onNodeStarted() { - authorizationHandler.init(); - } - @Override protected void validateRerankParameters(Boolean returnDocuments, Integer topN, ValidationException validationException) { if (returnDocuments != null) { @@ -270,32 +232,11 @@ protected void validateRerankParameters(Boolean returnDocuments, Integer topN, V } } - /** - * Only use this in tests. - * - * Waits the specified amount of time for the authorization call to complete. This is mainly to make testing easier. - * @param waitTime the max time to wait - * @throws IllegalStateException if the wait time is exceeded or the call receives an {@link InterruptedException} - */ - public void waitForFirstAuthorizationToComplete(TimeValue waitTime) { - authorizationHandler.waitForAuthorizationToComplete(waitTime); - } - @Override public Set supportedStreamingTasks() { return EnumSet.of(TaskType.CHAT_COMPLETION); } - @Override - public List defaultConfigIds() { - return authorizationHandler.defaultConfigIds(); - } - - @Override - public void defaultConfigs(ActionListener> defaultsListener) { - authorizationHandler.defaultConfigs(defaultsListener); - } - @Override protected void doUnifiedCompletionInfer( Model model, @@ -472,7 +413,9 @@ public InferenceServiceConfiguration getConfiguration() { @Override public EnumSet supportedTaskTypes() { - return authorizationHandler.supportedTaskTypes(); + throw new UnsupportedOperationException( + "The EIS supported task types change depending on authorization, requests should be made directly to EIS instead" + ); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceMinimalSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceMinimalSettings.java new file mode 100644 index 0000000000000..52f6f6fa074aa --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceMinimalSettings.java @@ -0,0 +1,108 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.MinimalServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; + +import java.util.Map; +import java.util.Set; + +import static java.util.stream.Collectors.toMap; + +public class ElasticInferenceServiceMinimalSettings { + + // rainbow-sprinkles + public static final String DEFAULT_CHAT_COMPLETION_MODEL_ID_V1 = "rainbow-sprinkles"; + public static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1); + + // elser-2 + public static final String DEFAULT_ELSER_2_MODEL_ID = "elser_model_2"; + public static final String DEFAULT_ELSER_ENDPOINT_ID_V2 = defaultEndpointId("elser-2"); + + // multilingual-text-embed + public static final Integer DENSE_TEXT_EMBEDDINGS_DIMENSIONS = 1024; + public static final String DEFAULT_MULTILINGUAL_EMBED_MODEL_ID = "multilingual-embed-v1"; + public static final String DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID = defaultEndpointId(DEFAULT_MULTILINGUAL_EMBED_MODEL_ID); + + // rerank-v1 + public static final String DEFAULT_RERANK_MODEL_ID_V1 = "rerank-v1"; + public static final String DEFAULT_RERANK_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_RERANK_MODEL_ID_V1); + + public static final Set EIS_PRECONFIGURED_ENDPOINTS = Set.of( + DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, + DEFAULT_ELSER_ENDPOINT_ID_V2, + DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID, + DEFAULT_RERANK_ENDPOINT_ID_V1 + ); + + static final MinimalServiceSettings CHAT_COMPLETION_V1_MINIMAL_SETTINGS = MinimalServiceSettings.chatCompletion( + ElasticInferenceService.NAME + ); + static final MinimalServiceSettings ELSER_V2_MINIMAL_SETTINGS = MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME); + static final MinimalServiceSettings MULTILINGUAL_EMBED_MINIMAL_SETTINGS = MinimalServiceSettings.textEmbedding( + ElasticInferenceService.NAME, + DENSE_TEXT_EMBEDDINGS_DIMENSIONS, + defaultDenseTextEmbeddingsSimilarity(), + DenseVectorFieldMapper.ElementType.FLOAT + ); + static final MinimalServiceSettings RERANK_V1_MINIMAL_SETTINGS = MinimalServiceSettings.rerank(ElasticInferenceService.NAME); + + public record SettingsWithEndpointInfo(String inferenceId, String modelId, MinimalServiceSettings minimalSettings) {} + + private static final Map MODEL_NAME_TO_MINIMAL_SETTINGS = Map.of( + DEFAULT_CHAT_COMPLETION_MODEL_ID_V1, + new SettingsWithEndpointInfo( + DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, + DEFAULT_CHAT_COMPLETION_MODEL_ID_V1, + CHAT_COMPLETION_V1_MINIMAL_SETTINGS + ), + DEFAULT_ELSER_2_MODEL_ID, + new SettingsWithEndpointInfo(DEFAULT_ELSER_ENDPOINT_ID_V2, DEFAULT_ELSER_2_MODEL_ID, ELSER_V2_MINIMAL_SETTINGS), + DEFAULT_MULTILINGUAL_EMBED_MODEL_ID, + new SettingsWithEndpointInfo( + DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID, + DEFAULT_MULTILINGUAL_EMBED_MODEL_ID, + MULTILINGUAL_EMBED_MINIMAL_SETTINGS + ), + DEFAULT_RERANK_MODEL_ID_V1, + new SettingsWithEndpointInfo(DEFAULT_RERANK_ENDPOINT_ID_V1, DEFAULT_RERANK_MODEL_ID_V1, RERANK_V1_MINIMAL_SETTINGS) + ); + + private static final Map INFERENCE_ID_TO_MINIMAL_SETTINGS = MODEL_NAME_TO_MINIMAL_SETTINGS.entrySet() + .stream() + .collect(toMap(e -> e.getValue().inferenceId(), Map.Entry::getValue)); + + public static SimilarityMeasure defaultDenseTextEmbeddingsSimilarity() { + return SimilarityMeasure.COSINE; + } + + public static String defaultEndpointId(String modelId) { + return Strings.format(".%s-elastic", modelId); + } + + public static boolean isEisPreconfiguredEndpoint(String inferenceEntityId) { + return EIS_PRECONFIGURED_ENDPOINTS.contains(inferenceEntityId); + } + + public static boolean containsModelName(String modelName) { + return MODEL_NAME_TO_MINIMAL_SETTINGS.containsKey(modelName); + } + + public static SettingsWithEndpointInfo getWithModelName(String modelName) { + return MODEL_NAME_TO_MINIMAL_SETTINGS.get(modelName); + } + + public static SettingsWithEndpointInfo getWithInferenceId(String inferenceId) { + return INFERENCE_ID_TO_MINIMAL_SETTINGS.get(inferenceId); + } + + private ElasticInferenceServiceMinimalSettings() {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java index 0d8bef246b35d..a2144bc9013ca 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java @@ -28,7 +28,8 @@ public class ElasticInferenceServiceSettings { @Deprecated static final Setting EIS_GATEWAY_URL = Setting.simpleString("xpack.inference.eis.gateway.url", Setting.Property.NodeScope); - static final Setting ELASTIC_INFERENCE_SERVICE_URL = Setting.simpleString( + // public so tests can access it + public static final Setting ELASTIC_INFERENCE_SERVICE_URL = Setting.simpleString( "xpack.inference.elastic.url", Setting.Property.NodeScope ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java deleted file mode 100644 index f83542e7fe740..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java +++ /dev/null @@ -1,336 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.elastic.authorization; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.common.Randomness; -import org.elasticsearch.common.Strings; -import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.InferenceService; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.threadpool.Scheduler; -import org.elasticsearch.xpack.inference.external.http.sender.Sender; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.services.ServiceComponents; -import org.elasticsearch.xpack.inference.services.elastic.DefaultModelConfig; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; - -import java.io.Closeable; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Comparator; -import java.util.EnumSet; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; -import java.util.TreeSet; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; -import java.util.stream.Collectors; - -import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; - -public class ElasticInferenceServiceAuthorizationHandler implements Closeable { - private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceAuthorizationHandler.class); - - private record AuthorizedContent( - ElasticInferenceServiceAuthorizationModel taskTypesAndModels, - List configIds, - List defaultModelConfigs - ) { - static AuthorizedContent empty() { - return new AuthorizedContent(ElasticInferenceServiceAuthorizationModel.newDisabledService(), List.of(), List.of()); - } - } - - private final ServiceComponents serviceComponents; - private final AtomicReference authorizedContent = new AtomicReference<>(AuthorizedContent.empty()); - private final ModelRegistry modelRegistry; - private final ElasticInferenceServiceAuthorizationRequestHandler authorizationHandler; - private final Map defaultModelsConfigs; - private final CountDownLatch firstAuthorizationCompletedLatch = new CountDownLatch(1); - private final EnumSet implementedTaskTypes; - private final InferenceService inferenceService; - private final Sender sender; - private final Runnable callback; - private final AtomicReference lastAuthTask = new AtomicReference<>(null); - private final AtomicBoolean shutdown = new AtomicBoolean(false); - private final ElasticInferenceServiceSettings elasticInferenceServiceSettings; - - public ElasticInferenceServiceAuthorizationHandler( - ServiceComponents serviceComponents, - ModelRegistry modelRegistry, - ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, - Map defaultModelsConfigs, - EnumSet implementedTaskTypes, - InferenceService inferenceService, - Sender sender, - ElasticInferenceServiceSettings elasticInferenceServiceSettings - ) { - this( - serviceComponents, - modelRegistry, - authorizationRequestHandler, - defaultModelsConfigs, - implementedTaskTypes, - Objects.requireNonNull(inferenceService), - sender, - elasticInferenceServiceSettings, - null - ); - } - - // default for testing - ElasticInferenceServiceAuthorizationHandler( - ServiceComponents serviceComponents, - ModelRegistry modelRegistry, - ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, - Map defaultModelsConfigs, - EnumSet implementedTaskTypes, - InferenceService inferenceService, - Sender sender, - ElasticInferenceServiceSettings elasticInferenceServiceSettings, - // this is a hack to facilitate testing - Runnable callback - ) { - this.serviceComponents = Objects.requireNonNull(serviceComponents); - this.modelRegistry = Objects.requireNonNull(modelRegistry); - this.authorizationHandler = Objects.requireNonNull(authorizationRequestHandler); - this.defaultModelsConfigs = Objects.requireNonNull(defaultModelsConfigs); - this.implementedTaskTypes = Objects.requireNonNull(implementedTaskTypes); - // allow the service to be null for testing - this.inferenceService = inferenceService; - this.sender = Objects.requireNonNull(sender); - this.elasticInferenceServiceSettings = Objects.requireNonNull(elasticInferenceServiceSettings); - this.callback = callback; - } - - public void init() { - logger.debug("Initializing authorization logic"); - serviceComponents.threadPool().executor(UTILITY_THREAD_POOL_NAME).execute(this::scheduleAndSendAuthorizationRequest); - } - - /** - * Waits the specified amount of time for the first authorization call to complete. This is mainly to make testing easier. - * @param waitTime the max time to wait - * @throws IllegalStateException if the wait time is exceeded or the call receives an {@link InterruptedException} - */ - public void waitForAuthorizationToComplete(TimeValue waitTime) { - try { - if (firstAuthorizationCompletedLatch.await(waitTime.getSeconds(), TimeUnit.SECONDS) == false) { - throw new IllegalStateException("The wait time has expired for authorization to complete."); - } - } catch (InterruptedException e) { - throw new IllegalStateException("Waiting for authorization to complete was interrupted"); - } - } - - public synchronized Set supportedStreamingTasks() { - var authorizedStreamingTaskTypes = EnumSet.of(TaskType.CHAT_COMPLETION); - authorizedStreamingTaskTypes.retainAll(authorizedContent.get().taskTypesAndModels.getAuthorizedTaskTypes()); - - return authorizedStreamingTaskTypes; - } - - public synchronized List defaultConfigIds() { - return authorizedContent.get().configIds; - } - - public synchronized void defaultConfigs(ActionListener> defaultsListener) { - var models = authorizedContent.get().defaultModelConfigs.stream().map(DefaultModelConfig::model).toList(); - defaultsListener.onResponse(models); - } - - public synchronized EnumSet supportedTaskTypes() { - return authorizedContent.get().taskTypesAndModels.getAuthorizedTaskTypes(); - } - - public synchronized boolean hideFromConfigurationApi() { - return authorizedContent.get().taskTypesAndModels.isAuthorized() == false; - } - - @Override - public void close() throws IOException { - shutdown.set(true); - if (lastAuthTask.get() != null) { - lastAuthTask.get().cancel(); - } - } - - private void scheduleAuthorizationRequest() { - try { - if (elasticInferenceServiceSettings.isPeriodicAuthorizationEnabled() == false) { - return; - } - - // this call has to be on the individual thread otherwise we get an exception - var random = Randomness.get(); - var jitter = (long) (elasticInferenceServiceSettings.getMaxAuthorizationRequestJitter().millis() * random.nextDouble()); - var waitTime = TimeValue.timeValueMillis(elasticInferenceServiceSettings.getAuthRequestInterval().millis() + jitter); - - logger.debug( - () -> Strings.format( - "Scheduling the next authorization call with request interval: %s ms, jitter: %d ms", - elasticInferenceServiceSettings.getAuthRequestInterval().millis(), - jitter - ) - ); - logger.debug(() -> Strings.format("Next authorization call in %d minutes", waitTime.getMinutes())); - - lastAuthTask.set( - serviceComponents.threadPool() - .schedule( - this::scheduleAndSendAuthorizationRequest, - waitTime, - serviceComponents.threadPool().executor(UTILITY_THREAD_POOL_NAME) - ) - ); - } catch (Exception e) { - logger.warn("Failed scheduling authorization request", e); - } - } - - private void scheduleAndSendAuthorizationRequest() { - if (shutdown.get()) { - return; - } - - scheduleAuthorizationRequest(); - sendAuthorizationRequest(); - } - - private void sendAuthorizationRequest() { - try { - ActionListener listener = ActionListener.wrap((model) -> { - setAuthorizedContent(model); - if (callback != null) { - callback.run(); - } - }, e -> { - // we don't need to do anything if there was a failure, everything is disabled by default - firstAuthorizationCompletedLatch.countDown(); - }); - - authorizationHandler.getAuthorization(listener, sender); - } catch (Exception e) { - logger.warn("Failure while sending the request to retrieve authorization", e); - // we don't need to do anything if there was a failure, everything is disabled by default - firstAuthorizationCompletedLatch.countDown(); - } - } - - private synchronized void setAuthorizedContent(ElasticInferenceServiceAuthorizationModel auth) { - logger.debug(() -> Strings.format("Received authorization response, %s", auth)); - - var authorizedTaskTypesAndModels = auth.newLimitedToTaskTypes(EnumSet.copyOf(implementedTaskTypes)); - logger.debug(() -> Strings.format("Authorization entity limited to service task types, %s", authorizedTaskTypesAndModels)); - - // recalculate which default config ids and models are authorized now - var authorizedDefaultModelIds = getAuthorizedDefaultModelIds(authorizedTaskTypesAndModels); - - var authorizedDefaultConfigIds = getAuthorizedDefaultConfigIds(authorizedDefaultModelIds, authorizedTaskTypesAndModels); - var authorizedDefaultModelObjects = getAuthorizedDefaultModelsObjects(authorizedDefaultModelIds); - authorizedContent.set( - new AuthorizedContent(authorizedTaskTypesAndModels, authorizedDefaultConfigIds, authorizedDefaultModelObjects) - ); - - authorizedContent.get().configIds().forEach(modelRegistry::putDefaultIdIfAbsent); - handleRevokedDefaultConfigs(authorizedDefaultModelIds); - } - - private Set getAuthorizedDefaultModelIds(ElasticInferenceServiceAuthorizationModel auth) { - var authorizedModels = auth.getAuthorizedModelIds(); - var authorizedDefaultModelIds = new TreeSet<>(defaultModelsConfigs.keySet()); - authorizedDefaultModelIds.retainAll(authorizedModels); - - return authorizedDefaultModelIds; - } - - private List getAuthorizedDefaultConfigIds( - Set authorizedDefaultModelIds, - ElasticInferenceServiceAuthorizationModel auth - ) { - var authorizedConfigIds = new ArrayList(); - for (var id : authorizedDefaultModelIds) { - var modelConfig = defaultModelsConfigs.get(id); - if (modelConfig != null) { - if (auth.getAuthorizedTaskTypes().contains(modelConfig.model().getTaskType()) == false) { - logger.warn( - org.elasticsearch.common.Strings.format( - "The authorization response included the default model: %s, " - + "but did not authorize the assumed task type of the model: %s. Enabling model.", - id, - modelConfig.model().getTaskType() - ) - ); - } - authorizedConfigIds.add( - new InferenceService.DefaultConfigId( - modelConfig.model().getInferenceEntityId(), - modelConfig.settings(), - inferenceService - ) - ); - } - } - - authorizedConfigIds.sort(Comparator.comparing(InferenceService.DefaultConfigId::inferenceId)); - return authorizedConfigIds; - } - - private List getAuthorizedDefaultModelsObjects(Set authorizedDefaultModelIds) { - var authorizedModels = new ArrayList(); - for (var id : authorizedDefaultModelIds) { - var modelConfig = defaultModelsConfigs.get(id); - if (modelConfig != null) { - authorizedModels.add(modelConfig); - } - } - - authorizedModels.sort(Comparator.comparing(modelConfig -> modelConfig.model().getInferenceEntityId())); - return authorizedModels; - } - - private void handleRevokedDefaultConfigs(Set authorizedDefaultModelIds) { - // if a model was initially returned in the authorization response but is absent, then we'll assume authorization was revoked - var unauthorizedDefaultModelIds = new HashSet<>(defaultModelsConfigs.keySet()); - unauthorizedDefaultModelIds.removeAll(authorizedDefaultModelIds); - - // get all the default inference endpoint ids for the unauthorized model ids - var unauthorizedDefaultInferenceEndpointIds = unauthorizedDefaultModelIds.stream() - .map(defaultModelsConfigs::get) // get all the model configs - .filter(Objects::nonNull) // limit to only non-null - .map(modelConfig -> modelConfig.model().getInferenceEntityId()) // get the inference ids - .collect(Collectors.toSet()); - - var deleteInferenceEndpointsListener = ActionListener.wrap(result -> { - logger.debug(Strings.format("Successfully revoked access to default inference endpoint IDs: %s", unauthorizedDefaultModelIds)); - firstAuthorizationCompletedLatch.countDown(); - }, e -> { - logger.warn( - Strings.format("Failed to revoke access to default inference endpoint IDs: %s, error: %s", unauthorizedDefaultModelIds, e) - ); - firstAuthorizationCompletedLatch.countDown(); - }); - - logger.debug( - () -> Strings.format( - "Synchronizing default inference endpoints, attempting to remove ids: %s", - unauthorizedDefaultInferenceEndpointIds - ) - ); - modelRegistry.removeDefaultConfigs(unauthorizedDefaultInferenceEndpointIds, deleteInferenceEndpointsListener); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java index 02800105ef83d..8f39f1261ee3e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java @@ -77,9 +77,9 @@ public void getAuthorization(ActionListener preconfiguredEndpoints) { + private static final Logger logger = LogManager.getLogger(PreconfiguredEndpointsModel.class); + + public static PreconfiguredEndpointsModel of(ElasticInferenceServiceAuthorizationModel authModel) { + var endpoints = authModel.getAuthorizedModelIds() + .stream() + .filter(ElasticInferenceServiceMinimalSettings::containsModelName) + .map((modelId) -> of(ElasticInferenceServiceMinimalSettings.getWithModelName(modelId))) + .filter(Objects::nonNull) + .collect(Collectors.toMap(PreconfiguredEndpoint::inferenceEntityId, Function.identity())); + + return new PreconfiguredEndpointsModel(endpoints); + } + + private static PreconfiguredEndpoint of(ElasticInferenceServiceMinimalSettings.SettingsWithEndpointInfo settings) { + return switch (settings.minimalSettings().taskType()) { + case TEXT_EMBEDDING -> { + if (settings.minimalSettings().dimensions() == null + || settings.minimalSettings().similarity() == null + || settings.minimalSettings().elementType() == null) { + logger.warn( + "Skipping embedding endpoint [{}] as it is missing required settings. " + + "Dimensions: [{}], Similarity: [{}], Element Type: [{}]", + settings.inferenceId(), + settings.minimalSettings().dimensions(), + settings.minimalSettings().similarity(), + settings.minimalSettings().elementType() + ); + yield null; + } + + yield new EmbeddingPreConfiguredEndpoint( + settings.inferenceId(), + settings.minimalSettings().taskType(), + settings.modelId(), + settings.minimalSettings().similarity(), + settings.minimalSettings().dimensions(), + settings.minimalSettings().elementType() + ); + } + case SPARSE_EMBEDDING, RERANK, COMPLETION, CHAT_COMPLETION -> new BasePreconfiguredEndpoint( + settings.inferenceId(), + settings.minimalSettings().taskType(), + settings.modelId() + ); + case ANY -> null; + }; + } + + public sealed interface PreconfiguredEndpoint permits BasePreconfiguredEndpoint, EmbeddingPreConfiguredEndpoint { + String inferenceEntityId(); + + TaskType taskType(); + + String modelId(); + + UnparsedModel toUnparsedModel(); + } + + private record EmbeddingPreConfiguredEndpoint( + String inferenceEntityId, + TaskType taskType, + String modelId, + SimilarityMeasure similarity, + int dimension, + DenseVectorFieldMapper.ElementType elementType + ) implements PreconfiguredEndpoint { + + @Override + public UnparsedModel toUnparsedModel() { + return new UnparsedModel( + inferenceEntityId, + taskType, + ElasticInferenceService.NAME, + embeddingSettings(modelId, similarity, dimension, elementType), + Map.of() + ); + } + } + + private static Map embeddingSettings( + String modelId, + SimilarityMeasure similarityMeasure, + int dimension, + DenseVectorFieldMapper.ElementType elementType + ) { + return wrapWithServiceSettings( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + modelId, + ServiceFields.SIMILARITY, + similarityMeasure.toString(), + ServiceFields.DIMENSIONS, + dimension, + ServiceFields.ELEMENT_TYPE, + elementType.toString() + ) + ) + ); + } + + private static Map wrapWithServiceSettings(Map settings) { + return new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, settings)); + } + + private record BasePreconfiguredEndpoint(String inferenceEntityId, TaskType taskType, String modelId) implements PreconfiguredEndpoint { + @Override + public UnparsedModel toUnparsedModel() { + return new UnparsedModel(inferenceEntityId, taskType, ElasticInferenceService.NAME, settingsWithModelId(modelId), Map.of()); + } + } + + private static Map settingsWithModelId(String modelId) { + return wrapWithServiceSettings(new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId))); + } + + public UnparsedModel toUnparsedModel(String inferenceId) { + PreconfiguredEndpoint endpoint = preconfiguredEndpoints.get(inferenceId); + if (endpoint == null) { + throw new IllegalArgumentException( + Strings.format( + "No Elastic Inference Service preconfigured endpoint found for inference ID [%s]. " + + "Either it does not exist, or you are not authorized to access it.", + inferenceId + ) + ); + } + + return endpoint.toUnparsedModel(); + } + + public List toUnparsedModels() { + return preconfiguredEndpoints.values() + .stream() + .map(PreconfiguredEndpoint::toUnparsedModel) + .sorted(Comparator.comparing(UnparsedModel::inferenceEntityId)) + .toList(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointsRequestHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointsRequestHandler.java new file mode 100644 index 0000000000000..1e347049d630b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointsRequestHandler.java @@ -0,0 +1,48 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.authorization; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.SubscribableListener; +import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; + +import java.util.List; +import java.util.Objects; + +/** + * This class is responsible for converting the current EIS authorization response structure + * into {@link UnparsedModel}. + */ +public class PreconfiguredEndpointsRequestHandler { + private final ElasticInferenceServiceAuthorizationRequestHandler eisAuthorizationRequestHandler; + private final Sender sender; + + public PreconfiguredEndpointsRequestHandler( + ElasticInferenceServiceAuthorizationRequestHandler eisAuthorizationRequestHandler, + Sender sender + ) { + this.eisAuthorizationRequestHandler = Objects.requireNonNull(eisAuthorizationRequestHandler); + this.sender = Objects.requireNonNull(sender); + } + + public void getPreconfiguredEndpointAsUnparsedModel(String inferenceId, ActionListener listener) { + SubscribableListener.newForked(authListener -> { + eisAuthorizationRequestHandler.getAuthorization(authListener, sender); + }) + .andThenApply(PreconfiguredEndpointsModel::of) + .andThenApply(preconfiguredEndpointsModel -> preconfiguredEndpointsModel.toUnparsedModel(inferenceId)) + .addListener(listener); + } + + public void getAllPreconfiguredEndpointsAsUnparsedModels(ActionListener> listener) { + SubscribableListener.newForked(authListener -> { + eisAuthorizationRequestHandler.getAuthorization(authListener, sender); + }).andThenApply(PreconfiguredEndpointsModel::of).andThenApply(PreconfiguredEndpointsModel::toUnparsedModels).addListener(listener); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index f47d7c4c37261..8d0c1b0533b98 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -83,6 +83,7 @@ import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.model.TestModel; import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.services.elastic.authorization.PreconfiguredEndpointsRequestHandler; import org.junit.After; import org.junit.AssumptionViolatedException; import org.junit.Before; @@ -120,6 +121,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; @@ -138,7 +140,7 @@ public SemanticTextFieldMapperTests(boolean useLegacyFormat) { private void startThreadPool() { threadPool = createThreadPool(); var clusterService = ClusterServiceUtils.createClusterService(threadPool); - var modelRegistry = new ModelRegistry(clusterService, new NoOpClient(threadPool)); + var modelRegistry = new ModelRegistry(clusterService, new NoOpClient(threadPool), mock(PreconfiguredEndpointsRequestHandler.class)); globalModelRegistry = spy(modelRegistry); globalModelRegistry.clusterChanged(new ClusterChangedEvent("init", clusterService.state(), clusterService.state()) { @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/AbstractInterceptedInferenceQueryBuilderTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/AbstractInterceptedInferenceQueryBuilderTestCase.java index 169ae6767303d..4c370ba409736 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/AbstractInterceptedInferenceQueryBuilderTestCase.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/AbstractInterceptedInferenceQueryBuilderTestCase.java @@ -56,6 +56,7 @@ import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.services.elastic.authorization.PreconfiguredEndpointsRequestHandler; import org.junit.AfterClass; import org.junit.BeforeClass; @@ -75,6 +76,7 @@ import static org.hamcrest.Matchers.instanceOf; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; public abstract class AbstractInterceptedInferenceQueryBuilderTestCase> extends MapperServiceTestCase { @@ -605,7 +607,9 @@ protected static void disableQueryInterception(QueryRewriteContext queryRewriteC private static ModelRegistry createModelRegistry(ThreadPool threadPool) { ClusterService clusterService = ClusterServiceUtils.createClusterService(threadPool); - ModelRegistry modelRegistry = spy(new ModelRegistry(clusterService, new NoOpClient(threadPool))); + ModelRegistry modelRegistry = spy( + new ModelRegistry(clusterService, new NoOpClient(threadPool), mock(PreconfiguredEndpointsRequestHandler.class)) + ); modelRegistry.clusterChanged(new ClusterChangedEvent("init", clusterService.state(), clusterService.state()) { @Override public boolean localNodeMaster() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticMultiMatchQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticMultiMatchQueryBuilderTests.java index b54ca946e6179..cce0f8070f7c5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticMultiMatchQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticMultiMatchQueryBuilderTests.java @@ -25,6 +25,7 @@ import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.services.elastic.authorization.PreconfiguredEndpointsRequestHandler; import org.junit.AfterClass; import org.junit.BeforeClass; @@ -32,6 +33,8 @@ import java.util.List; import java.util.function.Supplier; +import static org.mockito.Mockito.mock; + public class SemanticMultiMatchQueryBuilderTests extends MapperServiceTestCase { private static TestThreadPool threadPool; private static ModelRegistry modelRegistry; @@ -51,7 +54,7 @@ protected Supplier getModelRegistry() { public static void startModelRegistry() { threadPool = new TestThreadPool(SemanticMultiMatchQueryBuilderTests.class.getName()); var clusterService = ClusterServiceUtils.createClusterService(threadPool); - modelRegistry = new ModelRegistry(clusterService, new NoOpClient(threadPool)); + modelRegistry = new ModelRegistry(clusterService, new NoOpClient(threadPool), mock(PreconfiguredEndpointsRequestHandler.class)); modelRegistry.clusterChanged(new ClusterChangedEvent("init", clusterService.state(), clusterService.state()) { @Override public boolean localNodeMaster() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java index b2d7218720a57..4e9ccf7607c81 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java @@ -75,6 +75,7 @@ import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.services.elastic.authorization.PreconfiguredEndpointsRequestHandler; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; @@ -100,6 +101,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.notNullValue; +import static org.mockito.Mockito.mock; public class SemanticQueryBuilderTests extends AbstractQueryTestCase { private static final String SEMANTIC_TEXT_FIELD = "semantic"; @@ -144,7 +146,7 @@ public static void setInferenceResultType() { public static void startModelRegistry() { threadPool = new TestThreadPool(SemanticQueryBuilderTests.class.getName()); var clusterService = ClusterServiceUtils.createClusterService(threadPool); - modelRegistry = new ModelRegistry(clusterService, new NoOpClient(threadPool)); + modelRegistry = new ModelRegistry(clusterService, new NoOpClient(threadPool), mock(PreconfiguredEndpointsRequestHandler.class)); modelRegistry.clusterChanged(new ClusterChangedEvent("init", clusterService.state(), clusterService.state()) { @Override public boolean localNodeMaster() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/InferenceEndpointRegistryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/InferenceEndpointRegistryTests.java index b172f0e264c79..ad2036a40c06d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/InferenceEndpointRegistryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/InferenceEndpointRegistryTests.java @@ -24,8 +24,8 @@ import java.util.concurrent.TimeUnit; import static org.elasticsearch.xpack.inference.registry.ModelRegistryTests.assertStoreModel; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.sameInstance; public class InferenceEndpointRegistryTests extends ESSingleNodeTestCase { @@ -48,7 +48,7 @@ public void createComponents() { public void testGetThrowsResourceNotFoundWhenNoHitsReturned() { assertThat( getEndpointException("this is not found", ResourceNotFoundException.class).getMessage(), - is("Inference endpoint not found [this is not found]") + containsString("Inference endpoint [this is not found] not found ") ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java index eee8550ec6524..5af7a3fd8634c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java @@ -25,6 +25,8 @@ import org.elasticsearch.test.ESSingleNodeTestCase; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.model.TestModel; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceMinimalSettings; import org.junit.Before; import java.util.ArrayList; @@ -65,7 +67,7 @@ public void testGetUnparsedModelMap_ThrowsResourceNotFound_WhenNoHitsReturned() registry.getModelWithSecrets("1", listener); ResourceNotFoundException exception = expectThrows(ResourceNotFoundException.class, () -> listener.actionGet(TIMEOUT)); - assertThat(exception.getMessage(), is("Inference endpoint not found [1]")); + assertThat(exception.getMessage(), containsString("Inference endpoint [1] not found")); } public void testGetModelWithSecrets() { @@ -237,6 +239,42 @@ public void testDuplicateDefaultIds() { ); } + public void testGetMinimalServiceSettings_ThrowsResourceNotFound_WhenInferenceIdDoesNotExist() { + var exception = expectThrows(ResourceNotFoundException.class, () -> registry.getMinimalServiceSettings("non_existent_id")); + assertThat(exception.getMessage(), containsString("non_existent_id does not exist in this cluster.")); + } + + public void testGetMinimalServiceSettings_ReturnsEisPreconfiguredEndpoint() { + { + var minimalSettings = registry.getMinimalServiceSettings( + ElasticInferenceServiceMinimalSettings.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 + ); + + assertThat(minimalSettings.service(), is(ElasticInferenceService.NAME)); + assertThat(minimalSettings.taskType(), is(TaskType.CHAT_COMPLETION)); + } + { + var minimalSettings = registry.getMinimalServiceSettings(ElasticInferenceServiceMinimalSettings.DEFAULT_ELSER_ENDPOINT_ID_V2); + + assertThat(minimalSettings.service(), is(ElasticInferenceService.NAME)); + assertThat(minimalSettings.taskType(), is(TaskType.SPARSE_EMBEDDING)); + } + { + var minimalSettings = registry.getMinimalServiceSettings( + ElasticInferenceServiceMinimalSettings.DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID + ); + + assertThat(minimalSettings.service(), is(ElasticInferenceService.NAME)); + assertThat(minimalSettings.taskType(), is(TaskType.TEXT_EMBEDDING)); + } + { + var minimalSettings = registry.getMinimalServiceSettings(ElasticInferenceServiceMinimalSettings.DEFAULT_RERANK_ENDPOINT_ID_V1); + + assertThat(minimalSettings.service(), is(ElasticInferenceService.NAME)); + assertThat(minimalSettings.taskType(), is(TaskType.RERANK)); + } + } + public static void assertStoreModel(ModelRegistry registry, Model model) { PlainActionFuture storeListener = new PlainActionFuture<>(); registry.storeModel(model, storeListener, TimeValue.THIRTY_SECONDS); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index b8a82d6a7a29c..7a91e694eb9ce 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -17,16 +17,13 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.inference.EmptyTaskSettings; -import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; -import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnifiedCompletionRequest; @@ -921,8 +918,6 @@ public void testChunkedInfer_BatchesCallsChunkingSettingsSet() throws IOExceptio public void testHideFromConfigurationApi_ThrowsUnsupported_WithNoAvailableModels() throws Exception { try (var service = createServiceWithMockSender(ElasticInferenceServiceAuthorizationModel.newDisabledService())) { - ensureAuthorizationCallFinished(service); - expectThrows(UnsupportedOperationException.class, service::hideFromConfigurationApi); } } @@ -942,119 +937,86 @@ public void testHideFromConfigurationApi_ThrowsUnsupported_WithAvailableModels() ) ) ) { - ensureAuthorizationCallFinished(service); - expectThrows(UnsupportedOperationException.class, service::hideFromConfigurationApi); } } public void testCreateConfiguration() throws Exception { - try ( - var service = createServiceWithMockSender( - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "model-1", - EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.TEXT_EMBEDDING) - ) - ) - ) - ) - ) - ) { - ensureAuthorizationCallFinished(service); - - String content = XContentHelper.stripWhitespace(""" - { - "service": "elastic", - "name": "Elastic", - "task_types": ["sparse_embedding", "chat_completion", "text_embedding"], - "configurations": { - "model_id": { - "description": "The name of the model to use for the inference task.", - "label": "Model ID", - "required": true, - "sensitive": false, - "updatable": false, - "type": "str", - "supported_task_types": ["text_embedding", "sparse_embedding" , "rerank", "chat_completion"] - }, - "max_input_tokens": { - "description": "Allows you to specify the maximum number of tokens per input.", - "label": "Maximum Input Tokens", - "required": false, - "sensitive": false, - "updatable": false, - "type": "int", - "supported_task_types": ["text_embedding", "sparse_embedding"] - } + String content = XContentHelper.stripWhitespace(""" + { + "service": "elastic", + "name": "Elastic", + "task_types": ["sparse_embedding", "chat_completion", "text_embedding"], + "configurations": { + "model_id": { + "description": "The name of the model to use for the inference task.", + "label": "Model ID", + "required": true, + "sensitive": false, + "updatable": false, + "type": "str", + "supported_task_types": ["text_embedding", "sparse_embedding" , "rerank", "chat_completion"] + }, + "max_input_tokens": { + "description": "Allows you to specify the maximum number of tokens per input.", + "label": "Maximum Input Tokens", + "required": false, + "sensitive": false, + "updatable": false, + "type": "int", + "supported_task_types": ["text_embedding", "sparse_embedding"] } } - """); - InferenceServiceConfiguration configuration = InferenceServiceConfiguration.fromXContentBytes( - new BytesArray(content), - XContentType.JSON - ); - boolean humanReadable = true; - BytesReference originalBytes = toShuffledXContent(configuration, XContentType.JSON, ToXContent.EMPTY_PARAMS, humanReadable); - InferenceServiceConfiguration serviceConfiguration = ElasticInferenceService.createConfiguration( - EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.TEXT_EMBEDDING) - ); - assertToXContentEquivalent( - originalBytes, - toXContent(serviceConfiguration, XContentType.JSON, humanReadable), - XContentType.JSON - ); - } + } + """); + InferenceServiceConfiguration configuration = InferenceServiceConfiguration.fromXContentBytes( + new BytesArray(content), + XContentType.JSON + ); + boolean humanReadable = true; + BytesReference originalBytes = toShuffledXContent(configuration, XContentType.JSON, ToXContent.EMPTY_PARAMS, humanReadable); + InferenceServiceConfiguration serviceConfiguration = ElasticInferenceService.createConfiguration( + EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.TEXT_EMBEDDING) + ); + assertToXContentEquivalent(originalBytes, toXContent(serviceConfiguration, XContentType.JSON, humanReadable), XContentType.JSON); } public void testGetConfiguration_WithoutSupportedTaskTypes() throws Exception { - try (var service = createServiceWithMockSender(ElasticInferenceServiceAuthorizationModel.newDisabledService())) { - ensureAuthorizationCallFinished(service); - - String content = XContentHelper.stripWhitespace(""" - { - "service": "elastic", - "name": "Elastic", - "task_types": [], - "configurations": { - "model_id": { - "description": "The name of the model to use for the inference task.", - "label": "Model ID", - "required": true, - "sensitive": false, - "updatable": false, - "type": "str", - "supported_task_types": ["text_embedding", "sparse_embedding" , "rerank", "chat_completion"] - }, - "max_input_tokens": { - "description": "Allows you to specify the maximum number of tokens per input.", - "label": "Maximum Input Tokens", - "required": false, - "sensitive": false, - "updatable": false, - "type": "int", - "supported_task_types": ["text_embedding", "sparse_embedding"] - } + String content = XContentHelper.stripWhitespace(""" + { + "service": "elastic", + "name": "Elastic", + "task_types": [], + "configurations": { + "model_id": { + "description": "The name of the model to use for the inference task.", + "label": "Model ID", + "required": true, + "sensitive": false, + "updatable": false, + "type": "str", + "supported_task_types": ["text_embedding", "sparse_embedding" , "rerank", "chat_completion"] + }, + "max_input_tokens": { + "description": "Allows you to specify the maximum number of tokens per input.", + "label": "Maximum Input Tokens", + "required": false, + "sensitive": false, + "updatable": false, + "type": "int", + "supported_task_types": ["text_embedding", "sparse_embedding"] } } - """); - InferenceServiceConfiguration configuration = InferenceServiceConfiguration.fromXContentBytes( - new BytesArray(content), - XContentType.JSON - ); - boolean humanReadable = true; - BytesReference originalBytes = toShuffledXContent(configuration, XContentType.JSON, ToXContent.EMPTY_PARAMS, humanReadable); - InferenceServiceConfiguration serviceConfiguration = ElasticInferenceService.createConfiguration( - EnumSet.noneOf(TaskType.class) - ); - assertToXContentEquivalent( - originalBytes, - toXContent(serviceConfiguration, XContentType.JSON, humanReadable), - XContentType.JSON - ); - } + } + """); + InferenceServiceConfiguration configuration = InferenceServiceConfiguration.fromXContentBytes( + new BytesArray(content), + XContentType.JSON + ); + boolean humanReadable = true; + BytesReference originalBytes = toShuffledXContent(configuration, XContentType.JSON, ToXContent.EMPTY_PARAMS, humanReadable); + InferenceServiceConfiguration serviceConfiguration = ElasticInferenceService.createConfiguration(EnumSet.noneOf(TaskType.class)); + assertToXContentEquivalent(originalBytes, toXContent(serviceConfiguration, XContentType.JSON, humanReadable), XContentType.JSON); } public void testGetConfiguration_ThrowsUnsupported() throws Exception { @@ -1073,8 +1035,6 @@ public void testGetConfiguration_ThrowsUnsupported() throws Exception { ) ) ) { - ensureAuthorizationCallFinished(service); - expectThrows(UnsupportedOperationException.class, service::getConfiguration); } } @@ -1095,8 +1055,6 @@ public void testSupportedStreamingTasks_ReturnsChatCompletion_WhenAuthRespondsWi var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { - ensureAuthorizationCallFinished(service); - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); assertTrue(service.defaultConfigIds().isEmpty()); @@ -1107,200 +1065,36 @@ public void testSupportedStreamingTasks_ReturnsChatCompletion_WhenAuthRespondsWi } } - public void testSupportedTaskTypes_Returns_TheAuthorizedTaskTypes_IgnoresUnimplementedTaskTypes() throws Exception { - String responseJson = """ - { - "models": [ - { - "model_name": "model-a", - "task_types": ["embed/text/sparse"] - }, - { - "model_name": "model-b", - "task_types": ["embed"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - + public void testSupportedTaskTypes_ThrowsUnsupportedException() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { - ensureAuthorizationCallFinished(service); - - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING))); + expectThrows(UnsupportedOperationException.class, service::supportedTaskTypes); } } - public void testSupportedTaskTypes_Returns_TheAuthorizedTaskTypes() throws Exception { - String responseJson = """ - { - "models": [ - { - "model_name": "model-a", - "task_types": ["embed/text/sparse"] - }, - { - "model_name": "model-b", - "task_types": ["chat"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - + public void testSupportedStreamingTasks_ReturnsChatCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { - ensureAuthorizationCallFinished(service); - - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION))); - } - } - - public void testSupportedStreamingTasks_ReturnsEmpty_WhenAuthRespondsWithoutChatCompletion() throws Exception { - String responseJson = """ - { - "models": [ - { - "model_name": "model-a", - "task_types": ["embed/text/sparse"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { - ensureAuthorizationCallFinished(service); - - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - assertTrue(service.defaultConfigIds().isEmpty()); - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING))); - - PlainActionFuture> listener = new PlainActionFuture<>(); - service.defaultConfigs(listener); - assertTrue(listener.actionGet(TIMEOUT).isEmpty()); - } - } - - public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsIncorrect() throws Exception { - String responseJson = """ - { - "models": [ - { - "model_name": "rainbow-sprinkles", - "task_types": ["embed/text/sparse"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { - ensureAuthorizationCallFinished(service); assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - assertThat( - service.defaultConfigIds(), - is( - List.of( - new InferenceService.DefaultConfigId( - ".rainbow-sprinkles-elastic", - MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME), - service - ) - ) - ) - ); - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING))); + assertThat(service.defaultConfigIds(), is(List.of())); + expectThrows(UnsupportedOperationException.class, service::supportedTaskTypes); PlainActionFuture> listener = new PlainActionFuture<>(); service.defaultConfigs(listener); - assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); + assertThat(listener.actionGet(TIMEOUT), is(List.of())); } } - public void testDefaultConfigs_Returns_DefaultEndpoints_WhenTaskTypeIsCorrect() throws Exception { - String responseJson = """ - { - "models": [ - { - "model_name": "rainbow-sprinkles", - "task_types": ["chat"] - }, - { - "model_name": "elser_model_2", - "task_types": ["embed/text/sparse"] - }, - { - "model_name": "multilingual-embed-v1", - "task_types": ["embed/text/dense"] - }, - { - "model_name": "rerank-v1", - "task_types": ["rerank/text/text-similarity"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - + public void testDefaultConfigs_ReturnsEmptyList() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { - ensureAuthorizationCallFinished(service); assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - assertFalse(service.canStream(TaskType.ANY)); - assertThat( - service.defaultConfigIds(), - is( - List.of( - new InferenceService.DefaultConfigId( - ".elser-2-elastic", - MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME), - service - ), - new InferenceService.DefaultConfigId( - ".multilingual-embed-v1-elastic", - MinimalServiceSettings.textEmbedding( - ElasticInferenceService.NAME, - ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS, - ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(), - DenseVectorFieldMapper.ElementType.FLOAT - ), - service - ), - new InferenceService.DefaultConfigId( - ".rainbow-sprinkles-elastic", - MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME), - service - ), - new InferenceService.DefaultConfigId( - ".rerank-v1-elastic", - MinimalServiceSettings.rerank(ElasticInferenceService.NAME), - service - ) - ) - ) - ); - assertThat( - service.supportedTaskTypes(), - is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.RERANK, TaskType.TEXT_EMBEDDING)) - ); + assertThat(service.defaultConfigIds(), is(List.of())); + expectThrows(UnsupportedOperationException.class, service::supportedTaskTypes); PlainActionFuture> listener = new PlainActionFuture<>(); service.defaultConfigs(listener); - var models = listener.actionGet(TIMEOUT); - assertThat(models.size(), is(4)); - assertThat(models.get(0).getConfigurations().getInferenceEntityId(), is(".elser-2-elastic")); - assertThat(models.get(1).getConfigurations().getInferenceEntityId(), is(".multilingual-embed-v1-elastic")); - assertThat(models.get(2).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); - assertThat(models.get(3).getConfigurations().getInferenceEntityId(), is(".rerank-v1-elastic")); + assertThat(listener.actionGet(TIMEOUT), is(List.of())); } } @@ -1392,11 +1186,6 @@ private InferenceEventsAssertion testUnifiedStream(int responseCode, String resp } } - private void ensureAuthorizationCallFinished(ElasticInferenceService service) { - service.onNodeStarted(); - service.waitForFirstAuthorizationToComplete(TIMEOUT); - } - private ElasticInferenceService createServiceWithMockSender() { return createServiceWithMockSender(ElasticInferenceServiceAuthorizationModelTests.createEnabledAuth()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java deleted file mode 100644 index e42430b6512f5..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java +++ /dev/null @@ -1,283 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.elastic.authorization; - -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.support.PlainActionFuture; -import org.elasticsearch.common.util.concurrent.DeterministicTaskQueue; -import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.EmptySecretSettings; -import org.elasticsearch.inference.EmptyTaskSettings; -import org.elasticsearch.inference.InferenceService; -import org.elasticsearch.inference.MinimalServiceSettings; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.plugins.Plugin; -import org.elasticsearch.test.ESSingleNodeTestCase; -import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; -import org.elasticsearch.xpack.inference.Utils; -import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; -import org.elasticsearch.xpack.inference.external.http.sender.Sender; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.services.elastic.DefaultModelConfig; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettingsTests; -import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; -import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; -import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity; -import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; -import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; -import org.junit.Before; - -import java.io.IOException; -import java.util.Collection; -import java.util.EnumSet; -import java.util.List; -import java.util.Map; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; - -import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; -import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.defaultEndpointId; -import static org.hamcrest.CoreMatchers.is; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; - -public class ElasticInferenceServiceAuthorizationHandlerTests extends ESSingleNodeTestCase { - private DeterministicTaskQueue taskQueue; - private ModelRegistry modelRegistry; - - @Override - protected Collection> getPlugins() { - return List.of(LocalStateInferencePlugin.class); - } - - @Before - public void init() throws Exception { - taskQueue = new DeterministicTaskQueue(); - modelRegistry = getInstanceFromNode(ModelRegistry.class); - } - - public void testSecondAuthResultRevokesAuthorization() throws Exception { - var callbackCount = new AtomicInteger(0); - // we're only interested in two authorization calls which is why I'm using a value of 2 here - var latch = new CountDownLatch(2); - final AtomicReference handlerRef = new AtomicReference<>(); - - Runnable callback = () -> { - // the first authorization response contains a streaming task so we're expecting to support streaming here - if (callbackCount.incrementAndGet() == 1) { - assertThat(handlerRef.get().supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - } - latch.countDown(); - - // we only want to run the tasks twice, so advance the time on the queue - // which flags the scheduled authorization request to be ready to run - if (callbackCount.get() == 1) { - taskQueue.advanceTime(); - } else { - try { - handlerRef.get().close(); - } catch (IOException e) { - // ignore - } - } - }; - - var requestHandler = mockAuthorizationRequestHandler( - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "rainbow-sprinkles", - EnumSet.of(TaskType.CHAT_COMPLETION) - ) - ) - ) - ), - ElasticInferenceServiceAuthorizationModel.of(new ElasticInferenceServiceAuthorizationResponseEntity(List.of())) - ); - - handlerRef.set( - new ElasticInferenceServiceAuthorizationHandler( - createWithEmptySettings(taskQueue.getThreadPool()), - modelRegistry, - requestHandler, - initDefaultEndpoints(), - EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION), - null, - mock(Sender.class), - ElasticInferenceServiceSettingsTests.create(null, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), - callback - ) - ); - - var handler = handlerRef.get(); - handler.init(); - taskQueue.runAllRunnableTasks(); - latch.await(Utils.TIMEOUT.getSeconds(), TimeUnit.SECONDS); - - // this should be after we've received both authorization responses, the second response will revoke authorization - - assertThat(handler.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class))); - assertThat(handler.defaultConfigIds(), is(List.of())); - assertThat(handler.supportedTaskTypes(), is(EnumSet.noneOf(TaskType.class))); - - PlainActionFuture> listener = new PlainActionFuture<>(); - handler.defaultConfigs(listener); - - var configs = listener.actionGet(); - assertThat(configs.size(), is(0)); - } - - public void testSendsAnAuthorizationRequestTwice() throws Exception { - var callbackCount = new AtomicInteger(0); - // we're only interested in two authorization calls which is why I'm using a value of 2 here - var latch = new CountDownLatch(2); - final AtomicReference handlerRef = new AtomicReference<>(); - - Runnable callback = () -> { - // the first authorization response does not contain a streaming task so we're expecting to not support streaming here - if (callbackCount.incrementAndGet() == 1) { - assertThat(handlerRef.get().supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class))); - } - latch.countDown(); - - // we only want to run the tasks twice, so advance the time on the queue - // which flags the scheduled authorization request to be ready to run - if (callbackCount.get() == 1) { - taskQueue.advanceTime(); - } else { - try { - handlerRef.get().close(); - } catch (IOException e) { - // ignore - } - } - }; - - var requestHandler = mockAuthorizationRequestHandler( - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel("abc", EnumSet.of(TaskType.SPARSE_EMBEDDING)) - ) - ) - ), - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "abc", - EnumSet.of(TaskType.SPARSE_EMBEDDING) - ), - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "rainbow-sprinkles", - EnumSet.of(TaskType.CHAT_COMPLETION) - ) - ) - ) - ) - ); - - handlerRef.set( - new ElasticInferenceServiceAuthorizationHandler( - createWithEmptySettings(taskQueue.getThreadPool()), - modelRegistry, - requestHandler, - initDefaultEndpoints(), - EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION), - null, - mock(Sender.class), - ElasticInferenceServiceSettingsTests.create(null, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), - callback - ) - ); - - var handler = handlerRef.get(); - handler.init(); - taskQueue.runAllRunnableTasks(); - latch.await(Utils.TIMEOUT.getSeconds(), TimeUnit.SECONDS); - // this should be after we've received both authorization responses - - assertThat(handler.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - assertThat( - handler.defaultConfigIds(), - is( - List.of( - new InferenceService.DefaultConfigId( - ".rainbow-sprinkles-elastic", - MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME), - null - ) - ) - ) - ); - assertThat(handler.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION))); - - PlainActionFuture> listener = new PlainActionFuture<>(); - handler.defaultConfigs(listener); - - var configs = listener.actionGet(); - assertThat(configs.get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); - } - - private static ElasticInferenceServiceAuthorizationRequestHandler mockAuthorizationRequestHandler( - ElasticInferenceServiceAuthorizationModel firstAuthResponse, - ElasticInferenceServiceAuthorizationModel secondAuthResponse - ) { - var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); - listener.onResponse(firstAuthResponse); - return Void.TYPE; - }).doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); - listener.onResponse(secondAuthResponse); - return Void.TYPE; - }).when(mockAuthHandler).getAuthorization(any(), any()); - - return mockAuthHandler; - } - - private static Map initDefaultEndpoints() { - return Map.of( - "rainbow-sprinkles", - new DefaultModelConfig( - new ElasticInferenceServiceCompletionModel( - defaultEndpointId("rainbow-sprinkles"), - TaskType.CHAT_COMPLETION, - "test", - new ElasticInferenceServiceCompletionServiceSettings("rainbow-sprinkles"), - EmptyTaskSettings.INSTANCE, - EmptySecretSettings.INSTANCE, - ElasticInferenceServiceComponents.EMPTY_INSTANCE - ), - MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME) - ), - "elser-2", - new DefaultModelConfig( - new ElasticInferenceServiceSparseEmbeddingsModel( - defaultEndpointId("elser-2"), - TaskType.SPARSE_EMBEDDING, - "test", - new ElasticInferenceServiceSparseEmbeddingsServiceSettings("elser-2", null), - EmptyTaskSettings.INSTANCE, - EmptySecretSettings.INSTANCE, - ElasticInferenceServiceComponents.EMPTY_INSTANCE, - ChunkingSettingsBuilder.DEFAULT_SETTINGS - ), - MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME) - ) - ); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java index e3d24ea2ec8f7..6657bb71f6848 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java @@ -80,10 +80,8 @@ public void testDoesNotAttempt_ToRetrieveAuthorization_IfBaseUrlIsNull() throws PlainActionFuture listener = new PlainActionFuture<>(); authHandler.getAuthorization(listener, sender); - var authResponse = listener.actionGet(TIMEOUT); - assertTrue(authResponse.getAuthorizedTaskTypes().isEmpty()); - assertTrue(authResponse.getAuthorizedModelIds().isEmpty()); - assertFalse(authResponse.isAuthorized()); + var exception = expectThrows(IllegalStateException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(exception.getMessage(), containsString("The Elastic Inference Service URL is not configured.")); var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); verify(logger, times(2)).debug(loggerArgsCaptor.capture()); @@ -102,10 +100,8 @@ public void testDoesNotAttempt_ToRetrieveAuthorization_IfBaseUrlIsEmpty() throws PlainActionFuture listener = new PlainActionFuture<>(); authHandler.getAuthorization(listener, sender); - var authResponse = listener.actionGet(TIMEOUT); - assertTrue(authResponse.getAuthorizedTaskTypes().isEmpty()); - assertTrue(authResponse.getAuthorizedModelIds().isEmpty()); - assertFalse(authResponse.isAuthorized()); + var exception = expectThrows(IllegalStateException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(exception.getMessage(), containsString("The Elastic Inference Service URL is not configured.")); var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); verify(logger, times(2)).debug(loggerArgsCaptor.capture()); diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/40_semantic_text_query.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/40_semantic_text_query.yml index 0b1a611bcdf72..a528ac9090168 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/40_semantic_text_query.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/40_semantic_text_query.yml @@ -781,7 +781,10 @@ setup: - match: { hits.total.value: 0 } --- -"Query a field with an invalid inference ID": +"Query a field with an invalid inference ID v2": + - skip: + reason: "contains is a newly added assertion" + features: contains - do: indices.create: index: test-index-with-invalid-inference-id @@ -803,7 +806,7 @@ setup: query: "inference test" - match: { error.type: "resource_not_found_exception" } - - match: { error.reason: "Inference endpoint not found [invalid-inference-id]" } + - contains: { error.reason: "Inference endpoint [invalid-inference-id] not found" } --- "Query a field with a search inference ID that uses the wrong task type": @@ -896,7 +899,10 @@ setup: compatible with the inference endpoint [dense-inference-id]?" } --- -"Query a field with an invalid search inference ID": +"Query a field with an invalid search inference ID v2": + - skip: + reason: "contains is a newly added assertion" + features: contains - do: indices.put_mapping: index: test-dense-index @@ -927,7 +933,7 @@ setup: query: "inference test" - match: { error.type: "resource_not_found_exception" } - - match: { error.reason: "Inference endpoint not found [invalid-inference-id]" } + - contains: { error.reason: "Inference endpoint [invalid-inference-id] not found" } --- "Query a field that uses the default ELSER 2 endpoint": diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml index d971aad2bbc4b..261cf92c3553f 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml @@ -181,9 +181,9 @@ setup: - match: { hits.hits.0._id: "doc_1" } --- -"Text similarity reranking fails if the inference ID does not exist": +"Text similarity reranking fails if the inference ID does not exist v2": - do: - catch: /Inference endpoint not found/ + catch: /Inference endpoint \[.*?\] not found/ search: index: test-index body: @@ -206,13 +206,13 @@ setup: size: 10 --- -"Text similarity reranking fails if the inference ID does not exist and result set is empty": +"Text similarity reranking fails if the inference ID does not exist and result set is empty v2": - requires: cluster_features: "gte_v8.15.1" reason: bug fixed in 8.15.1 - do: - catch: /Inference endpoint not found/ + catch: /Inference endpoint \[.*?\] not found/ search: index: test-index body: diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_crud.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_crud.yml index 62a49422079b8..5444d70539a7b 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_crud.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_crud.yml @@ -1,11 +1,14 @@ --- -"Test get missing model": +"Test get missing model v2": + - skip: + reason: "contains is a newly added assertion" + features: contains - do: catch: missing inference.get: inference_id: inference_to_get - match: { error.type: "resource_not_found_exception" } - - match: { error.reason: "Inference endpoint not found [inference_to_get]" } + - contains: { error.reason: "Inference endpoint [inference_to_get] not found" } --- "Test put inference with bad task type":