Skip to content

Commit 6b110dc

Browse files
Adding revocation functionality back
1 parent 14ece26 commit 6b110dc

File tree

4 files changed

+321
-3
lines changed

4 files changed

+321
-3
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.integration;
9+
10+
import org.elasticsearch.ResourceNotFoundException;
11+
import org.elasticsearch.action.support.PlainActionFuture;
12+
import org.elasticsearch.common.settings.Settings;
13+
import org.elasticsearch.core.TimeValue;
14+
import org.elasticsearch.inference.InferenceService;
15+
import org.elasticsearch.inference.MinimalServiceSettings;
16+
import org.elasticsearch.inference.Model;
17+
import org.elasticsearch.inference.TaskType;
18+
import org.elasticsearch.inference.UnparsedModel;
19+
import org.elasticsearch.plugins.Plugin;
20+
import org.elasticsearch.reindex.ReindexPlugin;
21+
import org.elasticsearch.test.ESSingleNodeTestCase;
22+
import org.elasticsearch.test.http.MockResponse;
23+
import org.elasticsearch.test.http.MockWebServer;
24+
import org.elasticsearch.threadpool.ThreadPool;
25+
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
26+
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
27+
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
28+
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
29+
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
30+
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettingsTests;
31+
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler;
32+
import org.junit.After;
33+
import org.junit.Before;
34+
35+
import java.util.Collection;
36+
import java.util.EnumSet;
37+
import java.util.List;
38+
import java.util.concurrent.TimeUnit;
39+
40+
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
41+
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
42+
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
43+
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
44+
import static org.hamcrest.CoreMatchers.is;
45+
import static org.mockito.Mockito.mock;
46+
47+
public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase {
48+
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
49+
50+
private ModelRegistry modelRegistry;
51+
private final MockWebServer webServer = new MockWebServer();
52+
private ThreadPool threadPool;
53+
private String gatewayUrl;
54+
55+
@Before
56+
public void createComponents() throws Exception {
57+
threadPool = createThreadPool(inferenceUtilityPool());
58+
webServer.start();
59+
gatewayUrl = getUrl(webServer);
60+
modelRegistry = new ModelRegistry(client());
61+
}
62+
63+
@After
64+
public void shutdown() {
65+
terminate(threadPool);
66+
webServer.close();
67+
}
68+
69+
@Override
70+
protected boolean resetNodeAfterTest() {
71+
return true;
72+
}
73+
74+
@Override
75+
protected Collection<Class<? extends Plugin>> getPlugins() {
76+
return pluginList(ReindexPlugin.class);
77+
}
78+
79+
public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCorrect() throws Exception {
80+
String responseJson = """
81+
{
82+
"models": [
83+
{
84+
"model_name": "rainbow-sprinkles",
85+
"task_types": ["chat"]
86+
}
87+
]
88+
}
89+
""";
90+
91+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
92+
93+
try (var service = createElasticInferenceService()) {
94+
ensureAuthorizationCallFinished(service);
95+
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
96+
assertThat(
97+
service.defaultConfigIds(),
98+
is(
99+
List.of(
100+
new InferenceService.DefaultConfigId(".rainbow-sprinkles-elastic", MinimalServiceSettings.chatCompletion(), service)
101+
)
102+
)
103+
);
104+
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
105+
106+
PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
107+
service.defaultConfigs(listener);
108+
assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
109+
}
110+
}
111+
112+
public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty() throws Exception {
113+
{
114+
String responseJson = """
115+
{
116+
"models": [
117+
{
118+
"model_name": "rainbow-sprinkles",
119+
"task_types": ["chat"]
120+
}
121+
]
122+
}
123+
""";
124+
125+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
126+
127+
try (var service = createElasticInferenceService()) {
128+
ensureAuthorizationCallFinished(service);
129+
130+
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
131+
assertThat(
132+
service.defaultConfigIds(),
133+
is(
134+
List.of(
135+
new InferenceService.DefaultConfigId(
136+
".rainbow-sprinkles-elastic",
137+
MinimalServiceSettings.chatCompletion(),
138+
service
139+
)
140+
)
141+
)
142+
);
143+
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
144+
145+
PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
146+
service.defaultConfigs(listener);
147+
assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
148+
149+
var getModelListener = new PlainActionFuture<UnparsedModel>();
150+
// persists the default endpoints
151+
modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener);
152+
153+
var inferenceEntity = getModelListener.actionGet(TIMEOUT);
154+
assertThat(inferenceEntity.inferenceEntityId(), is(".rainbow-sprinkles-elastic"));
155+
assertThat(inferenceEntity.taskType(), is(TaskType.CHAT_COMPLETION));
156+
}
157+
}
158+
{
159+
String noAuthorizationResponseJson = """
160+
{
161+
"models": []
162+
}
163+
""";
164+
165+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson));
166+
167+
try (var service = createElasticInferenceService()) {
168+
ensureAuthorizationCallFinished(service);
169+
170+
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
171+
assertTrue(service.defaultConfigIds().isEmpty());
172+
assertThat(service.supportedTaskTypes(), is(EnumSet.noneOf(TaskType.class)));
173+
174+
var getModelListener = new PlainActionFuture<UnparsedModel>();
175+
modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener);
176+
177+
var exception = expectThrows(ResourceNotFoundException.class, () -> getModelListener.actionGet(TIMEOUT));
178+
assertThat(exception.getMessage(), is("Inference endpoint not found [.rainbow-sprinkles-elastic]"));
179+
}
180+
}
181+
}
182+
183+
public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnAuthForIt() throws Exception {
184+
{
185+
String responseJson = """
186+
{
187+
"models": [
188+
{
189+
"model_name": "rainbow-sprinkles",
190+
"task_types": ["chat"]
191+
},
192+
{
193+
"model_name": "elser-v2",
194+
"task_types": ["embed/text/sparse"]
195+
}
196+
]
197+
}
198+
""";
199+
200+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
201+
202+
try (var service = createElasticInferenceService()) {
203+
ensureAuthorizationCallFinished(service);
204+
205+
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
206+
assertThat(
207+
service.defaultConfigIds(),
208+
is(
209+
List.of(
210+
new InferenceService.DefaultConfigId(".elser-v2-elastic", MinimalServiceSettings.sparseEmbedding(), service),
211+
new InferenceService.DefaultConfigId(
212+
".rainbow-sprinkles-elastic",
213+
MinimalServiceSettings.chatCompletion(),
214+
service
215+
)
216+
)
217+
)
218+
);
219+
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING)));
220+
221+
PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
222+
service.defaultConfigs(listener);
223+
assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic"));
224+
assertThat(listener.actionGet(TIMEOUT).get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
225+
226+
var getModelListener = new PlainActionFuture<UnparsedModel>();
227+
// persists the default endpoints
228+
modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener);
229+
230+
var inferenceEntity = getModelListener.actionGet(TIMEOUT);
231+
assertThat(inferenceEntity.inferenceEntityId(), is(".rainbow-sprinkles-elastic"));
232+
assertThat(inferenceEntity.taskType(), is(TaskType.CHAT_COMPLETION));
233+
}
234+
}
235+
{
236+
String noAuthorizationResponseJson = """
237+
{
238+
"models": [
239+
{
240+
"model_name": "elser-v2",
241+
"task_types": ["embed/text/sparse"]
242+
}
243+
]
244+
}
245+
""";
246+
247+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson));
248+
249+
try (var service = createElasticInferenceService()) {
250+
ensureAuthorizationCallFinished(service);
251+
252+
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
253+
assertThat(
254+
service.defaultConfigIds(),
255+
is(
256+
List.of(
257+
new InferenceService.DefaultConfigId(".elser-v2-elastic", MinimalServiceSettings.sparseEmbedding(), service)
258+
)
259+
)
260+
);
261+
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING)));
262+
263+
var getModelListener = new PlainActionFuture<UnparsedModel>();
264+
modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener);
265+
var exception = expectThrows(ResourceNotFoundException.class, () -> getModelListener.actionGet(TIMEOUT));
266+
assertThat(exception.getMessage(), is("Inference endpoint not found [.rainbow-sprinkles-elastic]"));
267+
}
268+
}
269+
}
270+
271+
private void ensureAuthorizationCallFinished(ElasticInferenceService service) {
272+
service.onNodeStarted();
273+
service.waitForFirstAuthorizationToComplete(TIMEOUT);
274+
}
275+
276+
private ElasticInferenceService createElasticInferenceService() {
277+
var httpManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
278+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, httpManager);
279+
280+
return new ElasticInferenceService(
281+
senderFactory,
282+
createWithEmptySettings(threadPool),
283+
ElasticInferenceServiceSettingsTests.create(gatewayUrl),
284+
modelRegistry,
285+
new ElasticInferenceServiceAuthorizationRequestHandler(gatewayUrl, threadPool)
286+
);
287+
}
288+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,11 +163,13 @@ public void onNodeStarted() {
163163
}
164164

