Skip to content

Commit 4d9ec59

Browse files
Add Hugging Face Rerank support
1 parent ca18a86 commit 4d9ec59

25 files changed

+1384
-71
lines changed

server/src/test/java/org/elasticsearch/inference/SettingsConfigurationTestUtils.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@
2020
public class SettingsConfigurationTestUtils {
2121

2222
public static SettingsConfiguration getRandomSettingsConfigurationField() {
23-
return new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING)).setDefaultValue(
24-
randomAlphaOfLength(10)
25-
)
23+
return new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK))
24+
.setDefaultValue(randomAlphaOfLength(10))
2625
.setDescription(randomAlphaOfLength(10))
2726
.setLabel(randomAlphaOfLength(10))
2827
.setRequired(randomBoolean())

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,22 @@ static String mockDenseServiceModelConfig() {
171171
""";
172172
}
173173

174+
static String mockRerankServiceModelConfig() {
175+
return """
176+
{
177+
"task_type": "rerank",
178+
"service": "rerank_test_service",
179+
"service_settings": {
180+
"model": "rerank_model",
181+
"api_key": "abc64"
182+
},
183+
"task_settings": {
184+
"return_documents": true
185+
}
186+
}
187+
""";
188+
}
189+
174190
static void deleteModel(String modelId) throws IOException {
175191
var request = new Request("DELETE", "_inference/" + modelId);
176192
var response = client().performRequest(request);
@@ -484,6 +500,10 @@ private String jsonBody(List<String> input, @Nullable String query) {
484500
@SuppressWarnings("unchecked")
485501
protected void assertNonEmptyInferenceResults(Map<String, Object> resultMap, int expectedNumberOfResults, TaskType taskType) {
486502
switch (taskType) {
503+
case RERANK -> {
504+
var results = (List<Map<String, Object>>) resultMap.get(TaskType.RERANK.toString());
505+
assertThat(results, hasSize(expectedNumberOfResults));
506+
}
487507
case SPARSE_EMBEDDING -> {
488508
var results = (List<Map<String, Object>>) resultMap.get(TaskType.SPARSE_EMBEDDING.toString());
489509
assertThat(results, hasSize(expectedNumberOfResults));

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
101101

102102
public void testGetServicesWithRerankTaskType() throws IOException {
103103
List<Object> services = getServices(TaskType.RERANK);
104-
assertThat(services.size(), equalTo(7));
104+
assertThat(services.size(), equalTo(8));
105105

106106
var providers = providers(services);
107107

@@ -115,7 +115,8 @@ public void testGetServicesWithRerankTaskType() throws IOException {
115115
"googlevertexai",
116116
"jinaai",
117117
"test_reranking_service",
118-
"voyageai"
118+
"voyageai",
119+
"hugging_face"
119120
).toArray()
120121
)
121122
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@
7979
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;
8080
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
8181
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
82+
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankServiceSettings;
83+
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings;
8284
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
8385
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings;
8486
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings;
@@ -357,6 +359,16 @@ private static void addHuggingFaceNamedWriteables(List<NamedWriteableRegistry.En
357359
namedWriteables.add(
358360
new NamedWriteableRegistry.Entry(ServiceSettings.class, HuggingFaceServiceSettings.NAME, HuggingFaceServiceSettings::new)
359361
);
362+
namedWriteables.add(
363+
new NamedWriteableRegistry.Entry(TaskSettings.class, HuggingFaceRerankTaskSettings.NAME, HuggingFaceRerankTaskSettings::new)
364+
);
365+
namedWriteables.add(
366+
new NamedWriteableRegistry.Entry(
367+
ServiceSettings.class,
368+
HuggingFaceRerankServiceSettings.NAME,
369+
HuggingFaceRerankServiceSettings::new
370+
)
371+
);
360372
}
361373

362374
private static void addGoogleAiStudioNamedWritables(List<NamedWriteableRegistry.Entry> namedWriteables) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
2020
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
2121
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;
22+
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings;
2223

2324
import java.util.ArrayList;
2425
import java.util.Arrays;
@@ -91,7 +92,10 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[
9192
} else if (r.getEndpoints().isEmpty() == false
9293
&& r.getEndpoints().get(0).getTaskSettings() instanceof GoogleVertexAiRerankTaskSettings googleVertexAiTaskSettings) {
9394
configuredTopN = googleVertexAiTaskSettings.topN();
94-
}
95+
} else if (r.getEndpoints().isEmpty() == false
96+
&& r.getEndpoints().get(0).getTaskSettings() instanceof HuggingFaceRerankTaskSettings huggingFaceRerankTaskSettings) {
97+
configuredTopN = huggingFaceRerankTaskSettings.getTopNDocumentsOnly();
98+
}
9599
if (configuredTopN != null && configuredTopN < rankWindowSize) {
96100
l.onFailure(
97101
new IllegalArgumentException(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ public void parseRequestConfig(
5757
) {
5858
try {
5959
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
60+
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
6061

6162
ChunkingSettings chunkingSettings = null;
6263
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
@@ -65,7 +66,7 @@ public void parseRequestConfig(
6566
);
6667
}
6768

68-
var model = createModel(
69+
var modelBuilder = new HuggingFaceModelInput.Builder(
6970
inferenceEntityId,
7071
taskType,
7172
serviceSettingsMap,
@@ -75,8 +76,13 @@ public void parseRequestConfig(
7576
ConfigurationParseContext.REQUEST
7677
);
7778

79+
var model = createModel(
80+
TaskType.RERANK.equals(taskType) ? modelBuilder.withTaskSettings(taskSettingsMap).build() : modelBuilder.build()
81+
);
82+
7883
throwIfNotEmptyMap(config, name());
7984
throwIfNotEmptyMap(serviceSettingsMap, name());
85+
throwIfNotEmptyMap(taskSettingsMap, name());
8086

8187
parsedModelListener.onResponse(model);
8288
} catch (Exception e) {
@@ -92,14 +98,15 @@ public HuggingFaceModel parsePersistedConfigWithSecrets(
9298
Map<String, Object> secrets
9399
) {
94100
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
101+
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
95102
Map<String, Object> secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);
96103

97104
ChunkingSettings chunkingSettings = null;
98105
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
99106
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
100107
}
101108

102-
return createModel(
109+
var modelBuilder = new HuggingFaceModelInput.Builder(
103110
inferenceEntityId,
104111
taskType,
105112
serviceSettingsMap,
@@ -108,18 +115,23 @@ public HuggingFaceModel parsePersistedConfigWithSecrets(
108115
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
109116
ConfigurationParseContext.PERSISTENT
110117
);
118+
119+
return createModel(
120+
TaskType.RERANK.equals(taskType) ? modelBuilder.withTaskSettings(taskSettingsMap).build() : modelBuilder.build()
121+
);
111122
}
112123

113124
@Override
114125
public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) {
115126
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
127+
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
116128

117129
ChunkingSettings chunkingSettings = null;
118130
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
119131
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
120132
}
121133

122-
return createModel(
134+
var modelBuilder = new HuggingFaceModelInput.Builder(
123135
inferenceEntityId,
124136
taskType,
125137
serviceSettingsMap,
@@ -128,17 +140,13 @@ public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType
128140
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
129141
ConfigurationParseContext.PERSISTENT
130142
);
143+
144+
return createModel(
145+
TaskType.RERANK.equals(taskType) ? modelBuilder.withTaskSettings(taskSettingsMap).build() : modelBuilder.build()
146+
);
131147
}
132148

133-
protected abstract HuggingFaceModel createModel(
134-
String inferenceEntityId,
135-
TaskType taskType,
136-
Map<String, Object> serviceSettings,
137-
ChunkingSettings chunkingSettings,
138-
Map<String, Object> secretSettings,
139-
String failureMessage,
140-
ConfigurationParseContext context
141-
);
149+
protected abstract HuggingFaceModel createModel(HuggingFaceModelInput input);
142150

143151
@Override
144152
public void doInfer(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModel.java

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,19 @@
99

1010
import org.elasticsearch.common.settings.SecureString;
1111
import org.elasticsearch.core.Nullable;
12-
import org.elasticsearch.inference.Model;
1312
import org.elasticsearch.inference.ModelConfigurations;
1413
import org.elasticsearch.inference.ModelSecrets;
14+
import org.elasticsearch.inference.TaskSettings;
1515
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
16+
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
1617
import org.elasticsearch.xpack.inference.services.ServiceUtils;
1718
import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionVisitor;
1819
import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets;
20+
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
1921

2022
import java.util.Objects;
2123

22-
public abstract class HuggingFaceModel extends Model {
24+
public abstract class HuggingFaceModel extends RateLimitGroupingModel {
2325
private final HuggingFaceRateLimitServiceSettings rateLimitServiceSettings;
2426
private final SecureString apiKey;
2527

@@ -34,10 +36,27 @@ public HuggingFaceModel(
3436
apiKey = ServiceUtils.apiKey(apiKeySecrets);
3537
}
3638

39+
protected HuggingFaceModel(HuggingFaceModel model, TaskSettings taskSettings) {
40+
super(model, taskSettings);
41+
42+
rateLimitServiceSettings = model.rateLimitServiceSettings();
43+
apiKey = model.apiKey();
44+
}
45+
3746
public HuggingFaceRateLimitServiceSettings rateLimitServiceSettings() {
3847
return rateLimitServiceSettings;
3948
}
4049

50+
@Override
51+
public int rateLimitGroupingHash() {
52+
return Objects.hash(rateLimitServiceSettings.uri(), apiKey);
53+
}
54+
55+
@Override
56+
public RateLimitSettings rateLimitSettings() {
57+
return rateLimitServiceSettings.rateLimitSettings();
58+
}
59+
4160
public SecureString apiKey() {
4261
return apiKey;
4362
}
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.huggingface;
9+
10+
import org.elasticsearch.core.Nullable;
11+
import org.elasticsearch.inference.ChunkingSettings;
12+
import org.elasticsearch.inference.TaskType;
13+
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
14+
15+
import java.util.Map;
16+
17+
public class HuggingFaceModelInput {
18+
private final String inferenceEntityId;
19+
private final TaskType taskType;
20+
private final Map<String, Object> serviceSettings;
21+
@Nullable
22+
private final Map<String, Object> taskSettings;
23+
private final ChunkingSettings chunkingSettings;
24+
@Nullable
25+
private final Map<String, Object> secretSettings;
26+
private final String failureMessage;
27+
private final ConfigurationParseContext context;
28+
29+
public HuggingFaceModelInput(Builder builder) {
30+
this.inferenceEntityId = builder.inferenceEntityId;
31+
this.taskType = builder.taskType;
32+
this.serviceSettings = builder.serviceSettings;
33+
this.taskSettings = builder.taskSettings;
34+
this.chunkingSettings = builder.chunkingSettings;
35+
this.secretSettings = builder.secretSettings;
36+
this.failureMessage = builder.failureMessage;
37+
this.context = builder.context;
38+
}
39+
40+
public String getInferenceEntityId() {
41+
return inferenceEntityId;
42+
}
43+
44+
public TaskType getTaskType() {
45+
return taskType;
46+
}
47+
48+
public Map<String, Object> getServiceSettings() {
49+
return serviceSettings;
50+
}
51+
52+
@Nullable
53+
public Map<String, Object> getTaskSettings() {
54+
return taskSettings;
55+
}
56+
57+
public ChunkingSettings getChunkingSettings() {
58+
return chunkingSettings;
59+
}
60+
61+
@Nullable
62+
public Map<String, Object> getSecretSettings() {
63+
return secretSettings;
64+
}
65+
66+
public String getFailureMessage() {
67+
return failureMessage;
68+
}
69+
70+
public ConfigurationParseContext getContext() {
71+
return context;
72+
}
73+
74+
public static class Builder {
75+
private String inferenceEntityId;
76+
private TaskType taskType;
77+
private Map<String, Object> serviceSettings;
78+
@Nullable
79+
private Map<String, Object> taskSettings;
80+
private ChunkingSettings chunkingSettings;
81+
@Nullable
82+
Map<String, Object> secretSettings;
83+
private String failureMessage;
84+
private ConfigurationParseContext context;
85+
86+
public Builder(
87+
String inferenceEntityId,
88+
TaskType taskType,
89+
Map<String, Object> serviceSettings,
90+
ChunkingSettings chunkingSettings,
91+
@Nullable Map<String, Object> secretSettings,
92+
String failureMessage,
93+
ConfigurationParseContext context
94+
) {
95+
this.inferenceEntityId = inferenceEntityId;
96+
this.taskType = taskType;
97+
this.serviceSettings = serviceSettings;
98+
this.chunkingSettings = chunkingSettings;
99+
this.secretSettings = secretSettings;
100+
this.failureMessage = failureMessage;
101+
this.context = context;
102+
}
103+
104+
public Builder withTaskSettings(Map<String, Object> taskSettings) {
105+
this.taskSettings = taskSettings;
106+
return this;
107+
}
108+
109+
public HuggingFaceModelInput build() {
110+
return new HuggingFaceModelInput(this);
111+
}
112+
}
113+
}

0 commit comments

Comments
 (0)