Skip to content

Commit a7320f4

Browse files
Refactor error handling in response classes to use ChatCompletionErrorResponse for improved consistency and maintainability
1 parent 8ff0bb6 commit a7320f4

13 files changed

+188
-301
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java

Lines changed: 11 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@
1919

2020
import java.util.Locale;
2121
import java.util.Objects;
22-
import java.util.function.BiFunction;
2322
import java.util.function.Function;
24-
import java.util.function.Supplier;
2523

2624
import static org.elasticsearch.core.Strings.format;
2725
import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody;
@@ -38,7 +36,6 @@ public abstract class BaseResponseHandler implements ResponseHandler {
3836
public static final String SERVER_ERROR_OBJECT = "Received an error response";
3937
public static final String BAD_REQUEST = "Received a bad request status code";
4038
public static final String METHOD_NOT_ALLOWED = "Received a method not allowed status code";
41-
protected static final String ERROR_TYPE = "error";
4239
protected static final String STREAM_ERROR = "stream_error";
4340

4441
protected final String requestType;
@@ -140,47 +137,22 @@ protected Exception buildError(String message, Request request, HttpResult resul
140137
* @param request the request that caused the error
141138
* @param result the HTTP result containing the error response
142139
* @param errorResponse the parsed error response from the HTTP result
143-
* @param errorResponseClassSupplier the supplier that provides the class of the expected error response type
144-
* @param chatCompletionErrorBuilder the builder for creating provider-specific chat completion errors
145140
* @return an instance of {@link UnifiedChatCompletionException} with details from the error response
146141
*/
147142
protected UnifiedChatCompletionException buildChatCompletionError(
148143
String message,
149144
Request request,
150145
HttpResult result,
151-
ErrorResponse errorResponse,
152-
Supplier<Class<? extends ErrorResponse>> errorResponseClassSupplier,
153-
ChatCompletionErrorBuilder chatCompletionErrorBuilder
146+
ErrorResponse errorResponse
154147
) {
155148
assert request.isStreaming() : "Only streaming requests support this format";
156149
var statusCode = result.response().getStatusLine().getStatusCode();
157150
var errorMessage = extractErrorMessage(message, request, errorResponse, statusCode);
158151
var restStatus = toRestStatus(statusCode);
159152

160-
return buildChatCompletionError(errorResponse, errorMessage, restStatus, errorResponseClassSupplier, chatCompletionErrorBuilder);
161-
}
162-
163-
/**
164-
* Builds a {@link UnifiedChatCompletionException} for a streaming request.
165-
* This method is used when an error response is received from the external service.
166-
* Only streaming requests should use this method.
167-
*
168-
* @param errorResponse the error response parsed from the HTTP result
169-
* @param errorMessage the error message to include in the exception
170-
* @param restStatus the REST status code of the response
171-
* @param errorResponseClassSupplier the supplier that provides the class of the expected error response type
172-
* @param chatCompletionErrorBuilder the builder for creating provider-specific chat completion errors
173-
* @return an instance of {@link UnifiedChatCompletionException} with details from the error response
174-
*/
175-
protected UnifiedChatCompletionException buildChatCompletionError(
176-
ErrorResponse errorResponse,
177-
String errorMessage,
178-
RestStatus restStatus,
179-
Supplier<Class<? extends ErrorResponse>> errorResponseClassSupplier,
180-
ChatCompletionErrorBuilder chatCompletionErrorBuilder
181-
) {
182-
if (errorResponse.errorStructureFound() && errorResponseClassSupplier.get().isInstance(errorResponse)) {
183-
return chatCompletionErrorBuilder.buildProviderSpecificChatCompletionError(errorResponse, errorMessage, restStatus);
153+
if (errorResponse.errorStructureFound()
154+
&& errorResponse instanceof UnifiedChatCompletionExceptionConvertible chatCompletionExceptionConvertible) {
155+
return chatCompletionExceptionConvertible.toUnifiedChatCompletionException(errorMessage, restStatus);
184156
} else {
185157
return buildDefaultChatCompletionError(errorResponse, errorMessage, restStatus);
186158
}
@@ -196,7 +168,7 @@ protected UnifiedChatCompletionException buildChatCompletionError(
196168
* @param restStatus the REST status code of the response
197169
* @return an instance of {@link UnifiedChatCompletionException} with details from the error response
198170
*/
199-
private static UnifiedChatCompletionException buildDefaultChatCompletionError(
171+
protected static UnifiedChatCompletionException buildDefaultChatCompletionError(
200172
ErrorResponse errorResponse,
201173
String errorMessage,
202174
RestStatus restStatus
@@ -217,31 +189,27 @@ private static UnifiedChatCompletionException buildDefaultChatCompletionError(
217189
* @param inferenceEntityId the ID of the inference entity
218190
* @param message the error message
219191
* @param e the exception that caused the error, can be null
220-
* @param errorResponseClassSupplier a supplier that provides the class of the expected error response type
221-
* @param specificErrorBuilder a function that builds a specific error based on the inference entity ID and error response
222192
* @param midStreamErrorExtractor a function that extracts the mid-stream error response from the message
223193
* @return a {@link UnifiedChatCompletionException} representing the mid-stream error
224194
*/
225195
protected UnifiedChatCompletionException buildMidStreamChatCompletionError(
226196
String inferenceEntityId,
227197
String message,
228198
Exception e,
229-
Supplier<Class<? extends ErrorResponse>> errorResponseClassSupplier,
230-
BiFunction<String, ErrorResponse, UnifiedChatCompletionException> specificErrorBuilder,
231199
Function<String, ErrorResponse> midStreamErrorExtractor
232200
) {
233201
// Extract the error response from the message using the provided method
234-
var errorResponse = midStreamErrorExtractor.apply(message);
202+
var error = midStreamErrorExtractor.apply(message);
235203
// Check if the error response matches the expected type
236-
if (errorResponse.errorStructureFound() && errorResponseClassSupplier.get().isInstance(errorResponse)) {
204+
if (error.errorStructureFound() && error instanceof MidStreamUnifiedChatCompletionExceptionConvertible midStreamError) {
237205
// If it matches, we can build a custom mid-stream error exception
238-
return specificErrorBuilder.apply(inferenceEntityId, errorResponse);
206+
return midStreamError.toUnifiedChatCompletionException(inferenceEntityId);
239207
} else if (e != null) {
240208
// If the error response does not match, we can still return an exception based on the original throwable
241209
return UnifiedChatCompletionException.fromThrowable(e);
242210
} else {
243211
// If no specific error response is found, we return a default mid-stream error
244-
return buildDefaultMidStreamChatCompletionError(inferenceEntityId, errorResponse);
212+
return buildDefaultMidStreamChatCompletionError(inferenceEntityId, error);
245213
}
246214
}
247215

@@ -277,7 +245,7 @@ private static String createErrorType(ErrorResponse errorResponse) {
277245
return errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown";
278246
}
279247

280-
private static String extractErrorMessage(String message, Request request, ErrorResponse errorResponse, int statusCode) {
248+
protected static String extractErrorMessage(String message, Request request, ErrorResponse errorResponse, int statusCode) {
281249
return (errorResponse == null
282250
|| errorResponse.errorStructureFound() == false
283251
|| Strings.isNullOrEmpty(errorResponse.getErrorMessage()))
@@ -291,7 +259,7 @@ private static String extractErrorMessage(String message, Request request, Error
291259
);
292260
}
293261

294-
public static RestStatus toRestStatus(int statusCode) {
262+
protected static RestStatus toRestStatus(int statusCode) {
295263
RestStatus code = null;
296264
if (statusCode < 500) {
297265
code = RestStatus.fromCode(statusCode);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.http.retry;
9+
10+
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
11+
12+
public interface MidStreamUnifiedChatCompletionExceptionConvertible {
13+
14+
UnifiedChatCompletionException toUnifiedChatCompletionException(String inferenceEntityId);
15+
16+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.http.retry;
9+
10+
import org.elasticsearch.rest.RestStatus;
11+
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
12+
13+
public interface UnifiedChatCompletionExceptionConvertible {
14+
15+
UnifiedChatCompletionException toUnifiedChatCompletionException(String errorMessage, RestStatus restStatus);
16+
17+
}
Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,26 @@
88
package org.elasticsearch.xpack.inference.external.response.streaming;
99

1010
import org.elasticsearch.core.Nullable;
11+
import org.elasticsearch.rest.RestStatus;
1112
import org.elasticsearch.xcontent.ConstructingObjectParser;
1213
import org.elasticsearch.xcontent.ParseField;
1314
import org.elasticsearch.xcontent.XContentFactory;
1415
import org.elasticsearch.xcontent.XContentParser;
1516
import org.elasticsearch.xcontent.XContentParserConfiguration;
1617
import org.elasticsearch.xcontent.XContentType;
18+
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
1719
import org.elasticsearch.xpack.inference.external.http.HttpResult;
1820
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
21+
import org.elasticsearch.xpack.inference.external.http.retry.MidStreamUnifiedChatCompletionExceptionConvertible;
22+
import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionExceptionConvertible;
1923
import org.elasticsearch.xpack.inference.external.response.ErrorMessageResponseEntity;
2024

2125
import java.util.Objects;
2226
import java.util.Optional;
2327

28+
import static org.elasticsearch.core.Strings.format;
29+
import static org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler.SERVER_ERROR_OBJECT;
30+
2431
/**
2532
* Represents an error response from a streaming inference service.
2633
* This class extends {@link ErrorResponse} and provides additional fields
@@ -38,17 +45,21 @@
3845
* </code></pre>
3946
* TODO: {@link ErrorMessageResponseEntity} is nearly identical to this, but doesn't parse as many fields. We must remove the duplication.
4047
*/
41-
public class StreamingErrorResponse extends ErrorResponse {
48+
public class OpenAiStreamingChatCompletionErrorResponse extends ErrorResponse
49+
implements
50+
UnifiedChatCompletionExceptionConvertible,
51+
MidStreamUnifiedChatCompletionExceptionConvertible {
4252
private static final ConstructingObjectParser<Optional<ErrorResponse>, Void> ERROR_PARSER = new ConstructingObjectParser<>(
4353
"streaming_error",
4454
true,
45-
args -> Optional.ofNullable((StreamingErrorResponse) args[0])
46-
);
47-
private static final ConstructingObjectParser<StreamingErrorResponse, Void> ERROR_BODY_PARSER = new ConstructingObjectParser<>(
48-
"streaming_error",
49-
true,
50-
args -> new StreamingErrorResponse((String) args[0], (String) args[1], (String) args[2], (String) args[3])
55+
args -> Optional.ofNullable((OpenAiStreamingChatCompletionErrorResponse) args[0])
5156
);
57+
private static final ConstructingObjectParser<OpenAiStreamingChatCompletionErrorResponse, Void> ERROR_BODY_PARSER =
58+
new ConstructingObjectParser<>(
59+
"streaming_error",
60+
true,
61+
args -> new OpenAiStreamingChatCompletionErrorResponse((String) args[0], (String) args[1], (String) args[2], (String) args[3])
62+
);
5263

5364
static {
5465
ERROR_BODY_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("message"));
@@ -105,13 +116,34 @@ public static ErrorResponse fromString(String response) {
105116
private final String param;
106117
private final String type;
107118

108-
StreamingErrorResponse(String errorMessage, @Nullable String code, @Nullable String param, String type) {
119+
OpenAiStreamingChatCompletionErrorResponse(String errorMessage, @Nullable String code, @Nullable String param, String type) {
109120
super(errorMessage);
110121
this.code = code;
111122
this.param = param;
112123
this.type = Objects.requireNonNull(type);
113124
}
114125

126+
@Override
127+
public UnifiedChatCompletionException toUnifiedChatCompletionException(String errorMessage, RestStatus restStatus) {
128+
return new UnifiedChatCompletionException(restStatus, errorMessage, this.type(), this.code(), this.param());
129+
}
130+
131+
@Override
132+
public UnifiedChatCompletionException toUnifiedChatCompletionException(String inferenceEntityId) {
133+
return new UnifiedChatCompletionException(
134+
RestStatus.INTERNAL_SERVER_ERROR,
135+
format(
136+
"%s for request from inference entity id [%s]. Error message: [%s]",
137+
SERVER_ERROR_OBJECT,
138+
inferenceEntityId,
139+
this.getErrorMessage()
140+
),
141+
this.type(),
142+
this.code(),
143+
this.param()
144+
);
145+
}
146+
115147
@Nullable
116148
public String code() {
117149
return code;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
* This handler is designed to work with the unified Elastic Inference Service chat completion API.
3131
*/
3232
public class ElasticInferenceServiceUnifiedChatCompletionResponseHandler extends ElasticInferenceServiceResponseHandler {
33+
private static final String ERROR_TYPE = "error";
34+
3335
public ElasticInferenceServiceUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
3436
super(requestType, parseFunction, true);
3537
}
@@ -59,22 +61,16 @@ public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpR
5961
*/
6062
@Override
6163
protected UnifiedChatCompletionException buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) {
62-
return buildChatCompletionError(
63-
message,
64-
request,
65-
result,
66-
errorResponse,
67-
() -> ErrorResponse.class,
68-
ElasticInferenceServiceUnifiedChatCompletionResponseHandler::buildProviderSpecificChatCompletionError
69-
);
70-
}
64+
assert request.isStreaming() : "Only streaming requests support this format";
65+
var statusCode = result.response().getStatusLine().getStatusCode();
66+
var errorMessage = extractErrorMessage(message, request, errorResponse, statusCode);
67+
var restStatus = toRestStatus(statusCode);
7168

72-
private static UnifiedChatCompletionException buildProviderSpecificChatCompletionError(
73-
ErrorResponse response,
74-
String message,
75-
RestStatus restStatus
76-
) {
77-
return new UnifiedChatCompletionException(restStatus, message, ERROR_TYPE, restStatus.name().toLowerCase(Locale.ROOT));
69+
if (errorResponse.errorStructureFound()) {
70+
return new UnifiedChatCompletionException(restStatus, errorMessage, ERROR_TYPE, restStatus.name().toLowerCase(Locale.ROOT));
71+
} else {
72+
return buildDefaultChatCompletionError(errorResponse, errorMessage, restStatus);
73+
}
7874
}
7975

8076
/**

0 commit comments

Comments
 (0)