Skip to content

Commit 0054891

Browse files
Add to inference service and crud IT rerank tests
1 parent 2ea07f0 commit 0054891

File tree

3 files changed

+95
-1
lines changed

3 files changed

+95
-1
lines changed

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,20 @@ static String mockDenseServiceModelConfig() {
171171
""";
172172
}
173173

174+
static String mockRerankServiceModelConfig() {
175+
return """
176+
{
177+
"service": "test_reranking_service",
178+
"service_settings": {
179+
"model_id": "my_model",
180+
"api_key": "abc64"
181+
},
182+
"task_settings": {
183+
}
184+
}
185+
""";
186+
}
187+
174188
static void deleteModel(String modelId) throws IOException {
175189
var request = new Request("DELETE", "_inference/" + modelId);
176190
var response = client().performRequest(request);

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,12 @@ public void testCRUD() throws IOException {
5353
for (int i = 0; i < 4; i++) {
5454
putModel("te_model_" + i, mockDenseServiceModelConfig(), TaskType.TEXT_EMBEDDING);
5555
}
56+
for (int i = 0; i < 3; i++) {
57+
putModel("re-model-" + i, mockRerankServiceModelConfig(), TaskType.RERANK);
58+
}
5659

5760
var getAllModels = getAllModels();
58-
int numModels = 12;
61+
int numModels = 15;
5962
assertThat(getAllModels, hasSize(numModels));
6063

6164
var getSparseModels = getModels("_all", TaskType.SPARSE_EMBEDDING);
@@ -71,6 +74,13 @@ public void testCRUD() throws IOException {
7174
for (var denseModel : getDenseModels) {
7275
assertEquals("text_embedding", denseModel.get("task_type"));
7376
}
77+
78+
var getRerankModels = getModels("_all", TaskType.RERANK);
79+
int numRerankModels = 4;
80+
assertThat(getRerankModels, hasSize(numRerankModels));
81+
for (var denseModel : getRerankModels) {
82+
assertEquals("rerank", denseModel.get("task_type"));
83+
}
7484
String oldApiKey;
7585
{
7686
var singleModel = getModels("se_model_1", TaskType.SPARSE_EMBEDDING);
@@ -100,6 +110,9 @@ public void testCRUD() throws IOException {
100110
for (int i = 0; i < 4; i++) {
101111
deleteModel("te_model_" + i, TaskType.TEXT_EMBEDDING);
102112
}
113+
for (int i = 0; i < 3; i++) {
114+
deleteModel("re-model-" + i, TaskType.RERANK);
115+
}
103116
}
104117

105118
public void testGetModelWithWrongTaskType() throws IOException {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference;
9+
10+
import org.elasticsearch.inference.TaskType;
11+
12+
import java.io.IOException;
13+
import java.util.List;
14+
import java.util.Map;
15+
16+
public class MockRerankInferenceServiceIT extends InferenceBaseRestTest {
17+
18+
@SuppressWarnings("unchecked")
19+
public void testMockService() throws IOException {
20+
String inferenceEntityId = "test-mock";
21+
var putModel = putModel(inferenceEntityId, mockRerankServiceModelConfig(), TaskType.RERANK);
22+
var model = getModels(inferenceEntityId, TaskType.RERANK).get(0);
23+
24+
for (var modelMap : List.of(putModel, model)) {
25+
assertEquals(inferenceEntityId, modelMap.get("inference_id"));
26+
assertEquals(TaskType.RERANK, TaskType.fromString((String) modelMap.get("task_type")));
27+
assertEquals("test_reranking_service", modelMap.get("service"));
28+
}
29+
30+
List<String> input = List.of(randomAlphaOfLength(10));
31+
var inference = infer(inferenceEntityId, input);
32+
assertNonEmptyInferenceResults(inference, 1, TaskType.RERANK);
33+
assertEquals(inference, infer(inferenceEntityId, input));
34+
assertNotEquals(inference, infer(inferenceEntityId, randomValueOtherThan(input, () -> List.of(randomAlphaOfLength(10)))));
35+
}
36+
37+
public void testMockServiceWithMultipleInputs() throws IOException {
38+
String inferenceEntityId = "test-mock-with-multi-inputs";
39+
putModel(inferenceEntityId, mockRerankServiceModelConfig(), TaskType.RERANK);
40+
var queryParams = Map.of("timeout", "120s");
41+
42+
var inference = infer(
43+
inferenceEntityId,
44+
TaskType.RERANK,
45+
List.of(randomAlphaOfLength(5), randomAlphaOfLength(10)),
46+
"What if?",
47+
queryParams
48+
);
49+
50+
assertNonEmptyInferenceResults(inference, 2, TaskType.RERANK);
51+
}
52+
53+
@SuppressWarnings("unchecked")
54+
public void testMockService_DoesNotReturnSecretsInGetResponse() throws IOException {
55+
String inferenceEntityId = "test-mock";
56+
var putModel = putModel(inferenceEntityId, mockRerankServiceModelConfig(), TaskType.RERANK);
57+
var model = getModels(inferenceEntityId, TaskType.RERANK).get(0);
58+
59+
var serviceSettings = (Map<String, Object>) model.get("service_settings");
60+
assertNull(serviceSettings.get("api_key"));
61+
assertNotNull(serviceSettings.get("model_id"));
62+
63+
var putServiceSettings = (Map<String, Object>) putModel.get("service_settings");
64+
assertNull(putServiceSettings.get("api_key"));
65+
assertNotNull(putServiceSettings.get("model_id"));
66+
}
67+
}

0 commit comments

Comments
 (0)