Skip to content

Commit aacbd60

Browse files
Extend huggingface with rerank
1 parent d1225bc commit aacbd60

File tree

19 files changed

+1392
-3
lines changed

19 files changed

+1392
-3
lines changed

server/src/test/java/org/elasticsearch/inference/SettingsConfigurationTestUtils.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
public class SettingsConfigurationTestUtils {
2121

2222
public static SettingsConfiguration getRandomSettingsConfigurationField() {
23-
return new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING)).setDefaultValue(
24-
randomAlphaOfLength(10)
23+
return new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK))
24+
.setDefaultValue(randomAlphaOfLength(10)
2525
)
2626
.setDescription(randomAlphaOfLength(10))
2727
.setLabel(randomAlphaOfLength(10))

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,22 @@ static String mockDenseServiceModelConfig() {
171171
""";
172172
}
173173

174+
static String mockRerankServiceModelConfig() {
175+
return """
176+
{
177+
"task_type": "rerank",
178+
"service": "rerank_test_service",
179+
"service_settings": {
180+
"model": "rerank_model",
181+
"api_key": "abc64"
182+
},
183+
"task_settings": {
184+
"return_documents": true
185+
}
186+
}
187+
""";
188+
}
189+
174190
static void deleteModel(String modelId) throws IOException {
175191
var request = new Request("DELETE", "_inference/" + modelId);
176192
var response = client().performRequest(request);
@@ -484,6 +500,10 @@ private String jsonBody(List<String> input, @Nullable String query) {
484500
@SuppressWarnings("unchecked")
485501
protected void assertNonEmptyInferenceResults(Map<String, Object> resultMap, int expectedNumberOfResults, TaskType taskType) {
486502
switch (taskType) {
503+
case RERANK -> {
504+
var results = (List<Map<String, Object>>) resultMap.get(TaskType.RERANK.toString());
505+
assertThat(results, hasSize(expectedNumberOfResults));
506+
}
487507
case SPARSE_EMBEDDING -> {
488508
var results = (List<Map<String, Object>>) resultMap.get(TaskType.SPARSE_EMBEDDING.toString());
489509
assertThat(results, hasSize(expectedNumberOfResults));

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ public class InferenceCrudIT extends InferenceBaseRestTest {
4747

4848
@SuppressWarnings("unchecked")
4949
public void testCRUD() throws IOException {
50+
for (int i = 0; i < 6; i++) {
51+
putModel("r_model_" + i, mockRerankServiceModelConfig(), TaskType.RERANK);
52+
}
5053
for (int i = 0; i < 5; i++) {
5154
putModel("se_model_" + i, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
5255
}
@@ -55,9 +58,16 @@ public void testCRUD() throws IOException {
5558
}
5659

5760
var getAllModels = getAllModels();
58-
int numModels = 12;
61+
int numModels = 18;
5962
assertThat(getAllModels, hasSize(numModels));
6063

64+
var getRerankModels = getModels("_all", TaskType.RERANK);
65+
int numRerankModels = 6;
66+
assertThat(getRerankModels, hasSize(numRerankModels));
67+
for (var rerankModel : getRerankModels) {
68+
assertEquals("rerank", rerankModel.get("task_type"));
69+
}
70+
6171
var getSparseModels = getModels("_all", TaskType.SPARSE_EMBEDDING);
6272
int numSparseModels = 6;
6373
assertThat(getSparseModels, hasSize(numSparseModels));
@@ -94,6 +104,9 @@ public void testCRUD() throws IOException {
94104
assertNotEquals(oldApiKey, newApiKey);
95105
assertEquals(updatedEndpoint, singleModel.get(0));
96106
}
107+
for (int i = 0; i < 6; i++) {
108+
deleteModel("r_model_" + i, TaskType.RERANK);
109+
}
97110
for (int i = 0; i < 5; i++) {
98111
deleteModel("se_model_" + i, TaskType.SPARSE_EMBEDDING);
99112
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;
8080
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
8181
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
82+
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankServiceSettings;
8283
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
8384
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings;
8485
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings;
@@ -353,6 +354,13 @@ private static void addHuggingFaceNamedWriteables(List<NamedWriteableRegistry.En
353354
namedWriteables.add(
354355
new NamedWriteableRegistry.Entry(ServiceSettings.class, HuggingFaceServiceSettings.NAME, HuggingFaceServiceSettings::new)
355356
);
357+
namedWriteables.add(
358+
new NamedWriteableRegistry.Entry(
359+
ServiceSettings.class,
360+
HuggingFaceRerankServiceSettings.NAME,
361+
HuggingFaceRerankServiceSettings::new
362+
)
363+
);
356364
}
357365

358366
private static void addGoogleAiStudioNamedWritables(List<NamedWriteableRegistry.Entry> namedWriteables) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@
128128
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiService;
129129
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService;
130130
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserService;
131+
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankService;
131132
import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService;
132133
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService;
133134
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
@@ -361,6 +362,7 @@ public void loadExtensions(ExtensionLoader loader) {
361362
public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
362363
return List.of(
363364
context -> new HuggingFaceElserService(httpFactory.get(), serviceComponents.get()),
365+
context -> new HuggingFaceRerankService(httpFactory.get(), serviceComponents.get()),
364366
context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get()),
365367
context -> new OpenAiService(httpFactory.get(), serviceComponents.get()),
366368
context -> new CohereService(httpFactory.get(), serviceComponents.get()),
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.huggingface;
9+
10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.action.ActionListener;
13+
import org.elasticsearch.inference.InferenceServiceResults;
14+
import org.elasticsearch.threadpool.ThreadPool;
15+
import org.elasticsearch.xpack.inference.common.Truncator;
16+
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
17+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
18+
import org.elasticsearch.xpack.inference.external.http.sender.BaseRequestManager;
19+
import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest;
20+
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
21+
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
22+
import org.elasticsearch.xpack.inference.services.huggingface.request.HuggingFaceInferenceRerankRequest;
23+
24+
import java.util.List;
25+
import java.util.Objects;
26+
import java.util.function.Supplier;
27+
28+
import static org.elasticsearch.xpack.inference.common.Truncator.truncate;
29+
30+
public class HuggingFaceRequestRerankManager extends BaseRequestManager {
31+
private static final Logger logger = LogManager.getLogger(HuggingFaceRequestRerankManager.class);
32+
33+
public static HuggingFaceRequestRerankManager of(
34+
HuggingFaceModel model,
35+
ResponseHandler responseHandler,
36+
Truncator truncator,
37+
ThreadPool threadPool
38+
) {
39+
return new HuggingFaceRequestRerankManager(
40+
Objects.requireNonNull(model),
41+
Objects.requireNonNull(responseHandler),
42+
Objects.requireNonNull(truncator),
43+
Objects.requireNonNull(threadPool)
44+
);
45+
}
46+
47+
private final HuggingFaceModel model;
48+
private final ResponseHandler responseHandler;
49+
private final Truncator truncator;
50+
51+
private HuggingFaceRequestRerankManager(HuggingFaceModel model, ResponseHandler responseHandler, Truncator truncator, ThreadPool threadPool) {
52+
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
53+
this.model = model;
54+
this.responseHandler = responseHandler;
55+
this.truncator = truncator;
56+
}
57+
58+
@Override
59+
public void execute(
60+
InferenceInputs inferenceInputs,
61+
RequestSender requestSender,
62+
Supplier<Boolean> hasRequestCompletedFunction,
63+
ActionListener<InferenceServiceResults> listener
64+
) {
65+
List<String> docsInput = QueryAndDocsInputs.of(inferenceInputs).getChunks();
66+
var truncatedInput = truncate(docsInput, model.getTokenLimit());
67+
var request = new HuggingFaceInferenceRerankRequest(truncator, truncatedInput, model);
68+
69+
execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener));
70+
}
71+
72+
record RateLimitGrouping(int accountHash) {
73+
74+
public static RateLimitGrouping of(HuggingFaceModel model) {
75+
return new RateLimitGrouping(new HuggingFaceAccount(model.rateLimitServiceSettings().uri(), model.apiKey()).hashCode());
76+
}
77+
}
78+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,16 @@
1010
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1111
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
1212
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
13+
import org.elasticsearch.xpack.inference.external.response.huggingface.HuggingFaceRerankResponseEntity;
1314
import org.elasticsearch.xpack.inference.services.ServiceComponents;
1415
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceRequestManager;
16+
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceRequestRerankManager;
1517
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceResponseHandler;
1618
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
1719
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
1820
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceElserResponseEntity;
1921
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceEmbeddingsResponseEntity;
22+
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModel;
2023

2124
import java.util.Objects;
2225

@@ -34,6 +37,23 @@ public HuggingFaceActionCreator(Sender sender, ServiceComponents serviceComponen
3437
this.serviceComponents = Objects.requireNonNull(serviceComponents);
3538
}
3639

40+
@Override
41+
public ExecutableAction create(HuggingFaceRerankModel model) {
42+
var responseHandler = new HuggingFaceResponseHandler("hugging face rerank", HuggingFaceRerankResponseEntity::fromResponse);
43+
var requestCreator = HuggingFaceRequestRerankManager.of(
44+
model,
45+
responseHandler,
46+
serviceComponents.truncator(),
47+
serviceComponents.threadPool()
48+
);
49+
var errorMessage = format(
50+
"Failed to send Hugging Face %s request from inference entity id [%s]",
51+
"rerank",
52+
model.getInferenceEntityId()
53+
);
54+
return new SenderExecutableAction(sender, requestCreator, errorMessage);
55+
}
56+
3757
@Override
3858
public ExecutableAction create(HuggingFaceEmbeddingsModel model) {
3959
var responseHandler = new HuggingFaceResponseHandler(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionVisitor.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@
1010
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1111
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
1212
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
13+
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModel;
1314

1415
public interface HuggingFaceActionVisitor {
16+
ExecutableAction create(HuggingFaceRerankModel model);
17+
1518
ExecutableAction create(HuggingFaceEmbeddingsModel model);
1619

1720
ExecutableAction create(HuggingFaceElserModel model);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.huggingface.request;
9+
10+
import org.apache.http.HttpHeaders;
11+
import org.apache.http.client.methods.HttpPost;
12+
import org.apache.http.entity.ByteArrayEntity;
13+
import org.elasticsearch.common.Strings;
14+
import org.elasticsearch.xcontent.XContentType;
15+
import org.elasticsearch.xpack.inference.common.Truncator;
16+
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
17+
import org.elasticsearch.xpack.inference.external.request.Request;
18+
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceAccount;
19+
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel;
20+
21+
import java.net.URI;
22+
import java.nio.charset.StandardCharsets;
23+
import java.util.Objects;
24+
25+
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
26+
27+
public class HuggingFaceInferenceRerankRequest implements Request {
28+
29+
private final Truncator truncator;
30+
private final HuggingFaceAccount account;
31+
private final Truncator.TruncationResult truncationResult;
32+
private final HuggingFaceModel model;
33+
34+
public HuggingFaceInferenceRerankRequest(Truncator truncator, Truncator.TruncationResult input, HuggingFaceModel model) {
35+
this.truncator = Objects.requireNonNull(truncator);
36+
this.account = HuggingFaceAccount.of(model);
37+
this.truncationResult = Objects.requireNonNull(input);
38+
this.model = Objects.requireNonNull(model);
39+
}
40+
41+
public HttpRequest createHttpRequest() {
42+
HttpPost httpPost = new HttpPost(account.uri());
43+
44+
ByteArrayEntity byteEntity = new ByteArrayEntity(
45+
Strings.toString(new HuggingFaceInferenceRerankRequestEntity(truncationResult.input(), "default")).getBytes(StandardCharsets.UTF_8)
46+
);
47+
httpPost.setEntity(byteEntity);
48+
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters());
49+
httpPost.setHeader(createAuthBearerHeader(account.apiKey()));
50+
51+
return new HttpRequest(httpPost, getInferenceEntityId());
52+
}
53+
54+
public URI getURI() {
55+
return account.uri();
56+
}
57+
58+
@Override
59+
public String getInferenceEntityId() {
60+
return model.getInferenceEntityId();
61+
}
62+
63+
@Override
64+
public Request truncate() {
65+
var truncateResult = truncator.truncate(truncationResult.input());
66+
67+
return new HuggingFaceInferenceRerankRequest(truncator, truncateResult, model);
68+
}
69+
70+
@Override
71+
public boolean[] getTruncationInfo() {
72+
return truncationResult.truncated().clone();
73+
}
74+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.huggingface.request;
9+
10+
import org.elasticsearch.xcontent.ToXContentObject;
11+
import org.elasticsearch.xcontent.XContentBuilder;
12+
13+
import java.io.IOException;
14+
import java.util.List;
15+
import java.util.Objects;
16+
17+
public record HuggingFaceInferenceRerankRequestEntity(List<String> documents, String query) implements ToXContentObject {
18+
19+
private static final String DOCUMENTS_FIELD = "documents";
20+
private static final String QUERY_FIELD = "query";
21+
22+
public HuggingFaceInferenceRerankRequestEntity {
23+
Objects.requireNonNull(documents);
24+
Objects.requireNonNull(query);
25+
}
26+
27+
@Override
28+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
29+
builder.startObject();
30+
31+
builder.field(DOCUMENTS_FIELD, documents);
32+
builder.field(QUERY_FIELD, query);
33+
34+
builder.endObject();
35+
return builder;
36+
}
37+
}

0 commit comments

Comments
 (0)