Skip to content

Commit 052f46a

Browse files
authored
[ML] Reinstate default endpoint for Elastic Rerank behind a feature flag(#117939)"… (#118253)
* Revert "Revert "Adding default endpoint for Elastic Rerank (#117939)" (#118221)"
1 parent bcba5bf commit 052f46a

File tree

11 files changed

+185
-96
lines changed

11 files changed

+185
-96
lines changed

docs/changelog/117939.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 117939
2+
summary: Adding default endpoint for Elastic Rerank
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ public void testGet() throws IOException {
5757

5858
var e5Model = getModel(ElasticsearchInternalService.DEFAULT_E5_ID);
5959
assertDefaultE5Config(e5Model);
60+
61+
var rerankModel = getModel(ElasticsearchInternalService.DEFAULT_RERANK_ID);
62+
assertDefaultRerankConfig(rerankModel);
6063
}
6164

6265
@SuppressWarnings("unchecked")
@@ -125,6 +128,42 @@ private static void assertDefaultE5Config(Map<String, Object> modelConfig) {
125128
assertDefaultChunkingSettings(modelConfig);
126129
}
127130

131+
@SuppressWarnings("unchecked")
132+
public void testInferDeploysDefaultRerank() throws IOException {
133+
var model = getModel(ElasticsearchInternalService.DEFAULT_RERANK_ID);
134+
assertDefaultRerankConfig(model);
135+
136+
var inputs = List.of("Hello World", "Goodnight moon");
137+
var query = "but why";
138+
var queryParams = Map.of("timeout", "120s");
139+
var results = infer(ElasticsearchInternalService.DEFAULT_RERANK_ID, TaskType.RERANK, inputs, query, queryParams);
140+
var embeddings = (List<Map<String, Object>>) results.get("rerank");
141+
assertThat(results.toString(), embeddings, hasSize(2));
142+
}
143+
144+
@SuppressWarnings("unchecked")
145+
private static void assertDefaultRerankConfig(Map<String, Object> modelConfig) {
146+
assertEquals(modelConfig.toString(), ElasticsearchInternalService.DEFAULT_RERANK_ID, modelConfig.get("inference_id"));
147+
assertEquals(modelConfig.toString(), ElasticsearchInternalService.NAME, modelConfig.get("service"));
148+
assertEquals(modelConfig.toString(), TaskType.RERANK.toString(), modelConfig.get("task_type"));
149+
150+
var serviceSettings = (Map<String, Object>) modelConfig.get("service_settings");
151+
assertThat(modelConfig.toString(), serviceSettings.get("model_id"), is(".rerank-v1"));
152+
assertEquals(modelConfig.toString(), 1, serviceSettings.get("num_threads"));
153+
154+
var adaptiveAllocations = (Map<String, Object>) serviceSettings.get("adaptive_allocations");
155+
assertThat(
156+
modelConfig.toString(),
157+
adaptiveAllocations,
158+
Matchers.is(Map.of("enabled", true, "min_number_of_allocations", 0, "max_number_of_allocations", 32))
159+
);
160+
161+
var chunkingSettings = (Map<String, Object>) modelConfig.get("chunking_settings");
162+
assertNull(chunkingSettings);
163+
var taskSettings = (Map<String, Object>) modelConfig.get("task_settings");
164+
assertThat(modelConfig.toString(), taskSettings, Matchers.is(Map.of("return_documents", true)));
165+
}
166+
128167
@SuppressWarnings("unchecked")
129168
private static void assertDefaultChunkingSettings(Map<String, Object> modelConfig) {
130169
var chunkingSettings = (Map<String, Object>) modelConfig.get("chunking_settings");
@@ -159,6 +198,7 @@ public void onFailure(Exception exception) {
159198
var request = createInferenceRequest(
160199
Strings.format("_inference/%s", ElasticsearchInternalService.DEFAULT_ELSER_ID),
161200
inputs,
201+
null,
162202
queryParams
163203
);
164204
client().performRequestAsync(request, listener);

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ private List<Object> getInternalAsList(String endpoint) throws IOException {
338338

339339
protected Map<String, Object> infer(String modelId, List<String> input) throws IOException {
340340
var endpoint = Strings.format("_inference/%s", modelId);
341-
return inferInternal(endpoint, input, Map.of());
341+
return inferInternal(endpoint, input, null, Map.of());
342342
}
343343

344344
protected Deque<ServerSentEvent> streamInferOnMockService(String modelId, TaskType taskType, List<String> input) throws Exception {
@@ -354,7 +354,7 @@ protected Deque<ServerSentEvent> unifiedCompletionInferOnMockService(String mode
354354

355355
private Deque<ServerSentEvent> callAsync(String endpoint, List<String> input) throws Exception {
356356
var request = new Request("POST", endpoint);
357-
request.setJsonEntity(jsonBody(input));
357+
request.setJsonEntity(jsonBody(input, null));
358358

359359
return execAsyncCall(request);
360360
}
@@ -396,33 +396,60 @@ private String createUnifiedJsonBody(List<String> input, String role) throws IOE
396396

397397
protected Map<String, Object> infer(String modelId, TaskType taskType, List<String> input) throws IOException {
398398
var endpoint = Strings.format("_inference/%s/%s", taskType, modelId);
399-
return inferInternal(endpoint, input, Map.of());
399+
return inferInternal(endpoint, input, null, Map.of());
400400
}
401401

402402
protected Map<String, Object> infer(String modelId, TaskType taskType, List<String> input, Map<String, String> queryParameters)
403403
throws IOException {
404404
var endpoint = Strings.format("_inference/%s/%s?error_trace", taskType, modelId);
405-
return inferInternal(endpoint, input, queryParameters);
405+
return inferInternal(endpoint, input, null, queryParameters);
406406
}
407407

408-
protected Request createInferenceRequest(String endpoint, List<String> input, Map<String, String> queryParameters) {
408+
protected Map<String, Object> infer(
409+
String modelId,
410+
TaskType taskType,
411+
List<String> input,
412+
String query,
413+
Map<String, String> queryParameters
414+
) throws IOException {
415+
var endpoint = Strings.format("_inference/%s/%s?error_trace", taskType, modelId);
416+
return inferInternal(endpoint, input, query, queryParameters);
417+
}
418+
419+
protected Request createInferenceRequest(
420+
String endpoint,
421+
List<String> input,
422+
@Nullable String query,
423+
Map<String, String> queryParameters
424+
) {
409425
var request = new Request("POST", endpoint);
410-
request.setJsonEntity(jsonBody(input));
426+
request.setJsonEntity(jsonBody(input, query));
411427
if (queryParameters.isEmpty() == false) {
412428
request.addParameters(queryParameters);
413429
}
414430
return request;
415431
}
416432

417-
private Map<String, Object> inferInternal(String endpoint, List<String> input, Map<String, String> queryParameters) throws IOException {
418-
var request = createInferenceRequest(endpoint, input, queryParameters);
433+
private Map<String, Object> inferInternal(
434+
String endpoint,
435+
List<String> input,
436+
@Nullable String query,
437+
Map<String, String> queryParameters
438+
) throws IOException {
439+
var request = createInferenceRequest(endpoint, input, query, queryParameters);
419440
var response = client().performRequest(request);
420441
assertOkOrCreated(response);
421442
return entityAsMap(response);
422443
}
423444

424-
private String jsonBody(List<String> input) {
425-
var bodyBuilder = new StringBuilder("{\"input\": [");
445+
private String jsonBody(List<String> input, @Nullable String query) {
446+
final StringBuilder bodyBuilder = new StringBuilder("{");
447+
448+
if (query != null) {
449+
bodyBuilder.append("\"query\":\"").append(query).append("\",");
450+
}
451+
452+
bodyBuilder.append("\"input\": [");
426453
for (var in : input) {
427454
bodyBuilder.append('"').append(in).append('"').append(',');
428455
}

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public void testCRUD() throws IOException {
4949
}
5050

5151
var getAllModels = getAllModels();
52-
int numModels = 11;
52+
int numModels = 12;
5353
assertThat(getAllModels, hasSize(numModels));
5454

5555
var getSparseModels = getModels("_all", TaskType.SPARSE_EMBEDDING);
@@ -537,7 +537,7 @@ private static String expectedResult(String input) {
537537
}
538538

539539
public void testGetZeroModels() throws IOException {
540-
var models = getModels("_all", TaskType.RERANK);
540+
var models = getModels("_all", TaskType.COMPLETION);
541541
assertThat(models, empty());
542542
}
543543
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,12 @@
6363
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
6464
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings;
6565
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalTextEmbeddingServiceSettings;
66-
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandRerankTaskSettings;
6766
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticRerankerServiceSettings;
6867
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings;
6968
import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceSettings;
7069
import org.elasticsearch.xpack.inference.services.elasticsearch.ElserMlNodeTaskSettings;
7170
import org.elasticsearch.xpack.inference.services.elasticsearch.MultilingualE5SmallInternalServiceSettings;
71+
import org.elasticsearch.xpack.inference.services.elasticsearch.RerankTaskSettings;
7272
import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionServiceSettings;
7373
import org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsServiceSettings;
7474
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings;
@@ -518,9 +518,7 @@ private static void addCustomElandWriteables(final List<NamedWriteableRegistry.E
518518
CustomElandInternalTextEmbeddingServiceSettings::new
519519
)
520520
);
521-
namedWriteables.add(
522-
new NamedWriteableRegistry.Entry(TaskSettings.class, CustomElandRerankTaskSettings.NAME, CustomElandRerankTaskSettings::new)
523-
);
521+
namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, RerankTaskSettings.NAME, RerankTaskSettings::new));
524522
}
525523

526524
private static void addAnthropicNamedWritables(List<NamedWriteableRegistry.Entry> namedWriteables) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankModel.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import java.util.HashMap;
1818
import java.util.Map;
1919

20-
import static org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandRerankTaskSettings.RETURN_DOCUMENTS;
20+
import static org.elasticsearch.xpack.inference.services.elasticsearch.RerankTaskSettings.RETURN_DOCUMENTS;
2121

2222
public class CustomElandRerankModel extends CustomElandModel {
2323

@@ -26,7 +26,7 @@ public CustomElandRerankModel(
2626
TaskType taskType,
2727
String service,
2828
CustomElandInternalServiceSettings serviceSettings,
29-
CustomElandRerankTaskSettings taskSettings
29+
RerankTaskSettings taskSettings
3030
) {
3131
super(inferenceEntityId, taskType, service, serviceSettings, taskSettings);
3232
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerModel.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import org.elasticsearch.ResourceNotFoundException;
1111
import org.elasticsearch.action.ActionListener;
12-
import org.elasticsearch.inference.ChunkingSettings;
1312
import org.elasticsearch.inference.Model;
1413
import org.elasticsearch.inference.TaskType;
1514
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
@@ -22,9 +21,9 @@ public ElasticRerankerModel(
2221
TaskType taskType,
2322
String service,
2423
ElasticRerankerServiceSettings serviceSettings,
25-
ChunkingSettings chunkingSettings
24+
RerankTaskSettings taskSettings
2625
) {
27-
super(inferenceEntityId, taskType, service, serviceSettings, chunkingSettings);
26+
super(inferenceEntityId, taskType, service, serviceSettings, taskSettings);
2827
}
2928

3029
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
103103
public static final int EMBEDDING_MAX_BATCH_SIZE = 10;
104104
public static final String DEFAULT_ELSER_ID = ".elser-2-elasticsearch";
105105
public static final String DEFAULT_E5_ID = ".multilingual-e5-small-elasticsearch";
106+
public static final String DEFAULT_RERANK_ID = ".rerank-v1-elasticsearch";
106107

107108
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
108109
TaskType.RERANK,
@@ -227,7 +228,7 @@ public void parseRequestConfig(
227228
)
228229
);
229230
} else if (RERANKER_ID.equals(modelId)) {
230-
rerankerCase(inferenceEntityId, taskType, config, serviceSettingsMap, chunkingSettings, modelListener);
231+
rerankerCase(inferenceEntityId, taskType, config, serviceSettingsMap, taskSettingsMap, modelListener);
231232
} else {
232233
customElandCase(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, modelListener);
233234
}
@@ -310,7 +311,7 @@ private static CustomElandModel createCustomElandModel(
310311
taskType,
311312
NAME,
312313
elandServiceSettings(serviceSettings, context),
313-
CustomElandRerankTaskSettings.fromMap(taskSettings)
314+
RerankTaskSettings.fromMap(taskSettings)
314315
);
315316
default -> throw new ElasticsearchStatusException(TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), RestStatus.BAD_REQUEST);
316317
};
@@ -333,7 +334,7 @@ private void rerankerCase(
333334
TaskType taskType,
334335
Map<String, Object> config,
335336
Map<String, Object> serviceSettingsMap,
336-
ChunkingSettings chunkingSettings,
337+
Map<String, Object> taskSettingsMap,
337338
ActionListener<Model> modelListener
338339
) {
339340

@@ -348,7 +349,7 @@ private void rerankerCase(
348349
taskType,
349350
NAME,
350351
new ElasticRerankerServiceSettings(esServiceSettingsBuilder.build()),
351-
chunkingSettings
352+
RerankTaskSettings.fromMap(taskSettingsMap)
352353
)
353354
);
354355
}
@@ -514,6 +515,14 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
514515
ElserMlNodeTaskSettings.DEFAULT,
515516
chunkingSettings
516517
);
518+
} else if (modelId.equals(RERANKER_ID)) {
519+
return new ElasticRerankerModel(
520+
inferenceEntityId,
521+
taskType,
522+
NAME,
523+
new ElasticRerankerServiceSettings(ElasticsearchInternalServiceSettings.fromPersistedMap(serviceSettingsMap)),
524+
RerankTaskSettings.fromMap(taskSettingsMap)
525+
);
517526
} else {
518527
return createCustomElandModel(
519528
inferenceEntityId,
@@ -665,21 +674,23 @@ public void inferRerank(
665674
) {
666675
var request = buildInferenceRequest(model.mlNodeDeploymentId(), new TextSimilarityConfigUpdate(query), inputs, inputType, timeout);
667676

668-
var modelSettings = (CustomElandRerankTaskSettings) model.getTaskSettings();
669-
var requestSettings = CustomElandRerankTaskSettings.fromMap(requestTaskSettings);
670-
Boolean returnDocs = CustomElandRerankTaskSettings.of(modelSettings, requestSettings).returnDocuments();
677+
var returnDocs = Boolean.TRUE;
678+
if (model.getTaskSettings() instanceof RerankTaskSettings modelSettings) {
679+
var requestSettings = RerankTaskSettings.fromMap(requestTaskSettings);
680+
returnDocs = RerankTaskSettings.of(modelSettings, requestSettings).returnDocuments();
681+
}
671682

672683
Function<Integer, String> inputSupplier = returnDocs == Boolean.TRUE ? inputs::get : i -> null;
673684

674-
client.execute(
675-
InferModelAction.INSTANCE,
676-
request,
677-
listener.delegateFailureAndWrap(
678-
(l, inferenceResult) -> l.onResponse(
679-
textSimilarityResultsToRankedDocs(inferenceResult.getInferenceResults(), inputSupplier)
680-
)
681-
)
685+
ActionListener<InferModelAction.Response> mlResultsListener = listener.delegateFailureAndWrap(
686+
(l, inferenceResult) -> l.onResponse(textSimilarityResultsToRankedDocs(inferenceResult.getInferenceResults(), inputSupplier))
687+
);
688+
689+
var maybeDeployListener = mlResultsListener.delegateResponse(
690+
(l, exception) -> maybeStartDeployment(model, exception, request, mlResultsListener)
682691
);
692+
693+
client.execute(InferModelAction.INSTANCE, request, maybeDeployListener);
683694
}
684695

685696
public void chunkedInfer(
@@ -823,7 +834,8 @@ private RankedDocsResults textSimilarityResultsToRankedDocs(
823834
public List<DefaultConfigId> defaultConfigIds() {
824835
return List.of(
825836
new DefaultConfigId(DEFAULT_ELSER_ID, TaskType.SPARSE_EMBEDDING, this),
826-
new DefaultConfigId(DEFAULT_E5_ID, TaskType.TEXT_EMBEDDING, this)
837+
new DefaultConfigId(DEFAULT_E5_ID, TaskType.TEXT_EMBEDDING, this),
838+
new DefaultConfigId(DEFAULT_RERANK_ID, TaskType.RERANK, this)
827839
);
828840
}
829841

@@ -916,12 +928,19 @@ private List<Model> defaultConfigs(boolean useLinuxOptimizedModel) {
916928
),
917929
ChunkingSettingsBuilder.DEFAULT_SETTINGS
918930
);
919-
return List.of(defaultElser, defaultE5);
931+
var defaultRerank = new ElasticRerankerModel(
932+
DEFAULT_RERANK_ID,
933+
TaskType.RERANK,
934+
NAME,
935+
new ElasticRerankerServiceSettings(null, 1, RERANKER_ID, new AdaptiveAllocationsSettings(Boolean.TRUE, 0, 32)),
936+
RerankTaskSettings.DEFAULT_SETTINGS
937+
);
938+
return List.of(defaultElser, defaultE5, defaultRerank);
920939
}
921940

922941
@Override
923942
boolean isDefaultId(String inferenceId) {
924-
return DEFAULT_ELSER_ID.equals(inferenceId) || DEFAULT_E5_ID.equals(inferenceId);
943+
return DEFAULT_ELSER_ID.equals(inferenceId) || DEFAULT_E5_ID.equals(inferenceId) || DEFAULT_RERANK_ID.equals(inferenceId);
925944
}
926945

927946
static EmbeddingRequestChunker.EmbeddingType embeddingTypeFromTaskTypeAndSettings(

0 commit comments

Comments
 (0)