Skip to content

Commit b9e20b6

Browse files
More integration tests
1 parent 7b760b8 commit b9e20b6

File tree

1 file changed

+101
-44
lines changed

1 file changed

+101
-44
lines changed

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java

Lines changed: 101 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,13 @@ public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase {
5050
private ModelRegistry modelRegistry;
5151
private final MockWebServer webServer = new MockWebServer();
5252
private ThreadPool threadPool;
53+
private String gatewayUrl;
5354

5455
@Before
5556
public void createComponents() throws Exception {
5657
threadPool = createThreadPool(inferenceUtilityPool());
5758
webServer.start();
59+
gatewayUrl = getUrl(webServer);
5860
modelRegistry = new ModelRegistry(client());
5961
}
6062

@@ -64,16 +66,17 @@ public void shutdown() {
6466
webServer.close();
6567
}
6668

69+
@Override
70+
protected boolean resetNodeAfterTest() {
71+
return true;
72+
}
73+
6774
@Override
6875
protected Collection<Class<? extends Plugin>> getPlugins() {
6976
return pluginList(ReindexPlugin.class);
7077
}
7178

7279
public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCorrect() throws Exception {
73-
var clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
74-
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
75-
var gatewayUrl = getUrl(webServer);
76-
7780
String responseJson = """
7881
{
7982
"models": [
@@ -87,15 +90,7 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCo
8790

8891
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
8992

90-
try (
91-
var service = new ElasticInferenceService(
92-
senderFactory,
93-
createWithEmptySettings(threadPool),
94-
new ElasticInferenceServiceComponents(gatewayUrl),
95-
modelRegistry,
96-
new ElasticInferenceServiceAuthorizationHandler(gatewayUrl, threadPool)
97-
)
98-
) {
93+
try (var service = createElasticInferenceService()) {
9994
service.waitForAuthorizationToComplete(TIMEOUT);
10095
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY)));
10196
assertThat(
@@ -115,17 +110,7 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCo
115110
}
116111

117112
public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty() throws Exception {
118-
var gatewayUrl = getUrl(webServer);
119-
120113
{
121-
var clientManager = HttpClientManager.create(
122-
Settings.EMPTY,
123-
threadPool,
124-
mockClusterServiceEmpty(),
125-
mock(ThrottlerManager.class)
126-
);
127-
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
128-
129114
String responseJson = """
130115
{
131116
"models": [
@@ -139,15 +124,7 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty()
139124

140125
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
141126

142-
try (
143-
var service = new ElasticInferenceService(
144-
senderFactory,
145-
createWithEmptySettings(threadPool),
146-
new ElasticInferenceServiceComponents(gatewayUrl),
147-
modelRegistry,
148-
new ElasticInferenceServiceAuthorizationHandler(gatewayUrl, threadPool)
149-
)
150-
) {
127+
try (var service = createElasticInferenceService()) {
151128
service.waitForAuthorizationToComplete(TIMEOUT);
152129
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY)));
153130
assertThat(
@@ -186,18 +163,7 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty()
186163

187164
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson));
188165

189-
var httpManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
190-
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, httpManager);
191-
192-
try (
193-
var service = new ElasticInferenceService(
194-
senderFactory,
195-
createWithEmptySettings(threadPool),
196-
new ElasticInferenceServiceComponents(gatewayUrl),
197-
modelRegistry,
198-
new ElasticInferenceServiceAuthorizationHandler(gatewayUrl, threadPool)
199-
)
200-
) {
166+
try (var service = createElasticInferenceService()) {
201167
service.waitForAuthorizationToComplete(TIMEOUT);
202168
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
203169
assertTrue(service.defaultConfigIds().isEmpty());
@@ -211,4 +177,95 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty()
211177
}
212178
}
213179
}
180+
181+
public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnAuthForIt() throws Exception {
182+
{
183+
String responseJson = """
184+
{
185+
"models": [
186+
{
187+
"model_name": "rainbow-sprinkles",
188+
"task_types": ["chat"]
189+
},
190+
{
191+
"model_name": "elser-v2",
192+
"task_types": ["embed/text/sparse"]
193+
}
194+
]
195+
}
196+
""";
197+
198+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
199+
200+
try (var service = createElasticInferenceService()) {
201+
service.waitForAuthorizationToComplete(TIMEOUT);
202+
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY)));
203+
assertThat(
204+
service.defaultConfigIds(),
205+
is(
206+
List.of(
207+
new InferenceService.DefaultConfigId(
208+
".rainbow-sprinkles-elastic",
209+
MinimalServiceSettings.chatCompletion(),
210+
service
211+
)
212+
)
213+
)
214+
);
215+
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING)));
216+
217+
PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
218+
service.defaultConfigs(listener);
219+
assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
220+
221+
var getModelListener = new PlainActionFuture<UnparsedModel>();
222+
// persists the default endpoints
223+
modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener);
224+
225+
var inferenceEntity = getModelListener.actionGet(TIMEOUT);
226+
assertThat(inferenceEntity.inferenceEntityId(), is(".rainbow-sprinkles-elastic"));
227+
assertThat(inferenceEntity.taskType(), is(TaskType.CHAT_COMPLETION));
228+
}
229+
}
230+
{
231+
String noAuthorizationResponseJson = """
232+
{
233+
"models": [
234+
{
235+
"model_name": "elser-v2",
236+
"task_types": ["embed/text/sparse"]
237+
}
238+
]
239+
}
240+
""";
241+
242+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson));
243+
244+
try (var service = createElasticInferenceService()) {
245+
service.waitForAuthorizationToComplete(TIMEOUT);
246+
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
247+
assertTrue(service.defaultConfigIds().isEmpty());
248+
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING)));
249+
250+
var getModelListener = new PlainActionFuture<UnparsedModel>();
251+
modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener);
252+
253+
var exception = expectThrows(ResourceNotFoundException.class, () -> getModelListener.actionGet(TIMEOUT));
254+
assertThat(exception.getMessage(), is("Inference endpoint not found [.rainbow-sprinkles-elastic]"));
255+
}
256+
}
257+
}
258+
259+
private ElasticInferenceService createElasticInferenceService() {
260+
var httpManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
261+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, httpManager);
262+
263+
return new ElasticInferenceService(
264+
senderFactory,
265+
createWithEmptySettings(threadPool),
266+
new ElasticInferenceServiceComponents(gatewayUrl),
267+
modelRegistry,
268+
new ElasticInferenceServiceAuthorizationHandler(gatewayUrl, threadPool)
269+
);
270+
}
214271
}

0 commit comments

Comments
 (0)