Skip to content

Commit 98c1c89

Browse files
[ML] Support revoking inference default endpoint authorization (#121326)
* Starting revoke * Adding integration tests * More integration tests * Adding test for deleting default inference endpoint via rest call * Removing task type any * Addressing feedback and adding test
1 parent cbcdd0a commit 98c1c89

File tree

8 files changed

+555
-35
lines changed

8 files changed

+555
-35
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.integration;
9+
10+
import org.elasticsearch.ResourceNotFoundException;
11+
import org.elasticsearch.action.support.PlainActionFuture;
12+
import org.elasticsearch.common.settings.Settings;
13+
import org.elasticsearch.core.TimeValue;
14+
import org.elasticsearch.inference.InferenceService;
15+
import org.elasticsearch.inference.MinimalServiceSettings;
16+
import org.elasticsearch.inference.Model;
17+
import org.elasticsearch.inference.TaskType;
18+
import org.elasticsearch.inference.UnparsedModel;
19+
import org.elasticsearch.plugins.Plugin;
20+
import org.elasticsearch.reindex.ReindexPlugin;
21+
import org.elasticsearch.test.ESSingleNodeTestCase;
22+
import org.elasticsearch.test.http.MockResponse;
23+
import org.elasticsearch.test.http.MockWebServer;
24+
import org.elasticsearch.threadpool.ThreadPool;
25+
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
26+
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
27+
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
28+
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
29+
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
30+
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
31+
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationHandler;
32+
import org.junit.After;
33+
import org.junit.Before;
34+
35+
import java.util.Collection;
36+
import java.util.EnumSet;
37+
import java.util.List;
38+
import java.util.concurrent.TimeUnit;
39+
40+
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
41+
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
42+
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
43+
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
44+
import static org.hamcrest.CoreMatchers.is;
45+
import static org.mockito.Mockito.mock;
46+
47+
public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase {
48+
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
49+
50+
private ModelRegistry modelRegistry;
51+
private final MockWebServer webServer = new MockWebServer();
52+
private ThreadPool threadPool;
53+
private String gatewayUrl;
54+
55+
@Before
56+
public void createComponents() throws Exception {
57+
threadPool = createThreadPool(inferenceUtilityPool());
58+
webServer.start();
59+
gatewayUrl = getUrl(webServer);
60+
modelRegistry = new ModelRegistry(client());
61+
}
62+
63+
@After
64+
public void shutdown() {
65+
terminate(threadPool);
66+
webServer.close();
67+
}
68+
69+
@Override
70+
protected boolean resetNodeAfterTest() {
71+
return true;
72+
}
73+
74+
@Override
75+
protected Collection<Class<? extends Plugin>> getPlugins() {
76+
return pluginList(ReindexPlugin.class);
77+
}
78+
79+
public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCorrect() throws Exception {
80+
String responseJson = """
81+
{
82+
"models": [
83+
{
84+
"model_name": "rainbow-sprinkles",
85+
"task_types": ["chat"]
86+
}
87+
]
88+
}
89+
""";
90+
91+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
92+
93+
try (var service = createElasticInferenceService()) {
94+
service.waitForAuthorizationToComplete(TIMEOUT);
95+
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
96+
assertThat(
97+
service.defaultConfigIds(),
98+
is(
99+
List.of(
100+
new InferenceService.DefaultConfigId(".rainbow-sprinkles-elastic", MinimalServiceSettings.chatCompletion(), service)
101+
)
102+
)
103+
);
104+
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
105+
106+
PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
107+
service.defaultConfigs(listener);
108+
assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
109+
}
110+
}
111+
112+
public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty() throws Exception {
113+
{
114+
String responseJson = """
115+
{
116+
"models": [
117+
{
118+
"model_name": "rainbow-sprinkles",
119+
"task_types": ["chat"]
120+
}
121+
]
122+
}
123+
""";
124+
125+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
126+
127+
try (var service = createElasticInferenceService()) {
128+
service.waitForAuthorizationToComplete(TIMEOUT);
129+
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
130+
assertThat(
131+
service.defaultConfigIds(),
132+
is(
133+
List.of(
134+
new InferenceService.DefaultConfigId(
135+
".rainbow-sprinkles-elastic",
136+
MinimalServiceSettings.chatCompletion(),
137+
service
138+
)
139+
)
140+
)
141+
);
142+
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
143+
144+
PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
145+
service.defaultConfigs(listener);
146+
assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
147+
148+
var getModelListener = new PlainActionFuture<UnparsedModel>();
149+
// persists the default endpoints
150+
modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener);
151+
152+
var inferenceEntity = getModelListener.actionGet(TIMEOUT);
153+
assertThat(inferenceEntity.inferenceEntityId(), is(".rainbow-sprinkles-elastic"));
154+
assertThat(inferenceEntity.taskType(), is(TaskType.CHAT_COMPLETION));
155+
}
156+
}
157+
{
158+
String noAuthorizationResponseJson = """
159+
{
160+
"models": []
161+
}
162+
""";
163+
164+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson));
165+
166+
try (var service = createElasticInferenceService()) {
167+
service.waitForAuthorizationToComplete(TIMEOUT);
168+
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
169+
assertTrue(service.defaultConfigIds().isEmpty());
170+
assertThat(service.supportedTaskTypes(), is(EnumSet.noneOf(TaskType.class)));
171+
172+
var getModelListener = new PlainActionFuture<UnparsedModel>();
173+
modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener);
174+
175+
var exception = expectThrows(ResourceNotFoundException.class, () -> getModelListener.actionGet(TIMEOUT));
176+
assertThat(exception.getMessage(), is("Inference endpoint not found [.rainbow-sprinkles-elastic]"));
177+
}
178+
}
179+
}
180+
181+
public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnAuthForIt() throws Exception {
182+
{
183+
String responseJson = """
184+
{
185+
"models": [
186+
{
187+
"model_name": "rainbow-sprinkles",
188+
"task_types": ["chat"]
189+
},
190+
{
191+
"model_name": "elser-v2",
192+
"task_types": ["embed/text/sparse"]
193+
}
194+
]
195+
}
196+
""";
197+
198+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
199+
200+
try (var service = createElasticInferenceService()) {
201+
service.waitForAuthorizationToComplete(TIMEOUT);
202+
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
203+
assertThat(
204+
service.defaultConfigIds(),
205+
is(
206+
List.of(
207+
new InferenceService.DefaultConfigId(
208+
".rainbow-sprinkles-elastic",
209+
MinimalServiceSettings.chatCompletion(),
210+
service
211+
)
212+
)
213+
)
214+
);
215+
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING)));
216+
217+
PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
218+
service.defaultConfigs(listener);
219+
assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
220+
221+
var getModelListener = new PlainActionFuture<UnparsedModel>();
222+
// persists the default endpoints
223+
modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener);
224+
225+
var inferenceEntity = getModelListener.actionGet(TIMEOUT);
226+
assertThat(inferenceEntity.inferenceEntityId(), is(".rainbow-sprinkles-elastic"));
227+
assertThat(inferenceEntity.taskType(), is(TaskType.CHAT_COMPLETION));
228+
}
229+
}
230+
{
231+
String noAuthorizationResponseJson = """
232+
{
233+
"models": [
234+
{
235+
"model_name": "elser-v2",
236+
"task_types": ["embed/text/sparse"]
237+
}
238+
]
239+
}
240+
""";
241+
242+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson));
243+
244+
try (var service = createElasticInferenceService()) {
245+
service.waitForAuthorizationToComplete(TIMEOUT);
246+
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
247+
assertTrue(service.defaultConfigIds().isEmpty());
248+
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING)));
249+
250+
var getModelListener = new PlainActionFuture<UnparsedModel>();
251+
modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener);
252+
253+
var exception = expectThrows(ResourceNotFoundException.class, () -> getModelListener.actionGet(TIMEOUT));
254+
assertThat(exception.getMessage(), is("Inference endpoint not found [.rainbow-sprinkles-elastic]"));
255+
}
256+
}
257+
}
258+
259+
private ElasticInferenceService createElasticInferenceService() {
260+
var httpManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
261+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, httpManager);
262+
263+
return new ElasticInferenceService(
264+
senderFactory,
265+
createWithEmptySettings(threadPool),
266+
new ElasticInferenceServiceComponents(gatewayUrl),
267+
modelRegistry,
268+
new ElasticInferenceServiceAuthorizationHandler(gatewayUrl, threadPool)
269+
);
270+
}
271+
}

0 commit comments

Comments
 (0)