diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java index e9ba9923fdcf8..ecc3bcd508bb6 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java @@ -23,6 +23,15 @@ public class InferenceGetModelsWithElasticInferenceServiceIT extends BaseMockEISAuthServerTest { + /** + * This is done before the class because I've run into issues where another class that extends {@link BaseMockEISAuthServerTest} + * results in an authorization response not being queued up for the new Elasticsearch Node in time. When the node starts up, it + * retrieves authorization. If the request isn't queued up when that happens the tests will fail. From my testing locally it seems + * like the base class's static functionality to queue a response is only done once and not for each subclass. + * + * My understanding is that the @Before will be run after the node starts up and wouldn't be sufficient to handle + * this scenario. That is why this needs to be @BeforeClass. + */ @BeforeClass public static void init() { // Ensure the mock EIS server has an authorized response ready diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index 21e933292a0ed..4bf874674df6d 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -12,6 +12,7 @@ import org.elasticsearch.client.Request; import org.elasticsearch.common.Strings; import org.elasticsearch.inference.TaskType; +import org.junit.Before; import org.junit.BeforeClass; import java.io.IOException; @@ -23,6 +24,23 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest { + @Before + public void setUp() throws Exception { + super.setUp(); + // Ensure the mock EIS server has an authorized response ready before each test because each test will + // use the services API which makes a call to EIS + mockEISServer.enqueueAuthorizeAllModelsResponse(); + } + + /** + * This is done before the class because I've run into issues where another class that extends {@link BaseMockEISAuthServerTest} + * results in an authorization response not being queued up for the new Elasticsearch Node in time. When the node starts up, it + * retrieves authorization. If the request isn't queued up when that happens the tests will fail. From my testing locally it seems + * like the base class's static functionality to queue a response is only done once and not for each subclass. + * + * My understanding is that the @Before will be run after the node starts up and wouldn't be sufficient to handle + * this scenario. That is why this needs to be @BeforeClass. + */ @BeforeClass public static void init() { // Ensure the mock EIS server has an authorized response ready diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestCompletionServiceExtension.java index 9c15ac77cc13f..728c39b634bd0 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestCompletionServiceExtension.java @@ -178,6 +178,7 @@ public static InferenceServiceConfiguration get() { ); return new InferenceServiceConfiguration.Builder().setService(NAME) + .setName(NAME) .setTaskTypes(SUPPORTED_TASK_TYPES) .setConfigurations(configurationMap) .build(); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index 044af0ab1d37d..051b6dbf3e8fa 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -283,6 +283,7 @@ public static InferenceServiceConfiguration get() { ); return new InferenceServiceConfiguration.Builder().setService(NAME) + .setName(NAME) .setTaskTypes(supportedTaskTypes) .setConfigurations(configurationMap) .build(); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java index c1cf64b9f2ae8..962fc9e1ee818 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java @@ -228,6 +228,7 @@ public static InferenceServiceConfiguration get() { ); return new InferenceServiceConfiguration.Builder().setService(NAME) + .setName(NAME) .setTaskTypes(supportedTaskTypes) .setConfigurations(configurationMap) .build(); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index 59a53fae137ee..86dcb56fa369d 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -224,6 +224,7 @@ public static InferenceServiceConfiguration get() { ); return new InferenceServiceConfiguration.Builder().setService(NAME) + .setName(NAME) .setTaskTypes(supportedTaskTypes) .setConfigurations(configurationMap) .build(); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java index f70e0884879ea..28a191a1bbfac 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java @@ -326,6 +326,7 @@ public static InferenceServiceConfiguration get() { ); return new InferenceServiceConfiguration.Builder().setService(NAME) + .setName(NAME) .setTaskTypes(supportedTaskTypes) .setConfigurations(configurationMap) .build(); diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java index 182239359c889..72109e43bb6ac 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java @@ -176,7 +176,7 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty() try (var service = createElasticInferenceService()) { ensureAuthorizationCallFinished(service); - assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class))); + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); assertTrue(service.defaultConfigIds().isEmpty()); assertThat(service.supportedTaskTypes(), is(EnumSet.noneOf(TaskType.class))); @@ -299,7 +299,7 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA try (var service = createElasticInferenceService()) { ensureAuthorizationCallFinished(service); - assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class))); + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); assertThat( service.defaultConfigIds(), containsInAnyOrder( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 70e748c526a53..e4d66b92d5274 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -88,6 +88,7 @@ import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceSettings; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.highlight.SemanticTextHighlighter; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.mapper.OffsetSourceFieldMapper; @@ -382,6 +383,8 @@ public Collection createComponents(PluginServices services) { new TransportGetInferenceDiagnosticsAction.ClientManagers(httpClientManager, elasticInferenceServiceHttpClientManager) ); components.add(inferenceStatsBinding); + components.add(authorizationHandler); + components.add(new PluginComponentBinding<>(Sender.class, elasicInferenceServiceFactory.get().createSender())); components.add( new InferenceEndpointRegistry( services.clusterService(), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java index 0d8d7e81019a6..18c83df4067ed 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java @@ -7,18 +7,27 @@ package org.elasticsearch.xpack.inference.action; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.TaskType; import org.elasticsearch.injection.guice.Inject; import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.GetInferenceServicesAction; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; +import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationModel; +import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; import java.util.ArrayList; import java.util.Comparator; @@ -26,17 +35,27 @@ import java.util.Map; import java.util.stream.Collectors; +import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; + public class TransportGetInferenceServicesAction extends HandledTransportAction< GetInferenceServicesAction.Request, GetInferenceServicesAction.Response> { + private static final Logger logger = LogManager.getLogger(TransportGetInferenceServicesAction.class); + private final InferenceServiceRegistry serviceRegistry; + private final ElasticInferenceServiceAuthorizationRequestHandler eisAuthorizationRequestHandler; + private final Sender eisSender; + private final ThreadPool threadPool; @Inject public TransportGetInferenceServicesAction( TransportService transportService, ActionFilters actionFilters, - InferenceServiceRegistry serviceRegistry + ThreadPool threadPool, + InferenceServiceRegistry serviceRegistry, + ElasticInferenceServiceAuthorizationRequestHandler eisAuthorizationRequestHandler, + Sender sender ) { super( GetInferenceServicesAction.NAME, @@ -46,6 +65,9 @@ public TransportGetInferenceServicesAction( EsExecutors.DIRECT_EXECUTOR_SERVICE ); this.serviceRegistry = serviceRegistry; + this.eisAuthorizationRequestHandler = eisAuthorizationRequestHandler; + this.eisSender = sender; + this.threadPool = threadPool; } @Override @@ -69,41 +91,86 @@ private void getServiceConfigurationsForTaskType( .entrySet() .stream() .filter( - service -> service.getValue().hideFromConfigurationApi() == false + // Exclude EIS as the EIS specific configurations must be retrieved directly from EIS and merged in later + service -> service.getValue().name().equals(ElasticInferenceService.NAME) == false + && service.getValue().hideFromConfigurationApi() == false && service.getValue().supportedTaskTypes().contains(requestedTaskType) ) .sorted(Comparator.comparing(service -> service.getValue().name())) .collect(Collectors.toCollection(ArrayList::new)); - getServiceConfigurationsForServices(filteredServices, listener.delegateFailureAndWrap((delegate, configurations) -> { - delegate.onResponse(new GetInferenceServicesAction.Response(configurations)); - })); + getServiceConfigurationsForServicesAndEis(listener, filteredServices, requestedTaskType); } private void getAllServiceConfigurations(ActionListener listener) { var availableServices = serviceRegistry.getServices() .entrySet() .stream() - .filter(service -> service.getValue().hideFromConfigurationApi() == false) + .filter( + // Exclude EIS as the EIS specific configurations must be retrieved directly from EIS and merged in later + service -> service.getValue().name().equals(ElasticInferenceService.NAME) == false + && service.getValue().hideFromConfigurationApi() == false + ) .sorted(Comparator.comparing(service -> service.getValue().name())) .collect(Collectors.toCollection(ArrayList::new)); - getServiceConfigurationsForServices(availableServices, listener.delegateFailureAndWrap((delegate, configurations) -> { - delegate.onResponse(new GetInferenceServicesAction.Response(configurations)); - })); + + getServiceConfigurationsForServicesAndEis(listener, availableServices, null); } - private void getServiceConfigurationsForServices( - ArrayList> services, - ActionListener> listener + private void getServiceConfigurationsForServicesAndEis( + ActionListener listener, + ArrayList> availableServices, + @Nullable TaskType requestedTaskType ) { - try { - var serviceConfigurations = new ArrayList(); - for (var service : services) { - serviceConfigurations.add(service.getValue().getConfiguration()); + SubscribableListener.newForked(authModelListener -> { + // Executing on a separate thread because there's a chance the authorization call needs to do some initialization for the Sender + threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(() -> getEisAuthorization(authModelListener, eisSender)); + }).>andThen((configurationListener, authorizationModel) -> { + var serviceConfigs = getServiceConfigurationsForServices(availableServices); + + if (authorizationModel.isAuthorized() == false) { + configurationListener.onResponse(serviceConfigs); + return; + } + + var config = ElasticInferenceService.createConfiguration(authorizationModel.getAuthorizedTaskTypes()); + if (requestedTaskType != null && authorizationModel.getAuthorizedTaskTypes().contains(requestedTaskType) == false) { + configurationListener.onResponse(serviceConfigs); + return; } - listener.onResponse(serviceConfigurations.stream().toList()); - } catch (Exception e) { - listener.onFailure(e); + + serviceConfigs.add(config); + serviceConfigs.sort(Comparator.comparing(InferenceServiceConfiguration::getService)); + configurationListener.onResponse(serviceConfigs); + }) + .addListener( + listener.delegateFailureAndWrap( + (delegate, configurations) -> delegate.onResponse(new GetInferenceServicesAction.Response(configurations)) + ) + ); + } + + private void getEisAuthorization(ActionListener listener, Sender sender) { + var disabledServiceListener = listener.delegateResponse((delegate, e) -> { + logger.warn( + "Failed to retrieve authorization information from the " + + "Elastic Inference Service while determining service configurations. Marking service as disabled.", + e + ); + delegate.onResponse(ElasticInferenceServiceAuthorizationModel.newDisabledService()); + }); + + eisAuthorizationRequestHandler.getAuthorization(disabledServiceListener, sender); + } + + private List getServiceConfigurationsForServices( + ArrayList> services + ) { + var serviceConfigurations = new ArrayList(); + for (var service : services) { + serviceConfigurations.add(service.getValue().getConfiguration()); } + + return serviceConfigurations; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 69ae769e36dc4..a0cb0f7ae1249 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -14,7 +14,6 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; -import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; @@ -284,7 +283,7 @@ public void waitForFirstAuthorizationToComplete(TimeValue waitTime) { @Override public Set supportedStreamingTasks() { - return authorizationHandler.supportedStreamingTasks(); + return EnumSet.of(TaskType.CHAT_COMPLETION); } @Override @@ -460,9 +459,16 @@ public void parseRequestConfig( } } + /** + * This shouldn't be called because the configuration changes based on the authorization. + * Instead, retrieve the authorization directly from the EIS gateway and use the static method + * {@link ElasticInferenceService#createConfiguration(EnumSet)} to create a configuration based on the authorization response. + */ @Override public InferenceServiceConfiguration getConfiguration() { - return authorizationHandler.getConfiguration(); + throw new UnsupportedOperationException( + "The EIS configuration changes depending on authorization, requests should be made directly to EIS instead" + ); } @Override @@ -472,7 +478,11 @@ public EnumSet supportedTaskTypes() { @Override public boolean hideFromConfigurationApi() { - return authorizationHandler.hideFromConfigurationApi(); + // This shouldn't be called because the configuration changes based on the authorization + // Instead, retrieve the authorization directly from the EIS gateway and use the response to determine if EIS is authorized + throw new UnsupportedOperationException( + "The EIS configuration changes depending on authorization, requests should be made directly to EIS instead" + ); } private static ElasticInferenceServiceModel createModel( @@ -656,62 +666,45 @@ private TraceContext getCurrentTraceInfo() { return new TraceContext(traceParent, traceState); } - public static class Configuration { - - private final EnumSet enabledTaskTypes; - private final LazyInitializable configuration; - - public Configuration(EnumSet enabledTaskTypes) { - this.enabledTaskTypes = enabledTaskTypes; - configuration = initConfiguration(); - } - - private LazyInitializable initConfiguration() { - return new LazyInitializable<>(() -> { - var configurationMap = new HashMap(); - - configurationMap.put( - MODEL_ID, - new SettingsConfiguration.Builder( - EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.RERANK, TaskType.TEXT_EMBEDDING) - ).setDescription("The name of the model to use for the inference task.") - .setLabel("Model ID") - .setRequired(true) - .setSensitive(false) - .setUpdatable(false) - .setType(SettingsConfigurationFieldType.STRING) - .build() - ); - - configurationMap.put( - MAX_INPUT_TOKENS, - new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING)).setDescription( - "Allows you to specify the maximum number of tokens per input." - ) - .setLabel("Maximum Input Tokens") - .setRequired(false) - .setSensitive(false) - .setUpdatable(false) - .setType(SettingsConfigurationFieldType.INTEGER) - .build() - ); + public static InferenceServiceConfiguration createConfiguration(EnumSet enabledTaskTypes) { + var configurationMap = new HashMap(); + + configurationMap.put( + MODEL_ID, + new SettingsConfiguration.Builder( + EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.RERANK, TaskType.TEXT_EMBEDDING) + ).setDescription("The name of the model to use for the inference task.") + .setLabel("Model ID") + .setRequired(true) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ); - configurationMap.putAll( - RateLimitSettings.toSettingsConfiguration( - EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.RERANK, TaskType.TEXT_EMBEDDING) - ) - ); + configurationMap.put( + MAX_INPUT_TOKENS, + new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING)).setDescription( + "Allows you to specify the maximum number of tokens per input." + ) + .setLabel("Maximum Input Tokens") + .setRequired(false) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.INTEGER) + .build() + ); - return new InferenceServiceConfiguration.Builder().setService(NAME) - .setName(SERVICE_NAME) - .setTaskTypes(enabledTaskTypes) - .setConfigurations(configurationMap) - .build(); - }); - } + configurationMap.putAll( + RateLimitSettings.toSettingsConfiguration( + EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.RERANK, TaskType.TEXT_EMBEDDING) + ) + ); - public InferenceServiceConfiguration get() { - return configuration.getOrCompute(); - } + return new InferenceServiceConfiguration.Builder().setService(NAME) + .setName(SERVICE_NAME) + .setTaskTypes(enabledTaskTypes) + .setConfigurations(configurationMap) + .build(); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java index 28e05b24bad64..f83542e7fe740 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java @@ -14,7 +14,6 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceService; -import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.threadpool.Scheduler; @@ -22,7 +21,6 @@ import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.elastic.DefaultModelConfig; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; import java.io.Closeable; @@ -61,7 +59,6 @@ static AuthorizedContent empty() { private final AtomicReference authorizedContent = new AtomicReference<>(AuthorizedContent.empty()); private final ModelRegistry modelRegistry; private final ElasticInferenceServiceAuthorizationRequestHandler authorizationHandler; - private final AtomicReference configuration; private final Map defaultModelsConfigs; private final CountDownLatch firstAuthorizationCompletedLatch = new CountDownLatch(1); private final EnumSet implementedTaskTypes; @@ -117,10 +114,6 @@ public ElasticInferenceServiceAuthorizationHandler( this.inferenceService = inferenceService; this.sender = Objects.requireNonNull(sender); this.elasticInferenceServiceSettings = Objects.requireNonNull(elasticInferenceServiceSettings); - - configuration = new AtomicReference<>( - new ElasticInferenceService.Configuration(authorizedContent.get().taskTypesAndModels.getAuthorizedTaskTypes()) - ); this.callback = callback; } @@ -168,10 +161,6 @@ public synchronized boolean hideFromConfigurationApi() { return authorizedContent.get().taskTypesAndModels.isAuthorized() == false; } - public synchronized InferenceServiceConfiguration getConfiguration() { - return configuration.get().get(); - } - @Override public void close() throws IOException { shutdown.set(true); @@ -257,8 +246,6 @@ private synchronized void setAuthorizedContent(ElasticInferenceServiceAuthorizat new AuthorizedContent(authorizedTaskTypesAndModels, authorizedDefaultConfigIds, authorizedDefaultModelObjects) ); - configuration.set(new ElasticInferenceService.Configuration(authorizedContent.get().taskTypesAndModels.getAuthorizedTaskTypes())); - authorizedContent.get().configIds().forEach(modelRegistry::putDefaultIdIfAbsent); handleRevokedDefaultConfigs(authorizedDefaultModelIds); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index d660f395250dd..d861d5b2bb47b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -922,15 +922,15 @@ public void testChunkedInfer_BatchesCallsChunkingSettingsSet() throws IOExceptio } } - public void testHideFromConfigurationApi_ReturnsTrue_WithNoAvailableModels() throws Exception { + public void testHideFromConfigurationApi_ThrowsUnsupported_WithNoAvailableModels() throws Exception { try (var service = createServiceWithMockSender(ElasticInferenceServiceAuthorizationModel.newDisabledService())) { ensureAuthorizationCallFinished(service); - assertTrue(service.hideFromConfigurationApi()); + expectThrows(UnsupportedOperationException.class, service::hideFromConfigurationApi); } } - public void testHideFromConfigurationApi_ReturnsFalse_WithAvailableModels() throws Exception { + public void testHideFromConfigurationApi_ThrowsUnsupported_WithAvailableModels() throws Exception { try ( var service = createServiceWithMockSender( ElasticInferenceServiceAuthorizationModel.of( @@ -947,11 +947,11 @@ public void testHideFromConfigurationApi_ReturnsFalse_WithAvailableModels() thro ) { ensureAuthorizationCallFinished(service); - assertFalse(service.hideFromConfigurationApi()); + expectThrows(UnsupportedOperationException.class, service::hideFromConfigurationApi); } } - public void testGetConfiguration() throws Exception { + public void testCreateConfiguration() throws Exception { try ( var service = createServiceWithMockSender( ElasticInferenceServiceAuthorizationModel.of( @@ -1010,7 +1010,9 @@ public void testGetConfiguration() throws Exception { ); boolean humanReadable = true; BytesReference originalBytes = toShuffledXContent(configuration, XContentType.JSON, ToXContent.EMPTY_PARAMS, humanReadable); - InferenceServiceConfiguration serviceConfiguration = service.getConfiguration(); + InferenceServiceConfiguration serviceConfiguration = ElasticInferenceService.createConfiguration( + EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.TEXT_EMBEDDING) + ); assertToXContentEquivalent( originalBytes, toXContent(serviceConfiguration, XContentType.JSON, humanReadable), @@ -1065,7 +1067,9 @@ public void testGetConfiguration_WithoutSupportedTaskTypes() throws Exception { ); boolean humanReadable = true; BytesReference originalBytes = toShuffledXContent(configuration, XContentType.JSON, ToXContent.EMPTY_PARAMS, humanReadable); - InferenceServiceConfiguration serviceConfiguration = service.getConfiguration(); + InferenceServiceConfiguration serviceConfiguration = ElasticInferenceService.createConfiguration( + EnumSet.noneOf(TaskType.class) + ); assertToXContentEquivalent( originalBytes, toXContent(serviceConfiguration, XContentType.JSON, humanReadable), @@ -1074,7 +1078,7 @@ public void testGetConfiguration_WithoutSupportedTaskTypes() throws Exception { } } - public void testGetConfiguration_WithoutSupportedTaskTypes_WhenModelsReturnTaskOutsideOfImplementation() throws Exception { + public void testGetConfiguration_ThrowsUnsupported() throws Exception { try ( var service = createServiceWithMockSender( // this service doesn't yet support text embedding so we should still have no task types @@ -1092,54 +1096,7 @@ public void testGetConfiguration_WithoutSupportedTaskTypes_WhenModelsReturnTaskO ) { ensureAuthorizationCallFinished(service); - String content = XContentHelper.stripWhitespace(""" - { - "service": "elastic", - "name": "Elastic", - "task_types": ["text_embedding"], - "configurations": { - "rate_limit.requests_per_minute": { - "description": "Minimize the number of rate limit errors.", - "label": "Rate Limit", - "required": false, - "sensitive": false, - "updatable": false, - "type": "int", - "supported_task_types": ["text_embedding" , "sparse_embedding", "rerank", "chat_completion"] - }, - "model_id": { - "description": "The name of the model to use for the inference task.", - "label": "Model ID", - "required": true, - "sensitive": false, - "updatable": false, - "type": "str", - "supported_task_types": ["text_embedding" , "sparse_embedding", "rerank", "chat_completion"] - }, - "max_input_tokens": { - "description": "Allows you to specify the maximum number of tokens per input.", - "label": "Maximum Input Tokens", - "required": false, - "sensitive": false, - "updatable": false, - "type": "int", - "supported_task_types": ["text_embedding", "sparse_embedding"] - } - } - } - """); - InferenceServiceConfiguration configuration = InferenceServiceConfiguration.fromXContentBytes( - new BytesArray(content), - XContentType.JSON - ); - boolean humanReadable = true; - BytesReference originalBytes = toShuffledXContent(configuration, XContentType.JSON, ToXContent.EMPTY_PARAMS, humanReadable); - InferenceServiceConfiguration serviceConfiguration = service.getConfiguration(); - assertToXContentEquivalent( - originalBytes, - toXContent(serviceConfiguration, XContentType.JSON, humanReadable), - XContentType.JSON - ); + expectThrows(UnsupportedOperationException.class, service::getConfiguration); } } @@ -1241,7 +1198,7 @@ public void testSupportedStreamingTasks_ReturnsEmpty_WhenAuthRespondsWithoutChat try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { ensureAuthorizationCallFinished(service); - assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class))); + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); assertTrue(service.defaultConfigIds().isEmpty()); assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING))); @@ -1268,7 +1225,7 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsIn var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { ensureAuthorizationCallFinished(service); - assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class))); + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); assertThat( service.defaultConfigIds(), is(