Skip to content

Commit c3e114a

Browse files
authored
Avoid catch (Throwable t) in AmazonBedrockStreamingChatProcessor (#115715) (#115789)
`CompletableFuture.runAsync` implicitly catches all `Throwable` instances thrown by the task, which includes `Error` instances that no reasonable application should catch. Moreover, discarding the return value from these methods means that any such `Error` will be ignored, allowing the JVM to carry on running in an invalid state. This commit replaces these trappy calls with more appropriate exception handling.
1 parent 22e3307 commit c3e114a

File tree

3 files changed

+21
-4
lines changed

3 files changed

+21
-4
lines changed

docs/changelog/115715.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 115715
2+
summary: Avoid `catch (Throwable t)` in `AmazonBedrockStreamingChatProcessor`
3+
area: Machine Learning
4+
type: bug
5+
issues: []

x-pack/plugin/inference/src/main/java/module-info.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
requires org.slf4j;
3434
requires software.amazon.awssdk.retries.api;
3535
requires org.reactivestreams;
36+
requires org.elasticsearch.logging;
3637

3738
exports org.elasticsearch.xpack.inference.action;
3839
exports org.elasticsearch.xpack.inference.registry;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockStreamingChatProcessor.java

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@
1414
import org.elasticsearch.ElasticsearchException;
1515
import org.elasticsearch.common.util.concurrent.EsExecutors;
1616
import org.elasticsearch.core.Strings;
17+
import org.elasticsearch.logging.LogManager;
18+
import org.elasticsearch.logging.Logger;
1719
import org.elasticsearch.threadpool.ThreadPool;
1820
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
1921

2022
import java.util.ArrayDeque;
21-
import java.util.concurrent.CompletableFuture;
2223
import java.util.concurrent.Flow;
2324
import java.util.concurrent.atomic.AtomicBoolean;
2425
import java.util.concurrent.atomic.AtomicLong;
@@ -27,6 +28,8 @@
2728
import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME;
2829

2930
class AmazonBedrockStreamingChatProcessor implements Flow.Processor<ConverseStreamOutput, StreamingChatCompletionResults.Results> {
31+
private static final Logger logger = LogManager.getLogger(AmazonBedrockStreamingChatProcessor.class);
32+
3033
private final AtomicReference<Throwable> error = new AtomicReference<>(null);
3134
private final AtomicLong demand = new AtomicLong(0);
3235
private final AtomicBoolean isDone = new AtomicBoolean(false);
@@ -75,13 +78,13 @@ public void onNext(ConverseStreamOutput item) {
7578

7679
// this is always called from a netty thread maintained by the AWS SDK, we'll move it to our thread to process the response
7780
private void sendDownstreamOnAnotherThread(ContentBlockDeltaEvent event) {
78-
CompletableFuture.runAsync(() -> {
81+
runOnUtilityThreadPool(() -> {
7982
var text = event.delta().text();
8083
var result = new ArrayDeque<StreamingChatCompletionResults.Result>(1);
8184
result.offer(new StreamingChatCompletionResults.Result(text));
8285
var results = new StreamingChatCompletionResults.Results(result);
8386
downstream.onNext(results);
84-
}, threadPool.executor(UTILITY_THREAD_POOL_NAME));
87+
});
8588
}
8689

8790
@Override
@@ -108,6 +111,14 @@ public void onComplete() {
108111
}
109112
}
110113

114+
private void runOnUtilityThreadPool(Runnable runnable) {
115+
try {
116+
threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(runnable);
117+
} catch (Exception e) {
118+
logger.error(Strings.format("failed to fork [%s] to utility thread pool", runnable), e);
119+
}
120+
}
121+
111122
private class StreamSubscription implements Flow.Subscription {
112123
@Override
113124
public void request(long n) {
@@ -142,7 +153,7 @@ private void requestOnMlThread(long n) {
142153
if (UTILITY_THREAD_POOL_NAME.equalsIgnoreCase(currentThreadPool)) {
143154
upstream.request(n);
144155
} else {
145-
CompletableFuture.runAsync(() -> upstream.request(n), threadPool.executor(UTILITY_THREAD_POOL_NAME));
156+
runOnUtilityThreadPool(() -> upstream.request(n));
146157
}
147158
}
148159

0 commit comments

Comments
 (0)