-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[ML] Inference API _services retrieves authorization information directly from EIS #134398
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
89ef432
ef6e1c9
5bf5428
cecbc82
1d1adb8
dfba0f9
92e982c
b86a5e1
30cf241
17b54a3
6583ca4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -178,6 +178,7 @@ public static InferenceServiceConfiguration get() { | |
| ); | ||
|
|
||
| return new InferenceServiceConfiguration.Builder().setService(NAME) | ||
| .setName(NAME) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These aren't really necessary but I was running into a test failure because this field didn't exist until I switch the sorting field. Figured I'd leave the fix in though. |
||
| .setTaskTypes(SUPPORTED_TASK_TYPES) | ||
| .setConfigurations(configurationMap) | ||
| .build(); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -90,6 +90,7 @@ | |
| import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings; | ||
| import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; | ||
| import org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceSettings; | ||
| import org.elasticsearch.xpack.inference.external.http.sender.Sender; | ||
| import org.elasticsearch.xpack.inference.highlight.SemanticTextHighlighter; | ||
| import org.elasticsearch.xpack.inference.logging.ThrottlerManager; | ||
| import org.elasticsearch.xpack.inference.mapper.OffsetSourceFieldMapper; | ||
|
|
@@ -376,6 +377,8 @@ public Collection<?> createComponents(PluginServices services) { | |
| components.add(modelRegistry.get()); | ||
| components.add(httpClientManager); | ||
| components.add(inferenceStatsBinding); | ||
| components.add(authorizationHandler); | ||
| components.add(new PluginComponentBinding<>(Sender.class, elasicInferenceServiceFactory.get().createSender())); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Without this I get a binding error when running |
||
|
|
||
| // Only add InferenceServiceNodeLocalRateLimitCalculator (which is a ClusterStateListener) for cluster aware rate limiting, | ||
| // if the rate limiting feature flags are enabled, otherwise provide noop implementation | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -7,36 +7,55 @@ | |||||
|
|
||||||
| package org.elasticsearch.xpack.inference.action; | ||||||
|
|
||||||
| import org.apache.logging.log4j.LogManager; | ||||||
| import org.apache.logging.log4j.Logger; | ||||||
| import org.elasticsearch.action.ActionListener; | ||||||
| import org.elasticsearch.action.support.ActionFilters; | ||||||
| import org.elasticsearch.action.support.HandledTransportAction; | ||||||
| import org.elasticsearch.action.support.SubscribableListener; | ||||||
| import org.elasticsearch.common.util.concurrent.EsExecutors; | ||||||
| import org.elasticsearch.core.Nullable; | ||||||
| import org.elasticsearch.inference.InferenceService; | ||||||
| import org.elasticsearch.inference.InferenceServiceConfiguration; | ||||||
| import org.elasticsearch.inference.InferenceServiceRegistry; | ||||||
| import org.elasticsearch.inference.TaskType; | ||||||
| import org.elasticsearch.injection.guice.Inject; | ||||||
| import org.elasticsearch.tasks.Task; | ||||||
| import org.elasticsearch.threadpool.ThreadPool; | ||||||
| import org.elasticsearch.transport.TransportService; | ||||||
| import org.elasticsearch.xpack.core.inference.action.GetInferenceServicesAction; | ||||||
| import org.elasticsearch.xpack.inference.external.http.sender.Sender; | ||||||
| import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; | ||||||
| import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationModel; | ||||||
| import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; | ||||||
|
|
||||||
| import java.util.ArrayList; | ||||||
| import java.util.Comparator; | ||||||
| import java.util.List; | ||||||
| import java.util.Map; | ||||||
| import java.util.stream.Collectors; | ||||||
|
|
||||||
| import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; | ||||||
|
|
||||||
| public class TransportGetInferenceServicesAction extends HandledTransportAction< | ||||||
| GetInferenceServicesAction.Request, | ||||||
| GetInferenceServicesAction.Response> { | ||||||
|
|
||||||
| private static final Logger logger = LogManager.getLogger(TransportGetInferenceServicesAction.class); | ||||||
|
|
||||||
| private final InferenceServiceRegistry serviceRegistry; | ||||||
| private final ElasticInferenceServiceAuthorizationRequestHandler eisAuthorizationRequestHandler; | ||||||
| private final Sender eisSender; | ||||||
| private final ThreadPool threadPool; | ||||||
|
|
||||||
| @Inject | ||||||
| public TransportGetInferenceServicesAction( | ||||||
| TransportService transportService, | ||||||
| ActionFilters actionFilters, | ||||||
| InferenceServiceRegistry serviceRegistry | ||||||
| ThreadPool threadPool, | ||||||
| InferenceServiceRegistry serviceRegistry, | ||||||
| ElasticInferenceServiceAuthorizationRequestHandler eisAuthorizationRequestHandler, | ||||||
| Sender sender | ||||||
| ) { | ||||||
| super( | ||||||
| GetInferenceServicesAction.NAME, | ||||||
|
|
@@ -46,6 +65,9 @@ public TransportGetInferenceServicesAction( | |||||
| EsExecutors.DIRECT_EXECUTOR_SERVICE | ||||||
| ); | ||||||
| this.serviceRegistry = serviceRegistry; | ||||||
| this.eisAuthorizationRequestHandler = eisAuthorizationRequestHandler; | ||||||
| this.eisSender = sender; | ||||||
| this.threadPool = threadPool; | ||||||
| } | ||||||
|
|
||||||
| @Override | ||||||
|
|
@@ -69,27 +91,78 @@ private void getServiceConfigurationsForTaskType( | |||||
| .entrySet() | ||||||
| .stream() | ||||||
| .filter( | ||||||
| service -> service.getValue().hideFromConfigurationApi() == false | ||||||
| // exclude EIS here because the hideFromConfigurationApi() is not supported | ||||||
|
||||||
| // exclude EIS here because the hideFromConfigurationApi() is not supported | |
| // Exclude EIS as the EIS specific configurations are handled separately |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getServiceConfigurationsForServices() is a synchronous method and could be written with a return type instead of a listener. I think that would make this code easier to read as you wouldn't need to define the merge listener
}).<List<InferenceServiceConfiguration>>andThen((configurationListener, authorizationModel) -> {
var serviceConfigs = getServiceConfigurationsForServices(availableServices);
if (authorizationModel.isAuthorized() == false) {
delegate.onResponse(serviceConfigs);
return;
}
if (requestedTaskType != null && authorizationModel.getAuthorizedTaskTypes().contains(requestedTaskType) == false) {
delegate.onResponse(serviceConfigs);
return;
}
var config = ElasticInferenceService.createConfiguration(authorizationModel.getAuthorizedTaskTypes());
serviceConfigs.add(config);
serviceConfigs.sort(Comparator.comparing(InferenceServiceConfiguration::getService));
delegate.onResponse(serviceConfigs);
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤦♂️ thank you, for some reason I thought it needed use a listener.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,7 +14,6 @@ | |
| import org.elasticsearch.cluster.service.ClusterService; | ||
| import org.elasticsearch.common.Strings; | ||
| import org.elasticsearch.common.ValidationException; | ||
| import org.elasticsearch.common.util.LazyInitializable; | ||
| import org.elasticsearch.core.Nullable; | ||
| import org.elasticsearch.core.TimeValue; | ||
| import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; | ||
|
|
@@ -284,7 +283,7 @@ public void waitForFirstAuthorizationToComplete(TimeValue waitTime) { | |
|
|
||
| @Override | ||
| public Set<TaskType> supportedStreamingTasks() { | ||
| return authorizationHandler.supportedStreamingTasks(); | ||
| return EnumSet.of(TaskType.CHAT_COMPLETION); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Previously if the cluster wasn't authorized for chat completion we'd return a vague error about not being able to stream. With this change we'll allow the request to get sent to EIS and if it isn't authorized, EIS will return a failure. |
||
| } | ||
|
|
||
| @Override | ||
|
|
@@ -462,7 +461,12 @@ public void parseRequestConfig( | |
|
|
||
| @Override | ||
| public InferenceServiceConfiguration getConfiguration() { | ||
| return authorizationHandler.getConfiguration(); | ||
| // This shouldn't be called because the configuration changes based on the authorization | ||
| // Instead, retrieve the authorization directly from the EIS gateway and use the static method | ||
| // ElasticInferenceService.Configuration#createConfiguration() to create a configuration based on the authorization response | ||
|
||
| throw new UnsupportedOperationException( | ||
| "The EIS configuration changes depending on authorization, requests should be made directly to EIS instead" | ||
| ); | ||
| } | ||
|
|
||
| @Override | ||
|
|
@@ -472,7 +476,11 @@ public EnumSet<TaskType> supportedTaskTypes() { | |
|
|
||
| @Override | ||
| public boolean hideFromConfigurationApi() { | ||
| return authorizationHandler.hideFromConfigurationApi(); | ||
| // This shouldn't be called because the configuration changes based on the authorization | ||
| // Instead, retrieve the authorization directly from the EIS gateway and use the response to determine if EIS is authorized | ||
| throw new UnsupportedOperationException( | ||
| "The EIS configuration changes depending on authorization, requests should be made directly to EIS instead" | ||
| ); | ||
| } | ||
|
|
||
| private static ElasticInferenceServiceModel createModel( | ||
|
|
@@ -656,62 +664,45 @@ private TraceContext getCurrentTraceInfo() { | |
| return new TraceContext(traceParent, traceState); | ||
| } | ||
|
|
||
| public static class Configuration { | ||
|
|
||
| private final EnumSet<TaskType> enabledTaskTypes; | ||
| private final LazyInitializable<InferenceServiceConfiguration, RuntimeException> configuration; | ||
|
|
||
| public Configuration(EnumSet<TaskType> enabledTaskTypes) { | ||
| this.enabledTaskTypes = enabledTaskTypes; | ||
| configuration = initConfiguration(); | ||
| } | ||
|
|
||
| private LazyInitializable<InferenceServiceConfiguration, RuntimeException> initConfiguration() { | ||
| return new LazyInitializable<>(() -> { | ||
| var configurationMap = new HashMap<String, SettingsConfiguration>(); | ||
|
|
||
| configurationMap.put( | ||
| MODEL_ID, | ||
| new SettingsConfiguration.Builder( | ||
| EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.RERANK, TaskType.TEXT_EMBEDDING) | ||
| ).setDescription("The name of the model to use for the inference task.") | ||
| .setLabel("Model ID") | ||
| .setRequired(true) | ||
| .setSensitive(false) | ||
| .setUpdatable(false) | ||
| .setType(SettingsConfigurationFieldType.STRING) | ||
| .build() | ||
| ); | ||
|
|
||
| configurationMap.put( | ||
| MAX_INPUT_TOKENS, | ||
| new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING)).setDescription( | ||
| "Allows you to specify the maximum number of tokens per input." | ||
| ) | ||
| .setLabel("Maximum Input Tokens") | ||
| .setRequired(false) | ||
| .setSensitive(false) | ||
| .setUpdatable(false) | ||
| .setType(SettingsConfigurationFieldType.INTEGER) | ||
| .build() | ||
| ); | ||
| public static InferenceServiceConfiguration createConfiguration(EnumSet<TaskType> enabledTaskTypes) { | ||
| var configurationMap = new HashMap<String, SettingsConfiguration>(); | ||
|
|
||
| configurationMap.put( | ||
| MODEL_ID, | ||
| new SettingsConfiguration.Builder( | ||
| EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.RERANK, TaskType.TEXT_EMBEDDING) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this the intersection of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The list of task types here tells the UI for which task types this field should be configurable. That should stay the same regardless of whether the user is authorized for a specific task type. There's a top level field for task types that indicate which ones are authorized and that's set here: |
||
| ).setDescription("The name of the model to use for the inference task.") | ||
| .setLabel("Model ID") | ||
| .setRequired(true) | ||
| .setSensitive(false) | ||
| .setUpdatable(false) | ||
| .setType(SettingsConfigurationFieldType.STRING) | ||
| .build() | ||
| ); | ||
|
|
||
| configurationMap.putAll( | ||
| RateLimitSettings.toSettingsConfiguration( | ||
| EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.RERANK, TaskType.TEXT_EMBEDDING) | ||
| ) | ||
| ); | ||
| configurationMap.put( | ||
| MAX_INPUT_TOKENS, | ||
| new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING)).setDescription( | ||
| "Allows you to specify the maximum number of tokens per input." | ||
| ) | ||
| .setLabel("Maximum Input Tokens") | ||
| .setRequired(false) | ||
| .setSensitive(false) | ||
| .setUpdatable(false) | ||
| .setType(SettingsConfigurationFieldType.INTEGER) | ||
| .build() | ||
| ); | ||
|
|
||
| return new InferenceServiceConfiguration.Builder().setService(NAME) | ||
| .setName(SERVICE_NAME) | ||
| .setTaskTypes(enabledTaskTypes) | ||
| .setConfigurations(configurationMap) | ||
| .build(); | ||
| }); | ||
| } | ||
| configurationMap.putAll( | ||
| RateLimitSettings.toSettingsConfiguration( | ||
| EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.RERANK, TaskType.TEXT_EMBEDDING) | ||
| ) | ||
| ); | ||
|
|
||
| public InferenceServiceConfiguration get() { | ||
| return configuration.getOrCompute(); | ||
| } | ||
| return new InferenceServiceConfiguration.Builder().setService(NAME) | ||
| .setName(SERVICE_NAME) | ||
| .setTaskTypes(enabledTaskTypes) | ||
| .setConfigurations(configurationMap) | ||
| .build(); | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
init()method below also callsmockEISServer.enqueueAuthorizeAllModelsResponse(). Is it equivalent to change the annotation on that method to@Before?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah the original issue for using
@BeforeClassis because I ran into some weird issues locally. The original PR that added the@BeforeClassis here: #128640For some background, when the node for the test starts up it will reach out to EIS and get the auth response. If that fails (there isn't a response queued in the mock server), then the tests will fail. What I was observing is that the base classes static logic would only be executed once regardless of how many subclasses used the base. This resulted in the first test class succeeding but the second test class that leveraged the base would fail. To get around this I added the
@BeforeClassand it seemed to fix the issue. The reason we need this in@BeforeClassis because we need a response queued before Elasticsearch is started. Elasticsearch is started only once at the beginning before all the tests run.