Skip to content

Commit 6b2e904

Browse files
Add Amazon Bedrock Unified Chat Completions support
1 parent bbcbe00 commit 6b2e904

File tree

1 file changed

+60
-1
lines changed

1 file changed

+60
-1
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/completion/AmazonBedrockUnifiedChatCompletionRequest.java

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
package org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion;
99

1010
import software.amazon.awssdk.core.document.Document;
11+
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDeltaEvent;
12+
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockStartEvent;
1113
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest;
14+
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler;
15+
import software.amazon.awssdk.services.bedrockruntime.model.MessageStopEvent;
1216
import software.amazon.awssdk.services.bedrockruntime.model.SpecificToolChoice;
1317
import software.amazon.awssdk.services.bedrockruntime.model.Tool;
1418
import software.amazon.awssdk.services.bedrockruntime.model.ToolChoice;
@@ -30,6 +34,8 @@
3034
import java.util.HashMap;
3135
import java.util.Map;
3236
import java.util.Objects;
37+
import java.util.concurrent.CompletableFuture;
38+
import java.util.concurrent.ExecutionException;
3339
import java.util.concurrent.Flow;
3440

3541
import static org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion.AmazonBedrockConverseUtils.getUnifiedConverseMessageList;
@@ -54,7 +60,7 @@ public AmazonBedrockUnifiedChatCompletionRequest(
5460

5561
public Flow.Publisher<StreamingUnifiedChatCompletionResults.Results> executeStreamChatCompletionRequest(
5662
AmazonBedrockBaseClient awsBedrockClient
57-
) {
63+
) throws ExecutionException, InterruptedException {
5864
var converseStreamRequest = ConverseStreamRequest.builder()
5965
.messages(getUnifiedConverseMessageList(requestEntity.messages()))
6066
.modelId(amazonBedrockModel.model());
@@ -86,6 +92,59 @@ public Flow.Publisher<StreamingUnifiedChatCompletionResults.Results> executeStre
8692
});
8793
}
8894

95+
inferenceConfig(requestEntity).ifPresent(converseStreamRequest::inferenceConfig);
96+
var response = awsBedrockClient.converseUnifiedStream(converseStreamRequest.build());
97+
98+
var toolRequested = new CompletableFuture<Boolean>();
99+
final String[] toolUseIdHolder = new String[1];
100+
final StringBuilder toolJsonArgs = new StringBuilder();
101+
final StringBuilder assistantText = new StringBuilder();
102+
103+
var handler = ConverseStreamResponseHandler.builder()
104+
.onEventStream(es -> es.subscribe(event -> {
105+
switch (event.sdkEventType()) {
106+
case MESSAGE_START:
107+
break;
108+
case CONTENT_BLOCK_START:
109+
var start = ((ContentBlockStartEvent) event).start();
110+
if (start.toolUse() != null) {
111+
toolUseIdHolder[0] = start.toolUse().toolUseId();
112+
}
113+
break;
114+
case CONTENT_BLOCK_DELTA:
115+
var delta = ((ContentBlockDeltaEvent) event).delta();
116+
if (delta.toolUse() != null && delta.toolUse().input() != null) {
117+
toolJsonArgs.append(delta.toolUse().input());
118+
}
119+
if (delta.text() != null) {
120+
assistantText.append(delta.text());
121+
}
122+
break;
123+
case MESSAGE_STOP:
124+
var stop = ((MessageStopEvent) event).stopReason();
125+
if ("tool_use".equalsIgnoreCase(stop.name())) {
126+
toolRequested.complete(true);
127+
} else {
128+
toolRequested.complete(false);
129+
}
130+
break;
131+
default:
132+
}
133+
}))
134+
.onResponse(r -> toolRequested.complete(true))
135+
.onError(toolRequested::completeExceptionally);
136+
137+
handler.subscriber(converseStreamOutput ->
138+
getUnifiedConverseMessageList(requestEntity.messages()).forEach(toolJsonArgs::append));
139+
140+
if (Boolean.TRUE.equals(toolRequested.get())) {
141+
toolJsonArgs.toString().contains("args");
142+
Map<String, Object> result = Map.of("tool_use", toolUseIdHolder[0]);
143+
// var toolResultBlock = ContentBlock
144+
// .fromToolResult(ToolResultContentBlock.builder()
145+
// .document(DocumentBlock.builder().context(result).build()));
146+
147+
}
89148
inferenceConfig(requestEntity).ifPresent(converseStreamRequest::inferenceConfig);
90149
return awsBedrockClient.converseUnifiedStream(converseStreamRequest.build());
91150
}

0 commit comments

Comments
 (0)