Skip to content

Commit 4254002

Browse files
Add Amazon Bedrock Unified Chat Completions support
1 parent 77664b7 commit 4254002

File tree

3 files changed

+181
-1
lines changed

3 files changed

+181
-1
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/completion/AmazonBedrockUnifiedChatCompletionRequest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ public Flow.Publisher<StreamingUnifiedChatCompletionResults.Results> executeStre
8787
}
8888

8989
inferenceConfig(requestEntity).ifPresent(converseStreamRequest::inferenceConfig);
90-
return awsBedrockClient.converseUnifiedStream(converseStreamRequest.build());
90+
return new ToolAwareUnifiedPublisher(awsBedrockClient, converseStreamRequest.build());
9191
}
9292

9393
private Document toDocument(Object value) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion;
9+
10+
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
11+
import org.elasticsearch.xpack.inference.services.amazonbedrock.client.AmazonBedrockBaseClient;
12+
13+
import software.amazon.awssdk.core.document.Document;
14+
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock;
15+
import software.amazon.awssdk.services.bedrockruntime.model.ConversationRole;
16+
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest;
17+
import software.amazon.awssdk.services.bedrockruntime.model.Message;
18+
import software.amazon.awssdk.services.bedrockruntime.model.ToolResultBlock;
19+
import software.amazon.awssdk.services.bedrockruntime.model.ToolResultContentBlock;
20+
21+
import java.util.ArrayList;
22+
import java.util.Collection;
23+
import java.util.List;
24+
import java.util.concurrent.Flow;
25+
26+
public class ToolAwareUnifiedPublisher implements Flow.Publisher<StreamingUnifiedChatCompletionResults.Results> {
27+
private final AmazonBedrockBaseClient client;
28+
private ConverseStreamRequest request;
29+
30+
ToolAwareUnifiedPublisher(AmazonBedrockBaseClient client, ConverseStreamRequest request) {
31+
this.client = client;
32+
this.request = request;
33+
}
34+
35+
@SuppressWarnings("checkstyle:DescendantToken")
36+
@Override
37+
public void subscribe(Flow.Subscriber<? super StreamingUnifiedChatCompletionResults.Results> subscriber) {
38+
subscriber.onSubscribe(new Flow.Subscription() {
39+
boolean cancelled = false;
40+
41+
@SuppressWarnings("checkstyle:DescendantToken")
42+
@Override
43+
public void request(long n) {
44+
if (cancelled) {
45+
return;
46+
}
47+
try {
48+
var history = new ArrayList<>(request.messages());
49+
ConverseStreamRequest currentRequest = request;
50+
51+
while (!cancelled) {
52+
List<ToolUseInfo> toolUses = new ArrayList<>();
53+
String[] stopReasons = new String[1];
54+
var round = client.converseUnifiedStream(currentRequest.toBuilder().messages(history).build());
55+
56+
round.subscribe(new Flow.Subscriber<>() {
57+
@Override
58+
public void onSubscribe(Flow.Subscription subscription) {
59+
subscription.request(Long.MAX_VALUE);
60+
}
61+
62+
@Override
63+
public void onNext(StreamingUnifiedChatCompletionResults.Results results) {
64+
subscriber.onNext(results);
65+
66+
for (var chunk : results.chunks()) {
67+
for (var choice : chunk.choices()) {
68+
if (choice.finishReason() != null) {
69+
stopReasons[0] = choice.finishReason();
70+
}
71+
72+
var delta = choice.delta();
73+
if (delta == null) {
74+
continue;
75+
}
76+
77+
var calls = delta.toolCalls();
78+
if (calls != null && !calls.isEmpty()) {
79+
for (var call : calls) {
80+
String id = call.id();
81+
String name = call.function().name();
82+
83+
if (id != null && name != null) {
84+
toolUses.add(new ToolUseInfo(id, name));
85+
}
86+
87+
}
88+
}
89+
}
90+
}
91+
}
92+
93+
@Override
94+
public void onError(Throwable throwable) {
95+
subscriber.onError(throwable);
96+
97+
}
98+
99+
@Override
100+
public void onComplete() {
101+
102+
}
103+
});
104+
105+
106+
boolean toolRequested = "TOOL_USE".equalsIgnoreCase(stopReasons[0]) || !toolUses.isEmpty();
107+
108+
if (!toolRequested) {
109+
break;
110+
}
111+
112+
List<ContentBlock> toolResultBlocks = new ArrayList<>();
113+
for (var toolUse : toolUses) {
114+
115+
String jsonIn = toolUse.inputJson.toString();
116+
// String jsonOut = execute(toolUse.getName(), jsonIn);
117+
String jsonOut = "";
118+
119+
toolResultBlocks.add(
120+
ContentBlock.builder()
121+
.toolResult(ToolResultBlock.builder()
122+
.toolUseId(toolUse.getId())
123+
.content((Collection<ToolResultContentBlock>) Document.fromString(jsonOut))
124+
.build())
125+
.build());
126+
127+
Message toolResultMsg = Message.builder()
128+
.role(ConversationRole.USER)
129+
.content(toolResultBlocks)
130+
.build();
131+
132+
history.add(toolResultMsg);
133+
134+
currentRequest = currentRequest.toBuilder().messages(history).build();
135+
}
136+
}
137+
subscriber.onComplete();
138+
} catch (Throwable throwable) {
139+
subscriber.onError(throwable);
140+
}
141+
}
142+
143+
@Override
144+
public void cancel() {
145+
cancelled = true;
146+
}
147+
});
148+
}
149+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion;
9+
10+
public class ToolUseInfo {
11+
final String id;
12+
final String name;
13+
final StringBuilder inputJson = new StringBuilder();
14+
15+
public ToolUseInfo(String id, String name) {
16+
this.id = id;
17+
this.name = name;
18+
}
19+
20+
public String getId() {
21+
return id;
22+
}
23+
24+
public String getName() {
25+
return name;
26+
}
27+
28+
public StringBuilder getInputJson() {
29+
return inputJson;
30+
}
31+
}

0 commit comments

Comments
 (0)