Skip to content

Commit 469874c

Browse files
Refactoring authorization to happen after the node starts
1 parent d9beed1 commit 469874c

File tree

7 files changed

+69
-15
lines changed

7 files changed

+69
-15
lines changed

server/src/main/java/org/elasticsearch/inference/InferenceService.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,4 +241,10 @@ default void defaultConfigs(ActionListener<List<Model>> defaultsListener) {
241241
default void updateModelsWithDynamicFields(List<Model> model, ActionListener<List<Model>> listener) {
242242
listener.onResponse(model);
243243
}
244+
245+
/**
246+
* Called after the Elasticsearch node has completed its start up. This allows the service to perform initialization
247+
* after ensuring the node's internals are set up (for example if this ensures the internal ES client is ready for use).
248+
*/
249+
default void onNodeStarted() {}
244250
}

server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ public void init(Client client) {
4141
services.values().forEach(s -> s.init(client));
4242
}
4343

44+
public void onNodeStarted() {
45+
services.values().forEach(InferenceService::onNodeStarted);
46+
}
47+
4448
public Map<String, InferenceService> getServices() {
4549
return services;
4650
}

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCo
9191
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
9292

9393
try (var service = createElasticInferenceService()) {
94-
service.waitForAuthorizationToComplete(TIMEOUT);
94+
ensureAuthorizationCallFinished(service);
9595
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
9696
assertThat(
9797
service.defaultConfigIds(),
@@ -125,7 +125,8 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty()
125125
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
126126

127127
try (var service = createElasticInferenceService()) {
128-
service.waitForAuthorizationToComplete(TIMEOUT);
128+
ensureAuthorizationCallFinished(service);
129+
129130
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
130131
assertThat(
131132
service.defaultConfigIds(),
@@ -164,7 +165,8 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty()
164165
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson));
165166

166167
try (var service = createElasticInferenceService()) {
167-
service.waitForAuthorizationToComplete(TIMEOUT);
168+
ensureAuthorizationCallFinished(service);
169+
168170
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
169171
assertTrue(service.defaultConfigIds().isEmpty());
170172
assertThat(service.supportedTaskTypes(), is(EnumSet.noneOf(TaskType.class)));
@@ -198,7 +200,8 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
198200
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
199201

200202
try (var service = createElasticInferenceService()) {
201-
service.waitForAuthorizationToComplete(TIMEOUT);
203+
ensureAuthorizationCallFinished(service);
204+
202205
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
203206
assertThat(
204207
service.defaultConfigIds(),
@@ -244,7 +247,8 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
244247
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson));
245248

246249
try (var service = createElasticInferenceService()) {
247-
service.waitForAuthorizationToComplete(TIMEOUT);
250+
ensureAuthorizationCallFinished(service);
251+
248252
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
249253
assertThat(
250254
service.defaultConfigIds(),
@@ -264,6 +268,11 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
264268
}
265269
}
266270

271+
private void ensureAuthorizationCallFinished(ElasticInferenceService service) {
272+
service.onNodeStarted();
273+
service.waitForAuthorizationToComplete(TIMEOUT);
274+
}
275+
267276
private ElasticInferenceService createElasticInferenceService() {
268277
var httpManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
269278
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, httpManager);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.elasticsearch.license.XPackLicenseState;
3535
import org.elasticsearch.node.PluginComponentBinding;
3636
import org.elasticsearch.plugins.ActionPlugin;
37+
import org.elasticsearch.plugins.ClusterPlugin;
3738
import org.elasticsearch.plugins.ExtensiblePlugin;
3839
import org.elasticsearch.plugins.MapperPlugin;
3940
import org.elasticsearch.plugins.Plugin;
@@ -146,7 +147,8 @@ public class InferencePlugin extends Plugin
146147
SystemIndexPlugin,
147148
MapperPlugin,
148149
SearchPlugin,
149-
InternalSearchPlugin {
150+
InternalSearchPlugin,
151+
ClusterPlugin {
150152

151153
/**
152154
* When this setting is true the verification check that
@@ -507,6 +509,15 @@ public Map<String, Highlighter> getHighlighters() {
507509
return Map.of(SemanticTextHighlighter.NAME, new SemanticTextHighlighter());
508510
}
509511

512+
@Override
513+
public void onNodeStarted() {
514+
var registry = inferenceServiceRegistry.get();
515+
516+
if (registry != null) {
517+
registry.onNodeStarted();
518+
}
519+
}
520+
510521
protected SSLService getSslService() {
511522
return XPackPlugin.getSharedSslService();
512523
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,6 @@ public ElasticInferenceService(
134134

135135
configuration = new Configuration(authRef.get().taskTypesAndModels.getAuthorizedTaskTypes());
136136
defaultModelsConfigs = initDefaultEndpoints(elasticInferenceServiceComponents);
137-
138-
getAuthorization();
139137
}
140138

141139
private static Map<String, DefaultModelConfig> initDefaultEndpoints(
@@ -287,6 +285,11 @@ private void handleRevokedDefaultConfigs(Set<String> authorizedDefaultModelIds)
287285
.execute(() -> modelRegistry.removeDefaultConfigs(unauthorizedDefaultInferenceEndpointIds, deleteInferenceEndpointsListener));
288286
}
289287

288+
@Override
289+
public void onNodeStarted() {
290+
getServiceComponents().threadPool().executor(UTILITY_THREAD_POOL_NAME).execute(this::getAuthorization);
291+
}
292+
290293
/**
291294
* Waits the specified amount of time for the authorization call to complete. This is mainly to make testing easier.
292295
* @param waitTime the max time to wait

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ public void getAuthorization(ActionListener<ElasticInferenceServiceAuthorization
7676
logger.debug("Retrieving authorization information from the Elastic Inference Service.");
7777

7878
if (Strings.isNullOrEmpty(baseUrl)) {
79-
logger.warn("The base URL for the authorization service is not valid, rejecting authorization.");
79+
logger.debug("The base URL for the authorization service is not valid, rejecting authorization.");
8080
listener.onResponse(ElasticInferenceServiceAuthorization.newDisabledService());
8181
return;
8282
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,8 @@ public void testChunkedInfer_PassesThrough() throws IOException {
568568

569569
public void testHideFromConfigurationApi_ReturnsTrue_WithNoAvailableModels() throws Exception {
570570
try (var service = createServiceWithMockSender(ElasticInferenceServiceAuthorization.newDisabledService())) {
571+
ensureAuthorizationCallFinished(service);
572+
571573
assertTrue(service.hideFromConfigurationApi());
572574
}
573575
}
@@ -587,6 +589,8 @@ public void testHideFromConfigurationApi_ReturnsTrue_WithModelTaskTypesThatAreNo
587589
)
588590
)
589591
) {
592+
ensureAuthorizationCallFinished(service);
593+
590594
assertTrue(service.hideFromConfigurationApi());
591595
}
592596
}
@@ -606,6 +610,8 @@ public void testHideFromConfigurationApi_ReturnsFalse_WithAvailableModels() thro
606610
)
607611
)
608612
) {
613+
ensureAuthorizationCallFinished(service);
614+
609615
assertFalse(service.hideFromConfigurationApi());
610616
}
611617
}
@@ -625,6 +631,8 @@ public void testGetConfiguration() throws Exception {
625631
)
626632
)
627633
) {
634+
ensureAuthorizationCallFinished(service);
635+
628636
String content = XContentHelper.stripWhitespace("""
629637
{
630638
"service": "elastic",
@@ -678,6 +686,8 @@ public void testGetConfiguration() throws Exception {
678686

679687
public void testGetConfiguration_WithoutSupportedTaskTypes() throws Exception {
680688
try (var service = createServiceWithMockSender(ElasticInferenceServiceAuthorization.newDisabledService())) {
689+
ensureAuthorizationCallFinished(service);
690+
681691
String content = XContentHelper.stripWhitespace("""
682692
{
683693
"service": "elastic",
@@ -745,6 +755,8 @@ public void testGetConfiguration_WithoutSupportedTaskTypes_WhenModelsReturnTaskO
745755
)
746756
)
747757
) {
758+
ensureAuthorizationCallFinished(service);
759+
748760
String content = XContentHelper.stripWhitespace("""
749761
{
750762
"service": "elastic",
@@ -812,7 +824,8 @@ public void testSupportedStreamingTasks_ReturnsChatCompletion_WhenAuthRespondsWi
812824

813825
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
814826
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
815-
service.waitForAuthorizationToComplete(TIMEOUT);
827+
ensureAuthorizationCallFinished(service);
828+
816829
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
817830
assertFalse(service.canStream(TaskType.ANY));
818831
assertTrue(service.defaultConfigIds().isEmpty());
@@ -843,7 +856,8 @@ public void testSupportedTaskTypes_Returns_TheAuthorizedTaskTypes_IgnoresUnimple
843856

844857
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
845858
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
846-
service.waitForAuthorizationToComplete(TIMEOUT);
859+
ensureAuthorizationCallFinished(service);
860+
847861
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING)));
848862
}
849863
}
@@ -868,7 +882,8 @@ public void testSupportedTaskTypes_Returns_TheAuthorizedTaskTypes() throws Excep
868882

869883
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
870884
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
871-
service.waitForAuthorizationToComplete(TIMEOUT);
885+
ensureAuthorizationCallFinished(service);
886+
872887
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION)));
873888
}
874889
}
@@ -889,7 +904,8 @@ public void testSupportedStreamingTasks_ReturnsEmpty_WhenAuthRespondsWithoutChat
889904

890905
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
891906
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
892-
service.waitForAuthorizationToComplete(TIMEOUT);
907+
ensureAuthorizationCallFinished(service);
908+
893909
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
894910
assertTrue(service.defaultConfigIds().isEmpty());
895911
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING)));
@@ -916,7 +932,7 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsIn
916932

917933
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
918934
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
919-
service.waitForAuthorizationToComplete(TIMEOUT);
935+
ensureAuthorizationCallFinished(service);
920936
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
921937
assertThat(
922938
service.defaultConfigIds(),
@@ -954,7 +970,7 @@ public void testDefaultConfigs_Returns_DefaultEndpoints_WhenTaskTypeIsCorrect()
954970

955971
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
956972
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
957-
service.waitForAuthorizationToComplete(TIMEOUT);
973+
ensureAuthorizationCallFinished(service);
958974
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
959975
assertFalse(service.canStream(TaskType.ANY));
960976
assertThat(
@@ -1063,6 +1079,11 @@ private void testUnifiedStreamError(int responseCode, String responseJson, Strin
10631079
}
10641080
}
10651081

1082+
private void ensureAuthorizationCallFinished(ElasticInferenceService service) {
1083+
service.onNodeStarted();
1084+
service.waitForAuthorizationToComplete(TIMEOUT);
1085+
}
1086+
10661087
private ElasticInferenceService createServiceWithMockSender() {
10671088
return createServiceWithMockSender(ElasticInferenceServiceAuthorizationTests.createEnabledAuth());
10681089
}

0 commit comments

Comments
 (0)