Skip to content

Commit 7b760b8

Browse files
Adding integration tests
1 parent 25828f5 commit 7b760b8

File tree

4 files changed

+255
-7
lines changed

4 files changed

+255
-7
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.integration;
9+
10+
import org.elasticsearch.ResourceNotFoundException;
11+
import org.elasticsearch.action.support.PlainActionFuture;
12+
import org.elasticsearch.common.settings.Settings;
13+
import org.elasticsearch.core.TimeValue;
14+
import org.elasticsearch.inference.InferenceService;
15+
import org.elasticsearch.inference.MinimalServiceSettings;
16+
import org.elasticsearch.inference.Model;
17+
import org.elasticsearch.inference.TaskType;
18+
import org.elasticsearch.inference.UnparsedModel;
19+
import org.elasticsearch.plugins.Plugin;
20+
import org.elasticsearch.reindex.ReindexPlugin;
21+
import org.elasticsearch.test.ESSingleNodeTestCase;
22+
import org.elasticsearch.test.http.MockResponse;
23+
import org.elasticsearch.test.http.MockWebServer;
24+
import org.elasticsearch.threadpool.ThreadPool;
25+
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
26+
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
27+
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
28+
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
29+
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
30+
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
31+
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationHandler;
32+
import org.junit.After;
33+
import org.junit.Before;
34+
35+
import java.util.Collection;
36+
import java.util.EnumSet;
37+
import java.util.List;
38+
import java.util.concurrent.TimeUnit;
39+
40+
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
41+
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
42+
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
43+
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
44+
import static org.hamcrest.CoreMatchers.is;
45+
import static org.mockito.Mockito.mock;
46+
47+
public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase {
48+
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
49+
50+
private ModelRegistry modelRegistry;
51+
private final MockWebServer webServer = new MockWebServer();
52+
private ThreadPool threadPool;
53+
54+
@Before
55+
public void createComponents() throws Exception {
56+
threadPool = createThreadPool(inferenceUtilityPool());
57+
webServer.start();
58+
modelRegistry = new ModelRegistry(client());
59+
}
60+
61+
@After
62+
public void shutdown() {
63+
terminate(threadPool);
64+
webServer.close();
65+
}
66+
67+
@Override
68+
protected Collection<Class<? extends Plugin>> getPlugins() {
69+
return pluginList(ReindexPlugin.class);
70+
}
71+
72+
public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCorrect() throws Exception {
73+
var clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
74+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
75+
var gatewayUrl = getUrl(webServer);
76+
77+
String responseJson = """
78+
{
79+
"models": [
80+
{
81+
"model_name": "rainbow-sprinkles",
82+
"task_types": ["chat"]
83+
}
84+
]
85+
}
86+
""";
87+
88+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
89+
90+
try (
91+
var service = new ElasticInferenceService(
92+
senderFactory,
93+
createWithEmptySettings(threadPool),
94+
new ElasticInferenceServiceComponents(gatewayUrl),
95+
modelRegistry,
96+
new ElasticInferenceServiceAuthorizationHandler(gatewayUrl, threadPool)
97+
)
98+
) {
99+
service.waitForAuthorizationToComplete(TIMEOUT);
100+
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY)));
101+
assertThat(
102+
service.defaultConfigIds(),
103+
is(
104+
List.of(
105+
new InferenceService.DefaultConfigId(".rainbow-sprinkles-elastic", MinimalServiceSettings.chatCompletion(), service)
106+
)
107+
)
108+
);
109+
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
110+
111+
PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
112+
service.defaultConfigs(listener);
113+
assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
114+
}
115+
}
116+
117+
public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty() throws Exception {
118+
var gatewayUrl = getUrl(webServer);
119+
120+
{
121+
var clientManager = HttpClientManager.create(
122+
Settings.EMPTY,
123+
threadPool,
124+
mockClusterServiceEmpty(),
125+
mock(ThrottlerManager.class)
126+
);
127+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
128+
129+
String responseJson = """
130+
{
131+
"models": [
132+
{
133+
"model_name": "rainbow-sprinkles",
134+
"task_types": ["chat"]
135+
}
136+
]
137+
}
138+
""";
139+
140+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
141+
142+
try (
143+
var service = new ElasticInferenceService(
144+
senderFactory,
145+
createWithEmptySettings(threadPool),
146+
new ElasticInferenceServiceComponents(gatewayUrl),
147+
modelRegistry,
148+
new ElasticInferenceServiceAuthorizationHandler(gatewayUrl, threadPool)
149+
)
150+
) {
151+
service.waitForAuthorizationToComplete(TIMEOUT);
152+
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY)));
153+
assertThat(
154+
service.defaultConfigIds(),
155+
is(
156+
List.of(
157+
new InferenceService.DefaultConfigId(
158+
".rainbow-sprinkles-elastic",
159+
MinimalServiceSettings.chatCompletion(),
160+
service
161+
)
162+
)
163+
)
164+
);
165+
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
166+
167+
PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
168+
service.defaultConfigs(listener);
169+
assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
170+
171+
var getModelListener = new PlainActionFuture<UnparsedModel>();
172+
// persists the default endpoints
173+
modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener);
174+
175+
var inferenceEntity = getModelListener.actionGet(TIMEOUT);
176+
assertThat(inferenceEntity.inferenceEntityId(), is(".rainbow-sprinkles-elastic"));
177+
assertThat(inferenceEntity.taskType(), is(TaskType.CHAT_COMPLETION));
178+
}
179+
}
180+
{
181+
String noAuthorizationResponseJson = """
182+
{
183+
"models": []
184+
}
185+
""";
186+
187+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson));
188+
189+
var httpManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
190+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, httpManager);
191+
192+
try (
193+
var service = new ElasticInferenceService(
194+
senderFactory,
195+
createWithEmptySettings(threadPool),
196+
new ElasticInferenceServiceComponents(gatewayUrl),
197+
modelRegistry,
198+
new ElasticInferenceServiceAuthorizationHandler(gatewayUrl, threadPool)
199+
)
200+
) {
201+
service.waitForAuthorizationToComplete(TIMEOUT);
202+
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
203+
assertTrue(service.defaultConfigIds().isEmpty());
204+
assertThat(service.supportedTaskTypes(), is(EnumSet.noneOf(TaskType.class)));
205+
206+
var getModelListener = new PlainActionFuture<UnparsedModel>();
207+
modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener);
208+
209+
var exception = expectThrows(ResourceNotFoundException.class, () -> getModelListener.actionGet(TIMEOUT));
210+
assertThat(exception.getMessage(), is("Inference endpoint not found [.rainbow-sprinkles-elastic]"));
211+
}
212+
}
213+
}
214+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
import java.util.concurrent.CountDownLatch;
6969
import java.util.concurrent.TimeUnit;
7070
import java.util.concurrent.atomic.AtomicReference;
71+
import java.util.stream.Collectors;
7172

