Skip to content

Commit 58ec62d

Browse files
Adding more tests and a mock gateway
1 parent 574eaf9 commit 58ec62d

File tree

15 files changed

+607
-223
lines changed

15 files changed

+607
-223
lines changed

x-pack/plugin/inference/qa/inference-service-tests/build.gradle

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ dependencies {
44
javaRestTestImplementation project(path: xpackModule('core'))
55
javaRestTestImplementation project(path: xpackModule('inference'))
66
clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin')
7+
// Added this to have access to MockWebServer within the tests
8+
javaRestTestImplementation(testArtifact(project(xpackModule('core'))))
79
}
810

911
tasks.named("javaRestTest").configure {

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
import static org.hamcrest.Matchers.hasSize;
4343

4444
public class InferenceBaseRestTest extends ESRestTestCase {
45-
4645
@ClassRule
4746
public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
4847
.distribution(DistributionType.DEFAULT)
@@ -330,31 +329,13 @@ protected List<Map<String, Object>> getAllModels() throws IOException {
330329
return (List<Map<String, Object>>) getInternalAsMap("_inference/_all").get("endpoints");
331330
}
332331

333-
protected List<Object> getAllServices() throws IOException {
334-
var endpoint = Strings.format("_inference/_services");
335-
return getInternalAsList(endpoint);
336-
}
337-
338-
@SuppressWarnings("unchecked")
339-
protected List<Object> getServices(TaskType taskType) throws IOException {
340-
var endpoint = Strings.format("_inference/_services/%s", taskType);
341-
return getInternalAsList(endpoint);
342-
}
343-
344332
private Map<String, Object> getInternalAsMap(String endpoint) throws IOException {
345333
var request = new Request("GET", endpoint);
346334
var response = client().performRequest(request);
347335
assertOkOrCreated(response);
348336
return entityAsMap(response);
349337
}
350338

351-
private List<Object> getInternalAsList(String endpoint) throws IOException {
352-
var request = new Request("GET", endpoint);
353-
var response = client().performRequest(request);
354-
assertOkOrCreated(response);
355-
return entityAsList(response);
356-
}
357-
358339
protected Map<String, Object> infer(String modelId, List<String> input) throws IOException {
359340
var endpoint = Strings.format("_inference/%s", modelId);
360341
return inferInternal(endpoint, input, null, Map.of());
@@ -511,7 +492,7 @@ protected void assertNonEmptyInferenceResults(Map<String, Object> resultMap, int
511492
}
512493
}
513494

514-
protected static void assertOkOrCreated(Response response) throws IOException {
495+
public static void assertOkOrCreated(Response response) throws IOException {
515496
int statusCode = response.getStatusLine().getStatusCode();
516497
// Once EntityUtils.toString(entity) is called the entity cannot be reused.
517498
// Avoid that call with check here.

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java

Lines changed: 0 additions & 176 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,8 @@
1818
import org.elasticsearch.xcontent.XContentBuilder;
1919
import org.elasticsearch.xcontent.XContentFactory;
2020
import org.elasticsearch.xcontent.XContentType;
21-
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceFeature;
2221

2322
import java.io.IOException;
24-
import java.util.ArrayList;
25-
import java.util.Arrays;
2623
import java.util.Iterator;
2724
import java.util.List;
2825
import java.util.Locale;
@@ -145,179 +142,6 @@ public void testApisWithoutTaskType() throws IOException {
145142
deleteModel(modelId);
146143
}
147144

148-
@SuppressWarnings("unchecked")
149-
public void testGetServicesWithoutTaskType() throws IOException {
150-
List<Object> services = getAllServices();
151-
if ((ElasticInferenceServiceFeature.DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled()
152-
|| ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled())) {
153-
assertThat(services.size(), equalTo(19));
154-
} else {
155-
assertThat(services.size(), equalTo(18));
156-
}
157-
158-
String[] providers = new String[services.size()];
159-
for (int i = 0; i < services.size(); i++) {
160-
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
161-
providers[i] = (String) serviceConfig.get("service");
162-
}
163-
164-
var providerList = new ArrayList<>(
165-
Arrays.asList(
166-
"alibabacloud-ai-search",
167-
"amazonbedrock",
168-
"anthropic",
169-
"azureaistudio",
170-
"azureopenai",
171-
"cohere",
172-
"elasticsearch",
173-
"googleaistudio",
174-
"googlevertexai",
175-
"hugging_face",
176-
"jinaai",
177-
"mistral",
178-
"openai",
179-
"streaming_completion_test_service",
180-
"test_reranking_service",
181-
"test_service",
182-
"text_embedding_test_service",
183-
"watsonxai"
184-
)
185-
);
186-
if ((ElasticInferenceServiceFeature.DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled()
187-
|| ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled())) {
188-
providerList.add(6, "elastic");
189-
}
190-
assertArrayEquals(providerList.toArray(), providers);
191-
}
192-
193-
@SuppressWarnings("unchecked")
194-
public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
195-
List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
196-
assertThat(services.size(), equalTo(14));
197-
198-
String[] providers = new String[services.size()];
199-
for (int i = 0; i < services.size(); i++) {
200-
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
201-
providers[i] = (String) serviceConfig.get("service");
202-
}
203-
204-
assertArrayEquals(
205-
List.of(
206-
"alibabacloud-ai-search",
207-
"amazonbedrock",
208-
"azureaistudio",
209-
"azureopenai",
210-
"cohere",
211-
"elasticsearch",
212-
"googleaistudio",
213-
"googlevertexai",
214-
"hugging_face",
215-
"jinaai",
216-
"mistral",
217-
"openai",
218-
"text_embedding_test_service",
219-
"watsonxai"
220-
).toArray(),
221-
providers
222-
);
223-
}
224-
225-
@SuppressWarnings("unchecked")
226-
public void testGetServicesWithRerankTaskType() throws IOException {
227-
List<Object> services = getServices(TaskType.RERANK);
228-
assertThat(services.size(), equalTo(6));
229-
230-
String[] providers = new String[services.size()];
231-
for (int i = 0; i < services.size(); i++) {
232-
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
233-
providers[i] = (String) serviceConfig.get("service");
234-
}
235-
236-
assertArrayEquals(
237-
List.of("alibabacloud-ai-search", "cohere", "elasticsearch", "googlevertexai", "jinaai", "test_reranking_service").toArray(),
238-
providers
239-
);
240-
}
241-
242-
@SuppressWarnings("unchecked")
243-
public void testGetServicesWithCompletionTaskType() throws IOException {
244-
List<Object> services = getServices(TaskType.COMPLETION);
245-
assertThat(services.size(), equalTo(9));
246-
247-
String[] providers = new String[services.size()];
248-
for (int i = 0; i < services.size(); i++) {
249-
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
250-
providers[i] = (String) serviceConfig.get("service");
251-
}
252-
253-
var providerList = new ArrayList<>(
254-
List.of(
255-
"alibabacloud-ai-search",
256-
"amazonbedrock",
257-
"anthropic",
258-
"azureaistudio",
259-
"azureopenai",
260-
"cohere",
261-
"googleaistudio",
262-
"openai",
263-
"streaming_completion_test_service"
264-
)
265-
);
266-
267-
assertArrayEquals(providers, providerList.toArray());
268-
}
269-
270-
@SuppressWarnings("unchecked")
271-
public void testGetServicesWithChatCompletionTaskType() throws IOException {
272-
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
273-
if ((ElasticInferenceServiceFeature.DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled()
274-
|| ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled())) {
275-
assertThat(services.size(), equalTo(3));
276-
} else {
277-
assertThat(services.size(), equalTo(2));
278-
}
279-
280-
String[] providers = new String[services.size()];
281-
for (int i = 0; i < services.size(); i++) {
282-
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
283-
providers[i] = (String) serviceConfig.get("service");
284-
}
285-
286-
var providerList = new ArrayList<>(List.of("openai", "streaming_completion_test_service"));
287-
288-
if ((ElasticInferenceServiceFeature.DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled()
289-
|| ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled())) {
290-
providerList.addFirst("elastic");
291-
}
292-
293-
assertArrayEquals(providers, providerList.toArray());
294-
}
295-
296-
@SuppressWarnings("unchecked")
297-
public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
298-
List<Object> services = getServices(TaskType.SPARSE_EMBEDDING);
299-
300-
if ((ElasticInferenceServiceFeature.DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled()
301-
|| ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled())) {
302-
assertThat(services.size(), equalTo(5));
303-
} else {
304-
assertThat(services.size(), equalTo(4));
305-
}
306-
307-
String[] providers = new String[services.size()];
308-
for (int i = 0; i < services.size(); i++) {
309-
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
310-
providers[i] = (String) serviceConfig.get("service");
311-
}
312-
313-
var providerList = new ArrayList<>(Arrays.asList("alibabacloud-ai-search", "elasticsearch", "hugging_face", "test_service"));
314-
if ((ElasticInferenceServiceFeature.DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled()
315-
|| ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled())) {
316-
providerList.add(1, "elastic");
317-
}
318-
assertArrayEquals(providers, providerList.toArray());
319-
}
320-
321145
public void testSkipValidationAndStart() throws IOException {
322146
String openAiConfigWithBadApiKey = """
323147
{

0 commit comments

Comments
 (0)