Skip to content

Commit 3b1523a

Browse files
[ML] Refactoring streaming error handling (#131316)
* Refactoring google gemini streaming error handling * Updating comments
1 parent fd971e8 commit 3b1523a

11 files changed

+298
-91
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ public void validateResponse(
9595

9696
protected abstract void checkForFailureStatusCode(Request request, HttpResult result);
9797

98-
private void checkForErrorObject(Request request, HttpResult result) {
98+
protected void checkForErrorObject(Request request, HttpResult result) {
9999
var errorEntity = errorParseFunction.apply(result);
100100

101101
if (errorEntity.errorStructureFound()) {
@@ -116,12 +116,12 @@ protected Exception buildError(String message, Request request, HttpResult resul
116116
protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) {
117117
var responseStatusCode = result.response().getStatusLine().getStatusCode();
118118
return new ElasticsearchStatusException(
119-
errorMessage(message, request, result, errorResponse, responseStatusCode),
119+
constructErrorMessage(message, request, errorResponse, responseStatusCode),
120120
toRestStatus(responseStatusCode)
121121
);
122122
}
123123

124-
protected String errorMessage(String message, Request request, HttpResult result, ErrorResponse errorResponse, int statusCode) {
124+
public static String constructErrorMessage(String message, Request request, ErrorResponse errorResponse, int statusCode) {
125125
return (errorResponse == null
126126
|| errorResponse.errorStructureFound() == false
127127
|| Strings.isNullOrEmpty(errorResponse.getErrorMessage()))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.http.retry;
9+
10+
import org.elasticsearch.rest.RestStatus;
11+
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
12+
import org.elasticsearch.xpack.inference.external.http.HttpResult;
13+
import org.elasticsearch.xpack.inference.external.request.Request;
14+
15+
import java.util.Locale;
16+
import java.util.Objects;
17+
18+
import static org.elasticsearch.core.Strings.format;
19+
import static org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler.SERVER_ERROR_OBJECT;
20+
import static org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler.toRestStatus;
21+
22+
public class ChatCompletionErrorResponseHandler {
23+
private static final String STREAM_ERROR = "stream_error";
24+
25+
private final UnifiedChatCompletionErrorParser unifiedChatCompletionErrorParser;
26+
27+
public ChatCompletionErrorResponseHandler(UnifiedChatCompletionErrorParser errorParser) {
28+
this.unifiedChatCompletionErrorParser = Objects.requireNonNull(errorParser);
29+
}
30+
31+
public void checkForErrorObject(Request request, HttpResult result) {
32+
var errorEntity = unifiedChatCompletionErrorParser.parse(result);
33+
34+
if (errorEntity.errorStructureFound()) {
35+
// We don't really know what happened because the status code was 200 so we'll return a failure and let the
36+
// client retry if necessary
37+
// If we did want to retry here, we'll need to determine if this was a streaming request, if it was
38+
// we shouldn't retry because that would replay the entire streaming request and the client would get
39+
// duplicate chunks back
40+
throw new RetryException(false, buildChatCompletionErrorInternal(SERVER_ERROR_OBJECT, request, result, errorEntity));
41+
}
42+
}
43+
44+
public UnifiedChatCompletionException buildChatCompletionError(String message, Request request, HttpResult result) {
45+
var errorResponse = unifiedChatCompletionErrorParser.parse(result);
46+
return buildChatCompletionErrorInternal(message, request, result, errorResponse);
47+
}
48+
49+
private UnifiedChatCompletionException buildChatCompletionErrorInternal(
50+
String message,
51+
Request request,
52+
HttpResult result,
53+
UnifiedChatCompletionErrorResponse errorResponse
54+
) {
55+
assert request.isStreaming() : "Only streaming requests support this format";
56+
var statusCode = result.response().getStatusLine().getStatusCode();
57+
var errorMessage = BaseResponseHandler.constructErrorMessage(message, request, errorResponse, statusCode);
58+
var restStatus = toRestStatus(statusCode);
59+
60+
if (errorResponse.errorStructureFound()) {
61+
return new UnifiedChatCompletionException(
62+
restStatus,
63+
errorMessage,
64+
errorResponse.type(),
65+
errorResponse.code(),
66+
errorResponse.param()
67+
);
68+
} else {
69+
return buildDefaultChatCompletionError(errorResponse, errorMessage, restStatus);
70+
}
71+
}
72+
73+
/**
74+
* Builds a default {@link UnifiedChatCompletionException} for a streaming request.
75+
* This method is used when an error response is received we were unable to parse it in the format we were expecting.
76+
* Only streaming requests should use this method.
77+
*
78+
* @param errorResponse the error response extracted from the HTTP result
79+
* @param errorMessage the error message to include in the exception
80+
* @param restStatus the REST status code of the response
81+
* @return an instance of {@link UnifiedChatCompletionException} with details from the error response
82+
*/
83+
private static UnifiedChatCompletionException buildDefaultChatCompletionError(
84+
ErrorResponse errorResponse,
85+
String errorMessage,
86+
RestStatus restStatus
87+
) {
88+
return new UnifiedChatCompletionException(
89+
restStatus,
90+
errorMessage,
91+
createErrorType(errorResponse),
92+
restStatus.name().toLowerCase(Locale.ROOT)
93+
);
94+
}
95+
96+
/**
97+
* Builds a mid-stream error for a streaming request.
98+
* This method is used when an error occurs while processing a streaming response.
99+
* Only streaming requests should use this method.
100+
*
101+
* @param inferenceEntityId the ID of the inference entity
102+
* @param message the error message
103+
* @param e the exception that caused the error, can be null
104+
* @return a {@link UnifiedChatCompletionException} representing the mid-stream error
105+
*/
106+
public UnifiedChatCompletionException buildMidStreamChatCompletionError(String inferenceEntityId, String message, Exception e) {
107+
var error = unifiedChatCompletionErrorParser.parse(message);
108+
109+
if (error.errorStructureFound()) {
110+
return new UnifiedChatCompletionException(
111+
RestStatus.INTERNAL_SERVER_ERROR,
112+
format(
113+
"%s for request from inference entity id [%s]. Error message: [%s]",
114+
SERVER_ERROR_OBJECT,
115+
inferenceEntityId,
116+
error.getErrorMessage()
117+
),
118+
error.type(),
119+
error.code(),
120+
error.param()
121+
);
122+
} else if (e != null) {
123+
// If the error response does not match, we can still return an exception based on the original throwable
124+
return UnifiedChatCompletionException.fromThrowable(e);
125+
} else {
126+
// If no specific error response is found, we return a default mid-stream error
127+
return buildDefaultMidStreamChatCompletionError(inferenceEntityId, error);
128+
}
129+
}
130+
131+
/**
132+
* Builds a default mid-stream error for a streaming request.
133+
* This method is used when no specific error response is found in the message.
134+
* Only streaming requests should use this method.
135+
*
136+
* @param inferenceEntityId the ID of the inference entity
137+
* @param errorResponse the error response extracted from the message
138+
* @return a {@link UnifiedChatCompletionException} representing the default mid-stream error
139+
*/
140+
private static UnifiedChatCompletionException buildDefaultMidStreamChatCompletionError(
141+
String inferenceEntityId,
142+
ErrorResponse errorResponse
143+
) {
144+
return new UnifiedChatCompletionException(
145+
RestStatus.INTERNAL_SERVER_ERROR,
146+
format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, inferenceEntityId),
147+
createErrorType(errorResponse),
148+
STREAM_ERROR
149+
);
150+
}
151+
152+
/**
153+
* Creates a string representation of the error type based on the provided ErrorResponse.
154+
* This method is used to generate a human-readable error type for logging or exception messages.
155+
*
156+
* @param errorResponse the ErrorResponse object
157+
* @return a string representing the error type
158+
*/
159+
private static String createErrorType(ErrorResponse errorResponse) {
160+
return errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown";
161+
}
162+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ErrorResponse.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ public ErrorResponse(String errorMessage) {
2222
this.errorStructureFound = true;
2323
}
2424

25-
private ErrorResponse(boolean errorStructureFound) {
25+
protected ErrorResponse(boolean errorStructureFound) {
2626
this.errorMessage = "";
2727
this.errorStructureFound = errorStructureFound;
2828
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.http.retry;
9+
10+
import org.elasticsearch.xpack.inference.external.http.HttpResult;
11+
12+
public interface UnifiedChatCompletionErrorParser {
13+
UnifiedChatCompletionErrorResponse parse(HttpResult result);
14+
15+
UnifiedChatCompletionErrorResponse parse(String result);
16+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.http.retry;
9+
10+
import org.elasticsearch.core.Nullable;
11+
12+
import java.util.Objects;
13+
14+
public class UnifiedChatCompletionErrorResponse extends ErrorResponse {
15+
public static final UnifiedChatCompletionErrorResponse UNDEFINED_ERROR = new UnifiedChatCompletionErrorResponse();
16+
17+
@Nullable
18+
private final String code;
19+
@Nullable
20+
private final String param;
21+
private final String type;
22+
23+
public UnifiedChatCompletionErrorResponse(String errorMessage, String type, @Nullable String code, @Nullable String param) {
24+
super(errorMessage);
25+
this.code = code;
26+
this.param = param;
27+
this.type = Objects.requireNonNull(type);
28+
}
29+
30+
private UnifiedChatCompletionErrorResponse() {
31+
super(false);
32+
this.code = null;
33+
this.param = null;
34+
this.type = "unknown";
35+
}
36+
37+
@Nullable
38+
public String code() {
39+
return code;
40+
}
41+
42+
@Nullable
43+
public String param() {
44+
return param;
45+
}
46+
47+
public String type() {
48+
return type;
49+
}
50+
51+
@Override
52+
public boolean equals(Object o) {
53+
if (o == null || getClass() != o.getClass()) return false;
54+
if (super.equals(o) == false) return false;
55+
UnifiedChatCompletionErrorResponse that = (UnifiedChatCompletionErrorResponse) o;
56+
return Objects.equals(code, that.code) && Objects.equals(param, that.param) && Objects.equals(type, that.type);
57+
}
58+
59+
@Override
60+
public int hashCode() {
61+
return Objects.hash(super.hashCode(), code, param, type);
62+
}
63+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ protected Exception buildError(String message, Request request, HttpResult resul
4949
var restStatus = toRestStatus(responseStatusCode);
5050
return new UnifiedChatCompletionException(
5151
restStatus,
52-
errorMessage(message, request, result, errorResponse, responseStatusCode),
52+
constructErrorMessage(message, request, errorResponse, responseStatusCode),
5353
"error",
5454
restStatus.name().toLowerCase(Locale.ROOT)
5555
);

0 commit comments

Comments
 (0)