Skip to content

Commit 09250aa

Browse files
[8.x] [ML] Support delaying EIS authorization revocation until after the node has finished booting (elastic#122644) (elastic#123027)
* [ML] Support delaying EIS authorization revocation until after the node has finished booting (elastic#122644) * Refactoring authorization to happen after the node starts * Adding delay for model registry call * Fixing test (cherry picked from commit 4de8244) # Conflicts: # x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java # x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java # x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java * Fixing getFirst call
1 parent d440d72 commit 09250aa

File tree

11 files changed

+118
-33
lines changed

11 files changed

+118
-33
lines changed

server/src/main/java/org/elasticsearch/inference/InferenceService.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,4 +241,10 @@ default void defaultConfigs(ActionListener<List<Model>> defaultsListener) {
241241
default void updateModelsWithDynamicFields(List<Model> model, ActionListener<List<Model>> listener) {
242242
listener.onResponse(model);
243243
}
244+
245+
/**
246+
* Called after the Elasticsearch node has completed its start up. This allows the service to perform initialization
247+
* after ensuring the node's internals are set up (for example if this ensures the internal ES client is ready for use).
248+
*/
249+
default void onNodeStarted() {}
244250
}

server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ public void init(Client client) {
4141
services.values().forEach(s -> s.init(client));
4242
}
4343

44+
public void onNodeStarted() {
45+
services.values().forEach(InferenceService::onNodeStarted);
46+
}
47+
4448
public Map<String, InferenceService> getServices() {
4549
return services;
4650
}

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,9 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCo
9191
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
9292

9393
try (var service = createElasticInferenceService()) {
94-
service.waitForAuthorizationToComplete(TIMEOUT);
94+
ensureAuthorizationCallFinished(service);
9595
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY)));
96+
9697
assertThat(
9798
service.defaultConfigIds(),
9899
is(
@@ -125,7 +126,8 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty()
125126
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
126127

127128
try (var service = createElasticInferenceService()) {
128-
service.waitForAuthorizationToComplete(TIMEOUT);
129+
ensureAuthorizationCallFinished(service);
130+
129131
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY)));
130132
assertThat(
131133
service.defaultConfigIds(),
@@ -164,7 +166,8 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty()
164166
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson));
165167

166168
try (var service = createElasticInferenceService()) {
167-
service.waitForAuthorizationToComplete(TIMEOUT);
169+
ensureAuthorizationCallFinished(service);
170+
168171
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
169172
assertTrue(service.defaultConfigIds().isEmpty());
170173
assertThat(service.supportedTaskTypes(), is(EnumSet.noneOf(TaskType.class)));
@@ -198,7 +201,8 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
198201
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
199202

200203
try (var service = createElasticInferenceService()) {
201-
service.waitForAuthorizationToComplete(TIMEOUT);
204+
ensureAuthorizationCallFinished(service);
205+
202206
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY)));
203207
assertThat(
204208
service.defaultConfigIds(),
@@ -244,7 +248,8 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
244248
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson));
245249

246250
try (var service = createElasticInferenceService()) {
247-
service.waitForAuthorizationToComplete(TIMEOUT);
251+
ensureAuthorizationCallFinished(service);
252+
248253
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
249254
assertThat(
250255
service.defaultConfigIds(),
@@ -264,14 +269,19 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
264269
}
265270
}
266271

272+
private void ensureAuthorizationCallFinished(ElasticInferenceService service) {
273+
service.onNodeStarted();
274+
service.waitForAuthorizationToComplete(TIMEOUT);
275+
}
276+
267277
private ElasticInferenceService createElasticInferenceService() {
268278
var httpManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
269279
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, httpManager);
270280

