Skip to content

Commit e065a37

Browse files
[ML] Stream Cohere Completion (#114080)
Implement and enable streaming for Cohere chat completions (v1). Includes processor for ND JSON streaming responses. Co-authored-by: Elastic Machine <[email protected]>
1 parent ce73a90 commit e065a37

File tree

17 files changed

+585
-26
lines changed

17 files changed

+585
-26
lines changed

docs/changelog/114080.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 114080
2+
summary: Stream Cohere Completion
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
public abstract class DelegatingProcessor<T, R> implements Flow.Processor<T, R> {
2222
private static final Logger log = LogManager.getLogger(DelegatingProcessor.class);
2323
private final AtomicLong pendingRequests = new AtomicLong();
24-
private final AtomicBoolean isClosed = new AtomicBoolean(false);
24+
protected final AtomicBoolean isClosed = new AtomicBoolean(false);
2525
private Flow.Subscriber<? super R> downstream;
2626
private Flow.Subscription upstream;
2727

@@ -49,7 +49,7 @@ private Flow.Subscription forwardingSubscription() {
4949
@Override
5050
public void request(long n) {
5151
if (isClosed.get()) {
52-
downstream.onComplete(); // shouldn't happen, but reinforce that we're no longer listening
52+
downstream.onComplete();
5353
} else if (upstream != null) {
5454
upstream.request(n);
5555
} else {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,19 @@
88
package org.elasticsearch.xpack.inference.external.cohere;
99

1010
import org.apache.logging.log4j.Logger;
11+
import org.elasticsearch.inference.InferenceServiceResults;
12+
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
1113
import org.elasticsearch.xpack.inference.external.http.HttpResult;
1214
import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
1315
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
1416
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
1517
import org.elasticsearch.xpack.inference.external.request.Request;
1618
import org.elasticsearch.xpack.inference.external.response.cohere.CohereErrorResponseEntity;
19+
import org.elasticsearch.xpack.inference.external.response.streaming.NewlineDelimitedByteProcessor;
1720
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
1821

22+
import java.util.concurrent.Flow;
23+
1924
import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody;
2025

2126
/**
@@ -33,9 +38,11 @@
3338
public class CohereResponseHandler extends BaseResponseHandler {
3439
static final String TEXTS_ARRAY_TOO_LARGE_MESSAGE_MATCHER = "invalid request: total number of texts must be at most";
3540
static final String TEXTS_ARRAY_ERROR_MESSAGE = "Received a texts array too large response";
41+
private final boolean canHandleStreamingResponse;
3642

37-
public CohereResponseHandler(String requestType, ResponseParser parseFunction) {
43+
public CohereResponseHandler(String requestType, ResponseParser parseFunction, boolean canHandleStreamingResponse) {
3844
super(requestType, parseFunction, CohereErrorResponseEntity::fromResponse);
45+
this.canHandleStreamingResponse = canHandleStreamingResponse;
3946
}
4047

4148
@Override
@@ -45,6 +52,20 @@ public void validateResponse(ThrottlerManager throttlerManager, Logger logger, R
4552
checkForEmptyBody(throttlerManager, logger, request, result);
4653
}
4754

55+
@Override
56+
public boolean canHandleStreamingResponses() {
57+
return canHandleStreamingResponse;
58+
}
59+
60+
@Override
61+
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
62+
var ndProcessor = new NewlineDelimitedByteProcessor();
63+
var cohereProcessor = new CohereStreamingProcessor();
64+
flow.subscribe(ndProcessor);
65+
ndProcessor.subscribe(cohereProcessor);
66+
return new StreamingChatCompletionResults(cohereProcessor);
67+
}
68+
4869
/**
4970
* Validates the status code throws an RetryException if not in the range [200, 300).
5071
*
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.cohere;
9+
10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.ElasticsearchStatusException;
13+
import org.elasticsearch.rest.RestStatus;
14+
import org.elasticsearch.xcontent.XContentFactory;
15+
import org.elasticsearch.xcontent.XContentParser;
16+
import org.elasticsearch.xcontent.XContentParserConfiguration;
17+
import org.elasticsearch.xcontent.XContentType;
18+
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
19+
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
20+
21+
import java.io.IOException;
22+
import java.util.ArrayDeque;
23+
import java.util.Deque;
24+
import java.util.Map;
25+
import java.util.Optional;
26+
27+
class CohereStreamingProcessor extends DelegatingProcessor<Deque<String>, StreamingChatCompletionResults.Results> {
28+
private static final Logger log = LogManager.getLogger(CohereStreamingProcessor.class);
29+
30+
@Override
31+
protected void next(Deque<String> item) throws Exception {
32+
if (item.isEmpty()) {
33+
// discard empty result and go to the next
34+
upstream().request(1);
35+
return;
36+
}
37+
38+
var results = new ArrayDeque<StreamingChatCompletionResults.Result>(item.size());
39+
for (String json : item) {
40+
try (var jsonParser = jsonParser(json)) {
41+
var responseMap = jsonParser.map();
42+
var eventType = (String) responseMap.get("event_type");
43+
switch (eventType) {
44+
case "text-generation" -> parseText(responseMap).ifPresent(results::offer);
45+
case "stream-end" -> validateResponse(responseMap);
46+
case "stream-start", "search-queries-generation", "search-results", "citation-generation", "tool-calls-generation",
47+
"tool-calls-chunk" -> {
48+
log.debug("Skipping event type [{}] for line [{}].", eventType, item);
49+
}
50+
default -> throw new IOException("Unknown eventType found: " + eventType);
51+
}
52+
} catch (ElasticsearchStatusException e) {
53+
throw e;
54+
} catch (Exception e) {
55+
log.warn("Failed to parse json from cohere: {}", json);
56+
throw e;
57+
}
58+
}
59+
60+
if (results.isEmpty()) {
61+
upstream().request(1);
62+
} else {
63+
downstream().onNext(new StreamingChatCompletionResults.Results(results));
64+
}
65+
}
66+
67+
private static XContentParser jsonParser(String line) throws IOException {
68+
return XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, line);
69+
}
70+
71+
private Optional<StreamingChatCompletionResults.Result> parseText(Map<String, Object> responseMap) throws IOException {
72+
var text = (String) responseMap.get("text");
73+
if (text != null) {
74+
return Optional.of(new StreamingChatCompletionResults.Result(text));
75+
} else {
76+
throw new IOException("Null text found in text-generation cohere event");
77+
}
78+
}
79+
80+
private void validateResponse(Map<String, Object> responseMap) {
81+
var finishReason = (String) responseMap.get("finish_reason");
82+
switch (finishReason) {
83+
case "ERROR", "ERROR_TOXIC" -> throw new ElasticsearchStatusException(
84+
"Cohere stopped the stream due to an error: {}",
85+
RestStatus.INTERNAL_SERVER_ERROR,
86+
parseErrorMessage(responseMap)
87+
);
88+
case "ERROR_LIMIT" -> throw new ElasticsearchStatusException(
89+
"Cohere stopped the stream due to an error: {}",
90+
RestStatus.TOO_MANY_REQUESTS,
91+
parseErrorMessage(responseMap)
92+
);
93+
}
94+
}
95+
96+
@SuppressWarnings("unchecked")
97+
private String parseErrorMessage(Map<String, Object> responseMap) {
98+
var innerResponseMap = (Map<String, Object>) responseMap.get("response");
99+
return (String) innerResponseMap.get("text");
100+
}
101+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import org.elasticsearch.xpack.inference.external.response.cohere.CohereCompletionResponseEntity;
2020
import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModel;
2121

22-
import java.util.List;
2322
import java.util.Objects;
2423
import java.util.function.Supplier;
2524

@@ -30,7 +29,7 @@ public class CohereCompletionRequestManager extends CohereRequestManager {
3029
private static final ResponseHandler HANDLER = createCompletionHandler();
3130

3231
private static ResponseHandler createCompletionHandler() {
33-
return new CohereResponseHandler("cohere completion", CohereCompletionResponseEntity::fromResponse);
32+
return new CohereResponseHandler("cohere completion", CohereCompletionResponseEntity::fromResponse, true);
3433
}
3534

3635
public static CohereCompletionRequestManager of(CohereCompletionModel model, ThreadPool threadPool) {
@@ -51,8 +50,10 @@ public void execute(
5150
Supplier<Boolean> hasRequestCompletedFunction,
5251
ActionListener<InferenceServiceResults> listener
5352
) {
54-
List<String> docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs();
55-
CohereCompletionRequest request = new CohereCompletionRequest(docsInput, model);
53+
var docsOnly = DocumentsOnlyInput.of(inferenceInputs);
54+
var docsInput = docsOnly.getInputs();
55+
var stream = docsOnly.stream();
56+
CohereCompletionRequest request = new CohereCompletionRequest(docsInput, model, stream);
5657

5758
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
5859
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsRequestManager.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ public class CohereEmbeddingsRequestManager extends CohereRequestManager {
2828
private static final ResponseHandler HANDLER = createEmbeddingsHandler();
2929

3030
private static ResponseHandler createEmbeddingsHandler() {
31-
return new CohereResponseHandler("cohere text embedding", CohereEmbeddingsResponseEntity::fromResponse);
31+
return new CohereResponseHandler("cohere text embedding", CohereEmbeddingsResponseEntity::fromResponse, false);
3232
}
3333

3434
public static CohereEmbeddingsRequestManager of(CohereEmbeddingsModel model, ThreadPool threadPool) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ public class CohereRerankRequestManager extends CohereRequestManager {
2727
private static final ResponseHandler HANDLER = createCohereResponseHandler();
2828

2929
private static ResponseHandler createCohereResponseHandler() {
30-
return new CohereResponseHandler("cohere rerank", (request, response) -> CohereRankedResponseEntity.fromResponse(response));
30+
return new CohereResponseHandler("cohere rerank", (request, response) -> CohereRankedResponseEntity.fromResponse(response), false);
3131
}
3232

3333
public static CohereRerankRequestManager of(CohereRerankModel model, ThreadPool threadPool) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/completion/CohereCompletionRequest.java

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,30 +25,28 @@
2525
import java.util.Objects;
2626

2727
public class CohereCompletionRequest extends CohereRequest {
28-
2928
private final CohereAccount account;
30-
3129
private final List<String> input;
32-
3330
private final String modelId;
34-
3531
private final String inferenceEntityId;
32+
private final boolean stream;
3633

37-
public CohereCompletionRequest(List<String> input, CohereCompletionModel model) {
34+
public CohereCompletionRequest(List<String> input, CohereCompletionModel model, boolean stream) {
3835
Objects.requireNonNull(model);
3936

4037
this.account = CohereAccount.of(model, CohereCompletionRequest::buildDefaultUri);
4138
this.input = Objects.requireNonNull(input);
4239
this.modelId = model.getServiceSettings().modelId();
4340
this.inferenceEntityId = model.getInferenceEntityId();
41+
this.stream = stream;
4442
}
4543

4644
@Override
4745
public HttpRequest createHttpRequest() {
4846
HttpPost httpPost = new HttpPost(account.uri());
4947

5048
ByteArrayEntity byteEntity = new ByteArrayEntity(
51-
Strings.toString(new CohereCompletionRequestEntity(input, modelId)).getBytes(StandardCharsets.UTF_8)
49+
Strings.toString(new CohereCompletionRequestEntity(input, modelId, isStreaming())).getBytes(StandardCharsets.UTF_8)
5250
);
5351
httpPost.setEntity(byteEntity);
5452

@@ -62,6 +60,11 @@ public String getInferenceEntityId() {
6260
return inferenceEntityId;
6361
}
6462

63+
@Override
64+
public boolean isStreaming() {
65+
return stream;
66+
}
67+
6568
@Override
6669
public URI getURI() {
6770
return account.uri();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/completion/CohereCompletionRequestEntity.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
import java.util.List;
1616
import java.util.Objects;
1717

18-
public record CohereCompletionRequestEntity(List<String> input, @Nullable String model) implements ToXContentObject {
18+
public record CohereCompletionRequestEntity(List<String> input, @Nullable String model, boolean stream) implements ToXContentObject {
1919

2020
private static final String MESSAGE_FIELD = "message";
21-
2221
private static final String MODEL = "model";
22+
private static final String STREAM = "stream";
2323

2424
public CohereCompletionRequestEntity {
2525
Objects.requireNonNull(input);
@@ -36,6 +36,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
3636
builder.field(MODEL, model);
3737
}
3838

39+
if (stream) {
40+
builder.field(STREAM, true);
41+
}
42+
3943
builder.endObject();
4044

4145
return builder;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.response.streaming;
9+
10+
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
11+
import org.elasticsearch.xpack.inference.external.http.HttpResult;
12+
13+
import java.nio.charset.StandardCharsets;
14+
import java.util.ArrayDeque;
15+
import java.util.Deque;
16+
import java.util.regex.Pattern;
17+
18+
/**
19+
* Processes HttpResult bytes into lines separated by newlines, delimited by either line-feed or carriage-return line-feed.
20+
* Downstream is responsible for validating the structure of the lines after they have been separated.
21+
* Because Upstream (Apache) can send us a single line split between two HttpResults, this processor will aggregate bytes from the last
22+
* HttpResult and append them to the front of the next HttpResult.
23+
* When onComplete is called, the last batch is always flushed to the downstream onNext.
24+
*/
25+
public class NewlineDelimitedByteProcessor extends DelegatingProcessor<HttpResult, Deque<String>> {
26+
private static final Pattern END_OF_LINE_REGEX = Pattern.compile("\\n|\\r\\n");
27+
private volatile String previousTokens = "";
28+
29+
@Override
30+
protected void next(HttpResult item) {
31+
// discard empty result and go to the next
32+
if (item.isBodyEmpty()) {
33+
upstream().request(1);
34+
return;
35+
}
36+
37+
var body = previousTokens + new String(item.body(), StandardCharsets.UTF_8);
38+
var lines = END_OF_LINE_REGEX.split(body, -1); // -1 because we actually want trailing empty strings
39+
40+
var results = new ArrayDeque<String>(lines.length);
41+
for (var i = 0; i < lines.length - 1; i++) {
42+
var line = lines[i].trim();
43+
if (line.isBlank() == false) {
44+
results.offer(line);
45+
}
46+
}
47+
48+
previousTokens = lines[lines.length - 1].trim();
49+
50+
if (results.isEmpty()) {
51+
upstream().request(1);
52+
} else {
53+
downstream().onNext(results);
54+
}
55+
}
56+
57+
@Override
58+
public void onComplete() {
59+
if (previousTokens.isBlank()) {
60+
super.onComplete();
61+
} else if (isClosed.compareAndSet(false, true)) {
62+
var results = new ArrayDeque<String>(1);
63+
results.offer(previousTokens);
64+
downstream().onNext(results);
65+
}
66+
}
67+
}

0 commit comments

Comments
 (0)