From a77e5f5c21520768fd6b3850485ef31c78f63810 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner <56361221+jonathan-buttner@users.noreply.github.com> Date: Thu, 20 Feb 2025 08:39:45 -0500 Subject: [PATCH] [ML] Support delaying EIS authorization revocation until after the node has finished booting (#122644) * Refactoring authorization to happen after the node starts * Adding delay for model registry call * Fixing test (cherry picked from commit 4de82448c857c2c02080c26c4ec1fc42b12e47cc) # Conflicts: # x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java # x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java --- .../inference/InferenceService.java | 6 +++ .../inference/InferenceServiceRegistry.java | 4 ++ .../InferenceRevokeDefaultEndpointsIT.java | 21 +++++++--- .../xpack/inference/InferencePlugin.java | 15 ++++++- .../elastic/ElasticInferenceService.java | 23 +++++++--- .../ElasticInferenceServiceComponents.java | 20 ++++++++- ...cInferenceServiceAuthorizationHandler.java | 2 +- ...enceServiceSparseEmbeddingsModelTests.java | 2 +- .../elastic/ElasticInferenceServiceTests.java | 42 ++++++++++++++----- ...renceServiceAuthorizationHandlerTests.java | 14 ++++--- ...cInferenceServiceCompletionModelTests.java | 2 +- 11 files changed, 118 insertions(+), 33 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index e1ebd8bb81ff4..de1925cb641e9 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -241,4 +241,10 @@ default void defaultConfigs(ActionListener> defaultsListener) { default void updateModelsWithDynamicFields(List model, ActionListener> listener) { listener.onResponse(model); } + + /** + * Called after the Elasticsearch node has completed its start up. This allows the service to perform initialization + * after ensuring the node's internals are set up (for example if this ensures the internal ES client is ready for use). + */ + default void onNodeStarted() {} } diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java index f1ce94173a550..d4d18fed6dd02 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java @@ -41,6 +41,10 @@ public void init(Client client) { services.values().forEach(s -> s.init(client)); } + public void onNodeStarted() { + services.values().forEach(InferenceService::onNodeStarted); + } + public Map getServices() { return services; } diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java index 201f1250427a8..faef212bd3844 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java @@ -91,7 +91,7 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCo webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); try (var service = createElasticInferenceService()) { - service.waitForAuthorizationToComplete(TIMEOUT); + ensureAuthorizationCallFinished(service); assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY))); assertThat( service.defaultConfigIds(), @@ -125,7 +125,8 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty() webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); try (var service = createElasticInferenceService()) { - service.waitForAuthorizationToComplete(TIMEOUT); + ensureAuthorizationCallFinished(service); + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY))); assertThat( service.defaultConfigIds(), @@ -164,7 +165,8 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty() webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson)); try (var service = createElasticInferenceService()) { - service.waitForAuthorizationToComplete(TIMEOUT); + ensureAuthorizationCallFinished(service); + assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class))); assertTrue(service.defaultConfigIds().isEmpty()); assertThat(service.supportedTaskTypes(), is(EnumSet.noneOf(TaskType.class))); @@ -198,7 +200,8 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); try (var service = createElasticInferenceService()) { - service.waitForAuthorizationToComplete(TIMEOUT); + ensureAuthorizationCallFinished(service); + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY))); assertThat( service.defaultConfigIds(), @@ -242,7 +245,8 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson)); try (var service = createElasticInferenceService()) { - service.waitForAuthorizationToComplete(TIMEOUT); + ensureAuthorizationCallFinished(service); + assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class))); assertTrue(service.defaultConfigIds().isEmpty()); assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING))); @@ -256,6 +260,11 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA } } + private void ensureAuthorizationCallFinished(ElasticInferenceService service) { + service.onNodeStarted(); + service.waitForAuthorizationToComplete(TIMEOUT); + } + private ElasticInferenceService createElasticInferenceService() { var httpManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, httpManager); @@ -263,7 +272,7 @@ private ElasticInferenceService createElasticInferenceService() { return new ElasticInferenceService( senderFactory, createWithEmptySettings(threadPool), - new ElasticInferenceServiceComponents(gatewayUrl), + ElasticInferenceServiceComponents.withNoRevokeDelay(gatewayUrl), modelRegistry, new ElasticInferenceServiceAuthorizationHandler(gatewayUrl, threadPool) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index e3604351c1937..05b8944138a60 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -34,6 +34,7 @@ import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.node.PluginComponentBinding; import org.elasticsearch.plugins.ActionPlugin; +import org.elasticsearch.plugins.ClusterPlugin; import org.elasticsearch.plugins.ExtensiblePlugin; import org.elasticsearch.plugins.MapperPlugin; import org.elasticsearch.plugins.Plugin; @@ -146,7 +147,8 @@ public class InferencePlugin extends Plugin SystemIndexPlugin, MapperPlugin, SearchPlugin, - InternalSearchPlugin { + InternalSearchPlugin, + ClusterPlugin { /** * When this setting is true the verification check that @@ -274,7 +276,7 @@ public Collection createComponents(PluginServices services) { ElasticInferenceServiceSettings inferenceServiceSettings = new ElasticInferenceServiceSettings(settings); String elasticInferenceUrl = inferenceServiceSettings.getElasticInferenceServiceUrl(); - var elasticInferenceServiceComponentsInstance = new ElasticInferenceServiceComponents(elasticInferenceUrl); + var elasticInferenceServiceComponentsInstance = ElasticInferenceServiceComponents.withDefaultRevokeDelay(elasticInferenceUrl); elasticInferenceServiceComponents.set(elasticInferenceServiceComponentsInstance); var authorizationHandler = new ElasticInferenceServiceAuthorizationHandler( @@ -507,6 +509,15 @@ public Map getHighlighters() { return Map.of(SemanticTextHighlighter.NAME, new SemanticTextHighlighter()); } + @Override + public void onNodeStarted() { + var registry = inferenceServiceRegistry.get(); + + if (registry != null) { + registry.onNodeStarted(); + } + } + protected SSLService getSslService() { return XPackPlugin.getSharedSslService(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 4b53353b95d9e..e6149057cb1eb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -122,8 +122,6 @@ public ElasticInferenceService( configuration = new Configuration(authRef.get().taskTypesAndModels.getAuthorizedTaskTypes()); defaultModelsConfigs = initDefaultEndpoints(elasticInferenceServiceComponents); - - getAuthorization(); } private static Map initDefaultEndpoints( @@ -255,9 +253,24 @@ private void handleRevokedDefaultConfigs(Set authorizedDefaultModelIds) authorizationCompletedLatch.countDown(); }); - getServiceComponents().threadPool() - .executor(UTILITY_THREAD_POOL_NAME) - .execute(() -> modelRegistry.removeDefaultConfigs(unauthorizedDefaultInferenceEndpointIds, deleteInferenceEndpointsListener)); + Runnable removeFromRegistry = () -> { + logger.debug("Synchronizing default inference endpoints"); + modelRegistry.removeDefaultConfigs(unauthorizedDefaultInferenceEndpointIds, deleteInferenceEndpointsListener); + }; + + var delay = elasticInferenceServiceComponents.revokeAuthorizationDelay(); + if (delay == null) { + getServiceComponents().threadPool().executor(UTILITY_THREAD_POOL_NAME).execute(removeFromRegistry); + } else { + getServiceComponents().threadPool() + .schedule(removeFromRegistry, delay, getServiceComponents().threadPool().executor(UTILITY_THREAD_POOL_NAME)); + } + + } + + @Override + public void onNodeStarted() { + getServiceComponents().threadPool().executor(UTILITY_THREAD_POOL_NAME).execute(this::getAuthorization); } /** diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceComponents.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceComponents.java index 837581667882d..f79de437fcaf2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceComponents.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceComponents.java @@ -8,5 +8,23 @@ package org.elasticsearch.xpack.inference.services.elastic; import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; -public record ElasticInferenceServiceComponents(@Nullable String elasticInferenceServiceUrl) {} +/** + * @param elasticInferenceServiceUrl the upstream Elastic Inference Server's URL + * @param revokeAuthorizationDelay Amount of time to wait before attempting to revoke authorization to certain model ids. + * null indicates that there should be no delay + */ +public record ElasticInferenceServiceComponents(@Nullable String elasticInferenceServiceUrl, @Nullable TimeValue revokeAuthorizationDelay) { + private static final TimeValue DEFAULT_REVOKE_AUTHORIZATION_DELAY = TimeValue.timeValueMinutes(10); + + public static final ElasticInferenceServiceComponents EMPTY_INSTANCE = new ElasticInferenceServiceComponents(null, null); + + public static ElasticInferenceServiceComponents withNoRevokeDelay(String elasticInferenceServiceUrl) { + return new ElasticInferenceServiceComponents(elasticInferenceServiceUrl, null); + } + + public static ElasticInferenceServiceComponents withDefaultRevokeDelay(String elasticInferenceServiceUrl) { + return new ElasticInferenceServiceComponents(elasticInferenceServiceUrl, DEFAULT_REVOKE_AUTHORIZATION_DELAY); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java index f78b5357caeb3..4061e78c31dc4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java @@ -76,7 +76,7 @@ public void getAuthorization(ActionListener listener = new PlainActionFuture<>(); service.unifiedCompletionInfer( @@ -1053,6 +1070,11 @@ private void testUnifiedStreamError(int responseCode, String responseJson, Strin } } + private void ensureAuthorizationCallFinished(ElasticInferenceService service) { + service.onNodeStarted(); + service.waitForAuthorizationToComplete(TIMEOUT); + } + private ElasticInferenceService createServiceWithMockSender() { return createServiceWithMockSender(ElasticInferenceServiceAuthorizationTests.createEnabledAuth()); } @@ -1068,7 +1090,7 @@ private ElasticInferenceService createServiceWithMockSender(ElasticInferenceServ return new ElasticInferenceService( mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), - new ElasticInferenceServiceComponents(null), + ElasticInferenceServiceComponents.EMPTY_INSTANCE, mockModelRegistry(), mockAuthHandler ); @@ -1097,7 +1119,7 @@ private ElasticInferenceService createService( return new ElasticInferenceService( senderFactory, createWithEmptySettings(threadPool), - new ElasticInferenceServiceComponents(gatewayUrl), + ElasticInferenceServiceComponents.withNoRevokeDelay(gatewayUrl), mockModelRegistry(), mockAuthHandler ); @@ -1107,7 +1129,7 @@ private ElasticInferenceService createServiceWithAuthHandler(HttpRequestSender.F return new ElasticInferenceService( senderFactory, createWithEmptySettings(threadPool), - new ElasticInferenceServiceComponents(eisGatewayUrl), + ElasticInferenceServiceComponents.withNoRevokeDelay(eisGatewayUrl), mockModelRegistry(), new ElasticInferenceServiceAuthorizationHandler(eisGatewayUrl, threadPool) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java index a819bf1b4a513..a87c3f814b7e1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java @@ -83,9 +83,10 @@ public void testDoesNotAttempt_ToRetrieveAuthorization_IfBaseUrlIsNull() throws assertFalse(authResponse.isAuthorized()); var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); - verify(logger).warn(loggerArgsCaptor.capture()); - var message = loggerArgsCaptor.getValue(); - assertThat(message, is("The base URL for the authorization service is not valid, rejecting authorization.")); + verify(logger, times(2)).debug(loggerArgsCaptor.capture()); + var messages = loggerArgsCaptor.getAllValues(); + assertThat(messages.getFirst(), is("Retrieving authorization information from the Elastic Inference Service.")); + assertThat(messages.get(1), is("The base URL for the authorization service is not valid, rejecting authorization.")); } } @@ -104,9 +105,10 @@ public void testDoesNotAttempt_ToRetrieveAuthorization_IfBaseUrlIsEmpty() throws assertFalse(authResponse.isAuthorized()); var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); - verify(logger).warn(loggerArgsCaptor.capture()); - var message = loggerArgsCaptor.getValue(); - assertThat(message, is("The base URL for the authorization service is not valid, rejecting authorization.")); + verify(logger, times(2)).debug(loggerArgsCaptor.capture()); + var messages = loggerArgsCaptor.getAllValues(); + assertThat(messages.getFirst(), is("Retrieving authorization information from the Elastic Inference Service.")); + assertThat(messages.get(1), is("The base URL for the authorization service is not valid, rejecting authorization.")); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModelTests.java index 07da96cb32273..5778fadd30a83 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModelTests.java @@ -29,7 +29,7 @@ public void testOverridingModelId() { new ElasticInferenceServiceCompletionServiceSettings("model_id", new RateLimitSettings(100)), EmptyTaskSettings.INSTANCE, EmptySecretSettings.INSTANCE, - new ElasticInferenceServiceComponents("url") + ElasticInferenceServiceComponents.withNoRevokeDelay("url") ); var request = new UnifiedCompletionRequest(