Skip to content

Commit 0a09f00

Browse files
Refactor mid-stream error handling in response handlers
1 parent 569e351 commit 0a09f00

File tree

5 files changed

+12
-29
lines changed

5 files changed

+12
-29
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -224,23 +224,6 @@ protected UnifiedChatCompletionException buildDefaultChatCompletionError(
224224
);
225225
}
226226

227-
/**
228-
* Builds a mid-stream error for a streaming request.
229-
* This method is used when an error occurs while processing a streaming response.
230-
* It must be implemented by subclasses to handle specific error response formats.
231-
* Only streaming requests should use this method.
232-
*
233-
* @param inferenceEntityId the ID of the inference entity
234-
* @param message the error message
235-
* @param e the exception that caused the error, can be null
236-
* @return a {@link UnifiedChatCompletionException} representing the mid-stream error
237-
*/
238-
public UnifiedChatCompletionException buildMidStreamChatCompletionError(String inferenceEntityId, String message, Exception e) {
239-
throw new UnsupportedOperationException(
240-
"Mid-stream error handling is not implemented. Please override buildMidStreamChatCompletionError method."
241-
);
242-
}
243-
244227
/**
245228
* Builds a mid-stream error for a streaming request with a custom error type.
246229
* This method is used when an error occurs while processing a streaming response and allows for custom error handling.

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,7 @@ protected UnifiedChatCompletionException buildChatCompletionError(
9191
* @param e The exception that occurred, if any.
9292
* @return An instance of {@link UnifiedChatCompletionException} representing the mid-stream error.
9393
*/
94-
@Override
95-
public UnifiedChatCompletionException buildMidStreamChatCompletionError(String inferenceEntityId, String message, Exception e) {
94+
private UnifiedChatCompletionException buildMidStreamChatCompletionError(String inferenceEntityId, String message, Exception e) {
9695
var errorResponse = extractMidStreamChatCompletionErrorResponse(message);
9796
// Check if the error response contains a specific structure
9897
if (errorResponse.errorStructureFound()) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,7 @@ protected UnifiedChatCompletionException buildProviderSpecificChatCompletionErro
8383
);
8484
}
8585

86-
@Override
87-
public UnifiedChatCompletionException buildMidStreamChatCompletionError(String inferenceEntityId, String message, Exception e) {
86+
private UnifiedChatCompletionException buildMidStreamChatCompletionError(String inferenceEntityId, String message, Exception e) {
8887
return buildMidStreamChatCompletionError(inferenceEntityId, message, e, GoogleVertexAiErrorResponse.class);
8988
}
9089

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ protected UnifiedChatCompletionException buildProviderSpecificChatCompletionErro
9191
* @param e the exception that occurred
9292
* @return an instance of {@link UnifiedChatCompletionException} with details from the error response
9393
*/
94-
@Override
9594
public UnifiedChatCompletionException buildMidStreamChatCompletionError(String inferenceEntityId, String message, Exception e) {
9695
// Use the custom type StreamingErrorResponse for mid-stream errors
9796
return buildMidStreamChatCompletionError(inferenceEntityId, message, e, StreamingErrorResponse.class);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticCompletionPayload.java

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,14 @@
4040
* Each chunk should be in a valid JSON format, as that is the format the Elastic API uses.
4141
*/
4242
public class ElasticCompletionPayload implements SageMakerStreamSchemaPayload, ElasticPayload {
43-
private static final XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(
43+
private static final OpenAiUnifiedChatCompletionResponseHandler ERROR_HANDLER = new OpenAiUnifiedChatCompletionResponseHandler(
44+
"sagemaker openai chat completion",
45+
((request, result) -> {
46+
assert false : "do not call this";
47+
throw new UnsupportedOperationException("SageMaker should not call this object's response parser.");
48+
})
49+
);
50+
private static final XContentParserConfiguration PARSER_CONFIGURATION = XContentParserConfiguration.EMPTY.withDeprecationHandler(
4451
LoggingDeprecationHandler.INSTANCE
4552
);
4653

@@ -94,19 +101,15 @@ public SdkBytes chatCompletionRequestBytes(SageMakerModel model, UnifiedCompleti
94101
public StreamingUnifiedChatCompletionResults.Results chatCompletionResponseBody(SageMakerModel model, SdkBytes response) {
95102
var responseData = response.asUtf8String();
96103
try {
97-
var results = OpenAiUnifiedStreamingProcessor.parse(parserConfig, responseData)
104+
var results = OpenAiUnifiedStreamingProcessor.parse(PARSER_CONFIGURATION, responseData)
98105
.collect(
99106
() -> new ArrayDeque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk>(),
100107
ArrayDeque::offer,
101108
ArrayDeque::addAll
102109
);
103110
return new StreamingUnifiedChatCompletionResults.Results(results);
104111
} catch (Exception e) {
105-
throw new OpenAiUnifiedChatCompletionResponseHandler(null, null).buildMidStreamChatCompletionError(
106-
model.getInferenceEntityId(),
107-
responseData,
108-
e
109-
);
112+
throw ERROR_HANDLER.buildMidStreamChatCompletionError(model.getInferenceEntityId(), responseData, e);
110113
}
111114
}
112115

0 commit comments

Comments
 (0)