7273
import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException;
7374
import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME;
@@ -159,10 +160,7 @@ static AuthorizedContent empty() {
159160

160161
private void getAuthorization() {
161162
try {
162-
ActionListener<ElasticInferenceServiceAuthorization> listener = ActionListener.wrap(result -> {
163-
setAuthorizedContent(result);
164-
authorizationCompletedLatch.countDown();
165-
}, e -> {
163+
ActionListener<ElasticInferenceServiceAuthorization> listener = ActionListener.wrap(this::setAuthorizedContent, e -> {
166164
// we don't need to do anything if there was a failure, everything is disabled by default
167165
authorizationCompletedLatch.countDown();
168166
});
@@ -240,21 +238,34 @@ private void handleRevokedDefaultConfigs(Set<String> authorizedDefaultModelIds)
240238
var unauthorizedDefaultModelIds = new HashSet<>(defaultModelsConfigs.keySet());
241239
unauthorizedDefaultModelIds.removeAll(authorizedDefaultModelIds);
242240

241+
// get all the default inference endpoint ids for the unauthorized model ids
242+
var unauthorizedDefaultInferenceEndpointIds = unauthorizedDefaultModelIds.stream()
243+
.map(defaultModelsConfigs::get) // get all the model configs
244+
.filter(Objects::nonNull) // limit to only non-null
245+
.map(modelConfig -> modelConfig.model.getInferenceEntityId()) // get the inference ids
246+
.collect(Collectors.toSet());
247+
243248
var deleteInferenceEndpointsListener = ActionListener.<Boolean>wrap(result -> {
244249
logger.trace(Strings.format("Successfully revoked access to default inference endpoint IDs: %s", unauthorizedDefaultModelIds));
250+
authorizationCompletedLatch.countDown();
245251
}, e -> {
246252
logger.warn(
247253
Strings.format("Failed to revoke access to default inference endpoint IDs: %s, error: %s", unauthorizedDefaultModelIds, e)
248254
);
255+
authorizationCompletedLatch.countDown();
249256
});
250257

251258
getServiceComponents().threadPool()
252259
.executor(UTILITY_THREAD_POOL_NAME)
253-
.execute(() -> modelRegistry.removeDefaultConfigs(unauthorizedDefaultModelIds, deleteInferenceEndpointsListener));
260+
.execute(() -> modelRegistry.removeDefaultConfigs(unauthorizedDefaultInferenceEndpointIds, deleteInferenceEndpointsListener));
254261
}
255262

256-
// Default for testing
257-
void waitForAuthorizationToComplete(TimeValue waitTime) {
263+
/**
264+
* Waits the specified amount of time for the authorization call to complete. This is mainly to make testing easier.
265+
* @param waitTime the max time to wait
266+
* @throws IllegalStateException if the wait time is exceeded or the call receives an {@link InterruptedException}
267+
*/
268+
public void waitForAuthorizationToComplete(TimeValue waitTime) {
258269
try {
259270
if (authorizationCompletedLatch.await(waitTime.getSeconds(), TimeUnit.SECONDS) == false) {
260271
throw new IllegalStateException("The wait time has expired for authorization to complete.");

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import java.nio.ByteBuffer;
4242
import java.util.ArrayList;
4343
import java.util.Map;
44+
import java.util.Set;
4445
import java.util.concurrent.TimeUnit;
4546

4647
import static org.elasticsearch.core.Strings.format;
@@ -52,6 +53,8 @@
5253
import static org.mockito.ArgumentMatchers.any;
5354
import static org.mockito.Mockito.doAnswer;
5455
import static org.mockito.Mockito.mock;
56+
import static org.mockito.Mockito.times;
57+
import static org.mockito.Mockito.verify;
5558
import static org.mockito.Mockito.when;
5659

5760
public class ModelRegistryTests extends ESTestCase {
@@ -295,6 +298,18 @@ public void testStoreModel_ThrowsException_WhenFailureIsNotAVersionConflict() {
295298
);
296299
}
297300

301+
public void testRemoveDefaultConfigs_DoesNotCallClient_WhenPassedAnEmptySet() {
302+
var client = mock(Client.class);
303+
304+
var registry = new ModelRegistry(client);
305+
var listener = new PlainActionFuture<Boolean>();
306+
307+
registry.removeDefaultConfigs(Set.of(), listener);
308+
309+
assertTrue(listener.actionGet(TIMEOUT));
310+
verify(client, times(0)).execute(any(), any(), any());
311+
}
312+
298313
public void testIdMatchedDefault() {
299314
var defaultConfigIds = new ArrayList<InferenceService.DefaultConfigId>();
300315
defaultConfigIds.add(

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,14 @@ public void testInfer_ThrowsErrorWhenModelIsNotAValidModel() throws IOException
356356
private ModelRegistry mockModelRegistry() {
357357
var client = mock(Client.class);
358358
when(client.threadPool()).thenReturn(threadPool);
359+
360+
doAnswer(invocationOnMock -> {
361+
@SuppressWarnings("unchecked")
362+
var listener = (ActionListener<Boolean>) invocationOnMock.getArgument(2);
363+
listener.onResponse(true);
364+
365+
return Void.TYPE;
366+
}).when(client).execute(any(), any(), any());
359367
return new ModelRegistry(client);
360368
}
361369

0 commit comments

Comments
 (0)