Skip to content

Commit d08ad46

Browse files
authored
[ML] Change format for Unified Chat error responses (#121822)
Unified Chat Completion error responses now forward code, type, and param to in the response payload. `reason` has been renamed to `message`. Notes: - `XContentFormattedException` is a `ChunkedToXContent` so that the REST listener can call `toXContentChunked` to format the output structure. By default, the structure forwards to our existing ES exception structure. - `UnifiedChatCompletionException` will override the structure to match the new unified format. - The Rest, Transport, and Stream handlers all check the exception to verify it is a UnifiedChatCompletionException. - OpenAI response handler now reads all the fields in the error message and forwards them to the user. - In the event that a `Throwable` is a `Error`, we rethrow it on another thread so the JVM can catch and handle it. We also stop surfacing the JVM details to the user in the error message (but it's still logged for debugging purposes).
1 parent 623247e commit d08ad46

19 files changed

+749
-45
lines changed

docs/changelog/121396.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 121396
2+
summary: Change format for Unified Chat
3+
area: Machine Learning
4+
type: bug
5+
issues: []
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.inference.results;
9+
10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.ExceptionsHelper;
13+
import org.elasticsearch.common.collect.Iterators;
14+
import org.elasticsearch.core.Nullable;
15+
import org.elasticsearch.rest.RestStatus;
16+
import org.elasticsearch.xcontent.ToXContent;
17+
18+
import java.util.Iterator;
19+
import java.util.Locale;
20+
import java.util.Objects;
21+
22+
import static java.util.Collections.emptyIterator;
23+
import static org.elasticsearch.ExceptionsHelper.maybeError;
24+
import static org.elasticsearch.common.collect.Iterators.concat;
25+
import static org.elasticsearch.common.xcontent.ChunkedToXContentHelper.endObject;
26+
import static org.elasticsearch.common.xcontent.ChunkedToXContentHelper.startObject;
27+
28+
public class UnifiedChatCompletionException extends XContentFormattedException {
29+
30+
private static final Logger log = LogManager.getLogger(UnifiedChatCompletionException.class);
31+
private final String message;
32+
private final String type;
33+
@Nullable
34+
private final String code;
35+
@Nullable
36+
private final String param;
37+
38+
public UnifiedChatCompletionException(RestStatus status, String message, String type, @Nullable String code) {
39+
this(status, message, type, code, null);
40+
}
41+
42+
public UnifiedChatCompletionException(RestStatus status, String message, String type, @Nullable String code, @Nullable String param) {
43+
super(message, status);
44+
this.message = Objects.requireNonNull(message);
45+
this.type = Objects.requireNonNull(type);
46+
this.code = code;
47+
this.param = param;
48+
}
49+
50+
public UnifiedChatCompletionException(
51+
Throwable cause,
52+
RestStatus status,
53+
String message,
54+
String type,
55+
@Nullable String code,
56+
@Nullable String param
57+
) {
58+
super(message, cause, status);
59+
this.message = Objects.requireNonNull(message);
60+
this.type = Objects.requireNonNull(type);
61+
this.code = code;
62+
this.param = param;
63+
}
64+
65+
@Override
66+
public Iterator<? extends ToXContent> toXContentChunked(Params params) {
67+
return concat(
68+
startObject(),
69+
startObject("error"),
70+
optionalField("code", code),
71+
field("message", message),
72+
optionalField("param", param),
73+
field("type", type),
74+
endObject(),
75+
endObject()
76+
);
77+
}
78+
79+
private static Iterator<ToXContent> field(String key, String value) {
80+
return Iterators.single((b, p) -> b.field(key, value));
81+
}
82+
83+
private static Iterator<ToXContent> optionalField(String key, String value) {
84+
return value != null ? Iterators.single((b, p) -> b.field(key, value)) : emptyIterator();
85+
}
86+
87+
public static UnifiedChatCompletionException fromThrowable(Throwable t) {
88+
if (ExceptionsHelper.unwrapCause(t) instanceof UnifiedChatCompletionException e) {
89+
return e;
90+
} else {
91+
return maybeError(t).map(error -> {
92+
// we should never be throwing Error, but just in case we are, rethrow it on another thread so the JVM can handle it and
93+
// return a vague error to the user so that they at least see something went wrong but don't leak JVM details to users
94+
ExceptionsHelper.maybeDieOnAnotherThread(error);
95+
var e = new RuntimeException("Fatal error while streaming response. Please retry the request.");
96+
log.error(e.getMessage(), t);
97+
return new UnifiedChatCompletionException(
98+
RestStatus.INTERNAL_SERVER_ERROR,
99+
e.getMessage(),
100+
getExceptionName(e),
101+
RestStatus.INTERNAL_SERVER_ERROR.name().toLowerCase(Locale.ROOT)
102+
);
103+
}).orElseGet(() -> {
104+
log.atDebug().withThrowable(t).log("UnifiedChatCompletionException stack trace for debugging purposes.");
105+
var status = ExceptionsHelper.status(t);
106+
return new UnifiedChatCompletionException(
107+
t,
108+
status,
109+
t.getMessage(),
110+
getExceptionName(t),
111+
status.name().toLowerCase(Locale.ROOT),
112+
null
113+
);
114+
});
115+
}
116+
}
117+
}
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.inference.results;
9+
10+
import org.elasticsearch.ElasticsearchException;
11+
import org.elasticsearch.common.collect.Iterators;
12+
import org.elasticsearch.common.xcontent.ChunkedToXContent;
13+
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
14+
import org.elasticsearch.rest.RestStatus;
15+
import org.elasticsearch.xcontent.ToXContent;
16+
import org.elasticsearch.xcontent.XContentBuilder;
17+
18+
import java.util.Iterator;
19+
import java.util.Objects;
20+
21+
/**
22+
* Similar to {@link org.elasticsearch.ElasticsearchWrapperException}, this will wrap an Exception to generate an xContent using
23+
* {@link ElasticsearchException#generateFailureXContent(XContentBuilder, Params, Exception, boolean)}.
24+
* Extends {@link ElasticsearchException} to provide REST handlers the {@link #status()} method in order to set the response header.
25+
*/
26+
public class XContentFormattedException extends ElasticsearchException implements ChunkedToXContent {
27+
28+
public static final String X_CONTENT_PARAM = "detailedErrorsEnabled";
29+
private final RestStatus status;
30+
private final Throwable cause;
31+
32+
public XContentFormattedException(String message, RestStatus status) {
33+
super(message);
34+
this.status = Objects.requireNonNull(status);
35+
this.cause = null;
36+
}
37+
38+
public XContentFormattedException(Throwable cause, RestStatus status) {
39+
super(cause);
40+
this.status = Objects.requireNonNull(status);
41+
this.cause = cause;
42+
}
43+
44+
public XContentFormattedException(String message, Throwable cause, RestStatus status) {
45+
super(message, cause);
46+
this.status = Objects.requireNonNull(status);
47+
this.cause = cause;
48+
}
49+
50+
@Override
51+
public RestStatus status() {
52+
return status;
53+
}
54+
55+
@Override
56+
public Iterator<? extends ToXContent> toXContentChunked(Params params) {
57+
return Iterators.concat(
58+
ChunkedToXContentHelper.startObject(),
59+
Iterators.single(
60+
(b, p) -> ElasticsearchException.generateFailureXContent(
61+
b,
62+
p,
63+
cause instanceof Exception e ? e : this,
64+
params.paramAsBoolean(X_CONTENT_PARAM, false)
65+
)
66+
),
67+
Iterators.single((b, p) -> b.field("status", status.getStatus())),
68+
ChunkedToXContentHelper.endObject()
69+
);
70+
}
71+
72+
@Override
73+
public boolean isFragment() {
74+
return super.isFragment();
75+
}
76+
}

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,12 @@
4343
import org.elasticsearch.rest.RestController;
4444
import org.elasticsearch.rest.RestHandler;
4545
import org.elasticsearch.rest.RestRequest;
46+
import org.elasticsearch.rest.RestStatus;
4647
import org.elasticsearch.test.ESIntegTestCase;
4748
import org.elasticsearch.threadpool.ThreadPool;
4849
import org.elasticsearch.xcontent.ToXContent;
4950
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
51+
import org.elasticsearch.xpack.core.inference.results.XContentFormattedException;
5052
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
5153
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField;
5254
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
@@ -79,6 +81,7 @@ public class ServerSentEventsRestActionListenerTests extends ESIntegTestCase {
7981
private static final String REQUEST_COUNT = "request_count";
8082
private static final String WITH_ERROR = "with_error";
8183
private static final String ERROR_ROUTE = "/_inference_error";
84+
private static final String FORMATTED_ERROR_ROUTE = "/_formatted_inference_error";
8285
private static final String NO_STREAM_ROUTE = "/_inference_no_stream";
8386
private static final Exception expectedException = new IllegalStateException("hello there");
8487
private static final String expectedExceptionAsServerSentEvent = """
@@ -87,6 +90,11 @@ public class ServerSentEventsRestActionListenerTests extends ESIntegTestCase {
8790
"type":"illegal_state_exception","reason":"hello there"},"status":500\
8891
}""";
8992

93+
private static final Exception expectedFormattedException = new XContentFormattedException(
94+
expectedException,
95+
RestStatus.INTERNAL_SERVER_ERROR
96+
);
97+
9098
@Override
9199
protected boolean addMockHttpTransport() {
92100
return false;
@@ -144,6 +152,16 @@ public List<Route> routes() {
144152
public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) {
145153
new ServerSentEventsRestActionListener(channel, threadPool).onFailure(expectedException);
146154
}
155+
}, new RestHandler() {
156+
@Override
157+
public List<Route> routes() {
158+
return List.of(new Route(RestRequest.Method.POST, FORMATTED_ERROR_ROUTE));
159+
}
160+
161+
@Override
162+
public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) {
163+
new ServerSentEventsRestActionListener(channel, threadPool).onFailure(expectedFormattedException);
164+
}
147165
}, new RestHandler() {
148166
@Override
149167
public List<Route> routes() {
@@ -423,6 +441,21 @@ public void testErrorMidStream() {
423441
assertThat(collector.stringsVerified.getLast(), equalTo(expectedExceptionAsServerSentEvent));
424442
}
425443

444+
public void testFormattedError() throws IOException {
445+
var request = new Request(RestRequest.Method.POST.name(), FORMATTED_ERROR_ROUTE);
446+
447+
try {
448+
getRestClient().performRequest(request);
449+
fail("Expected an exception to be thrown from the error route");
450+
} catch (ResponseException e) {
451+
var response = e.getResponse();
452+
assertThat(response.getStatusLine().getStatusCode(), is(HttpStatus.SC_INTERNAL_SERVER_ERROR));
453+
assertThat(EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8), equalTo("""
454+
\uFEFFevent: error
455+
data:\s""" + expectedExceptionAsServerSentEvent + "\n\n"));
456+
}
457+
}
458+
426459
public void testNoStream() {
427460
var collector = new RandomStringCollector();
428461
var expectedTestCount = randomIntBetween(2, 30);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,11 @@
5050
import java.io.IOException;
5151
import java.util.Random;
5252
import java.util.concurrent.Executor;
53+
import java.util.concurrent.Flow;
5354
import java.util.function.Supplier;
5455
import java.util.stream.Collectors;
5556

57+
import static org.elasticsearch.ExceptionsHelper.unwrapCause;
5658
import static org.elasticsearch.core.Strings.format;
5759
import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;
5860
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes;
@@ -280,7 +282,9 @@ private void inferOnServiceWithMetrics(
280282
var instrumentedStream = new PublisherWithMetrics(timer, model);
281283
taskProcessor.subscribe(instrumentedStream);
282284

283-
listener.onResponse(new InferenceAction.Response(inferenceResults, instrumentedStream));
285+
var streamErrorHandler = streamErrorHandler(instrumentedStream);
286+
287+
listener.onResponse(new InferenceAction.Response(inferenceResults, streamErrorHandler));
284288
} else {
285289
recordMetrics(model, timer, null);
286290
listener.onResponse(new InferenceAction.Response(inferenceResults));
@@ -291,9 +295,13 @@ private void inferOnServiceWithMetrics(
291295
}));
292296
}
293297

298+
protected Flow.Publisher<ChunkedToXContent> streamErrorHandler(Flow.Processor<ChunkedToXContent, ChunkedToXContent> upstream) {
299+
return upstream;
300+
}
301+
294302
private void recordMetrics(Model model, InferenceTimer timer, @Nullable Throwable t) {
295303
try {
296-
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t));
304+
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, unwrapCause(t)));
297305
} catch (Exception e) {
298306
log.atDebug().withThrowable(e).log("Failed to record metrics with a parsed model, dropping metrics");
299307
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.action.ActionListener;
1212
import org.elasticsearch.action.support.ActionFilters;
1313
import org.elasticsearch.client.internal.node.NodeClient;
14+
import org.elasticsearch.common.xcontent.ChunkedToXContent;
1415
import org.elasticsearch.inference.InferenceService;
1516
import org.elasticsearch.inference.InferenceServiceRegistry;
1617
import org.elasticsearch.inference.InferenceServiceResults;
@@ -20,14 +21,19 @@
2021
import org.elasticsearch.injection.guice.Inject;
2122
import org.elasticsearch.license.XPackLicenseState;
2223
import org.elasticsearch.rest.RestStatus;
24+
import org.elasticsearch.tasks.Task;
2325
import org.elasticsearch.threadpool.ThreadPool;
2426
import org.elasticsearch.transport.TransportService;
27+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
2528
import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;
29+
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
2630
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
2731
import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
2832
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
2933
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
3034

35+
import java.util.concurrent.Flow;
36+
3137
public class TransportUnifiedCompletionInferenceAction extends BaseTransportInferenceAction<UnifiedCompletionAction.Request> {
3238

3339
@Inject
@@ -86,4 +92,40 @@ protected void doInference(
8692
) {
8793
service.unifiedCompletionInfer(model, request.getUnifiedCompletionRequest(), null, listener);
8894
}
95+
96+
@Override
97+
protected void doExecute(Task task, UnifiedCompletionAction.Request request, ActionListener<InferenceAction.Response> listener) {
98+
super.doExecute(task, request, listener.delegateResponse((l, e) -> l.onFailure(UnifiedChatCompletionException.fromThrowable(e))));
99+
}
100+
101+
/**
102+
* If we get any errors, either in {@link #doExecute} via the listener.onFailure or while streaming, make sure that they are formatted
103+
* as {@link UnifiedChatCompletionException}.
104+
*/
105+
@Override
106+
protected Flow.Publisher<ChunkedToXContent> streamErrorHandler(Flow.Processor<ChunkedToXContent, ChunkedToXContent> upstream) {
107+
return downstream -> {
108+
upstream.subscribe(new Flow.Subscriber<>() {
109+
@Override
110+
public void onSubscribe(Flow.Subscription subscription) {
111+
downstream.onSubscribe(subscription);
112+
}
113+
114+
@Override
115+
public void onNext(ChunkedToXContent item) {
116+
downstream.onNext(item);
117+
}
118+
119+
@Override
120+
public void onError(Throwable throwable) {
121+
downstream.onError(UnifiedChatCompletionException.fromThrowable(throwable));
122+
}
123+
124+
@Override
125+
public void onComplete() {
126+
downstream.onComplete();
127+
}
128+
});
129+
};
130+
}
89131
}

0 commit comments

Comments
 (0)