Skip to content

Commit 2d82bce

Browse files
committed
[ML] Parse mid-stream errors from OpenAI and EIS
When we are already parsing events, we can receive errors as the next event. OpenAI formats these as: ``` event: error data: <payload> ``` Elastic formats these as: ``` data: <payload> ``` Unified will consolidate them into the new error structure.
1 parent 5302589 commit 2d82bce

File tree

6 files changed

+214
-42
lines changed

6 files changed

+214
-42
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,23 @@
88
package org.elasticsearch.xpack.inference.external.elastic;
99

1010
import org.elasticsearch.inference.InferenceServiceResults;
11+
import org.elasticsearch.rest.RestStatus;
1112
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
1213
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
1314
import org.elasticsearch.xpack.inference.external.http.HttpResult;
1415
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
1516
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
1617
import org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedStreamingProcessor;
1718
import org.elasticsearch.xpack.inference.external.request.Request;
19+
import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceErrorResponseEntity;
1820
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
1921
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor;
2022

2123
import java.util.Locale;
2224
import java.util.concurrent.Flow;
2325

26+
import static org.elasticsearch.core.Strings.format;
27+
2428
public class ElasticInferenceServiceUnifiedChatCompletionResponseHandler extends ElasticInferenceServiceResponseHandler {
2529
public ElasticInferenceServiceUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
2630
super(requestType, parseFunction, true);
@@ -29,7 +33,8 @@ public ElasticInferenceServiceUnifiedChatCompletionResponseHandler(String reques
2933
@Override
3034
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
3135
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
32-
var openAiProcessor = new OpenAiUnifiedStreamingProcessor(); // EIS uses the unified API spec
36+
// EIS uses the unified API spec
37+
var openAiProcessor = new OpenAiUnifiedStreamingProcessor((m, e) -> buildMidStreamError(request, m, e));
3338

3439
flow.subscribe(serverSentEventProcessor);
3540
serverSentEventProcessor.subscribe(openAiProcessor);
@@ -52,4 +57,30 @@ protected Exception buildError(String message, Request request, HttpResult resul
5257
return super.buildError(message, request, result, errorResponse);
5358
}
5459
}
60+
61+
private static Exception buildMidStreamError(Request request, String message, Exception e) {
62+
var errorResponse = ElasticInferenceServiceErrorResponseEntity.fromString(message);
63+
if (errorResponse.errorStructureFound()) {
64+
return new UnifiedChatCompletionException(
65+
RestStatus.INTERNAL_SERVER_ERROR,
66+
format(
67+
"%s for request from inference entity id [%s]. Error message: [%s]",
68+
SERVER_ERROR_OBJECT,
69+
request.getInferenceEntityId(),
70+
errorResponse.getErrorMessage()
71+
),
72+
"error",
73+
"stream_error"
74+
);
75+
} else if (e != null) {
76+
return UnifiedChatCompletionException.fromThrowable(e);
77+
} else {
78+
return new UnifiedChatCompletionException(
79+
RestStatus.INTERNAL_SERVER_ERROR,
80+
format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()),
81+
"error",
82+
"stream_error"
83+
);
84+
}
85+
}
5586
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedChatCompletionResponseHandler.java

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.elasticsearch.core.Nullable;
1111
import org.elasticsearch.inference.InferenceServiceResults;
12+
import org.elasticsearch.rest.RestStatus;
1213
import org.elasticsearch.xcontent.ConstructingObjectParser;
1314
import org.elasticsearch.xcontent.ParseField;
1415
import org.elasticsearch.xcontent.XContentFactory;
@@ -29,6 +30,8 @@
2930
import java.util.Optional;
3031
import java.util.concurrent.Flow;
3132

