2020
2121import static org .elasticsearch .xpack .inference .InferenceBaseRestTest .assertStatusOkOrCreated ;
2222import static org .hamcrest .Matchers .containsInAnyOrder ;
23- import static org .hamcrest .Matchers .equalTo ;
2423
2524public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
2625
@@ -31,13 +30,8 @@ public static void init() {
3130 }
3231
3332 public void testGetServicesWithoutTaskType () throws IOException {
34- List <Object > services = getAllServices ();
35- assertThat (services .size (), equalTo (24 ));
36-
37- var providers = providers (services );
38-
3933 assertThat (
40- providers ,
34+ allProviders () ,
4135 containsInAnyOrder (
4236 List .of (
4337 "alibabacloud-ai-search" ,
@@ -69,6 +63,10 @@ public void testGetServicesWithoutTaskType() throws IOException {
6963 );
7064 }
7165
66+ private Iterable <String > allProviders () throws IOException {
67+ return providers (getAllServices ());
68+ }
69+
7270 @ SuppressWarnings ("unchecked" )
7371 private Iterable <String > providers (List <Object > services ) {
7472 return services .stream ().map (service -> {
@@ -78,13 +76,8 @@ private Iterable<String> providers(List<Object> services) {
7876 }
7977
8078 public void testGetServicesWithTextEmbeddingTaskType () throws IOException {
81- List <Object > services = getServices (TaskType .TEXT_EMBEDDING );
82- assertThat (services .size (), equalTo (17 ));
83-
84- var providers = providers (services );
85-
8679 assertThat (
87- providers ,
80+ providersFor ( TaskType . TEXT_EMBEDDING ) ,
8881 containsInAnyOrder (
8982 List .of (
9083 "alibabacloud-ai-search" ,
@@ -109,14 +102,13 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
109102 );
110103 }
111104
112- public void testGetServicesWithRerankTaskType () throws IOException {
113- List <Object > services = getServices (TaskType .RERANK );
114- assertThat (services .size (), equalTo (9 ));
115-
116- var providers = providers (services );
105+ private Iterable <String > providersFor (TaskType taskType ) throws IOException {
106+ return providers (getServices (taskType ));
107+ }
117108
109+ public void testGetServicesWithRerankTaskType () throws IOException {
118110 assertThat (
119- providers ,
111+ providersFor ( TaskType . RERANK ) ,
120112 containsInAnyOrder (
121113 List .of (
122114 "alibabacloud-ai-search" ,
@@ -127,20 +119,16 @@ public void testGetServicesWithRerankTaskType() throws IOException {
127119 "jinaai" ,
128120 "test_reranking_service" ,
129121 "voyageai" ,
130- "hugging_face"
122+ "hugging_face" ,
123+ "amazon_sagemaker"
131124 ).toArray ()
132125 )
133126 );
134127 }
135128
136129 public void testGetServicesWithCompletionTaskType () throws IOException {
137- List <Object > services = getServices (TaskType .COMPLETION );
138- assertThat (services .size (), equalTo (14 ));
139-
140- var providers = providers (services );
141-
142130 assertThat (
143- providers ,
131+ providersFor ( TaskType . COMPLETION ) ,
144132 containsInAnyOrder (
145133 List .of (
146134 "alibabacloud-ai-search" ,
@@ -164,13 +152,8 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
164152 }
165153
166154 public void testGetServicesWithChatCompletionTaskType () throws IOException {
167- List <Object > services = getServices (TaskType .CHAT_COMPLETION );
168- assertThat (services .size (), equalTo (8 ));
169-
170- var providers = providers (services );
171-
172155 assertThat (
173- providers ,
156+ providersFor ( TaskType . CHAT_COMPLETION ) ,
174157 containsInAnyOrder (
175158 List .of (
176159 "deepseek" ,
@@ -187,13 +170,8 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException {
187170 }
188171
189172 public void testGetServicesWithSparseEmbeddingTaskType () throws IOException {
190- List <Object > services = getServices (TaskType .SPARSE_EMBEDDING );
191- assertThat (services .size (), equalTo (7 ));
192-
193- var providers = providers (services );
194-
195173 assertThat (
196- providers ,
174+ providersFor ( TaskType . SPARSE_EMBEDDING ) ,
197175 containsInAnyOrder (
198176 List .of (
199177 "alibabacloud-ai-search" ,
@@ -202,7 +180,8 @@ public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
202180 "elasticsearch" ,
203181 "hugging_face" ,
204182 "streaming_completion_test_service" ,
205- "test_service"
183+ "test_service" ,
184+ "amazon_sagemaker"
206185 ).toArray ()
207186 )
208187 );
0 commit comments