-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[ML] Inference API _services retrieves authorization information directly from EIS #134398
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
89ef432
ef6e1c9
5bf5428
cecbc82
1d1adb8
dfba0f9
92e982c
b86a5e1
30cf241
17b54a3
6583ca4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -178,6 +178,7 @@ public static InferenceServiceConfiguration get() { | |
| ); | ||
|
|
||
| return new InferenceServiceConfiguration.Builder().setService(NAME) | ||
| .setName(NAME) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These aren't really necessary but I was running into a test failure because this field didn't exist until I switch the sorting field. Figured I'd leave the fix in though. |
||
| .setTaskTypes(SUPPORTED_TASK_TYPES) | ||
| .setConfigurations(configurationMap) | ||
| .build(); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -90,6 +90,7 @@ | |
| import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings; | ||
| import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; | ||
| import org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceSettings; | ||
| import org.elasticsearch.xpack.inference.external.http.sender.Sender; | ||
| import org.elasticsearch.xpack.inference.highlight.SemanticTextHighlighter; | ||
| import org.elasticsearch.xpack.inference.logging.ThrottlerManager; | ||
| import org.elasticsearch.xpack.inference.mapper.OffsetSourceFieldMapper; | ||
|
|
@@ -379,6 +380,8 @@ public Collection<?> createComponents(PluginServices services) { | |
| new TransportGetInferenceDiagnosticsAction.ClientManagers(httpClientManager, elasticInferenceServiceHttpClientManager) | ||
| ); | ||
| components.add(inferenceStatsBinding); | ||
| components.add(authorizationHandler); | ||
| components.add(new PluginComponentBinding<>(Sender.class, elasicInferenceServiceFactory.get().createSender())); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Without this I get a binding error when running |
||
|
|
||
| // Only add InferenceServiceNodeLocalRateLimitCalculator (which is a ClusterStateListener) for cluster aware rate limiting, | ||
| // if the rate limiting feature flags are enabled, otherwise provide noop implementation | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -7,36 +7,55 @@ | |||||
|
|
||||||
| package org.elasticsearch.xpack.inference.action; | ||||||
|
|
||||||
| import org.apache.logging.log4j.LogManager; | ||||||
| import org.apache.logging.log4j.Logger; | ||||||
| import org.elasticsearch.action.ActionListener; | ||||||
| import org.elasticsearch.action.support.ActionFilters; | ||||||
| import org.elasticsearch.action.support.HandledTransportAction; | ||||||
| import org.elasticsearch.action.support.SubscribableListener; | ||||||
| import org.elasticsearch.common.util.concurrent.EsExecutors; | ||||||
| import org.elasticsearch.core.Nullable; | ||||||
| import org.elasticsearch.inference.InferenceService; | ||||||
| import org.elasticsearch.inference.InferenceServiceConfiguration; | ||||||
| import org.elasticsearch.inference.InferenceServiceRegistry; | ||||||
| import org.elasticsearch.inference.TaskType; | ||||||
| import org.elasticsearch.injection.guice.Inject; | ||||||
| import org.elasticsearch.tasks.Task; | ||||||
| import org.elasticsearch.threadpool.ThreadPool; | ||||||
| import org.elasticsearch.transport.TransportService; | ||||||
| import org.elasticsearch.xpack.core.inference.action.GetInferenceServicesAction; | ||||||
| import org.elasticsearch.xpack.inference.external.http.sender.Sender; | ||||||
| import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; | ||||||
| import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationModel; | ||||||
| import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; | ||||||
|
|
||||||
| import java.util.ArrayList; | ||||||
| import java.util.Comparator; | ||||||
| import java.util.List; | ||||||
| import java.util.Map; | ||||||
| import java.util.stream.Collectors; | ||||||
|
|
||||||
| import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; | ||||||
|
|
||||||
| public class TransportGetInferenceServicesAction extends HandledTransportAction< | ||||||
| GetInferenceServicesAction.Request, | ||||||
| GetInferenceServicesAction.Response> { | ||||||
|
|
||||||
| private static final Logger logger = LogManager.getLogger(TransportGetInferenceServicesAction.class); | ||||||
|
|
||||||
| private final InferenceServiceRegistry serviceRegistry; | ||||||
| private final ElasticInferenceServiceAuthorizationRequestHandler eisAuthorizationRequestHandler; | ||||||
| private final Sender eisSender; | ||||||
| private final ThreadPool threadPool; | ||||||
|
|
||||||
| @Inject | ||||||
| public TransportGetInferenceServicesAction( | ||||||
| TransportService transportService, | ||||||
| ActionFilters actionFilters, | ||||||
| InferenceServiceRegistry serviceRegistry | ||||||
| ThreadPool threadPool, | ||||||
| InferenceServiceRegistry serviceRegistry, | ||||||
| ElasticInferenceServiceAuthorizationRequestHandler eisAuthorizationRequestHandler, | ||||||
| Sender sender | ||||||
| ) { | ||||||
| super( | ||||||
| GetInferenceServicesAction.NAME, | ||||||
|
|
@@ -46,6 +65,9 @@ public TransportGetInferenceServicesAction( | |||||
| EsExecutors.DIRECT_EXECUTOR_SERVICE | ||||||
| ); | ||||||
| this.serviceRegistry = serviceRegistry; | ||||||
| this.eisAuthorizationRequestHandler = eisAuthorizationRequestHandler; | ||||||
| this.eisSender = sender; | ||||||
| this.threadPool = threadPool; | ||||||
| } | ||||||
|
|
||||||
| @Override | ||||||
|
|
@@ -69,27 +91,81 @@ private void getServiceConfigurationsForTaskType( | |||||
| .entrySet() | ||||||
| .stream() | ||||||
| .filter( | ||||||
| service -> service.getValue().hideFromConfigurationApi() == false | ||||||
| // exclude EIS here because the hideFromConfigurationApi() is not supported | ||||||
|
||||||
| // exclude EIS here because the hideFromConfigurationApi() is not supported | |
| // Exclude EIS as the EIS specific configurations are handled separately |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getServiceConfigurationsForServices() is a synchronous method and could be written with a return type instead of a listener. I think that would make this code easier to read as you wouldn't need to define the merge listener
}).<List<InferenceServiceConfiguration>>andThen((configurationListener, authorizationModel) -> {
var serviceConfigs = getServiceConfigurationsForServices(availableServices);
if (authorizationModel.isAuthorized() == false) {
delegate.onResponse(serviceConfigs);
return;
}
if (requestedTaskType != null && authorizationModel.getAuthorizedTaskTypes().contains(requestedTaskType) == false) {
delegate.onResponse(serviceConfigs);
return;
}
var config = ElasticInferenceService.createConfiguration(authorizationModel.getAuthorizedTaskTypes());
serviceConfigs.add(config);
serviceConfigs.sort(Comparator.comparing(InferenceServiceConfiguration::getService));
delegate.onResponse(serviceConfigs);
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤦♂️ thank you, for some reason I thought it needed use a listener.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
init()method below also callsmockEISServer.enqueueAuthorizeAllModelsResponse(). Is it equivalent to change the annotation on that method to@Before?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah the original issue for using
@BeforeClassis because I ran into some weird issues locally. The original PR that added the@BeforeClassis here: #128640For some background, when the node for the test starts up it will reach out to EIS and get the auth response. If that fails (there isn't a response queued in the mock server), then the tests will fail. What I was observing is that the base classes static logic would only be executed once regardless of how many subclasses used the base. This resulted in the first test class succeeding but the second test class that leveraged the base would fail. To get around this I added the
@BeforeClassand it seemed to fix the issue. The reason we need this in@BeforeClassis because we need a response queued before Elasticsearch is started. Elasticsearch is started only once at the beginning before all the tests run.