Skip to content

Commit 10ad8f2

Browse files
Adding integration test
1 parent b41b1d4 commit 10ad8f2

File tree

3 files changed

+223
-8
lines changed

3 files changed

+223
-8
lines changed

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java

Lines changed: 165 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,49 +7,92 @@
77

88
package org.elasticsearch.xpack.inference.integration;
99

10+
import org.elasticsearch.action.support.PlainActionFuture;
1011
import org.elasticsearch.common.settings.Settings;
12+
import org.elasticsearch.core.TimeValue;
13+
import org.elasticsearch.inference.TaskType;
14+
import org.elasticsearch.inference.UnparsedModel;
1115
import org.elasticsearch.plugins.Plugin;
1216
import org.elasticsearch.reindex.ReindexPlugin;
1317
import org.elasticsearch.test.ESSingleNodeTestCase;
18+
import org.elasticsearch.test.http.MockResponse;
1419
import org.elasticsearch.test.http.MockWebServer;
1520
import org.elasticsearch.threadpool.ThreadPool;
1621
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
1722
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
23+
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
1824
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings;
25+
import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints;
26+
import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationTaskExecutor;
1927
import org.junit.After;
28+
import org.junit.AfterClass;
2029
import org.junit.Before;
2130
import org.junit.BeforeClass;
2231

2332
import java.io.IOException;
2433
import java.util.Collection;
34+
import java.util.List;
35+
import java.util.function.Function;
36+
import java.util.stream.Collectors;
2537

2638
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
2739
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
40+
import static org.hamcrest.Matchers.empty;
41+
import static org.hamcrest.Matchers.is;
2842

2943
public class AuthorizationTaskExecutorIT extends ESSingleNodeTestCase {
3044

45+
private static final String EMPTY_AUTH_RESPONSE = """
46+
{
47+
"models": [
48+
]
49+
}
50+
""";
51+
52+
private static final String AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE = """
53+
{
54+
"models": [
55+
{
56+
"model_name": "rainbow-sprinkles",
57+
"task_types": ["chat"]
58+
}
59+
]
60+
}
61+
""";
62+
3163
private static final MockWebServer webServer = new MockWebServer();
3264
private static String gatewayUrl;
3365

3466
private ModelRegistry modelRegistry;
3567
private ThreadPool threadPool;
68+
private AuthorizationTaskExecutor authorizationTaskExecutor;
3669

3770
@BeforeClass
3871
public static void initClass() throws IOException {
3972
webServer.start();
4073
gatewayUrl = getUrl(webServer);
41-
// TODO add response to the web server to return no authorized models
74+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EMPTY_AUTH_RESPONSE));
4275
}
4376

4477
@Before
4578
public void createComponents() {
4679
threadPool = createThreadPool(inferenceUtilityExecutors());
4780
modelRegistry = node().injector().getInstance(ModelRegistry.class);
81+
authorizationTaskExecutor = node().injector().getInstance(AuthorizationTaskExecutor.class);
4882
}
4983

5084
@After
5185
public void shutdown() {
86+
// Delete all the eis preconfigured endpoints
87+
var listener = new PlainActionFuture<Boolean>();
88+
modelRegistry.deleteModels(InternalPreconfiguredEndpoints.EIS_PRECONFIGURED_ENDPOINT_IDS, listener);
89+
listener.actionGet(TimeValue.THIRTY_SECONDS);
90+
5291
terminate(threadPool);
92+
}
93+
94+
@AfterClass
95+
public static void cleanUpClass() {
5396
webServer.close();
5497
}
5598

@@ -66,12 +109,126 @@ protected Collection<Class<? extends Plugin>> getPlugins() {
66109
return pluginList(ReindexPlugin.class, LocalStateInferencePlugin.class);
67110
}
68111

