Skip to content

Commit bcdc7c5

Browse files
[8.18] [ML] Support delaying EIS authorization revocation until after the node has finished booting (elastic#122644) (elastic#123030)
* [ML] Support delaying EIS authorization revocation until after the node has finished booting (elastic#122644) * Refactoring authorization to happen after the node starts * Adding delay for model registry call * Fixing test (cherry picked from commit 4de8244) # Conflicts: # x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java # x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java # x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java * Fixing getFirst call
1 parent b0018b4 commit bcdc7c5

File tree

11 files changed

+117
-33
lines changed

11 files changed

+117
-33
lines changed

server/src/main/java/org/elasticsearch/inference/InferenceService.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,4 +241,10 @@ default void defaultConfigs(ActionListener<List<Model>> defaultsListener) {
241241
default void updateModelsWithDynamicFields(List<Model> model, ActionListener<List<Model>> listener) {
242242
listener.onResponse(model);
243243
}
244+
245+
/**
246+
* Called after the Elasticsearch node has completed its start up. This allows the service to perform initialization
247+
* after ensuring the node's internals are set up (for example if this ensures the internal ES client is ready for use).
248+
*/
249+
default void onNodeStarted() {}
244250
}

server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ public void init(Client client) {
4141
services.values().forEach(s -> s.init(client));
4242
}
4343

44+
public void onNodeStarted() {
45+
services.values().forEach(InferenceService::onNodeStarted);
46+
}
47+
4448
public Map<String, InferenceService> getServices() {
4549
return services;
4650
}

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCo
9191
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
9292

9393
try (var service = createElasticInferenceService()) {
94-
service.waitForAuthorizationToComplete(TIMEOUT);
94+
ensureAuthorizationCallFinished(service);
9595
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY)));
9696
assertThat(
9797
service.defaultConfigIds(),
@@ -125,7 +125,8 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty()
125125
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
126126

127127
try (var service = createElasticInferenceService()) {
128-
service.waitForAuthorizationToComplete(TIMEOUT);
128+
ensureAuthorizationCallFinished(service);
129+
129130
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY)));
130131
assertThat(
131132
service.defaultConfigIds(),
@@ -164,7 +165,8 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty()
164165
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson));
165166

166167
try (var service = createElasticInferenceService()) {
167-
service.waitForAuthorizationToComplete(TIMEOUT);
168+
ensureAuthorizationCallFinished(service);
169+
168170
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
169171
assertTrue(service.defaultConfigIds().isEmpty());
170172
assertThat(service.supportedTaskTypes(), is(EnumSet.noneOf(TaskType.class)));
@@ -198,7 +200,8 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
198200
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
199201

200202
try (var service = createElasticInferenceService()) {
201-
service.waitForAuthorizationToComplete(TIMEOUT);
203+
ensureAuthorizationCallFinished(service);
204+
202205
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY)));
203206
assertThat(
204207
service.defaultConfigIds(),
@@ -242,7 +245,8 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
242245
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson));
243246

244247
try (var service = createElasticInferenceService()) {
245-
service.waitForAuthorizationToComplete(TIMEOUT);
248+
ensureAuthorizationCallFinished(service);
249+
246250
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
247251
assertTrue(service.defaultConfigIds().isEmpty());
248252
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING)));
@@ -256,14 +260,19 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
256260
}
257261
}
258262

263+
private void ensureAuthorizationCallFinished(ElasticInferenceService service) {
264+
service.onNodeStarted();
265+
service.waitForAuthorizationToComplete(TIMEOUT);
266+
}
267+
259268
private ElasticInferenceService createElasticInferenceService() {
260269
var httpManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
261270
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, httpManager);
262271

