Skip to content

Commit 96fd807

Browse files
[ML] Support delaying EIS authorization revocation until after the node has finished booting (elastic#122644) (elastic#123029)
* Refactoring authorization to happen after the node starts * Adding delay for model registry call * Fixing test (cherry picked from commit 4de8244) # Conflicts: # x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java # x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java
1 parent ffd7631 commit 96fd807

File tree

11 files changed

+118
-33
lines changed

11 files changed

+118
-33
lines changed

server/src/main/java/org/elasticsearch/inference/InferenceService.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,4 +241,10 @@ default void defaultConfigs(ActionListener<List<Model>> defaultsListener) {
241241
default void updateModelsWithDynamicFields(List<Model> model, ActionListener<List<Model>> listener) {
242242
listener.onResponse(model);
243243
}
244+
245+
/**
246+
* Called after the Elasticsearch node has completed its start up. This allows the service to perform initialization
247+
* after ensuring the node's internals are set up (for example if this ensures the internal ES client is ready for use).
248+
*/
249+
default void onNodeStarted() {}
244250
}

server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ public void init(Client client) {
4141
services.values().forEach(s -> s.init(client));
4242
}
4343

44+
public void onNodeStarted() {
45+
services.values().forEach(InferenceService::onNodeStarted);
46+
}
47+
4448
public Map<String, InferenceService> getServices() {
4549
return services;
4650
}

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCo
9191
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
9292

9393
try (var service = createElasticInferenceService()) {
94-
service.waitForAuthorizationToComplete(TIMEOUT);
94+
ensureAuthorizationCallFinished(service);
9595
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY)));
9696
assertThat(
9797
service.defaultConfigIds(),
@@ -125,7 +125,8 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty()
125125
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
126126

127127
try (var service = createElasticInferenceService()) {
128-
service.waitForAuthorizationToComplete(TIMEOUT);
128+
ensureAuthorizationCallFinished(service);
129+
129130
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY)));
130131
assertThat(
131132
service.defaultConfigIds(),
@@ -164,7 +165,8 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty()
164165
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson));
165166

166167
try (var service = createElasticInferenceService()) {
167-
service.waitForAuthorizationToComplete(TIMEOUT);
168+
ensureAuthorizationCallFinished(service);
169+
168170
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
169171
assertTrue(service.defaultConfigIds().isEmpty());
170172
assertThat(service.supportedTaskTypes(), is(EnumSet.noneOf(TaskType.class)));
@@ -198,7 +200,8 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
198200
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
199201

200202
try (var service = createElasticInferenceService()) {
201-
service.waitForAuthorizationToComplete(TIMEOUT);
203+
ensureAuthorizationCallFinished(service);
204+
202205
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY)));
203206
assertThat(
204207
service.defaultConfigIds(),
@@ -242,7 +245,8 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
242245
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson));
243246

244247
try (var service = createElasticInferenceService()) {
245-
service.waitForAuthorizationToComplete(TIMEOUT);
248+
ensureAuthorizationCallFinished(service);
249+
246250
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
247251
assertTrue(service.defaultConfigIds().isEmpty());
248252
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING)));
@@ -256,14 +260,19 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
256260
}
257261
}
258262

263+
private void ensureAuthorizationCallFinished(ElasticInferenceService service) {
264+
service.onNodeStarted();
265+
service.waitForAuthorizationToComplete(TIMEOUT);
266+
}
267+
259268
private ElasticInferenceService createElasticInferenceService() {
260269
var httpManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
261270
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, httpManager);
262271

