Skip to content

Commit 67683cb

Browse files
[ML] Stream Anthropic Completion (#114321)
Enable chat completion streaming responses for Anthropic's server sent events. Co-authored-by: Elastic Machine <[email protected]>
1 parent 965061a commit 67683cb

File tree

12 files changed

+423
-13
lines changed

12 files changed

+423
-13
lines changed

docs/changelog/114321.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 114321
2+
summary: Stream Anthropic Completion
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/anthropic/AnthropicResponseHandler.java

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,20 @@
99

1010
import org.apache.logging.log4j.Logger;
1111
import org.elasticsearch.common.Strings;
12+
import org.elasticsearch.inference.InferenceServiceResults;
13+
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
1214
import org.elasticsearch.xpack.inference.external.http.HttpResult;
1315
import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
1416
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
1517
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
1618
import org.elasticsearch.xpack.inference.external.request.Request;
1719
import org.elasticsearch.xpack.inference.external.response.ErrorMessageResponseEntity;
20+
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
21+
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor;
1822
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
1923

24+
import java.util.concurrent.Flow;
25+
2026
import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody;
2127
import static org.elasticsearch.xpack.inference.external.http.retry.ResponseHandlerUtils.getFirstHeaderOrUnknown;
2228

@@ -41,8 +47,11 @@ public class AnthropicResponseHandler extends BaseResponseHandler {
4147

4248
static final String SERVER_BUSY = "Received an Anthropic server is temporarily overloaded status code";
4349

44-
public AnthropicResponseHandler(String requestType, ResponseParser parseFunction) {
50+
private final boolean canHandleStreamingResponses;
51+
52+
public AnthropicResponseHandler(String requestType, ResponseParser parseFunction, boolean canHandleStreamingResponses) {
4553
super(requestType, parseFunction, ErrorMessageResponseEntity::fromResponse);
54+
this.canHandleStreamingResponses = canHandleStreamingResponses;
4655
}
4756

4857
@Override
@@ -52,6 +61,20 @@ public void validateResponse(ThrottlerManager throttlerManager, Logger logger, R
5261
checkForEmptyBody(throttlerManager, logger, request, result);
5362
}
5463

64+
@Override
65+
public boolean canHandleStreamingResponses() {
66+
return canHandleStreamingResponses;
67+
}
68+
69+
@Override
70+
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
71+
var sseProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
72+
var anthropicProcessor = new AnthropicStreamingProcessor();
73+
sseProcessor.subscribe(anthropicProcessor);
74+
flow.subscribe(sseProcessor);
75+
return new StreamingChatCompletionResults(anthropicProcessor);
76+
}
77+
5578
/**
5679
* Validates the status code throws an RetryException if not in the range [200, 300).
5780
*
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.anthropic;
9+
10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.ElasticsearchStatusException;
13+
import org.elasticsearch.rest.RestStatus;
14+
import org.elasticsearch.xcontent.XContentFactory;
15+
import org.elasticsearch.xcontent.XContentParser;
16+
import org.elasticsearch.xcontent.XContentParserConfiguration;
17+
import org.elasticsearch.xcontent.XContentType;
18+
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
19+
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
20+
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
21+
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField;
22+
23+
import java.io.IOException;
24+
import java.util.ArrayDeque;
25+
import java.util.Deque;
26+
import java.util.Optional;
27+
28+
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
29+
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
30+
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField;
31+
32+
public class AnthropicStreamingProcessor extends DelegatingProcessor<Deque<ServerSentEvent>, StreamingChatCompletionResults.Results> {
33+
private static final Logger log = LogManager.getLogger(AnthropicStreamingProcessor.class);
34+
private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in Anthropic chat completions response";
35+
36+
@Override
37+
protected void next(Deque<ServerSentEvent> item) throws Exception {
38+
if (item.isEmpty()) {
39+
upstream().request(1);
40+
return;
41+
}
42+
43+
var results = new ArrayDeque<StreamingChatCompletionResults.Result>(item.size());
44+
for (var event : item) {
45+
if (event.name() == ServerSentEventField.DATA && event.hasValue()) {
46+
try (var parser = parser(event.value())) {
47+
var eventType = eventType(parser);
48+
switch (eventType) {
49+
case "error" -> {
50+
onError(parseError(parser));
51+
return;
52+
}
53+
case "content_block_start" -> {
54+
parseStartBlock(parser).ifPresent(results::offer);
55+
}
56+
case "content_block_delta" -> {
57+
parseMessage(parser).ifPresent(results::offer);
58+
}
59+
case "message_start", "message_stop", "message_delta", "content_block_stop", "ping" -> {
60+
log.debug("Skipping event type [{}] for line [{}].", eventType, item);
61+
}
62+
default -> {
63+
// "handle unknown events gracefully" https://docs.anthropic.com/en/api/messages-streaming#other-events
64+
// we'll ignore unknown events
65+
log.debug("Unknown event type [{}] for line [{}].", eventType, item);
66+
}
67+
}
68+
} catch (Exception e) {
69+
log.warn("Failed to parse line {}", event);
70+
throw e;
71+
}
72+
}
73+
}
74+
75+
if (results.isEmpty()) {
76+
upstream().request(1);
77+
} else {
78+
downstream().onNext(new StreamingChatCompletionResults.Results(results));
79+
}
80+
}
81+
82+
private Throwable parseError(XContentParser parser) throws IOException {
83+
positionParserAtTokenAfterField(parser, "error", FAILED_TO_FIND_FIELD_TEMPLATE);
84+
var type = parseString(parser, "type");
85+
var message = parseString(parser, "message");
86+
var statusCode = switch (type) {
87+
case "invalid_request_error" -> RestStatus.BAD_REQUEST;
88+
case "authentication_error" -> RestStatus.UNAUTHORIZED;
89+
case "permission_error" -> RestStatus.FORBIDDEN;
90+
case "not_found_error" -> RestStatus.NOT_FOUND;
91+
case "request_too_large" -> RestStatus.REQUEST_ENTITY_TOO_LARGE;
92+
case "rate_limit_error" -> RestStatus.TOO_MANY_REQUESTS;
93+
default -> RestStatus.INTERNAL_SERVER_ERROR;
94+
};
95+
return new ElasticsearchStatusException(message, statusCode);
96+
}
97+
98+
private Optional<StreamingChatCompletionResults.Result> parseStartBlock(XContentParser parser) throws IOException {
99+
positionParserAtTokenAfterField(parser, "content_block", FAILED_TO_FIND_FIELD_TEMPLATE);
100+
var text = parseString(parser, "text");
101+
return text.isBlank() ? Optional.empty() : Optional.of(new StreamingChatCompletionResults.Result(text));
102+
}
103+
104+
private Optional<StreamingChatCompletionResults.Result> parseMessage(XContentParser parser) throws IOException {
105+
positionParserAtTokenAfterField(parser, "delta", FAILED_TO_FIND_FIELD_TEMPLATE);
106+
var text = parseString(parser, "text");
107+
return text.isBlank() ? Optional.empty() : Optional.of(new StreamingChatCompletionResults.Result(text));
108+
}
109+
110+
private static XContentParser parser(String line) throws IOException {
111+
return XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, line);
112+
}
113+
114+
private static String eventType(XContentParser parser) throws IOException {
115+
moveToFirstToken(parser);
116+
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
117+
return parseString(parser, "type");
118+
}
119+
120+
private static String parseString(XContentParser parser, String fieldName) throws IOException {
121+
positionParserAtTokenAfterField(parser, fieldName, FAILED_TO_FIND_FIELD_TEMPLATE);
122+
ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser);
123+
return parser.text();
124+
}
125+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AnthropicCompletionRequestManager.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import org.elasticsearch.xpack.inference.external.response.anthropic.AnthropicChatCompletionResponseEntity;
2020
import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionModel;
2121

22-
import java.util.List;
2322
import java.util.Objects;
2423
import java.util.function.Supplier;
2524

@@ -47,13 +46,15 @@ public void execute(
4746
Supplier<Boolean> hasRequestCompletedFunction,
4847
ActionListener<InferenceServiceResults> listener
4948
) {
50-
List<String> docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs();
51-
AnthropicChatCompletionRequest request = new AnthropicChatCompletionRequest(docsInput, model);
49+
var docsOnly = DocumentsOnlyInput.of(inferenceInputs);
50+
var docsInput = docsOnly.getInputs();
51+
var stream = docsOnly.stream();
52+
AnthropicChatCompletionRequest request = new AnthropicChatCompletionRequest(docsInput, model, stream);
5253

5354
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
5455
}
5556

5657
private static ResponseHandler createCompletionHandler() {
57-
return new AnthropicResponseHandler("anthropic completions", AnthropicChatCompletionResponseEntity::fromResponse);
58+
return new AnthropicResponseHandler("anthropic completions", AnthropicChatCompletionResponseEntity::fromResponse, true);
5859
}
5960
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/anthropic/AnthropicChatCompletionRequest.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,21 @@ public class AnthropicChatCompletionRequest implements Request {
2929
private final AnthropicAccount account;
3030
private final List<String> input;
3131
private final AnthropicChatCompletionModel model;
32+
private final boolean stream;
3233

33-
public AnthropicChatCompletionRequest(List<String> input, AnthropicChatCompletionModel model) {
34+
public AnthropicChatCompletionRequest(List<String> input, AnthropicChatCompletionModel model, boolean stream) {
3435
this.account = AnthropicAccount.of(model);
3536
this.input = Objects.requireNonNull(input);
3637
this.model = Objects.requireNonNull(model);
38+
this.stream = stream;
3739
}
3840

3941
@Override
4042
public HttpRequest createHttpRequest() {
4143
HttpPost httpPost = new HttpPost(account.uri());
4244

4345
ByteArrayEntity byteEntity = new ByteArrayEntity(
44-
Strings.toString(new AnthropicChatCompletionRequestEntity(input, model.getServiceSettings(), model.getTaskSettings()))
46+
Strings.toString(new AnthropicChatCompletionRequestEntity(input, model.getServiceSettings(), model.getTaskSettings(), stream))
4547
.getBytes(StandardCharsets.UTF_8)
4648
);
4749
httpPost.setEntity(byteEntity);
@@ -75,4 +77,9 @@ public String getInferenceEntityId() {
7577
return model.getInferenceEntityId();
7678
}
7779

80+
@Override
81+
public boolean isStreaming() {
82+
return stream;
83+
}
84+
7885
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/anthropic/AnthropicChatCompletionRequestEntity.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,23 @@ public class AnthropicChatCompletionRequestEntity implements ToXContentObject {
2828
private static final String TEMPERATURE_FIELD = "temperature";
2929
private static final String TOP_P_FIELD = "top_p";
3030
private static final String TOP_K_FIELD = "top_k";
31+
private static final String STREAM = "stream";
3132

3233
private final List<String> messages;
3334
private final AnthropicChatCompletionServiceSettings serviceSettings;
3435
private final AnthropicChatCompletionTaskSettings taskSettings;
36+
private final boolean stream;
3537

3638
public AnthropicChatCompletionRequestEntity(
3739
List<String> messages,
3840
AnthropicChatCompletionServiceSettings serviceSettings,
39-
AnthropicChatCompletionTaskSettings taskSettings
41+
AnthropicChatCompletionTaskSettings taskSettings,
42+
boolean stream
4043
) {
4144
this.messages = Objects.requireNonNull(messages);
4245
this.serviceSettings = Objects.requireNonNull(serviceSettings);
4346
this.taskSettings = Objects.requireNonNull(taskSettings);
47+
this.stream = stream;
4448
}
4549

4650
@Override
@@ -77,6 +81,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
7781
builder.field(TOP_K_FIELD, taskSettings.topK());
7882
}
7983

84+
if (stream) {
85+
builder.field(STREAM, true);
86+
}
87+
8088
builder.endObject();
8189

8290
return builder;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
import java.util.List;
3535
import java.util.Map;
36+
import java.util.Set;
3637

3738
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
3839
import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
@@ -199,4 +200,9 @@ protected void doChunkedInfer(
199200
public TransportVersion getMinimalSupportedVersion() {
200201
return TransportVersions.ML_ANTHROPIC_INTEGRATION_ADDED;
201202
}
203+
204+
@Override
205+
public Set<TaskType> supportedStreamingTasks() {
206+
return COMPLETION_ONLY;
207+
}
202208
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/anthropic/AnthropicResponseHandlerTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ private static void callCheckForFailureStatusCode(int statusCode, String inferen
160160
var mockRequest = mock(Request.class);
161161
when(mockRequest.getInferenceEntityId()).thenReturn(inferenceEntityId);
162162
var httpResult = new HttpResult(httpResponse, new byte[] {});
163-
var handler = new AnthropicResponseHandler("", (request, result) -> null);
163+
var handler = new AnthropicResponseHandler("", (request, result) -> null, false);
164164

165165
handler.checkForFailureStatusCode(mockRequest, httpResult);
166166
}

0 commit comments

Comments
 (0)