263272
return new ElasticInferenceService(
264273
senderFactory,
265274
createWithEmptySettings(threadPool),
266-
new ElasticInferenceServiceComponents(gatewayUrl),
275+
ElasticInferenceServiceComponents.withNoRevokeDelay(gatewayUrl),
267276
modelRegistry,
268277
new ElasticInferenceServiceAuthorizationHandler(gatewayUrl, threadPool)
269278
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.elasticsearch.license.XPackLicenseState;
3535
import org.elasticsearch.node.PluginComponentBinding;
3636
import org.elasticsearch.plugins.ActionPlugin;
37+
import org.elasticsearch.plugins.ClusterPlugin;
3738
import org.elasticsearch.plugins.ExtensiblePlugin;
3839
import org.elasticsearch.plugins.MapperPlugin;
3940
import org.elasticsearch.plugins.Plugin;
@@ -146,7 +147,8 @@ public class InferencePlugin extends Plugin
146147
SystemIndexPlugin,
147148
MapperPlugin,
148149
SearchPlugin,
149-
InternalSearchPlugin {
150+
InternalSearchPlugin,
151+
ClusterPlugin {
150152

151153
/**
152154
* When this setting is true the verification check that
@@ -274,7 +276,7 @@ public Collection<?> createComponents(PluginServices services) {
274276
ElasticInferenceServiceSettings inferenceServiceSettings = new ElasticInferenceServiceSettings(settings);
275277
String elasticInferenceUrl = inferenceServiceSettings.getElasticInferenceServiceUrl();
276278

277-
var elasticInferenceServiceComponentsInstance = new ElasticInferenceServiceComponents(elasticInferenceUrl);
279+
var elasticInferenceServiceComponentsInstance = ElasticInferenceServiceComponents.withDefaultRevokeDelay(elasticInferenceUrl);
278280
elasticInferenceServiceComponents.set(elasticInferenceServiceComponentsInstance);
279281

280282
var authorizationHandler = new ElasticInferenceServiceAuthorizationHandler(
@@ -516,6 +518,15 @@ private String getElasticInferenceServiceUrl(ElasticInferenceServiceSettings set
516518
return settings.getElasticInferenceServiceUrl();
517519
}
518520

521+
@Override
522+
public void onNodeStarted() {
523+
var registry = inferenceServiceRegistry.get();
524+
525+
if (registry != null) {
526+
registry.onNodeStarted();
527+
}
528+
}
529+
519530
protected SSLService getSslService() {
520531
return XPackPlugin.getSharedSslService();
521532
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,6 @@ public ElasticInferenceService(
122122

123123
configuration = new Configuration(authRef.get().taskTypesAndModels.getAuthorizedTaskTypes());
124124
defaultModelsConfigs = initDefaultEndpoints(elasticInferenceServiceComponents);
125-
126-
getAuthorization();
127125
}
128126

129127
private static Map<String, DefaultModelConfig> initDefaultEndpoints(
@@ -255,9 +253,24 @@ private void handleRevokedDefaultConfigs(Set<String> authorizedDefaultModelIds)
255253
authorizationCompletedLatch.countDown();
256254
});
257255

258-
getServiceComponents().threadPool()
259-
.executor(UTILITY_THREAD_POOL_NAME)
260-
.execute(() -> modelRegistry.removeDefaultConfigs(unauthorizedDefaultInferenceEndpointIds, deleteInferenceEndpointsListener));
256+
Runnable removeFromRegistry = () -> {
257+
logger.debug("Synchronizing default inference endpoints");
258+
modelRegistry.removeDefaultConfigs(unauthorizedDefaultInferenceEndpointIds, deleteInferenceEndpointsListener);
259+
};
260+
261+
var delay = elasticInferenceServiceComponents.revokeAuthorizationDelay();
262+
if (delay == null) {
263+
getServiceComponents().threadPool().executor(UTILITY_THREAD_POOL_NAME).execute(removeFromRegistry);
264+
} else {
265+
getServiceComponents().threadPool()
266+
.schedule(removeFromRegistry, delay, getServiceComponents().threadPool().executor(UTILITY_THREAD_POOL_NAME));
267+
}
268+
269+
}
270+
271+
@Override
272+
public void onNodeStarted() {
273+
getServiceComponents().threadPool().executor(UTILITY_THREAD_POOL_NAME).execute(this::getAuthorization);
261274
}
262275

263276
/**

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceComponents.java

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,23 @@
88
package org.elasticsearch.xpack.inference.services.elastic;
99

1010
import org.elasticsearch.core.Nullable;
11+
import org.elasticsearch.core.TimeValue;
1112

12-
public record ElasticInferenceServiceComponents(@Nullable String elasticInferenceServiceUrl) {}
13+
/**
14+
* @param elasticInferenceServiceUrl the upstream Elastic Inference Server's URL
15+
* @param revokeAuthorizationDelay Amount of time to wait before attempting to revoke authorization to certain model ids.
16+
* null indicates that there should be no delay
17+
*/
18+
public record ElasticInferenceServiceComponents(@Nullable String elasticInferenceServiceUrl, @Nullable TimeValue revokeAuthorizationDelay) {
19+
private static final TimeValue DEFAULT_REVOKE_AUTHORIZATION_DELAY = TimeValue.timeValueMinutes(10);
20+
21+
public static final ElasticInferenceServiceComponents EMPTY_INSTANCE = new ElasticInferenceServiceComponents(null, null);
22+
23+
public static ElasticInferenceServiceComponents withNoRevokeDelay(String elasticInferenceServiceUrl) {
24+
return new ElasticInferenceServiceComponents(elasticInferenceServiceUrl, null);
25+
}
26+
27+
public static ElasticInferenceServiceComponents withDefaultRevokeDelay(String elasticInferenceServiceUrl) {
28+
return new ElasticInferenceServiceComponents(elasticInferenceServiceUrl, DEFAULT_REVOKE_AUTHORIZATION_DELAY);
29+
}
30+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ public void getAuthorization(ActionListener<ElasticInferenceServiceAuthorization
7676
logger.debug("Retrieving authorization information from the Elastic Inference Service.");
7777

7878
if (Strings.isNullOrEmpty(baseUrl)) {
79-
logger.warn("The base URL for the authorization service is not valid, rejecting authorization.");
79+
logger.debug("The base URL for the authorization service is not valid, rejecting authorization.");
8080
listener.onResponse(ElasticInferenceServiceAuthorization.newDisabledService());
8181
return;
8282
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ public static ElasticInferenceServiceSparseEmbeddingsModel createModel(String ur
2626
new ElasticInferenceServiceSparseEmbeddingsServiceSettings(modelId, maxInputTokens, null),
2727
EmptyTaskSettings.INSTANCE,
2828
EmptySecretSettings.INSTANCE,
29-
new ElasticInferenceServiceComponents(url)
29+
ElasticInferenceServiceComponents.withNoRevokeDelay(url)
3030
);
3131
}
3232
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,8 @@ public void testChunkedInfer_PassesThrough() throws IOException {
567567

568568
public void testHideFromConfigurationApi_ReturnsTrue_WithNoAvailableModels() throws Exception {
569569
try (var service = createServiceWithMockSender(ElasticInferenceServiceAuthorization.newDisabledService())) {
570+
ensureAuthorizationCallFinished(service);
571+
570572
assertTrue(service.hideFromConfigurationApi());
571573
}
572574
}
@@ -586,6 +588,8 @@ public void testHideFromConfigurationApi_ReturnsTrue_WithModelTaskTypesThatAreNo
586588
)
587589
)
588590
) {
591+
ensureAuthorizationCallFinished(service);
592+
589593
assertTrue(service.hideFromConfigurationApi());
590594
}
591595
}
@@ -605,6 +609,8 @@ public void testHideFromConfigurationApi_ReturnsFalse_WithAvailableModels() thro
605609
)
606610
)
607611
) {
612+
ensureAuthorizationCallFinished(service);
613+
608614
assertFalse(service.hideFromConfigurationApi());
609615
}
610616
}
@@ -624,6 +630,8 @@ public void testGetConfiguration() throws Exception {
624630
)
625631
)
626632
) {
633+
ensureAuthorizationCallFinished(service);
634+
627635
String content = XContentHelper.stripWhitespace("""
628636
{
629637
"service": "elastic",
@@ -677,6 +685,8 @@ public void testGetConfiguration() throws Exception {
677685

678686
public void testGetConfiguration_WithoutSupportedTaskTypes() throws Exception {
679687
try (var service = createServiceWithMockSender(ElasticInferenceServiceAuthorization.newDisabledService())) {
688+
ensureAuthorizationCallFinished(service);
689+
680690
String content = XContentHelper.stripWhitespace("""
681691
{
682692
"service": "elastic",
@@ -744,6 +754,8 @@ public void testGetConfiguration_WithoutSupportedTaskTypes_WhenModelsReturnTaskO
744754
)
745755
)
746756
) {
757+
ensureAuthorizationCallFinished(service);
758+
747759
String content = XContentHelper.stripWhitespace("""
748760
{
749761
"service": "elastic",
@@ -811,7 +823,8 @@ public void testSupportedStreamingTasks_ReturnsChatCompletion_WhenAuthRespondsWi
811823

812824
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
813825
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
814-
service.waitForAuthorizationToComplete(TIMEOUT);
826+
ensureAuthorizationCallFinished(service);
827+
815828
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY)));
816829
assertTrue(service.defaultConfigIds().isEmpty());
817830

@@ -841,7 +854,8 @@ public void testSupportedTaskTypes_Returns_TheAuthorizedTaskTypes_IgnoresUnimple
841854

842855
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
843856
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
844-
service.waitForAuthorizationToComplete(TIMEOUT);
857+
ensureAuthorizationCallFinished(service);
858+
845859
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING)));
846860
}
847861
}
@@ -866,7 +880,8 @@ public void testSupportedTaskTypes_Returns_TheAuthorizedTaskTypes() throws Excep
866880

867881
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
868882
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
869-
service.waitForAuthorizationToComplete(TIMEOUT);
883+
ensureAuthorizationCallFinished(service);
884+
870885
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION)));
871886
}
872887
}
@@ -887,7 +902,8 @@ public void testSupportedStreamingTasks_ReturnsEmpty_WhenAuthRespondsWithoutChat
887902

888903
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
889904
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
890-
service.waitForAuthorizationToComplete(TIMEOUT);
905+
ensureAuthorizationCallFinished(service);
906+
891907
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
892908
assertTrue(service.defaultConfigIds().isEmpty());
893909
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING)));
@@ -914,7 +930,7 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsIn
914930

915931
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
916932
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
917-
service.waitForAuthorizationToComplete(TIMEOUT);
933+
ensureAuthorizationCallFinished(service);
918934
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
919935
assertThat(
920936
service.defaultConfigIds(),
@@ -948,7 +964,7 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCo
948964

949965
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
950966
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
951-
service.waitForAuthorizationToComplete(TIMEOUT);
967+
ensureAuthorizationCallFinished(service);
952968
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY)));
953969
assertThat(
954970
service.defaultConfigIds(),
@@ -1019,7 +1035,7 @@ private void testUnifiedStreamError(int responseCode, String responseJson, Strin
10191035
new ElasticInferenceServiceCompletionServiceSettings("model_id", new RateLimitSettings(100)),
10201036
EmptyTaskSettings.INSTANCE,
10211037
EmptySecretSettings.INSTANCE,
1022-
new ElasticInferenceServiceComponents(eisGatewayUrl)
1038+
ElasticInferenceServiceComponents.withNoRevokeDelay(eisGatewayUrl)
10231039
);
10241040
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
10251041
service.unifiedCompletionInfer(
@@ -1052,6 +1068,11 @@ private void testUnifiedStreamError(int responseCode, String responseJson, Strin
10521068
}
10531069
}
10541070

1071+
private void ensureAuthorizationCallFinished(ElasticInferenceService service) {
1072+
service.onNodeStarted();
1073+
service.waitForAuthorizationToComplete(TIMEOUT);
1074+
}
1075+
10551076
private ElasticInferenceService createServiceWithMockSender() {
10561077
return createServiceWithMockSender(ElasticInferenceServiceAuthorizationTests.createEnabledAuth());
10571078
}
@@ -1067,7 +1088,7 @@ private ElasticInferenceService createServiceWithMockSender(ElasticInferenceServ
10671088
return new ElasticInferenceService(
10681089
mock(HttpRequestSender.Factory.class),
10691090
createWithEmptySettings(threadPool),
1070-
new ElasticInferenceServiceComponents(null),
1091+
ElasticInferenceServiceComponents.EMPTY_INSTANCE,
10711092
mockModelRegistry(),
10721093
mockAuthHandler
10731094
);
@@ -1096,7 +1117,7 @@ private ElasticInferenceService createService(
10961117
return new ElasticInferenceService(
10971118
senderFactory,
10981119
createWithEmptySettings(threadPool),
1099-
new ElasticInferenceServiceComponents(gatewayUrl),
1120+
ElasticInferenceServiceComponents.withNoRevokeDelay(gatewayUrl),
11001121
mockModelRegistry(),
11011122
mockAuthHandler
11021123
);
@@ -1106,7 +1127,7 @@ private ElasticInferenceService createServiceWithAuthHandler(HttpRequestSender.F
11061127
return new ElasticInferenceService(
11071128
senderFactory,
11081129
createWithEmptySettings(threadPool),
1109-
new ElasticInferenceServiceComponents(eisGatewayUrl),
1130+
ElasticInferenceServiceComponents.withNoRevokeDelay(eisGatewayUrl),
11101131
mockModelRegistry(),
11111132
new ElasticInferenceServiceAuthorizationHandler(eisGatewayUrl, threadPool)
11121133
);

0 commit comments

Comments
 (0)