263272
return new ElasticInferenceService(
264273
senderFactory,
265274
createWithEmptySettings(threadPool),
266-
new ElasticInferenceServiceComponents(gatewayUrl),
275+
ElasticInferenceServiceComponents.withNoRevokeDelay(gatewayUrl),
267276
modelRegistry,
268277
new ElasticInferenceServiceAuthorizationHandler(gatewayUrl, threadPool)
269278
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.elasticsearch.license.XPackLicenseState;
3535
import org.elasticsearch.node.PluginComponentBinding;
3636
import org.elasticsearch.plugins.ActionPlugin;
37+
import org.elasticsearch.plugins.ClusterPlugin;
3738
import org.elasticsearch.plugins.ExtensiblePlugin;
3839
import org.elasticsearch.plugins.MapperPlugin;
3940
import org.elasticsearch.plugins.Plugin;
@@ -146,7 +147,8 @@ public class InferencePlugin extends Plugin
146147
SystemIndexPlugin,
147148
MapperPlugin,
148149
SearchPlugin,
149-
InternalSearchPlugin {
150+
InternalSearchPlugin,
151+
ClusterPlugin {
150152

151153
/**
152154
* When this setting is true the verification check that
@@ -274,7 +276,7 @@ public Collection<?> createComponents(PluginServices services) {
274276
ElasticInferenceServiceSettings inferenceServiceSettings = new ElasticInferenceServiceSettings(settings);
275277
String elasticInferenceUrl = inferenceServiceSettings.getElasticInferenceServiceUrl();
276278

277-
var elasticInferenceServiceComponentsInstance = new ElasticInferenceServiceComponents(elasticInferenceUrl);
279+
var elasticInferenceServiceComponentsInstance = ElasticInferenceServiceComponents.withDefaultRevokeDelay(elasticInferenceUrl);
278280
elasticInferenceServiceComponents.set(elasticInferenceServiceComponentsInstance);
279281

280282
var authorizationHandler = new ElasticInferenceServiceAuthorizationHandler(
@@ -507,6 +509,15 @@ public Map<String, Highlighter> getHighlighters() {
507509
return Map.of(SemanticTextHighlighter.NAME, new SemanticTextHighlighter());
508510
}
509511

512+
@Override
513+
public void onNodeStarted() {
514+
var registry = inferenceServiceRegistry.get();
515+
516+
if (registry != null) {
517+
registry.onNodeStarted();
518+
}
519+
}
520+
510521
protected SSLService getSslService() {
511522
return XPackPlugin.getSharedSslService();
512523
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,6 @@ public ElasticInferenceService(
122122

123123
configuration = new Configuration(authRef.get().taskTypesAndModels.getAuthorizedTaskTypes());
124124
defaultModelsConfigs = initDefaultEndpoints(elasticInferenceServiceComponents);
125-
126-
getAuthorization();
127125
}
128126

129127
private static Map<String, DefaultModelConfig> initDefaultEndpoints(
@@ -255,9 +253,24 @@ private void handleRevokedDefaultConfigs(Set<String> authorizedDefaultModelIds)
255253
authorizationCompletedLatch.countDown();
256254
});
257255

258-
getServiceComponents().threadPool()
259-
.executor(UTILITY_THREAD_POOL_NAME)
260-
.execute(() -> modelRegistry.removeDefaultConfigs(unauthorizedDefaultInferenceEndpointIds, deleteInferenceEndpointsListener));
256+
Runnable removeFromRegistry = () -> {
257+
logger.debug("Synchronizing default inference endpoints");
258+
modelRegistry.removeDefaultConfigs(unauthorizedDefaultInferenceEndpointIds, deleteInferenceEndpointsListener);
259+
};
260+
261+
var delay = elasticInferenceServiceComponents.revokeAuthorizationDelay();
262+
if (delay == null) {
263+
getServiceComponents().threadPool().executor(UTILITY_THREAD_POOL_NAME).execute(removeFromRegistry);
264+
} else {
265+
getServiceComponents().threadPool()
266+
.schedule(removeFromRegistry, delay, getServiceComponents().threadPool().executor(UTILITY_THREAD_POOL_NAME));
267+
}
268+
269+
}
270+
271+
@Override
272+
public void onNodeStarted() {
273+
getServiceComponents().threadPool().executor(UTILITY_THREAD_POOL_NAME).execute(this::getAuthorization);
261274
}
262275

263276
/**

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceComponents.java

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,23 @@
88
package org.elasticsearch.xpack.inference.services.elastic;
99

1010
import org.elasticsearch.core.Nullable;
11+
import org.elasticsearch.core.TimeValue;
1112

12-
public record ElasticInferenceServiceComponents(@Nullable String elasticInferenceServiceUrl) {}
13+
/**
14+
* @param elasticInferenceServiceUrl the upstream Elastic Inference Server's URL
15+
* @param revokeAuthorizationDelay Amount of time to wait before attempting to revoke authorization to certain model ids.
16+
* null indicates that there should be no delay
17+
*/
18+
public record ElasticInferenceServiceComponents(@Nullable String elasticInferenceServiceUrl, @Nullable TimeValue revokeAuthorizationDelay) {
19+
private static final TimeValue DEFAULT_REVOKE_AUTHORIZATION_DELAY = TimeValue.timeValueMinutes(10);
20+
21+
public static final ElasticInferenceServiceComponents EMPTY_INSTANCE = new ElasticInferenceServiceComponents(null, null);
22+
23+
public static ElasticInferenceServiceComponents withNoRevokeDelay(String elasticInferenceServiceUrl) {
24+
return new ElasticInferenceServiceComponents(elasticInferenceServiceUrl, null);
25+
}
26+
27+
public static ElasticInferenceServiceComponents withDefaultRevokeDelay(String elasticInferenceServiceUrl) {
28+
return new ElasticInferenceServiceComponents(elasticInferenceServiceUrl, DEFAULT_REVOKE_AUTHORIZATION_DELAY);
29+
}
30+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ public void getAuthorization(ActionListener<ElasticInferenceServiceAuthorization
7676
logger.debug("Retrieving authorization information from the Elastic Inference Service.");
7777

7878
if (Strings.isNullOrEmpty(baseUrl)) {
79-
logger.warn("The base URL for the authorization service is not valid, rejecting authorization.");
79+
logger.debug("The base URL for the authorization service is not valid, rejecting authorization.");
8080
listener.onResponse(ElasticInferenceServiceAuthorization.newDisabledService());
8181
return;
8282
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ public static ElasticInferenceServiceSparseEmbeddingsModel createModel(String ur
2626
new ElasticInferenceServiceSparseEmbeddingsServiceSettings(modelId, maxInputTokens, null),
2727
EmptyTaskSettings.INSTANCE,
2828
EmptySecretSettings.INSTANCE,
29-
new ElasticInferenceServiceComponents(url)
29+
ElasticInferenceServiceComponents.withNoRevokeDelay(url)
3030
);
3131
}
3232
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,8 @@ public void testChunkedInfer_PassesThrough() throws IOException {
568568

569569
public void testHideFromConfigurationApi_ReturnsTrue_WithNoAvailableModels() throws Exception {
570570
try (var service = createServiceWithMockSender(ElasticInferenceServiceAuthorization.newDisabledService())) {
571+
ensureAuthorizationCallFinished(service);
572+
571573
assertTrue(service.hideFromConfigurationApi());
572574
}
573575
}
@@ -587,6 +589,8 @@ public void testHideFromConfigurationApi_ReturnsTrue_WithModelTaskTypesThatAreNo
587589
)
588590
)
589591
) {
592+
ensureAuthorizationCallFinished(service);
593+
590594
assertTrue(service.hideFromConfigurationApi());
591595
}
592596
}
@@ -606,6 +610,8 @@ public void testHideFromConfigurationApi_ReturnsFalse_WithAvailableModels() thro
606610
)
607611
)
608612
) {
613+
ensureAuthorizationCallFinished(service);
614+
609615
assertFalse(service.hideFromConfigurationApi());
610616
}
611617
}
@@ -625,6 +631,8 @@ public void testGetConfiguration() throws Exception {
625631
)
626632
)
627633
) {
634+
ensureAuthorizationCallFinished(service);
635+
628636
String content = XContentHelper.stripWhitespace("""
629637
{
630638
"service": "elastic",
@@ -678,6 +686,8 @@ public void testGetConfiguration() throws Exception {
678686

679687
public void testGetConfiguration_WithoutSupportedTaskTypes() throws Exception {
680688
try (var service = createServiceWithMockSender(ElasticInferenceServiceAuthorization.newDisabledService())) {
689+
ensureAuthorizationCallFinished(service);
690+
681691
String content = XContentHelper.stripWhitespace("""
682692
{
683693
"service": "elastic",
@@ -745,6 +755,8 @@ public void testGetConfiguration_WithoutSupportedTaskTypes_WhenModelsReturnTaskO
745755
)
746756
)
747757
) {
758+
ensureAuthorizationCallFinished(service);
759+
748760
String content = XContentHelper.stripWhitespace("""
749761
{
750762
"service": "elastic",
@@ -812,7 +824,8 @@ public void testSupportedStreamingTasks_ReturnsChatCompletion_WhenAuthRespondsWi
812824

813825
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
814826
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
815-
service.waitForAuthorizationToComplete(TIMEOUT);
827+
ensureAuthorizationCallFinished(service);
828+
816829
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY)));
817830
assertTrue(service.defaultConfigIds().isEmpty());
818831

@@ -842,7 +855,8 @@ public void testSupportedTaskTypes_Returns_TheAuthorizedTaskTypes_IgnoresUnimple
842855

843856
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
844857
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
845-
service.waitForAuthorizationToComplete(TIMEOUT);
858+
ensureAuthorizationCallFinished(service);
859+
846860
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING)));
847861
}
848862
}
@@ -867,7 +881,8 @@ public void testSupportedTaskTypes_Returns_TheAuthorizedTaskTypes() throws Excep
867881

868882
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
869883
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
870-
service.waitForAuthorizationToComplete(TIMEOUT);
884+
ensureAuthorizationCallFinished(service);
885+
871886
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION)));
872887
}
873888
}
@@ -888,7 +903,8 @@ public void testSupportedStreamingTasks_ReturnsEmpty_WhenAuthRespondsWithoutChat
888903

889904
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
890905
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
891-
service.waitForAuthorizationToComplete(TIMEOUT);
906+
ensureAuthorizationCallFinished(service);
907+
892908
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
893909
assertTrue(service.defaultConfigIds().isEmpty());
894910
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING)));
@@ -915,7 +931,7 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsIn
915931

916932
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
917933
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
918-
service.waitForAuthorizationToComplete(TIMEOUT);
934+
ensureAuthorizationCallFinished(service);
919935
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
920936
assertThat(
921937
service.defaultConfigIds(),
@@ -949,8 +965,9 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCo
949965

950966
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
951967
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
952-
service.waitForAuthorizationToComplete(TIMEOUT);
968+
ensureAuthorizationCallFinished(service);
953969
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY)));
970+
954971
assertThat(
955972
service.defaultConfigIds(),
956973
is(
@@ -1020,7 +1037,7 @@ private void testUnifiedStreamError(int responseCode, String responseJson, Strin
10201037
new ElasticInferenceServiceCompletionServiceSettings("model_id", new RateLimitSettings(100)),
10211038
EmptyTaskSettings.INSTANCE,
10221039
EmptySecretSettings.INSTANCE,
1023-
new ElasticInferenceServiceComponents(eisGatewayUrl)
1040+
ElasticInferenceServiceComponents.withNoRevokeDelay(eisGatewayUrl)
10241041
);
10251042
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
10261043
service.unifiedCompletionInfer(
@@ -1053,6 +1070,11 @@ private void testUnifiedStreamError(int responseCode, String responseJson, Strin
10531070
}
10541071
}
10551072

1073+
private void ensureAuthorizationCallFinished(ElasticInferenceService service) {
1074+
service.onNodeStarted();
1075+
service.waitForAuthorizationToComplete(TIMEOUT);
1076+
}
1077+
10561078
private ElasticInferenceService createServiceWithMockSender() {
10571079
return createServiceWithMockSender(ElasticInferenceServiceAuthorizationTests.createEnabledAuth());
10581080
}
@@ -1068,7 +1090,7 @@ private ElasticInferenceService createServiceWithMockSender(ElasticInferenceServ
10681090
return new ElasticInferenceService(
10691091
mock(HttpRequestSender.Factory.class),
10701092
createWithEmptySettings(threadPool),
1071-
new ElasticInferenceServiceComponents(null),
1093+
ElasticInferenceServiceComponents.EMPTY_INSTANCE,
10721094
mockModelRegistry(),
10731095
mockAuthHandler
10741096
);
@@ -1097,7 +1119,7 @@ private ElasticInferenceService createService(
10971119
return new ElasticInferenceService(
10981120
senderFactory,
10991121
createWithEmptySettings(threadPool),
1100-
new ElasticInferenceServiceComponents(gatewayUrl),
1122+
ElasticInferenceServiceComponents.withNoRevokeDelay(gatewayUrl),
11011123
mockModelRegistry(),
11021124
mockAuthHandler
11031125
);
@@ -1107,7 +1129,7 @@ private ElasticInferenceService createServiceWithAuthHandler(HttpRequestSender.F
11071129
return new ElasticInferenceService(
11081130
senderFactory,
11091131
createWithEmptySettings(threadPool),
1110-
new ElasticInferenceServiceComponents(eisGatewayUrl),
1132+
ElasticInferenceServiceComponents.withNoRevokeDelay(eisGatewayUrl),
11111133
mockModelRegistry(),
11121134
new ElasticInferenceServiceAuthorizationHandler(eisGatewayUrl, threadPool)
11131135
);

0 commit comments

Comments
 (0)