Skip to content

Commit f4582f3

Browse files
Refactor error handling in streaming response handlers to use functional interfaces for improved flexibility
1 parent 0a09f00 commit f4582f3

File tree

8 files changed

+158
-203
lines changed

8 files changed

+158
-203
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java

Lines changed: 28 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919

2020
import java.util.Locale;
2121
import java.util.Objects;
22+
import java.util.function.BiFunction;
2223
import java.util.function.Function;
24+
import java.util.function.Supplier;
2325

2426
import static org.elasticsearch.core.Strings.format;
2527
import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody;
@@ -124,7 +126,7 @@ protected Exception buildError(String message, Request request, HttpResult resul
124126
protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) {
125127
var responseStatusCode = result.response().getStatusLine().getStatusCode();
126128
return new ElasticsearchStatusException(
127-
errorMessage(message, request, errorResponse, responseStatusCode),
129+
extractErrorMessage(message, request, errorResponse, responseStatusCode),
128130
toRestStatus(responseStatusCode)
129131
);
130132
}
@@ -138,22 +140,24 @@ protected Exception buildError(String message, Request request, HttpResult resul
138140
* @param request the request that caused the error
139141
* @param result the HTTP result containing the error response
140142
* @param errorResponse the parsed error response from the HTTP result
141-
* @param errorResponseClass the class of the expected error response type
143+
* @param errorResponseClassSupplier the supplier that provides the class of the expected error response type
144+
* @param chatCompletionErrorBuilder the builder for creating provider-specific chat completion errors
142145
* @return an instance of {@link UnifiedChatCompletionException} with details from the error response
143146
*/
144147
protected UnifiedChatCompletionException buildChatCompletionError(
145148
String message,
146149
Request request,
147150
HttpResult result,
148151
ErrorResponse errorResponse,
149-
Class<? extends ErrorResponse> errorResponseClass
152+
Supplier<Class<? extends ErrorResponse>> errorResponseClassSupplier,
153+
ChatCompletionErrorBuilder chatCompletionErrorBuilder
150154
) {
151155
assert request.isStreaming() : "Only streaming requests support this format";
152156
var statusCode = result.response().getStatusLine().getStatusCode();
153-
var errorMessage = errorMessage(message, request, errorResponse, statusCode);
157+
var errorMessage = extractErrorMessage(message, request, errorResponse, statusCode);
154158
var restStatus = toRestStatus(statusCode);
155159

156-
return buildChatCompletionError(errorResponse, errorMessage, restStatus, errorResponseClass);
160+
return buildChatCompletionError(errorResponse, errorMessage, restStatus, errorResponseClassSupplier, chatCompletionErrorBuilder);
157161
}
158162

159163
/**
@@ -164,43 +168,24 @@ protected UnifiedChatCompletionException buildChatCompletionError(
164168
* @param errorResponse the error response parsed from the HTTP result
165169
* @param errorMessage the error message to include in the exception
166170
* @param restStatus the REST status code of the response
167-
* @param errorResponseClass the class of the expected error response type
171+
* @param errorResponseClassSupplier the supplier that provides the class of the expected error response type
172+
* @param chatCompletionErrorBuilder the builder for creating provider-specific chat completion errors
168173
* @return an instance of {@link UnifiedChatCompletionException} with details from the error response
169174
*/
170175
protected UnifiedChatCompletionException buildChatCompletionError(
171176
ErrorResponse errorResponse,
172177
String errorMessage,
173178
RestStatus restStatus,
174-
Class<? extends ErrorResponse> errorResponseClass
179+
Supplier<Class<? extends ErrorResponse>> errorResponseClassSupplier,
180+
ChatCompletionErrorBuilder chatCompletionErrorBuilder
175181
) {
176-
if (errorResponseClass.isInstance(errorResponse)) {
177-
return buildProviderSpecificChatCompletionError(errorResponse, errorMessage, restStatus);
182+
if (errorResponseClassSupplier.get().isInstance(errorResponse)) {
183+
return chatCompletionErrorBuilder.buildProviderSpecificChatCompletionError(errorResponse, errorMessage, restStatus);
178184
} else {
179185
return buildDefaultChatCompletionError(errorResponse, errorMessage, restStatus);
180186
}
181187
}
182188

183-
/**
184-
* Builds a custom {@link UnifiedChatCompletionException} for a streaming request.
185-
* This method is called when a specific error response is found in the HTTP result.
186-
* It must be implemented by subclasses to handle specific error response formats.
187-
* Only streaming requests should use this method.
188-
*
189-
* @param errorResponse the error response parsed from the HTTP result
190-
* @param errorMessage the error message to include in the exception
191-
* @param restStatus the REST status code of the response
192-
* @return an instance of {@link UnifiedChatCompletionException} with details from the error response
193-
*/
194-
protected UnifiedChatCompletionException buildProviderSpecificChatCompletionError(
195-
ErrorResponse errorResponse,
196-
String errorMessage,
197-
RestStatus restStatus
198-
) {
199-
throw new UnsupportedOperationException(
200-
"Custom error handling is not implemented. Please override buildProviderSpecificChatCompletionError method."
201-
);
202-
}
203-
204189
/**
205190
* Builds a default {@link UnifiedChatCompletionException} for a streaming request.
206191
* This method is used when an error response is received but no specific error handling is implemented.
@@ -211,7 +196,7 @@ protected UnifiedChatCompletionException buildProviderSpecificChatCompletionErro
211196
* @param restStatus the REST status code of the response
212197
* @return an instance of {@link UnifiedChatCompletionException} with details from the error response
213198
*/
214-
protected UnifiedChatCompletionException buildDefaultChatCompletionError(
199+
private static UnifiedChatCompletionException buildDefaultChatCompletionError(
215200
ErrorResponse errorResponse,
216201
String errorMessage,
217202
RestStatus restStatus
@@ -232,21 +217,25 @@ protected UnifiedChatCompletionException buildDefaultChatCompletionError(
232217
* @param inferenceEntityId the ID of the inference entity
233218
* @param message the error message
234219
* @param e the exception that caused the error, can be null
235-
* @param errorResponseClass the class of the expected error response type
220+
* @param errorResponseClassSupplier a supplier that provides the class of the expected error response type
221+
* @param specificErrorBuilder a function that builds a specific error based on the inference entity ID and error response
222+
* @param midStreamErrorExtractor a function that extracts the mid-stream error response from the message
236223
* @return a {@link UnifiedChatCompletionException} representing the mid-stream error
237224
*/
238225
protected UnifiedChatCompletionException buildMidStreamChatCompletionError(
239226
String inferenceEntityId,
240227
String message,
241228
Exception e,
242-
Class<? extends ErrorResponse> errorResponseClass
229+
Supplier<Class<? extends ErrorResponse>> errorResponseClassSupplier,
230+
BiFunction<String, ErrorResponse, UnifiedChatCompletionException> specificErrorBuilder,
231+
Function<String, ErrorResponse> midStreamErrorExtractor
243232
) {
244233
// Extract the error response from the message using the provided method
245-
var errorResponse = extractMidStreamChatCompletionErrorResponse(message);
234+
var errorResponse = midStreamErrorExtractor.apply(message);
246235
// Check if the error response matches the expected type
247-
if (errorResponseClass.isInstance(errorResponse)) {
236+
if (errorResponseClassSupplier.get().isInstance(errorResponse)) {
248237
// If it matches, we can build a custom mid-stream error exception
249-
return buildProviderSpecificMidStreamChatCompletionError(inferenceEntityId, errorResponse);
238+
return specificErrorBuilder.apply(inferenceEntityId, errorResponse);
250239
} else if (e != null) {
251240
// If the error response does not match, we can still return an exception based on the original throwable
252241
return UnifiedChatCompletionException.fromThrowable(e);
@@ -256,26 +245,6 @@ protected UnifiedChatCompletionException buildMidStreamChatCompletionError(
256245
}
257246
}
258247

259-
/**
260-
* Builds a custom mid-stream {@link UnifiedChatCompletionException} for a streaming request.
261-
* This method is called when a specific error response is found in the message.
262-
* It must be implemented by subclasses to handle specific error response formats.
263-
* Only streaming requests should use this method.
264-
*
265-
* @param inferenceEntityId the ID of the inference entity
266-
* @param errorResponse the error response parsed from the message
267-
* @return an instance of {@link UnifiedChatCompletionException} with details from the error response
268-
*/
269-
protected UnifiedChatCompletionException buildProviderSpecificMidStreamChatCompletionError(
270-
String inferenceEntityId,
271-
ErrorResponse errorResponse
272-
) {
273-
throw new UnsupportedOperationException(
274-
"Mid-stream error handling is not implemented for this response handler. "
275-
+ "Please override buildProviderSpecificMidStreamChatCompletionError method."
276-
);
277-
}
278-
279248
/**
280249
* Builds a default mid-stream error for a streaming request.
281250
* This method is used when no specific error response is found in the message.
@@ -285,7 +254,7 @@ protected UnifiedChatCompletionException buildProviderSpecificMidStreamChatCompl
285254
* @param errorResponse the error response extracted from the message
286255
* @return a {@link UnifiedChatCompletionException} representing the default mid-stream error
287256
*/
288-
protected UnifiedChatCompletionException buildDefaultMidStreamChatCompletionError(
257+
protected static UnifiedChatCompletionException buildDefaultMidStreamChatCompletionError(
289258
String inferenceEntityId,
290259
ErrorResponse errorResponse
291260
) {
@@ -297,33 +266,18 @@ protected UnifiedChatCompletionException buildDefaultMidStreamChatCompletionErro
297266
);
298267
}
299268

300-
/**
301-
* Extracts the mid-stream error response from the message.
302-
* This method is used to parse the error response from a streaming message.
303-
* It must be implemented by subclasses to handle specific error response formats.
304-
* Only streaming requests should use this method.
305-
*
306-
* @param message the message containing the error response
307-
* @return an {@link ErrorResponse} object representing the mid-stream error
308-
*/
309-
protected ErrorResponse extractMidStreamChatCompletionErrorResponse(String message) {
310-
throw new UnsupportedOperationException(
311-
"Mid-stream error extraction is not implemented. Please override extractMidStreamChatCompletionErrorResponse method."
312-
);
313-
}
314-
315269
/**
316270
* Creates a string representation of the error type based on the provided ErrorResponse.
317271
* This method is used to generate a human-readable error type for logging or exception messages.
318272
*
319273
* @param errorResponse the ErrorResponse object
320274
* @return a string representing the error type
321275
*/
322-
protected static String createErrorType(ErrorResponse errorResponse) {
276+
private static String createErrorType(ErrorResponse errorResponse) {
323277
return errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown";
324278
}
325279

326-
protected String errorMessage(String message, Request request, ErrorResponse errorResponse, int statusCode) {
280+
private static String extractErrorMessage(String message, Request request, ErrorResponse errorResponse, int statusCode) {
327281
return (errorResponse == null
328282
|| errorResponse.errorStructureFound() == false
329283
|| Strings.isNullOrEmpty(errorResponse.getErrorMessage()))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.http.retry;
9+
10+
import org.elasticsearch.rest.RestStatus;
11+
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
12+
13+
/**
14+
* Functional interface for building provider-specific chat completion errors.
15+
* This interface is used to create exceptions that are specific to the chat completion service being used.
16+
*/
17+
@FunctionalInterface
18+
public interface ChatCompletionErrorBuilder {
19+
20+
/**
21+
* Builds a provider-specific chat completion error based on the given parameters.
22+
*
23+
* @param errorResponse The error response received from the service.
24+
* @param errorMessage A custom error message to include in the exception.
25+
* @param restStatus The HTTP status code associated with the error.
26+
* @return An instance of {@link UnifiedChatCompletionException} representing the error.
27+
*/
28+
UnifiedChatCompletionException buildProviderSpecificChatCompletionError(
29+
ErrorResponse errorResponse,
30+
String errorMessage,
31+
RestStatus restStatus
32+
);
33+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java

Lines changed: 25 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -59,27 +59,22 @@ public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpR
5959
*/
6060
@Override
6161
protected UnifiedChatCompletionException buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) {
62-
return buildChatCompletionError(message, request, result, errorResponse, ErrorResponse.class);
62+
return buildChatCompletionError(
63+
message,
64+
request,
65+
result,
66+
errorResponse,
67+
() -> ErrorResponse.class,
68+
ElasticInferenceServiceUnifiedChatCompletionResponseHandler::buildProviderSpecificChatCompletionError
69+
);
6370
}
6471

65-
/**
66-
* Builds a custom {@link UnifiedChatCompletionException} for the Elastic Inference Service.
67-
* This method is called when an error response is received from the service.
68-
*
69-
* @param errorResponse The error response received from the service.
70-
* @param errorMessage The error message to include in the exception.
71-
* @param restStatus The HTTP status of the error response.
72-
* @param errorResponseClass The class of the error response.
73-
* @return An instance of {@link UnifiedChatCompletionException} with details from the error response.
74-
*/
75-
@Override
76-
protected UnifiedChatCompletionException buildChatCompletionError(
77-
ErrorResponse errorResponse,
78-
String errorMessage,
79-
RestStatus restStatus,
80-
Class<? extends ErrorResponse> errorResponseClass
72+
private static UnifiedChatCompletionException buildProviderSpecificChatCompletionError(
73+
ErrorResponse response,
74+
String message,
75+
RestStatus restStatus
8176
) {
82-
return new UnifiedChatCompletionException(restStatus, errorMessage, ERROR_TYPE, restStatus.name().toLowerCase(Locale.ROOT));
77+
return new UnifiedChatCompletionException(restStatus, message, ERROR_TYPE, restStatus.name().toLowerCase(Locale.ROOT));
8378
}
8479

8580
/**
@@ -92,73 +87,24 @@ protected UnifiedChatCompletionException buildChatCompletionError(
9287
* @return An instance of {@link UnifiedChatCompletionException} representing the mid-stream error.
9388
*/
9489
private UnifiedChatCompletionException buildMidStreamChatCompletionError(String inferenceEntityId, String message, Exception e) {
95-
var errorResponse = extractMidStreamChatCompletionErrorResponse(message);
90+
var errorResponse = ElasticInferenceServiceErrorResponseEntity.fromString(message);
9691
// Check if the error response contains a specific structure
9792
if (errorResponse.errorStructureFound()) {
98-
return buildProviderSpecificMidStreamChatCompletionError(inferenceEntityId, errorResponse);
93+
return new UnifiedChatCompletionException(
94+
RestStatus.INTERNAL_SERVER_ERROR,
95+
format(
96+
"%s for request from inference entity id [%s]. Error message: [%s]",
97+
SERVER_ERROR_OBJECT,
98+
inferenceEntityId,
99+
errorResponse.getErrorMessage()
100+
),
101+
ERROR_TYPE,
102+
STREAM_ERROR
103+
);
99104
} else if (e != null) {
100105
return UnifiedChatCompletionException.fromThrowable(e);
101106
} else {
102107
return buildDefaultMidStreamChatCompletionError(inferenceEntityId, errorResponse);
103108
}
104109
}
105-
106-
/**
107-
* Extracts the error response from the message. This method is specific to the Elastic Inference Service
108-
* and should parse the message according to its error response format.
109-
*
110-
* @param message The message containing the error response.
111-
* @return An instance of {@link ErrorResponse} parsed from the message.
112-
*/
113-
@Override
114-
protected ErrorResponse extractMidStreamChatCompletionErrorResponse(String message) {
115-
return ElasticInferenceServiceErrorResponseEntity.fromString(message);
116-
}
117-
118-
/**
119-
* Builds a custom mid-stream {@link UnifiedChatCompletionException} for the Elastic Inference Service.
120-
* This method is called when a specific error response structure is found in the message.
121-
*
122-
* @param inferenceEntityId The ID of the inference entity.
123-
* @param errorResponse The error response parsed from the message.
124-
* @return An instance of {@link UnifiedChatCompletionException} with details from the error response.
125-
*/
126-
@Override
127-
protected UnifiedChatCompletionException buildProviderSpecificMidStreamChatCompletionError(
128-
String inferenceEntityId,
129-
ErrorResponse errorResponse
130-
) {
131-
return new UnifiedChatCompletionException(
132-
RestStatus.INTERNAL_SERVER_ERROR,
133-
format(
134-
"%s for request from inference entity id [%s]. Error message: [%s]",
135-
SERVER_ERROR_OBJECT,
136-
inferenceEntityId,
137-
errorResponse.getErrorMessage()
138-
),
139-
ERROR_TYPE,
140-
STREAM_ERROR
141-
);
142-
}
143-
144-
/**
145-
* Builds a default mid-stream {@link UnifiedChatCompletionException} for the Elastic Inference Service.
146-
* This method is called when specific error response structure is NOT found in the message.
147-
*
148-
* @param inferenceEntityId The ID of the inference entity.
149-
* @param errorResponse The error response parsed from the message.
150-
* @return An instance of {@link UnifiedChatCompletionException} with a generic error message.
151-
*/
152-
@Override
153-
protected UnifiedChatCompletionException buildDefaultMidStreamChatCompletionError(
154-
String inferenceEntityId,
155-
ErrorResponse errorResponse
156-
) {
157-
return new UnifiedChatCompletionException(
158-
RestStatus.INTERNAL_SERVER_ERROR,
159-
format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, inferenceEntityId),
160-
ERROR_TYPE,
161-
STREAM_ERROR
162-
);
163-
}
164110
}

0 commit comments

Comments
 (0)