Skip to content

Commit fa1909c

Browse files
committed
Add sagemaker to IT
1 parent b4b1f3f commit fa1909c

File tree

1 file changed

+18
-39
lines changed

1 file changed

+18
-39
lines changed

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 18 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.assertStatusOkOrCreated;
2222
import static org.hamcrest.Matchers.containsInAnyOrder;
23-
import static org.hamcrest.Matchers.equalTo;
2423

2524
public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
2625

@@ -31,13 +30,8 @@ public static void init() {
3130
}
3231

3332
public void testGetServicesWithoutTaskType() throws IOException {
34-
List<Object> services = getAllServices();
35-
assertThat(services.size(), equalTo(24));
36-
37-
var providers = providers(services);
38-
3933
assertThat(
40-
providers,
34+
allProviders(),
4135
containsInAnyOrder(
4236
List.of(
4337
"alibabacloud-ai-search",
@@ -69,6 +63,10 @@ public void testGetServicesWithoutTaskType() throws IOException {
6963
);
7064
}
7165

66+
private Iterable<String> allProviders() throws IOException {
67+
return providers(getAllServices());
68+
}
69+
7270
@SuppressWarnings("unchecked")
7371
private Iterable<String> providers(List<Object> services) {
7472
return services.stream().map(service -> {
@@ -78,13 +76,8 @@ private Iterable<String> providers(List<Object> services) {
7876
}
7977

8078
public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
81-
List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
82-
assertThat(services.size(), equalTo(17));
83-
84-
var providers = providers(services);
85-
8679
assertThat(
87-
providers,
80+
providersFor(TaskType.TEXT_EMBEDDING),
8881
containsInAnyOrder(
8982
List.of(
9083
"alibabacloud-ai-search",
@@ -109,14 +102,13 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
109102
);
110103
}
111104

112-
public void testGetServicesWithRerankTaskType() throws IOException {
113-
List<Object> services = getServices(TaskType.RERANK);
114-
assertThat(services.size(), equalTo(9));
115-
116-
var providers = providers(services);
105+
private Iterable<String> providersFor(TaskType taskType) throws IOException {
106+
return providers(getServices(taskType));
107+
}
117108

109+
public void testGetServicesWithRerankTaskType() throws IOException {
118110
assertThat(
119-
providers,
111+
providersFor(TaskType.RERANK),
120112
containsInAnyOrder(
121113
List.of(
122114
"alibabacloud-ai-search",
@@ -127,20 +119,16 @@ public void testGetServicesWithRerankTaskType() throws IOException {
127119
"jinaai",
128120
"test_reranking_service",
129121
"voyageai",
130-
"hugging_face"
122+
"hugging_face",
123+
"amazon_sagemaker"
131124
).toArray()
132125
)
133126
);
134127
}
135128

136129
public void testGetServicesWithCompletionTaskType() throws IOException {
137-
List<Object> services = getServices(TaskType.COMPLETION);
138-
assertThat(services.size(), equalTo(14));
139-
140-
var providers = providers(services);
141-
142130
assertThat(
143-
providers,
131+
providersFor(TaskType.COMPLETION),
144132
containsInAnyOrder(
145133
List.of(
146134
"alibabacloud-ai-search",
@@ -164,13 +152,8 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
164152
}
165153

166154
public void testGetServicesWithChatCompletionTaskType() throws IOException {
167-
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
168-
assertThat(services.size(), equalTo(8));
169-
170-
var providers = providers(services);
171-
172155
assertThat(
173-
providers,
156+
providersFor(TaskType.CHAT_COMPLETION),
174157
containsInAnyOrder(
175158
List.of(
176159
"deepseek",
@@ -187,13 +170,8 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException {
187170
}
188171

189172
public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
190-
List<Object> services = getServices(TaskType.SPARSE_EMBEDDING);
191-
assertThat(services.size(), equalTo(7));
192-
193-
var providers = providers(services);
194-
195173
assertThat(
196-
providers,
174+
providersFor(TaskType.SPARSE_EMBEDDING),
197175
containsInAnyOrder(
198176
List.of(
199177
"alibabacloud-ai-search",
@@ -202,7 +180,8 @@ public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
202180
"elasticsearch",
203181
"hugging_face",
204182
"streaming_completion_test_service",
205-
"test_service"
183+
"test_service",
184+
"amazon_sagemaker"
206185
).toArray()
207186
)
208187
);

0 commit comments

Comments
 (0)