Skip to content

Commit 3c5f0eb

Browse files
Add Azure AI Rerank support
1 parent 11ca4f6 commit 3c5f0eb

18 files changed

+1380
-0
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionTaskSettings;
5151
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsServiceSettings;
5252
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsTaskSettings;
53+
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankServiceSettings;
54+
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankTaskSettings;
5355
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings;
5456
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionServiceSettings;
5557
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionTaskSettings;
@@ -311,6 +313,17 @@ private static void addAzureAiStudioNamedWriteables(List<NamedWriteableRegistry.
311313
AzureAiStudioChatCompletionTaskSettings::new
312314
)
313315
);
316+
317+
namedWriteables.add(
318+
new NamedWriteableRegistry.Entry(
319+
ServiceSettings.class,
320+
AzureAiStudioRerankServiceSettings.NAME,
321+
AzureAiStudioRerankServiceSettings::new
322+
)
323+
);
324+
namedWriteables.add(
325+
new NamedWriteableRegistry.Entry(TaskSettings.class, AzureAiStudioRerankTaskSettings.NAME, AzureAiStudioRerankTaskSettings::new)
326+
);
314327
}
315328

316329
private static void addAzureOpenAiNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioConstants.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
public class AzureAiStudioConstants {
1111
public static final String EMBEDDINGS_URI_PATH = "/v1/embeddings";
1212
public static final String COMPLETIONS_URI_PATH = "/v1/chat/completions";
13+
public static final String RERANK_URI_PATH = "/v1/rerank";
1314

1415
// common service settings fields
1516
public static final String TARGET_FIELD = "target";
@@ -35,5 +36,9 @@ public class AzureAiStudioConstants {
3536
public static final Double MIN_TEMPERATURE_TOP_P = 0.0;
3637
public static final Double MAX_TEMPERATURE_TOP_P = 2.0;
3738

39+
// rerank task settings fields
40+
public static final String RETURN_DOCUMENTS_FIELD = "return_documents";
41+
public static final String TOP_N_FIELD = "top_n";
42+
3843
private AzureAiStudioConstants() {}
3944
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.azureaistudio;
9+
10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.action.ActionListener;
13+
import org.elasticsearch.inference.InferenceServiceResults;
14+
import org.elasticsearch.threadpool.ThreadPool;
15+
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
16+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
17+
import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest;
18+
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
19+
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
20+
import org.elasticsearch.xpack.inference.external.response.ErrorMessageResponseEntity;
21+
import org.elasticsearch.xpack.inference.services.azureaistudio.request.AzureAiStudioRerankRequest;
22+
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel;
23+
import org.elasticsearch.xpack.inference.services.azureaistudio.response.AzureAiStudioRerankResponseEntity;
24+
import org.elasticsearch.xpack.inference.services.azureopenai.response.AzureMistralOpenAiExternalResponseHandler;
25+
26+
import java.util.function.Supplier;
27+
28+
public class AzureAiStudioRerankRequestManager extends AzureAiStudioRequestManager {
29+
private static final Logger logger = LogManager.getLogger(AzureAiStudioRerankRequestManager.class);
30+
31+
private static final ResponseHandler HANDLER = createRerankHandler();
32+
33+
private final AzureAiStudioRerankModel model;
34+
35+
public AzureAiStudioRerankRequestManager(AzureAiStudioRerankModel model, ThreadPool threadPool) {
36+
super(threadPool, model);
37+
this.model = model;
38+
}
39+
40+
@Override
41+
public void execute(
42+
InferenceInputs inferenceInputs,
43+
RequestSender requestSender,
44+
Supplier<Boolean> hasRequestRerankFunction,
45+
ActionListener<InferenceServiceResults> listener
46+
) {
47+
var rerankInput = QueryAndDocsInputs.of(inferenceInputs);
48+
AzureAiStudioRerankRequest request = new AzureAiStudioRerankRequest(
49+
model,
50+
rerankInput.getQuery(),
51+
rerankInput.getChunks(),
52+
rerankInput.getReturnDocuments(),
53+
rerankInput.getTopN()
54+
);
55+
56+
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestRerankFunction, listener));
57+
}
58+
59+
private static ResponseHandler createRerankHandler() {
60+
return new AzureMistralOpenAiExternalResponseHandler(
61+
"azure ai studio rerank",
62+
new AzureAiStudioRerankResponseEntity(),
63+
ErrorMessageResponseEntity::fromResponse,
64+
true
65+
);
66+
}
67+
68+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionTaskSettings;
4545
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModel;
4646
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsServiceSettings;
47+
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel;
4748
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
4849
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
4950

@@ -307,6 +308,24 @@ private static AzureAiStudioModel createModel(
307308
return completionModel;
308309
}
309310

311+
if (taskType == TaskType.RERANK) {
312+
var rerankModel = new AzureAiStudioRerankModel(
313+
inferenceEntityId,
314+
taskType,
315+
NAME,
316+
serviceSettings,
317+
taskSettings,
318+
secretSettings,
319+
context
320+
);
321+
checkProviderAndEndpointTypeForTask(
322+
TaskType.RERANK,
323+
rerankModel.getServiceSettings().provider(),
324+
rerankModel.getServiceSettings().endpointType()
325+
);
326+
return rerankModel;
327+
}
328+
310329
throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
311330
}
312331

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionCreator.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
import org.elasticsearch.xpack.inference.services.ServiceComponents;
1414
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioChatCompletionRequestManager;
1515
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEmbeddingsRequestManager;
16+
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioRerankRequestManager;
1617
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionModel;
1718
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModel;
19+
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel;
1820

