Skip to content

Commit 0b70308

Browse files
authored
[ML] Add internal action to return the Rerank window size (#132169)
1 parent ab4bd14 commit 0b70308

File tree

47 files changed

+808
-83
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+808
-83
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.inference;
11+
12+
public interface RerankingInferenceService {
13+
14+
/**
15+
* The default window size for small reranking models (512 input tokens).
16+
*/
17+
int CONSERVATIVE_DEFAULT_WINDOW_SIZE = 300;
18+
19+
/**
20+
* The reranking model's max window or an approximation of
21+
* measured in the number of words.
22+
* @param modelId The model ID
23+
* @return Window size in words
24+
*/
25+
int rerankerWindowSize(String modelId);
26+
}
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.inference.action;
9+
10+
import org.elasticsearch.action.ActionRequest;
11+
import org.elasticsearch.action.ActionRequestValidationException;
12+
import org.elasticsearch.action.ActionResponse;
13+
import org.elasticsearch.action.ActionType;
14+
import org.elasticsearch.common.io.stream.StreamInput;
15+
import org.elasticsearch.common.io.stream.StreamOutput;
16+
17+
import java.io.IOException;
18+
import java.util.Objects;
19+
20+
public class GetRerankerWindowSizeAction extends ActionType<GetRerankerWindowSizeAction.Response> {
21+
22+
public static final GetRerankerWindowSizeAction INSTANCE = new GetRerankerWindowSizeAction();
23+
public static final String NAME = "cluster:internal/xpack/inference/rerankwindowsize/get";
24+
25+
public GetRerankerWindowSizeAction() {
26+
super(NAME);
27+
}
28+
29+
public static class Request extends ActionRequest {
30+
31+
private final String inferenceEntityId;
32+
33+
public Request(String inferenceEntityId) {
34+
this.inferenceEntityId = inferenceEntityId;
35+
}
36+
37+
public Request(StreamInput in) throws IOException {
38+
super(in);
39+
this.inferenceEntityId = in.readString();
40+
}
41+
42+
public String getInferenceEntityId() {
43+
return inferenceEntityId;
44+
}
45+
46+
@Override
47+
public void writeTo(StreamOutput out) throws IOException {
48+
super.writeTo(out);
49+
out.writeString(inferenceEntityId);
50+
}
51+
52+
@Override
53+
public ActionRequestValidationException validate() {
54+
return null;
55+
}
56+
57+
@Override
58+
public boolean equals(Object o) {
59+
if (o == null || getClass() != o.getClass()) return false;
60+
Request request = (Request) o;
61+
return Objects.equals(inferenceEntityId, request.inferenceEntityId);
62+
}
63+
64+
@Override
65+
public int hashCode() {
66+
return Objects.hashCode(inferenceEntityId);
67+
}
68+
}
69+
70+
public static class Response extends ActionResponse {
71+
72+
private final int windowSize;
73+
74+
public Response(int windowSize) {
75+
this.windowSize = windowSize;
76+
}
77+
78+
public Response(StreamInput in) throws IOException {
79+
this.windowSize = in.readVInt();
80+
}
81+
82+
public int getWindowSize() {
83+
return windowSize;
84+
}
85+
86+
@Override
87+
public void writeTo(StreamOutput out) throws IOException {
88+
out.writeVInt(windowSize);
89+
}
90+
91+
@Override
92+
public boolean equals(Object o) {
93+
if (o == null || getClass() != o.getClass()) return false;
94+
Response response = (Response) o;
95+
return windowSize == response.windowSize;
96+
}
97+
98+
@Override
99+
public int hashCode() {
100+
return Objects.hashCode(windowSize);
101+
}
102+
}
103+
}

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.elasticsearch.inference.Model;
2626
import org.elasticsearch.inference.ModelConfigurations;
2727
import org.elasticsearch.inference.ModelSecrets;
28+
import org.elasticsearch.inference.RerankingInferenceService;
2829
import org.elasticsearch.inference.ServiceSettings;
2930
import org.elasticsearch.inference.SettingsConfiguration;
3031
import org.elasticsearch.inference.TaskSettings;
@@ -48,6 +49,8 @@
4849

4950
public class TestRerankingServiceExtension implements InferenceServiceExtension {
5051

52+
public static final int RERANK_WINDOW_SIZE = 333;
53+
5154
@Override
5255
public List<Factory> getInferenceServiceFactories() {
5356
return List.of(TestInferenceService::new);
@@ -62,7 +65,7 @@ public TestRerankingModel(String inferenceEntityId, TestServiceSettings serviceS
6265
}
6366
}
6467

65-
public static class TestInferenceService extends AbstractTestInferenceService {
68+
public static class TestInferenceService extends AbstractTestInferenceService implements RerankingInferenceService {
6669
public static final String NAME = "test_reranking_service";
6770

6871
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.RERANK);
@@ -200,6 +203,11 @@ protected ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceS
200203
return TestServiceSettings.fromMap(serviceSettingsMap);
201204
}
202205

206+
@Override
207+
public int rerankerWindowSize(String modelId) {
208+
return RERANK_WINDOW_SIZE;
209+
}
210+
203211
public static class Configuration {
204212
public static InferenceServiceConfiguration get() {
205213
return configuration.getOrCompute();

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterBasicLicenseIT.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,9 @@ public static Iterable<Object[]> parameters() {
6161
@Before
6262
public void setup() throws Exception {
6363
ModelRegistry modelRegistry = internalCluster().getCurrentMasterNodeInstance(ModelRegistry.class);
64-
Utils.storeSparseModel(modelRegistry);
64+
Utils.storeSparseModel("sparse-endpoint", modelRegistry);
6565
Utils.storeDenseModel(
66+
"dense-endpoint",
6667
modelRegistry,
6768
randomIntBetween(1, 100),
6869
// dot product means that we need normalized vectors; it's not worth doing that in this test

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ public void setup() throws Exception {
9090
() -> randomFrom(DenseVectorFieldMapperTestUtils.getSupportedSimilarities(elementType))
9191
);
9292
int dimensions = DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, 100);
93-
Utils.storeSparseModel(modelRegistry);
94-
Utils.storeDenseModel(modelRegistry, dimensions, similarity, elementType);
93+
Utils.storeSparseModel("sparse-endpoint", modelRegistry);
94+
Utils.storeDenseModel("dense-endpoint", modelRegistry, dimensions, similarity, elementType);
9595
}
9696

9797
@Override
@@ -122,27 +122,20 @@ public Settings indexSettings() {
122122
}
123123

124124
public void testBulkOperations() throws Exception {
125-
prepareCreate(INDEX_NAME).setMapping(
126-
String.format(
127-
Locale.ROOT,
128-
"""
129-
{
130-
"properties": {
131-
"sparse_field": {
132-
"type": "semantic_text",
133-
"inference_id": "%s"
134-
},
135-
"dense_field": {
136-
"type": "semantic_text",
137-
"inference_id": "%s"
138-
}
139-
}
125+
prepareCreate(INDEX_NAME).setMapping(String.format(Locale.ROOT, """
126+
{
127+
"properties": {
128+
"sparse_field": {
129+
"type": "semantic_text",
130+
"inference_id": "%s"
131+
},
132+
"dense_field": {
133+
"type": "semantic_text",
134+
"inference_id": "%s"
140135
}
141-
""",
142-
TestSparseInferenceServiceExtension.TestInferenceService.NAME,
143-
TestDenseInferenceServiceExtension.TestInferenceService.NAME
144-
)
145-
).get();
136+
}
137+
}
138+
""", "sparse-endpoint", "dense-endpoint")).get();
146139
assertRandomBulkOperations(INDEX_NAME, isIndexRequest -> {
147140
Map<String, Object> map = new HashMap<>();
148141
map.put("sparse_field", isIndexRequest && rarely() ? null : randomSemanticTextInput());
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.integration;
9+
10+
import org.elasticsearch.ElasticsearchStatusException;
11+
import org.elasticsearch.plugins.Plugin;
12+
import org.elasticsearch.test.ESIntegTestCase;
13+
import org.elasticsearch.test.ESTestCase;
14+
import org.elasticsearch.xpack.core.inference.action.GetRerankerWindowSizeAction;
15+
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
16+
import org.elasticsearch.xpack.inference.Utils;
17+
import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin;
18+
import org.elasticsearch.xpack.inference.mock.TestRerankingServiceExtension;
19+
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
20+
import org.junit.Before;
21+
22+
import java.util.Collection;
23+
import java.util.List;
24+
25+
import static org.hamcrest.Matchers.containsString;
26+
27+
@ESTestCase.WithoutEntitlements // due to dependency issue ES-12435
28+
public class RerankWindowSizeIT extends ESIntegTestCase {
29+
30+
@Before
31+
public void setup() throws Exception {
32+
ModelRegistry modelRegistry = internalCluster().getCurrentMasterNodeInstance(ModelRegistry.class);
33+
Utils.storeRerankModel("rerank-endpoint", modelRegistry);
34+
Utils.storeSparseModel("sparse-endpoint", modelRegistry);
35+
}
36+
37+
@Override
38+
protected Collection<Class<? extends Plugin>> nodePlugins() {
39+
return List.of(LocalStateInferencePlugin.class, TestInferenceServicePlugin.class);
40+
}
41+
42+
public void testRerankWindowSizeAction() {
43+
var response = client().execute(GetRerankerWindowSizeAction.INSTANCE, new GetRerankerWindowSizeAction.Request("rerank-endpoint"))
44+
.actionGet();
45+
assertEquals(TestRerankingServiceExtension.RERANK_WINDOW_SIZE, response.getWindowSize());
46+
}
47+
48+
public void testActionNotAReranker() {
49+
var e = expectThrows(
50+
ElasticsearchStatusException.class,
51+
() -> client().execute(GetRerankerWindowSizeAction.INSTANCE, new GetRerankerWindowSizeAction.Request("sparse-endpoint"))
52+
.actionGet()
53+
);
54+
assertThat(e.getMessage(), containsString("Inference endpoint [sparse-endpoint] does not have the rerank task type"));
55+
}
56+
}

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextIndexVersionIT.java

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@
3131
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
3232
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
3333
import org.elasticsearch.xpack.inference.Utils;
34-
import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension;
35-
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;
3634
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
3735
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
3836
import org.junit.Before;
@@ -68,8 +66,8 @@ public void setup() throws Exception {
6866
() -> randomFrom(DenseVectorFieldMapperTestUtils.getSupportedSimilarities(elementType))
6967
);
7068
int dimensions = DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, 100);
71-
Utils.storeSparseModel(modelRegistry);
72-
Utils.storeDenseModel(modelRegistry, dimensions, similarity, elementType);
69+
Utils.storeSparseModel("sparse-endpoint", modelRegistry);
70+
Utils.storeDenseModel("dense-endpoint", modelRegistry, dimensions, similarity, elementType);
7371

7472
Set<IndexVersion> availableVersions = IndexVersionUtils.allReleasedVersions()
7573
.stream()
@@ -113,11 +111,11 @@ public void testSemanticText() throws Exception {
113111
.startObject("properties")
114112
.startObject(SPARSE_SEMANTIC_FIELD)
115113
.field("type", "semantic_text")
116-
.field("inference_id", TestSparseInferenceServiceExtension.TestInferenceService.NAME)
114+
.field("inference_id", "sparse-endpoint")
117115
.endObject()
118116
.startObject(DENSE_SEMANTIC_FIELD)
119117
.field("type", "semantic_text")
120-
.field("inference_id", TestDenseInferenceServiceExtension.TestInferenceService.NAME)
118+
.field("inference_id", "dense-endpoint")
121119
.endObject()
122120
.endObject()
123121
.endObject();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
import org.elasticsearch.xpack.core.inference.action.GetInferenceDiagnosticsAction;
6363
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
6464
import org.elasticsearch.xpack.core.inference.action.GetInferenceServicesAction;
65+
import org.elasticsearch.xpack.core.inference.action.GetRerankerWindowSizeAction;
6566
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
6667
import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy;
6768
import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
@@ -72,6 +73,7 @@
7273
import org.elasticsearch.xpack.inference.action.TransportGetInferenceDiagnosticsAction;
7374
import org.elasticsearch.xpack.inference.action.TransportGetInferenceModelAction;
7475
import org.elasticsearch.xpack.inference.action.TransportGetInferenceServicesAction;
76+
import org.elasticsearch.xpack.inference.action.TransportGetRerankerWindowSizeAction;
7577
import org.elasticsearch.xpack.inference.action.TransportInferenceAction;
7678
import org.elasticsearch.xpack.inference.action.TransportInferenceActionProxy;
7779
import org.elasticsearch.xpack.inference.action.TransportInferenceUsageAction;
@@ -234,7 +236,8 @@ public List<ActionHandler> getActions() {
234236
new ActionHandler(XPackUsageFeatureAction.INFERENCE, TransportInferenceUsageAction.class),
235237
new ActionHandler(GetInferenceDiagnosticsAction.INSTANCE, TransportGetInferenceDiagnosticsAction.class),
236238
new ActionHandler(GetInferenceServicesAction.INSTANCE, TransportGetInferenceServicesAction.class),
237-
new ActionHandler(UnifiedCompletionAction.INSTANCE, TransportUnifiedCompletionInferenceAction.class)
239+
new ActionHandler(UnifiedCompletionAction.INSTANCE, TransportUnifiedCompletionInferenceAction.class),
240+
new ActionHandler(GetRerankerWindowSizeAction.INSTANCE, TransportGetRerankerWindowSizeAction.class)
238241
);
239242
}
240243

0 commit comments

Comments
 (0)