Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -241,4 +241,10 @@ default void defaultConfigs(ActionListener<List<Model>> defaultsListener) {
default void updateModelsWithDynamicFields(List<Model> model, ActionListener<List<Model>> listener) {
listener.onResponse(model);
}

/**
* Called after the Elasticsearch node has completed its start up. This allows the service to perform initialization
* after ensuring the node's internals are set up (for example if this ensures the internal ES client is ready for use).
*/
default void onNodeStarted() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ public void init(Client client) {
services.values().forEach(s -> s.init(client));
}

public void onNodeStarted() {
services.values().forEach(InferenceService::onNodeStarted);
}

public Map<String, InferenceService> getServices() {
return services;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,9 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCo
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));

try (var service = createElasticInferenceService()) {
service.waitForAuthorizationToComplete(TIMEOUT);
ensureAuthorizationCallFinished(service);
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY)));

assertThat(
service.defaultConfigIds(),
is(
Expand Down Expand Up @@ -125,7 +126,8 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty()
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));

try (var service = createElasticInferenceService()) {
service.waitForAuthorizationToComplete(TIMEOUT);
ensureAuthorizationCallFinished(service);

assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY)));
assertThat(
service.defaultConfigIds(),
Expand Down Expand Up @@ -164,7 +166,8 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty()
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson));

try (var service = createElasticInferenceService()) {
service.waitForAuthorizationToComplete(TIMEOUT);
ensureAuthorizationCallFinished(service);

assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
assertTrue(service.defaultConfigIds().isEmpty());
assertThat(service.supportedTaskTypes(), is(EnumSet.noneOf(TaskType.class)));
Expand Down Expand Up @@ -198,7 +201,8 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));

try (var service = createElasticInferenceService()) {
service.waitForAuthorizationToComplete(TIMEOUT);
ensureAuthorizationCallFinished(service);

assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY)));
assertThat(
service.defaultConfigIds(),
Expand Down Expand Up @@ -244,7 +248,8 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson));

try (var service = createElasticInferenceService()) {
service.waitForAuthorizationToComplete(TIMEOUT);
ensureAuthorizationCallFinished(service);

assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
assertThat(
service.defaultConfigIds(),
Expand All @@ -264,14 +269,19 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
}
}

private void ensureAuthorizationCallFinished(ElasticInferenceService service) {
service.onNodeStarted();
service.waitForAuthorizationToComplete(TIMEOUT);
}

private ElasticInferenceService createElasticInferenceService() {
var httpManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, httpManager);

