Skip to content

Commit 4de8244

Browse files
[ML] Support delaying EIS authorization revocation until after the node has finished booting (#122644)
* Refactoring authorization to happen after the node starts * Adding delay for model registry call * Fixing test
1 parent a26b596 commit 4de8244

File tree

11 files changed

+117
-33
lines changed

11 files changed

+117
-33
lines changed

server/src/main/java/org/elasticsearch/inference/InferenceService.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,4 +241,10 @@ default void defaultConfigs(ActionListener<List<Model>> defaultsListener) {
241241
default void updateModelsWithDynamicFields(List<Model> model, ActionListener<List<Model>> listener) {
242242
listener.onResponse(model);
243243
}
244+
245+
/**
246+
* Called after the Elasticsearch node has completed its start up. This allows the service to perform initialization
247+
* after ensuring the node's internals are set up (for example if this ensures the internal ES client is ready for use).
248+
*/
249+
default void onNodeStarted() {}
244250
}

server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ public void init(Client client) {
4141
services.values().forEach(s -> s.init(client));
4242
}
4343

44+
public void onNodeStarted() {
45+
services.values().forEach(InferenceService::onNodeStarted);
46+
}
47+
4448
public Map<String, InferenceService> getServices() {
4549
return services;
4650
}

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCo
9191
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
9292

9393
try (var service = createElasticInferenceService()) {
94-
service.waitForAuthorizationToComplete(TIMEOUT);
94+
ensureAuthorizationCallFinished(service);
9595
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
9696
assertThat(
9797
service.defaultConfigIds(),
@@ -125,7 +125,8 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty()
125125
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
126126

127127
try (var service = createElasticInferenceService()) {
128-
service.waitForAuthorizationToComplete(TIMEOUT);
128+
ensureAuthorizationCallFinished(service);
129+
129130
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
130131
assertThat(
131132
service.defaultConfigIds(),
@@ -164,7 +165,8 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty()
164165
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson));
165166

166167
try (var service = createElasticInferenceService()) {
167-
service.waitForAuthorizationToComplete(TIMEOUT);
168+
ensureAuthorizationCallFinished(service);
169+
168170
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
169171
assertTrue(service.defaultConfigIds().isEmpty());
170172
assertThat(service.supportedTaskTypes(), is(EnumSet.noneOf(TaskType.class)));
@@ -198,7 +200,8 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
198200
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
199201

200202
try (var service = createElasticInferenceService()) {
201-
service.waitForAuthorizationToComplete(TIMEOUT);
203+
ensureAuthorizationCallFinished(service);
204+
202205
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
203206
assertThat(
204207
service.defaultConfigIds(),
@@ -244,7 +247,8 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
244247
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson));
245248

246249
try (var service = createElasticInferenceService()) {
247-
service.waitForAuthorizationToComplete(TIMEOUT);
250+
ensureAuthorizationCallFinished(service);
251+
248252
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
249253
assertThat(
250254
service.defaultConfigIds(),
@@ -264,14 +268,19 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
264268
}
265269
}
266270

271+
private void ensureAuthorizationCallFinished(ElasticInferenceService service) {
272+
service.onNodeStarted();
273+
service.waitForAuthorizationToComplete(TIMEOUT);
274+
}
275+
267276
private ElasticInferenceService createElasticInferenceService() {
268277
var httpManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
269278
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, httpManager);
270279

271280
return new ElasticInferenceService(
272281
senderFactory,
273282
createWithEmptySettings(threadPool),
274-
new ElasticInferenceServiceComponents(gatewayUrl),
283+
ElasticInferenceServiceComponents.withNoRevokeDelay(gatewayUrl),
275284
modelRegistry,
276285
new ElasticInferenceServiceAuthorizationHandler(gatewayUrl, threadPool)
277286
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.elasticsearch.license.XPackLicenseState;
3535
import org.elasticsearch.node.PluginComponentBinding;
3636
import org.elasticsearch.plugins.ActionPlugin;
37+
import org.elasticsearch.plugins.ClusterPlugin;
3738
import org.elasticsearch.plugins.ExtensiblePlugin;
3839
import org.elasticsearch.plugins.MapperPlugin;
3940
import org.elasticsearch.plugins.Plugin;
@@ -146,7 +147,8 @@ public class InferencePlugin extends Plugin
146147
SystemIndexPlugin,
147148
MapperPlugin,
148149
SearchPlugin,
149-
InternalSearchPlugin {
150+
InternalSearchPlugin,
151+
ClusterPlugin {
150152

151153
/**
152154
* When this setting is true the verification check that
@@ -274,7 +276,7 @@ public Collection<?> createComponents(PluginServices services) {
274276
ElasticInferenceServiceSettings inferenceServiceSettings = new ElasticInferenceServiceSettings(settings);
275277
String elasticInferenceUrl = inferenceServiceSettings.getElasticInferenceServiceUrl();
276278

277-
var elasticInferenceServiceComponentsInstance = new ElasticInferenceServiceComponents(elasticInferenceUrl);
279+
var elasticInferenceServiceComponentsInstance = ElasticInferenceServiceComponents.withDefaultRevokeDelay(elasticInferenceUrl);
278280
elasticInferenceServiceComponents.set(elasticInferenceServiceComponentsInstance);
279281

280282
var authorizationHandler = new ElasticInferenceServiceAuthorizationHandler(
@@ -507,6 +509,15 @@ public Map<String, Highlighter> getHighlighters() {
507509
return Map.of(SemanticTextHighlighter.NAME, new SemanticTextHighlighter());
508510
}
509511

512+
@Override
513+
public void onNodeStarted() {
514+
var registry = inferenceServiceRegistry.get();
515+
516+
if (registry != null) {
517+
registry.onNodeStarted();
518+
}
519+
}
520+
510521
protected SSLService getSslService() {
511522
return XPackPlugin.getSharedSslService();
512523
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,6 @@ public ElasticInferenceService(
135135

136136
configuration = new Configuration(authRef.get().taskTypesAndModels.getAuthorizedTaskTypes());
137137
defaultModelsConfigs = initDefaultEndpoints(elasticInferenceServiceComponents);
138-
139-
getAuthorization();
140138
}
141139

142140
private static Map<String, DefaultModelConfig> initDefaultEndpoints(
@@ -283,9 +281,24 @@ private void handleRevokedDefaultConfigs(Set<String> authorizedDefaultModelIds)
283281
authorizationCompletedLatch.countDown();
284282
});
285283

286-
getServiceComponents().threadPool()
287-
.executor(UTILITY_THREAD_POOL_NAME)
288-
.execute(() -> modelRegistry.removeDefaultConfigs(unauthorizedDefaultInferenceEndpointIds, deleteInferenceEndpointsListener));
284+
Runnable removeFromRegistry = () -> {
285+
logger.debug("Synchronizing default inference endpoints");
286+
modelRegistry.removeDefaultConfigs(unauthorizedDefaultInferenceEndpointIds, deleteInferenceEndpointsListener);
287+
};
288+
289+
var delay = elasticInferenceServiceComponents.revokeAuthorizationDelay();
290+
if (delay == null) {
291+
getServiceComponents().threadPool().executor(UTILITY_THREAD_POOL_NAME).execute(removeFromRegistry);
292+
} else {
293+
getServiceComponents().threadPool()
294+
.schedule(removeFromRegistry, delay, getServiceComponents().threadPool().executor(UTILITY_THREAD_POOL_NAME));
295+
}
296+
297+
}
298+
299+
@Override
300+
public void onNodeStarted() {
301+
getServiceComponents().threadPool().executor(UTILITY_THREAD_POOL_NAME).execute(this::getAuthorization);
289302
}
290303

291304
/**

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceComponents.java

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,23 @@
88
package org.elasticsearch.xpack.inference.services.elastic;
99

1010
import org.elasticsearch.core.Nullable;
11+
import org.elasticsearch.core.TimeValue;
1112

12-
public record ElasticInferenceServiceComponents(@Nullable String elasticInferenceServiceUrl) {}
13+
/**
14+
* @param elasticInferenceServiceUrl the upstream Elastic Inference Server's URL
15+
* @param revokeAuthorizationDelay Amount of time to wait before attempting to revoke authorization to certain model ids.
16+
* null indicates that there should be no delay
17+
*/
18+
public record ElasticInferenceServiceComponents(@Nullable String elasticInferenceServiceUrl, @Nullable TimeValue revokeAuthorizationDelay) {
19+
private static final TimeValue DEFAULT_REVOKE_AUTHORIZATION_DELAY = TimeValue.timeValueMinutes(10);
20+
21+
public static final ElasticInferenceServiceComponents EMPTY_INSTANCE = new ElasticInferenceServiceComponents(null, null);
22+
23+
public static ElasticInferenceServiceComponents withNoRevokeDelay(String elasticInferenceServiceUrl) {
24+
return new ElasticInferenceServiceComponents(elasticInferenceServiceUrl, null);
25+
}
26+
27+
public static ElasticInferenceServiceComponents withDefaultRevokeDelay(String elasticInferenceServiceUrl) {
28+
return new ElasticInferenceServiceComponents(elasticInferenceServiceUrl, DEFAULT_REVOKE_AUTHORIZATION_DELAY);
29+
}
30+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ public void getAuthorization(ActionListener<ElasticInferenceServiceAuthorization
7676
logger.debug("Retrieving authorization information from the Elastic Inference Service.");
7777

7878
if (Strings.isNullOrEmpty(baseUrl)) {
79-
logger.warn("The base URL for the authorization service is not valid, rejecting authorization.");
79+
logger.debug("The base URL for the authorization service is not valid, rejecting authorization.");
8080
listener.onResponse(ElasticInferenceServiceAuthorization.newDisabledService());
8181
return;
8282
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ public static ElasticInferenceServiceSparseEmbeddingsModel createModel(String ur
2626
new ElasticInferenceServiceSparseEmbeddingsServiceSettings(modelId, maxInputTokens, null),
2727
EmptyTaskSettings.INSTANCE,
2828
EmptySecretSettings.INSTANCE,
29-
new ElasticInferenceServiceComponents(url)
29+
ElasticInferenceServiceComponents.withNoRevokeDelay(url)
3030
);
3131
}
3232
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,8 @@ public void testChunkedInfer_PassesThrough() throws IOException {
582582

583583
public void testHideFromConfigurationApi_ReturnsTrue_WithNoAvailableModels() throws Exception {
584584
try (var service = createServiceWithMockSender(ElasticInferenceServiceAuthorization.newDisabledService())) {
585+
ensureAuthorizationCallFinished(service);
586+
585587
assertTrue(service.hideFromConfigurationApi());
586588
}
587589
}
@@ -601,6 +603,8 @@ public void testHideFromConfigurationApi_ReturnsTrue_WithModelTaskTypesThatAreNo
601603
)
602604
)
603605
) {
606+
ensureAuthorizationCallFinished(service);
607+
604608
assertTrue(service.hideFromConfigurationApi());
605609
}
606610
}
@@ -620,6 +624,8 @@ public void testHideFromConfigurationApi_ReturnsFalse_WithAvailableModels() thro
620624
)
621625
)
622626
) {
627+
ensureAuthorizationCallFinished(service);
628+
623629
assertFalse(service.hideFromConfigurationApi());
624630
}
625631
}
@@ -639,6 +645,8 @@ public void testGetConfiguration() throws Exception {
639645
)
640646
)
641647
) {
648+
ensureAuthorizationCallFinished(service);
649+
642650
String content = XContentHelper.stripWhitespace("""
643651
{
644652
"service": "elastic",
@@ -692,6 +700,8 @@ public void testGetConfiguration() throws Exception {
692700

693701
public void testGetConfiguration_WithoutSupportedTaskTypes() throws Exception {
694702
try (var service = createServiceWithMockSender(ElasticInferenceServiceAuthorization.newDisabledService())) {
703+
ensureAuthorizationCallFinished(service);
704+
695705
String content = XContentHelper.stripWhitespace("""
696706
{
697707
"service": "elastic",
@@ -759,6 +769,8 @@ public void testGetConfiguration_WithoutSupportedTaskTypes_WhenModelsReturnTaskO
759769
)
760770
)
761771
) {
772+
ensureAuthorizationCallFinished(service);
773+
762774
String content = XContentHelper.stripWhitespace("""
763775
{
764776
"service": "elastic",
@@ -826,7 +838,8 @@ public void testSupportedStreamingTasks_ReturnsChatCompletion_WhenAuthRespondsWi
826838

827839
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
828840
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
829-
service.waitForAuthorizationToComplete(TIMEOUT);
841+
ensureAuthorizationCallFinished(service);
842+
830843
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
831844
assertFalse(service.canStream(TaskType.ANY));
832845
assertTrue(service.defaultConfigIds().isEmpty());
@@ -857,7 +870,8 @@ public void testSupportedTaskTypes_Returns_TheAuthorizedTaskTypes_IgnoresUnimple
857870

858871
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
859872
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
860-
service.waitForAuthorizationToComplete(TIMEOUT);
873+
ensureAuthorizationCallFinished(service);
874+
861875
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING)));
862876
}
863877
}
@@ -882,7 +896,8 @@ public void testSupportedTaskTypes_Returns_TheAuthorizedTaskTypes() throws Excep
882896

883897
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
884898
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
885-
service.waitForAuthorizationToComplete(TIMEOUT);
899+
ensureAuthorizationCallFinished(service);
900+
886901
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION)));
887902
}
888903
}
@@ -903,7 +918,8 @@ public void testSupportedStreamingTasks_ReturnsEmpty_WhenAuthRespondsWithoutChat
903918

904919
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
905920
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
906-
service.waitForAuthorizationToComplete(TIMEOUT);
921+
ensureAuthorizationCallFinished(service);
922+
907923
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
908924
assertTrue(service.defaultConfigIds().isEmpty());
909925
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING)));
@@ -930,7 +946,7 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsIn
930946

931947
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
932948
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
933-
service.waitForAuthorizationToComplete(TIMEOUT);
949+
ensureAuthorizationCallFinished(service);
934950
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
935951
assertThat(
936952
service.defaultConfigIds(),
@@ -968,7 +984,7 @@ public void testDefaultConfigs_Returns_DefaultEndpoints_WhenTaskTypeIsCorrect()
968984

969985
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
970986
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
971-
service.waitForAuthorizationToComplete(TIMEOUT);
987+
ensureAuthorizationCallFinished(service);
972988
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
973989
assertFalse(service.canStream(TaskType.ANY));
974990
assertThat(
@@ -1044,7 +1060,7 @@ private void testUnifiedStreamError(int responseCode, String responseJson, Strin
10441060
new ElasticInferenceServiceCompletionServiceSettings("model_id", new RateLimitSettings(100)),
10451061
EmptyTaskSettings.INSTANCE,
10461062
EmptySecretSettings.INSTANCE,
1047-
new ElasticInferenceServiceComponents(eisGatewayUrl)
1063+
ElasticInferenceServiceComponents.withNoRevokeDelay(eisGatewayUrl)
10481064
);
10491065
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
10501066
service.unifiedCompletionInfer(
@@ -1077,6 +1093,11 @@ private void testUnifiedStreamError(int responseCode, String responseJson, Strin
10771093
}
10781094
}
10791095

1096+
private void ensureAuthorizationCallFinished(ElasticInferenceService service) {
1097+
service.onNodeStarted();
1098+
service.waitForAuthorizationToComplete(TIMEOUT);
1099+
}
1100+
10801101
private ElasticInferenceService createServiceWithMockSender() {
10811102
return createServiceWithMockSender(ElasticInferenceServiceAuthorizationTests.createEnabledAuth());
10821103
}
@@ -1092,7 +1113,7 @@ private ElasticInferenceService createServiceWithMockSender(ElasticInferenceServ
10921113
return new ElasticInferenceService(
10931114
mock(HttpRequestSender.Factory.class),
10941115
createWithEmptySettings(threadPool),
1095-
new ElasticInferenceServiceComponents(null),
1116+
ElasticInferenceServiceComponents.EMPTY_INSTANCE,
10961117
mockModelRegistry(),
10971118
mockAuthHandler
10981119
);
@@ -1121,7 +1142,7 @@ private ElasticInferenceService createService(
11211142
return new ElasticInferenceService(
11221143
senderFactory,
11231144
createWithEmptySettings(threadPool),
1124-
new ElasticInferenceServiceComponents(gatewayUrl),
1145+
ElasticInferenceServiceComponents.withNoRevokeDelay(gatewayUrl),
11251146
mockModelRegistry(),
11261147
mockAuthHandler
11271148
);
@@ -1131,7 +1152,7 @@ private ElasticInferenceService createServiceWithAuthHandler(HttpRequestSender.F
11311152
return new ElasticInferenceService(
11321153
senderFactory,
11331154
createWithEmptySettings(threadPool),
1134-
new ElasticInferenceServiceComponents(eisGatewayUrl),
1155+
ElasticInferenceServiceComponents.withNoRevokeDelay(eisGatewayUrl),
11351156
mockModelRegistry(),
11361157
new ElasticInferenceServiceAuthorizationHandler(eisGatewayUrl, threadPool)
11371158
);

0 commit comments

Comments
 (0)