Skip to content

Commit bd9aa11

Browse files
committed
Fix task types
1 parent 1aea84c commit bd9aa11

File tree

3 files changed

+13
-13
lines changed

3 files changed

+13
-13
lines changed

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
123123

124124
public void testGetServicesWithCompletionTaskType() throws IOException {
125125
List<Object> services = getServices(TaskType.COMPLETION);
126-
assertThat(services.size(), equalTo(10));
126+
assertThat(services.size(), equalTo(11));
127127

128128
var providers = providers(services);
129129

@@ -140,19 +140,23 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
140140
"deepseek",
141141
"googleaistudio",
142142
"openai",
143-
"streaming_completion_test_service"
143+
"streaming_completion_test_service",
144+
"sagemaker"
144145
).toArray()
145146
)
146147
);
147148
}
148149

149150
public void testGetServicesWithChatCompletionTaskType() throws IOException {
150151
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
151-
assertThat(services.size(), equalTo(4));
152+
assertThat(services.size(), equalTo(5));
152153

153154
var providers = providers(services);
154155

155-
assertThat(providers, containsInAnyOrder(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service").toArray()));
156+
assertThat(
157+
providers,
158+
containsInAnyOrder(List.of("deepseek", "elastic", "openai", "sagemaker", "streaming_completion_test_service").toArray())
159+
);
156160
}
157161

158162
public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemas.java

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import java.util.List;
2222
import java.util.Map;
2323
import java.util.Set;
24-
import java.util.function.Predicate;
2524
import java.util.stream.Collectors;
2625
import java.util.stream.Stream;
2726

@@ -57,13 +56,7 @@ public class SageMakerSchemas {
5756
.collect(Collectors.groupingBy(TaskAndApi::api, Collectors.mapping(TaskAndApi::taskType, Collectors.toSet())));
5857

5958
supportedStreamingTasks = streamSchemas.keySet().stream().map(TaskAndApi::taskType).collect(Collectors.toSet());
60-
supportedTaskTypes = EnumSet.copyOf(
61-
schemas.keySet()
62-
.stream()
63-
.map(TaskAndApi::taskType)
64-
.filter(Predicate.not(TaskType.CHAT_COMPLETION::equals)) // chat_completion is currently never supported for non-streaming
65-
.collect(Collectors.toSet())
66-
);
59+
supportedTaskTypes = EnumSet.copyOf(schemas.keySet().stream().map(TaskAndApi::taskType).collect(Collectors.toSet()));
6760
}
6861

6962
private static Map<TaskAndApi, SageMakerSchema> register(SageMakerSchemaPayload... payloads) {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemasTests.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,10 @@ public static SageMakerSchema mockSchema() {
4141
private static final SageMakerSchemas schemas = new SageMakerSchemas();
4242

4343
public void testSupportedTaskTypes() {
44-
assertThat(schemas.supportedTaskTypes(), containsInAnyOrder(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION));
44+
assertThat(
45+
schemas.supportedTaskTypes(),
46+
containsInAnyOrder(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION, TaskType.CHAT_COMPLETION)
47+
);
4548
}
4649

4750
public void testSupportedStreamingTasks() {

0 commit comments

Comments
 (0)