Skip to content

Commit 8d8c5fc

Browse files
author
Max Hniebergall
committed
Add reranker special case to inference API
1 parent 04a8044 commit 8d8c5fc

File tree

4 files changed

+166
-0
lines changed

4 files changed

+166
-0
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ protected void putModel(Model model, ActionListener<Boolean> listener) {
157157
putBuiltInModel(e5Model.getServiceSettings().modelId(), listener);
158158
} else if (model instanceof ElserInternalModel elserModel) {
159159
putBuiltInModel(elserModel.getServiceSettings().modelId(), listener);
160+
} else if (model instanceof ElasticRerankerModel elasticRerankerModel) {
161+
putBuiltInModel(elasticRerankerModel.getServiceSettings().modelId(), listener);
160162
} else if (model instanceof CustomElandModel) {
161163
logger.info("Custom eland model detected, model must have been already loaded into the cluster with eland.");
162164
listener.onResponse(Boolean.TRUE);
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.elasticsearch;
9+
10+
import org.elasticsearch.ResourceNotFoundException;
11+
import org.elasticsearch.action.ActionListener;
12+
import org.elasticsearch.inference.ChunkingSettings;
13+
import org.elasticsearch.inference.Model;
14+
import org.elasticsearch.inference.TaskType;
15+
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
16+
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
17+
18+
public class ElasticRerankerModel extends ElasticsearchInternalModel {
19+
20+
public ElasticRerankerModel(
21+
String inferenceEntityId,
22+
TaskType taskType,
23+
String service,
24+
ElasticRerankerServiceSettings serviceSettings,
25+
ChunkingSettings chunkingSettings
26+
) {
27+
super(inferenceEntityId, taskType, service, serviceSettings, chunkingSettings);
28+
}
29+
30+
@Override
31+
public ElasticRerankerServiceSettings getServiceSettings() {
32+
return (ElasticRerankerServiceSettings) super.getServiceSettings();
33+
}
34+
35+
@Override
36+
public ActionListener<CreateTrainedModelAssignmentAction.Response> getCreateTrainedModelAssignmentActionListener(
37+
Model model,
38+
ActionListener<Boolean> listener
39+
) {
40+
41+
return new ActionListener<>() {
42+
@Override
43+
public void onResponse(CreateTrainedModelAssignmentAction.Response response) {
44+
listener.onResponse(Boolean.TRUE);
45+
}
46+
47+
@Override
48+
public void onFailure(Exception e) {
49+
if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
50+
listener.onFailure(
51+
new ResourceNotFoundException(
52+
"Could not start the Elastic Reranker Endpoint due to [{}]",
53+
e,
54+
e.getMessage()
55+
)
56+
);
57+
return;
58+
}
59+
listener.onFailure(e);
60+
}
61+
};
62+
}
63+
64+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.elasticsearch;
9+
10+
import org.elasticsearch.common.ValidationException;
11+
import org.elasticsearch.common.io.stream.StreamInput;
12+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
13+
import org.elasticsearch.inference.SimilarityMeasure;
14+
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
15+
16+
import java.io.IOException;
17+
import java.util.Arrays;
18+
import java.util.Map;
19+
20+
public class ElasticRerankerServiceSettings extends ElasticsearchInternalServiceSettings {
21+
22+
public static final String NAME = "elastic_reranker_service_settings";
23+
24+
25+
public ElasticRerankerServiceSettings(ElasticsearchInternalServiceSettings other) {
26+
super(other);
27+
}
28+
29+
public ElasticRerankerServiceSettings(
30+
Integer numAllocations,
31+
int numThreads,
32+
String modelId,
33+
AdaptiveAllocationsSettings adaptiveAllocationsSettings
34+
) {
35+
super(numAllocations, numThreads, modelId, adaptiveAllocationsSettings);
36+
}
37+
38+
public ElasticRerankerServiceSettings(StreamInput in) throws IOException {
39+
super(in);
40+
}
41+
42+
/**
43+
* Parse the MultilingualE5SmallServiceSettings from map and validate the setting values.
44+
*
45+
* If required setting are missing or the values are invalid an
46+
* {@link ValidationException} is thrown.
47+
*
48+
* @param map Source map containing the config
49+
* @return The builder
50+
*/
51+
public static Builder fromRequestMap(Map<String, Object> map) {
52+
ValidationException validationException = new ValidationException();
53+
var baseSettings = ElasticsearchInternalServiceSettings.fromMap(map, validationException);
54+
55+
if (validationException.validationErrors().isEmpty() == false) {
56+
throw validationException;
57+
}
58+
59+
return baseSettings;
60+
}
61+
62+
@Override
63+
public String getWriteableName() {
64+
return ElasticRerankerServiceSettings.NAME;
65+
}
66+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
9797
MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86
9898
);
9999

100+
public static final String RERANKER_ID = ".rerank-v1";
101+
100102
public static final int EMBEDDING_MAX_BATCH_SIZE = 10;
101103
public static final String DEFAULT_ELSER_ID = ".elser-2-elasticsearch";
102104
public static final String DEFAULT_E5_ID = ".multilingual-e5-small-elasticsearch";
@@ -223,6 +225,13 @@ public void parseRequestConfig(
223225
)
224226
)
225227
);
228+
} else if (RERANKER_ID.equals(modelId)) {
229+
rerankerCase( inferenceEntityId,
230+
taskType,
231+
config,
232+
serviceSettingsMap,
233+
chunkingSettings,
234+
modelListener);
226235
} else {
227236
customElandCase(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, modelListener);
228237
}
@@ -323,6 +332,31 @@ private static CustomElandInternalServiceSettings elandServiceSettings(
323332
};
324333
}
325334

335+
private void rerankerCase(
336+
String inferenceEntityId,
337+
TaskType taskType,
338+
Map<String, Object> config,
339+
Map<String, Object> serviceSettingsMap,
340+
ChunkingSettings chunkingSettings,
341+
ActionListener<Model> modelListener
342+
) {
343+
344+
var esServiceSettingsBuilder = ElasticsearchInternalServiceSettings.fromRequestMap(serviceSettingsMap);
345+
346+
throwIfNotEmptyMap(config, name());
347+
throwIfNotEmptyMap(serviceSettingsMap, name());
348+
349+
modelListener.onResponse(
350+
new ElasticRerankerModel(
351+
inferenceEntityId,
352+
taskType,
353+
NAME,
354+
new ElasticRerankerServiceSettings(esServiceSettingsBuilder.build()),
355+
chunkingSettings
356+
)
357+
);
358+
}
359+
326360
private void e5Case(
327361
String inferenceEntityId,
328362
TaskType taskType,

0 commit comments

Comments
 (0)