Skip to content

Commit aae1ffc

Browse files
jonathan-buttnerelasticsearchmachine
andauthored
[ML] Inference API _services retrieves authorization information directly from EIS (elastic#134398)
* get services querying eis gateway for info * [CI] Auto commit changes from spotless * Adding fixes * Fixing tests * Address feedback for javadoc * Addressing feedback --------- Co-authored-by: elasticsearchmachine <[email protected]>
1 parent 1b49b87 commit aae1ffc

File tree

13 files changed

+189
-150
lines changed

13 files changed

+189
-150
lines changed

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,15 @@
2323

2424
public class InferenceGetModelsWithElasticInferenceServiceIT extends BaseMockEISAuthServerTest {
2525

26+
/**
27+
* This is done before the class because I've run into issues where another class that extends {@link BaseMockEISAuthServerTest}
28+
* results in an authorization response not being queued up for the new Elasticsearch Node in time. When the node starts up, it
29+
* retrieves authorization. If the request isn't queued up when that happens the tests will fail. From my testing locally it seems
30+
* like the base class's static functionality to queue a response is only done once and not for each subclass.
31+
*
32+
* My understanding is that the @Before will be run after the node starts up and wouldn't be sufficient to handle
33+
* this scenario. That is why this needs to be @BeforeClass.
34+
*/
2635
@BeforeClass
2736
public static void init() {
2837
// Ensure the mock EIS server has an authorized response ready

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.elasticsearch.client.Request;
1313
import org.elasticsearch.common.Strings;
1414
import org.elasticsearch.inference.TaskType;
15+
import org.junit.Before;
1516
import org.junit.BeforeClass;
1617

1718
import java.io.IOException;
@@ -23,6 +24,23 @@
2324

2425
public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
2526

27+
@Before
28+
public void setUp() throws Exception {
29+
super.setUp();
30+
// Ensure the mock EIS server has an authorized response ready before each test because each test will
31+
// use the services API which makes a call to EIS
32+
mockEISServer.enqueueAuthorizeAllModelsResponse();
33+
}
34+
35+
/**
36+
* This is done before the class because I've run into issues where another class that extends {@link BaseMockEISAuthServerTest}
37+
* results in an authorization response not being queued up for the new Elasticsearch Node in time. When the node starts up, it
38+
* retrieves authorization. If the request isn't queued up when that happens the tests will fail. From my testing locally it seems
39+
* like the base class's static functionality to queue a response is only done once and not for each subclass.
40+
*
41+
* My understanding is that the @Before will be run after the node starts up and wouldn't be sufficient to handle
42+
* this scenario. That is why this needs to be @BeforeClass.
43+
*/
2644
@BeforeClass
2745
public static void init() {
2846
// Ensure the mock EIS server has an authorized response ready

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestCompletionServiceExtension.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ public static InferenceServiceConfiguration get() {
178178
);
179179

180180
return new InferenceServiceConfiguration.Builder().setService(NAME)
181+
.setName(NAME)
181182
.setTaskTypes(SUPPORTED_TASK_TYPES)
182183
.setConfigurations(configurationMap)
183184
.build();

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ public static InferenceServiceConfiguration get() {
283283
);
284284

285285
return new InferenceServiceConfiguration.Builder().setService(NAME)
286+
.setName(NAME)
286287
.setTaskTypes(supportedTaskTypes)
287288
.setConfigurations(configurationMap)
288289
.build();

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ public static InferenceServiceConfiguration get() {
228228
);
229229

230230
return new InferenceServiceConfiguration.Builder().setService(NAME)
231+
.setName(NAME)
231232
.setTaskTypes(supportedTaskTypes)
232233
.setConfigurations(configurationMap)
233234
.build();

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ public static InferenceServiceConfiguration get() {
224224
);
225225

226226
return new InferenceServiceConfiguration.Builder().setService(NAME)
227+
.setName(NAME)
227228
.setTaskTypes(supportedTaskTypes)
228229
.setConfigurations(configurationMap)
229230
.build();

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,7 @@ public static InferenceServiceConfiguration get() {
326326
);
327327

328328
return new InferenceServiceConfiguration.Builder().setService(NAME)
329+
.setName(NAME)
329330
.setTaskTypes(supportedTaskTypes)
330331
.setConfigurations(configurationMap)
331332
.build();

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty()
176176
try (var service = createElasticInferenceService()) {
177177
ensureAuthorizationCallFinished(service);
178178

179-
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
179+
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
180180
assertTrue(service.defaultConfigIds().isEmpty());
181181
assertThat(service.supportedTaskTypes(), is(EnumSet.noneOf(TaskType.class)));
182182

@@ -299,7 +299,7 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
299299
try (var service = createElasticInferenceService()) {
300300
ensureAuthorizationCallFinished(service);
301301

302-
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
302+
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
303303
assertThat(
304304
service.defaultConfigIds(),
305305
containsInAnyOrder(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings;
8989
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
9090
import org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceSettings;
91+
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
9192
import org.elasticsearch.xpack.inference.highlight.SemanticTextHighlighter;
9293
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
9394
import org.elasticsearch.xpack.inference.mapper.OffsetSourceFieldMapper;
@@ -382,6 +383,8 @@ public Collection<?> createComponents(PluginServices services) {
382383
new TransportGetInferenceDiagnosticsAction.ClientManagers(httpClientManager, elasticInferenceServiceHttpClientManager)
383384
);
384385
components.add(inferenceStatsBinding);
386+
components.add(authorizationHandler);
387+
components.add(new PluginComponentBinding<>(Sender.class, elasicInferenceServiceFactory.get().createSender()));
385388
components.add(
386389
new InferenceEndpointRegistry(
387390
services.clusterService(),

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java

Lines changed: 86 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,36 +7,55 @@
77

88
package org.elasticsearch.xpack.inference.action;
99

10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
1012
import org.elasticsearch.action.ActionListener;
1113
import org.elasticsearch.action.support.ActionFilters;
1214
import org.elasticsearch.action.support.HandledTransportAction;
15+
import org.elasticsearch.action.support.SubscribableListener;
1316
import org.elasticsearch.common.util.concurrent.EsExecutors;
17+
import org.elasticsearch.core.Nullable;
1418
import org.elasticsearch.inference.InferenceService;
1519
import org.elasticsearch.inference.InferenceServiceConfiguration;
1620
import org.elasticsearch.inference.InferenceServiceRegistry;
1721
import org.elasticsearch.inference.TaskType;
1822
import org.elasticsearch.injection.guice.Inject;
1923
import org.elasticsearch.tasks.Task;
24+
import org.elasticsearch.threadpool.ThreadPool;
2025
import org.elasticsearch.transport.TransportService;
2126
import org.elasticsearch.xpack.core.inference.action.GetInferenceServicesAction;
27+
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
28+
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
29+
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationModel;
30+
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler;
2231

2332
import java.util.ArrayList;
2433
import java.util.Comparator;
2534
import java.util.List;
2635
import java.util.Map;
2736
import java.util.stream.Collectors;
2837

38+
import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME;
39+
2940
public class TransportGetInferenceServicesAction extends HandledTransportAction<
3041
GetInferenceServicesAction.Request,
3142
GetInferenceServicesAction.Response> {
3243

44+
private static final Logger logger = LogManager.getLogger(TransportGetInferenceServicesAction.class);
45+
3346
private final InferenceServiceRegistry serviceRegistry;
47+
private final ElasticInferenceServiceAuthorizationRequestHandler eisAuthorizationRequestHandler;
48+
private final Sender eisSender;
49+
private final ThreadPool threadPool;
3450

3551
@Inject
3652
public TransportGetInferenceServicesAction(
3753
TransportService transportService,
3854
ActionFilters actionFilters,
39-
InferenceServiceRegistry serviceRegistry
55+
ThreadPool threadPool,
56+
InferenceServiceRegistry serviceRegistry,
57+
ElasticInferenceServiceAuthorizationRequestHandler eisAuthorizationRequestHandler,
58+
Sender sender
4059
) {
4160
super(
4261
GetInferenceServicesAction.NAME,
@@ -46,6 +65,9 @@ public TransportGetInferenceServicesAction(
4665
EsExecutors.DIRECT_EXECUTOR_SERVICE
4766
);
4867
this.serviceRegistry = serviceRegistry;
68+
this.eisAuthorizationRequestHandler = eisAuthorizationRequestHandler;
69+
this.eisSender = sender;
70+
this.threadPool = threadPool;
4971
}
5072

5173
@Override
@@ -69,41 +91,86 @@ private void getServiceConfigurationsForTaskType(
6991
.entrySet()
7092
.stream()
7193
.filter(
72-
service -> service.getValue().hideFromConfigurationApi() == false
94+
// Exclude EIS as the EIS specific configurations must be retrieved directly from EIS and merged in later
95+
service -> service.getValue().name().equals(ElasticInferenceService.NAME) == false
96+
&& service.getValue().hideFromConfigurationApi() == false
7397
&& service.getValue().supportedTaskTypes().contains(requestedTaskType)
7498
)
7599
.sorted(Comparator.comparing(service -> service.getValue().name()))
76100
.collect(Collectors.toCollection(ArrayList::new));
77101

78-
getServiceConfigurationsForServices(filteredServices, listener.delegateFailureAndWrap((delegate, configurations) -> {
79-
delegate.onResponse(new GetInferenceServicesAction.Response(configurations));
80-
}));
102+
getServiceConfigurationsForServicesAndEis(listener, filteredServices, requestedTaskType);
81103
}
82104

83105
private void getAllServiceConfigurations(ActionListener<GetInferenceServicesAction.Response> listener) {
84106
var availableServices = serviceRegistry.getServices()
85107
.entrySet()
86108
.stream()
87-
.filter(service -> service.getValue().hideFromConfigurationApi() == false)
109+
.filter(
110+
// Exclude EIS as the EIS specific configurations must be retrieved directly from EIS and merged in later
111+
service -> service.getValue().name().equals(ElasticInferenceService.NAME) == false
112+
&& service.getValue().hideFromConfigurationApi() == false
113+
)
88114
.sorted(Comparator.comparing(service -> service.getValue().name()))
89115
.collect(Collectors.toCollection(ArrayList::new));
90-
getServiceConfigurationsForServices(availableServices, listener.delegateFailureAndWrap((delegate, configurations) -> {
91-
delegate.onResponse(new GetInferenceServicesAction.Response(configurations));
92-
}));
116+
117+
getServiceConfigurationsForServicesAndEis(listener, availableServices, null);
93118
}
94119

95-
private void getServiceConfigurationsForServices(
96-
ArrayList<Map.Entry<String, InferenceService>> services,
97-
ActionListener<List<InferenceServiceConfiguration>> listener
120+
private void getServiceConfigurationsForServicesAndEis(
121+
ActionListener<GetInferenceServicesAction.Response> listener,
122+
ArrayList<Map.Entry<String, InferenceService>> availableServices,
123+
@Nullable TaskType requestedTaskType
98124
) {
99-
try {
100-
var serviceConfigurations = new ArrayList<InferenceServiceConfiguration>();
101-
for (var service : services) {
102-
serviceConfigurations.add(service.getValue().getConfiguration());
125+
SubscribableListener.<ElasticInferenceServiceAuthorizationModel>newForked(authModelListener -> {
126+
// Executing on a separate thread because there's a chance the authorization call needs to do some initialization for the Sender
127+
threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(() -> getEisAuthorization(authModelListener, eisSender));
128+
}).<List<InferenceServiceConfiguration>>andThen((configurationListener, authorizationModel) -> {
129+
var serviceConfigs = getServiceConfigurationsForServices(availableServices);
130+
131+
if (authorizationModel.isAuthorized() == false) {
132+
configurationListener.onResponse(serviceConfigs);
133+
return;
134+
}
135+
136+
var config = ElasticInferenceService.createConfiguration(authorizationModel.getAuthorizedTaskTypes());
137+
if (requestedTaskType != null && authorizationModel.getAuthorizedTaskTypes().contains(requestedTaskType) == false) {
138+
configurationListener.onResponse(serviceConfigs);
139+
return;
103140
}
104-
listener.onResponse(serviceConfigurations.stream().toList());
105-
} catch (Exception e) {
106-
listener.onFailure(e);
141+
142+
serviceConfigs.add(config);
143+
serviceConfigs.sort(Comparator.comparing(InferenceServiceConfiguration::getService));
144+
configurationListener.onResponse(serviceConfigs);
145+
})
146+
.addListener(
147+
listener.delegateFailureAndWrap(
148+
(delegate, configurations) -> delegate.onResponse(new GetInferenceServicesAction.Response(configurations))
149+
)
150+
);
151+
}
152+
153+
private void getEisAuthorization(ActionListener<ElasticInferenceServiceAuthorizationModel> listener, Sender sender) {
154+
var disabledServiceListener = listener.delegateResponse((delegate, e) -> {
155+
logger.warn(
156+
"Failed to retrieve authorization information from the "
157+
+ "Elastic Inference Service while determining service configurations. Marking service as disabled.",
158+
e
159+
);
160+
delegate.onResponse(ElasticInferenceServiceAuthorizationModel.newDisabledService());
161+
});
162+
163+
eisAuthorizationRequestHandler.getAuthorization(disabledServiceListener, sender);
164+
}
165+
166+
private List<InferenceServiceConfiguration> getServiceConfigurationsForServices(
167+
ArrayList<Map.Entry<String, InferenceService>> services
168+
) {
169+
var serviceConfigurations = new ArrayList<InferenceServiceConfiguration>();
170+
for (var service : services) {
171+
serviceConfigurations.add(service.getValue().getConfiguration());
107172
}
173+
174+
return serviceConfigurations;
108175
}
109176
}

0 commit comments

Comments
 (0)