Skip to content

Commit 59e2c38

Browse files
add EIS rerank default inference endpoint
1 parent 4275bc7 commit 59e2c38

File tree

6 files changed

+65
-6
lines changed

6 files changed

+65
-6
lines changed

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public void testGetDefaultEndpoints() throws IOException {
3333
var allModels = getAllModels();
3434
var chatCompletionModels = getModels("_all", TaskType.CHAT_COMPLETION);
3535

36-
assertThat(allModels, hasSize(5));
36+
assertThat(allModels, hasSize(6));
3737
assertThat(chatCompletionModels, hasSize(1));
3838

3939
for (var model : chatCompletionModels) {
@@ -42,6 +42,7 @@ public void testGetDefaultEndpoints() throws IOException {
4242

4343
assertInferenceIdTaskType(allModels, ".rainbow-sprinkles-elastic", TaskType.CHAT_COMPLETION);
4444
assertInferenceIdTaskType(allModels, ".elser-v2-elastic", TaskType.SPARSE_EMBEDDING);
45+
assertInferenceIdTaskType(allModels, ".rerank-v1-elastic", TaskType.RERANK);
4546
}
4647

4748
private static void assertInferenceIdTaskType(List<Map<String, Object>> models, String inferenceId, TaskType taskType) {

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,10 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
197197
{
198198
"model_name": "elser-v2",
199199
"task_types": ["embed/text/sparse"]
200+
},
201+
{
202+
"model_name": "rerank-v1",
203+
"task_types": ["rerank/text/text-similarity"]
200204
}
201205
]
202206
}
@@ -221,16 +225,29 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
221225
".rainbow-sprinkles-elastic",
222226
MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME),
223227
service
228+
),
229+
new InferenceService.DefaultConfigId(
230+
".rerank-v1-elastic",
231+
MinimalServiceSettings.rerank(ElasticInferenceService.NAME),
232+
service
224233
)
225234
)
226235
)
227236
);
228-
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING)));
237+
assertThat(
238+
service.supportedTaskTypes(),
239+
is(EnumSet.of(
240+
TaskType.CHAT_COMPLETION,
241+
TaskType.SPARSE_EMBEDDING,
242+
TaskType.RERANK)
243+
)
244+
);
229245

230246
PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
231247
service.defaultConfigs(listener);
232248
assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic"));
233249
assertThat(listener.actionGet(TIMEOUT).get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
250+
assertThat(listener.actionGet(TIMEOUT).get(2).getConfigurations().getInferenceEntityId(), is(".rerank-v1-elastic"));
234251

235252
var getModelListener = new PlainActionFuture<UnparsedModel>();
236253
// persists the default endpoints
@@ -248,6 +265,10 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
248265
{
249266
"model_name": "elser-v2",
250267
"task_types": ["embed/text/sparse"]
268+
},
269+
{
270+
"model_name": "rerank-v1",
271+
"task_types": ["rerank/text/text-similarity"]
251272
}
252273
]
253274
}
@@ -267,11 +288,17 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
267288
".elser-v2-elastic",
268289
MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME),
269290
service
291+
),
292+
new InferenceService.DefaultConfigId(
293+
".rerank-v1-elastic",
294+
MinimalServiceSettings.rerank(ElasticInferenceService.NAME),
295+
service
270296
)
271297
)
272298
)
273299
);
274300
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING)));
301+
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.RERANK)));
275302

276303
var getModelListener = new PlainActionFuture<UnparsedModel>();
277304
modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
5353
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
5454
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
55+
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings;
5556
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel;
5657
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
5758
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@@ -95,6 +96,10 @@ public class ElasticInferenceService extends SenderService {
9596
static final String DEFAULT_ELSER_MODEL_ID_V2 = "elser-v2";
9697
static final String DEFAULT_ELSER_ENDPOINT_ID_V2 = defaultEndpointId(DEFAULT_ELSER_MODEL_ID_V2);
9798

99+
// rerank-v1
100+
static final String DEFAULT_RERANK_MODEL_ID_V1 = "rerank-v1";
101+
static final String DEFAULT_RERANK_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_RERANK_MODEL_ID_V1);
102+
98103
/**
99104
* The task types that the {@link InferenceAction.Request} can accept.
100105
*/
@@ -159,6 +164,19 @@ private static Map<String, DefaultModelConfig> initDefaultEndpoints(
159164
elasticInferenceServiceComponents
160165
),
161166
MinimalServiceSettings.sparseEmbedding(NAME)
167+
),
168+
DEFAULT_RERANK_MODEL_ID_V1,
169+
new DefaultModelConfig(
170+
new ElasticInferenceServiceRerankModel(
171+
DEFAULT_RERANK_ENDPOINT_ID_V1,
172+
TaskType.RERANK,
173+
NAME,
174+
new ElasticInferenceServiceRerankServiceSettings(DEFAULT_RERANK_MODEL_ID_V1, null),
175+
EmptyTaskSettings.INSTANCE,
176+
EmptySecretSettings.INSTANCE,
177+
elasticInferenceServiceComponents
178+
),
179+
MinimalServiceSettings.rerank(NAME)
162180
)
163181
);
164182
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ public URI uri() {
8787
private URI createUri() throws ElasticsearchStatusException {
8888
try {
8989
// TODO, consider transforming the base URL into a URI for better error handling.
90-
return new URI(elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/rerank");
90+
return new URI(elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/rerank/text/text-similarity");
9191
} catch (URISyntaxException e) {
9292
throw new ElasticsearchStatusException(
9393
"Failed to create URI for service ["

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntity.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ public class ElasticInferenceServiceAuthorizationResponseEntity implements Infer
4343
"embed/text/sparse",
4444
TaskType.SPARSE_EMBEDDING,
4545
"chat",
46-
TaskType.CHAT_COMPLETION
46+
TaskType.CHAT_COMPLETION,
47+
"rerank/text/text-similarity",
48+
TaskType.RERANK
4749
);
4850

4951
@SuppressWarnings("unchecked")

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1294,6 +1294,10 @@ public void testDefaultConfigs_Returns_DefaultEndpoints_WhenTaskTypeIsCorrect()
12941294
{
12951295
"model_name": "elser-v2",
12961296
"task_types": ["embed/text/sparse"]
1297+
},
1298+
{
1299+
"model_name": "rerank-v1",
1300+
"task_types": ["rerank/text/text-similarity"]
12971301
}
12981302
]
12991303
}
@@ -1319,18 +1323,25 @@ public void testDefaultConfigs_Returns_DefaultEndpoints_WhenTaskTypeIsCorrect()
13191323
".rainbow-sprinkles-elastic",
13201324
MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME),
13211325
service
1326+
),
1327+
new InferenceService.DefaultConfigId(
1328+
".rerank-v1-elastic",
1329+
MinimalServiceSettings.rerank(ElasticInferenceService.NAME),
1330+
service
13221331
)
13231332
)
13241333
)
13251334
);
1326-
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING)));
1335+
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.RERANK)));
13271336

13281337
PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
13291338
service.defaultConfigs(listener);
13301339
var models = listener.actionGet(TIMEOUT);
1331-
assertThat(models.size(), is(2));
1340+
assertThat(models.size(), is(3));
13321341
assertThat(models.get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic"));
13331342
assertThat(models.get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
1343+
assertThat(models.get(2).getConfigurations().getInferenceEntityId(), is(".rerank-v1-elastic"));
1344+
13341345
}
13351346
}
13361347

0 commit comments

Comments
 (0)