Skip to content

Commit 41c1208

Browse files
Adding tests
1 parent d5b43b1 commit 41c1208

File tree

2 files changed

+64
-59
lines changed

2 files changed

+64
-59
lines changed

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceAuthorizationResponseEntityTests.java

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,19 @@
1717

1818
import static org.hamcrest.Matchers.is;
1919

20-
2120
public class ElasticInferenceServiceAuthorizationResponseEntityTests extends ESTestCase {
2221

2322
public void testParseAllFields() throws IOException {
2423
String json = """
25-
{
26-
"models": [
27-
{
28-
"model_name": "test_model",
29-
"task_types": ["embedding/text/sparse", "chat/completion"]
30-
}
31-
]
32-
}
33-
""";
24+
{
25+
"models": [
26+
{
27+
"model_name": "test_model",
28+
"task_types": ["embed/text/sparse", "chat"]
29+
}
30+
]
31+
}
32+
""";
3433

3534
try (var parser = createParser(JsonXContent.jsonXContent, json)) {
3635
var entity = ElasticInferenceServiceAuthorizationResponseEntity.PARSER.apply(parser, null);
@@ -47,13 +46,12 @@ public void testParseAllFields() throws IOException {
4746
}
4847
}
4948

50-
5149
public void testParsing_EmptyModels() throws IOException {
5250
String json = """
53-
{
54-
"models": []
55-
}
56-
""";
51+
{
52+
"models": []
53+
}
54+
""";
5755

5856
try (var parser = createParser(JsonXContent.jsonXContent, json)) {
5957
var entity = ElasticInferenceServiceAuthorizationResponseEntity.PARSER.apply(parser, null);

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java

Lines changed: 51 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,25 @@
1212
import org.elasticsearch.action.support.PlainActionFuture;
1313
import org.elasticsearch.common.settings.Settings;
1414
import org.elasticsearch.core.TimeValue;
15+
import org.elasticsearch.inference.InferenceServiceResults;
1516
import org.elasticsearch.inference.TaskType;
1617
import org.elasticsearch.test.ESTestCase;
1718
import org.elasticsearch.test.http.MockResponse;
1819
import org.elasticsearch.test.http.MockWebServer;
1920
import org.elasticsearch.threadpool.ThreadPool;
21+
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
2022
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
23+
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
2124
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
25+
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
2226
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
2327
import org.junit.After;
2428
import org.junit.Before;
2529
import org.mockito.ArgumentCaptor;
2630

2731
import java.io.IOException;
2832
import java.util.EnumSet;
33+
import java.util.List;
2934
import java.util.concurrent.TimeUnit;
3035

3136
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
@@ -34,10 +39,12 @@
3439
import static org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender.MAX_RETIES;
3540
import static org.hamcrest.Matchers.is;
3641
import static org.mockito.ArgumentMatchers.any;
42+
import static org.mockito.Mockito.doAnswer;
3743
import static org.mockito.Mockito.mock;
3844
import static org.mockito.Mockito.times;
3945
import static org.mockito.Mockito.verify;
4046
import static org.mockito.Mockito.verifyNoMoreInteractions;
47+
import static org.mockito.Mockito.when;
4148

4249
public class ElasticInferenceServiceAuthorizationHandlerTests extends ESTestCase {
4350
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
@@ -185,6 +192,7 @@ public void testGetAuthorization_ReturnsAValidResponse() throws IOException {
185192
verifyNoMoreInteractions(logger);
186193
}
187194
}
195+
188196
@SuppressWarnings("unchecked")
189197
public void testGetAuthorization_OnResponseCalledOnce() throws IllegalStateException {
190198
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
@@ -195,64 +203,63 @@ public void testGetAuthorization_OnResponseCalledOnce() throws IllegalStateExcep
195203
ActionListener<ElasticInferenceServiceAuthorization> listener = mock(ActionListener.class);
196204

197205
String responseJson = """
198-
{
199-
"models": [
200-
{
201-
"model_name": "model-a",
202-
"task_types": ["embedding/text/sparse", "chat/completion"]
203-
}
204-
]
205-
}
206-
""";
206+
{
207+
"models": [
208+
{
209+
"model_name": "model-a",
210+
"task_types": ["embed/text/sparse", "chat"]
211+
}
212+
]
213+
}
214+
""";
207215

208216
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
209217

210218
authHandler.getAuthorization(listener, senderFactory.createSender());
211-
authHandler.waitForAuthRequestCompletion(TimeValue.timeValueSeconds(1));
219+
authHandler.waitForAuthRequestCompletion(TIMEOUT);
212220

213221
verify(listener, times(1)).onResponse(any());
222+
var loggerArgsCaptor = ArgumentCaptor.forClass(String.class);
223+
verify(logger, times(1)).debug(loggerArgsCaptor.capture());
224+
225+
var message = loggerArgsCaptor.getValue();
226+
assertThat(message, is("Retrieving authorization information from the Elastic Inference Service."));
214227
verifyNoMoreInteractions(logger);
215228
}
216229

217-
public void testGetAuthorization_InvalidResponse() throws IOException {
218-
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
219-
var eisGatewayUrl = getUrl(webServer);
220-
var logger = mock(Logger.class);
221-
var authHandler = new ElasticInferenceServiceAuthorizationHandler(eisGatewayUrl, threadPool, logger);
230+
public void testGetAuthorization_InvalidResponse() throws IOException {
231+
var senderMock = mock(Sender.class);
232+
var senderFactory = mock(HttpRequestSender.Factory.class);
233+
when(senderFactory.createSender()).thenReturn(senderMock);
222234

223-
try (var sender = senderFactory.createSender()) {
224-
String responseJson = """
225-
{
226-
"completion": [
227-
{
228-
"result": "some result 1"
229-
},
230-
{
231-
"result": "some result 2"
232-
}
233-
]
234-
}
235-
""";
235+
doAnswer(invocationOnMock -> {
236+
ActionListener<InferenceServiceResults> listener = invocationOnMock.getArgument(4);
237+
listener.onResponse(new ChatCompletionResults(List.of(new ChatCompletionResults.Result("awesome"))));
238+
return Void.TYPE;
239+
}).when(senderMock).sendWithoutQueuing(any(), any(), any(), any(), any());
236240

237-
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
241+
var logger = mock(Logger.class);
242+
var authHandler = new ElasticInferenceServiceAuthorizationHandler("abc", threadPool, logger);
243+
244+
try (var sender = senderFactory.createSender()) {
245+
PlainActionFuture<ElasticInferenceServiceAuthorization> listener = new PlainActionFuture<>();
238246

239-
PlainActionFuture<ElasticInferenceServiceAuthorization> listener = new PlainActionFuture<>();
240-
authHandler.getAuthorization(listener, sender);
247+
authHandler.getAuthorization(listener, sender);
248+
var result = listener.actionGet(TIMEOUT);
241249

242-
var authResponse = listener.actionGet(TIMEOUT);
243-
assertTrue(authResponse.enabledTaskTypes().isEmpty());
244-
assertFalse(authResponse.isEnabled());
250+
assertThat(result, is(ElasticInferenceServiceAuthorization.newDisabledService()));
245251

246-
var loggerArgsCaptor = ArgumentCaptor.forClass(String.class);
247-
verify(logger).warn(loggerArgsCaptor.capture());
248-
var message = loggerArgsCaptor.getValue();
249-
assertThat(
250-
message,
251-
is(
252-
"Failed to retrieve the authorization information from the Elastic Inference Service."
253-
+ " Received an invalid response type: InferenceServiceResults"
254-
));
255-
}
252+
var loggerArgsCaptor = ArgumentCaptor.forClass(String.class);
253+
verify(logger).warn(loggerArgsCaptor.capture());
254+
var message = loggerArgsCaptor.getValue();
255+
assertThat(
256+
message,
257+
is(
258+
"Failed to retrieve the authorization information from the Elastic Inference Service."
259+
+ " Received an invalid response type: ChatCompletionResults"
260+
)
261+
);
262+
}
256263

257264
}
258265

0 commit comments

Comments
 (0)