Skip to content

Commit 11e8aad

Browse files
committed
Integ test
1 parent d5ebe19 commit 11e8aad

File tree

15 files changed

+122
-40
lines changed

15 files changed

+122
-40
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetRerankerAction.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ public Response(StreamInput in) throws IOException {
8080
this.windowSize = in.readVInt();
8181
}
8282

83+
public int getWindowSize() {
84+
return windowSize;
85+
}
86+
8387
@Override
8488
public void writeTo(StreamOutput out) throws IOException {
8589
out.writeVInt(windowSize);

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.elasticsearch.inference.Model;
2626
import org.elasticsearch.inference.ModelConfigurations;
2727
import org.elasticsearch.inference.ModelSecrets;
28+
import org.elasticsearch.inference.RerankingInferenceService;
2829
import org.elasticsearch.inference.ServiceSettings;
2930
import org.elasticsearch.inference.SettingsConfiguration;
3031
import org.elasticsearch.inference.TaskSettings;
@@ -62,7 +63,7 @@ public TestRerankingModel(String inferenceEntityId, TestServiceSettings serviceS
6263
}
6364
}
6465

65-
public static class TestInferenceService extends AbstractTestInferenceService {
66+
public static class TestInferenceService extends AbstractTestInferenceService implements RerankingInferenceService {
6667
public static final String NAME = "test_reranking_service";
6768

6869
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.RERANK);
@@ -191,6 +192,11 @@ protected ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceS
191192
return TestServiceSettings.fromMap(serviceSettingsMap);
192193
}
193194

