Skip to content

Commit 6faa60c

Browse files
Merge branch 'main' into remove-yaml-tests
2 parents 19f15b3 + 3b1523a commit 6faa60c

File tree

41 files changed

+2537
-172
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+2537
-172
lines changed

docs/changelog/129848.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 129848
2+
summary: "[ML] Add Azure AI Rerank support to the Inference Plugin"
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

modules/repository-s3/src/main/resources/org/elasticsearch/repositories/s3/regions_by_endpoint.txt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ ap-east-1 s3-fips.ap-east-1.amazonaws.com
66
ap-east-1 s3-fips.dualstack.ap-east-1.amazonaws.com
77
ap-east-1 s3.ap-east-1.amazonaws.com
88
ap-east-1 s3.dualstack.ap-east-1.amazonaws.com
9+
ap-east-2 s3-fips.ap-east-2.amazonaws.com
10+
ap-east-2 s3-fips.dualstack.ap-east-2.amazonaws.com
11+
ap-east-2 s3.ap-east-2.amazonaws.com
12+
ap-east-2 s3.dualstack.ap-east-2.amazonaws.com
913
ap-northeast-1 s3-fips.ap-northeast-1.amazonaws.com
1014
ap-northeast-1 s3-fips.dualstack.ap-northeast-1.amazonaws.com
1115
ap-northeast-1 s3.ap-northeast-1.amazonaws.com
@@ -56,6 +60,14 @@ aws-iso-b-global s3-fips.aws-iso-b-global.sc2s.sgov.gov
5660
aws-iso-b-global s3-fips.dualstack.aws-iso-b-global.sc2s.sgov.gov
5761
aws-iso-b-global s3.aws-iso-b-global.sc2s.sgov.gov
5862
aws-iso-b-global s3.dualstack.aws-iso-b-global.sc2s.sgov.gov
63+
aws-iso-e-global s3-fips.aws-iso-e-global.cloud.adc-e.uk
64+
aws-iso-e-global s3-fips.dualstack.aws-iso-e-global.cloud.adc-e.uk
65+
aws-iso-e-global s3.aws-iso-e-global.cloud.adc-e.uk
66+
aws-iso-e-global s3.dualstack.aws-iso-e-global.cloud.adc-e.uk
67+
aws-iso-f-global s3-fips.aws-iso-f-global.csp.hci.ic.gov
68+
aws-iso-f-global s3-fips.dualstack.aws-iso-f-global.csp.hci.ic.gov
69+
aws-iso-f-global s3.aws-iso-f-global.csp.hci.ic.gov
70+
aws-iso-f-global s3.dualstack.aws-iso-f-global.csp.hci.ic.gov
5971
aws-iso-global s3-fips.aws-iso-global.c2s.ic.gov
6072
aws-iso-global s3-fips.dualstack.aws-iso-global.c2s.ic.gov
6173
aws-iso-global s3.aws-iso-global.c2s.ic.gov
@@ -76,6 +88,10 @@ cn-north-1 s3.cn-north-1.amazonaws.com.cn
7688
cn-north-1 s3.dualstack.cn-north-1.amazonaws.com.cn
7789
cn-northwest-1 s3.cn-northwest-1.amazonaws.com.cn
7890
cn-northwest-1 s3.dualstack.cn-northwest-1.amazonaws.com.cn
91+
eusc-de-east-1 s3-fips.eusc-de-east-1.amazonaws.eu
92+
eusc-de-east-1 s3-fips.dualstack.eusc-de-east-1.amazonaws.eu
93+
eusc-de-east-1 s3.eusc-de-east-1.amazonaws.eu
94+
eusc-de-east-1 s3.dualstack.eusc-de-east-1.amazonaws.eu
7995
eu-central-1 s3-fips.dualstack.eu-central-1.amazonaws.com
8096
eu-central-1 s3-fips.eu-central-1.amazonaws.com
8197
eu-central-1 s3.dualstack.eu-central-1.amazonaws.com

modules/repository-s3/src/test/java/org/elasticsearch/repositories/s3/RegionFromEndpointGuesserTests.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@
99

1010
package org.elasticsearch.repositories.s3;
1111

12+
import software.amazon.awssdk.endpoints.Endpoint;
13+
import software.amazon.awssdk.regions.Region;
14+
import software.amazon.awssdk.services.s3.endpoints.S3EndpointParams;
15+
import software.amazon.awssdk.services.s3.endpoints.internal.DefaultS3EndpointProvider;
16+
1217
import org.elasticsearch.core.Nullable;
1318
import org.elasticsearch.test.ESTestCase;
1419

