Skip to content

Commit d06b0c8

Browse files
Add Azure AI Rerank support (#129848)
* Add Azure AI Rerank support * address comments * address comments * refactor azure ai studio service * update rerank task settings test * add provider for rerank
1 parent f9eee6c commit d06b0c8

26 files changed

+2147
-81
lines changed

docs/changelog/129848.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 129848
2+
summary: "[ML] Add Azure AI Rerank support to the Inference Plugin"
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ static TransportVersion def(int id) {
341341
public static final TransportVersion LOOKUP_JOIN_CCS = def(9_120_0_00);
342342
public static final TransportVersion NODE_USAGE_STATS_FOR_THREAD_POOLS_IN_CLUSTER_INFO = def(9_121_0_00);
343343
public static final TransportVersion ESQL_CATEGORIZE_OPTIONS = def(9_122_0_00);
344+
public static final TransportVersion ML_INFERENCE_AZURE_AI_STUDIO_RERANK_ADDED = def(9_123_0_00);
344345

345346
/*
346347
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
111111
containsInAnyOrder(
112112
List.of(
113113
"alibabacloud-ai-search",
114+
"azureaistudio",
114115
"cohere",
115116
"elasticsearch",
116117
"googlevertexai",

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionTaskSettings;
5151
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsServiceSettings;
5252
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsTaskSettings;
53+
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankServiceSettings;
54+
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankTaskSettings;
5355
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings;
5456
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionServiceSettings;
5557
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionTaskSettings;
@@ -306,6 +308,17 @@ private static void addAzureAiStudioNamedWriteables(List<NamedWriteableRegistry.
306308
AzureAiStudioChatCompletionTaskSettings::new
307309
)
308310
);
311+
312+
namedWriteables.add(
313+
new NamedWriteableRegistry.Entry(
314+
ServiceSettings.class,
315+
AzureAiStudioRerankServiceSettings.NAME,
316+
AzureAiStudioRerankServiceSettings::new
317+
)
318+
);
319+
namedWriteables.add(
320+
new NamedWriteableRegistry.Entry(TaskSettings.class, AzureAiStudioRerankTaskSettings.NAME, AzureAiStudioRerankTaskSettings::new)
321+
);
309322
}
310323

311324
private static void addAzureOpenAiNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioConstants.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
public class AzureAiStudioConstants {
1111
public static final String EMBEDDINGS_URI_PATH = "/v1/embeddings";
1212
public static final String COMPLETIONS_URI_PATH = "/v1/chat/completions";
13+
public static final String RERANK_URI_PATH = "/v1/rerank";
1314

1415
// common service settings fields
1516
public static final String TARGET_FIELD = "target";
@@ -22,6 +23,10 @@ public class AzureAiStudioConstants {
2223
public static final String DIMENSIONS_FIELD = "dimensions";
2324
public static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user";
2425

26+
// rerank task settings fields
27+
public static final String DOCUMENTS_FIELD = "documents";
28+
public static final String QUERY_FIELD = "query";
29+
2530
// embeddings task settings fields
2631
public static final String USER_FIELD = "user";
2732

@@ -35,5 +40,9 @@ public class AzureAiStudioConstants {
3540
public static final Double MIN_TEMPERATURE_TOP_P = 0.0;
3641
public static final Double MAX_TEMPERATURE_TOP_P = 2.0;
3742

43+
// rerank task settings fields
44+
public static final String RETURN_DOCUMENTS_FIELD = "return_documents";
45+
public static final String TOP_N_FIELD = "top_n";
46+
3847
private AzureAiStudioConstants() {}
3948
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioProviderCapabilities.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ public final class AzureAiStudioProviderCapabilities {
2222
// these providers have chat completion inference (all providers at the moment)
2323
public static final List<AzureAiStudioProvider> chatCompletionProviders = List.of(AzureAiStudioProvider.values());
2424

25+
// these providers have rerank inference
26+
public static final List<AzureAiStudioProvider> rerankProviders = List.of(AzureAiStudioProvider.COHERE);
27+
2528
// these providers allow token ("pay as you go") embeddings endpoints
2629
public static final List<AzureAiStudioProvider> tokenEmbeddingsProviders = List.of(
2730
AzureAiStudioProvider.OPENAI,
@@ -31,6 +34,9 @@ public final class AzureAiStudioProviderCapabilities {
3134
// these providers allow realtime embeddings endpoints (none at the moment)
3235
public static final List<AzureAiStudioProvider> realtimeEmbeddingsProviders = List.of();
3336

37+
// these providers allow realtime rerank endpoints (none at the moment)
38+
public static final List<AzureAiStudioProvider> realtimeRerankProviders = List.of();
39+
3440
// these providers allow token ("pay as you go") chat completion endpoints
3541
public static final List<AzureAiStudioProvider> tokenChatCompletionProviders = List.of(
3642
AzureAiStudioProvider.OPENAI,
@@ -54,6 +60,9 @@ public static boolean providerAllowsTaskType(AzureAiStudioProvider provider, Tas
5460
case TEXT_EMBEDDING -> {
5561
return embeddingProviders.contains(provider);
5662
}
63+
case RERANK -> {
64+
return rerankProviders.contains(provider);
65+
}
5766
default -> {
5867
return false;
5968
}
@@ -76,6 +85,11 @@ public static boolean providerAllowsEndpointTypeForTask(
7685
? tokenEmbeddingsProviders.contains(provider)
7786
: realtimeEmbeddingsProviders.contains(provider);
7887
}
88+
case RERANK -> {
89+
return (endpointType == AzureAiStudioEndpointType.TOKEN)
90+
? rerankProviders.contains(provider)
91+
: realtimeRerankProviders.contains(provider);
92+
}
7993
default -> {
8094
return false;
8195
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.azureaistudio;
9+
10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.action.ActionListener;
13+
import org.elasticsearch.inference.InferenceServiceResults;
14+
import org.elasticsearch.threadpool.ThreadPool;
15+
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
16+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
17+
import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest;
18+
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
19+
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
20+
import org.elasticsearch.xpack.inference.external.response.ErrorMessageResponseEntity;
21+
import org.elasticsearch.xpack.inference.services.azureaistudio.request.AzureAiStudioRerankRequest;
22+
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel;
23+
import org.elasticsearch.xpack.inference.services.azureaistudio.response.AzureAiStudioRerankResponseEntity;
24+
import org.elasticsearch.xpack.inference.services.azureopenai.response.AzureMistralOpenAiExternalResponseHandler;
25+
26+
import java.util.function.Supplier;
27+
28+
public class AzureAiStudioRerankRequestManager extends AzureAiStudioRequestManager {
29+
private static final Logger logger = LogManager.getLogger(AzureAiStudioRerankRequestManager.class);
30+
31+
private static final ResponseHandler HANDLER = createRerankHandler();
32+
33+
private final AzureAiStudioRerankModel model;
34+
35+
public AzureAiStudioRerankRequestManager(AzureAiStudioRerankModel model, ThreadPool threadPool) {
36+
super(threadPool, model);
37+
this.model = model;
38+
}
39+
40+
@Override
41+
public void execute(
42+
InferenceInputs inferenceInputs,
43+
RequestSender requestSender,
44+
Supplier<Boolean> hasRequestRerankFunction,
45+
ActionListener<InferenceServiceResults> listener
46+
) {
47+
var rerankInput = QueryAndDocsInputs.of(inferenceInputs);
48+
AzureAiStudioRerankRequest request = new AzureAiStudioRerankRequest(
49+
model,
50+
rerankInput.getQuery(),
51+
rerankInput.getChunks(),
52+
rerankInput.getReturnDocuments(),
53+
rerankInput.getTopN()
54+
);
55+
56+
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestRerankFunction, listener));
57+
}
58+
59+
private static ResponseHandler createRerankHandler() {
60+
// This currently covers response handling for Azure AI Studio
61+
return new AzureMistralOpenAiExternalResponseHandler(
62+
"azure ai studio rerank",
63+
new AzureAiStudioRerankResponseEntity(),
64+
ErrorMessageResponseEntity::fromResponse,
65+
true
66+
);
67+
}
68+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionTaskSettings;
4545
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModel;
4646
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsServiceSettings;
47+
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel;
4748
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
4849
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
4950

@@ -71,10 +72,10 @@
7172

7273
public class AzureAiStudioService extends SenderService {
7374

74-
static final String NAME = "azureaistudio";
75+
public static final String NAME = "azureaistudio";
7576

7677
private static final String SERVICE_NAME = "Azure AI Studio";
77-
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION);
78+
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION, TaskType.RERANK);
7879

7980
private static final EnumSet<InputType> VALID_INPUT_TYPE_VALUES = EnumSet.of(
8081
InputType.INGEST,
@@ -270,8 +271,9 @@ private static AzureAiStudioModel createModel(
270271
ConfigurationParseContext context
271272
) {
272273

273-
if (taskType == TaskType.TEXT_EMBEDDING) {
274-
var embeddingsModel = new AzureAiStudioEmbeddingsModel(
274+
AzureAiStudioModel model;
275+
switch (taskType) {
276+
case TEXT_EMBEDDING -> model = new AzureAiStudioEmbeddingsModel(
275277
inferenceEntityId,
276278
taskType,
277279
NAME,
@@ -281,16 +283,7 @@ private static AzureAiStudioModel createModel(
281283
secretSettings,
282284
context
283285
);
284-
checkProviderAndEndpointTypeForTask(
285-
TaskType.TEXT_EMBEDDING,
286-
embeddingsModel.getServiceSettings().provider(),
287-
embeddingsModel.getServiceSettings().endpointType()
288-
);
289-
return embeddingsModel;
290-
}
291-
292-
if (taskType == TaskType.COMPLETION) {
293-
var completionModel = new AzureAiStudioChatCompletionModel(
286+
case COMPLETION -> model = new AzureAiStudioChatCompletionModel(
294287
inferenceEntityId,
295288
taskType,
296289
NAME,
@@ -299,15 +292,12 @@ private static AzureAiStudioModel createModel(
299292
secretSettings,
300293
context
301294
);
302-
checkProviderAndEndpointTypeForTask(
303-
TaskType.COMPLETION,
304-
completionModel.getServiceSettings().provider(),
305-
completionModel.getServiceSettings().endpointType()
306-
);
307-
return completionModel;
295+
case RERANK -> model = new AzureAiStudioRerankModel(inferenceEntityId, serviceSettings, taskSettings, secretSettings, context);
296+
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
308297
}
309-
310-
throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
298+
final var azureAiStudioServiceSettings = (AzureAiStudioServiceSettings) model.getServiceSettings();
299+
checkProviderAndEndpointTypeForTask(taskType, azureAiStudioServiceSettings.provider(), azureAiStudioServiceSettings.endpointType());
300+
return model;
311301
}
312302

313303
private AzureAiStudioModel createModelFromPersistent(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionCreator.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
import org.elasticsearch.xpack.inference.services.ServiceComponents;
1414
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioChatCompletionRequestManager;
1515
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEmbeddingsRequestManager;
16+
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioRerankRequestManager;
1617
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionModel;
1718
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModel;
19+
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel;
1820

1921
import java.util.Map;
2022
import java.util.Objects;
@@ -49,4 +51,12 @@ public ExecutableAction create(AzureAiStudioEmbeddingsModel embeddingsModel, Map
4951
var errorMessage = constructFailedToSendRequestMessage("Azure AI Studio embeddings");
5052
return new SenderExecutableAction(sender, requestManager, errorMessage);
5153
}
54+
55+
@Override
56+
public ExecutableAction create(AzureAiStudioRerankModel rerankModel, Map<String, Object> taskSettings) {
57+
var overriddenModel = AzureAiStudioRerankModel.of(rerankModel, taskSettings);
58+
var requestManager = new AzureAiStudioRerankRequestManager(overriddenModel, serviceComponents.threadPool());
59+
var errorMessage = constructFailedToSendRequestMessage("Azure AI Studio rerank");
60+
return new SenderExecutableAction(sender, requestManager, errorMessage);
61+
}
5262
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionVisitor.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@
1010
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1111
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionModel;
1212
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModel;
13+
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel;
1314

1415
import java.util.Map;
1516

1617
public interface AzureAiStudioActionVisitor {
1718
ExecutableAction create(AzureAiStudioEmbeddingsModel embeddingsModel, Map<String, Object> taskSettings);
1819

1920
ExecutableAction create(AzureAiStudioChatCompletionModel completionModel, Map<String, Object> taskSettings);
21+
22+
ExecutableAction create(AzureAiStudioRerankModel rerankModel, Map<String, Object> taskSettings);
2023
}

0 commit comments

Comments
 (0)