165165
/**
166+
* Only use this in tests.
167+
*
166168
* Waits the specified amount of time for the authorization call to complete. This is mainly to make testing easier.
167169
* @param waitTime the max time to wait
168170
* @throws IllegalStateException if the wait time is exceeded or the call receives an {@link InterruptedException}
169171
*/
170-
void waitForAuthorizationToComplete(TimeValue waitTime) {
172+
public void waitForFirstAuthorizationToComplete(TimeValue waitTime) {
171173
authorizationHandler.waitForAuthorizationToComplete(waitTime);
172174
}
173175

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import java.util.ArrayList;
3131
import java.util.Comparator;
3232
import java.util.EnumSet;
33+
import java.util.HashSet;
3334
import java.util.List;
3435
import java.util.Map;
3536
import java.util.Objects;
@@ -39,6 +40,7 @@
3940
import java.util.concurrent.TimeUnit;
4041
import java.util.concurrent.atomic.AtomicBoolean;
4142
import java.util.concurrent.atomic.AtomicReference;
43+
import java.util.stream.Collectors;
4244

4345
import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME;
4446

@@ -227,7 +229,6 @@ private void sendAuthorizationRequest() {
227229
if (callback != null) {
228230
callback.run();
229231
}
230-
firstAuthorizationCompletedLatch.countDown();
231232
}, e -> {
232233
// we don't need to do anything if there was a failure, everything is disabled by default
233234
firstAuthorizationCompletedLatch.countDown();
@@ -258,6 +259,7 @@ private synchronized void setAuthorizedContent(ElasticInferenceServiceAuthorizat
258259
configuration.set(new ElasticInferenceService.Configuration(authorizedContent.get().taskTypesAndModels.getAuthorizedTaskTypes()));
259260

260261
authorizedContent.get().configIds().forEach(modelRegistry::putDefaultIdIfAbsent);
262+
handleRevokedDefaultConfigs(authorizedDefaultModelIds);
261263
}
262264

263265
private Set<String> getAuthorizedDefaultModelIds(ElasticInferenceServiceAuthorizationModel auth) {
@@ -312,4 +314,30 @@ private List<DefaultModelConfig> getAuthorizedDefaultModelsObjects(Set<String> a
312314
authorizedModels.sort(Comparator.comparing(modelConfig -> modelConfig.model().getInferenceEntityId()));
313315
return authorizedModels;
314316
}
317+
318+
private void handleRevokedDefaultConfigs(Set<String> authorizedDefaultModelIds) {
319+
// if a model was initially returned in the authorization response but is absent, then we'll assume authorization was revoked
320+
var unauthorizedDefaultModelIds = new HashSet<>(defaultModelsConfigs.keySet());
321+
unauthorizedDefaultModelIds.removeAll(authorizedDefaultModelIds);
322+
323+
// get all the default inference endpoint ids for the unauthorized model ids
324+
var unauthorizedDefaultInferenceEndpointIds = unauthorizedDefaultModelIds.stream()
325+
.map(defaultModelsConfigs::get) // get all the model configs
326+
.filter(Objects::nonNull) // limit to only non-null
327+
.map(modelConfig -> modelConfig.model().getInferenceEntityId()) // get the inference ids
328+
.collect(Collectors.toSet());
329+
330+
var deleteInferenceEndpointsListener = ActionListener.<Boolean>wrap(result -> {
331+
logger.debug(Strings.format("Successfully revoked access to default inference endpoint IDs: %s", unauthorizedDefaultModelIds));
332+
firstAuthorizationCompletedLatch.countDown();
333+
}, e -> {
334+
logger.warn(
335+
Strings.format("Failed to revoke access to default inference endpoint IDs: %s, error: %s", unauthorizedDefaultModelIds, e)
336+
);
337+
firstAuthorizationCompletedLatch.countDown();
338+
});
339+
340+
logger.debug("Synchronizing default inference endpoints");
341+
modelRegistry.removeDefaultConfigs(unauthorizedDefaultInferenceEndpointIds, deleteInferenceEndpointsListener);
342+
}
315343
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1100,7 +1100,7 @@ private void testUnifiedStreamError(int responseCode, String responseJson, Strin
11001100

11011101
private void ensureAuthorizationCallFinished(ElasticInferenceService service) {
11021102
service.onNodeStarted();
1103-
service.waitForAuthorizationToComplete(TIMEOUT);
1103+
service.waitForFirstAuthorizationToComplete(TIMEOUT);
11041104
}
11051105

11061106
private ElasticInferenceService createServiceWithMockSender() {

0 commit comments

Comments
 (0)