195+
@Override
196+
public int rerankerWindowSize(String modelId) {
197+
return 333;
198+
}
199+
194200
public static class Configuration {
195201
public static InferenceServiceConfiguration get() {
196202
return configuration.getOrCompute();

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterBasicLicenseIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ public static Iterable<Object[]> parameters() {
6161
@Before
6262
public void setup() throws Exception {
6363
ModelRegistry modelRegistry = internalCluster().getCurrentMasterNodeInstance(ModelRegistry.class);
64-
Utils.storeSparseModel(modelRegistry);
64+
Utils.storeSparseModel("sparse-endpoint", modelRegistry);
6565
Utils.storeDenseModel(
6666
modelRegistry,
6767
randomIntBetween(1, 100),

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ public void setup() throws Exception {
9090
() -> randomFrom(DenseVectorFieldMapperTestUtils.getSupportedSimilarities(elementType))
9191
);
9292
int dimensions = DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, 100);
93-
Utils.storeSparseModel(modelRegistry);
93+
Utils.storeSparseModel("sparse-endpoint", modelRegistry);
9494
Utils.storeDenseModel(modelRegistry, dimensions, similarity, elementType);
9595
}
9696

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.integration;
9+
10+
import org.elasticsearch.ElasticsearchStatusException;
11+
import org.elasticsearch.plugins.Plugin;
12+
import org.elasticsearch.test.ESIntegTestCase;
13+
import org.elasticsearch.xpack.core.inference.action.GetRerankerAction;
14+
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
15+
import org.elasticsearch.xpack.inference.Utils;
16+
import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin;
17+
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
18+
import org.junit.Before;
19+
20+
import java.util.Collection;
21+
import java.util.List;
22+
23+
import static org.hamcrest.Matchers.containsString;
24+
25+
public class RerankWindowSizeIT extends ESIntegTestCase {
26+
27+
@Before
28+
public void setup() throws Exception {
29+
ModelRegistry modelRegistry = internalCluster().getCurrentMasterNodeInstance(ModelRegistry.class);
30+
Utils.storeRerankModel("rerank-endpoint", modelRegistry);
31+
Utils.storeSparseModel("sparse-endpoint", modelRegistry);
32+
}
33+
34+
@Override
35+
protected Collection<Class<? extends Plugin>> nodePlugins() {
36+
return List.of(LocalStateInferencePlugin.class, TestInferenceServicePlugin.class);
37+
}
38+
39+
public void testRerankWindowSizeAction() {
40+
var response = client().execute(GetRerankerAction.INSTANCE, new GetRerankerAction.Request("rerank-endpoint")).actionGet();
41+
assertEquals(333, response.getWindowSize());
42+
}
43+
44+
public void testActionNotARerankder() {
45+
var e = expectThrows(
46+
ElasticsearchStatusException.class,
47+
() -> client().execute(GetRerankerAction.INSTANCE, new GetRerankerAction.Request("sparse-endpoint")).actionGet()
48+
);
49+
assertThat(e.getMessage(), containsString("Inference endpoint [sparse-endpoint] is not a reranker"));
50+
}
51+
}

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextIndexVersionIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ public void setup() throws Exception {
6868
() -> randomFrom(DenseVectorFieldMapperTestUtils.getSupportedSimilarities(elementType))
6969
);
7070
int dimensions = DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, 100);
71-
Utils.storeSparseModel(modelRegistry);
71+
Utils.storeSparseModel("sparse-endpoint", modelRegistry);
7272
Utils.storeDenseModel(modelRegistry, dimensions, similarity, elementType);
7373

7474
Set<IndexVersion> availableVersions = IndexVersionUtils.allReleasedVersions()

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetRerankerAction.java

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
import org.elasticsearch.action.support.HandledTransportAction;
1414
import org.elasticsearch.action.support.SubscribableListener;
1515
import org.elasticsearch.common.util.concurrent.EsExecutors;
16-
import org.elasticsearch.inference.InferenceService;
1716
import org.elasticsearch.inference.InferenceServiceRegistry;
17+
import org.elasticsearch.inference.RerankingInferenceService;
1818
import org.elasticsearch.inference.TaskType;
1919
import org.elasticsearch.inference.UnparsedModel;
2020
import org.elasticsearch.injection.guice.Inject;
@@ -51,26 +51,46 @@ public TransportGetRerankerAction(
5151
@Override
5252
protected void doExecute(Task task, GetRerankerAction.Request request, ActionListener<GetRerankerAction.Response> listener) {
5353

54-
SubscribableListener.<UnparsedModel>newForked(l -> modelRegistry.getModel(request.getInferenceEntityId(), l))
55-
.andThen((l2, model) -> {
56-
if (model.taskType() != TaskType.RERANK) {
57-
l2.onFailure(
58-
new ElasticsearchStatusException(
59-
"Inference endpoint [{}] is not a reranker",
60-
RestStatus.BAD_REQUEST,
61-
request.getInferenceEntityId()
62-
)
54+
SubscribableListener.<UnparsedModel>newForked(l -> modelRegistry.getModel(request.getInferenceEntityId(), l)).<
55+
GetRerankerAction.Response>andThen((l, unparsedModel) -> {
56+
if (unparsedModel.taskType() != TaskType.RERANK) {
57+
throw new ElasticsearchStatusException(
58+
"Inference endpoint [{}] is not a reranker",
59+
RestStatus.BAD_REQUEST,
60+
request.getInferenceEntityId()
6361
);
64-
return;
6562
}
6663

67-
var service = serviceRegistry.getService(model.service());
68-
l2.onResponse(new GetRerankerAction.Response(rerankWindowSize(service.get())));
69-
});
70-
}
64+
var service = serviceRegistry.getService(unparsedModel.service());
65+
if (service.isEmpty()) {
66+
throw new ElasticsearchStatusException(
67+
"Unknown service [{}] for inference endpoint [{}]",
68+
RestStatus.BAD_REQUEST,
69+
unparsedModel.service(),
70+
request.getInferenceEntityId()
71+
);
72+
}
73+
74+
var model = service.get()
75+
.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
7176

72-
public int rerankWindowSize(InferenceService service) {
73-
return 0;
77+
if (service.get() instanceof RerankingInferenceService rerankingInferenceService) {
78+
l.onResponse(
79+
new GetRerankerAction.Response(rerankWindowSize(rerankingInferenceService, model.getServiceSettings().modelId()))
80+
);
81+
} else {
82+
throw new IllegalStateException(
83+
"Inference endpoint ["
84+
+ request.getInferenceEntityId()
85+
+ "] is a reranker but the service ["
86+
+ service.get().name()
87+
+ "] does not support reranking"
88+
);
89+
}
90+
}).addListener(listener);
7491
}
7592

93+
public int rerankWindowSize(RerankingInferenceService service, String modelId) {
94+
return service.rerankerWindowSize(modelId);
95+
}
7696
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/LocalStateInferencePlugin.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.elasticsearch.xpack.core.LocalStateCompositeXPackPlugin;
1818
import org.elasticsearch.xpack.core.ssl.SSLService;
1919
import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension;
20+
import org.elasticsearch.xpack.inference.mock.TestRerankingServiceExtension;
2021
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;
2122

2223
import java.nio.file.Path;
@@ -47,7 +48,8 @@ protected XPackLicenseState getLicenseState() {
4748
public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
4849
return List.of(
4950
TestSparseInferenceServiceExtension.TestInferenceService::new,
50-
TestDenseInferenceServiceExtension.TestInferenceService::new
51+
TestDenseInferenceServiceExtension.TestInferenceService::new,
52+
TestRerankingServiceExtension.TestInferenceService::new
5153
);
5254
}
5355
};

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.elasticsearch.xcontent.XContentType;
2828
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
2929
import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension;
30+
import org.elasticsearch.xpack.inference.mock.TestRerankingServiceExtension;
3031
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;
3132
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
3233
import org.hamcrest.Matchers;
@@ -81,9 +82,9 @@ public static ScalingExecutorBuilder inferenceUtilityPool() {
8182
);
8283
}
8384

84-
public static void storeSparseModel(ModelRegistry modelRegistry) throws Exception {
85+
public static void storeSparseModel(String inferenceId, ModelRegistry modelRegistry) throws Exception {
8586
Model model = new TestSparseInferenceServiceExtension.TestSparseModel(
86-
TestSparseInferenceServiceExtension.TestInferenceService.NAME,
87+
inferenceId,
8788
new TestSparseInferenceServiceExtension.TestServiceSettings("sparse_model", null, false)
8889
);
8990
storeModel(modelRegistry, model);
@@ -102,6 +103,14 @@ public static void storeDenseModel(
102103
storeModel(modelRegistry, model);
103104
}
104105

106+
public static void storeRerankModel(String inferenceId, ModelRegistry modelRegistry) throws Exception {
107+
Model model = new TestRerankingServiceExtension.TestRerankingModel(
108+
inferenceId,
109+
new TestRerankingServiceExtension.TestServiceSettings("rerank-model")
110+
);
111+
storeModel(modelRegistry, model);
112+
}
113+
105114
public static void storeModel(ModelRegistry modelRegistry, Model model) throws Exception {
106115
PlainActionFuture<Boolean> listener = new PlainActionFuture<>();
107116
modelRegistry.storeModel(model, listener, AcknowledgedRequest.DEFAULT_ACK_TIMEOUT);

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextNonDynamicFieldMapperTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ public class SemanticTextNonDynamicFieldMapperTests extends NonDynamicFieldMappe
2626
@Before
2727
public void setup() throws Exception {
2828
ModelRegistry modelRegistry = node().injector().getInstance(ModelRegistry.class);
29-
Utils.storeSparseModel(modelRegistry);
29+
Utils.storeSparseModel("sparse-endpoint", modelRegistry);
3030
}
3131

3232
@Override

0 commit comments

Comments
 (0)