Skip to content

Commit f273fdc

Browse files
authored
[8.x] [ML] Stream Bedrock Completion (#114732) (#114781)
* [ML] Stream Bedrock Completion (#114732) Notes: - Adds a new API to the chatCompletionRequest to invoke the Bedrock Stream API - Create a StreamingChatProcessor that subscribes to streaming results from bedrock and handles the parsing on another thread. - There was no good way (that I could see) to extend the Provider-based CompletionRequestEntity, so they have been flattened into one RequestEntity that can be shared between ConverseRequest and ConverseStreamRequest. * Use jdk17 API
1 parent ffcf87c commit f273fdc

File tree

33 files changed

+702
-920
lines changed

33 files changed

+702
-920
lines changed

docs/changelog/114732.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 114732
2+
summary: Stream Bedrock Completion
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

x-pack/plugin/inference/src/main/java/module-info.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
requires software.amazon.awssdk.profiles;
3333
requires org.slf4j;
3434
requires software.amazon.awssdk.retries.api;
35+
requires org.reactivestreams;
3536

3637
exports org.elasticsearch.xpack.inference.action;
3738
exports org.elasticsearch.xpack.inference.registry;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockChatCompletionExecutor.java

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.apache.logging.log4j.Logger;
1111
import org.elasticsearch.action.ActionListener;
1212
import org.elasticsearch.inference.InferenceServiceResults;
13+
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
1314
import org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockChatCompletionRequest;
1415
import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponseHandler;
1516
import org.elasticsearch.xpack.inference.external.response.amazonbedrock.completion.AmazonBedrockChatCompletionResponseListener;
@@ -33,11 +34,16 @@ protected AmazonBedrockChatCompletionExecutor(
3334

3435
@Override
3536
protected void executeClientRequest(AmazonBedrockBaseClient awsBedrockClient) {
36-
var chatCompletionResponseListener = new AmazonBedrockChatCompletionResponseListener(
37-
chatCompletionRequest,
38-
responseHandler,
39-
inferenceResultsListener
40-
);
41-
chatCompletionRequest.executeChatCompletionRequest(awsBedrockClient, chatCompletionResponseListener);
37+
if (chatCompletionRequest.isStreaming()) {
38+
var publisher = chatCompletionRequest.executeStreamChatCompletionRequest(awsBedrockClient);
39+
inferenceResultsListener.onResponse(new StreamingChatCompletionResults(publisher));
40+
} else {
41+
var chatCompletionResponseListener = new AmazonBedrockChatCompletionResponseListener(
42+
chatCompletionRequest,
43+
responseHandler,
44+
inferenceResultsListener
45+
);
46+
chatCompletionRequest.executeChatCompletionRequest(awsBedrockClient, chatCompletionResponseListener);
47+
}
4248
}
4349
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockClient.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,22 @@
99

1010
import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest;
1111
import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse;
12+
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest;
1213
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest;
1314
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse;
1415

1516
import org.elasticsearch.ElasticsearchException;
1617
import org.elasticsearch.action.ActionListener;
18+
import org.elasticsearch.common.xcontent.ChunkedToXContent;
1719

1820
import java.time.Instant;
21+
import java.util.concurrent.Flow;
1922

2023
public interface AmazonBedrockClient {
2124
void converse(ConverseRequest converseRequest, ActionListener<ConverseResponse> responseListener) throws ElasticsearchException;
2225

26+
Flow.Publisher<? extends ChunkedToXContent> converseStream(ConverseStreamRequest converseStreamRequest) throws ElasticsearchException;
27+
2328
void invokeModel(InvokeModelRequest invokeModelRequest, ActionListener<InvokeModelResponse> responseListener)
2429
throws ElasticsearchException;
2530

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClient.java

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,21 @@
1717
import software.amazon.awssdk.services.bedrockruntime.model.BedrockRuntimeException;
1818
import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest;
1919
import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse;
20+
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest;
21+
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler;
2022
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest;
2123
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse;
2224

2325
import org.elasticsearch.ElasticsearchException;
2426
import org.elasticsearch.SpecialPermission;
2527
import org.elasticsearch.action.ActionListener;
28+
import org.elasticsearch.common.xcontent.ChunkedToXContent;
2629
import org.elasticsearch.core.Nullable;
2730
import org.elasticsearch.core.Strings;
2831
import org.elasticsearch.core.TimeValue;
32+
import org.elasticsearch.threadpool.ThreadPool;
2933
import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel;
34+
import org.reactivestreams.FlowAdapters;
3035
import org.slf4j.LoggerFactory;
3136

3237
import java.security.AccessController;
@@ -36,6 +41,7 @@
3641
import java.util.Objects;
3742
import java.util.concurrent.CompletionException;
3843
import java.util.concurrent.ExecutionException;
44+
import java.util.concurrent.Flow;
3945

4046
/**
4147
* Not marking this as "final" so we can subclass it for mocking
@@ -53,19 +59,21 @@ public class AmazonBedrockInferenceClient extends AmazonBedrockBaseClient {
5359
private static final Duration DEFAULT_CLIENT_TIMEOUT_MS = Duration.ofMillis(10000);
5460

5561
private final BedrockRuntimeAsyncClient internalClient;
62+
private final ThreadPool threadPool;
5663
private volatile Instant expiryTimestamp;
5764

58-
public static AmazonBedrockBaseClient create(AmazonBedrockModel model, @Nullable TimeValue timeout) {
65+
public static AmazonBedrockBaseClient create(AmazonBedrockModel model, @Nullable TimeValue timeout, ThreadPool threadPool) {
5966
try {
60-
return new AmazonBedrockInferenceClient(model, timeout);
67+
return new AmazonBedrockInferenceClient(model, timeout, threadPool);
6168
} catch (Exception e) {
6269
throw new ElasticsearchException("Failed to create Amazon Bedrock Client", e);
6370
}
6471
}
6572

66-
protected AmazonBedrockInferenceClient(AmazonBedrockModel model, @Nullable TimeValue timeout) {
73+
protected AmazonBedrockInferenceClient(AmazonBedrockModel model, @Nullable TimeValue timeout, ThreadPool threadPool) {
6774
super(model, timeout);
6875
this.internalClient = createAmazonBedrockClient(model, timeout);
76+
this.threadPool = Objects.requireNonNull(threadPool);
6977
setExpiryTimestamp();
7078
}
7179

@@ -79,6 +87,16 @@ public void converse(ConverseRequest converseRequest, ActionListener<ConverseRes
7987
}
8088
}
8189

90+
@Override
91+
public Flow.Publisher<? extends ChunkedToXContent> converseStream(ConverseStreamRequest request) throws ElasticsearchException {
92+
var awsResponseProcessor = new AmazonBedrockStreamingChatProcessor(threadPool);
93+
internalClient.converseStream(
94+
request,
95+
ConverseStreamResponseHandler.builder().subscriber(() -> FlowAdapters.toSubscriber(awsResponseProcessor)).build()
96+
);
97+
return awsResponseProcessor;
98+
}
99+
82100
private void onFailure(ActionListener<?> listener, Throwable t, String method) {
83101
var unwrappedException = t;
84102
if (t instanceof CompletionException || t instanceof ExecutionException) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClientCache.java

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,9 @@ public final class AmazonBedrockInferenceClientCache implements AmazonBedrockCli
2929
// not final for testing
3030
private Clock clock;
3131

32-
public AmazonBedrockInferenceClientCache(
33-
BiFunction<AmazonBedrockModel, TimeValue, AmazonBedrockBaseClient> creator,
34-
@Nullable Clock clock
35-
) {
32+
public AmazonBedrockInferenceClientCache(BiFunction<AmazonBedrockModel, TimeValue, AmazonBedrockBaseClient> creator, Clock clock) {
3633
this.creator = Objects.requireNonNull(creator);
37-
this.clock = Objects.requireNonNullElse(clock, Clock.systemUTC());
34+
this.clock = Objects.requireNonNull(clock);
3835
}
3936

4037
public AmazonBedrockBaseClient getOrCreateClient(AmazonBedrockModel model, @Nullable TimeValue timeout) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSender.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.elasticsearch.xpack.inference.services.ServiceComponents;
2424

2525
import java.io.IOException;
26+
import java.time.Clock;
2627
import java.util.Objects;
2728
import java.util.concurrent.CountDownLatch;
2829
import java.util.concurrent.TimeUnit;
@@ -42,7 +43,10 @@ public Factory(ServiceComponents serviceComponents, ClusterService clusterServic
4243
}
4344

4445
public Sender createSender() {
45-
var clientCache = new AmazonBedrockInferenceClientCache(AmazonBedrockInferenceClient::create, null);
46+
var clientCache = new AmazonBedrockInferenceClientCache(
47+
(model, timeout) -> AmazonBedrockInferenceClient.create(model, timeout, serviceComponents.threadPool()),
48+
Clock.systemUTC()
49+
);
4650
return createSender(new AmazonBedrockExecuteOnlyRequestSender(clientCache, serviceComponents.throttlerManager()));
4751
}
4852

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.amazonbedrock;
9+
10+
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDeltaEvent;
11+
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput;
12+
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler;
13+
14+
import org.elasticsearch.ElasticsearchException;
15+
import org.elasticsearch.common.util.concurrent.EsExecutors;
16+
import org.elasticsearch.core.Strings;
17+
import org.elasticsearch.threadpool.ThreadPool;
18+
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
19+
20+
import java.util.ArrayDeque;
21+
import java.util.concurrent.CompletableFuture;
22+
import java.util.concurrent.Flow;
23+
import java.util.concurrent.atomic.AtomicBoolean;
24+
import java.util.concurrent.atomic.AtomicLong;
25+
import java.util.concurrent.atomic.AtomicReference;
26+
27+
import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME;
28+
29+
class AmazonBedrockStreamingChatProcessor implements Flow.Processor<ConverseStreamOutput, StreamingChatCompletionResults.Results> {
30+
private final AtomicReference<Throwable> error = new AtomicReference<>(null);
31+
private final AtomicLong demand = new AtomicLong(0);
32+
private final AtomicBoolean isDone = new AtomicBoolean(false);
33+
private final AtomicBoolean onCompleteCalled = new AtomicBoolean(false);
34+
private final AtomicBoolean onErrorCalled = new AtomicBoolean(false);
35+
private final ThreadPool threadPool;
36+
private volatile Flow.Subscriber<? super StreamingChatCompletionResults.Results> downstream;
37+
private volatile Flow.Subscription upstream;
38+
39+
AmazonBedrockStreamingChatProcessor(ThreadPool threadPool) {
40+
this.threadPool = threadPool;
41+
}
42+
43+
@Override
44+
public void subscribe(Flow.Subscriber<? super StreamingChatCompletionResults.Results> subscriber) {
45+
if (downstream == null) {
46+
downstream = subscriber;
47+
downstream.onSubscribe(new StreamSubscription());
48+
} else {
49+
subscriber.onError(new IllegalStateException("Subscriber already set."));
50+
}
51+
}
52+
53+
@Override
54+
public void onSubscribe(Flow.Subscription subscription) {
55+
if (upstream == null) {
56+
upstream = subscription;
57+
var currentRequestCount = demand.getAndUpdate(i -> 0);
58+
if (currentRequestCount > 0) {
59+
upstream.request(currentRequestCount);
60+
}
61+
} else {
62+
subscription.cancel();
63+
}
64+
}
65+
66+
@Override
67+
public void onNext(ConverseStreamOutput item) {
68+
if (item.sdkEventType() == ConverseStreamOutput.EventType.CONTENT_BLOCK_DELTA) {
69+
demand.set(0); // reset demand before we fork to another thread
70+
item.accept(ConverseStreamResponseHandler.Visitor.builder().onContentBlockDelta(this::sendDownstreamOnAnotherThread).build());
71+
} else {
72+
upstream.request(1);
73+
}
74+
}
75+
76+
// this is always called from a netty thread maintained by the AWS SDK, we'll move it to our thread to process the response
77+
private void sendDownstreamOnAnotherThread(ContentBlockDeltaEvent event) {
78+
CompletableFuture.runAsync(() -> {
79+
var text = event.delta().text();
80+
var result = new ArrayDeque<StreamingChatCompletionResults.Result>(1);
81+
result.offer(new StreamingChatCompletionResults.Result(text));
82+
var results = new StreamingChatCompletionResults.Results(result);
83+
downstream.onNext(results);
84+
}, threadPool.executor(UTILITY_THREAD_POOL_NAME));
85+
}
86+
87+
@Override
88+
public void onError(Throwable amazonBedrockRuntimeException) {
89+
error.set(
90+
new ElasticsearchException(
91+
Strings.format("AmazonBedrock StreamingChatProcessor failure: [%s]", amazonBedrockRuntimeException.getMessage()),
92+
amazonBedrockRuntimeException
93+
)
94+
);
95+
if (isDone.compareAndSet(false, true) && checkAndResetDemand() && onErrorCalled.compareAndSet(false, true)) {
96+
downstream.onError(error.get());
97+
}
98+
}
99+
100+
private boolean checkAndResetDemand() {
101+
return demand.getAndUpdate(i -> 0L) > 0L;
102+
}
103+
104+
@Override
105+
public void onComplete() {
106+
if (isDone.compareAndSet(false, true) && checkAndResetDemand() && onCompleteCalled.compareAndSet(false, true)) {
107+
downstream.onComplete();
108+
}
109+
}
110+
111+
private class StreamSubscription implements Flow.Subscription {
112+
@Override
113+
public void request(long n) {
114+
if (n > 0L) {
115+
demand.updateAndGet(i -> {
116+
var sum = i + n;
117+
return sum >= 0 ? sum : Long.MAX_VALUE;
118+
});
119+
if (upstream == null) {
120+
// wait for upstream to subscribe before forwarding request
121+
return;
122+
}
123+
if (upstreamIsRunning()) {
124+
requestOnMlThread(n);
125+
} else if (error.get() != null && onErrorCalled.compareAndSet(false, true)) {
126+
downstream.onError(error.get());
127+
} else if (onCompleteCalled.compareAndSet(false, true)) {
128+
downstream.onComplete();
129+
}
130+
} else {
131+
cancel();
132+
downstream.onError(new IllegalStateException("Cannot request a negative number."));
133+
}
134+
}
135+
136+
private boolean upstreamIsRunning() {
137+
return isDone.get() == false && error.get() == null;
138+
}
139+
140+
private void requestOnMlThread(long n) {
141+
var currentThreadPool = EsExecutors.executorName(Thread.currentThread().getName());
142+
if (UTILITY_THREAD_POOL_NAME.equalsIgnoreCase(currentThreadPool)) {
143+
upstream.request(n);
144+
} else {
145+
CompletableFuture.runAsync(() -> upstream.request(n), threadPool.executor(UTILITY_THREAD_POOL_NAME));
146+
}
147+
}
148+
149+
@Override
150+
public void cancel() {
151+
if (upstream != null && upstreamIsRunning()) {
152+
upstream.cancel();
153+
}
154+
}
155+
}
156+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockChatCompletionRequestManager.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import org.elasticsearch.xpack.inference.external.response.amazonbedrock.completion.AmazonBedrockChatCompletionResponseHandler;
2323
import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModel;
2424

25-
import java.util.List;
2625
import java.util.function.Supplier;
2726

2827
public class AmazonBedrockChatCompletionRequestManager extends AmazonBedrockRequestManager {
@@ -45,9 +44,11 @@ public void execute(
4544
Supplier<Boolean> hasRequestCompletedFunction,
4645
ActionListener<InferenceServiceResults> listener
4746
) {
48-
List<String> docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs();
47+
var docsOnly = DocumentsOnlyInput.of(inferenceInputs);
48+
var docsInput = docsOnly.getInputs();
49+
var stream = docsOnly.stream();
4950
var requestEntity = AmazonBedrockChatCompletionEntityFactory.createEntity(model, docsInput);
50-
var request = new AmazonBedrockChatCompletionRequest(model, requestEntity, timeout);
51+
var request = new AmazonBedrockChatCompletionRequest(model, requestEntity, timeout, stream);
5152
var responseHandler = new AmazonBedrockChatCompletionResponseHandler();
5253

5354
try {

0 commit comments

Comments
 (0)