return new ElasticInferenceService(
senderFactory,
createWithEmptySettings(threadPool),
new ElasticInferenceServiceComponents(gatewayUrl),
ElasticInferenceServiceComponents.withNoRevokeDelay(gatewayUrl),
modelRegistry,
new ElasticInferenceServiceAuthorizationHandler(gatewayUrl, threadPool)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.node.PluginComponentBinding;
import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.plugins.ClusterPlugin;
import org.elasticsearch.plugins.ExtensiblePlugin;
import org.elasticsearch.plugins.MapperPlugin;
import org.elasticsearch.plugins.Plugin;
Expand Down Expand Up @@ -146,7 +147,8 @@ public class InferencePlugin extends Plugin
SystemIndexPlugin,
MapperPlugin,
SearchPlugin,
InternalSearchPlugin {
InternalSearchPlugin,
ClusterPlugin {

/**
* When this setting is true the verification check that
Expand Down Expand Up @@ -274,7 +276,7 @@ public Collection<?> createComponents(PluginServices services) {
ElasticInferenceServiceSettings inferenceServiceSettings = new ElasticInferenceServiceSettings(settings);
String elasticInferenceUrl = inferenceServiceSettings.getElasticInferenceServiceUrl();

var elasticInferenceServiceComponentsInstance = new ElasticInferenceServiceComponents(elasticInferenceUrl);
var elasticInferenceServiceComponentsInstance = ElasticInferenceServiceComponents.withDefaultRevokeDelay(elasticInferenceUrl);
elasticInferenceServiceComponents.set(elasticInferenceServiceComponentsInstance);

var authorizationHandler = new ElasticInferenceServiceAuthorizationHandler(
Expand Down Expand Up @@ -516,6 +518,15 @@ private String getElasticInferenceServiceUrl(ElasticInferenceServiceSettings set
return settings.getElasticInferenceServiceUrl();
}

@Override
public void onNodeStarted() {
var registry = inferenceServiceRegistry.get();

if (registry != null) {
registry.onNodeStarted();
}
}

protected SSLService getSslService() {
return XPackPlugin.getSharedSslService();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,6 @@ public ElasticInferenceService(

configuration = new Configuration(authRef.get().taskTypesAndModels.getAuthorizedTaskTypes());
defaultModelsConfigs = initDefaultEndpoints(elasticInferenceServiceComponents);

getAuthorization();
}

private static Map<String, DefaultModelConfig> initDefaultEndpoints(
Expand Down Expand Up @@ -282,9 +280,24 @@ private void handleRevokedDefaultConfigs(Set<String> authorizedDefaultModelIds)
authorizationCompletedLatch.countDown();
});

getServiceComponents().threadPool()
.executor(UTILITY_THREAD_POOL_NAME)
.execute(() -> modelRegistry.removeDefaultConfigs(unauthorizedDefaultInferenceEndpointIds, deleteInferenceEndpointsListener));
Runnable removeFromRegistry = () -> {
logger.debug("Synchronizing default inference endpoints");
modelRegistry.removeDefaultConfigs(unauthorizedDefaultInferenceEndpointIds, deleteInferenceEndpointsListener);
};

var delay = elasticInferenceServiceComponents.revokeAuthorizationDelay();
if (delay == null) {
getServiceComponents().threadPool().executor(UTILITY_THREAD_POOL_NAME).execute(removeFromRegistry);
} else {
getServiceComponents().threadPool()
.schedule(removeFromRegistry, delay, getServiceComponents().threadPool().executor(UTILITY_THREAD_POOL_NAME));
}

}

@Override
public void onNodeStarted() {
getServiceComponents().threadPool().executor(UTILITY_THREAD_POOL_NAME).execute(this::getAuthorization);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,23 @@
package org.elasticsearch.xpack.inference.services.elastic;

import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;

public record ElasticInferenceServiceComponents(@Nullable String elasticInferenceServiceUrl) {}
/**
* @param elasticInferenceServiceUrl the upstream Elastic Inference Server's URL
* @param revokeAuthorizationDelay Amount of time to wait before attempting to revoke authorization to certain model ids.
* null indicates that there should be no delay
*/
public record ElasticInferenceServiceComponents(@Nullable String elasticInferenceServiceUrl, @Nullable TimeValue revokeAuthorizationDelay) {
private static final TimeValue DEFAULT_REVOKE_AUTHORIZATION_DELAY = TimeValue.timeValueMinutes(10);

public static final ElasticInferenceServiceComponents EMPTY_INSTANCE = new ElasticInferenceServiceComponents(null, null);

public static ElasticInferenceServiceComponents withNoRevokeDelay(String elasticInferenceServiceUrl) {
return new ElasticInferenceServiceComponents(elasticInferenceServiceUrl, null);
}

public static ElasticInferenceServiceComponents withDefaultRevokeDelay(String elasticInferenceServiceUrl) {
return new ElasticInferenceServiceComponents(elasticInferenceServiceUrl, DEFAULT_REVOKE_AUTHORIZATION_DELAY);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public void getAuthorization(ActionListener<ElasticInferenceServiceAuthorization
logger.debug("Retrieving authorization information from the Elastic Inference Service.");

if (Strings.isNullOrEmpty(baseUrl)) {
logger.warn("The base URL for the authorization service is not valid, rejecting authorization.");
logger.debug("The base URL for the authorization service is not valid, rejecting authorization.");
listener.onResponse(ElasticInferenceServiceAuthorization.newDisabledService());
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public static ElasticInferenceServiceSparseEmbeddingsModel createModel(String ur
new ElasticInferenceServiceSparseEmbeddingsServiceSettings(modelId, maxInputTokens, null),
EmptyTaskSettings.INSTANCE,
EmptySecretSettings.INSTANCE,
new ElasticInferenceServiceComponents(url)
ElasticInferenceServiceComponents.withNoRevokeDelay(url)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,8 @@ public void testChunkedInfer_PassesThrough() throws IOException {

public void testHideFromConfigurationApi_ReturnsTrue_WithNoAvailableModels() throws Exception {
try (var service = createServiceWithMockSender(ElasticInferenceServiceAuthorization.newDisabledService())) {
ensureAuthorizationCallFinished(service);

assertTrue(service.hideFromConfigurationApi());
}
}
Expand All @@ -586,6 +588,8 @@ public void testHideFromConfigurationApi_ReturnsTrue_WithModelTaskTypesThatAreNo
)
)
) {
ensureAuthorizationCallFinished(service);

assertTrue(service.hideFromConfigurationApi());
}
}
Expand All @@ -605,6 +609,8 @@ public void testHideFromConfigurationApi_ReturnsFalse_WithAvailableModels() thro
)
)
) {
ensureAuthorizationCallFinished(service);

assertFalse(service.hideFromConfigurationApi());
}
}
Expand All @@ -624,6 +630,8 @@ public void testGetConfiguration() throws Exception {
)
)
) {
ensureAuthorizationCallFinished(service);

String content = XContentHelper.stripWhitespace("""
{
"service": "elastic",
Expand Down Expand Up @@ -677,6 +685,8 @@ public void testGetConfiguration() throws Exception {

public void testGetConfiguration_WithoutSupportedTaskTypes() throws Exception {
try (var service = createServiceWithMockSender(ElasticInferenceServiceAuthorization.newDisabledService())) {
ensureAuthorizationCallFinished(service);

String content = XContentHelper.stripWhitespace("""
{
"service": "elastic",
Expand Down Expand Up @@ -744,6 +754,8 @@ public void testGetConfiguration_WithoutSupportedTaskTypes_WhenModelsReturnTaskO
)
)
) {
ensureAuthorizationCallFinished(service);

String content = XContentHelper.stripWhitespace("""
{
"service": "elastic",
Expand Down Expand Up @@ -811,7 +823,8 @@ public void testSupportedStreamingTasks_ReturnsChatCompletion_WhenAuthRespondsWi

var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
service.waitForAuthorizationToComplete(TIMEOUT);
ensureAuthorizationCallFinished(service);

assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY)));
assertTrue(service.defaultConfigIds().isEmpty());

Expand Down Expand Up @@ -841,7 +854,8 @@ public void testSupportedTaskTypes_Returns_TheAuthorizedTaskTypes_IgnoresUnimple

var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
service.waitForAuthorizationToComplete(TIMEOUT);
ensureAuthorizationCallFinished(service);

assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING)));
}
}
Expand All @@ -866,7 +880,8 @@ public void testSupportedTaskTypes_Returns_TheAuthorizedTaskTypes() throws Excep

var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
service.waitForAuthorizationToComplete(TIMEOUT);
ensureAuthorizationCallFinished(service);

assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION)));
}
}
Expand All @@ -887,7 +902,8 @@ public void testSupportedStreamingTasks_ReturnsEmpty_WhenAuthRespondsWithoutChat

var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
service.waitForAuthorizationToComplete(TIMEOUT);
ensureAuthorizationCallFinished(service);

assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
assertTrue(service.defaultConfigIds().isEmpty());
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING)));
Expand All @@ -914,7 +930,7 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsIn

var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
service.waitForAuthorizationToComplete(TIMEOUT);
ensureAuthorizationCallFinished(service);
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
assertThat(
service.defaultConfigIds(),
Expand Down Expand Up @@ -952,7 +968,7 @@ public void testDefaultConfigs_Returns_DefaultEndpoints_WhenTaskTypeIsCorrect()

var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
service.waitForAuthorizationToComplete(TIMEOUT);
ensureAuthorizationCallFinished(service);
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY)));
assertThat(
service.defaultConfigIds(),
Expand Down Expand Up @@ -1027,7 +1043,7 @@ private void testUnifiedStreamError(int responseCode, String responseJson, Strin
new ElasticInferenceServiceCompletionServiceSettings("model_id", new RateLimitSettings(100)),
EmptyTaskSettings.INSTANCE,
EmptySecretSettings.INSTANCE,
new ElasticInferenceServiceComponents(eisGatewayUrl)
ElasticInferenceServiceComponents.withNoRevokeDelay(eisGatewayUrl)
);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
service.unifiedCompletionInfer(
Expand Down Expand Up @@ -1060,6 +1076,11 @@ private void testUnifiedStreamError(int responseCode, String responseJson, Strin
}
}

private void ensureAuthorizationCallFinished(ElasticInferenceService service) {
service.onNodeStarted();
service.waitForAuthorizationToComplete(TIMEOUT);
}

private ElasticInferenceService createServiceWithMockSender() {
return createServiceWithMockSender(ElasticInferenceServiceAuthorizationTests.createEnabledAuth());
}
Expand All @@ -1075,7 +1096,7 @@ private ElasticInferenceService createServiceWithMockSender(ElasticInferenceServ
return new ElasticInferenceService(
mock(HttpRequestSender.Factory.class),
createWithEmptySettings(threadPool),
new ElasticInferenceServiceComponents(null),
ElasticInferenceServiceComponents.EMPTY_INSTANCE,
mockModelRegistry(),
mockAuthHandler
);
Expand Down Expand Up @@ -1104,7 +1125,7 @@ private ElasticInferenceService createService(
return new ElasticInferenceService(
senderFactory,
createWithEmptySettings(threadPool),
new ElasticInferenceServiceComponents(gatewayUrl),
ElasticInferenceServiceComponents.withNoRevokeDelay(gatewayUrl),
mockModelRegistry(),
mockAuthHandler
);
Expand All @@ -1114,7 +1135,7 @@ private ElasticInferenceService createServiceWithAuthHandler(HttpRequestSender.F
return new ElasticInferenceService(
senderFactory,
createWithEmptySettings(threadPool),
new ElasticInferenceServiceComponents(eisGatewayUrl),
ElasticInferenceServiceComponents.withNoRevokeDelay(eisGatewayUrl),
mockModelRegistry(),
new ElasticInferenceServiceAuthorizationHandler(eisGatewayUrl, threadPool)
);
Expand Down
Loading