Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.client.Request;
import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.TaskType;
import org.junit.Before;
import org.junit.BeforeClass;

import java.io.IOException;
Expand All @@ -23,6 +24,13 @@

public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {

@Before
public void setUp() throws Exception {
super.setUp();
// Ensure the mock EIS server has an authorized response ready
mockEISServer.enqueueAuthorizeAllModelsResponse();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The init() method below also calls mockEISServer.enqueueAuthorizeAllModelsResponse(). Is it equivalent to change the annotation on that method to @Before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the original issue for using @BeforeClass is because I ran into some weird issues locally. The original PR that added the @BeforeClass is here: #128640

For some background, when the node for the test starts up it will reach out to EIS and get the auth response. If that fails (there isn't a response queued in the mock server), then the tests will fail. What I was observing is that the base classes static logic would only be executed once regardless of how many subclasses used the base. This resulted in the first test class succeeding but the second test class that leveraged the base would fail. To get around this I added the @BeforeClass and it seemed to fix the issue. The reason we need this in @BeforeClass is because we need a response queued before Elasticsearch is started. Elasticsearch is started only once at the beginning before all the tests run.

}

@BeforeClass
public static void init() {
// Ensure the mock EIS server has an authorized response ready
Expand Down Expand Up @@ -130,8 +138,9 @@ public void testGetServicesWithRerankTaskType() throws IOException {
}

public void testGetServicesWithCompletionTaskType() throws IOException {
var a = providersFor(TaskType.COMPLETION);
assertThat(
providersFor(TaskType.COMPLETION),
a,
containsInAnyOrder(
List.of(
"ai21",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ public static InferenceServiceConfiguration get() {
);

return new InferenceServiceConfiguration.Builder().setService(NAME)
.setName(NAME)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These aren't really necessary but I was running into a test failure because this field didn't exist until I switch the sorting field. Figured I'd leave the fix in though.

.setTaskTypes(SUPPORTED_TASK_TYPES)
.setConfigurations(configurationMap)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ public static InferenceServiceConfiguration get() {
);

return new InferenceServiceConfiguration.Builder().setService(NAME)
.setName(NAME)
.setTaskTypes(supportedTaskTypes)
.setConfigurations(configurationMap)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ public static InferenceServiceConfiguration get() {
);

return new InferenceServiceConfiguration.Builder().setService(NAME)
.setName(NAME)
.setTaskTypes(supportedTaskTypes)
.setConfigurations(configurationMap)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ public static InferenceServiceConfiguration get() {
.build()
);

return new InferenceServiceConfiguration.Builder().setService(NAME)
return new InferenceServiceConfiguration.Builder().setService(NAME).setName(NAME)
.setTaskTypes(supportedTaskTypes)
.setConfigurations(configurationMap)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ public static InferenceServiceConfiguration get() {
);

return new InferenceServiceConfiguration.Builder().setService(NAME)
.setName(NAME)
.setTaskTypes(supportedTaskTypes)
.setConfigurations(configurationMap)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceSettings;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.highlight.SemanticTextHighlighter;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.mapper.OffsetSourceFieldMapper;
Expand Down Expand Up @@ -376,6 +377,8 @@ public Collection<?> createComponents(PluginServices services) {
components.add(modelRegistry.get());
components.add(httpClientManager);
components.add(inferenceStatsBinding);
components.add(authorizationHandler);
components.add(new PluginComponentBinding<>(Sender.class, elasicInferenceServiceFactory.get().createSender()));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this I get a binding error when running


// Only add InferenceServiceNodeLocalRateLimitCalculator (which is a ClusterStateListener) for cluster aware rate limiting,
// if the rate limiting feature flags are enabled, otherwise provide noop implementation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,55 @@

package org.elasticsearch.xpack.inference.action;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.injection.guice.Inject;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.inference.action.GetInferenceServicesAction;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationModel;
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME;

public class TransportGetInferenceServicesAction extends HandledTransportAction<
GetInferenceServicesAction.Request,
GetInferenceServicesAction.Response> {

private static final Logger logger = LogManager.getLogger(TransportGetInferenceServicesAction.class);

private final InferenceServiceRegistry serviceRegistry;
private final ElasticInferenceServiceAuthorizationRequestHandler eisAuthorizationRequestHandler;
private final Sender eisSender;
private final ThreadPool threadPool;

@Inject
public TransportGetInferenceServicesAction(
TransportService transportService,
ActionFilters actionFilters,
InferenceServiceRegistry serviceRegistry
ThreadPool threadPool,
InferenceServiceRegistry serviceRegistry,
ElasticInferenceServiceAuthorizationRequestHandler eisAuthorizationRequestHandler,
Sender sender
) {
super(
GetInferenceServicesAction.NAME,
Expand All @@ -46,6 +65,9 @@ public TransportGetInferenceServicesAction(
EsExecutors.DIRECT_EXECUTOR_SERVICE
);
this.serviceRegistry = serviceRegistry;
this.eisAuthorizationRequestHandler = eisAuthorizationRequestHandler;
this.eisSender = sender;
this.threadPool = threadPool;
}

@Override
Expand All @@ -69,27 +91,78 @@ private void getServiceConfigurationsForTaskType(
.entrySet()
.stream()
.filter(
service -> service.getValue().hideFromConfigurationApi() == false
// exclude EIS here because the hideFromConfigurationApi() is not supported
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was slightly confused about "hideFromConfigurationApi() is not supported" in this comment

Suggested change
// exclude EIS here because the hideFromConfigurationApi() is not supported
// Exclude EIS as the EIS specific configurations are handled separately

service -> service.getValue().name().equals(ElasticInferenceService.NAME) == false
&& service.getValue().hideFromConfigurationApi() == false
&& service.getValue().supportedTaskTypes().contains(requestedTaskType)
)
.sorted(Comparator.comparing(service -> service.getValue().name()))
.collect(Collectors.toCollection(ArrayList::new));

getServiceConfigurationsForServices(filteredServices, listener.delegateFailureAndWrap((delegate, configurations) -> {
delegate.onResponse(new GetInferenceServicesAction.Response(configurations));
}));
getServiceConfigurationsForServicesAndEis(listener, filteredServices, requestedTaskType);
}

private void getAllServiceConfigurations(ActionListener<GetInferenceServicesAction.Response> listener) {
var availableServices = serviceRegistry.getServices()
.entrySet()
.stream()
.filter(service -> service.getValue().hideFromConfigurationApi() == false)
.filter(
// exclude EIS here because the hideFromConfigurationApi() is not supported
service -> service.getValue().name().equals(ElasticInferenceService.NAME) == false
&& service.getValue().hideFromConfigurationApi() == false
)
.sorted(Comparator.comparing(service -> service.getValue().name()))
.collect(Collectors.toCollection(ArrayList::new));
getServiceConfigurationsForServices(availableServices, listener.delegateFailureAndWrap((delegate, configurations) -> {
delegate.onResponse(new GetInferenceServicesAction.Response(configurations));
}));

getServiceConfigurationsForServicesAndEis(listener, availableServices, null);
}

private void getServiceConfigurationsForServicesAndEis(
ActionListener<GetInferenceServicesAction.Response> listener,
ArrayList<Map.Entry<String, InferenceService>> availableServices,
@Nullable TaskType requestedTaskType
) {
SubscribableListener.<ElasticInferenceServiceAuthorizationModel>newForked(authModelListener -> {
// Executing on a separate thread because there's a chance the authorization call needs to do some initialization for the Sender
threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(() -> getEisAuthorization(authModelListener, eisSender));
}).<List<InferenceServiceConfiguration>>andThen((configurationListener, authorizationModel) -> {

ActionListener<List<InferenceServiceConfiguration>> mergeEisConfigListener = configurationListener.delegateFailureAndWrap(
(delegate, serviceConfigs) -> {
if (authorizationModel.isAuthorized() == false) {
delegate.onResponse(serviceConfigs);
return;
}

var config = ElasticInferenceService.createConfiguration(authorizationModel.getAuthorizedTaskTypes());
if (requestedTaskType != null && authorizationModel.getAuthorizedTaskTypes().contains(requestedTaskType)) {
serviceConfigs.add(config);
}
serviceConfigs.sort(Comparator.comparing(InferenceServiceConfiguration::getService));
delegate.onResponse(serviceConfigs);
}
);

getServiceConfigurationsForServices(availableServices, mergeEisConfigListener);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getServiceConfigurationsForServices() is a synchronous method and could be written with a return type instead of a listener. I think that would make this code easier to read as you wouldn't need to define the merge listener

}).<List<InferenceServiceConfiguration>>andThen((configurationListener, authorizationModel) -> {

                    var serviceConfigs = getServiceConfigurationsForServices(availableServices);

                    if (authorizationModel.isAuthorized() == false) {
                        delegate.onResponse(serviceConfigs);
                        return;
                    }


                    if (requestedTaskType != null && authorizationModel.getAuthorizedTaskTypes().contains(requestedTaskType) == false) {
                        delegate.onResponse(serviceConfigs);
                        return;
                    }

                    var config = ElasticInferenceService.createConfiguration(authorizationModel.getAuthorizedTaskTypes());
                    serviceConfigs.add(config);
                    serviceConfigs.sort(Comparator.comparing(InferenceServiceConfiguration::getService));
                    delegate.onResponse(serviceConfigs);
                }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤦‍♂️ thank you, for some reason I thought it needed use a listener.

})
.addListener(
listener.delegateFailureAndWrap(
(delegate, configurations) -> delegate.onResponse(new GetInferenceServicesAction.Response(configurations))
)
);
}

private void getEisAuthorization(ActionListener<ElasticInferenceServiceAuthorizationModel> listener, Sender sender) {
var disabledServiceListener = listener.delegateResponse((delegate, e) -> {
logger.warn(
"Failed to retrieve authorization information from the "
+ "Elastic Inference Service while determining service configurations. Marking service as disabled.",
e
);
delegate.onResponse(ElasticInferenceServiceAuthorizationModel.newDisabledService());
});

eisAuthorizationRequestHandler.getAuthorization(disabledServiceListener, sender);
}

private void getServiceConfigurationsForServices(
Expand All @@ -101,7 +174,7 @@ private void getServiceConfigurationsForServices(
for (var service : services) {
serviceConfigurations.add(service.getValue().getConfiguration());
}
listener.onResponse(serviceConfigurations.stream().toList());
listener.onResponse(serviceConfigurations);
} catch (Exception e) {
listener.onFailure(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
Expand Down Expand Up @@ -284,7 +283,7 @@ public void waitForFirstAuthorizationToComplete(TimeValue waitTime) {

@Override
public Set<TaskType> supportedStreamingTasks() {
return authorizationHandler.supportedStreamingTasks();
return EnumSet.of(TaskType.CHAT_COMPLETION);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously if the cluster wasn't authorized for chat completion we'd return a vague error about not being able to stream. With this change we'll allow the request to get sent to EIS and if it isn't authorized, EIS will return a failure.

}

@Override
Expand Down Expand Up @@ -462,7 +461,12 @@ public void parseRequestConfig(

@Override
public InferenceServiceConfiguration getConfiguration() {
return authorizationHandler.getConfiguration();
// This shouldn't be called because the configuration changes based on the authorization
// Instead, retrieve the authorization directly from the EIS gateway and use the static method
// ElasticInferenceService.Configuration#createConfiguration() to create a configuration based on the authorization response
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be a javadoc comment instead, so that the method you're referencing stays correct even if it's modified?

throw new UnsupportedOperationException(
"The EIS configuration changes depending on authorization, requests should be made directly to EIS instead"
);
}

@Override
Expand All @@ -472,7 +476,11 @@ public EnumSet<TaskType> supportedTaskTypes() {

@Override
public boolean hideFromConfigurationApi() {
return authorizationHandler.hideFromConfigurationApi();
// This shouldn't be called because the configuration changes based on the authorization
// Instead, retrieve the authorization directly from the EIS gateway and use the response to determine if EIS is authorized
throw new UnsupportedOperationException(
"The EIS configuration changes depending on authorization, requests should be made directly to EIS instead"
);
}

private static ElasticInferenceServiceModel createModel(
Expand Down Expand Up @@ -656,62 +664,45 @@ private TraceContext getCurrentTraceInfo() {
return new TraceContext(traceParent, traceState);
}

public static class Configuration {

private final EnumSet<TaskType> enabledTaskTypes;
private final LazyInitializable<InferenceServiceConfiguration, RuntimeException> configuration;

public Configuration(EnumSet<TaskType> enabledTaskTypes) {
this.enabledTaskTypes = enabledTaskTypes;
configuration = initConfiguration();
}

private LazyInitializable<InferenceServiceConfiguration, RuntimeException> initConfiguration() {
return new LazyInitializable<>(() -> {
var configurationMap = new HashMap<String, SettingsConfiguration>();

configurationMap.put(
MODEL_ID,
new SettingsConfiguration.Builder(
EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.RERANK, TaskType.TEXT_EMBEDDING)
).setDescription("The name of the model to use for the inference task.")
.setLabel("Model ID")
.setRequired(true)
.setSensitive(false)
.setUpdatable(false)
.setType(SettingsConfigurationFieldType.STRING)
.build()
);

configurationMap.put(
MAX_INPUT_TOKENS,
new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING)).setDescription(
"Allows you to specify the maximum number of tokens per input."
)
.setLabel("Maximum Input Tokens")
.setRequired(false)
.setSensitive(false)
.setUpdatable(false)
.setType(SettingsConfigurationFieldType.INTEGER)
.build()
);
public static InferenceServiceConfiguration createConfiguration(EnumSet<TaskType> enabledTaskTypes) {
var configurationMap = new HashMap<String, SettingsConfiguration>();

configurationMap.put(
MODEL_ID,
new SettingsConfiguration.Builder(
EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.RERANK, TaskType.TEXT_EMBEDDING)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this the intersection of enabledTaskTypes and the full set?

EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.RERANK, TaskType.TEXT_EMBEDDING).retainAll(enabledTaskTypes);
``

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The list of task types here tells the UI for which task types this field should be configurable. That should stay the same regardless of whether the user is authorized for a specific task type. There's a top level field for task types that indicate which ones are authorized and that's set here:

        return new InferenceServiceConfiguration.Builder().setService(NAME)
            .setName(SERVICE_NAME)
            .setTaskTypes(enabledTaskTypes) <--------
            .setConfigurations(configurationMap)
            .build();

).setDescription("The name of the model to use for the inference task.")
.setLabel("Model ID")
.setRequired(true)
.setSensitive(false)
.setUpdatable(false)
.setType(SettingsConfigurationFieldType.STRING)
.build()
);

configurationMap.putAll(
RateLimitSettings.toSettingsConfiguration(
EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.RERANK, TaskType.TEXT_EMBEDDING)
)
);
configurationMap.put(
MAX_INPUT_TOKENS,
new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING)).setDescription(
"Allows you to specify the maximum number of tokens per input."
)
.setLabel("Maximum Input Tokens")
.setRequired(false)
.setSensitive(false)
.setUpdatable(false)
.setType(SettingsConfigurationFieldType.INTEGER)
.build()
);

return new InferenceServiceConfiguration.Builder().setService(NAME)
.setName(SERVICE_NAME)
.setTaskTypes(enabledTaskTypes)
.setConfigurations(configurationMap)
.build();
});
}
configurationMap.putAll(
RateLimitSettings.toSettingsConfiguration(
EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.RERANK, TaskType.TEXT_EMBEDDING)
)
);

public InferenceServiceConfiguration get() {
return configuration.getOrCompute();
}
return new InferenceServiceConfiguration.Builder().setService(NAME)
.setName(SERVICE_NAME)
.setTaskTypes(enabledTaskTypes)
.setConfigurations(configurationMap)
.build();
}
}
Loading