Skip to content

Commit d5b43b1

Browse files
add AuthHandler, AuthRequest, and AuthResponseEntity tests
1 parent b3d2687 commit d5b43b1

File tree

5 files changed

+207
-0
lines changed

5 files changed

+207
-0
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceAuthorizationResponseEntity.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,4 +163,17 @@ public List<? extends InferenceResults> transformToLegacyFormat() {
163163
public Map<String, Object> asMap() {
164164
throw new UnsupportedOperationException("Not implemented");
165165
}
166+
167+
@Override
168+
public boolean equals(Object o) {
169+
if (this == o) return true;
170+
if (o == null || getClass() != o.getClass()) return false;
171+
ElasticInferenceServiceAuthorizationResponseEntity that = (ElasticInferenceServiceAuthorizationResponseEntity) o;
172+
return Objects.equals(authorizedModels, that.authorizedModels);
173+
}
174+
175+
@Override
176+
public int hashCode() {
177+
return Objects.hash(authorizedModels);
178+
}
166179
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.elasticsearch.action.ActionListener;
1414
import org.elasticsearch.common.Strings;
1515
import org.elasticsearch.core.Nullable;
16+
import org.elasticsearch.core.TimeValue;
1617
import org.elasticsearch.inference.InferenceServiceResults;
1718
import org.elasticsearch.tasks.Task;
1819
import org.elasticsearch.threadpool.ThreadPool;
@@ -25,6 +26,8 @@
2526

2627
import java.util.Locale;
2728
import java.util.Objects;
29+
import java.util.concurrent.CountDownLatch;
30+
import java.util.concurrent.TimeUnit;
2831

2932
import static org.elasticsearch.xpack.core.inference.action.InferenceAction.Request.DEFAULT_TIMEOUT;
3033
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER;
@@ -48,6 +51,7 @@ private static ResponseHandler createAuthResponseHandler() {
4851
private final String baseUrl;
4952
private final ThreadPool threadPool;
5053
private final Logger logger;
54+
private final CountDownLatch requestCompleteLatch = new CountDownLatch(1);
5155

5256
public ElasticInferenceServiceAuthorizationHandler(@Nullable String baseUrl, ThreadPool threadPool) {
5357
this.baseUrl = baseUrl;
@@ -92,6 +96,7 @@ public void getAuthorization(ActionListener<ElasticInferenceServiceAuthorization
9296
);
9397
listener.onResponse(ElasticInferenceServiceAuthorization.newDisabledService());
9498
}
99+
requestCompleteLatch.countDown();
95100
}, e -> {
96101
Throwable exception = e;
97102
if (e instanceof ElasticsearchWrapperException wrapperException) {
@@ -100,6 +105,7 @@ public void getAuthorization(ActionListener<ElasticInferenceServiceAuthorization
100105

101106
logger.warn(Strings.format(FAILED_TO_RETRIEVE_MESSAGE + " Encountered an exception: %s", exception));
102107
listener.onResponse(ElasticInferenceServiceAuthorization.newDisabledService());
108+
requestCompleteLatch.countDown();
103109
});
104110

105111
var request = new ElasticInferenceServiceAuthorizationRequest(baseUrl, getCurrentTraceInfo());
@@ -117,4 +123,14 @@ private TraceContext getCurrentTraceInfo() {
117123

118124
return new TraceContext(traceParent, traceState);
119125
}
126+
127+
void waitForAuthRequestCompletion(TimeValue timeValue) throws IllegalStateException {
128+
try {
129+
if (requestCompleteLatch.await(timeValue.getMillis(), TimeUnit.MILLISECONDS) == false) {
130+
throw new IllegalStateException("The authorization request did not complete.");
131+
}
132+
} catch (IllegalStateException | InterruptedException e) {
133+
logger.warn("Interrupted while waiting for the authorization request to complete", e);
134+
}
135+
}
120136
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.request.elastic;
9+
10+
import org.elasticsearch.ElasticsearchStatusException;
11+
import org.elasticsearch.rest.RestStatus;
12+
import org.elasticsearch.test.ESTestCase;
13+
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
14+
import org.junit.Before;
15+
16+
import static org.hamcrest.Matchers.containsString;
17+
import static org.hamcrest.Matchers.is;
18+
19+
public class ElasticInferenceServiceAuthorizationRequestTests extends ESTestCase {
20+
21+
private TraceContext traceContext;
22+
23+
@Before
24+
public void init() {
25+
traceContext = new TraceContext("dummyTraceParent", "dummyTraceState");
26+
}
27+
28+
public void testCreateUriThrowsForInvalidBaseUrl() {
29+
String invalidUrl = "http://invalid-url^";
30+
31+
ElasticsearchStatusException exception = assertThrows(
32+
ElasticsearchStatusException.class,
33+
() -> new ElasticInferenceServiceAuthorizationRequest(invalidUrl, traceContext)
34+
);
35+
36+
assertThat(exception.status(), is(RestStatus.BAD_REQUEST));
37+
assertThat(exception.getMessage(), containsString("Failed to create URI for service"));
38+
}
39+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.response.elastic;
9+
10+
import org.elasticsearch.inference.TaskType;
11+
import org.elasticsearch.test.ESTestCase;
12+
import org.elasticsearch.xcontent.json.JsonXContent;
13+
14+
import java.io.IOException;
15+
import java.util.EnumSet;
16+
import java.util.List;
17+
18+
import static org.hamcrest.Matchers.is;
19+
20+
21+
public class ElasticInferenceServiceAuthorizationResponseEntityTests extends ESTestCase {
22+
23+
public void testParseAllFields() throws IOException {
24+
String json = """
25+
{
26+
"models": [
27+
{
28+
"model_name": "test_model",
29+
"task_types": ["embedding/text/sparse", "chat/completion"]
30+
}
31+
]
32+
}
33+
""";
34+
35+
try (var parser = createParser(JsonXContent.jsonXContent, json)) {
36+
var entity = ElasticInferenceServiceAuthorizationResponseEntity.PARSER.apply(parser, null);
37+
var expected = new ElasticInferenceServiceAuthorizationResponseEntity(
38+
List.of(
39+
new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel(
40+
"test_model",
41+
EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION)
42+
)
43+
)
44+
);
45+
46+
assertThat(entity, is(expected));
47+
}
48+
}
49+
50+
51+
public void testParsing_EmptyModels() throws IOException {
52+
String json = """
53+
{
54+
"models": []
55+
}
56+
""";
57+
58+
try (var parser = createParser(JsonXContent.jsonXContent, json)) {
59+
var entity = ElasticInferenceServiceAuthorizationResponseEntity.PARSER.apply(parser, null);
60+
var expected = new ElasticInferenceServiceAuthorizationResponseEntity(List.of());
61+
62+
assertThat(entity, is(expected));
63+
}
64+
}
65+
66+
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.inference.services.elastic.authorization;
99

1010
import org.apache.logging.log4j.Logger;
11+
import org.elasticsearch.action.ActionListener;
1112
import org.elasticsearch.action.support.PlainActionFuture;
1213
import org.elasticsearch.common.settings.Settings;
1314
import org.elasticsearch.core.TimeValue;
@@ -32,6 +33,7 @@
3233
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
3334
import static org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender.MAX_RETIES;
3435
import static org.hamcrest.Matchers.is;
36+
import static org.mockito.ArgumentMatchers.any;
3537
import static org.mockito.Mockito.mock;
3638
import static org.mockito.Mockito.times;
3739
import static org.mockito.Mockito.verify;
@@ -183,4 +185,75 @@ public void testGetAuthorization_ReturnsAValidResponse() throws IOException {
183185
verifyNoMoreInteractions(logger);
184186
}
185187
}
188+
@SuppressWarnings("unchecked")
189+
public void testGetAuthorization_OnResponseCalledOnce() throws IllegalStateException {
190+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
191+
var eisGatewayUrl = getUrl(webServer);
192+
var logger = mock(Logger.class);
193+
var authHandler = new ElasticInferenceServiceAuthorizationHandler(eisGatewayUrl, threadPool, logger);
194+
195+
ActionListener<ElasticInferenceServiceAuthorization> listener = mock(ActionListener.class);
196+
197+
String responseJson = """
198+
{
199+
"models": [
200+
{
201+
"model_name": "model-a",
202+
"task_types": ["embedding/text/sparse", "chat/completion"]
203+
}
204+
]
205+
}
206+
""";
207+
208+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
209+
210+
authHandler.getAuthorization(listener, senderFactory.createSender());
211+
authHandler.waitForAuthRequestCompletion(TimeValue.timeValueSeconds(1));
212+
213+
verify(listener, times(1)).onResponse(any());
214+
verifyNoMoreInteractions(logger);
215+
}
216+
217+
public void testGetAuthorization_InvalidResponse() throws IOException {
218+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
219+
var eisGatewayUrl = getUrl(webServer);
220+
var logger = mock(Logger.class);
221+
var authHandler = new ElasticInferenceServiceAuthorizationHandler(eisGatewayUrl, threadPool, logger);
222+
223+
try (var sender = senderFactory.createSender()) {
224+
String responseJson = """
225+
{
226+
"completion": [
227+
{
228+
"result": "some result 1"
229+
},
230+
{
231+
"result": "some result 2"
232+
}
233+
]
234+
}
235+
""";
236+
237+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
238+
239+
PlainActionFuture<ElasticInferenceServiceAuthorization> listener = new PlainActionFuture<>();
240+
authHandler.getAuthorization(listener, sender);
241+
242+
var authResponse = listener.actionGet(TIMEOUT);
243+
assertTrue(authResponse.enabledTaskTypes().isEmpty());
244+
assertFalse(authResponse.isEnabled());
245+
246+
var loggerArgsCaptor = ArgumentCaptor.forClass(String.class);
247+
verify(logger).warn(loggerArgsCaptor.capture());
248+
var message = loggerArgsCaptor.getValue();
249+
assertThat(
250+
message,
251+
is(
252+
"Failed to retrieve the authorization information from the Elastic Inference Service."
253+
+ " Received an invalid response type: InferenceServiceResults"
254+
));
255+
}
256+
257+
}
258+
186259
}

0 commit comments

Comments
 (0)