Skip to content

Commit 59da06b

Browse files
authored
Adding default endpoint for Elastic Rerank (#117939) (#118153)
* Adding default endpoint for Elastic Rerank * CustomElandRerankTaskSettings -> RerankTaskSettings * Update docs/changelog/117939.yaml
1 parent ca99157 commit 59da06b

File tree

11 files changed

+185
-96
lines changed

11 files changed

+185
-96
lines changed

docs/changelog/117939.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 117939
2+
summary: Adding default endpoint for Elastic Rerank
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ public void testGet() throws IOException {
4949

5050
var e5Model = getModel(ElasticsearchInternalService.DEFAULT_E5_ID);
5151
assertDefaultE5Config(e5Model);
52+
53+
var rerankModel = getModel(ElasticsearchInternalService.DEFAULT_RERANK_ID);
54+
assertDefaultRerankConfig(rerankModel);
5255
}
5356

5457
@SuppressWarnings("unchecked")
@@ -117,6 +120,42 @@ private static void assertDefaultE5Config(Map<String, Object> modelConfig) {
117120
assertDefaultChunkingSettings(modelConfig);
118121
}
119122

123+
@SuppressWarnings("unchecked")
124+
public void testInferDeploysDefaultRerank() throws IOException {
125+
var model = getModel(ElasticsearchInternalService.DEFAULT_RERANK_ID);
126+
assertDefaultRerankConfig(model);
127+
128+
var inputs = List.of("Hello World", "Goodnight moon");
129+
var query = "but why";
130+
var queryParams = Map.of("timeout", "120s");
131+
var results = infer(ElasticsearchInternalService.DEFAULT_RERANK_ID, TaskType.RERANK, inputs, query, queryParams);
132+
var embeddings = (List<Map<String, Object>>) results.get("rerank");
133+
assertThat(results.toString(), embeddings, hasSize(2));
134+
}
135+
136+
@SuppressWarnings("unchecked")
137+
private static void assertDefaultRerankConfig(Map<String, Object> modelConfig) {
138+
assertEquals(modelConfig.toString(), ElasticsearchInternalService.DEFAULT_RERANK_ID, modelConfig.get("inference_id"));
139+
assertEquals(modelConfig.toString(), ElasticsearchInternalService.NAME, modelConfig.get("service"));
140+
assertEquals(modelConfig.toString(), TaskType.RERANK.toString(), modelConfig.get("task_type"));
141+
142+
var serviceSettings = (Map<String, Object>) modelConfig.get("service_settings");
143+
assertThat(modelConfig.toString(), serviceSettings.get("model_id"), is(".rerank-v1"));
144+
assertEquals(modelConfig.toString(), 1, serviceSettings.get("num_threads"));
145+
146+
var adaptiveAllocations = (Map<String, Object>) serviceSettings.get("adaptive_allocations");
147+
assertThat(
148+
modelConfig.toString(),
149+
adaptiveAllocations,
150+
Matchers.is(Map.of("enabled", true, "min_number_of_allocations", 0, "max_number_of_allocations", 32))
151+
);
152+
153+
var chunkingSettings = (Map<String, Object>) modelConfig.get("chunking_settings");
154+
assertNull(chunkingSettings);
155+
var taskSettings = (Map<String, Object>) modelConfig.get("task_settings");
156+
assertThat(modelConfig.toString(), taskSettings, Matchers.is(Map.of("return_documents", true)));
157+
}
158+
120159
@SuppressWarnings("unchecked")
121160
private static void assertDefaultChunkingSettings(Map<String, Object> modelConfig) {
122161
var chunkingSettings = (Map<String, Object>) modelConfig.get("chunking_settings");
@@ -151,6 +190,7 @@ public void onFailure(Exception exception) {
151190
var request = createInferenceRequest(
152191
Strings.format("_inference/%s", ElasticsearchInternalService.DEFAULT_ELSER_ID),
153192
inputs,
193+
null,
154194
queryParams
155195
);
156196
client().performRequestAsync(request, listener);

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ private List<Object> getInternalAsList(String endpoint) throws IOException {
333333

334334
protected Map<String, Object> infer(String modelId, List<String> input) throws IOException {
335335
var endpoint = Strings.format("_inference/%s", modelId);
336-
return inferInternal(endpoint, input, Map.of());
336+
return inferInternal(endpoint, input, null, Map.of());
337337
}
338338

339339
protected Deque<ServerSentEvent> streamInferOnMockService(String modelId, TaskType taskType, List<String> input) throws Exception {
@@ -344,7 +344,7 @@ protected Deque<ServerSentEvent> streamInferOnMockService(String modelId, TaskTy
344344
private Deque<ServerSentEvent> callAsync(String endpoint, List<String> input) throws Exception {
345345
var responseConsumer = new AsyncInferenceResponseConsumer();
346346
var request = new Request("POST", endpoint);
347-
request.setJsonEntity(jsonBody(input));
347+
request.setJsonEntity(jsonBody(input, null));
348348
request.setOptions(RequestOptions.DEFAULT.toBuilder().setHttpAsyncResponseConsumerFactory(() -> responseConsumer).build());
349349
var latch = new CountDownLatch(1);
350350
client().performRequestAsync(request, new ResponseListener() {
@@ -364,33 +364,60 @@ public void onFailure(Exception exception) {
364364

365365
protected Map<String, Object> infer(String modelId, TaskType taskType, List<String> input) throws IOException {
366366
var endpoint = Strings.format("_inference/%s/%s", taskType, modelId);
367-
return inferInternal(endpoint, input, Map.of());
367+
return inferInternal(endpoint, input, null, Map.of());
368368
}
369369

370370
protected Map<String, Object> infer(String modelId, TaskType taskType, List<String> input, Map<String, String> queryParameters)
371371
throws IOException {
372372
var endpoint = Strings.format("_inference/%s/%s?error_trace", taskType, modelId);
373-
return inferInternal(endpoint, input, queryParameters);
373+
return inferInternal(endpoint, input, null, queryParameters);
374374
}
375375

376-
protected Request createInferenceRequest(String endpoint, List<String> input, Map<String, String> queryParameters) {
376+
protected Map<String, Object> infer(
377+
String modelId,
378+
TaskType taskType,
379+
List<String> input,
380+
String query,
381+
Map<String, String> queryParameters
382+
) throws IOException {
383+
var endpoint = Strings.format("_inference/%s/%s?error_trace", taskType, modelId);
384+
return inferInternal(endpoint, input, query, queryParameters);
385+
}
386+
387+
protected Request createInferenceRequest(
388+
String endpoint,
389+
List<String> input,
390+
@Nullable String query,
391+
Map<String, String> queryParameters
392+
) {
377393
var request = new Request("POST", endpoint);
378-
request.setJsonEntity(jsonBody(input));
394+
request.setJsonEntity(jsonBody(input, query));
379395
if (queryParameters.isEmpty() == false) {
380396
request.addParameters(queryParameters);
381397
}
382398
return request;
383399
}
384400

385-
private Map<String, Object> inferInternal(String endpoint, List<String> input, Map<String, String> queryParameters) throws IOException {
386-
var request = createInferenceRequest(endpoint, input, queryParameters);
401+
private Map<String, Object> inferInternal(
402+
String endpoint,
403+
List<String> input,
404+
@Nullable String query,
405+
Map<String, String> queryParameters
406+
) throws IOException {
407+
var request = createInferenceRequest(endpoint, input, query, queryParameters);
387408
var response = client().performRequest(request);
388409
assertOkOrCreated(response);
389410
return entityAsMap(response);
390411
}
391412

392-
private String jsonBody(List<String> input) {
393-
var bodyBuilder = new StringBuilder("{\"input\": [");
413+
private String jsonBody(List<String> input, @Nullable String query) {
414+
final StringBuilder bodyBuilder = new StringBuilder("{");
415+
416+
if (query != null) {
417+
bodyBuilder.append("\"query\":\"").append(query).append("\",");
418+
}
419+
420+
bodyBuilder.append("\"input\": [");
394421
for (var in : input) {
395422
bodyBuilder.append('"').append(in).append('"').append(',');
396423
}

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ public void testCRUD() throws IOException {
4141
}
4242

4343
var getAllModels = getAllModels();
44-
int numModels = 11;
44+
int numModels = 12;
4545
assertThat(getAllModels, hasSize(numModels));
4646

4747
var getSparseModels = getModels("_all", TaskType.SPARSE_EMBEDDING);
@@ -328,7 +328,7 @@ public void testSupportedStream() throws Exception {
328328
}
329329

330330
public void testGetZeroModels() throws IOException {
331-
var models = getModels("_all", TaskType.RERANK);
331+
var models = getModels("_all", TaskType.COMPLETION);
332332
assertThat(models, empty());
333333
}
334334
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,12 @@
6262
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
6363
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings;
6464
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalTextEmbeddingServiceSettings;
65-
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandRerankTaskSettings;
6665
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticRerankerServiceSettings;
6766
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings;
6867
import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceSettings;
6968
import org.elasticsearch.xpack.inference.services.elasticsearch.ElserMlNodeTaskSettings;
7069
import org.elasticsearch.xpack.inference.services.elasticsearch.MultilingualE5SmallInternalServiceSettings;
70+
import org.elasticsearch.xpack.inference.services.elasticsearch.RerankTaskSettings;
7171
import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionServiceSettings;
7272
import org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsServiceSettings;
7373
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings;
@@ -510,9 +510,7 @@ private static void addCustomElandWriteables(final List<NamedWriteableRegistry.E
510510
CustomElandInternalTextEmbeddingServiceSettings::new
511511
)
512512
);
513-
namedWriteables.add(
514-
new NamedWriteableRegistry.Entry(TaskSettings.class, CustomElandRerankTaskSettings.NAME, CustomElandRerankTaskSettings::new)
515-
);
513+
namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, RerankTaskSettings.NAME, RerankTaskSettings::new));
516514
}
517515

518516
private static void addAnthropicNamedWritables(List<NamedWriteableRegistry.Entry> namedWriteables) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankModel.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import java.util.HashMap;
1818
import java.util.Map;
1919

20-
import static org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandRerankTaskSettings.RETURN_DOCUMENTS;
20+
import static org.elasticsearch.xpack.inference.services.elasticsearch.RerankTaskSettings.RETURN_DOCUMENTS;
2121

2222
public class CustomElandRerankModel extends CustomElandModel {
2323

@@ -26,7 +26,7 @@ public CustomElandRerankModel(
2626
TaskType taskType,
2727
String service,
2828
CustomElandInternalServiceSettings serviceSettings,
29-
CustomElandRerankTaskSettings taskSettings
29+
RerankTaskSettings taskSettings
3030
) {
3131
super(inferenceEntityId, taskType, service, serviceSettings, taskSettings);
3232
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerModel.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import org.elasticsearch.ResourceNotFoundException;
1111
import org.elasticsearch.action.ActionListener;
12-
import org.elasticsearch.inference.ChunkingSettings;
1312
import org.elasticsearch.inference.Model;
1413
import org.elasticsearch.inference.TaskType;
1514
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
@@ -22,9 +21,9 @@ public ElasticRerankerModel(
2221
TaskType taskType,
2322
String service,
2423
ElasticRerankerServiceSettings serviceSettings,
25-
ChunkingSettings chunkingSettings
24+
RerankTaskSettings taskSettings
2625
) {
27-
super(inferenceEntityId, taskType, service, serviceSettings, chunkingSettings);
26+
super(inferenceEntityId, taskType, service, serviceSettings, taskSettings);
2827
}
2928

3029
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
101101
public static final int EMBEDDING_MAX_BATCH_SIZE = 10;
102102
public static final String DEFAULT_ELSER_ID = ".elser-2-elasticsearch";
103103
public static final String DEFAULT_E5_ID = ".multilingual-e5-small-elasticsearch";
104+
public static final String DEFAULT_RERANK_ID = ".rerank-v1-elasticsearch";
104105

105106
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
106107
TaskType.RERANK,
@@ -225,7 +226,7 @@ public void parseRequestConfig(
225226
)
226227
);
227228
} else if (RERANKER_ID.equals(modelId)) {
228-
rerankerCase(inferenceEntityId, taskType, config, serviceSettingsMap, chunkingSettings, modelListener);
229+
rerankerCase(inferenceEntityId, taskType, config, serviceSettingsMap, taskSettingsMap, modelListener);
229230
} else {
230231
customElandCase(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, modelListener);
231232
}
@@ -308,7 +309,7 @@ private static CustomElandModel createCustomElandModel(
308309
taskType,
309310
NAME,
310311
elandServiceSettings(serviceSettings, context),
311-
CustomElandRerankTaskSettings.fromMap(taskSettings)
312+
RerankTaskSettings.fromMap(taskSettings)
312313
);
313314
default -> throw new ElasticsearchStatusException(TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), RestStatus.BAD_REQUEST);
314315
};
@@ -331,7 +332,7 @@ private void rerankerCase(
331332
TaskType taskType,
332333
Map<String, Object> config,
333334
Map<String, Object> serviceSettingsMap,
334-
ChunkingSettings chunkingSettings,
335+
Map<String, Object> taskSettingsMap,
335336
ActionListener<Model> modelListener
336337
) {
337338

@@ -346,7 +347,7 @@ private void rerankerCase(
346347
taskType,
347348
NAME,
348349
new ElasticRerankerServiceSettings(esServiceSettingsBuilder.build()),
349-
chunkingSettings
350+
RerankTaskSettings.fromMap(taskSettingsMap)
350351
)
351352
);
352353
}
@@ -512,6 +513,14 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
512513
ElserMlNodeTaskSettings.DEFAULT,
513514
chunkingSettings
514515
);
516+
} else if (modelId.equals(RERANKER_ID)) {
517+
return new ElasticRerankerModel(
518+
inferenceEntityId,
519+
taskType,
520+
NAME,
521+
new ElasticRerankerServiceSettings(ElasticsearchInternalServiceSettings.fromPersistedMap(serviceSettingsMap)),
522+
RerankTaskSettings.fromMap(taskSettingsMap)
523+
);
515524
} else {
516525
return createCustomElandModel(
517526
inferenceEntityId,
@@ -653,21 +662,23 @@ public void inferRerank(
653662
) {
654663
var request = buildInferenceRequest(model.mlNodeDeploymentId(), new TextSimilarityConfigUpdate(query), inputs, inputType, timeout);
655664

656-
var modelSettings = (CustomElandRerankTaskSettings) model.getTaskSettings();
657-
var requestSettings = CustomElandRerankTaskSettings.fromMap(requestTaskSettings);
658-
Boolean returnDocs = CustomElandRerankTaskSettings.of(modelSettings, requestSettings).returnDocuments();
665+
var returnDocs = Boolean.TRUE;
666+
if (model.getTaskSettings() instanceof RerankTaskSettings modelSettings) {
667+
var requestSettings = RerankTaskSettings.fromMap(requestTaskSettings);
668+
returnDocs = RerankTaskSettings.of(modelSettings, requestSettings).returnDocuments();
669+
}
659670

660671
Function<Integer, String> inputSupplier = returnDocs == Boolean.TRUE ? inputs::get : i -> null;
661672

662-
client.execute(
663-
InferModelAction.INSTANCE,
664-
request,
665-
listener.delegateFailureAndWrap(
666-
(l, inferenceResult) -> l.onResponse(
667-
textSimilarityResultsToRankedDocs(inferenceResult.getInferenceResults(), inputSupplier)
668-
)
669-
)
673+
ActionListener<InferModelAction.Response> mlResultsListener = listener.delegateFailureAndWrap(
674+
(l, inferenceResult) -> l.onResponse(textSimilarityResultsToRankedDocs(inferenceResult.getInferenceResults(), inputSupplier))
675+
);
676+
677+
var maybeDeployListener = mlResultsListener.delegateResponse(
678+
(l, exception) -> maybeStartDeployment(model, exception, request, mlResultsListener)
670679
);
680+
681+
client.execute(InferModelAction.INSTANCE, request, maybeDeployListener);
671682
}
672683

673684
public void chunkedInfer(
@@ -811,7 +822,8 @@ private RankedDocsResults textSimilarityResultsToRankedDocs(
811822
public List<DefaultConfigId> defaultConfigIds() {
812823
return List.of(
813824
new DefaultConfigId(DEFAULT_ELSER_ID, TaskType.SPARSE_EMBEDDING, this),
814-
new DefaultConfigId(DEFAULT_E5_ID, TaskType.TEXT_EMBEDDING, this)
825+
new DefaultConfigId(DEFAULT_E5_ID, TaskType.TEXT_EMBEDDING, this),
826+
new DefaultConfigId(DEFAULT_RERANK_ID, TaskType.RERANK, this)
815827
);
816828
}
817829

@@ -903,12 +915,19 @@ private List<Model> defaultConfigs(boolean useLinuxOptimizedModel) {
903915
),
904916
ChunkingSettingsBuilder.DEFAULT_SETTINGS
905917
);
906-
return List.of(defaultElser, defaultE5);
918+
var defaultRerank = new ElasticRerankerModel(
919+
DEFAULT_RERANK_ID,
920+
TaskType.RERANK,
921+
NAME,
922+
new ElasticRerankerServiceSettings(null, 1, RERANKER_ID, new AdaptiveAllocationsSettings(Boolean.TRUE, 0, 32)),
923+
RerankTaskSettings.DEFAULT_SETTINGS
924+
);
925+
return List.of(defaultElser, defaultE5, defaultRerank);
907926
}
908927

909928
@Override
910929
boolean isDefaultId(String inferenceId) {
911-
return DEFAULT_ELSER_ID.equals(inferenceId) || DEFAULT_E5_ID.equals(inferenceId);
930+
return DEFAULT_ELSER_ID.equals(inferenceId) || DEFAULT_E5_ID.equals(inferenceId) || DEFAULT_RERANK_ID.equals(inferenceId);
912931
}
913932

914933
static EmbeddingRequestChunker.EmbeddingType embeddingTypeFromTaskTypeAndSettings(

0 commit comments

Comments
 (0)