69-
public void testCreateEndpoints() {
70-
// verify that no models are authorized
71-
// add request to return an authorized model
72-
// cancel the task
73-
// ensure the task is recreated?
74-
// verify that the authorized model is present
75-
fail("Not implemented yet");
112+
public void testCreatesEisChatCompletionEndpoint() throws Exception {
113+
assertNoAuthorizedEisEndpoints();
114+
115+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE));
116+
waitForNewAuthorizationResponse();
117+
118+
assertChatCompletionEndpointExists();
119+
}
120+
121+
private void assertNoAuthorizedEisEndpoints() throws Exception {
122+
assertBusy(() -> {
123+
var newPoller = authorizationTaskExecutor.getCurrentPollerTask();
124+
assertNotNull(newPoller);
125+
newPoller.waitForAuthorizationToComplete(TimeValue.THIRTY_SECONDS);
126+
});
127+
128+
var eisEndpoints = getEisEndpoints();
129+
assertThat(eisEndpoints, empty());
130+
}
131+
132+
private List<UnparsedModel> getEisEndpoints() {
133+
var listener = new PlainActionFuture<List<UnparsedModel>>();
134+
modelRegistry.getAllModels(false, listener);
135+
136+
var endpoints = listener.actionGet(TimeValue.THIRTY_SECONDS);
137+
return endpoints.stream().filter(m -> m.service().equals(ElasticInferenceService.NAME)).toList();
138+
}
139+
140+
private void waitForNewAuthorizationResponse() throws Exception {
141+
var taskListener = new PlainActionFuture<Void>();
142+
143+
authorizationTaskExecutor.abortTask(TimeValue.THIRTY_SECONDS, taskListener);
144+
// Ensure that the listener doesn't return a failure
145+
assertNull(taskListener.actionGet(TimeValue.THIRTY_SECONDS));
146+
147+
// wait for the new task to be recreated
148+
assertBusy(() -> {
149+
var newPoller = authorizationTaskExecutor.getCurrentPollerTask();
150+
assertNotNull(newPoller);
151+
newPoller.waitForAuthorizationToComplete(TimeValue.THIRTY_SECONDS);
152+
});
153+
}
154+
155+
public void testCreatesEisChatCompletion_DoesNotRemoveEndpointWhenNoLongerAuthorized() throws Exception {
156+
assertNoAuthorizedEisEndpoints();
157+
158+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE));
159+
waitForNewAuthorizationResponse();
160+
161+
assertChatCompletionEndpointExists();
162+
163+
// Simulate that the model is no longer authorized
164+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EMPTY_AUTH_RESPONSE));
165+
waitForNewAuthorizationResponse();
166+
167+
assertChatCompletionEndpointExists();
168+
}
169+
170+
private void assertChatCompletionEndpointExists() {
171+
var eisEndpoints = getEisEndpoints();
172+
assertThat(eisEndpoints.size(), is(1));
173+
174+
var rainbowSprinklesModel = eisEndpoints.get(0);
175+
assertChatCompletionUnparsedModel(rainbowSprinklesModel);
176+
}
177+
178+
private void assertChatCompletionUnparsedModel(UnparsedModel rainbowSprinklesModel) {
179+
assertThat(rainbowSprinklesModel.taskType(), is(TaskType.CHAT_COMPLETION));
180+
assertThat(rainbowSprinklesModel.service(), is(ElasticInferenceService.NAME));
181+
assertThat(rainbowSprinklesModel.inferenceEntityId(), is(InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1));
182+
}
183+
184+
public void testCreatesChatCompletion_AndThenCreatesTextEmbedding() throws Exception {
185+
assertNoAuthorizedEisEndpoints();
186+
187+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE));
188+
waitForNewAuthorizationResponse();
189+
190+
assertChatCompletionEndpointExists();
191+
192+
// Simulate that the model is no longer authorized
193+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EMPTY_AUTH_RESPONSE));
194+
waitForNewAuthorizationResponse();
195+
196+
assertChatCompletionEndpointExists();
197+
198+
// Simulate that a text embedding model is now authorized
199+
var authorizedTextEmbeddingResponse = """
200+
{
201+
"models": [
202+
{
203+
"model_name": "multilingual-embed-v1",
204+
"task_types": ["embed/text/dense"]
205+
}
206+
]
207+
}
208+
""";
209+
210+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(authorizedTextEmbeddingResponse));
211+
waitForNewAuthorizationResponse();
212+
213+
var eisEndpoints = getEisEndpoints().stream().collect(Collectors.toMap(UnparsedModel::inferenceEntityId, Function.identity()));
214+
assertThat(eisEndpoints.size(), is(2));
215+
216+
assertTrue(eisEndpoints.containsKey(InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1));
217+
assertChatCompletionUnparsedModel(eisEndpoints.get(InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1));
218+
219+
assertTrue(eisEndpoints.containsKey(InternalPreconfiguredEndpoints.DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID));
220+
221+
var textEmbeddingEndpoint = eisEndpoints.get(InternalPreconfiguredEndpoints.DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID);
222+
assertThat(textEmbeddingEndpoint.taskType(), is(TaskType.TEXT_EMBEDDING));
223+
assertThat(textEmbeddingEndpoint.service(), is(ElasticInferenceService.NAME));
224+
}
225+
226+
public void testRestartsTaskAfterAbort() throws Exception {
227+
// Ensure the task is created and we get an initial authorization response
228+
assertNoAuthorizedEisEndpoints();
229+
230+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EMPTY_AUTH_RESPONSE));
231+
// Abort the task and ensure it is restarted
232+
waitForNewAuthorizationResponse();
76233
}
77234
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
import java.util.Map;
3333
import java.util.Objects;
3434
import java.util.Set;
35+
import java.util.concurrent.CountDownLatch;
36+
import java.util.concurrent.TimeUnit;
3537
import java.util.concurrent.atomic.AtomicBoolean;
3638
import java.util.concurrent.atomic.AtomicReference;
3739
import java.util.stream.Collectors;
@@ -56,6 +58,7 @@ public class AuthorizationPoller extends AllocatedPersistentTask {
5658
private final AtomicBoolean initialized = new AtomicBoolean(false);
5759
private final ElasticInferenceServiceComponents elasticInferenceServiceComponents;
5860
private final Client client;
61+
private final CountDownLatch receivedFirstAuthResponseLatch = new CountDownLatch(1);
5962

6063
public record TaskFields(long id, String type, String action, String description, TaskId parentTask, Map<String, String> headers) {}
6164

@@ -117,9 +120,23 @@ public void start() {
117120
}
118121
}
119122