@@ -23,6 +28,14 @@ public void testRegionGuessing() {
2328
assertRegionGuess("random.endpoint.internal.net", null);
2429
}
2530

31+
public void testHasEntryForEachRegion() {
32+
final var defaultS3EndpointProvider = new DefaultS3EndpointProvider();
33+
for (var region : Region.regions()) {
34+
final Endpoint endpoint = safeGet(defaultS3EndpointProvider.resolveEndpoint(S3EndpointParams.builder().region(region).build()));
35+
assertNotNull(region.id(), RegionFromEndpointGuesser.guessRegion(endpoint.url().toString()));
36+
}
37+
}
38+
2639
private static void assertRegionGuess(String endpoint, @Nullable String expectedRegion) {
2740
assertEquals(endpoint, expectedRegion, RegionFromEndpointGuesser.guessRegion(endpoint));
2841
}

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ static TransportVersion def(int id) {
341341
public static final TransportVersion LOOKUP_JOIN_CCS = def(9_120_0_00);
342342
public static final TransportVersion NODE_USAGE_STATS_FOR_THREAD_POOLS_IN_CLUSTER_INFO = def(9_121_0_00);
343343
public static final TransportVersion ESQL_CATEGORIZE_OPTIONS = def(9_122_0_00);
344+
public static final TransportVersion ML_INFERENCE_AZURE_AI_STUDIO_RERANK_ADDED = def(9_123_0_00);
344345

345346
/*
346347
* STOP! READ THIS FIRST! No, really,

server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplainRequest.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ public String toString() {
236236
}
237237
}
238238
sb.append(",").append(INCLUDE_YES_DECISIONS_PARAMETER_NAME).append("?=").append(includeYesDecisions);
239+
sb.append(",").append(INCLUDE_DISK_INFO_PARAMETER_NAME).append("?=").append(includeDiskInfo);
239240
return sb.toString();
240241
}
241242

server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplainRequestTests.java

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,66 @@ public void testSerialization() throws Exception {
3737
assertEquals(request.getCurrentNode(), actual.getCurrentNode());
3838
}
3939

40+
public void testToStringWithEmptyBody() {
41+
ClusterAllocationExplainRequest clusterAllocationExplainRequest = new ClusterAllocationExplainRequest(randomTimeValue());
42+
clusterAllocationExplainRequest.includeYesDecisions(true);
43+
clusterAllocationExplainRequest.includeDiskInfo(false);
44+
45+
String expected = "ClusterAllocationExplainRequest[useAnyUnassignedShard=true,"
46+
+ "include_yes_decisions?=true,include_disk_info?=false";
47+
assertEquals(expected, clusterAllocationExplainRequest.toString());
48+
}
49+
50+
public void testToStringWithValidBodyButCurrentNodeIsNull() {
51+
String index = "test-index";
52+
int shard = randomInt();
53+
boolean primary = randomBoolean();
54+
ClusterAllocationExplainRequest clusterAllocationExplainRequest = new ClusterAllocationExplainRequest(
55+
randomTimeValue(),
56+
index,
57+
shard,
58+
primary,
59+
null
60+
);
61+
clusterAllocationExplainRequest.includeYesDecisions(false);
62+
clusterAllocationExplainRequest.includeDiskInfo(true);
63+
64+
String expected = "ClusterAllocationExplainRequest[index="
65+
+ index
66+
+ ",shard="
67+
+ shard
68+
+ ",primary?="
69+
+ primary
70+
+ ",include_yes_decisions?=false"
71+
+ ",include_disk_info?=true";
72+
assertEquals(expected, clusterAllocationExplainRequest.toString());
73+
}
74+
75+
public void testToStringWithAllBodyParameters() {
76+
String index = "test-index";
77+
int shard = randomInt();
78+
boolean primary = randomBoolean();
79+
String currentNode = "current_node";
80+
ClusterAllocationExplainRequest clusterAllocationExplainRequest = new ClusterAllocationExplainRequest(
81+
randomTimeValue(),
82+
index,
83+
shard,
84+
primary,
85+
currentNode
86+
);
87+
clusterAllocationExplainRequest.includeYesDecisions(false);
88+
clusterAllocationExplainRequest.includeDiskInfo(true);
89+
90+
String expected = "ClusterAllocationExplainRequest[index="
91+
+ index
92+
+ ",shard="
93+
+ shard
94+
+ ",primary?="
95+
+ primary
96+
+ ",current_node="
97+
+ currentNode
98+
+ ",include_yes_decisions?=false"
99+
+ ",include_disk_info?=true";
100+
assertEquals(expected, clusterAllocationExplainRequest.toString());
101+
}
40102
}

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
111111
containsInAnyOrder(
112112
List.of(
113113
"alibabacloud-ai-search",
114+
"azureaistudio",
114115
"cohere",
115116
"elasticsearch",
116117
"googlevertexai",

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionTaskSettings;
5151
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsServiceSettings;
5252
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsTaskSettings;
53+
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankServiceSettings;
54+
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankTaskSettings;
5355
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings;
5456
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionServiceSettings;
5557
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionTaskSettings;
@@ -306,6 +308,17 @@ private static void addAzureAiStudioNamedWriteables(List<NamedWriteableRegistry.
306308
AzureAiStudioChatCompletionTaskSettings::new
307309
)
308310
);
311+
312+
namedWriteables.add(
313+
new NamedWriteableRegistry.Entry(
314+
ServiceSettings.class,
315+
AzureAiStudioRerankServiceSettings.NAME,
316+
AzureAiStudioRerankServiceSettings::new
317+
)
318+
);
319+
namedWriteables.add(
320+
new NamedWriteableRegistry.Entry(TaskSettings.class, AzureAiStudioRerankTaskSettings.NAME, AzureAiStudioRerankTaskSettings::new)
321+
);
309322
}
310323

311324
private static void addAzureOpenAiNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ public void validateResponse(
9595

9696
protected abstract void checkForFailureStatusCode(Request request, HttpResult result);
9797

98-
private void checkForErrorObject(Request request, HttpResult result) {
98+
protected void checkForErrorObject(Request request, HttpResult result) {
9999
var errorEntity = errorParseFunction.apply(result);
100100

101101
if (errorEntity.errorStructureFound()) {
@@ -116,12 +116,12 @@ protected Exception buildError(String message, Request request, HttpResult resul
116116
protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) {
117117
var responseStatusCode = result.response().getStatusLine().getStatusCode();
118118
return new ElasticsearchStatusException(
119-
errorMessage(message, request, result, errorResponse, responseStatusCode),
119+
constructErrorMessage(message, request, errorResponse, responseStatusCode),
120120
toRestStatus(responseStatusCode)
121121
);
122122
}
123123

124-
protected String errorMessage(String message, Request request, HttpResult result, ErrorResponse errorResponse, int statusCode) {
124+
public static String constructErrorMessage(String message, Request request, ErrorResponse errorResponse, int statusCode) {
125125
return (errorResponse == null
126126
|| errorResponse.errorStructureFound() == false
127127
|| Strings.isNullOrEmpty(errorResponse.getErrorMessage()))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.http.retry;
9+
10+
import org.elasticsearch.rest.RestStatus;
11+
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
12+
import org.elasticsearch.xpack.inference.external.http.HttpResult;
13+
import org.elasticsearch.xpack.inference.external.request.Request;
14+
15+
import java.util.Locale;
16+
import java.util.Objects;
17+
18+
import static org.elasticsearch.core.Strings.format;
19+
import static org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler.SERVER_ERROR_OBJECT;
20+
import static org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler.toRestStatus;
21+
22+
public class ChatCompletionErrorResponseHandler {
23+
private static final String STREAM_ERROR = "stream_error";
24+
25+
private final UnifiedChatCompletionErrorParser unifiedChatCompletionErrorParser;
26+
27+
public ChatCompletionErrorResponseHandler(UnifiedChatCompletionErrorParser errorParser) {
28+
this.unifiedChatCompletionErrorParser = Objects.requireNonNull(errorParser);
29+
}
30+
31+
public void checkForErrorObject(Request request, HttpResult result) {
32+
var errorEntity = unifiedChatCompletionErrorParser.parse(result);
33+
34+
if (errorEntity.errorStructureFound()) {
35+
// We don't really know what happened because the status code was 200 so we'll return a failure and let the
36+
// client retry if necessary
37+
// If we did want to retry here, we'll need to determine if this was a streaming request, if it was
38+
// we shouldn't retry because that would replay the entire streaming request and the client would get
39+
// duplicate chunks back
40+
throw new RetryException(false, buildChatCompletionErrorInternal(SERVER_ERROR_OBJECT, request, result, errorEntity));
41+
}
42+
}
43+
44+
public UnifiedChatCompletionException buildChatCompletionError(String message, Request request, HttpResult result) {
45+
var errorResponse = unifiedChatCompletionErrorParser.parse(result);
46+
return buildChatCompletionErrorInternal(message, request, result, errorResponse);
47+
}
48+
49+
private UnifiedChatCompletionException buildChatCompletionErrorInternal(
50+
String message,
51+
Request request,
52+
HttpResult result,
53+
UnifiedChatCompletionErrorResponse errorResponse
54+
) {
55+
assert request.isStreaming() : "Only streaming requests support this format";
56+
var statusCode = result.response().getStatusLine().getStatusCode();
57+
var errorMessage = BaseResponseHandler.constructErrorMessage(message, request, errorResponse, statusCode);
58+
var restStatus = toRestStatus(statusCode);
59+
60+
if (errorResponse.errorStructureFound()) {
61+
return new UnifiedChatCompletionException(
62+
restStatus,
63+
errorMessage,
64+
errorResponse.type(),
65+
errorResponse.code(),
66+
errorResponse.param()
67+
);
68+
} else {
69+
return buildDefaultChatCompletionError(errorResponse, errorMessage, restStatus);
70+
}
71+
}
72+
73+
/**
74+
* Builds a default {@link UnifiedChatCompletionException} for a streaming request.
75+
* This method is used when an error response is received we were unable to parse it in the format we were expecting.
76+
* Only streaming requests should use this method.
77+
*
78+
* @param errorResponse the error response extracted from the HTTP result
79+
* @param errorMessage the error message to include in the exception
80+
* @param restStatus the REST status code of the response
81+
* @return an instance of {@link UnifiedChatCompletionException} with details from the error response
82+
*/
83+
private static UnifiedChatCompletionException buildDefaultChatCompletionError(
84+
ErrorResponse errorResponse,
85+
String errorMessage,
86+
RestStatus restStatus
87+
) {
88+
return new UnifiedChatCompletionException(
89+
restStatus,
90+
errorMessage,
91+
createErrorType(errorResponse),
92+
restStatus.name().toLowerCase(Locale.ROOT)
93+
);
94+
}
95+
96+
/**
97+
* Builds a mid-stream error for a streaming request.
98+
* This method is used when an error occurs while processing a streaming response.
99+
* Only streaming requests should use this method.
100+
*
101+
* @param inferenceEntityId the ID of the inference entity
102+
* @param message the error message
103+
* @param e the exception that caused the error, can be null
104+
* @return a {@link UnifiedChatCompletionException} representing the mid-stream error
105+
*/
106+
public UnifiedChatCompletionException buildMidStreamChatCompletionError(String inferenceEntityId, String message, Exception e) {
107+
var error = unifiedChatCompletionErrorParser.parse(message);
108+
109+
if (error.errorStructureFound()) {
110+
return new UnifiedChatCompletionException(
111+
RestStatus.INTERNAL_SERVER_ERROR,
112+
format(
113+
"%s for request from inference entity id [%s]. Error message: [%s]",
114+
SERVER_ERROR_OBJECT,
115+
inferenceEntityId,
116+
error.getErrorMessage()
117+
),
118+
error.type(),
119+
error.code(),
120+
error.param()
121+
);
122+
} else if (e != null) {
123+
// If the error response does not match, we can still return an exception based on the original throwable
124+
return UnifiedChatCompletionException.fromThrowable(e);
125+
} else {
126+
// If no specific error response is found, we return a default mid-stream error
127+
return buildDefaultMidStreamChatCompletionError(inferenceEntityId, error);
128+
}
129+
}
130+
131+
/**
132+
* Builds a default mid-stream error for a streaming request.
133+
* This method is used when no specific error response is found in the message.
134+
* Only streaming requests should use this method.
135+
*
136+
* @param inferenceEntityId the ID of the inference entity
137+
* @param errorResponse the error response extracted from the message
138+
* @return a {@link UnifiedChatCompletionException} representing the default mid-stream error
139+
*/
140+
private static UnifiedChatCompletionException buildDefaultMidStreamChatCompletionError(
141+
String inferenceEntityId,
142+
ErrorResponse errorResponse
143+
) {
144+
return new UnifiedChatCompletionException(
145+
RestStatus.INTERNAL_SERVER_ERROR,
146+
format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, inferenceEntityId),
147+
createErrorType(errorResponse),
148+
STREAM_ERROR
149+
);
150+
}
151+
152+
/**
153+
* Creates a string representation of the error type based on the provided ErrorResponse.
154+
* This method is used to generate a human-readable error type for logging or exception messages.
155+
*
156+
* @param errorResponse the ErrorResponse object
157+
* @return a string representing the error type
158+
*/
159+
private static String createErrorType(ErrorResponse errorResponse) {
160+
return errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown";
161+
}
162+
}

0 commit comments

Comments
 (0)