Skip to content

Commit ad00113

Browse files
authored
[ML] Change format for Unified Chat error responses (#121396)
Unified Chat Completion error responses now forward code, type, and param to in the response payload. `reason` has been renamed to `message`. Notes: - `XContentFormattedException` is a `ChunkedToXContent` so that the REST listener can call `toXContentChunked` to format the output structure. By default, the structure forwards to our existing ES exception structure. - `UnifiedChatCompletionException` will override the structure to match the new unified format. - The Rest, Transport, and Stream handlers all check the exception to verify it is a UnifiedChatCompletionException. - OpenAI response handler now reads all the fields in the error message and forwards them to the user. - In the event that a `Throwable` is a `Error`, we rethrow it on another thread so the JVM can catch and handle it. We also stop surfacing the JVM details to the user in the error message (but it's still logged for debugging purposes).
1 parent 62f0fe8 commit ad00113

19 files changed

+760
-45
lines changed

docs/changelog/121396.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 121396
2+
summary: Change format for Unified Chat
3+
area: Machine Learning
4+
type: bug
5+
issues: []
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.inference.results;
9+
10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.ExceptionsHelper;
13+
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
14+
import org.elasticsearch.core.Nullable;
15+
import org.elasticsearch.rest.RestStatus;
16+
import org.elasticsearch.xcontent.ToXContent;
17+
18+
import java.util.Iterator;
19+
import java.util.Locale;
20+
import java.util.Objects;
21+
22+
import static java.util.Collections.emptyIterator;
23+
import static org.elasticsearch.ExceptionsHelper.maybeError;
24+
import static org.elasticsearch.common.collect.Iterators.concat;
25+
import static org.elasticsearch.common.xcontent.ChunkedToXContentHelper.endObject;
26+
import static org.elasticsearch.common.xcontent.ChunkedToXContentHelper.startObject;
27+
28+
public class UnifiedChatCompletionException extends XContentFormattedException {
29+
30+
private static final Logger log = LogManager.getLogger(UnifiedChatCompletionException.class);
31+
private final String message;
32+
private final String type;
33+
@Nullable
34+
private final String code;
35+
@Nullable
36+
private final String param;
37+
38+
public UnifiedChatCompletionException(RestStatus status, String message, String type, @Nullable String code) {
39+
this(status, message, type, code, null);
40+
}
41+
42+
public UnifiedChatCompletionException(RestStatus status, String message, String type, @Nullable String code, @Nullable String param) {
43+
super(message, status);
44+
this.message = Objects.requireNonNull(message);
45+
this.type = Objects.requireNonNull(type);
46+
this.code = code;
47+
this.param = param;
48+
}
49+
50+
public UnifiedChatCompletionException(
51+
Throwable cause,
52+
RestStatus status,
53+
String message,
54+
String type,
55+
@Nullable String code,
56+
@Nullable String param
57+
) {
58+
super(message, cause, status);
59+
this.message = Objects.requireNonNull(message);
60+
this.type = Objects.requireNonNull(type);
61+
this.code = code;
62+
this.param = param;
63+
}
64+
65+
@Override
66+
public Iterator<? extends ToXContent> toXContentChunked(Params params) {
67+
return concat(
68+
startObject(),
69+
startObject("error"),
70+
optionalField("code", code),
71+
field("message", message),
72+
optionalField("param", param),
73+
field("type", type),
74+
endObject(),
75+
endObject()
76+
);
77+
}
78+
79+
private static Iterator<ToXContent> field(String key, String value) {
80+
return ChunkedToXContentHelper.chunk((b, p) -> b.field(key, value));
81+
}
82+
83+
private static Iterator<ToXContent> optionalField(String key, String value) {
84+
return value != null ? ChunkedToXContentHelper.chunk((b, p) -> b.field(key, value)) : emptyIterator();
85+
}
86+
87+
public static UnifiedChatCompletionException fromThrowable(Throwable t) {
88+
if (ExceptionsHelper.unwrapCause(t) instanceof UnifiedChatCompletionException e) {
89+
return e;
90+
} else {
91+
return maybeError(t).map(error -> {
92+
// we should never be throwing Error, but just in case we are, rethrow it on another thread so the JVM can handle it and
93+
// return a vague error to the user so that they at least see something went wrong but don't leak JVM details to users
94+
ExceptionsHelper.maybeDieOnAnotherThread(error);
95+
var e = new RuntimeException("Fatal error while streaming response. Please retry the request.");
96+
log.error(e.getMessage(), t);
97+
return new UnifiedChatCompletionException(
98+
RestStatus.INTERNAL_SERVER_ERROR,
99+
e.getMessage(),
100+
getExceptionName(e),
101+
RestStatus.INTERNAL_SERVER_ERROR.name().toLowerCase(Locale.ROOT)
102+
);
103+
}).orElseGet(() -> {
104+
log.atDebug().withThrowable(t).log("UnifiedChatCompletionException stack trace for debugging purposes.");
105+
var status = ExceptionsHelper.status(t);
106+
return new UnifiedChatCompletionException(
107+
t,
108+
status,
109+
t.getMessage(),
110+
getExceptionName(t),
111+
status.name().toLowerCase(Locale.ROOT),
112+
null
113+
);
114+
});
115+
}
116+
}
117+
}
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.inference.results;
9+
10+
import org.elasticsearch.ElasticsearchException;
11+
import org.elasticsearch.common.collect.Iterators;
12+
import org.elasticsearch.common.xcontent.ChunkedToXContent;
13+
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
14+
import org.elasticsearch.core.RestApiVersion;
15+
import org.elasticsearch.rest.RestStatus;
16+
import org.elasticsearch.xcontent.ToXContent;
17+
import org.elasticsearch.xcontent.XContentBuilder;
18+
19+
import java.util.Iterator;
20+
import java.util.Objects;
21+
22+
/**
23+
* Similar to {@link org.elasticsearch.ElasticsearchWrapperException}, this will wrap an Exception to generate an xContent using
24+
* {@link ElasticsearchException#generateFailureXContent(XContentBuilder, Params, Exception, boolean)}.
25+
* Extends {@link ElasticsearchException} to provide REST handlers the {@link #status()} method in order to set the response header.
26+
*/
27+
public class XContentFormattedException extends ElasticsearchException implements ChunkedToXContent {
28+
29+
public static final String X_CONTENT_PARAM = "detailedErrorsEnabled";
30+
private final RestStatus status;
31+
private final Throwable cause;
32+
33+
public XContentFormattedException(String message, RestStatus status) {
34+
super(message);
35+
this.status = Objects.requireNonNull(status);
36+
this.cause = null;
37+
}
38+
39+
public XContentFormattedException(Throwable cause, RestStatus status) {
40+
super(cause);
41+
this.status = Objects.requireNonNull(status);
42+
this.cause = cause;
43+
}
44+
45+
public XContentFormattedException(String message, Throwable cause, RestStatus status) {
46+
super(message, cause);
47+
this.status = Objects.requireNonNull(status);
48+
this.cause = cause;
49+
}
50+
51+
@Override
52+
public RestStatus status() {
53+
return status;
54+
}
55+
56+
@Override
57+
public Iterator<? extends ToXContent> toXContentChunked(Params params) {
58+
return Iterators.concat(
59+
ChunkedToXContentHelper.startObject(),
60+
Iterators.single(
61+
(b, p) -> ElasticsearchException.generateFailureXContent(
62+
b,
63+
p,
64+
cause instanceof Exception e ? e : this,
65+
params.paramAsBoolean(X_CONTENT_PARAM, false)
66+
)
67+
),
68+
Iterators.single((b, p) -> b.field("status", status.getStatus())),
69+
ChunkedToXContentHelper.endObject()
70+
);
71+
}
72+
73+
@Override
74+
public Iterator<? extends ToXContent> toXContentChunked(RestApiVersion restApiVersion, Params params) {
75+
return ChunkedToXContent.super.toXContentChunked(restApiVersion, params);
76+
}
77+
78+
@Override
79+
public Iterator<? extends ToXContent> toXContentChunkedV8(Params params) {
80+
return ChunkedToXContent.super.toXContentChunkedV8(params);
81+
}
82+
83+
@Override
84+
public boolean isFragment() {
85+
return super.isFragment();
86+
}
87+
}

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,12 @@
4444
import org.elasticsearch.rest.RestController;
4545
import org.elasticsearch.rest.RestHandler;
4646
import org.elasticsearch.rest.RestRequest;
47+
import org.elasticsearch.rest.RestStatus;
4748
import org.elasticsearch.test.ESIntegTestCase;
4849
import org.elasticsearch.threadpool.ThreadPool;
4950
import org.elasticsearch.xcontent.ToXContent;
5051
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
52+
import org.elasticsearch.xpack.core.inference.results.XContentFormattedException;
5153
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
5254
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField;
5355
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
@@ -80,6 +82,7 @@ public class ServerSentEventsRestActionListenerTests extends ESIntegTestCase {
8082
private static final String REQUEST_COUNT = "request_count";
8183
private static final String WITH_ERROR = "with_error";
8284
private static final String ERROR_ROUTE = "/_inference_error";
85+
private static final String FORMATTED_ERROR_ROUTE = "/_formatted_inference_error";
8386
private static final String NO_STREAM_ROUTE = "/_inference_no_stream";
8487
private static final Exception expectedException = new IllegalStateException("hello there");
8588
private static final String expectedExceptionAsServerSentEvent = """
@@ -88,6 +91,11 @@ public class ServerSentEventsRestActionListenerTests extends ESIntegTestCase {
8891
"type":"illegal_state_exception","reason":"hello there"},"status":500\
8992
}""";
9093

94+
private static final Exception expectedFormattedException = new XContentFormattedException(
95+
expectedException,
96+
RestStatus.INTERNAL_SERVER_ERROR
97+
);
98+
9199
@Override
92100
protected boolean addMockHttpTransport() {
93101
return false;
@@ -145,6 +153,16 @@ public List<Route> routes() {
145153
public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) {
146154
new ServerSentEventsRestActionListener(channel, threadPool).onFailure(expectedException);
147155
}
156+
}, new RestHandler() {
157+
@Override
158+
public List<Route> routes() {
159+
return List.of(new Route(RestRequest.Method.POST, FORMATTED_ERROR_ROUTE));
160+
}
161+
162+
@Override
163+
public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) {
164+
new ServerSentEventsRestActionListener(channel, threadPool).onFailure(expectedFormattedException);
165+
}
148166
}, new RestHandler() {
149167
@Override
150168
public List<Route> routes() {
@@ -424,6 +442,21 @@ public void testErrorMidStream() {
424442
assertThat(collector.stringsVerified.getLast(), equalTo(expectedExceptionAsServerSentEvent));
425443
}
426444

445+
public void testFormattedError() throws IOException {
446+
var request = new Request(RestRequest.Method.POST.name(), FORMATTED_ERROR_ROUTE);
447+
448+
try {
449+
getRestClient().performRequest(request);
450+
fail("Expected an exception to be thrown from the error route");
451+
} catch (ResponseException e) {
452+
var response = e.getResponse();
453+
assertThat(response.getStatusLine().getStatusCode(), is(HttpStatus.SC_INTERNAL_SERVER_ERROR));
454+
assertThat(EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8), equalTo("""
455+
\uFEFFevent: error
456+
data:\s""" + expectedExceptionAsServerSentEvent + "\n\n"));
457+
}
458+
}
459+
427460
public void testNoStream() {
428461
var collector = new RandomStringCollector();
429462
var expectedTestCount = randomIntBetween(2, 30);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,11 @@
5050
import java.io.IOException;
5151
import java.util.Random;
5252
import java.util.concurrent.Executor;
53+
import java.util.concurrent.Flow;
5354
import java.util.function.Supplier;
5455
import java.util.stream.Collectors;
5556

57+
import static org.elasticsearch.ExceptionsHelper.unwrapCause;
5658
import static org.elasticsearch.core.Strings.format;
5759
import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;
5860
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes;
@@ -280,7 +282,9 @@ private void inferOnServiceWithMetrics(
280282
var instrumentedStream = new PublisherWithMetrics(timer, model);
281283
taskProcessor.subscribe(instrumentedStream);
282284

283-
listener.onResponse(new InferenceAction.Response(inferenceResults, instrumentedStream));
285+
var streamErrorHandler = streamErrorHandler(instrumentedStream);
286+
287+
listener.onResponse(new InferenceAction.Response(inferenceResults, streamErrorHandler));
284288
} else {
285289
recordMetrics(model, timer, null);
286290
listener.onResponse(new InferenceAction.Response(inferenceResults));
@@ -291,9 +295,13 @@ private void inferOnServiceWithMetrics(
291295
}));
292296
}
293297

298+
protected Flow.Publisher<ChunkedToXContent> streamErrorHandler(Flow.Processor<ChunkedToXContent, ChunkedToXContent> upstream) {
299+
return upstream;
300+
}
301+
294302
private void recordMetrics(Model model, InferenceTimer timer, @Nullable Throwable t) {
295303
try {
296-
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t));
304+
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, unwrapCause(t)));
297305
} catch (Exception e) {
298306
log.atDebug().withThrowable(e).log("Failed to record metrics with a parsed model, dropping metrics");
299307
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.action.ActionListener;
1212
import org.elasticsearch.action.support.ActionFilters;
1313
import org.elasticsearch.client.internal.node.NodeClient;
14+
import org.elasticsearch.common.xcontent.ChunkedToXContent;
1415
import org.elasticsearch.inference.InferenceService;
1516
import org.elasticsearch.inference.InferenceServiceRegistry;
1617
import org.elasticsearch.inference.InferenceServiceResults;
@@ -20,14 +21,19 @@
2021
import org.elasticsearch.injection.guice.Inject;
2122
import org.elasticsearch.license.XPackLicenseState;
2223
import org.elasticsearch.rest.RestStatus;
24+
import org.elasticsearch.tasks.Task;
2325
import org.elasticsearch.threadpool.ThreadPool;
2426
import org.elasticsearch.transport.TransportService;
27+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
2528
import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;
29+
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
2630
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
2731
import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
2832
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
2933
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
3034

35+
import java.util.concurrent.Flow;
36+
3137
public class TransportUnifiedCompletionInferenceAction extends BaseTransportInferenceAction<UnifiedCompletionAction.Request> {
3238

3339
@Inject
@@ -86,4 +92,40 @@ protected void doInference(
8692
) {
8793
service.unifiedCompletionInfer(model, request.getUnifiedCompletionRequest(), null, listener);
8894
}
95+
96+
@Override
97+
protected void doExecute(Task task, UnifiedCompletionAction.Request request, ActionListener<InferenceAction.Response> listener) {
98+
super.doExecute(task, request, listener.delegateResponse((l, e) -> l.onFailure(UnifiedChatCompletionException.fromThrowable(e))));
99+
}
100+
101+
/**
102+
* If we get any errors, either in {@link #doExecute} via the listener.onFailure or while streaming, make sure that they are formatted
103+
* as {@link UnifiedChatCompletionException}.
104+
*/
105+
@Override
106+
protected Flow.Publisher<ChunkedToXContent> streamErrorHandler(Flow.Processor<ChunkedToXContent, ChunkedToXContent> upstream) {
107+
return downstream -> {
108+
upstream.subscribe(new Flow.Subscriber<>() {
109+
@Override
110+
public void onSubscribe(Flow.Subscription subscription) {
111+
downstream.onSubscribe(subscription);
112+
}
113+
114+
@Override
115+
public void onNext(ChunkedToXContent item) {
116+
downstream.onNext(item);
117+
}
118+
119+
@Override
120+
public void onError(Throwable throwable) {
121+
downstream.onError(UnifiedChatCompletionException.fromThrowable(throwable));
122+
}
123+
124+
@Override
125+
public void onComplete() {
126+
downstream.onComplete();
127+
}
128+
});
129+
};
130+
}
89131
}

0 commit comments

Comments
 (0)