Skip to content

Commit c7cf850

Browse files
Evgenii-Kazannikelasticsearchmachine
andauthored
Add Hugging Face Rerank support (#127966)
* Add Hugging Face Rerank support * Address comments * Add transport version * Add transport version * Add to inference service and crud IT rerank tests * Refactor slightly / error message * correct 'testGetConfiguration' test case * apply suggestions * fix tests * apply suggestions * [CI] Auto commit changes from spotless * add changelog information --------- Co-authored-by: elasticsearchmachine <[email protected]>
1 parent f6e4a26 commit c7cf850

File tree

31 files changed

+1485
-85
lines changed

31 files changed

+1485
-85
lines changed

docs/changelog/127966.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 127966
2+
summary: "[ML] Add Rerank support to the Inference Plugin"
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ static TransportVersion def(int id) {
179179
public static final TransportVersion V_8_19_FIELD_CAPS_ADD_CLUSTER_ALIAS = def(8_841_0_32);
180180
public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME_8_19 = def(8_841_0_34);
181181
public static final TransportVersion RERANKER_FAILURES_ALLOWED_8_19 = def(8_841_0_35);
182+
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED_8_19 = def(8_841_0_36);
182183
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
183184
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
184185
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
@@ -261,7 +262,7 @@ static TransportVersion def(int id) {
261262
public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME = def(9_077_0_00);
262263
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_CHAT_COMPLETION_ADDED = def(9_078_0_00);
263264
public static final TransportVersion NODES_STATS_SUPPORTS_MULTI_PROJECT = def(9_079_0_00);
264-
265+
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED = def(9_080_0_00);
265266
/*
266267
* STOP! READ THIS FIRST! No, really,
267268
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _

server/src/test/java/org/elasticsearch/inference/SettingsConfigurationTestUtils.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@
2020
public class SettingsConfigurationTestUtils {
2121

2222
public static SettingsConfiguration getRandomSettingsConfigurationField() {
23-
return new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING)).setDefaultValue(
24-
randomAlphaOfLength(10)
25-
)
23+
return new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK))
24+
.setDefaultValue(randomAlphaOfLength(10))
2625
.setDescription(randomAlphaOfLength(10))
2726
.setLabel(randomAlphaOfLength(10))
2827
.setRequired(randomBoolean())

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/RankedDocsResultsTests.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,13 @@ private List<RankedDocsResults.RankedDoc> rankedDocsNullStringToEmpty(List<Ranke
8282
protected RankedDocsResults doParseInstance(XContentParser parser) throws IOException {
8383
return RankedDocsResults.createParser(true).apply(parser, null);
8484
}
85+
86+
public record RerankExpectation(Map<String, Object> rankedDocFields) {}
87+
88+
public static Map<String, Object> buildExpectationRerank(List<RerankExpectation> rerank) {
89+
return Map.of(
90+
RankedDocsResults.RERANK,
91+
rerank.stream().map(rerankExpectation -> Map.of(RankedDocsResults.RankedDoc.NAME, rerankExpectation.rankedDocFields)).toList()
92+
);
93+
}
8594
}

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,20 @@ static String mockDenseServiceModelConfig() {
171171
""";
172172
}
173173

174+
static String mockRerankServiceModelConfig() {
175+
return """
176+
{
177+
"service": "test_reranking_service",
178+
"service_settings": {
179+
"model_id": "my_model",
180+
"api_key": "abc64"
181+
},
182+
"task_settings": {
183+
}
184+
}
185+
""";
186+
}
187+
174188
static void deleteModel(String modelId) throws IOException {
175189
var request = new Request("DELETE", "_inference/" + modelId);
176190
var response = client().performRequest(request);
@@ -484,6 +498,10 @@ private String jsonBody(List<String> input, @Nullable String query) {
484498
@SuppressWarnings("unchecked")
485499
protected void assertNonEmptyInferenceResults(Map<String, Object> resultMap, int expectedNumberOfResults, TaskType taskType) {
486500
switch (taskType) {
501+
case RERANK -> {
502+
var results = (List<Map<String, Object>>) resultMap.get(TaskType.RERANK.toString());
503+
assertThat(results, hasSize(expectedNumberOfResults));
504+
}
487505
case SPARSE_EMBEDDING -> {
488506
var results = (List<Map<String, Object>>) resultMap.get(TaskType.SPARSE_EMBEDDING.toString());
489507
assertThat(results, hasSize(expectedNumberOfResults));

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,12 @@ public void testCRUD() throws IOException {
5353
for (int i = 0; i < 4; i++) {
5454
putModel("te_model_" + i, mockDenseServiceModelConfig(), TaskType.TEXT_EMBEDDING);
5555
}
56+
for (int i = 0; i < 3; i++) {
57+
putModel("re-model-" + i, mockRerankServiceModelConfig(), TaskType.RERANK);
58+
}
5659

5760
var getAllModels = getAllModels();
58-
int numModels = 12;
61+
int numModels = 15;
5962
assertThat(getAllModels, hasSize(numModels));
6063

6164
var getSparseModels = getModels("_all", TaskType.SPARSE_EMBEDDING);
@@ -71,6 +74,13 @@ public void testCRUD() throws IOException {
7174
for (var denseModel : getDenseModels) {
7275
assertEquals("text_embedding", denseModel.get("task_type"));
7376
}
77+
78+
var getRerankModels = getModels("_all", TaskType.RERANK);
79+
int numRerankModels = 4;
80+
assertThat(getRerankModels, hasSize(numRerankModels));
81+
for (var denseModel : getRerankModels) {
82+
assertEquals("rerank", denseModel.get("task_type"));
83+
}
7484
String oldApiKey;
7585
{
7686
var singleModel = getModels("se_model_1", TaskType.SPARSE_EMBEDDING);
@@ -100,6 +110,9 @@ public void testCRUD() throws IOException {
100110
for (int i = 0; i < 4; i++) {
101111
deleteModel("te_model_" + i, TaskType.TEXT_EMBEDDING);
102112
}
113+
for (int i = 0; i < 3; i++) {
114+
deleteModel("re-model-" + i, TaskType.RERANK);
115+
}
103116
}
104117

105118
public void testGetModelWithWrongTaskType() throws IOException {

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
101101

102102
public void testGetServicesWithRerankTaskType() throws IOException {
103103
List<Object> services = getServices(TaskType.RERANK);
104-
assertThat(services.size(), equalTo(7));
104+
assertThat(services.size(), equalTo(8));
105105

106106
var providers = providers(services);
107107

@@ -115,7 +115,8 @@ public void testGetServicesWithRerankTaskType() throws IOException {
115115
"googlevertexai",
116116
"jinaai",
117117
"test_reranking_service",
118-
"voyageai"
118+
"voyageai",
119+
"hugging_face"
119120
).toArray()
120121
)
121122
);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference;
9+
10+
import org.elasticsearch.inference.TaskType;
11+
12+
import java.io.IOException;
13+
import java.util.List;
14+
import java.util.Map;
15+
16+
public class MockRerankInferenceServiceIT extends InferenceBaseRestTest {
17+
18+
@SuppressWarnings("unchecked")
19+
public void testMockService() throws IOException {
20+
String inferenceEntityId = "test-mock";
21+
var putModel = putModel(inferenceEntityId, mockRerankServiceModelConfig(), TaskType.RERANK);
22+
var model = getModels(inferenceEntityId, TaskType.RERANK).get(0);
23+
24+
for (var modelMap : List.of(putModel, model)) {
25+
assertEquals(inferenceEntityId, modelMap.get("inference_id"));
26+
assertEquals(TaskType.RERANK, TaskType.fromString((String) modelMap.get("task_type")));
27+
assertEquals("test_reranking_service", modelMap.get("service"));
28+
}
29+
30+
List<String> input = List.of(randomAlphaOfLength(10));
31+
var inference = infer(inferenceEntityId, input);
32+
assertNonEmptyInferenceResults(inference, 1, TaskType.RERANK);
33+
assertEquals(inference, infer(inferenceEntityId, input));
34+
assertNotEquals(inference, infer(inferenceEntityId, randomValueOtherThan(input, () -> List.of(randomAlphaOfLength(10)))));
35+
}
36+
37+
public void testMockServiceWithMultipleInputs() throws IOException {
38+
String inferenceEntityId = "test-mock-with-multi-inputs";
39+
putModel(inferenceEntityId, mockRerankServiceModelConfig(), TaskType.RERANK);
40+
var queryParams = Map.of("timeout", "120s");
41+
42+
var inference = infer(
43+
inferenceEntityId,
44+
TaskType.RERANK,
45+
List.of(randomAlphaOfLength(5), randomAlphaOfLength(10)),
46+
"What if?",
47+
queryParams
48+
);
49+
50+
assertNonEmptyInferenceResults(inference, 2, TaskType.RERANK);
51+
}
52+
53+
@SuppressWarnings("unchecked")
54+
public void testMockService_DoesNotReturnSecretsInGetResponse() throws IOException {
55+
String inferenceEntityId = "test-mock";
56+
var putModel = putModel(inferenceEntityId, mockRerankServiceModelConfig(), TaskType.RERANK);
57+
var model = getModels(inferenceEntityId, TaskType.RERANK).get(0);
58+
59+
var serviceSettings = (Map<String, Object>) model.get("service_settings");
60+
assertNull(serviceSettings.get("api_key"));
61+
assertNotNull(serviceSettings.get("model_id"));
62+
63+
var putServiceSettings = (Map<String, Object>) putModel.get("service_settings");
64+
assertNull(putServiceSettings.get("api_key"));
65+
assertNotNull(putServiceSettings.get("model_id"));
66+
}
67+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@
8080
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
8181
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionServiceSettings;
8282
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
83+
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankServiceSettings;
84+
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings;
8385
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
8486
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings;
8587
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings;
@@ -365,6 +367,16 @@ private static void addHuggingFaceNamedWriteables(List<NamedWriteableRegistry.En
365367
HuggingFaceChatCompletionServiceSettings::new
366368
)
367369
);
370+
namedWriteables.add(
371+
new NamedWriteableRegistry.Entry(TaskSettings.class, HuggingFaceRerankTaskSettings.NAME, HuggingFaceRerankTaskSettings::new)
372+
);
373+
namedWriteables.add(
374+
new NamedWriteableRegistry.Entry(
375+
ServiceSettings.class,
376+
HuggingFaceRerankServiceSettings.NAME,
377+
HuggingFaceRerankServiceSettings::new
378+
)
379+
);
368380
}
369381

370382
private static void addGoogleAiStudioNamedWritables(List<NamedWriteableRegistry.Entry> namedWriteables) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
2020
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
2121
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;
22+
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings;
2223

2324
import java.util.ArrayList;
2425
import java.util.Arrays;
@@ -94,7 +95,10 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[
9495
} else if (r.getEndpoints().isEmpty() == false
9596
&& r.getEndpoints().get(0).getTaskSettings() instanceof GoogleVertexAiRerankTaskSettings googleVertexAiTaskSettings) {
9697
configuredTopN = googleVertexAiTaskSettings.topN();
97-
}
98+
} else if (r.getEndpoints().isEmpty() == false
99+
&& r.getEndpoints().get(0).getTaskSettings() instanceof HuggingFaceRerankTaskSettings huggingFaceRerankTaskSettings) {
100+
configuredTopN = huggingFaceRerankTaskSettings.getTopNDocumentsOnly();
101+
}
98102
if (configuredTopN != null && configuredTopN < rankWindowSize) {
99103
l.onFailure(
100104
new IllegalArgumentException(

0 commit comments

Comments
 (0)