diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java index d4d18fed6dd02..8ef9b59f5545a 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java @@ -42,7 +42,13 @@ public void init(Client client) { } public void onNodeStarted() { - services.values().forEach(InferenceService::onNodeStarted); + for (var service : services.values()) { + try { + service.onNodeStarted(); + } catch (Exception e) { + // ignore + } + } } public Map getServices() { diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java index d0f797e9f8fab..81c1a8dc7a5ba 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java @@ -39,6 +39,10 @@ public class BaseMockEISAuthServerTest extends ESRestTestCase { .setting("xpack.security.enabled", "true") // Adding both settings unless one feature flag is disabled in a particular environment .setting("xpack.inference.elastic.url", mockEISServer::getUrl) + // If we don't disable this there's a very small chance that the authorization code could attempt to make two + // calls which would result in a test failure because the webserver is only expecting a single request + // So to ensure we avoid that all together, this flag indicates that we'll only perform a single authorization request + .setting("xpack.inference.elastic.periodic_authorization_enabled", "false") // This plugin is located in the inference/qa/test-service-plugin package, look for TestInferenceServicePlugin .plugin("inference-service-test") .user("x_pack_rest_user", "x-pack-test-password") diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java index 2070b0f1f1574..2c133440100aa 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java @@ -27,8 +27,8 @@ import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; -import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationHandler; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; import org.junit.After; import org.junit.Before; @@ -270,7 +270,7 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA private void ensureAuthorizationCallFinished(ElasticInferenceService service) { service.onNodeStarted(); - service.waitForAuthorizationToComplete(TIMEOUT); + service.waitForFirstAuthorizationToComplete(TIMEOUT); } private ElasticInferenceService createElasticInferenceService() { @@ -280,9 +280,9 @@ private ElasticInferenceService createElasticInferenceService() { return new ElasticInferenceService( senderFactory, createWithEmptySettings(threadPool), - ElasticInferenceServiceComponents.withNoRevokeDelay(gatewayUrl), + ElasticInferenceServiceSettingsTests.create(gatewayUrl), modelRegistry, - new ElasticInferenceServiceAuthorizationHandler(gatewayUrl, threadPool) + new ElasticInferenceServiceAuthorizationRequestHandler(gatewayUrl, threadPool) ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 0b01ad5e3c66f..714829c08b041 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -118,7 +118,7 @@ import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; -import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationHandler; +import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioService; import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiService; @@ -274,14 +274,11 @@ public Collection createComponents(PluginServices services) { ); elasicInferenceServiceFactory.set(elasticInferenceServiceRequestSenderFactory); - ElasticInferenceServiceSettings inferenceServiceSettings = new ElasticInferenceServiceSettings(settings); - String elasticInferenceUrl = inferenceServiceSettings.getElasticInferenceServiceUrl(); + var inferenceServiceSettings = new ElasticInferenceServiceSettings(settings); + inferenceServiceSettings.init(services.clusterService()); - var elasticInferenceServiceComponentsInstance = ElasticInferenceServiceComponents.withDefaultRevokeDelay(elasticInferenceUrl); - elasticInferenceServiceComponents.set(elasticInferenceServiceComponentsInstance); - - var authorizationHandler = new ElasticInferenceServiceAuthorizationHandler( - elasticInferenceServiceComponentsInstance.elasticInferenceServiceUrl(), + var authorizationHandler = new ElasticInferenceServiceAuthorizationRequestHandler( + inferenceServiceSettings.getElasticInferenceServiceUrl(), services.threadPool() ); @@ -290,7 +287,7 @@ public Collection createComponents(PluginServices services) { context -> new ElasticInferenceService( elasicInferenceServiceFactory.get(), serviceComponents.get(), - elasticInferenceServiceComponentsInstance, + inferenceServiceSettings, modelRegistry, authorizationHandler ) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/DefaultModelConfig.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/DefaultModelConfig.java new file mode 100644 index 0000000000000..dcdf5bce1fbb4 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/DefaultModelConfig.java @@ -0,0 +1,15 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic; + +import org.elasticsearch.inference.MinimalServiceSettings; +import org.elasticsearch.inference.Model; + +public record DefaultModelConfig(Model model, MinimalServiceSettings settings) { + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 32a6c3ea274e3..9fc94c598f4b3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -7,8 +7,6 @@ package org.elasticsearch.xpack.inference.services.elastic; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; @@ -49,32 +47,22 @@ import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceUtils; -import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorization; import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationHandler; +import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import org.elasticsearch.xpack.inference.telemetry.TraceContext; -import java.util.ArrayList; -import java.util.Comparator; import java.util.EnumSet; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Locale; import java.util.Map; -import java.util.Objects; import java.util.Set; -import java.util.TreeSet; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; -import java.util.stream.Collectors; import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException; -import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; @@ -90,7 +78,6 @@ public class ElasticInferenceService extends SenderService { public static final String NAME = "elastic"; public static final String ELASTIC_INFERENCE_SERVICE_IDENTIFIER = "Elastic Inference Service"; - private static final Logger logger = LogManager.getLogger(ElasticInferenceService.class); private static final EnumSet IMPLEMENTED_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION); private static final String SERVICE_NAME = "Elastic"; @@ -107,34 +94,34 @@ public class ElasticInferenceService extends SenderService { */ private static final EnumSet SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING); - private static String defaultEndpointId(String modelId) { + public static String defaultEndpointId(String modelId) { return Strings.format(".%s-elastic", modelId); } private final ElasticInferenceServiceComponents elasticInferenceServiceComponents; - private Configuration configuration; - private final AtomicReference authRef = new AtomicReference<>(AuthorizedContent.empty()); - private final ModelRegistry modelRegistry; private final ElasticInferenceServiceAuthorizationHandler authorizationHandler; - private final CountDownLatch authorizationCompletedLatch = new CountDownLatch(1); - // model ids to model information, used for the default config methods to return the list of models and default - // configs - private final Map defaultModelsConfigs; public ElasticInferenceService( HttpRequestSender.Factory factory, ServiceComponents serviceComponents, - ElasticInferenceServiceComponents elasticInferenceServiceComponents, + ElasticInferenceServiceSettings elasticInferenceServiceSettings, ModelRegistry modelRegistry, - ElasticInferenceServiceAuthorizationHandler authorizationHandler + ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler ) { super(factory, serviceComponents); - this.elasticInferenceServiceComponents = Objects.requireNonNull(elasticInferenceServiceComponents); - this.modelRegistry = Objects.requireNonNull(modelRegistry); - this.authorizationHandler = Objects.requireNonNull(authorizationHandler); - - configuration = new Configuration(authRef.get().taskTypesAndModels.getAuthorizedTaskTypes()); - defaultModelsConfigs = initDefaultEndpoints(elasticInferenceServiceComponents); + this.elasticInferenceServiceComponents = new ElasticInferenceServiceComponents( + elasticInferenceServiceSettings.getElasticInferenceServiceUrl() + ); + authorizationHandler = new ElasticInferenceServiceAuthorizationHandler( + serviceComponents, + modelRegistry, + authorizationRequestHandler, + initDefaultEndpoints(elasticInferenceServiceComponents), + IMPLEMENTED_TASK_TYPES, + this, + getSender(), + elasticInferenceServiceSettings + ); } private static Map initDefaultEndpoints( @@ -170,169 +157,35 @@ private static Map initDefaultEndpoints( ); } - private record DefaultModelConfig(Model model, MinimalServiceSettings settings) {} - - private record AuthorizedContent( - ElasticInferenceServiceAuthorization taskTypesAndModels, - List configIds, - List defaultModelConfigs - ) { - static AuthorizedContent empty() { - return new AuthorizedContent(ElasticInferenceServiceAuthorization.newDisabledService(), List.of(), List.of()); - } - } - - private void getAuthorization() { - try { - ActionListener listener = ActionListener.wrap(this::setAuthorizedContent, e -> { - // we don't need to do anything if there was a failure, everything is disabled by default - authorizationCompletedLatch.countDown(); - }); - - authorizationHandler.getAuthorization(listener, getSender()); - } catch (Exception e) { - // we don't need to do anything if there was a failure, everything is disabled by default - authorizationCompletedLatch.countDown(); - } - } - - private synchronized void setAuthorizedContent(ElasticInferenceServiceAuthorization auth) { - var authorizedTaskTypesAndModels = auth.newLimitedToTaskTypes(EnumSet.copyOf(IMPLEMENTED_TASK_TYPES)); - - // recalculate which default config ids and models are authorized now - var authorizedDefaultModelIds = getAuthorizedDefaultModelIds(auth); - - var authorizedDefaultConfigIds = getAuthorizedDefaultConfigIds(authorizedDefaultModelIds, auth); - var authorizedDefaultModelObjects = getAuthorizedDefaultModelsObjects(authorizedDefaultModelIds); - authRef.set(new AuthorizedContent(authorizedTaskTypesAndModels, authorizedDefaultConfigIds, authorizedDefaultModelObjects)); - - configuration = new Configuration(authRef.get().taskTypesAndModels.getAuthorizedTaskTypes()); - - defaultConfigIds().forEach(modelRegistry::putDefaultIdIfAbsent); - handleRevokedDefaultConfigs(authorizedDefaultModelIds); - } - - private Set getAuthorizedDefaultModelIds(ElasticInferenceServiceAuthorization auth) { - var authorizedModels = auth.getAuthorizedModelIds(); - var authorizedDefaultModelIds = new TreeSet<>(defaultModelsConfigs.keySet()); - authorizedDefaultModelIds.retainAll(authorizedModels); - - return authorizedDefaultModelIds; - } - - private List getAuthorizedDefaultConfigIds( - Set authorizedDefaultModelIds, - ElasticInferenceServiceAuthorization auth - ) { - var authorizedConfigIds = new ArrayList(); - for (var id : authorizedDefaultModelIds) { - var modelConfig = defaultModelsConfigs.get(id); - if (modelConfig != null) { - if (auth.getAuthorizedTaskTypes().contains(modelConfig.model.getTaskType()) == false) { - logger.warn( - Strings.format( - "The authorization response included the default model: %s, " - + "but did not authorize the assumed task type of the model: %s. Enabling model.", - id, - modelConfig.model.getTaskType() - ) - ); - } - authorizedConfigIds.add(new DefaultConfigId(modelConfig.model.getInferenceEntityId(), modelConfig.settings(), this)); - } - } - - authorizedConfigIds.sort(Comparator.comparing(DefaultConfigId::inferenceId)); - return authorizedConfigIds; - } - - private List getAuthorizedDefaultModelsObjects(Set authorizedDefaultModelIds) { - var authorizedModels = new ArrayList(); - for (var id : authorizedDefaultModelIds) { - var modelConfig = defaultModelsConfigs.get(id); - if (modelConfig != null) { - authorizedModels.add(modelConfig); - } - } - - authorizedModels.sort(Comparator.comparing(modelConfig -> modelConfig.model.getInferenceEntityId())); - return authorizedModels; - } - - private void handleRevokedDefaultConfigs(Set authorizedDefaultModelIds) { - // if a model was initially returned in the authorization response but is absent, then we'll assume authorization was revoked - var unauthorizedDefaultModelIds = new HashSet<>(defaultModelsConfigs.keySet()); - unauthorizedDefaultModelIds.removeAll(authorizedDefaultModelIds); - - // get all the default inference endpoint ids for the unauthorized model ids - var unauthorizedDefaultInferenceEndpointIds = unauthorizedDefaultModelIds.stream() - .map(defaultModelsConfigs::get) // get all the model configs - .filter(Objects::nonNull) // limit to only non-null - .map(modelConfig -> modelConfig.model.getInferenceEntityId()) // get the inference ids - .collect(Collectors.toSet()); - - var deleteInferenceEndpointsListener = ActionListener.wrap(result -> { - logger.trace(Strings.format("Successfully revoked access to default inference endpoint IDs: %s", unauthorizedDefaultModelIds)); - authorizationCompletedLatch.countDown(); - }, e -> { - logger.warn( - Strings.format("Failed to revoke access to default inference endpoint IDs: %s, error: %s", unauthorizedDefaultModelIds, e) - ); - authorizationCompletedLatch.countDown(); - }); - - Runnable removeFromRegistry = () -> { - logger.debug("Synchronizing default inference endpoints"); - modelRegistry.removeDefaultConfigs(unauthorizedDefaultInferenceEndpointIds, deleteInferenceEndpointsListener); - }; - - var delay = elasticInferenceServiceComponents.revokeAuthorizationDelay(); - if (delay == null) { - getServiceComponents().threadPool().executor(UTILITY_THREAD_POOL_NAME).execute(removeFromRegistry); - } else { - getServiceComponents().threadPool() - .schedule(removeFromRegistry, delay, getServiceComponents().threadPool().executor(UTILITY_THREAD_POOL_NAME)); - } - - } - @Override public void onNodeStarted() { - getServiceComponents().threadPool().executor(UTILITY_THREAD_POOL_NAME).execute(this::getAuthorization); + authorizationHandler.init(); } /** + * Only use this in tests. + * * Waits the specified amount of time for the authorization call to complete. This is mainly to make testing easier. * @param waitTime the max time to wait * @throws IllegalStateException if the wait time is exceeded or the call receives an {@link InterruptedException} */ - public void waitForAuthorizationToComplete(TimeValue waitTime) { - try { - if (authorizationCompletedLatch.await(waitTime.getSeconds(), TimeUnit.SECONDS) == false) { - throw new IllegalStateException("The wait time has expired for authorization to complete."); - } - } catch (InterruptedException e) { - throw new IllegalStateException("Waiting for authorization to complete was interrupted"); - } + public void waitForFirstAuthorizationToComplete(TimeValue waitTime) { + authorizationHandler.waitForAuthorizationToComplete(waitTime); } @Override - public synchronized Set supportedStreamingTasks() { - var authorizedStreamingTaskTypes = EnumSet.of(TaskType.CHAT_COMPLETION); - authorizedStreamingTaskTypes.retainAll(authRef.get().taskTypesAndModels.getAuthorizedTaskTypes()); - - return authorizedStreamingTaskTypes; + public Set supportedStreamingTasks() { + return authorizationHandler.supportedStreamingTasks(); } @Override - public synchronized List defaultConfigIds() { - return authRef.get().configIds; + public List defaultConfigIds() { + return authorizationHandler.defaultConfigIds(); } @Override - public synchronized void defaultConfigs(ActionListener> defaultsListener) { - var models = authRef.get().defaultModelConfigs.stream().map(config -> config.model).toList(); - defaultsListener.onResponse(models); + public void defaultConfigs(ActionListener> defaultsListener) { + authorizationHandler.defaultConfigs(defaultsListener); } @Override @@ -384,6 +237,7 @@ protected void doInfer( responseString = responseString + " " + useChatCompletionUrlMessage(model); } listener.onFailure(new ElasticsearchStatusException(responseString, RestStatus.BAD_REQUEST)); + return; } if (model instanceof ElasticInferenceServiceExecutableActionModel == false) { @@ -458,18 +312,18 @@ public void parseRequestConfig( } @Override - public synchronized InferenceServiceConfiguration getConfiguration() { - return configuration.get(); + public InferenceServiceConfiguration getConfiguration() { + return authorizationHandler.getConfiguration(); } @Override - public synchronized EnumSet supportedTaskTypes() { - return authRef.get().taskTypesAndModels.getAuthorizedTaskTypes(); + public EnumSet supportedTaskTypes() { + return authorizationHandler.supportedTaskTypes(); } @Override - public synchronized boolean hideFromConfigurationApi() { - return authRef.get().taskTypesAndModels.isAuthorized() == false; + public boolean hideFromConfigurationApi() { + return authorizationHandler.hideFromConfigurationApi(); } private static ElasticInferenceServiceModel createModel( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceComponents.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceComponents.java index f79de437fcaf2..83fd957f9005d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceComponents.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceComponents.java @@ -8,23 +8,14 @@ package org.elasticsearch.xpack.inference.services.elastic; import org.elasticsearch.core.Nullable; -import org.elasticsearch.core.TimeValue; /** * @param elasticInferenceServiceUrl the upstream Elastic Inference Server's URL - * @param revokeAuthorizationDelay Amount of time to wait before attempting to revoke authorization to certain model ids. - * null indicates that there should be no delay */ -public record ElasticInferenceServiceComponents(@Nullable String elasticInferenceServiceUrl, @Nullable TimeValue revokeAuthorizationDelay) { - private static final TimeValue DEFAULT_REVOKE_AUTHORIZATION_DELAY = TimeValue.timeValueMinutes(10); +public record ElasticInferenceServiceComponents(@Nullable String elasticInferenceServiceUrl) { + public static final ElasticInferenceServiceComponents EMPTY_INSTANCE = ElasticInferenceServiceComponents.of(null); - public static final ElasticInferenceServiceComponents EMPTY_INSTANCE = new ElasticInferenceServiceComponents(null, null); - - public static ElasticInferenceServiceComponents withNoRevokeDelay(String elasticInferenceServiceUrl) { - return new ElasticInferenceServiceComponents(elasticInferenceServiceUrl, null); - } - - public static ElasticInferenceServiceComponents withDefaultRevokeDelay(String elasticInferenceServiceUrl) { - return new ElasticInferenceServiceComponents(elasticInferenceServiceUrl, DEFAULT_REVOKE_AUTHORIZATION_DELAY); + public static ElasticInferenceServiceComponents of(String elasticInferenceServiceUrl) { + return new ElasticInferenceServiceComponents(elasticInferenceServiceUrl); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java index 98d55fd799598..fe6ebb6cfb625 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java @@ -7,9 +7,11 @@ package org.elasticsearch.xpack.inference.services.elastic; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.xpack.core.ssl.SSLConfigurationSettings; import java.util.ArrayList; @@ -31,15 +33,31 @@ public class ElasticInferenceServiceSettings { Setting.Property.NodeScope ); - @Deprecated - private final String eisGatewayUrl; + /** + * This setting is for testing only. It controls whether authorization is only performed once at bootup. If set to true, an + * authorization request will be made repeatedly on an interval. + */ + static final Setting PERIODIC_AUTHORIZATION_ENABLED = Setting.boolSetting( + "xpack.inference.elastic.periodic_authorization_enabled", + true, + Setting.Property.NodeScope + ); - private final String elasticInferenceServiceUrl; + private static final TimeValue DEFAULT_AUTH_REQUEST_INTERVAL = TimeValue.timeValueMinutes(10); + static final Setting AUTHORIZATION_REQUEST_INTERVAL = Setting.timeSetting( + "xpack.inference.elastic.authorization_request_interval", + DEFAULT_AUTH_REQUEST_INTERVAL, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); - public ElasticInferenceServiceSettings(Settings settings) { - eisGatewayUrl = EIS_GATEWAY_URL.get(settings); - elasticInferenceServiceUrl = ELASTIC_INFERENCE_SERVICE_URL.get(settings); - } + private static final TimeValue DEFAULT_AUTH_REQUEST_JITTER = TimeValue.timeValueMinutes(5); + static final Setting MAX_AUTHORIZATION_REQUEST_JITTER = Setting.timeSetting( + "xpack.inference.elastic.max_authorization_request_jitter", + DEFAULT_AUTH_REQUEST_JITTER, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); public static final SSLConfigurationSettings ELASTIC_INFERENCE_SERVICE_SSL_CONFIGURATION_SETTINGS = SSLConfigurationSettings.withPrefix( ELASTIC_INFERENCE_SERVICE_SSL_CONFIGURATION_PREFIX, @@ -52,13 +70,60 @@ public ElasticInferenceServiceSettings(Settings settings) { Setting.Property.NodeScope ); + @Deprecated + private final String eisGatewayUrl; + + private final String elasticInferenceServiceUrl; + private final boolean periodicAuthorizationEnabled; + private volatile TimeValue authRequestInterval; + private volatile TimeValue maxAuthorizationRequestJitter; + + public ElasticInferenceServiceSettings(Settings settings) { + eisGatewayUrl = EIS_GATEWAY_URL.get(settings); + elasticInferenceServiceUrl = ELASTIC_INFERENCE_SERVICE_URL.get(settings); + periodicAuthorizationEnabled = PERIODIC_AUTHORIZATION_ENABLED.get(settings); + authRequestInterval = AUTHORIZATION_REQUEST_INTERVAL.get(settings); + maxAuthorizationRequestJitter = MAX_AUTHORIZATION_REQUEST_JITTER.get(settings); + } + + /** + * This must be called after the object is constructed to avoid leaking the this reference before the constructor + * finishes. + * + * Handles initializing the settings changes listener. + */ + public final void init(ClusterService clusterService) { + clusterService.getClusterSettings() + .addSettingsUpdateConsumer(AUTHORIZATION_REQUEST_INTERVAL, this::setAuthorizationRequestInterval); + clusterService.getClusterSettings() + .addSettingsUpdateConsumer(MAX_AUTHORIZATION_REQUEST_JITTER, this::setMaxAuthorizationRequestJitter); + } + + private void setAuthorizationRequestInterval(TimeValue interval) { + authRequestInterval = interval; + } + + private void setMaxAuthorizationRequestJitter(TimeValue jitter) { + maxAuthorizationRequestJitter = jitter; + } + + public TimeValue getAuthRequestInterval() { + return authRequestInterval; + } + + public TimeValue getMaxAuthorizationRequestJitter() { + return maxAuthorizationRequestJitter; + } + public static List> getSettingsDefinitions() { ArrayList> settings = new ArrayList<>(); settings.add(EIS_GATEWAY_URL); settings.add(ELASTIC_INFERENCE_SERVICE_URL); settings.add(ELASTIC_INFERENCE_SERVICE_SSL_ENABLED); settings.addAll(ELASTIC_INFERENCE_SERVICE_SSL_CONFIGURATION_SETTINGS.getEnabledSettings()); - + settings.add(PERIODIC_AUTHORIZATION_ENABLED); + settings.add(AUTHORIZATION_REQUEST_INTERVAL); + settings.add(MAX_AUTHORIZATION_REQUEST_JITTER); return settings; } @@ -66,4 +131,7 @@ public String getElasticInferenceServiceUrl() { return Strings.isEmpty(elasticInferenceServiceUrl) ? eisGatewayUrl : elasticInferenceServiceUrl; } + public boolean isPeriodicAuthorizationEnabled() { + return periodicAuthorizationEnabled; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModel.java index ac6a389914a10..fd38d63f7f74e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModel.java @@ -58,7 +58,7 @@ public ElasticInferenceServiceSparseEmbeddingsModel( this.uri = createUri(); } - ElasticInferenceServiceSparseEmbeddingsModel( + public ElasticInferenceServiceSparseEmbeddingsModel( String inferenceEntityId, TaskType taskType, String service, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java index 414753123ca27..c0addad455222 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java @@ -9,130 +9,335 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.elasticsearch.ElasticsearchWrapperException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.Randomness; import org.elasticsearch.common.Strings; -import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.tasks.Task; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.inference.external.elastic.ElasticInferenceServiceResponseHandler; -import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.threadpool.Scheduler; import org.elasticsearch.xpack.inference.external.http.sender.Sender; -import org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceAuthorizationRequest; -import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceAuthorizationResponseEntity; -import org.elasticsearch.xpack.inference.telemetry.TraceContext; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.DefaultModelConfig; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; -import java.util.Locale; +import java.io.Closeable; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.EnumSet; +import java.util.HashSet; +import java.util.List; +import java.util.Map; import java.util.Objects; +import java.util.Set; +import java.util.TreeSet; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; -import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER; +import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; -/** - * Handles retrieving the authorization information from Elastic Inference Service. - */ -public class ElasticInferenceServiceAuthorizationHandler { +public class ElasticInferenceServiceAuthorizationHandler implements Closeable { + private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceAuthorizationHandler.class); - private static final String FAILED_TO_RETRIEVE_MESSAGE = - "Failed to retrieve the authorization information from the Elastic Inference Service."; - private static final TimeValue DEFAULT_AUTH_TIMEOUT = TimeValue.timeValueMinutes(1); - private static final ResponseHandler AUTH_RESPONSE_HANDLER = createAuthResponseHandler(); + private record AuthorizedContent( + ElasticInferenceServiceAuthorizationModel taskTypesAndModels, + List configIds, + List defaultModelConfigs + ) { + static AuthorizedContent empty() { + return new AuthorizedContent(ElasticInferenceServiceAuthorizationModel.newDisabledService(), List.of(), List.of()); + } + } - private static ResponseHandler createAuthResponseHandler() { - return new ElasticInferenceServiceResponseHandler( - String.format(Locale.ROOT, "%s sparse embeddings", ELASTIC_INFERENCE_SERVICE_IDENTIFIER), - ElasticInferenceServiceAuthorizationResponseEntity::fromResponse + private final ServiceComponents serviceComponents; + private final AtomicReference authorizedContent = new AtomicReference<>(AuthorizedContent.empty()); + private final ModelRegistry modelRegistry; + private final ElasticInferenceServiceAuthorizationRequestHandler authorizationHandler; + private final AtomicReference configuration; + private final Map defaultModelsConfigs; + private final CountDownLatch firstAuthorizationCompletedLatch = new CountDownLatch(1); + private final EnumSet implementedTaskTypes; + private final InferenceService inferenceService; + private final Sender sender; + private final Runnable callback; + private final AtomicReference lastAuthTask = new AtomicReference<>(null); + private final AtomicBoolean shutdown = new AtomicBoolean(false); + private final ElasticInferenceServiceSettings elasticInferenceServiceSettings; + + public ElasticInferenceServiceAuthorizationHandler( + ServiceComponents serviceComponents, + ModelRegistry modelRegistry, + ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, + Map defaultModelsConfigs, + EnumSet implementedTaskTypes, + InferenceService inferenceService, + Sender sender, + ElasticInferenceServiceSettings elasticInferenceServiceSettings + ) { + this( + serviceComponents, + modelRegistry, + authorizationRequestHandler, + defaultModelsConfigs, + implementedTaskTypes, + Objects.requireNonNull(inferenceService), + sender, + elasticInferenceServiceSettings, + null ); } - private final String baseUrl; - private final ThreadPool threadPool; - private final Logger logger; - private final CountDownLatch requestCompleteLatch = new CountDownLatch(1); + // default for testing + ElasticInferenceServiceAuthorizationHandler( + ServiceComponents serviceComponents, + ModelRegistry modelRegistry, + ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, + Map defaultModelsConfigs, + EnumSet implementedTaskTypes, + InferenceService inferenceService, + Sender sender, + ElasticInferenceServiceSettings elasticInferenceServiceSettings, + // this is a hack to facilitate testing + Runnable callback + ) { + this.serviceComponents = Objects.requireNonNull(serviceComponents); + this.modelRegistry = Objects.requireNonNull(modelRegistry); + this.authorizationHandler = Objects.requireNonNull(authorizationRequestHandler); + this.defaultModelsConfigs = Objects.requireNonNull(defaultModelsConfigs); + this.implementedTaskTypes = Objects.requireNonNull(implementedTaskTypes); + // allow the service to be null for testing + this.inferenceService = inferenceService; + this.sender = Objects.requireNonNull(sender); + this.elasticInferenceServiceSettings = Objects.requireNonNull(elasticInferenceServiceSettings); - public ElasticInferenceServiceAuthorizationHandler(@Nullable String baseUrl, ThreadPool threadPool) { - this.baseUrl = baseUrl; - this.threadPool = Objects.requireNonNull(threadPool); - logger = LogManager.getLogger(ElasticInferenceServiceAuthorizationHandler.class); + configuration = new AtomicReference<>( + new ElasticInferenceService.Configuration(authorizedContent.get().taskTypesAndModels.getAuthorizedTaskTypes()) + ); + this.callback = callback; } - // only use for testing - ElasticInferenceServiceAuthorizationHandler(@Nullable String baseUrl, ThreadPool threadPool, Logger logger) { - this.baseUrl = baseUrl; - this.threadPool = Objects.requireNonNull(threadPool); - this.logger = Objects.requireNonNull(logger); + public void init() { + logger.debug("Initializing authorization logic"); + serviceComponents.threadPool().executor(UTILITY_THREAD_POOL_NAME).execute(this::scheduleAndSendAuthorizationRequest); } /** - * Retrieve the authorization information from Elastic Inference Service - * @param listener a listener to receive the response - * @param sender a {@link Sender} for making the request to the Elastic Inference Service + * Waits the specified amount of time for the first authorization call to complete. This is mainly to make testing easier. + * @param waitTime the max time to wait + * @throws IllegalStateException if the wait time is exceeded or the call receives an {@link InterruptedException} */ - public void getAuthorization(ActionListener listener, Sender sender) { + public void waitForAuthorizationToComplete(TimeValue waitTime) { try { - logger.debug("Retrieving authorization information from the Elastic Inference Service."); + if (firstAuthorizationCompletedLatch.await(waitTime.getSeconds(), TimeUnit.SECONDS) == false) { + throw new IllegalStateException("The wait time has expired for authorization to complete."); + } + } catch (InterruptedException e) { + throw new IllegalStateException("Waiting for authorization to complete was interrupted"); + } + } + + public synchronized Set supportedStreamingTasks() { + var authorizedStreamingTaskTypes = EnumSet.of(TaskType.CHAT_COMPLETION); + authorizedStreamingTaskTypes.retainAll(authorizedContent.get().taskTypesAndModels.getAuthorizedTaskTypes()); + + return authorizedStreamingTaskTypes; + } + + public synchronized List defaultConfigIds() { + return authorizedContent.get().configIds; + } + + public synchronized void defaultConfigs(ActionListener> defaultsListener) { + var models = authorizedContent.get().defaultModelConfigs.stream().map(DefaultModelConfig::model).toList(); + defaultsListener.onResponse(models); + } - if (Strings.isNullOrEmpty(baseUrl)) { - logger.debug("The base URL for the authorization service is not valid, rejecting authorization."); - listener.onResponse(ElasticInferenceServiceAuthorization.newDisabledService()); + public synchronized EnumSet supportedTaskTypes() { + return authorizedContent.get().taskTypesAndModels.getAuthorizedTaskTypes(); + } + + public synchronized boolean hideFromConfigurationApi() { + return authorizedContent.get().taskTypesAndModels.isAuthorized() == false; + } + + public synchronized InferenceServiceConfiguration getConfiguration() { + return configuration.get().get(); + } + + @Override + public void close() throws IOException { + shutdown.set(true); + if (lastAuthTask.get() != null) { + lastAuthTask.get().cancel(); + } + } + + private void scheduleAuthorizationRequest() { + try { + if (elasticInferenceServiceSettings.isPeriodicAuthorizationEnabled() == false) { return; } - // ensure that the sender is initialized - sender.start(); + // this call has to be on the individual thread otherwise we get an exception + var random = Randomness.get(); + var jitter = (long) (elasticInferenceServiceSettings.getMaxAuthorizationRequestJitter().millis() * random.nextDouble()); + var waitTime = TimeValue.timeValueMillis(elasticInferenceServiceSettings.getAuthRequestInterval().millis() + jitter); - ActionListener newListener = ActionListener.wrap(results -> { - if (results instanceof ElasticInferenceServiceAuthorizationResponseEntity authResponseEntity) { - listener.onResponse(ElasticInferenceServiceAuthorization.of(authResponseEntity)); - } else { - logger.warn( - Strings.format( - FAILED_TO_RETRIEVE_MESSAGE + " Received an invalid response type: %s", - results.getClass().getSimpleName() - ) - ); - listener.onResponse(ElasticInferenceServiceAuthorization.newDisabledService()); + logger.debug( + () -> Strings.format( + "Scheduling the next authorization call with request interval: %s ms, jitter: %d ms", + elasticInferenceServiceSettings.getAuthRequestInterval().millis(), + jitter + ) + ); + logger.debug(() -> Strings.format("Next authorization call in %d minutes", waitTime.getMinutes())); + + lastAuthTask.set( + serviceComponents.threadPool() + .schedule( + this::scheduleAndSendAuthorizationRequest, + waitTime, + serviceComponents.threadPool().executor(UTILITY_THREAD_POOL_NAME) + ) + ); + } catch (Exception e) { + logger.warn("Failed scheduling authorization request", e); + } + } + + private void scheduleAndSendAuthorizationRequest() { + if (shutdown.get()) { + return; + } + + scheduleAuthorizationRequest(); + sendAuthorizationRequest(); + } + + private void sendAuthorizationRequest() { + try { + ActionListener listener = ActionListener.wrap((model) -> { + setAuthorizedContent(model); + if (callback != null) { + callback.run(); } - requestCompleteLatch.countDown(); }, e -> { - Throwable exception = e; - if (e instanceof ElasticsearchWrapperException wrapperException) { - exception = wrapperException.getCause(); - } - - logger.warn(Strings.format(FAILED_TO_RETRIEVE_MESSAGE + " Encountered an exception: %s", exception)); - listener.onResponse(ElasticInferenceServiceAuthorization.newDisabledService()); - requestCompleteLatch.countDown(); + // we don't need to do anything if there was a failure, everything is disabled by default + firstAuthorizationCompletedLatch.countDown(); }); - var productOrigin = threadPool.getThreadContext().getHeader(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER); - var request = new ElasticInferenceServiceAuthorizationRequest(baseUrl, getCurrentTraceInfo(), productOrigin); - - sender.sendWithoutQueuing(logger, request, AUTH_RESPONSE_HANDLER, DEFAULT_AUTH_TIMEOUT, newListener); + authorizationHandler.getAuthorization(listener, sender); } catch (Exception e) { - logger.warn(Strings.format("Retrieving the authorization information encountered an exception: %s", e)); - requestCompleteLatch.countDown(); + logger.warn("Failure while sending the request to retrieve authorization", e); + // we don't need to do anything if there was a failure, everything is disabled by default + firstAuthorizationCompletedLatch.countDown(); } } - private TraceContext getCurrentTraceInfo() { - var traceParent = threadPool.getThreadContext().getHeader(Task.TRACE_PARENT); - var traceState = threadPool.getThreadContext().getHeader(Task.TRACE_STATE); + private synchronized void setAuthorizedContent(ElasticInferenceServiceAuthorizationModel auth) { + logger.debug("Received authorization response"); + var authorizedTaskTypesAndModels = authorizedContent.get().taskTypesAndModels.merge(auth) + .newLimitedToTaskTypes(EnumSet.copyOf(implementedTaskTypes)); - return new TraceContext(traceParent, traceState); + // recalculate which default config ids and models are authorized now + var authorizedDefaultModelIds = getAuthorizedDefaultModelIds(auth); + + var authorizedDefaultConfigIds = getAuthorizedDefaultConfigIds(authorizedDefaultModelIds, auth); + var authorizedDefaultModelObjects = getAuthorizedDefaultModelsObjects(authorizedDefaultModelIds); + authorizedContent.set( + new AuthorizedContent(authorizedTaskTypesAndModels, authorizedDefaultConfigIds, authorizedDefaultModelObjects) + ); + + configuration.set(new ElasticInferenceService.Configuration(authorizedContent.get().taskTypesAndModels.getAuthorizedTaskTypes())); + + authorizedContent.get().configIds().forEach(modelRegistry::putDefaultIdIfAbsent); + handleRevokedDefaultConfigs(authorizedDefaultModelIds); } - // Default because should only be used for testing - void waitForAuthRequestCompletion(TimeValue timeValue) throws IllegalStateException { - try { - if (requestCompleteLatch.await(timeValue.getMillis(), TimeUnit.MILLISECONDS) == false) { - throw new IllegalStateException("The wait time has expired for authorization to complete."); + private Set getAuthorizedDefaultModelIds(ElasticInferenceServiceAuthorizationModel auth) { + var authorizedModels = auth.getAuthorizedModelIds(); + var authorizedDefaultModelIds = new TreeSet<>(defaultModelsConfigs.keySet()); + authorizedDefaultModelIds.retainAll(authorizedModels); + + return authorizedDefaultModelIds; + } + + private List getAuthorizedDefaultConfigIds( + Set authorizedDefaultModelIds, + ElasticInferenceServiceAuthorizationModel auth + ) { + var authorizedConfigIds = new ArrayList(); + for (var id : authorizedDefaultModelIds) { + var modelConfig = defaultModelsConfigs.get(id); + if (modelConfig != null) { + if (auth.getAuthorizedTaskTypes().contains(modelConfig.model().getTaskType()) == false) { + logger.warn( + org.elasticsearch.common.Strings.format( + "The authorization response included the default model: %s, " + + "but did not authorize the assumed task type of the model: %s. Enabling model.", + id, + modelConfig.model().getTaskType() + ) + ); + } + authorizedConfigIds.add( + new InferenceService.DefaultConfigId( + modelConfig.model().getInferenceEntityId(), + modelConfig.settings(), + inferenceService + ) + ); } - } catch (InterruptedException e) { - throw new IllegalStateException("Waiting for authorization to complete was interrupted"); } + + authorizedConfigIds.sort(Comparator.comparing(InferenceService.DefaultConfigId::inferenceId)); + return authorizedConfigIds; + } + + private List getAuthorizedDefaultModelsObjects(Set authorizedDefaultModelIds) { + var authorizedModels = new ArrayList(); + for (var id : authorizedDefaultModelIds) { + var modelConfig = defaultModelsConfigs.get(id); + if (modelConfig != null) { + authorizedModels.add(modelConfig); + } + } + + authorizedModels.sort(Comparator.comparing(modelConfig -> modelConfig.model().getInferenceEntityId())); + return authorizedModels; + } + + private void handleRevokedDefaultConfigs(Set authorizedDefaultModelIds) { + // if a model was initially returned in the authorization response but is absent, then we'll assume authorization was revoked + var unauthorizedDefaultModelIds = new HashSet<>(defaultModelsConfigs.keySet()); + unauthorizedDefaultModelIds.removeAll(authorizedDefaultModelIds); + + // get all the default inference endpoint ids for the unauthorized model ids + var unauthorizedDefaultInferenceEndpointIds = unauthorizedDefaultModelIds.stream() + .map(defaultModelsConfigs::get) // get all the model configs + .filter(Objects::nonNull) // limit to only non-null + .map(modelConfig -> modelConfig.model().getInferenceEntityId()) // get the inference ids + .collect(Collectors.toSet()); + + var deleteInferenceEndpointsListener = ActionListener.wrap(result -> { + logger.debug(Strings.format("Successfully revoked access to default inference endpoint IDs: %s", unauthorizedDefaultModelIds)); + firstAuthorizationCompletedLatch.countDown(); + }, e -> { + logger.warn( + Strings.format("Failed to revoke access to default inference endpoint IDs: %s, error: %s", unauthorizedDefaultModelIds, e) + ); + firstAuthorizationCompletedLatch.countDown(); + }); + + logger.debug("Synchronizing default inference endpoints"); + modelRegistry.removeDefaultConfigs(unauthorizedDefaultInferenceEndpointIds, deleteInferenceEndpointsListener); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorization.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationModel.java similarity index 63% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorization.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationModel.java index 76721bb6dcd7b..6ff3cb950151e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorization.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationModel.java @@ -19,21 +19,21 @@ import java.util.stream.Collectors; /** - * This is a helper class for managing the response from {@link ElasticInferenceServiceAuthorizationHandler}. + * Transforms the response from {@link ElasticInferenceServiceAuthorizationRequestHandler} into a format for consumption by the service. */ -public class ElasticInferenceServiceAuthorization { +public class ElasticInferenceServiceAuthorizationModel { private final Map> taskTypeToModels; private final EnumSet authorizedTaskTypes; private final Set authorizedModelIds; /** - * Converts an authorization response from Elastic Inference Service into the {@link ElasticInferenceServiceAuthorization} format. + * Converts an authorization response from Elastic Inference Service into the {@link ElasticInferenceServiceAuthorizationModel} format. * * @param responseEntity the {@link ElasticInferenceServiceAuthorizationResponseEntity} response from the upstream gateway. - * @return a new {@link ElasticInferenceServiceAuthorization} + * @return a new {@link ElasticInferenceServiceAuthorizationModel} */ - public static ElasticInferenceServiceAuthorization of(ElasticInferenceServiceAuthorizationResponseEntity responseEntity) { + public static ElasticInferenceServiceAuthorizationModel of(ElasticInferenceServiceAuthorizationResponseEntity responseEntity) { var taskTypeToModelsMap = new HashMap>(); var enabledTaskTypesSet = EnumSet.noneOf(TaskType.class); var enabledModelsSet = new HashSet(); @@ -54,17 +54,17 @@ public static ElasticInferenceServiceAuthorization of(ElasticInferenceServiceAut } } - return new ElasticInferenceServiceAuthorization(taskTypeToModelsMap, enabledModelsSet, enabledTaskTypesSet); + return new ElasticInferenceServiceAuthorizationModel(taskTypeToModelsMap, enabledModelsSet, enabledTaskTypesSet); } /** * Returns an object indicating that the cluster has no access to Elastic Inference Service. */ - public static ElasticInferenceServiceAuthorization newDisabledService() { - return new ElasticInferenceServiceAuthorization(Map.of(), Set.of(), EnumSet.noneOf(TaskType.class)); + public static ElasticInferenceServiceAuthorizationModel newDisabledService() { + return new ElasticInferenceServiceAuthorizationModel(Map.of(), Set.of(), EnumSet.noneOf(TaskType.class)); } - private ElasticInferenceServiceAuthorization( + private ElasticInferenceServiceAuthorizationModel( Map> taskTypeToModels, Set authorizedModelIds, EnumSet authorizedTaskTypes @@ -91,13 +91,13 @@ public EnumSet getAuthorizedTaskTypes() { } /** - * Returns a new {@link ElasticInferenceServiceAuthorization} object retaining only the specified task types + * Returns a new {@link ElasticInferenceServiceAuthorizationModel} object retaining only the specified task types * and applicable models that leverage those task types. Any task types not specified in the passed in set will be * excluded from the returned object. This is essentially an intersection. * @param taskTypes the task types to retain in the newly created object * @return a new object containing models and task types limited to the specified set. */ - public ElasticInferenceServiceAuthorization newLimitedToTaskTypes(EnumSet taskTypes) { + public ElasticInferenceServiceAuthorizationModel newLimitedToTaskTypes(EnumSet taskTypes) { var newTaskTypeToModels = new HashMap>(); var taskTypesThatHaveModels = EnumSet.noneOf(TaskType.class); @@ -110,15 +110,48 @@ public ElasticInferenceServiceAuthorization newLimitedToTaskTypes(EnumSet newEnabledModels = newTaskTypeToModels.values().stream().flatMap(Set::stream).collect(Collectors.toSet()); + return new ElasticInferenceServiceAuthorizationModel( + newTaskTypeToModels, + enabledModels(newTaskTypeToModels), + taskTypesThatHaveModels + ); + } + + private static Set enabledModels(Map> taskTypeToModels) { + return taskTypeToModels.values().stream().flatMap(Set::stream).collect(Collectors.toSet()); + } + + /** + * Returns a new {@link ElasticInferenceServiceAuthorizationModel} that combines the current model and the passed in one. + * @param other model to merge into this one + * @return a new model + */ + public ElasticInferenceServiceAuthorizationModel merge(ElasticInferenceServiceAuthorizationModel other) { + Map> newTaskTypeToModels = taskTypeToModels.entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> new HashSet<>(e.getValue()))); + + for (var entry : other.taskTypeToModels.entrySet()) { + newTaskTypeToModels.merge(entry.getKey(), new HashSet<>(entry.getValue()), (existingModelIds, newModelIds) -> { + existingModelIds.addAll(newModelIds); + return existingModelIds; + }); + } + + var newAuthorizedTaskTypes = authorizedTaskTypes.isEmpty() ? EnumSet.noneOf(TaskType.class) : EnumSet.copyOf(authorizedTaskTypes); + newAuthorizedTaskTypes.addAll(other.authorizedTaskTypes); - return new ElasticInferenceServiceAuthorization(newTaskTypeToModels, newEnabledModels, taskTypesThatHaveModels); + return new ElasticInferenceServiceAuthorizationModel( + newTaskTypeToModels, + enabledModels(newTaskTypeToModels), + newAuthorizedTaskTypes + ); } @Override public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; - ElasticInferenceServiceAuthorization that = (ElasticInferenceServiceAuthorization) o; + ElasticInferenceServiceAuthorizationModel that = (ElasticInferenceServiceAuthorizationModel) o; return Objects.equals(taskTypeToModels, that.taskTypeToModels) && Objects.equals(authorizedTaskTypes, that.authorizedTaskTypes) && Objects.equals(authorizedModelIds, that.authorizedModelIds); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java new file mode 100644 index 0000000000000..eb92d1b48f8a7 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java @@ -0,0 +1,138 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.authorization; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchWrapperException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.elastic.ElasticInferenceServiceResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceAuthorizationRequest; +import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceAuthorizationResponseEntity; +import org.elasticsearch.xpack.inference.telemetry.TraceContext; + +import java.util.Locale; +import java.util.Objects; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER; + +/** + * Handles retrieving the authorization information from Elastic Inference Service. + */ +public class ElasticInferenceServiceAuthorizationRequestHandler { + + private static final String FAILED_TO_RETRIEVE_MESSAGE = + "Failed to retrieve the authorization information from the Elastic Inference Service."; + private static final TimeValue DEFAULT_AUTH_TIMEOUT = TimeValue.timeValueMinutes(1); + private static final ResponseHandler AUTH_RESPONSE_HANDLER = createAuthResponseHandler(); + + private static ResponseHandler createAuthResponseHandler() { + return new ElasticInferenceServiceResponseHandler( + String.format(Locale.ROOT, "%s sparse embeddings", ELASTIC_INFERENCE_SERVICE_IDENTIFIER), + ElasticInferenceServiceAuthorizationResponseEntity::fromResponse + ); + } + + private final String baseUrl; + private final ThreadPool threadPool; + private final Logger logger; + private final CountDownLatch requestCompleteLatch = new CountDownLatch(1); + + public ElasticInferenceServiceAuthorizationRequestHandler(@Nullable String baseUrl, ThreadPool threadPool) { + this.baseUrl = baseUrl; + this.threadPool = Objects.requireNonNull(threadPool); + logger = LogManager.getLogger(ElasticInferenceServiceAuthorizationRequestHandler.class); + } + + // only use for testing + ElasticInferenceServiceAuthorizationRequestHandler(@Nullable String baseUrl, ThreadPool threadPool, Logger logger) { + this.baseUrl = baseUrl; + this.threadPool = Objects.requireNonNull(threadPool); + this.logger = Objects.requireNonNull(logger); + } + + /** + * Retrieve the authorization information from Elastic Inference Service + * @param listener a listener to receive the response + * @param sender a {@link Sender} for making the request to the Elastic Inference Service + */ + public void getAuthorization(ActionListener listener, Sender sender) { + try { + logger.debug("Retrieving authorization information from the Elastic Inference Service."); + + if (Strings.isNullOrEmpty(baseUrl)) { + logger.debug("The base URL for the authorization service is not valid, rejecting authorization."); + listener.onResponse(ElasticInferenceServiceAuthorizationModel.newDisabledService()); + return; + } + + // ensure that the sender is initialized + sender.start(); + + ActionListener newListener = ActionListener.wrap(results -> { + if (results instanceof ElasticInferenceServiceAuthorizationResponseEntity authResponseEntity) { + listener.onResponse(ElasticInferenceServiceAuthorizationModel.of(authResponseEntity)); + } else { + logger.warn( + Strings.format( + FAILED_TO_RETRIEVE_MESSAGE + " Received an invalid response type: %s", + results.getClass().getSimpleName() + ) + ); + listener.onResponse(ElasticInferenceServiceAuthorizationModel.newDisabledService()); + } + requestCompleteLatch.countDown(); + }, e -> { + Throwable exception = e; + if (e instanceof ElasticsearchWrapperException wrapperException) { + exception = wrapperException.getCause(); + } + + logger.warn(Strings.format(FAILED_TO_RETRIEVE_MESSAGE + " Encountered an exception: %s", exception)); + listener.onResponse(ElasticInferenceServiceAuthorizationModel.newDisabledService()); + requestCompleteLatch.countDown(); + }); + + var productOrigin = threadPool.getThreadContext().getHeader(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER); + var request = new ElasticInferenceServiceAuthorizationRequest(baseUrl, getCurrentTraceInfo(), productOrigin); + + sender.sendWithoutQueuing(logger, request, AUTH_RESPONSE_HANDLER, DEFAULT_AUTH_TIMEOUT, newListener); + } catch (Exception e) { + logger.warn(Strings.format("Retrieving the authorization information encountered an exception: %s", e)); + requestCompleteLatch.countDown(); + } + } + + private TraceContext getCurrentTraceInfo() { + var traceParent = threadPool.getThreadContext().getHeader(Task.TRACE_PARENT); + var traceState = threadPool.getThreadContext().getHeader(Task.TRACE_STATE); + + return new TraceContext(traceParent, traceState); + } + + // Default because should only be used for testing + void waitForAuthRequestCompletion(TimeValue timeValue) throws IllegalStateException { + try { + if (requestCompleteLatch.await(timeValue.getMillis(), TimeUnit.MILLISECONDS) == false) { + throw new IllegalStateException("The wait time has expired for authorization to complete."); + } + } catch (InterruptedException e) { + throw new IllegalStateException("Waiting for authorization to complete was interrupted"); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java index 1260b89034e6b..85300f24deea4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java @@ -34,6 +34,7 @@ import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension; import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; import org.hamcrest.Matchers; import java.io.IOException; @@ -59,6 +60,8 @@ public final class Utils { + public static final TimeValue TIMEOUT = TimeValue.timeValueSeconds(30); + private Utils() { throw new UnsupportedOperationException("Utils is a utility class and should not be instantiated"); } @@ -76,7 +79,8 @@ public static ClusterService mockClusterService(Settings settings) { ThrottlerManager.getSettingsDefinitions(), RetrySettings.getSettingsDefinitions(), Truncator.getSettingsDefinitions(), - RequestExecutorServiceSettings.getSettingsDefinitions() + RequestExecutorServiceSettings.getSettingsDefinitions(), + ElasticInferenceServiceSettings.getSettingsDefinitions() ).flatMap(Collection::stream).collect(Collectors.toSet()); var cSettings = new ClusterSettings(settings, registeredSettings); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettingsTests.java index e477ffb10def0..2616393ac8442 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettingsTests.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.services.elastic; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.test.ESTestCase; import static org.hamcrest.Matchers.equalTo; @@ -17,6 +18,30 @@ public class ElasticInferenceServiceSettingsTests extends ESTestCase { private static final String ELASTIC_INFERENCE_SERVICE_URL = "http://elastic-inference-service"; private static final String ELASTIC_INFERENCE_SERVICE_LEGACY_URL = "http://elastic-inference-service-legacy"; + public static ElasticInferenceServiceSettings create(String elasticInferenceServiceUrl) { + var settings = Settings.builder() + .put(ElasticInferenceServiceSettings.ELASTIC_INFERENCE_SERVICE_URL.getKey(), elasticInferenceServiceUrl) + .build(); + + return new ElasticInferenceServiceSettings(settings); + } + + public static ElasticInferenceServiceSettings create( + String elasticInferenceServiceUrl, + TimeValue authorizationRequestInterval, + TimeValue maxJitter, + boolean periodicAuthorizationEnabled + ) { + var settings = Settings.builder() + .put(ElasticInferenceServiceSettings.ELASTIC_INFERENCE_SERVICE_URL.getKey(), elasticInferenceServiceUrl) + .put(ElasticInferenceServiceSettings.AUTHORIZATION_REQUEST_INTERVAL.getKey(), authorizationRequestInterval) + .put(ElasticInferenceServiceSettings.MAX_AUTHORIZATION_REQUEST_JITTER.getKey(), maxJitter) + .put(ElasticInferenceServiceSettings.PERIODIC_AUTHORIZATION_ENABLED.getKey(), periodicAuthorizationEnabled) + .build(); + + return new ElasticInferenceServiceSettings(settings); + } + public void testGetElasticInferenceServiceUrl_WithUrlSetting() { var settings = Settings.builder() .put(ElasticInferenceServiceSettings.ELASTIC_INFERENCE_SERVICE_URL.getKey(), ELASTIC_INFERENCE_SERVICE_URL) @@ -53,5 +78,4 @@ public void testGetElasticInferenceServiceUrl_WithoutUrlSetting() { assertThat(eisSettings.getElasticInferenceServiceUrl(), equalTo("")); } - } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java index 1b4cd026b816f..4bd673e856123 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java @@ -26,7 +26,7 @@ public static ElasticInferenceServiceSparseEmbeddingsModel createModel(String ur new ElasticInferenceServiceSparseEmbeddingsServiceSettings(modelId, maxInputTokens, null), EmptyTaskSettings.INSTANCE, EmptySecretSettings.INSTANCE, - ElasticInferenceServiceComponents.withNoRevokeDelay(url) + ElasticInferenceServiceComponents.of(url) ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index b3f8579903885..b2ff028750e24 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -50,9 +50,9 @@ import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests; import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; import org.elasticsearch.xpack.inference.services.ServiceFields; -import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorization; -import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationHandler; -import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationTests; +import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationModel; +import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationModelTests; +import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels; @@ -379,6 +379,10 @@ public void testInfer_ThrowsErrorWhenModelIsNotAValidModel() throws IOException } private ModelRegistry mockModelRegistry() { + return mockModelRegistry(threadPool); + } + + public static ModelRegistry mockModelRegistry(ThreadPool threadPool) { var client = mock(Client.class); when(client.threadPool()).thenReturn(threadPool); @@ -582,7 +586,7 @@ public void testChunkedInfer_PassesThrough() throws IOException { } public void testHideFromConfigurationApi_ReturnsTrue_WithNoAvailableModels() throws Exception { - try (var service = createServiceWithMockSender(ElasticInferenceServiceAuthorization.newDisabledService())) { + try (var service = createServiceWithMockSender(ElasticInferenceServiceAuthorizationModel.newDisabledService())) { ensureAuthorizationCallFinished(service); assertTrue(service.hideFromConfigurationApi()); @@ -592,7 +596,7 @@ public void testHideFromConfigurationApi_ReturnsTrue_WithNoAvailableModels() thr public void testHideFromConfigurationApi_ReturnsTrue_WithModelTaskTypesThatAreNotImplemented() throws Exception { try ( var service = createServiceWithMockSender( - ElasticInferenceServiceAuthorization.of( + ElasticInferenceServiceAuthorizationModel.of( new ElasticInferenceServiceAuthorizationResponseEntity( List.of( new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( @@ -613,7 +617,7 @@ public void testHideFromConfigurationApi_ReturnsTrue_WithModelTaskTypesThatAreNo public void testHideFromConfigurationApi_ReturnsFalse_WithAvailableModels() throws Exception { try ( var service = createServiceWithMockSender( - ElasticInferenceServiceAuthorization.of( + ElasticInferenceServiceAuthorizationModel.of( new ElasticInferenceServiceAuthorizationResponseEntity( List.of( new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( @@ -634,7 +638,7 @@ public void testHideFromConfigurationApi_ReturnsFalse_WithAvailableModels() thro public void testGetConfiguration() throws Exception { try ( var service = createServiceWithMockSender( - ElasticInferenceServiceAuthorization.of( + ElasticInferenceServiceAuthorizationModel.of( new ElasticInferenceServiceAuthorizationResponseEntity( List.of( new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( @@ -700,7 +704,7 @@ public void testGetConfiguration() throws Exception { } public void testGetConfiguration_WithoutSupportedTaskTypes() throws Exception { - try (var service = createServiceWithMockSender(ElasticInferenceServiceAuthorization.newDisabledService())) { + try (var service = createServiceWithMockSender(ElasticInferenceServiceAuthorizationModel.newDisabledService())) { ensureAuthorizationCallFinished(service); String content = XContentHelper.stripWhitespace(""" @@ -758,7 +762,7 @@ public void testGetConfiguration_WithoutSupportedTaskTypes_WhenModelsReturnTaskO try ( var service = createServiceWithMockSender( // this service doesn't yet support text embedding so we should still have no task types - ElasticInferenceServiceAuthorization.of( + ElasticInferenceServiceAuthorizationModel.of( new ElasticInferenceServiceAuthorizationResponseEntity( List.of( new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( @@ -1061,7 +1065,7 @@ private void testUnifiedStreamError(int responseCode, String responseJson, Strin new ElasticInferenceServiceCompletionServiceSettings("model_id", new RateLimitSettings(100)), EmptyTaskSettings.INSTANCE, EmptySecretSettings.INSTANCE, - ElasticInferenceServiceComponents.withNoRevokeDelay(eisGatewayUrl) + ElasticInferenceServiceComponents.of(eisGatewayUrl) ); PlainActionFuture listener = new PlainActionFuture<>(); service.unifiedCompletionInfer( @@ -1096,46 +1100,50 @@ private void testUnifiedStreamError(int responseCode, String responseJson, Strin private void ensureAuthorizationCallFinished(ElasticInferenceService service) { service.onNodeStarted(); - service.waitForAuthorizationToComplete(TIMEOUT); + service.waitForFirstAuthorizationToComplete(TIMEOUT); } private ElasticInferenceService createServiceWithMockSender() { - return createServiceWithMockSender(ElasticInferenceServiceAuthorizationTests.createEnabledAuth()); + return createServiceWithMockSender(ElasticInferenceServiceAuthorizationModelTests.createEnabledAuth()); } - private ElasticInferenceService createServiceWithMockSender(ElasticInferenceServiceAuthorization auth) { - var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationHandler.class); + private ElasticInferenceService createServiceWithMockSender(ElasticInferenceServiceAuthorizationModel auth) { + var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); + ActionListener listener = invocation.getArgument(0); listener.onResponse(auth); return Void.TYPE; }).when(mockAuthHandler).getAuthorization(any(), any()); + var sender = mock(Sender.class); + + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); return new ElasticInferenceService( - mock(HttpRequestSender.Factory.class), + factory, createWithEmptySettings(threadPool), - ElasticInferenceServiceComponents.EMPTY_INSTANCE, + new ElasticInferenceServiceSettings(Settings.EMPTY), mockModelRegistry(), mockAuthHandler ); } private ElasticInferenceService createService(HttpRequestSender.Factory senderFactory) { - return createService(senderFactory, ElasticInferenceServiceAuthorizationTests.createEnabledAuth(), null); + return createService(senderFactory, ElasticInferenceServiceAuthorizationModelTests.createEnabledAuth(), null); } private ElasticInferenceService createService(HttpRequestSender.Factory senderFactory, String gatewayUrl) { - return createService(senderFactory, ElasticInferenceServiceAuthorizationTests.createEnabledAuth(), gatewayUrl); + return createService(senderFactory, ElasticInferenceServiceAuthorizationModelTests.createEnabledAuth(), gatewayUrl); } private ElasticInferenceService createService( HttpRequestSender.Factory senderFactory, - ElasticInferenceServiceAuthorization auth, + ElasticInferenceServiceAuthorizationModel auth, String gatewayUrl ) { - var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationHandler.class); + var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); + ActionListener listener = invocation.getArgument(0); listener.onResponse(auth); return Void.TYPE; }).when(mockAuthHandler).getAuthorization(any(), any()); @@ -1143,7 +1151,7 @@ private ElasticInferenceService createService( return new ElasticInferenceService( senderFactory, createWithEmptySettings(threadPool), - ElasticInferenceServiceComponents.withNoRevokeDelay(gatewayUrl), + ElasticInferenceServiceSettingsTests.create(gatewayUrl), mockModelRegistry(), mockAuthHandler ); @@ -1153,9 +1161,23 @@ private ElasticInferenceService createServiceWithAuthHandler(HttpRequestSender.F return new ElasticInferenceService( senderFactory, createWithEmptySettings(threadPool), - ElasticInferenceServiceComponents.withNoRevokeDelay(eisGatewayUrl), + ElasticInferenceServiceSettingsTests.create(eisGatewayUrl), mockModelRegistry(), - new ElasticInferenceServiceAuthorizationHandler(eisGatewayUrl, threadPool) + new ElasticInferenceServiceAuthorizationRequestHandler(eisGatewayUrl, threadPool) + ); + } + + public static ElasticInferenceService createServiceWithAuthHandler( + HttpRequestSender.Factory senderFactory, + String eisGatewayUrl, + ThreadPool threadPool + ) { + return new ElasticInferenceService( + senderFactory, + createWithEmptySettings(threadPool), + ElasticInferenceServiceSettingsTests.create(eisGatewayUrl), + mockModelRegistry(threadPool), + new ElasticInferenceServiceAuthorizationRequestHandler(eisGatewayUrl, threadPool) ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java index a87c3f814b7e1..94a42d2e46fa9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java @@ -7,266 +7,180 @@ package org.elasticsearch.xpack.inference.services.elastic.authorization; -import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; -import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.DeterministicTaskQueue; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.MinimalServiceSettings; +import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.test.http.MockResponse; -import org.elasticsearch.test.http.MockWebServer; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; -import org.elasticsearch.xpack.inference.external.http.HttpClientManager; -import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; -import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.Utils; import org.elasticsearch.xpack.inference.external.http.sender.Sender; -import org.elasticsearch.xpack.inference.logging.ThrottlerManager; -import org.junit.After; +import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceAuthorizationResponseEntity; +import org.elasticsearch.xpack.inference.services.elastic.DefaultModelConfig; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; +import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; import org.junit.Before; -import org.mockito.ArgumentCaptor; import java.io.IOException; import java.util.EnumSet; import java.util.List; -import java.util.Set; +import java.util.Map; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; -import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; -import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; -import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; -import static org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender.MAX_RETIES; -import static org.hamcrest.Matchers.is; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.defaultEndpointId; +import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceTests.mockModelRegistry; +import static org.hamcrest.CoreMatchers.is; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.when; public class ElasticInferenceServiceAuthorizationHandlerTests extends ESTestCase { - private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); - private final MockWebServer webServer = new MockWebServer(); - private ThreadPool threadPool; - - private HttpClientManager clientManager; + private DeterministicTaskQueue taskQueue; @Before public void init() throws Exception { - webServer.start(); - threadPool = createThreadPool(inferenceUtilityPool()); - clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); - } - - @After - public void shutdown() throws IOException { - clientManager.close(); - terminate(threadPool); - webServer.close(); - } - - public void testDoesNotAttempt_ToRetrieveAuthorization_IfBaseUrlIsNull() throws Exception { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - var logger = mock(Logger.class); - var authHandler = new ElasticInferenceServiceAuthorizationHandler(null, threadPool, logger); - - try (var sender = senderFactory.createSender()) { - PlainActionFuture listener = new PlainActionFuture<>(); - authHandler.getAuthorization(listener, sender); - - var authResponse = listener.actionGet(TIMEOUT); - assertTrue(authResponse.getAuthorizedTaskTypes().isEmpty()); - assertTrue(authResponse.getAuthorizedModelIds().isEmpty()); - assertFalse(authResponse.isAuthorized()); - - var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); - verify(logger, times(2)).debug(loggerArgsCaptor.capture()); - var messages = loggerArgsCaptor.getAllValues(); - assertThat(messages.getFirst(), is("Retrieving authorization information from the Elastic Inference Service.")); - assertThat(messages.get(1), is("The base URL for the authorization service is not valid, rejecting authorization.")); - } + taskQueue = new DeterministicTaskQueue(); } - public void testDoesNotAttempt_ToRetrieveAuthorization_IfBaseUrlIsEmpty() throws Exception { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - var logger = mock(Logger.class); - var authHandler = new ElasticInferenceServiceAuthorizationHandler("", threadPool, logger); - - try (var sender = senderFactory.createSender()) { - PlainActionFuture listener = new PlainActionFuture<>(); - authHandler.getAuthorization(listener, sender); - - var authResponse = listener.actionGet(TIMEOUT); - assertTrue(authResponse.getAuthorizedTaskTypes().isEmpty()); - assertTrue(authResponse.getAuthorizedModelIds().isEmpty()); - assertFalse(authResponse.isAuthorized()); - - var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); - verify(logger, times(2)).debug(loggerArgsCaptor.capture()); - var messages = loggerArgsCaptor.getAllValues(); - assertThat(messages.getFirst(), is("Retrieving authorization information from the Elastic Inference Service.")); - assertThat(messages.get(1), is("The base URL for the authorization service is not valid, rejecting authorization.")); - } - } - - public void testGetAuthorization_FailsWhenAnInvalidFieldIsFound() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - var eisGatewayUrl = getUrl(webServer); - var logger = mock(Logger.class); - var authHandler = new ElasticInferenceServiceAuthorizationHandler(eisGatewayUrl, threadPool, logger); - - try (var sender = senderFactory.createSender()) { - String responseJson = """ - { - "models": [ - { - "invalid-field": "model-a", - "task-types": ["embed/text/sparse", "chat"] - } - ] + public void testSendsAnAuthorizationRequestTwice() throws Exception { + var callbackCount = new AtomicInteger(0); + // we're only interested in two authorization calls which is why I'm using a value of 2 here + var latch = new CountDownLatch(2); + final AtomicReference handlerRef = new AtomicReference<>(); + + Runnable callback = () -> { + // the first authorization response does not contain a streaming task so we're expecting to not support streaming here + if (callbackCount.incrementAndGet() == 1) { + assertThat(handlerRef.get().supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class))); + } + latch.countDown(); + + // we only want to run the tasks twice, so advance the time on the queue + // which flags the scheduled authorization request to be ready to run + if (callbackCount.get() == 1) { + taskQueue.advanceTime(); + } else { + try { + handlerRef.get().close(); + } catch (IOException e) { + // ignore } - """; - - queueWebServerResponsesForRetries(responseJson); - - PlainActionFuture listener = new PlainActionFuture<>(); - authHandler.getAuthorization(listener, sender); - - var authResponse = listener.actionGet(TIMEOUT); - assertTrue(authResponse.getAuthorizedTaskTypes().isEmpty()); - assertTrue(authResponse.getAuthorizedModelIds().isEmpty()); - assertFalse(authResponse.isAuthorized()); - - var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); - verify(logger).warn(loggerArgsCaptor.capture()); - var message = loggerArgsCaptor.getValue(); - assertThat( - message, - is( - "Failed to retrieve the authorization information from the Elastic Inference Service." - + " Encountered an exception: org.elasticsearch.xcontent.XContentParseException: [4:28] " - + "[ElasticInferenceServiceAuthorizationResponseEntity] failed to parse field [models]" + } + }; + + var requestHandler = mockAuthorizationRequestHandler( + ElasticInferenceServiceAuthorizationModel.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel("abc", EnumSet.of(TaskType.SPARSE_EMBEDDING)) + ) ) - ); - } - } - - /** - * Queues the required number of responses to handle the retries of the internal sender. - */ - private void queueWebServerResponsesForRetries(String responseJson) { - for (int i = 0; i < MAX_RETIES; i++) { - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - } - } - - public void testGetAuthorization_ReturnsAValidResponse() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - var eisGatewayUrl = getUrl(webServer); - var logger = mock(Logger.class); - var authHandler = new ElasticInferenceServiceAuthorizationHandler(eisGatewayUrl, threadPool, logger); - - try (var sender = senderFactory.createSender()) { - String responseJson = """ - { - "models": [ - { - "model_name": "model-a", - "task_types": ["embed/text/sparse", "chat"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - PlainActionFuture listener = new PlainActionFuture<>(); - authHandler.getAuthorization(listener, sender); - - var authResponse = listener.actionGet(TIMEOUT); - assertThat(authResponse.getAuthorizedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION))); - assertThat(authResponse.getAuthorizedModelIds(), is(Set.of("model-a"))); - assertTrue(authResponse.isAuthorized()); - - var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); - verify(logger, times(1)).debug(loggerArgsCaptor.capture()); - - var message = loggerArgsCaptor.getValue(); - assertThat(message, is("Retrieving authorization information from the Elastic Inference Service.")); - verifyNoMoreInteractions(logger); - } - } - - @SuppressWarnings("unchecked") - public void testGetAuthorization_OnResponseCalledOnce() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - var eisGatewayUrl = getUrl(webServer); - var logger = mock(Logger.class); - var authHandler = new ElasticInferenceServiceAuthorizationHandler(eisGatewayUrl, threadPool, logger); - - ActionListener listener = mock(ActionListener.class); - String responseJson = """ - { - "models": [ - { - "model_name": "model-a", - "task_types": ["embed/text/sparse", "chat"] - } - ] - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - try (var sender = senderFactory.createSender()) { - authHandler.getAuthorization(listener, sender); - authHandler.waitForAuthRequestCompletion(TIMEOUT); - - verify(listener, times(1)).onResponse(any()); - var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); - verify(logger, times(1)).debug(loggerArgsCaptor.capture()); - - var message = loggerArgsCaptor.getValue(); - assertThat(message, is("Retrieving authorization information from the Elastic Inference Service.")); - verifyNoMoreInteractions(logger); - } + ), + ElasticInferenceServiceAuthorizationModel.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "rainbow-sprinkles", + EnumSet.of(TaskType.CHAT_COMPLETION) + ) + ) + ) + ) + ); + + handlerRef.set( + new ElasticInferenceServiceAuthorizationHandler( + createWithEmptySettings(taskQueue.getThreadPool()), + mockModelRegistry(taskQueue.getThreadPool()), + requestHandler, + initDefaultEndpoints(), + EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION), + null, + mock(Sender.class), + ElasticInferenceServiceSettingsTests.create(null, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), + callback + ) + ); + + var handler = handlerRef.get(); + handler.init(); + taskQueue.runAllRunnableTasks(); + latch.await(Utils.TIMEOUT.getSeconds(), TimeUnit.SECONDS); + // this should be after we've received both authorization responses + + assertThat(handler.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); + assertThat( + handler.defaultConfigIds(), + is(List.of(new InferenceService.DefaultConfigId(".rainbow-sprinkles-elastic", MinimalServiceSettings.chatCompletion(), null))) + ); + assertThat(handler.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION))); + + PlainActionFuture> listener = new PlainActionFuture<>(); + handler.defaultConfigs(listener); + + var configs = listener.actionGet(); + assertThat(configs.get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); } - public void testGetAuthorization_InvalidResponse() throws IOException { - var senderMock = mock(Sender.class); - var senderFactory = mock(HttpRequestSender.Factory.class); - when(senderFactory.createSender()).thenReturn(senderMock); - - doAnswer(invocationOnMock -> { - ActionListener listener = invocationOnMock.getArgument(4); - listener.onResponse(new ChatCompletionResults(List.of(new ChatCompletionResults.Result("awesome")))); + private static ElasticInferenceServiceAuthorizationRequestHandler mockAuthorizationRequestHandler( + ElasticInferenceServiceAuthorizationModel firstAuthResponse, + ElasticInferenceServiceAuthorizationModel secondAuthResponse + ) { + var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(firstAuthResponse); return Void.TYPE; - }).when(senderMock).sendWithoutQueuing(any(), any(), any(), any(), any()); - - var logger = mock(Logger.class); - var authHandler = new ElasticInferenceServiceAuthorizationHandler("abc", threadPool, logger); - - try (var sender = senderFactory.createSender()) { - PlainActionFuture listener = new PlainActionFuture<>(); - - authHandler.getAuthorization(listener, sender); - var result = listener.actionGet(TIMEOUT); - - assertThat(result, is(ElasticInferenceServiceAuthorization.newDisabledService())); + }).doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(secondAuthResponse); + return Void.TYPE; + }).when(mockAuthHandler).getAuthorization(any(), any()); - var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); - verify(logger).warn(loggerArgsCaptor.capture()); - var message = loggerArgsCaptor.getValue(); - assertThat( - message, - is( - "Failed to retrieve the authorization information from the Elastic Inference Service." - + " Received an invalid response type: ChatCompletionResults" - ) - ); - } + return mockAuthHandler; + } + private static Map initDefaultEndpoints() { + return Map.of( + "rainbow-sprinkles", + new DefaultModelConfig( + new ElasticInferenceServiceCompletionModel( + defaultEndpointId("rainbow-sprinkles"), + TaskType.CHAT_COMPLETION, + "test", + new ElasticInferenceServiceCompletionServiceSettings("rainbow-sprinkles", null), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + ElasticInferenceServiceComponents.EMPTY_INSTANCE + ), + MinimalServiceSettings.chatCompletion() + ), + "elser-v2", + new DefaultModelConfig( + new ElasticInferenceServiceSparseEmbeddingsModel( + defaultEndpointId("elser-v2"), + TaskType.SPARSE_EMBEDDING, + "test", + new ElasticInferenceServiceSparseEmbeddingsServiceSettings("elser-v2", null, null), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + ElasticInferenceServiceComponents.EMPTY_INSTANCE + ), + MinimalServiceSettings.sparseEmbedding() + ) + ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationModelTests.java similarity index 53% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationModelTests.java index 559de47232a7b..6db9238ab65a4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationModelTests.java @@ -17,9 +17,9 @@ import static org.hamcrest.Matchers.is; -public class ElasticInferenceServiceAuthorizationTests extends ESTestCase { - public static ElasticInferenceServiceAuthorization createEnabledAuth() { - return ElasticInferenceServiceAuthorization.of( +public class ElasticInferenceServiceAuthorizationModelTests extends ESTestCase { + public static ElasticInferenceServiceAuthorizationModel createEnabledAuth() { + return ElasticInferenceServiceAuthorizationModel.of( new ElasticInferenceServiceAuthorizationResponseEntity( List.of( new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel("model-1", EnumSet.of(TaskType.TEXT_EMBEDDING)) @@ -29,20 +29,20 @@ public static ElasticInferenceServiceAuthorization createEnabledAuth() { } public void testIsAuthorized_ReturnsFalse_WithEmptyMap() { - assertFalse(ElasticInferenceServiceAuthorization.newDisabledService().isAuthorized()); + assertFalse(ElasticInferenceServiceAuthorizationModel.newDisabledService().isAuthorized()); } public void testExcludes_ModelsWithoutTaskTypes() { var response = new ElasticInferenceServiceAuthorizationResponseEntity( List.of(new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel("model-1", EnumSet.noneOf(TaskType.class))) ); - var auth = ElasticInferenceServiceAuthorization.of(response); + var auth = ElasticInferenceServiceAuthorizationModel.of(response); assertTrue(auth.getAuthorizedTaskTypes().isEmpty()); assertFalse(auth.isAuthorized()); } public void testEnabledTaskTypes_MergesFromSeparateModels() { - var auth = ElasticInferenceServiceAuthorization.of( + var auth = ElasticInferenceServiceAuthorizationModel.of( new ElasticInferenceServiceAuthorizationResponseEntity( List.of( new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel("model-1", EnumSet.of(TaskType.TEXT_EMBEDDING)), @@ -55,7 +55,7 @@ public void testEnabledTaskTypes_MergesFromSeparateModels() { } public void testEnabledTaskTypes_FromSingleEntry() { - var auth = ElasticInferenceServiceAuthorization.of( + var auth = ElasticInferenceServiceAuthorizationModel.of( new ElasticInferenceServiceAuthorizationResponseEntity( List.of( new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( @@ -71,7 +71,7 @@ public void testEnabledTaskTypes_FromSingleEntry() { } public void testNewLimitToTaskTypes_SingleModel() { - var auth = ElasticInferenceServiceAuthorization.of( + var auth = ElasticInferenceServiceAuthorizationModel.of( new ElasticInferenceServiceAuthorizationResponseEntity( List.of( new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( @@ -86,7 +86,7 @@ public void testNewLimitToTaskTypes_SingleModel() { assertThat( auth.newLimitedToTaskTypes(EnumSet.of(TaskType.TEXT_EMBEDDING)), is( - ElasticInferenceServiceAuthorization.of( + ElasticInferenceServiceAuthorizationModel.of( new ElasticInferenceServiceAuthorizationResponseEntity( List.of( new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( @@ -101,7 +101,7 @@ public void testNewLimitToTaskTypes_SingleModel() { } public void testNewLimitToTaskTypes_MultipleModels_OnlyTextEmbedding() { - var auth = ElasticInferenceServiceAuthorization.of( + var auth = ElasticInferenceServiceAuthorizationModel.of( new ElasticInferenceServiceAuthorizationResponseEntity( List.of( new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( @@ -116,7 +116,7 @@ public void testNewLimitToTaskTypes_MultipleModels_OnlyTextEmbedding() { assertThat( auth.newLimitedToTaskTypes(EnumSet.of(TaskType.TEXT_EMBEDDING)), is( - ElasticInferenceServiceAuthorization.of( + ElasticInferenceServiceAuthorizationModel.of( new ElasticInferenceServiceAuthorizationResponseEntity( List.of( new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( @@ -135,7 +135,7 @@ public void testNewLimitToTaskTypes_MultipleModels_OnlyTextEmbedding() { } public void testNewLimitToTaskTypes_MultipleModels_MultipleTaskTypes() { - var auth = ElasticInferenceServiceAuthorization.of( + var auth = ElasticInferenceServiceAuthorizationModel.of( new ElasticInferenceServiceAuthorizationResponseEntity( List.of( new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( @@ -154,11 +154,11 @@ public void testNewLimitToTaskTypes_MultipleModels_MultipleTaskTypes() { ) ); - var a = auth.newLimitedToTaskTypes(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.CHAT_COMPLETION)); + var limitedAuth = auth.newLimitedToTaskTypes(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.CHAT_COMPLETION)); assertThat( - a, + limitedAuth, is( - ElasticInferenceServiceAuthorization.of( + ElasticInferenceServiceAuthorizationModel.of( new ElasticInferenceServiceAuthorizationResponseEntity( List.of( new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( @@ -177,7 +177,7 @@ public void testNewLimitToTaskTypes_MultipleModels_MultipleTaskTypes() { } public void testNewLimitToTaskTypes_DuplicateModelNames() { - var auth = ElasticInferenceServiceAuthorization.of( + var auth = ElasticInferenceServiceAuthorizationModel.of( new ElasticInferenceServiceAuthorizationResponseEntity( List.of( new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( @@ -192,11 +192,11 @@ public void testNewLimitToTaskTypes_DuplicateModelNames() { ) ); - var a = auth.newLimitedToTaskTypes(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK)); + var limitedAuth = auth.newLimitedToTaskTypes(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK)); assertThat( - a, + limitedAuth, is( - ElasticInferenceServiceAuthorization.of( + ElasticInferenceServiceAuthorizationModel.of( new ElasticInferenceServiceAuthorizationResponseEntity( List.of( new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( @@ -211,7 +211,7 @@ public void testNewLimitToTaskTypes_DuplicateModelNames() { } public void testNewLimitToTaskTypes_ReturnsDisabled_WhenNoOverlapForTaskTypes() { - var auth = ElasticInferenceServiceAuthorization.of( + var auth = ElasticInferenceServiceAuthorizationModel.of( new ElasticInferenceServiceAuthorizationResponseEntity( List.of( new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( @@ -226,7 +226,160 @@ public void testNewLimitToTaskTypes_ReturnsDisabled_WhenNoOverlapForTaskTypes() ) ); - var a = auth.newLimitedToTaskTypes(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.RERANK)); - assertThat(a, is(ElasticInferenceServiceAuthorization.newDisabledService())); + var limitedAuth = auth.newLimitedToTaskTypes(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.RERANK)); + assertThat(limitedAuth, is(ElasticInferenceServiceAuthorizationModel.newDisabledService())); + } + + public void testMerge_CombinesCorrectly() { + var auth1 = ElasticInferenceServiceAuthorizationModel.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-1", + EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) + ) + ) + ) + ); + + var auth2 = ElasticInferenceServiceAuthorizationModel.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel("model-2", EnumSet.of(TaskType.SPARSE_EMBEDDING)) + ) + ) + ); + + assertThat( + auth1.merge(auth2), + is( + ElasticInferenceServiceAuthorizationModel.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-1", + EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) + ), + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-2", + EnumSet.of(TaskType.SPARSE_EMBEDDING) + ) + ) + ) + ) + ) + ); + } + + public void testMerge_AddsNewTaskType() { + var auth1 = ElasticInferenceServiceAuthorizationModel.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-1", + EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) + ) + ) + ) + ); + + var auth2 = ElasticInferenceServiceAuthorizationModel.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel("model-2", EnumSet.of(TaskType.CHAT_COMPLETION)) + ) + ) + ); + + assertThat( + auth1.merge(auth2), + is( + ElasticInferenceServiceAuthorizationModel.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-1", + EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) + ), + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-2", + EnumSet.of(TaskType.CHAT_COMPLETION) + ) + ) + ) + ) + ) + ); + } + + public void testMerge_IgnoresDuplicates() { + var auth1 = ElasticInferenceServiceAuthorizationModel.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-1", + EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) + ) + ) + ) + ); + + var auth2 = ElasticInferenceServiceAuthorizationModel.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-1", + EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) + ) + ) + ) + ); + + assertThat( + auth1.merge(auth2), + is( + ElasticInferenceServiceAuthorizationModel.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-1", + EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) + ) + ) + ) + ) + ) + ); + } + + public void testMerge_CombinesCorrectlyWithEmptyModel() { + var auth1 = ElasticInferenceServiceAuthorizationModel.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-1", + EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) + ) + ) + ) + ); + + var auth2 = ElasticInferenceServiceAuthorizationModel.newDisabledService(); + + assertThat( + auth1.merge(auth2), + is( + ElasticInferenceServiceAuthorizationModel.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "model-1", + EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) + ) + ) + ) + ) + ) + ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java new file mode 100644 index 0000000000000..1c19285ea8bf1 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java @@ -0,0 +1,272 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.authorization; + +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.junit.After; +import org.junit.Before; +import org.mockito.ArgumentCaptor; + +import java.io.IOException; +import java.util.EnumSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender.MAX_RETIES; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +public class ElasticInferenceServiceAuthorizationRequestHandlerTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testDoesNotAttempt_ToRetrieveAuthorization_IfBaseUrlIsNull() throws Exception { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + var logger = mock(Logger.class); + var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler(null, threadPool, logger); + + try (var sender = senderFactory.createSender()) { + PlainActionFuture listener = new PlainActionFuture<>(); + authHandler.getAuthorization(listener, sender); + + var authResponse = listener.actionGet(TIMEOUT); + assertTrue(authResponse.getAuthorizedTaskTypes().isEmpty()); + assertTrue(authResponse.getAuthorizedModelIds().isEmpty()); + assertFalse(authResponse.isAuthorized()); + + var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); + verify(logger, times(2)).debug(loggerArgsCaptor.capture()); + var messages = loggerArgsCaptor.getAllValues(); + assertThat(messages.getFirst(), is("Retrieving authorization information from the Elastic Inference Service.")); + assertThat(messages.get(1), is("The base URL for the authorization service is not valid, rejecting authorization.")); + } + } + + public void testDoesNotAttempt_ToRetrieveAuthorization_IfBaseUrlIsEmpty() throws Exception { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + var logger = mock(Logger.class); + var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler("", threadPool, logger); + + try (var sender = senderFactory.createSender()) { + PlainActionFuture listener = new PlainActionFuture<>(); + authHandler.getAuthorization(listener, sender); + + var authResponse = listener.actionGet(TIMEOUT); + assertTrue(authResponse.getAuthorizedTaskTypes().isEmpty()); + assertTrue(authResponse.getAuthorizedModelIds().isEmpty()); + assertFalse(authResponse.isAuthorized()); + + var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); + verify(logger, times(2)).debug(loggerArgsCaptor.capture()); + var messages = loggerArgsCaptor.getAllValues(); + assertThat(messages.getFirst(), is("Retrieving authorization information from the Elastic Inference Service.")); + assertThat(messages.get(1), is("The base URL for the authorization service is not valid, rejecting authorization.")); + } + } + + public void testGetAuthorization_FailsWhenAnInvalidFieldIsFound() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + var eisGatewayUrl = getUrl(webServer); + var logger = mock(Logger.class); + var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler(eisGatewayUrl, threadPool, logger); + + try (var sender = senderFactory.createSender()) { + String responseJson = """ + { + "models": [ + { + "invalid-field": "model-a", + "task-types": ["embed/text/sparse", "chat"] + } + ] + } + """; + + queueWebServerResponsesForRetries(responseJson); + + PlainActionFuture listener = new PlainActionFuture<>(); + authHandler.getAuthorization(listener, sender); + + var authResponse = listener.actionGet(TIMEOUT); + assertTrue(authResponse.getAuthorizedTaskTypes().isEmpty()); + assertTrue(authResponse.getAuthorizedModelIds().isEmpty()); + assertFalse(authResponse.isAuthorized()); + + var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); + verify(logger).warn(loggerArgsCaptor.capture()); + var message = loggerArgsCaptor.getValue(); + assertThat( + message, + is( + "Failed to retrieve the authorization information from the Elastic Inference Service." + + " Encountered an exception: org.elasticsearch.xcontent.XContentParseException: [4:28] " + + "[ElasticInferenceServiceAuthorizationResponseEntity] failed to parse field [models]" + ) + ); + } + } + + /** + * Queues the required number of responses to handle the retries of the internal sender. + */ + private void queueWebServerResponsesForRetries(String responseJson) { + for (int i = 0; i < MAX_RETIES; i++) { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + } + } + + public void testGetAuthorization_ReturnsAValidResponse() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + var eisGatewayUrl = getUrl(webServer); + var logger = mock(Logger.class); + var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler(eisGatewayUrl, threadPool, logger); + + try (var sender = senderFactory.createSender()) { + String responseJson = """ + { + "models": [ + { + "model_name": "model-a", + "task_types": ["embed/text/sparse", "chat"] + } + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture listener = new PlainActionFuture<>(); + authHandler.getAuthorization(listener, sender); + + var authResponse = listener.actionGet(TIMEOUT); + assertThat(authResponse.getAuthorizedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION))); + assertThat(authResponse.getAuthorizedModelIds(), is(Set.of("model-a"))); + assertTrue(authResponse.isAuthorized()); + + var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); + verify(logger, times(1)).debug(loggerArgsCaptor.capture()); + + var message = loggerArgsCaptor.getValue(); + assertThat(message, is("Retrieving authorization information from the Elastic Inference Service.")); + verifyNoMoreInteractions(logger); + } + } + + @SuppressWarnings("unchecked") + public void testGetAuthorization_OnResponseCalledOnce() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + var eisGatewayUrl = getUrl(webServer); + var logger = mock(Logger.class); + var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler(eisGatewayUrl, threadPool, logger); + + ActionListener listener = mock(ActionListener.class); + String responseJson = """ + { + "models": [ + { + "model_name": "model-a", + "task_types": ["embed/text/sparse", "chat"] + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + try (var sender = senderFactory.createSender()) { + authHandler.getAuthorization(listener, sender); + authHandler.waitForAuthRequestCompletion(TIMEOUT); + + verify(listener, times(1)).onResponse(any()); + var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); + verify(logger, times(1)).debug(loggerArgsCaptor.capture()); + + var message = loggerArgsCaptor.getValue(); + assertThat(message, is("Retrieving authorization information from the Elastic Inference Service.")); + verifyNoMoreInteractions(logger); + } + } + + public void testGetAuthorization_InvalidResponse() throws IOException { + var senderMock = mock(Sender.class); + var senderFactory = mock(HttpRequestSender.Factory.class); + when(senderFactory.createSender()).thenReturn(senderMock); + + doAnswer(invocationOnMock -> { + ActionListener listener = invocationOnMock.getArgument(4); + listener.onResponse(new ChatCompletionResults(List.of(new ChatCompletionResults.Result("awesome")))); + return Void.TYPE; + }).when(senderMock).sendWithoutQueuing(any(), any(), any(), any(), any()); + + var logger = mock(Logger.class); + var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler("abc", threadPool, logger); + + try (var sender = senderFactory.createSender()) { + PlainActionFuture listener = new PlainActionFuture<>(); + + authHandler.getAuthorization(listener, sender); + var result = listener.actionGet(TIMEOUT); + + assertThat(result, is(ElasticInferenceServiceAuthorizationModel.newDisabledService())); + + var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); + verify(logger).warn(loggerArgsCaptor.capture()); + var message = loggerArgsCaptor.getValue(); + assertThat( + message, + is( + "Failed to retrieve the authorization information from the Elastic Inference Service." + + " Received an invalid response type: ChatCompletionResults" + ) + ); + } + + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModelTests.java index 5778fadd30a83..51945776b4f9e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModelTests.java @@ -29,7 +29,7 @@ public void testOverridingModelId() { new ElasticInferenceServiceCompletionServiceSettings("model_id", new RateLimitSettings(100)), EmptyTaskSettings.INSTANCE, EmptySecretSettings.INSTANCE, - ElasticInferenceServiceComponents.withNoRevokeDelay("url") + ElasticInferenceServiceComponents.of("url") ); var request = new UnifiedCompletionRequest(