Skip to content

Commit b736749

Browse files
authored
[ML] Parse mid-stream errors from OpenAI and EIS (elastic#121806) (elastic#121961)
When we are already parsing events, we can receive errors as the next event. OpenAI formats these as: ``` event: error data: <payload> ``` Elastic formats these as: ``` data: <payload> ``` Unified will consolidate them into the new error structure.
1 parent 50d9778 commit b736749

File tree

6 files changed

+213
-42
lines changed

6 files changed

+213
-42
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,23 @@
88
package org.elasticsearch.xpack.inference.external.elastic;
99

1010
import org.elasticsearch.inference.InferenceServiceResults;
11+
import org.elasticsearch.rest.RestStatus;
1112
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
1213
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
1314
import org.elasticsearch.xpack.inference.external.http.HttpResult;
1415
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
1516
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
1617
import org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedStreamingProcessor;
1718
import org.elasticsearch.xpack.inference.external.request.Request;
19+
import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceErrorResponseEntity;
1820
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
1921
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor;
2022

2123
import java.util.Locale;
2224
import java.util.concurrent.Flow;
2325

26+
import static org.elasticsearch.core.Strings.format;
27+
2428
public class ElasticInferenceServiceUnifiedChatCompletionResponseHandler extends ElasticInferenceServiceResponseHandler {
2529
public ElasticInferenceServiceUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
2630
super(requestType, parseFunction);
@@ -34,7 +38,8 @@ public boolean canHandleStreamingResponses() {
3438
@Override
3539
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
3640
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
37-
var openAiProcessor = new OpenAiUnifiedStreamingProcessor(); // EIS uses the unified API spec
41+
// EIS uses the unified API spec
42+
var openAiProcessor = new OpenAiUnifiedStreamingProcessor((m, e) -> buildMidStreamError(request, m, e));
3843

3944
flow.subscribe(serverSentEventProcessor);
4045
serverSentEventProcessor.subscribe(openAiProcessor);
@@ -57,4 +62,30 @@ protected Exception buildError(String message, Request request, HttpResult resul
5762
return super.buildError(message, request, result, errorResponse);
5863
}
5964
}
65+
66+
private static Exception buildMidStreamError(Request request, String message, Exception e) {
67+
var errorResponse = ElasticInferenceServiceErrorResponseEntity.fromString(message);
68+
if (errorResponse.errorStructureFound()) {
69+
return new UnifiedChatCompletionException(
70+
RestStatus.INTERNAL_SERVER_ERROR,
71+
format(
72+
"%s for request from inference entity id [%s]. Error message: [%s]",
73+
SERVER_ERROR_OBJECT,
74+
request.getInferenceEntityId(),
75+
errorResponse.getErrorMessage()
76+
),
77+
"error",
78+
"stream_error"
79+
);
80+
} else if (e != null) {
81+
return UnifiedChatCompletionException.fromThrowable(e);
82+
} else {
83+
return new UnifiedChatCompletionException(
84+
RestStatus.INTERNAL_SERVER_ERROR,
85+
format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()),
86+
"error",
87+
"stream_error"
88+
);
89+
}
90+
}
6091
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedChatCompletionResponseHandler.java

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.elasticsearch.core.Nullable;
1111
import org.elasticsearch.inference.InferenceServiceResults;
12+
import org.elasticsearch.rest.RestStatus;
1213
import org.elasticsearch.xcontent.ConstructingObjectParser;
1314
import org.elasticsearch.xcontent.ParseField;
1415
import org.elasticsearch.xcontent.XContentFactory;
@@ -29,6 +30,8 @@
2930
import java.util.Optional;
3031
import java.util.concurrent.Flow;
3132

33+
import static org.elasticsearch.core.Strings.format;
34+
3235
public class OpenAiUnifiedChatCompletionResponseHandler extends OpenAiChatCompletionResponseHandler {
3336
public OpenAiUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
3437
super(requestType, parseFunction, OpenAiErrorResponse::fromResponse);
@@ -37,7 +40,7 @@ public OpenAiUnifiedChatCompletionResponseHandler(String requestType, ResponsePa
3740
@Override
3841
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
3942
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
40-
var openAiProcessor = new OpenAiUnifiedStreamingProcessor();
43+
var openAiProcessor = new OpenAiUnifiedStreamingProcessor((m, e) -> buildMidStreamError(request, m, e));
4144

4245
flow.subscribe(serverSentEventProcessor);
4346
serverSentEventProcessor.subscribe(openAiProcessor);
@@ -64,6 +67,33 @@ protected Exception buildError(String message, Request request, HttpResult resul
6467
}
6568
}
6669

70+
private static Exception buildMidStreamError(Request request, String message, Exception e) {
71+
var errorResponse = OpenAiErrorResponse.fromString(message);
72+
if (errorResponse instanceof OpenAiErrorResponse oer) {
73+
return new UnifiedChatCompletionException(
74+
RestStatus.INTERNAL_SERVER_ERROR,
75+
format(
76+
"%s for request from inference entity id [%s]. Error message: [%s]",
77+
SERVER_ERROR_OBJECT,
78+
request.getInferenceEntityId(),
79+
errorResponse.getErrorMessage()
80+
),
81+
oer.type(),
82+
oer.code(),
83+
oer.param()
84+
);
85+
} else if (e != null) {
86+
return UnifiedChatCompletionException.fromThrowable(e);
87+
} else {
88+
return new UnifiedChatCompletionException(
89+
RestStatus.INTERNAL_SERVER_ERROR,
90+
format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()),
91+
errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown",
92+
"stream_error"
93+
);
94+
}
95+
}
96+
6797
private static class OpenAiErrorResponse extends ErrorResponse {
6898
private static final ConstructingObjectParser<Optional<ErrorResponse>, Void> ERROR_PARSER = new ConstructingObjectParser<>(
6999
"open_ai_error",
@@ -103,6 +133,19 @@ private static ErrorResponse fromResponse(HttpResult response) {
103133
return ErrorResponse.UNDEFINED_ERROR;
104134
}
105135

136+
private static ErrorResponse fromString(String response) {
137+
try (
138+
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
139+
.createParser(XContentParserConfiguration.EMPTY, response)
140+
) {
141+
return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR);
142+
} catch (Exception e) {
143+
// swallow the error
144+
}
145+
146+
return ErrorResponse.UNDEFINED_ERROR;
147+
}
148+
106149
@Nullable
107150
private final String code;
108151
@Nullable

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
2121
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
2222
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
23+
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField;
2324

2425
import java.io.IOException;
2526
import java.util.ArrayDeque;
@@ -28,6 +29,7 @@
2829
import java.util.Iterator;
2930
import java.util.List;
3031
import java.util.concurrent.LinkedBlockingDeque;
32+
import java.util.function.BiFunction;
3133

3234
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
3335
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
@@ -57,7 +59,13 @@ public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor<Deque<S
5759
public static final String PROMPT_TOKENS_FIELD = "prompt_tokens";
5860
public static final String TOTAL_TOKENS_FIELD = "total_tokens";
5961

62+
private final BiFunction<String, Exception, Exception> errorParser;
6063
private final Deque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> buffer = new LinkedBlockingDeque<>();
64+
private volatile boolean previousEventWasError = false;
65+
66+
public OpenAiUnifiedStreamingProcessor(BiFunction<String, Exception, Exception> errorParser) {
67+
this.errorParser = errorParser;
68+
}
6169

6270
@Override
6371
protected void upstreamRequest(long n) {
@@ -71,7 +79,25 @@ protected void upstreamRequest(long n) {
7179
@Override
7280
protected void next(Deque<ServerSentEvent> item) throws Exception {
7381
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
74-
var results = parseEvent(item, OpenAiUnifiedStreamingProcessor::parse, parserConfig, logger);
82+
83+
var results = new ArrayDeque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk>(item.size());
84+
for (var event : item) {
85+
if (ServerSentEventField.EVENT == event.name() && "error".equals(event.value())) {
86+
previousEventWasError = true;
87+
} else if (ServerSentEventField.DATA == event.name() && event.hasValue()) {
88+
if (previousEventWasError) {
89+
throw errorParser.apply(event.value(), null);
90+
}
91+
92+
try {
93+
var delta = parse(parserConfig, event);
94+
delta.forEachRemaining(results::offer);
95+
} catch (Exception e) {
96+
logger.warn("Failed to parse event from inference provider: {}", event);
97+
throw errorParser.apply(event.value(), e);
98+
}
99+
}
100+
}
75101

76102
if (results.isEmpty()) {
77103
upstream().request(1);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceErrorResponseEntity.java

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,26 @@
99

1010
import org.apache.logging.log4j.LogManager;
1111
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.common.CheckedSupplier;
1213
import org.elasticsearch.xcontent.XContentFactory;
1314
import org.elasticsearch.xcontent.XContentParser;
1415
import org.elasticsearch.xcontent.XContentParserConfiguration;
1516
import org.elasticsearch.xcontent.XContentType;
1617
import org.elasticsearch.xpack.inference.external.http.HttpResult;
1718
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
1819

20+
import java.io.IOException;
21+
22+
/**
23+
* An example error response would look like
24+
*
25+
* <code>
26+
* {
27+
* "error": "some error"
28+
* }
29+
* </code>
30+
*
31+
*/
1932
public class ElasticInferenceServiceErrorResponseEntity extends ErrorResponse {
2033

2134
private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceErrorResponseEntity.class);
@@ -24,24 +37,18 @@ private ElasticInferenceServiceErrorResponseEntity(String errorMessage) {
2437
super(errorMessage);
2538
}
2639

27-
/**
28-
* An example error response would look like
29-
*
30-
* <code>
31-
* {
32-
* "error": "some error"
33-
* }
34-
* </code>
35-
*
36-
* @param response The error response
37-
* @return An error entity if the response is JSON with the above structure
38-
* or {@link ErrorResponse#UNDEFINED_ERROR} if the error field wasn't found
39-
*/
4040
public static ErrorResponse fromResponse(HttpResult response) {
41-
try (
42-
XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON)
43-
.createParser(XContentParserConfiguration.EMPTY, response.body())
44-
) {
41+
return fromParser(
42+
() -> XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())
43+
);
44+
}
45+
46+
public static ErrorResponse fromString(String response) {
47+
return fromParser(() -> XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response));
48+
}
49+
50+
private static ErrorResponse fromParser(CheckedSupplier<XContentParser, IOException> jsonParserFactory) {
51+
try (XContentParser jsonParser = jsonParserFactory.get()) {
4552
var responseMap = jsonParser.map();
4653
var error = (String) responseMap.get("error");
4754
if (error != null) {
@@ -50,7 +57,6 @@ public static ErrorResponse fromResponse(HttpResult response) {
5057
} catch (Exception e) {
5158
logger.debug("Failed to parse error response", e);
5259
}
53-
5460
return ErrorResponse.UNDEFINED_ERROR;
5561
}
5662
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -968,14 +968,51 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCo
968968
}
969969

970970
public void testUnifiedCompletionError() throws Exception {
971+
testUnifiedStreamError(404, """
972+
{
973+
"error": "The model `rainbow-sprinkles` does not exist or you do not have access to it."
974+
}""", """
975+
{\
976+
"error":{\
977+
"code":"not_found",\
978+
"message":"Received an unsuccessful status code for request from inference entity id [id] status \
979+
[404]. Error message: [The model `rainbow-sprinkles` does not exist or you do not have access to it.]",\
980+
"type":"error"\
981+
}}""");
982+
}
983+
984+
public void testUnifiedCompletionErrorMidStream() throws Exception {
985+
testUnifiedStreamError(200, """
986+
data: { "error": "some error" }
987+
988+
""", """
989+
{\
990+
"error":{\
991+
"code":"stream_error",\
992+
"message":"Received an error response for request from inference entity id [id]. Error message: [some error]",\
993+
"type":"error"\
994+
}}""");
995+
}
996+
997+
public void testUnifiedCompletionMalformedError() throws Exception {
998+
testUnifiedStreamError(200, """
999+
data: { i am not json }
1000+
1001+
""", """
1002+
{\
1003+
"error":{\
1004+
"code":"bad_request",\
1005+
"message":"[1:3] Unexpected character ('i' (code 105)): was expecting double-quote to start field name\\n\
1006+
at [Source: (String)\\"{ i am not json }\\"; line: 1, column: 3]",\
1007+
"type":"x_content_parse_exception"\
1008+
}}""");
1009+
}
1010+
1011+
private void testUnifiedStreamError(int responseCode, String responseJson, String expectedJson) throws Exception {
9711012
var eisGatewayUrl = getUrl(webServer);
9721013
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
9731014
try (var service = createService(senderFactory, eisGatewayUrl)) {
974-
var responseJson = """
975-
{
976-
"error": "The model `rainbow-sprinkles` does not exist or you do not have access to it."
977-
}""";
978-
webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson));
1015+
webServer.enqueue(new MockResponse().setResponseCode(responseCode).setBody(responseJson));
9791016
var model = new ElasticInferenceServiceCompletionModel(
9801017
"id",
9811018
TaskType.COMPLETION,
@@ -1010,14 +1047,7 @@ public void testUnifiedCompletionError() throws Exception {
10101047
});
10111048
var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
10121049

1013-
assertThat(json, is("""
1014-
{\
1015-
"error":{\
1016-
"code":"not_found",\
1017-
"message":"Received an unsuccessful status code for request from inference entity id [id] status \
1018-
[404]. Error message: [The model `rainbow-sprinkles` does not exist or you do not have access to it.]",\
1019-
"type":"error"\
1020-
}}"""));
1050+
assertThat(json, is(expectedJson));
10211051
}
10221052
});
10231053
}

0 commit comments

Comments
 (0)