33+
import static org.elasticsearch.core.Strings.format;
34+
3235
public class OpenAiUnifiedChatCompletionResponseHandler extends OpenAiChatCompletionResponseHandler {
3336
public OpenAiUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
3437
super(requestType, parseFunction, OpenAiErrorResponse::fromResponse);
@@ -37,7 +40,7 @@ public OpenAiUnifiedChatCompletionResponseHandler(String requestType, ResponsePa
3740
@Override
3841
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
3942
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
40-
var openAiProcessor = new OpenAiUnifiedStreamingProcessor();
43+
var openAiProcessor = new OpenAiUnifiedStreamingProcessor((m, e) -> buildMidStreamError(request, m, e));
4144

4245
flow.subscribe(serverSentEventProcessor);
4346
serverSentEventProcessor.subscribe(openAiProcessor);
@@ -64,6 +67,34 @@ protected Exception buildError(String message, Request request, HttpResult resul
6467
}
6568
}
6669

70+
private static Exception buildMidStreamError(Request request, String message, Exception e) {
71+
var restStatus = RestStatus.OK;
72+
var errorResponse = OpenAiErrorResponse.fromString(message);
73+
if (errorResponse instanceof OpenAiErrorResponse oer) {
74+
return new UnifiedChatCompletionException(
75+
restStatus,
76+
format(
77+
"%s for request from inference entity id [%s]. Error message: [%s]",
78+
SERVER_ERROR_OBJECT,
79+
request.getInferenceEntityId(),
80+
errorResponse.getErrorMessage()
81+
),
82+
oer.type(),
83+
oer.code(),
84+
oer.param()
85+
);
86+
} else if (e != null) {
87+
return UnifiedChatCompletionException.fromThrowable(e);
88+
} else {
89+
return new UnifiedChatCompletionException(
90+
restStatus,
91+
format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()),
92+
errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown",
93+
"streaming_error"
94+
);
95+
}
96+
}
97+
6798
private static class OpenAiErrorResponse extends ErrorResponse {
6899
private static final ConstructingObjectParser<Optional<ErrorResponse>, Void> ERROR_PARSER = new ConstructingObjectParser<>(
69100
"open_ai_error",
@@ -103,6 +134,19 @@ private static ErrorResponse fromResponse(HttpResult response) {
103134
return ErrorResponse.UNDEFINED_ERROR;
104135
}
105136

137+
private static ErrorResponse fromString(String response) {
138+
try (
139+
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
140+
.createParser(XContentParserConfiguration.EMPTY, response)
141+
) {
142+
return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR);
143+
} catch (Exception e) {
144+
// swallow the error
145+
}
146+
147+
return ErrorResponse.UNDEFINED_ERROR;
148+
}
149+
106150
@Nullable
107151
private final String code;
108152
@Nullable

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
2121
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
2222
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
23+
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField;
2324

2425
import java.io.IOException;
2526
import java.util.ArrayDeque;
@@ -28,6 +29,7 @@
2829
import java.util.Iterator;
2930
import java.util.List;
3031
import java.util.concurrent.LinkedBlockingDeque;
32+
import java.util.function.BiFunction;
3133

3234
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
3335
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
@@ -57,7 +59,13 @@ public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor<Deque<S
5759
public static final String PROMPT_TOKENS_FIELD = "prompt_tokens";
5860
public static final String TOTAL_TOKENS_FIELD = "total_tokens";
5961

62+
private final BiFunction<String, Exception, Exception> errorParser;
6063
private final Deque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> buffer = new LinkedBlockingDeque<>();
64+
private volatile boolean previousEventWasNotError = true;
65+
66+
public OpenAiUnifiedStreamingProcessor(BiFunction<String, Exception, Exception> errorParser) {
67+
this.errorParser = errorParser;
68+
}
6169

6270
@Override
6371
protected void upstreamRequest(long n) {
@@ -71,7 +79,25 @@ protected void upstreamRequest(long n) {
7179
@Override
7280
protected void next(Deque<ServerSentEvent> item) throws Exception {
7381
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
74-
var results = parseEvent(item, OpenAiUnifiedStreamingProcessor::parse, parserConfig, logger);
82+
83+
var results = new ArrayDeque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk>(item.size());
84+
for (var event : item) {
85+
if (ServerSentEventField.EVENT == event.name() && "error".equals(event.value())) {
86+
previousEventWasNotError = false;
87+
} else if (ServerSentEventField.DATA == event.name() && event.hasValue()) {
88+
if (previousEventWasNotError) {
89+
try {
90+
var delta = parse(parserConfig, event);
91+
delta.forEachRemaining(results::offer);
92+
} catch (Exception e) {
93+
logger.warn("Failed to parse event from inference provider: {}", event);
94+
throw errorParser.apply(event.value(), e);
95+
}
96+
} else {
97+
throw errorParser.apply(event.value(), null);
98+
}
99+
}
100+
}
75101

76102
if (results.isEmpty()) {
77103
upstream().request(1);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceErrorResponseEntity.java

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,26 @@
99

1010
import org.apache.logging.log4j.LogManager;
1111
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.common.CheckedSupplier;
1213
import org.elasticsearch.xcontent.XContentFactory;
1314
import org.elasticsearch.xcontent.XContentParser;
1415
import org.elasticsearch.xcontent.XContentParserConfiguration;
1516
import org.elasticsearch.xcontent.XContentType;
1617
import org.elasticsearch.xpack.inference.external.http.HttpResult;
1718
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
1819

20+
import java.io.IOException;
21+
22+
/**
23+
* An example error response would look like
24+
*
25+
* <code>
26+
* {
27+
* "error": "some error"
28+
* }
29+
* </code>
30+
*
31+
*/
1932
public class ElasticInferenceServiceErrorResponseEntity extends ErrorResponse {
2033

2134
private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceErrorResponseEntity.class);
@@ -24,24 +37,18 @@ private ElasticInferenceServiceErrorResponseEntity(String errorMessage) {
2437
super(errorMessage);
2538
}
2639

27-
/**
28-
* An example error response would look like
29-
*
30-
* <code>
31-
* {
32-
* "error": "some error"
33-
* }
34-
* </code>
35-
*
36-
* @param response The error response
37-
* @return An error entity if the response is JSON with the above structure
38-
* or {@link ErrorResponse#UNDEFINED_ERROR} if the error field wasn't found
39-
*/
4040
public static ErrorResponse fromResponse(HttpResult response) {
41-
try (
42-
XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON)
43-
.createParser(XContentParserConfiguration.EMPTY, response.body())
44-
) {
41+
return fromParser(
42+
() -> XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())
43+
);
44+
}
45+
46+
public static ErrorResponse fromString(String response) {
47+
return fromParser(() -> XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response));
48+
}
49+
50+
private static ErrorResponse fromParser(CheckedSupplier<XContentParser, IOException> jsonParserFactory) {
51+
try (XContentParser jsonParser = jsonParserFactory.get()) {
4552
var responseMap = jsonParser.map();
4653
var error = (String) responseMap.get("error");
4754
if (error != null) {
@@ -50,7 +57,6 @@ public static ErrorResponse fromResponse(HttpResult response) {
5057
} catch (Exception e) {
5158
logger.debug("Failed to parse error response", e);
5259
}
53-
5460
return ErrorResponse.UNDEFINED_ERROR;
5561
}
5662
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -962,14 +962,51 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCo
962962
}
963963

964964
public void testUnifiedCompletionError() throws Exception {
965+
testUnifiedStreamError(404, """
966+
{
967+
"error": "The model `rainbow-sprinkles` does not exist or you do not have access to it."
968+
}""", """
969+
{\
970+
"error":{\
971+
"code":"not_found",\
972+
"message":"Received an unsuccessful status code for request from inference entity id [id] status \
973+
[404]. Error message: [The model `rainbow-sprinkles` does not exist or you do not have access to it.]",\
974+
"type":"error"\
975+
}}""");
976+
}
977+
978+
public void testUnifiedCompletionErrorMidStream() throws Exception {
979+
testUnifiedStreamError(200, """
980+
data: { "error": "some error" }
981+
982+
""", """
983+
{\
984+
"error":{\
985+
"code":"stream_error",\
986+
"message":"Received an error response for request from inference entity id [id]. Error message: [some error]",\
987+
"type":"error"\
988+
}}""");
989+
}
990+
991+
public void testUnifiedCompletionMalformedError() throws Exception {
992+
testUnifiedStreamError(200, """
993+
data: { i am not json }
994+
995+
""", """
996+
{\
997+
"error":{\
998+
"code":"bad_request",\
999+
"message":"[1:3] Unexpected character ('i' (code 105)): was expecting double-quote to start field name\\n\
1000+
at [Source: (String)\\"{ i am not json }\\"; line: 1, column: 3]",\
1001+
"type":"x_content_parse_exception"\
1002+
}}""");
1003+
}
1004+
1005+
private void testUnifiedStreamError(int responseCode, String responseJson, String expectedJson) throws Exception {
9651006
var eisGatewayUrl = getUrl(webServer);
9661007
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
9671008
try (var service = createService(senderFactory, eisGatewayUrl)) {
968-
var responseJson = """
969-
{
970-
"error": "The model `rainbow-sprinkles` does not exist or you do not have access to it."
971-
}""";
972-
webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson));
1009+
webServer.enqueue(new MockResponse().setResponseCode(responseCode).setBody(responseJson));
9731010
var model = new ElasticInferenceServiceCompletionModel(
9741011
"id",
9751012
TaskType.COMPLETION,
@@ -1004,14 +1041,7 @@ public void testUnifiedCompletionError() throws Exception {
10041041
});
10051042
var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
10061043

1007-
assertThat(json, is("""
1008-
{\
1009-
"error":{\
1010-
"code":"not_found",\
1011-
"message":"Received an unsuccessful status code for request from inference entity id [id] status \
1012-
[404]. Error message: [The model `rainbow-sprinkles` does not exist or you do not have access to it.]",\
1013-
"type":"error"\
1014-
}}"""));
1044+
assertThat(json, is(expectedJson));
10151045
}
10161046
});
10171047
}

0 commit comments

Comments
 (0)