From 469874c85905fd7774205a1159c606eb9b50678a Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Fri, 14 Feb 2025 11:59:57 -0500 Subject: [PATCH 1/3] Refactoring authorization to happen after the node starts --- .../inference/InferenceService.java | 6 ++++ .../inference/InferenceServiceRegistry.java | 4 +++ .../InferenceRevokeDefaultEndpointsIT.java | 19 ++++++++--- .../xpack/inference/InferencePlugin.java | 13 +++++++- .../elastic/ElasticInferenceService.java | 7 ++-- ...cInferenceServiceAuthorizationHandler.java | 2 +- .../elastic/ElasticInferenceServiceTests.java | 33 +++++++++++++++---- 7 files changed, 69 insertions(+), 15 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index e1ebd8bb81ff4..de1925cb641e9 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -241,4 +241,10 @@ default void defaultConfigs(ActionListener> defaultsListener) { default void updateModelsWithDynamicFields(List model, ActionListener> listener) { listener.onResponse(model); } + + /** + * Called after the Elasticsearch node has completed its start up. This allows the service to perform initialization + * after ensuring the node's internals are set up (for example if this ensures the internal ES client is ready for use). + */ + default void onNodeStarted() {} } diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java index f1ce94173a550..d4d18fed6dd02 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java @@ -41,6 +41,10 @@ public void init(Client client) { services.values().forEach(s -> s.init(client)); } + public void onNodeStarted() { + services.values().forEach(InferenceService::onNodeStarted); + } + public Map getServices() { return services; } diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java index 5205ce07a0676..181f4c0a18f46 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java @@ -91,7 +91,7 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCo webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); try (var service = createElasticInferenceService()) { - service.waitForAuthorizationToComplete(TIMEOUT); + ensureAuthorizationCallFinished(service); assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); assertThat( service.defaultConfigIds(), @@ -125,7 +125,8 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty() webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); try (var service = createElasticInferenceService()) { - service.waitForAuthorizationToComplete(TIMEOUT); + ensureAuthorizationCallFinished(service); + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); assertThat( service.defaultConfigIds(), @@ -164,7 +165,8 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty() webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson)); try (var service = createElasticInferenceService()) { - service.waitForAuthorizationToComplete(TIMEOUT); + ensureAuthorizationCallFinished(service); + assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class))); assertTrue(service.defaultConfigIds().isEmpty()); assertThat(service.supportedTaskTypes(), is(EnumSet.noneOf(TaskType.class))); @@ -198,7 +200,8 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); try (var service = createElasticInferenceService()) { - service.waitForAuthorizationToComplete(TIMEOUT); + ensureAuthorizationCallFinished(service); + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); assertThat( service.defaultConfigIds(), @@ -244,7 +247,8 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson)); try (var service = createElasticInferenceService()) { - service.waitForAuthorizationToComplete(TIMEOUT); + ensureAuthorizationCallFinished(service); + assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class))); assertThat( service.defaultConfigIds(), @@ -264,6 +268,11 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA } } + private void ensureAuthorizationCallFinished(ElasticInferenceService service) { + service.onNodeStarted(); + service.waitForAuthorizationToComplete(TIMEOUT); + } + private ElasticInferenceService createElasticInferenceService() { var httpManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, httpManager); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index e3604351c1937..a61f03d50031c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -34,6 +34,7 @@ import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.node.PluginComponentBinding; import org.elasticsearch.plugins.ActionPlugin; +import org.elasticsearch.plugins.ClusterPlugin; import org.elasticsearch.plugins.ExtensiblePlugin; import org.elasticsearch.plugins.MapperPlugin; import org.elasticsearch.plugins.Plugin; @@ -146,7 +147,8 @@ public class InferencePlugin extends Plugin SystemIndexPlugin, MapperPlugin, SearchPlugin, - InternalSearchPlugin { + InternalSearchPlugin, + ClusterPlugin { /** * When this setting is true the verification check that @@ -507,6 +509,15 @@ public Map getHighlighters() { return Map.of(SemanticTextHighlighter.NAME, new SemanticTextHighlighter()); } + @Override + public void onNodeStarted() { + var registry = inferenceServiceRegistry.get(); + + if (registry != null) { + registry.onNodeStarted(); + } + } + protected SSLService getSslService() { return XPackPlugin.getSharedSslService(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index fee66a9f84ac9..a856769fbd04d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -134,8 +134,6 @@ public ElasticInferenceService( configuration = new Configuration(authRef.get().taskTypesAndModels.getAuthorizedTaskTypes()); defaultModelsConfigs = initDefaultEndpoints(elasticInferenceServiceComponents); - - getAuthorization(); } private static Map initDefaultEndpoints( @@ -287,6 +285,11 @@ private void handleRevokedDefaultConfigs(Set authorizedDefaultModelIds) .execute(() -> modelRegistry.removeDefaultConfigs(unauthorizedDefaultInferenceEndpointIds, deleteInferenceEndpointsListener)); } + @Override + public void onNodeStarted() { + getServiceComponents().threadPool().executor(UTILITY_THREAD_POOL_NAME).execute(this::getAuthorization); + } + /** * Waits the specified amount of time for the authorization call to complete. This is mainly to make testing easier. * @param waitTime the max time to wait diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java index f78b5357caeb3..4061e78c31dc4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java @@ -76,7 +76,7 @@ public void getAuthorization(ActionListener Date: Fri, 14 Feb 2025 14:41:54 -0500 Subject: [PATCH 2/3] Adding delay for model registry call --- .../InferenceRevokeDefaultEndpointsIT.java | 2 +- .../xpack/inference/InferencePlugin.java | 2 +- .../elastic/ElasticInferenceService.java | 16 ++++++++++++--- .../ElasticInferenceServiceComponents.java | 20 ++++++++++++++++++- ...enceServiceSparseEmbeddingsModelTests.java | 2 +- .../elastic/ElasticInferenceServiceTests.java | 8 ++++---- ...cInferenceServiceCompletionModelTests.java | 2 +- 7 files changed, 40 insertions(+), 12 deletions(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java index 181f4c0a18f46..2070b0f1f1574 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java @@ -280,7 +280,7 @@ private ElasticInferenceService createElasticInferenceService() { return new ElasticInferenceService( senderFactory, createWithEmptySettings(threadPool), - new ElasticInferenceServiceComponents(gatewayUrl), + ElasticInferenceServiceComponents.withNoRevokeDelay(gatewayUrl), modelRegistry, new ElasticInferenceServiceAuthorizationHandler(gatewayUrl, threadPool) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index a61f03d50031c..05b8944138a60 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -276,7 +276,7 @@ public Collection createComponents(PluginServices services) { ElasticInferenceServiceSettings inferenceServiceSettings = new ElasticInferenceServiceSettings(settings); String elasticInferenceUrl = inferenceServiceSettings.getElasticInferenceServiceUrl(); - var elasticInferenceServiceComponentsInstance = new ElasticInferenceServiceComponents(elasticInferenceUrl); + var elasticInferenceServiceComponentsInstance = ElasticInferenceServiceComponents.withDefaultRevokeDelay(elasticInferenceUrl); elasticInferenceServiceComponents.set(elasticInferenceServiceComponentsInstance); var authorizationHandler = new ElasticInferenceServiceAuthorizationHandler( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index a856769fbd04d..e8e51c1b4b7eb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -280,9 +280,19 @@ private void handleRevokedDefaultConfigs(Set authorizedDefaultModelIds) authorizationCompletedLatch.countDown(); }); - getServiceComponents().threadPool() - .executor(UTILITY_THREAD_POOL_NAME) - .execute(() -> modelRegistry.removeDefaultConfigs(unauthorizedDefaultInferenceEndpointIds, deleteInferenceEndpointsListener)); + Runnable removeFromRegistry = () -> { + logger.debug("Synchronizing default inference endpoints"); + modelRegistry.removeDefaultConfigs(unauthorizedDefaultInferenceEndpointIds, deleteInferenceEndpointsListener); + }; + + var delay = elasticInferenceServiceComponents.revokeAuthorizationDelay(); + if (delay == null) { + getServiceComponents().threadPool().executor(UTILITY_THREAD_POOL_NAME).execute(removeFromRegistry); + } else { + getServiceComponents().threadPool() + .schedule(removeFromRegistry, delay, getServiceComponents().threadPool().executor(UTILITY_THREAD_POOL_NAME)); + } + } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceComponents.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceComponents.java index 837581667882d..f79de437fcaf2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceComponents.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceComponents.java @@ -8,5 +8,23 @@ package org.elasticsearch.xpack.inference.services.elastic; import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; -public record ElasticInferenceServiceComponents(@Nullable String elasticInferenceServiceUrl) {} +/** + * @param elasticInferenceServiceUrl the upstream Elastic Inference Server's URL + * @param revokeAuthorizationDelay Amount of time to wait before attempting to revoke authorization to certain model ids. + * null indicates that there should be no delay + */ +public record ElasticInferenceServiceComponents(@Nullable String elasticInferenceServiceUrl, @Nullable TimeValue revokeAuthorizationDelay) { + private static final TimeValue DEFAULT_REVOKE_AUTHORIZATION_DELAY = TimeValue.timeValueMinutes(10); + + public static final ElasticInferenceServiceComponents EMPTY_INSTANCE = new ElasticInferenceServiceComponents(null, null); + + public static ElasticInferenceServiceComponents withNoRevokeDelay(String elasticInferenceServiceUrl) { + return new ElasticInferenceServiceComponents(elasticInferenceServiceUrl, null); + } + + public static ElasticInferenceServiceComponents withDefaultRevokeDelay(String elasticInferenceServiceUrl) { + return new ElasticInferenceServiceComponents(elasticInferenceServiceUrl, DEFAULT_REVOKE_AUTHORIZATION_DELAY); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java index 02bbbb844c04f..1b4cd026b816f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java @@ -26,7 +26,7 @@ public static ElasticInferenceServiceSparseEmbeddingsModel createModel(String ur new ElasticInferenceServiceSparseEmbeddingsServiceSettings(modelId, maxInputTokens, null), EmptyTaskSettings.INSTANCE, EmptySecretSettings.INSTANCE, - new ElasticInferenceServiceComponents(url) + ElasticInferenceServiceComponents.withNoRevokeDelay(url) ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index aef482f9ae49a..2fbb9498cd202 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -1046,7 +1046,7 @@ private void testUnifiedStreamError(int responseCode, String responseJson, Strin new ElasticInferenceServiceCompletionServiceSettings("model_id", new RateLimitSettings(100)), EmptyTaskSettings.INSTANCE, EmptySecretSettings.INSTANCE, - new ElasticInferenceServiceComponents(eisGatewayUrl) + ElasticInferenceServiceComponents.withNoRevokeDelay(eisGatewayUrl) ); PlainActionFuture listener = new PlainActionFuture<>(); service.unifiedCompletionInfer( @@ -1099,7 +1099,7 @@ private ElasticInferenceService createServiceWithMockSender(ElasticInferenceServ return new ElasticInferenceService( mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), - new ElasticInferenceServiceComponents(null), + ElasticInferenceServiceComponents.EMPTY_INSTANCE, mockModelRegistry(), mockAuthHandler ); @@ -1128,7 +1128,7 @@ private ElasticInferenceService createService( return new ElasticInferenceService( senderFactory, createWithEmptySettings(threadPool), - new ElasticInferenceServiceComponents(gatewayUrl), + ElasticInferenceServiceComponents.withNoRevokeDelay(gatewayUrl), mockModelRegistry(), mockAuthHandler ); @@ -1138,7 +1138,7 @@ private ElasticInferenceService createServiceWithAuthHandler(HttpRequestSender.F return new ElasticInferenceService( senderFactory, createWithEmptySettings(threadPool), - new ElasticInferenceServiceComponents(eisGatewayUrl), + ElasticInferenceServiceComponents.withNoRevokeDelay(eisGatewayUrl), mockModelRegistry(), new ElasticInferenceServiceAuthorizationHandler(eisGatewayUrl, threadPool) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModelTests.java index 07da96cb32273..5778fadd30a83 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModelTests.java @@ -29,7 +29,7 @@ public void testOverridingModelId() { new ElasticInferenceServiceCompletionServiceSettings("model_id", new RateLimitSettings(100)), EmptyTaskSettings.INSTANCE, EmptySecretSettings.INSTANCE, - new ElasticInferenceServiceComponents("url") + ElasticInferenceServiceComponents.withNoRevokeDelay("url") ); var request = new UnifiedCompletionRequest( From 1d5e36a5d8086d9d26effdcf3e4b8cc783f63f53 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 18 Feb 2025 16:57:18 -0500 Subject: [PATCH 3/3] Fixing test --- ...cInferenceServiceAuthorizationHandlerTests.java | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java index a819bf1b4a513..a87c3f814b7e1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java @@ -83,9 +83,10 @@ public void testDoesNotAttempt_ToRetrieveAuthorization_IfBaseUrlIsNull() throws assertFalse(authResponse.isAuthorized()); var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); - verify(logger).warn(loggerArgsCaptor.capture()); - var message = loggerArgsCaptor.getValue(); - assertThat(message, is("The base URL for the authorization service is not valid, rejecting authorization.")); + verify(logger, times(2)).debug(loggerArgsCaptor.capture()); + var messages = loggerArgsCaptor.getAllValues(); + assertThat(messages.getFirst(), is("Retrieving authorization information from the Elastic Inference Service.")); + assertThat(messages.get(1), is("The base URL for the authorization service is not valid, rejecting authorization.")); } } @@ -104,9 +105,10 @@ public void testDoesNotAttempt_ToRetrieveAuthorization_IfBaseUrlIsEmpty() throws assertFalse(authResponse.isAuthorized()); var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); - verify(logger).warn(loggerArgsCaptor.capture()); - var message = loggerArgsCaptor.getValue(); - assertThat(message, is("The base URL for the authorization service is not valid, rejecting authorization.")); + verify(logger, times(2)).debug(loggerArgsCaptor.capture()); + var messages = loggerArgsCaptor.getAllValues(); + assertThat(messages.getFirst(), is("Retrieving authorization information from the Elastic Inference Service.")); + assertThat(messages.get(1), is("The base URL for the authorization service is not valid, rejecting authorization.")); } }