1921
import java.util.Map;
2022
import java.util.Objects;
@@ -49,4 +51,12 @@ public ExecutableAction create(AzureAiStudioEmbeddingsModel embeddingsModel, Map
4951
var errorMessage = constructFailedToSendRequestMessage("Azure AI Studio embeddings");
5052
return new SenderExecutableAction(sender, requestManager, errorMessage);
5153
}
54+
55+
@Override
56+
public ExecutableAction create(AzureAiStudioRerankModel rerankModel, Map<String, Object> taskSettings) {
57+
var overriddenModel = AzureAiStudioRerankModel.of(rerankModel, taskSettings);
58+
var requestManager = new AzureAiStudioRerankRequestManager(overriddenModel, serviceComponents.threadPool());
59+
var errorMessage = constructFailedToSendRequestMessage("Azure AI Studio rerank");
60+
return new SenderExecutableAction(sender, requestManager, errorMessage);
61+
}
5262
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionVisitor.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@
1010
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1111
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionModel;
1212
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModel;
13+
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel;
1314

1415
import java.util.Map;
1516

1617
public interface AzureAiStudioActionVisitor {
1718
ExecutableAction create(AzureAiStudioEmbeddingsModel embeddingsModel, Map<String, Object> taskSettings);
1819

1920
ExecutableAction create(AzureAiStudioChatCompletionModel completionModel, Map<String, Object> taskSettings);
21+
22+
ExecutableAction create(AzureAiStudioRerankModel rerankModel, Map<String, Object> taskSettings);
2023
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.azureaistudio.request;
9+
10+
import org.apache.http.HttpHeaders;
11+
import org.apache.http.client.methods.HttpPost;
12+
import org.apache.http.entity.ByteArrayEntity;
13+
import org.elasticsearch.common.Strings;
14+
import org.elasticsearch.core.Nullable;
15+
import org.elasticsearch.xcontent.XContentType;
16+
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
17+
import org.elasticsearch.xpack.inference.external.request.Request;
18+
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel;
19+
20+
import java.nio.charset.StandardCharsets;
21+
import java.util.List;
22+
import java.util.Objects;
23+
24+
public class AzureAiStudioRerankRequest extends AzureAiStudioRequest {
25+
private final String query;
26+
private final List<String> input;
27+
private final Boolean returnDocuments;
28+
private final Integer topN;
29+
private final AzureAiStudioRerankModel rerankModel;
30+
31+
public AzureAiStudioRerankRequest(
32+
AzureAiStudioRerankModel model,
33+
String query,
34+
List<String> input,
35+
@Nullable Boolean returnDocuments,
36+
@Nullable Integer topN
37+
) {
38+
super(model);
39+
this.rerankModel = Objects.requireNonNull(model);
40+
this.query = query;
41+
this.input = Objects.requireNonNull(input);
42+
this.returnDocuments = returnDocuments;
43+
this.topN = topN;
44+
}
45+
46+
@Override
47+
public HttpRequest createHttpRequest() {
48+
HttpPost httpPost = new HttpPost(this.uri);
49+
50+
ByteArrayEntity byteEntity = new ByteArrayEntity(Strings.toString(createRequestEntity()).getBytes(StandardCharsets.UTF_8));
51+
httpPost.setEntity(byteEntity);
52+
53+
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
54+
setAuthHeader(httpPost, rerankModel);
55+
56+
return new HttpRequest(httpPost, getInferenceEntityId());
57+
}
58+
59+
@Override
60+
public Request truncate() {
61+
// no truncation
62+
return this;
63+
}
64+
65+
@Override
66+
public boolean[] getTruncationInfo() {
67+
// no truncation
68+
return null;
69+
}
70+
71+
public Integer getTopN() {
72+
return topN != null ? topN : rerankModel.getTaskSettings().topN();
73+
}
74+
75+
private AzureAiStudioRerankRequestEntity createRequestEntity() {
76+
var taskSettings = rerankModel.getTaskSettings();
77+
return new AzureAiStudioRerankRequestEntity(query, input, returnDocuments, topN, taskSettings);
78+
}
79+
80+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.azureaistudio.request;
9+
10+
import org.elasticsearch.core.Nullable;
11+
import org.elasticsearch.xcontent.ToXContentObject;
12+
import org.elasticsearch.xcontent.XContentBuilder;
13+
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankTaskSettings;
14+
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings;
15+
16+
import java.io.IOException;
17+
import java.util.List;
18+
import java.util.Objects;
19+
20+
public record AzureAiStudioRerankRequestEntity(
21+
String query,
22+
List<String> input,
23+
@Nullable Boolean returnDocuments,
24+
@Nullable Integer topN,
25+
AzureAiStudioRerankTaskSettings taskSettings
26+
) implements ToXContentObject {
27+
28+
private static final String RETURN_TEXT = "return_text";
29+
private static final String DOCUMENTS_FIELD = "texts";
30+
private static final String QUERY = "query";
31+
32+
public AzureAiStudioRerankRequestEntity {
33+
Objects.requireNonNull(query);
34+
Objects.requireNonNull(input);
35+
Objects.requireNonNull(taskSettings);
36+
}
37+
38+
@Override
39+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
40+
builder.startObject();
41+
42+
builder.field(DOCUMENTS_FIELD, input);
43+
builder.field(QUERY, query);
44+
45+
if (returnDocuments != null) {
46+
builder.field(RETURN_TEXT, returnDocuments);
47+
} else if (taskSettings.returnDocuments() != null) {
48+
builder.field(RETURN_TEXT, taskSettings.returnDocuments());
49+
}
50+
51+
if (topN != null) {
52+
builder.field(HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY, topN);
53+
} else if (taskSettings.topN() != null) {
54+
builder.field(HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY, taskSettings.topN());
55+
}
56+
builder.endObject();
57+
return builder;
58+
}
59+
}

0 commit comments

Comments
 (0)