123+
/**
124+
* This should only be used for testing to wait for the first authorization response to be received.
125+
*/
126+
public void waitForAuthorizationToComplete(TimeValue waitTime) {
127+
try {
128+
if (receivedFirstAuthResponseLatch.await(waitTime.getSeconds(), TimeUnit.SECONDS) == false) {
129+
throw new IllegalStateException("The wait time has expired for first authorization response to be received.");
130+
}
131+
} catch (InterruptedException e) {
132+
throw new IllegalStateException("Waiting for first authorization response to complete was interrupted");
133+
}
134+
}
135+
120136
@Override
121137
protected void onCancelled() {
122138
shutdown();
139+
markAsCompleted();
123140
}
124141

125142
// default for testing
@@ -182,6 +199,7 @@ void sendAuthorizationRequest() {
182199
if (callback != null) {
183200
callback.run();
184201
}
202+
receivedFirstAuthResponseLatch.countDown();
185203
}).delegateResponse((delegate, e) -> {
186204
logger.atWarn().withThrowable(e).log("Failed processing EIS preconfigured endpoints");
187205
delegate.onResponse(null);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,49 @@ void init() {
7979
}
8080
}
8181

82+
/**
83+
* This method should only be used for testing purposes to simulate a task being recreated.
84+
*/
85+
public void abortTask(TimeValue timeout, ActionListener<Void> listener) {
86+
var task = currentTask.get();
87+
if (task != null && task.isCancelled() == false) {
88+
task.markAsLocallyAborted("testing task cancellation");
89+
currentTask.set(null);
90+
waitForNullTask(task, timeout, listener);
91+
} else {
92+
listener.onFailure(new IllegalStateException("Authorization poller task was not created yet, or was already aborted"));
93+
}
94+
}
95+
96+
private void waitForNullTask(AllocatedPersistentTask task, TimeValue timeout, ActionListener<Void> listener) {
97+
task.waitForPersistentTask(
98+
Objects::isNull,
99+
timeout,
100+
new PersistentTasksService.WaitForPersistentTaskListener<AuthorizationTaskParams>() {
101+
@Override
102+
public void onResponse(PersistentTasksCustomMetadata.PersistentTask<AuthorizationTaskParams> persistentTask) {
103+
listener.onResponse(null);
104+
}
105+
106+
@Override
107+
public void onFailure(Exception e) {
108+
listener.onFailure(e);
109+
}
110+
}
111+
);
112+
}
113+
114+
/**
115+
* This method should only be used for testing purposes to get the current running task.
116+
*/
117+
public AuthorizationPoller getCurrentPollerTask() {
118+
return currentTask.get();
119+
}
120+
82121
@Override
83122
protected void nodeOperation(AllocatedPersistentTask task, AuthorizationTaskParams params, PersistentTaskState state) {
84123
var authPoller = (AuthorizationPoller) task;
124+
currentTask.set(authPoller);
85125
authPoller.start();
86126
}
87127

0 commit comments

Comments
 (0)