Skip to content

Commit 95099a7

Browse files
committed
[ML] Change format for Unified Chat
Unified Chat Completion error responses now forward code, type, and param to in the response payload. `reason` has been renamed to `message`.
1 parent 2e84950 commit 95099a7

File tree

15 files changed

+703
-24
lines changed

15 files changed

+703
-24
lines changed
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.inference.results;
9+
10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.ElasticsearchWrapperException;
13+
import org.elasticsearch.ExceptionsHelper;
14+
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
15+
import org.elasticsearch.core.Nullable;
16+
import org.elasticsearch.rest.RestStatus;
17+
import org.elasticsearch.xcontent.ToXContent;
18+
19+
import java.util.Iterator;
20+
import java.util.Locale;
21+
22+
import static java.util.Collections.emptyIterator;
23+
import static org.elasticsearch.ExceptionsHelper.maybeError;
24+
import static org.elasticsearch.common.collect.Iterators.concat;
25+
import static org.elasticsearch.common.xcontent.ChunkedToXContentHelper.endObject;
26+
import static org.elasticsearch.common.xcontent.ChunkedToXContentHelper.startObject;
27+
28+
public class UnifiedChatCompletionException extends XContentFormattedException {
29+
30+
private static final Logger log = LogManager.getLogger(UnifiedChatCompletionException.class);
31+
private final String message;
32+
private final String type;
33+
@Nullable
34+
private final String code;
35+
@Nullable
36+
private final String param;
37+
38+
public UnifiedChatCompletionException(RestStatus status, String message, String type, @Nullable String code) {
39+
this(status, message, type, code, null);
40+
}
41+
42+
public UnifiedChatCompletionException(
43+
RestStatus status,
44+
String message,
45+
String type,
46+
@Nullable String code,
47+
@Nullable String param
48+
) {
49+
super(message, status);
50+
this.message = message;
51+
this.type = type;
52+
this.code = code;
53+
this.param = param;
54+
}
55+
56+
public UnifiedChatCompletionException(
57+
Throwable cause,
58+
RestStatus status,
59+
String message,
60+
String type,
61+
@Nullable String code,
62+
@Nullable String param
63+
) {
64+
super(message, cause, status);
65+
this.message = message;
66+
this.type = type;
67+
this.code = code;
68+
this.param = param;
69+
}
70+
71+
@Override
72+
public Iterator<? extends ToXContent> toXContentChunked(Params params) {
73+
return concat(
74+
startObject(),
75+
startObject("error"),
76+
optionalField("code", code),
77+
field("message", message),
78+
optionalField("param", param),
79+
field("type", type),
80+
endObject(),
81+
endObject()
82+
);
83+
}
84+
85+
private static Iterator<ToXContent> field(String key, String value) {
86+
return ChunkedToXContentHelper.chunk((b, p) -> b.field(key, value));
87+
}
88+
89+
private static Iterator<ToXContent> optionalField(String key, String value) {
90+
return value != null ? ChunkedToXContentHelper.chunk((b, p) -> b.field(key, value)) : emptyIterator();
91+
}
92+
93+
public static UnifiedChatCompletionException fromThrowable(Throwable t) {
94+
if (t instanceof UnifiedChatCompletionException e) {
95+
return e;
96+
} else if (unwrapCause(t) instanceof UnifiedChatCompletionException e) {
97+
return e;
98+
} else {
99+
return maybeError(t).map(error -> {
100+
// we should never be throwing Error, but just in case we are, rethrow it on another thread so the JVM can handle it and
101+
// return a vague error to the user so that they at least see something went wrong but don't leak JVM details to users
102+
ExceptionsHelper.maybeDieOnAnotherThread(error);
103+
var e = new RuntimeException("Fatal error while streaming response. Please retry the request.");
104+
log.error(e.getMessage(), t);
105+
return new UnifiedChatCompletionException(
106+
RestStatus.INTERNAL_SERVER_ERROR,
107+
e.getMessage(),
108+
getExceptionName(e),
109+
RestStatus.INTERNAL_SERVER_ERROR.name().toLowerCase(Locale.ROOT)
110+
);
111+
}).orElseGet(() -> {
112+
log.atDebug().withThrowable(t).log("UnifiedChatCompletionException stack trace for debugging purposes.");
113+
var status = ExceptionsHelper.status(t);
114+
return new UnifiedChatCompletionException(
115+
t,
116+
status,
117+
t.getMessage(),
118+
getExceptionName(t),
119+
status.name().toLowerCase(Locale.ROOT),
120+
null
121+
);
122+
});
123+
}
124+
}
125+
126+
private static Throwable unwrapCause(Throwable t) {
127+
int counter = 0;
128+
Throwable result = t;
129+
while (result instanceof ElasticsearchWrapperException) {
130+
if (result instanceof UnifiedChatCompletionException) {
131+
return result;
132+
}
133+
if (result.getCause() == null) {
134+
return result;
135+
}
136+
if (result.getCause() == result) {
137+
return result;
138+
}
139+
if (counter++ > 10) {
140+
return result;
141+
}
142+
result = result.getCause();
143+
}
144+
return result;
145+
}
146+
}
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.inference.results;
9+
10+
import org.elasticsearch.ElasticsearchException;
11+
import org.elasticsearch.ElasticsearchWrapperException;
12+
import org.elasticsearch.common.collect.Iterators;
13+
import org.elasticsearch.common.xcontent.ChunkedToXContent;
14+
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
15+
import org.elasticsearch.core.RestApiVersion;
16+
import org.elasticsearch.rest.RestStatus;
17+
import org.elasticsearch.xcontent.ToXContent;
18+
import org.elasticsearch.xcontent.XContentBuilder;
19+
20+
import java.util.Iterator;
21+
import java.util.Objects;
22+
23+
/**
24+
* Similar to {@link org.elasticsearch.ElasticsearchWrapperException}, this will wrap an Exception to generate an xContent using
25+
* {@link ElasticsearchException#generateFailureXContent(XContentBuilder, Params, Exception, boolean)}.
26+
* Extends {@link ElasticsearchException} to provide REST handlers the {@link #status()} method in order to set the response header.
27+
*/
28+
public class XContentFormattedException extends ElasticsearchException implements ChunkedToXContent {
29+
30+
public static final String X_CONTENT_PARAM = "detailedErrorsEnabled";
31+
private final RestStatus status;
32+
private final Throwable cause;
33+
34+
public XContentFormattedException(String message, RestStatus status) {
35+
super(message);
36+
this.status = Objects.requireNonNull(status);
37+
this.cause = null;
38+
}
39+
40+
public XContentFormattedException(Throwable cause, RestStatus status) {
41+
super(cause);
42+
this.status = Objects.requireNonNull(status);
43+
this.cause = cause;
44+
}
45+
46+
public XContentFormattedException(String message, Throwable cause, RestStatus status) {
47+
super(message, cause);
48+
this.status = Objects.requireNonNull(status);
49+
this.cause = cause;
50+
}
51+
52+
@Override
53+
public RestStatus status() {
54+
return status;
55+
}
56+
57+
@Override
58+
public Iterator<? extends ToXContent> toXContentChunked(Params params) {
59+
return Iterators.concat(
60+
ChunkedToXContentHelper.startObject(),
61+
Iterators.single(
62+
(b, p) -> ElasticsearchException.generateFailureXContent(
63+
b,
64+
p,
65+
cause instanceof Exception e ? e : this,
66+
params.paramAsBoolean(X_CONTENT_PARAM, false)
67+
)
68+
),
69+
Iterators.single((b, p) -> b.field("status", status.getStatus())),
70+
ChunkedToXContentHelper.endObject()
71+
);
72+
}
73+
74+
@Override
75+
public Iterator<? extends ToXContent> toXContentChunked(RestApiVersion restApiVersion, Params params) {
76+
return ChunkedToXContent.super.toXContentChunked(restApiVersion, params);
77+
}
78+
79+
@Override
80+
public Iterator<? extends ToXContent> toXContentChunkedV8(Params params) {
81+
return ChunkedToXContent.super.toXContentChunkedV8(params);
82+
}
83+
84+
@Override
85+
public boolean isFragment() {
86+
return super.isFragment();
87+
}
88+
}

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,12 @@
4444
import org.elasticsearch.rest.RestController;
4545
import org.elasticsearch.rest.RestHandler;
4646
import org.elasticsearch.rest.RestRequest;
47+
import org.elasticsearch.rest.RestStatus;
4748
import org.elasticsearch.test.ESIntegTestCase;
4849
import org.elasticsearch.threadpool.ThreadPool;
4950
import org.elasticsearch.xcontent.ToXContent;
5051
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
52+
import org.elasticsearch.xpack.core.inference.results.XContentFormattedException;
5153
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
5254
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField;
5355
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
@@ -80,6 +82,7 @@ public class ServerSentEventsRestActionListenerTests extends ESIntegTestCase {
8082
private static final String REQUEST_COUNT = "request_count";
8183
private static final String WITH_ERROR = "with_error";
8284
private static final String ERROR_ROUTE = "/_inference_error";
85+
private static final String FORMATTED_ERROR_ROUTE = "/_formatted_inference_error";
8386
private static final String NO_STREAM_ROUTE = "/_inference_no_stream";
8487
private static final Exception expectedException = new IllegalStateException("hello there");
8588
private static final String expectedExceptionAsServerSentEvent = """
@@ -88,6 +91,11 @@ public class ServerSentEventsRestActionListenerTests extends ESIntegTestCase {
8891
"type":"illegal_state_exception","reason":"hello there"},"status":500\
8992
}""";
9093

94+
private static final Exception expectedFormattedException = new XContentFormattedException(
95+
expectedException,
96+
RestStatus.INTERNAL_SERVER_ERROR
97+
);
98+
9199
@Override
92100
protected boolean addMockHttpTransport() {
93101
return false;
@@ -145,6 +153,16 @@ public List<Route> routes() {
145153
public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) {
146154
new ServerSentEventsRestActionListener(channel, threadPool).onFailure(expectedException);
147155
}
156+
}, new RestHandler() {
157+
@Override
158+
public List<Route> routes() {
159+
return List.of(new Route(RestRequest.Method.POST, FORMATTED_ERROR_ROUTE));
160+
}
161+
162+
@Override
163+
public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) {
164+
new ServerSentEventsRestActionListener(channel, threadPool).onFailure(expectedFormattedException);
165+
}
148166
}, new RestHandler() {
149167
@Override
150168
public List<Route> routes() {
@@ -424,6 +442,21 @@ public void testErrorMidStream() {
424442
assertThat(collector.stringsVerified.getLast(), equalTo(expectedExceptionAsServerSentEvent));
425443
}
426444

445+
public void testFormattedError() throws IOException {
446+
var request = new Request(RestRequest.Method.POST.name(), FORMATTED_ERROR_ROUTE);
447+
448+
try {
449+
getRestClient().performRequest(request);
450+
fail("Expected an exception to be thrown from the error route");
451+
} catch (ResponseException e) {
452+
var response = e.getResponse();
453+
assertThat(response.getStatusLine().getStatusCode(), is(HttpStatus.SC_INTERNAL_SERVER_ERROR));
454+
assertThat(EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8), equalTo("""
455+
\uFEFFevent: error
456+
data:\s""" + expectedExceptionAsServerSentEvent + "\n\n"));
457+
}
458+
}
459+
427460
public void testNoStream() {
428461
var collector = new RandomStringCollector();
429462
var expectedTestCount = randomIntBetween(2, 30);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,11 @@
5050
import java.io.IOException;
5151
import java.util.Random;
5252
import java.util.concurrent.Executor;
53+
import java.util.concurrent.Flow;
5354
import java.util.function.Supplier;
5455
import java.util.stream.Collectors;
5556

57+
import static org.elasticsearch.ExceptionsHelper.unwrapCause;
5658
import static org.elasticsearch.core.Strings.format;
5759
import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;
5860
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes;
@@ -280,7 +282,9 @@ private void inferOnServiceWithMetrics(
280282
var instrumentedStream = new PublisherWithMetrics(timer, model);
281283
taskProcessor.subscribe(instrumentedStream);
282284

283-
listener.onResponse(new InferenceAction.Response(inferenceResults, instrumentedStream));
285+
var streamErrorHandler = streamErrorHandler(instrumentedStream);
286+
287+
listener.onResponse(new InferenceAction.Response(inferenceResults, streamErrorHandler));
284288
} else {
285289
recordMetrics(model, timer, null);
286290
listener.onResponse(new InferenceAction.Response(inferenceResults));
@@ -291,9 +295,13 @@ private void inferOnServiceWithMetrics(
291295
}));
292296
}
293297

298+
protected Flow.Publisher<ChunkedToXContent> streamErrorHandler(Flow.Processor<ChunkedToXContent, ChunkedToXContent> upstream) {
299+
return upstream;
300+
}
301+
294302
private void recordMetrics(Model model, InferenceTimer timer, @Nullable Throwable t) {
295303
try {
296-
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t));
304+
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, unwrapCause(t)));
297305
} catch (Exception e) {
298306
log.atDebug().withThrowable(e).log("Failed to record metrics with a parsed model, dropping metrics");
299307
}

0 commit comments

Comments
 (0)