271281
return new ElasticInferenceService(
272282
senderFactory,
273283
createWithEmptySettings(threadPool),
274-
new ElasticInferenceServiceComponents(gatewayUrl),
284+
ElasticInferenceServiceComponents.withNoRevokeDelay(gatewayUrl),
275285
modelRegistry,
276286
new ElasticInferenceServiceAuthorizationHandler(gatewayUrl, threadPool)
277287
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.elasticsearch.license.XPackLicenseState;
3535
import org.elasticsearch.node.PluginComponentBinding;
3636
import org.elasticsearch.plugins.ActionPlugin;
37+
import org.elasticsearch.plugins.ClusterPlugin;
3738
import org.elasticsearch.plugins.ExtensiblePlugin;
3839
import org.elasticsearch.plugins.MapperPlugin;
3940
import org.elasticsearch.plugins.Plugin;
@@ -146,7 +147,8 @@ public class InferencePlugin extends Plugin
146147
SystemIndexPlugin,
147148
MapperPlugin,
148149
SearchPlugin,
149-
InternalSearchPlugin {
150+
InternalSearchPlugin,
151+
ClusterPlugin {
150152

151153
/**
152154
* When this setting is true the verification check that
@@ -274,7 +276,7 @@ public Collection<?> createComponents(PluginServices services) {
274276
ElasticInferenceServiceSettings inferenceServiceSettings = new ElasticInferenceServiceSettings(settings);
275277
String elasticInferenceUrl = inferenceServiceSettings.getElasticInferenceServiceUrl();
276278

277-
var elasticInferenceServiceComponentsInstance = new ElasticInferenceServiceComponents(elasticInferenceUrl);
279+
var elasticInferenceServiceComponentsInstance = ElasticInferenceServiceComponents.withDefaultRevokeDelay(elasticInferenceUrl);
278280
elasticInferenceServiceComponents.set(elasticInferenceServiceComponentsInstance);
279281

280282
var authorizationHandler = new ElasticInferenceServiceAuthorizationHandler(
@@ -516,6 +518,15 @@ private String getElasticInferenceServiceUrl(ElasticInferenceServiceSettings set
516518
return settings.getElasticInferenceServiceUrl();
517519
}
518520

521+
@Override
522+
public void onNodeStarted() {
523+
var registry = inferenceServiceRegistry.get();
524+
525+
if (registry != null) {
526+
registry.onNodeStarted();
527+
}
528+
}
529+
519530
protected SSLService getSslService() {
520531
return XPackPlugin.getSharedSslService();
521532
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,6 @@ public ElasticInferenceService(
134134

135135
configuration = new Configuration(authRef.get().taskTypesAndModels.getAuthorizedTaskTypes());
136136
defaultModelsConfigs = initDefaultEndpoints(elasticInferenceServiceComponents);
137-
138-
getAuthorization();
139137
}
140138

141139
private static Map<String, DefaultModelConfig> initDefaultEndpoints(
@@ -282,9 +280,24 @@ private void handleRevokedDefaultConfigs(Set<String> authorizedDefaultModelIds)
282280
authorizationCompletedLatch.countDown();
283281
});
284282

285-
getServiceComponents().threadPool()
286-
.executor(UTILITY_THREAD_POOL_NAME)
287-
.execute(() -> modelRegistry.removeDefaultConfigs(unauthorizedDefaultInferenceEndpointIds, deleteInferenceEndpointsListener));
283+
Runnable removeFromRegistry = () -> {
284+
logger.debug("Synchronizing default inference endpoints");
285+
modelRegistry.removeDefaultConfigs(unauthorizedDefaultInferenceEndpointIds, deleteInferenceEndpointsListener);
286+
};
287+
288+
var delay = elasticInferenceServiceComponents.revokeAuthorizationDelay();
289+
if (delay == null) {
290+
getServiceComponents().threadPool().executor(UTILITY_THREAD_POOL_NAME).execute(removeFromRegistry);
291+
} else {
292+
getServiceComponents().threadPool()
293+
.schedule(removeFromRegistry, delay, getServiceComponents().threadPool().executor(UTILITY_THREAD_POOL_NAME));
294+
}
295+
296+
}
297+
298+
@Override
299+
public void onNodeStarted() {
300+
getServiceComponents().threadPool().executor(UTILITY_THREAD_POOL_NAME).execute(this::getAuthorization);
288301
}
289302

290303
/**

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceComponents.java

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,23 @@
88
package org.elasticsearch.xpack.inference.services.elastic;
99

1010
import org.elasticsearch.core.Nullable;
11+
import org.elasticsearch.core.TimeValue;
1112

12-
public record ElasticInferenceServiceComponents(@Nullable String elasticInferenceServiceUrl) {}
13+
/**
14+
* @param elasticInferenceServiceUrl the upstream Elastic Inference Server's URL
15+
* @param revokeAuthorizationDelay Amount of time to wait before attempting to revoke authorization to certain model ids.
16+
* null indicates that there should be no delay
17+
*/
18+
public record ElasticInferenceServiceComponents(@Nullable String elasticInferenceServiceUrl, @Nullable TimeValue revokeAuthorizationDelay) {
19+
private static final TimeValue DEFAULT_REVOKE_AUTHORIZATION_DELAY = TimeValue.timeValueMinutes(10);
20+
21+
public static final ElasticInferenceServiceComponents EMPTY_INSTANCE = new ElasticInferenceServiceComponents(null, null);
22+
23+
public static ElasticInferenceServiceComponents withNoRevokeDelay(String elasticInferenceServiceUrl) {
24+
return new ElasticInferenceServiceComponents(elasticInferenceServiceUrl, null);
25+
}
26+
27+
public static ElasticInferenceServiceComponents withDefaultRevokeDelay(String elasticInferenceServiceUrl) {
28+
return new ElasticInferenceServiceComponents(elasticInferenceServiceUrl, DEFAULT_REVOKE_AUTHORIZATION_DELAY);
29+
}
30+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ public void getAuthorization(ActionListener<ElasticInferenceServiceAuthorization
7676
logger.debug("Retrieving authorization information from the Elastic Inference Service.");
7777

7878
if (Strings.isNullOrEmpty(baseUrl)) {
79-
logger.warn("The base URL for the authorization service is not valid, rejecting authorization.");
79+
logger.debug("The base URL for the authorization service is not valid, rejecting authorization.");
8080
listener.onResponse(ElasticInferenceServiceAuthorization.newDisabledService());
8181
return;
8282
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ public static ElasticInferenceServiceSparseEmbeddingsModel createModel(String ur
2626
new ElasticInferenceServiceSparseEmbeddingsServiceSettings(modelId, maxInputTokens, null),
2727
EmptyTaskSettings.INSTANCE,
2828
EmptySecretSettings.INSTANCE,
29-
new ElasticInferenceServiceComponents(url)
29+
ElasticInferenceServiceComponents.withNoRevokeDelay(url)
3030
);
3131
}
3232
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,8 @@ public void testChunkedInfer_PassesThrough() throws IOException {
567567

568568
public void testHideFromConfigurationApi_ReturnsTrue_WithNoAvailableModels() throws Exception {
569569
try (var service = createServiceWithMockSender(ElasticInferenceServiceAuthorization.newDisabledService())) {
570+
ensureAuthorizationCallFinished(service);
571+
570572
assertTrue(service.hideFromConfigurationApi());
571573
}
572574
}
@@ -586,6 +588,8 @@ public void testHideFromConfigurationApi_ReturnsTrue_WithModelTaskTypesThatAreNo
586588
)
587589
)
588590
) {
591+
ensureAuthorizationCallFinished(service);
592+
589593
assertTrue(service.hideFromConfigurationApi());
590594
}
591595
}
@@ -605,6 +609,8 @@ public void testHideFromConfigurationApi_ReturnsFalse_WithAvailableModels() thro
605609
)
606610
)
607611
) {
612+
ensureAuthorizationCallFinished(service);
613+
608614
assertFalse(service.hideFromConfigurationApi());
609615
}
610616
}
@@ -624,6 +630,8 @@ public void testGetConfiguration() throws Exception {
624630
)
625631
)
626632
) {
633+
ensureAuthorizationCallFinished(service);
634+
627635
String content = XContentHelper.stripWhitespace("""
628636
{
629637
"service": "elastic",
@@ -677,6 +685,8 @@ public void testGetConfiguration() throws Exception {
677685

678686
public void testGetConfiguration_WithoutSupportedTaskTypes() throws Exception {
679687
try (var service = createServiceWithMockSender(ElasticInferenceServiceAuthorization.newDisabledService())) {
688+
ensureAuthorizationCallFinished(service);
689+
680690
String content = XContentHelper.stripWhitespace("""
681691
{
682692
"service": "elastic",
@@ -744,6 +754,8 @@ public void testGetConfiguration_WithoutSupportedTaskTypes_WhenModelsReturnTaskO
744754
)
745755
)
746756
) {
757+
ensureAuthorizationCallFinished(service);
758+
747759
String content = XContentHelper.stripWhitespace("""
748760
{
749761
"service": "elastic",
@@ -811,7 +823,8 @@ public void testSupportedStreamingTasks_ReturnsChatCompletion_WhenAuthRespondsWi
811823

812824
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
813825
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
814-
service.waitForAuthorizationToComplete(TIMEOUT);
826+
ensureAuthorizationCallFinished(service);
827+
815828
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY)));
816829
assertTrue(service.defaultConfigIds().isEmpty());
817830

@@ -841,7 +854,8 @@ public void testSupportedTaskTypes_Returns_TheAuthorizedTaskTypes_IgnoresUnimple
841854

842855
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
843856
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
844-
service.waitForAuthorizationToComplete(TIMEOUT);
857+
ensureAuthorizationCallFinished(service);
858+
845859
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING)));
846860
}
847861
}
@@ -866,7 +880,8 @@ public void testSupportedTaskTypes_Returns_TheAuthorizedTaskTypes() throws Excep
866880

867881
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
868882
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
869-
service.waitForAuthorizationToComplete(TIMEOUT);
883+
ensureAuthorizationCallFinished(service);
884+
870885
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION)));
871886
}
872887
}
@@ -887,7 +902,8 @@ public void testSupportedStreamingTasks_ReturnsEmpty_WhenAuthRespondsWithoutChat
887902

888903
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
889904
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
890-
service.waitForAuthorizationToComplete(TIMEOUT);
905+
ensureAuthorizationCallFinished(service);
906+
891907
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
892908
assertTrue(service.defaultConfigIds().isEmpty());
893909
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING)));
@@ -914,7 +930,7 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsIn
914930

915931
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
916932
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
917-
service.waitForAuthorizationToComplete(TIMEOUT);
933+
ensureAuthorizationCallFinished(service);
918934
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
919935
assertThat(
920936
service.defaultConfigIds(),
@@ -952,7 +968,7 @@ public void testDefaultConfigs_Returns_DefaultEndpoints_WhenTaskTypeIsCorrect()
952968

953969
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
954970
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
955-
service.waitForAuthorizationToComplete(TIMEOUT);
971+
ensureAuthorizationCallFinished(service);
956972
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY)));
957973
assertThat(
958974
service.defaultConfigIds(),
@@ -1027,7 +1043,7 @@ private void testUnifiedStreamError(int responseCode, String responseJson, Strin
10271043
new ElasticInferenceServiceCompletionServiceSettings("model_id", new RateLimitSettings(100)),
10281044
EmptyTaskSettings.INSTANCE,
10291045
EmptySecretSettings.INSTANCE,
1030-
new ElasticInferenceServiceComponents(eisGatewayUrl)
1046+
ElasticInferenceServiceComponents.withNoRevokeDelay(eisGatewayUrl)
10311047
);
10321048
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
10331049
service.unifiedCompletionInfer(
@@ -1060,6 +1076,11 @@ private void testUnifiedStreamError(int responseCode, String responseJson, Strin
10601076
}
10611077
}
10621078

1079+
private void ensureAuthorizationCallFinished(ElasticInferenceService service) {
1080+
service.onNodeStarted();
1081+
service.waitForAuthorizationToComplete(TIMEOUT);
1082+
}
1083+
10631084
private ElasticInferenceService createServiceWithMockSender() {
10641085
return createServiceWithMockSender(ElasticInferenceServiceAuthorizationTests.createEnabledAuth());
10651086
}
@@ -1075,7 +1096,7 @@ private ElasticInferenceService createServiceWithMockSender(ElasticInferenceServ
10751096
return new ElasticInferenceService(
10761097
mock(HttpRequestSender.Factory.class),
10771098
createWithEmptySettings(threadPool),
1078-
new ElasticInferenceServiceComponents(null),
1099+
ElasticInferenceServiceComponents.EMPTY_INSTANCE,
10791100
mockModelRegistry(),
10801101
mockAuthHandler
10811102
);
@@ -1104,7 +1125,7 @@ private ElasticInferenceService createService(
11041125
return new ElasticInferenceService(
11051126
senderFactory,
11061127
createWithEmptySettings(threadPool),
1107-
new ElasticInferenceServiceComponents(gatewayUrl),
1128+
ElasticInferenceServiceComponents.withNoRevokeDelay(gatewayUrl),
11081129
mockModelRegistry(),
11091130
mockAuthHandler
11101131
);
@@ -1114,7 +1135,7 @@ private ElasticInferenceService createServiceWithAuthHandler(HttpRequestSender.F
11141135
return new ElasticInferenceService(
11151136
senderFactory,
11161137
createWithEmptySettings(threadPool),
1117-
new ElasticInferenceServiceComponents(eisGatewayUrl),
1138+
ElasticInferenceServiceComponents.withNoRevokeDelay(eisGatewayUrl),
11181139
mockModelRegistry(),
11191140
new ElasticInferenceServiceAuthorizationHandler(eisGatewayUrl, threadPool)
11201141
);

0 commit comments

Comments
 (0)