Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@

public class InferenceGetModelsWithElasticInferenceServiceIT extends BaseMockEISAuthServerTest {

/**
* This is done before the class because I've run into issues where another class that extends {@link BaseMockEISAuthServerTest}
* results in an authorization response not being queued up for the new Elasticsearch Node in time. When the node starts up, it
* retrieves authorization. If the request isn't queued up when that happens the tests will fail. From my testing locally it seems
* like the base class's static functionality to queue a response is only done once and not for each subclass.
*
* My understanding is that the @Before will be run after the node starts up and wouldn't be sufficient to handle
* this scenario. That is why this needs to be @BeforeClass.
*/
@BeforeClass
public static void init() {
// Ensure the mock EIS server has an authorized response ready
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.client.Request;
import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.TaskType;
import org.junit.Before;
import org.junit.BeforeClass;

import java.io.IOException;
Expand All @@ -23,6 +24,23 @@

public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {

@Before
public void setUp() throws Exception {
super.setUp();
// Ensure the mock EIS server has an authorized response ready before each test because each test will
// use the services API which makes a call to EIS
mockEISServer.enqueueAuthorizeAllModelsResponse();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The init() method below also calls mockEISServer.enqueueAuthorizeAllModelsResponse(). Is it equivalent to change the annotation on that method to @Before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the original issue for using @BeforeClass is because I ran into some weird issues locally. The original PR that added the @BeforeClass is here: #128640

For some background, when the node for the test starts up it will reach out to EIS and get the auth response. If that fails (there isn't a response queued in the mock server), then the tests will fail. What I was observing is that the base classes static logic would only be executed once regardless of how many subclasses used the base. This resulted in the first test class succeeding but the second test class that leveraged the base would fail. To get around this I added the @BeforeClass and it seemed to fix the issue. The reason we need this in @BeforeClass is because we need a response queued before Elasticsearch is started. Elasticsearch is started only once at the beginning before all the tests run.

}

/**
* This is done before the class because I've run into issues where another class that extends {@link BaseMockEISAuthServerTest}
* results in an authorization response not being queued up for the new Elasticsearch Node in time. When the node starts up, it
* retrieves authorization. If the request isn't queued up when that happens the tests will fail. From my testing locally it seems
* like the base class's static functionality to queue a response is only done once and not for each subclass.
*
* My understanding is that the @Before will be run after the node starts up and wouldn't be sufficient to handle
* this scenario. That is why this needs to be @BeforeClass.
*/
@BeforeClass
public static void init() {
// Ensure the mock EIS server has an authorized response ready
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ public static InferenceServiceConfiguration get() {
);

return new InferenceServiceConfiguration.Builder().setService(NAME)
.setName(NAME)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These aren't really necessary but I was running into a test failure because this field didn't exist until I switch the sorting field. Figured I'd leave the fix in though.

.setTaskTypes(SUPPORTED_TASK_TYPES)
.setConfigurations(configurationMap)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ public static InferenceServiceConfiguration get() {
);

return new InferenceServiceConfiguration.Builder().setService(NAME)
.setName(NAME)
.setTaskTypes(supportedTaskTypes)
.setConfigurations(configurationMap)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ public static InferenceServiceConfiguration get() {
);

return new InferenceServiceConfiguration.Builder().setService(NAME)
.setName(NAME)
.setTaskTypes(supportedTaskTypes)
.setConfigurations(configurationMap)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ public static InferenceServiceConfiguration get() {
);

return new InferenceServiceConfiguration.Builder().setService(NAME)
.setName(NAME)
.setTaskTypes(supportedTaskTypes)
.setConfigurations(configurationMap)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ public static InferenceServiceConfiguration get() {
);

return new InferenceServiceConfiguration.Builder().setService(NAME)
.setName(NAME)
.setTaskTypes(supportedTaskTypes)
.setConfigurations(configurationMap)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty()
try (var service = createElasticInferenceService()) {
ensureAuthorizationCallFinished(service);

assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
assertTrue(service.defaultConfigIds().isEmpty());
assertThat(service.supportedTaskTypes(), is(EnumSet.noneOf(TaskType.class)));

Expand Down Expand Up @@ -299,7 +299,7 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
try (var service = createElasticInferenceService()) {
ensureAuthorizationCallFinished(service);

assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
assertThat(
service.defaultConfigIds(),
containsInAnyOrder(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceSettings;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.highlight.SemanticTextHighlighter;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.mapper.OffsetSourceFieldMapper;
Expand Down Expand Up @@ -382,6 +383,8 @@ public Collection<?> createComponents(PluginServices services) {
new TransportGetInferenceDiagnosticsAction.ClientManagers(httpClientManager, elasticInferenceServiceHttpClientManager)
);
components.add(inferenceStatsBinding);
components.add(authorizationHandler);
components.add(new PluginComponentBinding<>(Sender.class, elasicInferenceServiceFactory.get().createSender()));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this I get a binding error when running

components.add(
new InferenceEndpointRegistry(
services.clusterService(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,55 @@

package org.elasticsearch.xpack.inference.action;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.injection.guice.Inject;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.inference.action.GetInferenceServicesAction;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationModel;
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME;

public class TransportGetInferenceServicesAction extends HandledTransportAction<
GetInferenceServicesAction.Request,
GetInferenceServicesAction.Response> {

private static final Logger logger = LogManager.getLogger(TransportGetInferenceServicesAction.class);

private final InferenceServiceRegistry serviceRegistry;
private final ElasticInferenceServiceAuthorizationRequestHandler eisAuthorizationRequestHandler;
private final Sender eisSender;
private final ThreadPool threadPool;

@Inject
public TransportGetInferenceServicesAction(
TransportService transportService,
ActionFilters actionFilters,
InferenceServiceRegistry serviceRegistry
ThreadPool threadPool,
InferenceServiceRegistry serviceRegistry,
ElasticInferenceServiceAuthorizationRequestHandler eisAuthorizationRequestHandler,
Sender sender
) {
super(
GetInferenceServicesAction.NAME,
Expand All @@ -46,6 +65,9 @@ public TransportGetInferenceServicesAction(
EsExecutors.DIRECT_EXECUTOR_SERVICE
);
this.serviceRegistry = serviceRegistry;
this.eisAuthorizationRequestHandler = eisAuthorizationRequestHandler;
this.eisSender = sender;
this.threadPool = threadPool;
}

@Override
Expand All @@ -69,41 +91,86 @@ private void getServiceConfigurationsForTaskType(
.entrySet()
.stream()
.filter(
service -> service.getValue().hideFromConfigurationApi() == false
// Exclude EIS as the EIS specific configurations must be retrieved directly from EIS and merged in later
service -> service.getValue().name().equals(ElasticInferenceService.NAME) == false
&& service.getValue().hideFromConfigurationApi() == false
&& service.getValue().supportedTaskTypes().contains(requestedTaskType)
)
.sorted(Comparator.comparing(service -> service.getValue().name()))
.collect(Collectors.toCollection(ArrayList::new));

getServiceConfigurationsForServices(filteredServices, listener.delegateFailureAndWrap((delegate, configurations) -> {
delegate.onResponse(new GetInferenceServicesAction.Response(configurations));
}));
getServiceConfigurationsForServicesAndEis(listener, filteredServices, requestedTaskType);
}

private void getAllServiceConfigurations(ActionListener<GetInferenceServicesAction.Response> listener) {
var availableServices = serviceRegistry.getServices()
.entrySet()
.stream()
.filter(service -> service.getValue().hideFromConfigurationApi() == false)
.filter(
// Exclude EIS as the EIS specific configurations must be retrieved directly from EIS and merged in later
service -> service.getValue().name().equals(ElasticInferenceService.NAME) == false
&& service.getValue().hideFromConfigurationApi() == false
)
.sorted(Comparator.comparing(service -> service.getValue().name()))
.collect(Collectors.toCollection(ArrayList::new));
getServiceConfigurationsForServices(availableServices, listener.delegateFailureAndWrap((delegate, configurations) -> {
delegate.onResponse(new GetInferenceServicesAction.Response(configurations));
}));

getServiceConfigurationsForServicesAndEis(listener, availableServices, null);
}

private void getServiceConfigurationsForServices(
ArrayList<Map.Entry<String, InferenceService>> services,
ActionListener<List<InferenceServiceConfiguration>> listener
private void getServiceConfigurationsForServicesAndEis(
ActionListener<GetInferenceServicesAction.Response> listener,
ArrayList<Map.Entry<String, InferenceService>> availableServices,
@Nullable TaskType requestedTaskType
) {
try {
var serviceConfigurations = new ArrayList<InferenceServiceConfiguration>();
for (var service : services) {
serviceConfigurations.add(service.getValue().getConfiguration());
SubscribableListener.<ElasticInferenceServiceAuthorizationModel>newForked(authModelListener -> {
// Executing on a separate thread because there's a chance the authorization call needs to do some initialization for the Sender
threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(() -> getEisAuthorization(authModelListener, eisSender));
}).<List<InferenceServiceConfiguration>>andThen((configurationListener, authorizationModel) -> {
var serviceConfigs = getServiceConfigurationsForServices(availableServices);

if (authorizationModel.isAuthorized() == false) {
configurationListener.onResponse(serviceConfigs);
return;
}

var config = ElasticInferenceService.createConfiguration(authorizationModel.getAuthorizedTaskTypes());
if (requestedTaskType != null && authorizationModel.getAuthorizedTaskTypes().contains(requestedTaskType) == false) {
configurationListener.onResponse(serviceConfigs);
return;
}
listener.onResponse(serviceConfigurations.stream().toList());
} catch (Exception e) {
listener.onFailure(e);

serviceConfigs.add(config);
serviceConfigs.sort(Comparator.comparing(InferenceServiceConfiguration::getService));
configurationListener.onResponse(serviceConfigs);
})
.addListener(
listener.delegateFailureAndWrap(
(delegate, configurations) -> delegate.onResponse(new GetInferenceServicesAction.Response(configurations))
)
);
}

private void getEisAuthorization(ActionListener<ElasticInferenceServiceAuthorizationModel> listener, Sender sender) {
var disabledServiceListener = listener.delegateResponse((delegate, e) -> {
logger.warn(
"Failed to retrieve authorization information from the "
+ "Elastic Inference Service while determining service configurations. Marking service as disabled.",
e
);
delegate.onResponse(ElasticInferenceServiceAuthorizationModel.newDisabledService());
});

eisAuthorizationRequestHandler.getAuthorization(disabledServiceListener, sender);
}

private List<InferenceServiceConfiguration> getServiceConfigurationsForServices(
ArrayList<Map.Entry<String, InferenceService>> services
) {
var serviceConfigurations = new ArrayList<InferenceServiceConfiguration>();
for (var service : services) {
serviceConfigurations.add(service.getValue().getConfiguration());
}

return serviceConfigurations;
}
}
Loading