Skip to content

Commit fd7fd59

Browse files
Adding supported streaming tasks tests
1 parent 4f7b964 commit fd7fd59

File tree

3 files changed

+76
-4
lines changed

3 files changed

+76
-4
lines changed

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/RetryRule.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@
2020

2121
/**
2222
* Provides a way to retry a failed test. To use this functionality add something like the following to your test class:
23-
* </br>
23+
* <p>
2424
* <code>
2525
* {@literal @}Rule
26-
* </br>
26+
* <p>
2727
* public RetryRule retry = new RetryRule(3, TimeValue.timeValueSeconds(1));
2828
* </code>
29-
* </p>
29+
* <p>
3030
* See {@link InferenceGetServicesIT#retry} for an example.
3131
*/
3232
public class RetryRule implements TestRule {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@
5757
import java.util.Map;
5858
import java.util.Objects;
5959
import java.util.Set;
60+
import java.util.concurrent.CountDownLatch;
61+
import java.util.concurrent.TimeUnit;
6062

6163
import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException;
6264
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
@@ -87,6 +89,7 @@ public class ElasticInferenceService extends SenderService {
8789
private EnumSet<TaskType> enabledTaskTypes;
8890
private final ModelRegistry modelRegistry;
8991
private final ElasticInferenceServiceAuthorizationHandler authorizationHandler;
92+
private final CountDownLatch authorizationCompletedLatch = new CountDownLatch(1);
9093

9194
public ElasticInferenceService(
9295
HttpRequestSender.Factory factory,
@@ -108,8 +111,12 @@ public ElasticInferenceService(
108111

109112
private void getAuthorization() {
110113
try {
111-
ActionListener<ElasticInferenceServiceAuthorization> listener = ActionListener.wrap(this::setEnabledTaskTypes, e -> {
114+
ActionListener<ElasticInferenceServiceAuthorization> listener = ActionListener.wrap(result -> {
115+
setEnabledTaskTypes(result);
116+
authorizationCompletedLatch.countDown();
117+
}, e -> {
112118
// we don't need to do anything if there was a failure, everything is disabled by default
119+
authorizationCompletedLatch.countDown();
113120
});
114121

115122
authorizationHandler.getAuthorization(listener, getSender());
@@ -131,6 +138,17 @@ private static EnumSet<TaskType> filterTaskTypesByAuthorization(ElasticInference
131138
return implementedTaskTypes;
132139
}
133140

141+
// Default for testing
142+
void waitForAuthorizationToComplete(TimeValue waitTime) {
143+
try {
144+
if (authorizationCompletedLatch.await(waitTime.getSeconds(), TimeUnit.SECONDS) == false) {
145+
throw new IllegalStateException("The wait time has expired for authorization to complete.");
146+
}
147+
} catch (InterruptedException e) {
148+
throw new IllegalStateException("Waiting for authorization to complete was interrupted");
149+
}
150+
}
151+
134152
@Override
135153
public synchronized Set<TaskType> supportedStreamingTasks() {
136154
var enabledStreamingTaskTypes = EnumSet.of(TaskType.CHAT_COMPLETION);

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,50 @@ public void testGetConfiguration_WithoutSupportedTaskTypes_WhenModelsReturnTaskO
739739
}
740740
}
741741

742+
public void testSupportedStreamingTasks_ReturnsChatCompletion_WhenAuthRespondsWithAValidModel() throws Exception {
743+
String responseJson = """
744+
{
745+
"models": [
746+
{
747+
"model_name": "model-a",
748+
"task_types": ["embed/text/sparse", "chat"]
749+
}
750+
]
751+
}
752+
""";
753+
754+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
755+
756+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
757+
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
758+
service.waitForAuthorizationToComplete(TIMEOUT);
759+
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY)));
760+
assertTrue(service.defaultConfigIds().isEmpty());
761+
}
762+
}
763+
764+
public void testSupportedStreamingTasks_ReturnsEmpty_WhenAuthRespondsWithoutChatCompletion() throws Exception {
765+
String responseJson = """
766+
{
767+
"models": [
768+
{
769+
"model_name": "model-a",
770+
"task_types": ["embed/text/sparse"]
771+
}
772+
]
773+
}
774+
""";
775+
776+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
777+
778+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
779+
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
780+
service.waitForAuthorizationToComplete(TIMEOUT);
781+
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
782+
assertTrue(service.defaultConfigIds().isEmpty());
783+
}
784+
}
785+
742786
private ElasticInferenceService createServiceWithMockSender() {
743787
return createServiceWithMockSender(ElasticInferenceServiceAuthorizationTests.createEnabledAuth());
744788
}
@@ -788,4 +832,14 @@ private ElasticInferenceService createService(
788832
mockAuthHandler
789833
);
790834
}
835+
836+
private ElasticInferenceService createServiceWithAuthHandler(HttpRequestSender.Factory senderFactory, String eisGatewayUrl) {
837+
return new ElasticInferenceService(
838+
senderFactory,
839+
createWithEmptySettings(threadPool),
840+
new ElasticInferenceServiceComponents(eisGatewayUrl),
841+
mockModelRegistry(),
842+
new ElasticInferenceServiceAuthorizationHandler(eisGatewayUrl, threadPool)
843+
);
844+
}
791845
}

0 commit comments

Comments
 (0)