8
8
package org .elasticsearch .xpack .inference .services .amazonbedrock .request .completion ;
9
9
10
10
import software .amazon .awssdk .core .document .Document ;
11
+ import software .amazon .awssdk .services .bedrockruntime .model .ContentBlockDeltaEvent ;
12
+ import software .amazon .awssdk .services .bedrockruntime .model .ContentBlockStartEvent ;
11
13
import software .amazon .awssdk .services .bedrockruntime .model .ConverseStreamRequest ;
14
+ import software .amazon .awssdk .services .bedrockruntime .model .ConverseStreamResponseHandler ;
15
+ import software .amazon .awssdk .services .bedrockruntime .model .MessageStopEvent ;
12
16
import software .amazon .awssdk .services .bedrockruntime .model .SpecificToolChoice ;
13
17
import software .amazon .awssdk .services .bedrockruntime .model .Tool ;
14
18
import software .amazon .awssdk .services .bedrockruntime .model .ToolChoice ;
30
34
import java .util .HashMap ;
31
35
import java .util .Map ;
32
36
import java .util .Objects ;
37
+ import java .util .concurrent .CompletableFuture ;
38
+ import java .util .concurrent .ExecutionException ;
33
39
import java .util .concurrent .Flow ;
34
40
35
41
import static org .elasticsearch .xpack .inference .services .amazonbedrock .request .completion .AmazonBedrockConverseUtils .getUnifiedConverseMessageList ;
@@ -54,7 +60,7 @@ public AmazonBedrockUnifiedChatCompletionRequest(
54
60
55
61
public Flow .Publisher <StreamingUnifiedChatCompletionResults .Results > executeStreamChatCompletionRequest (
56
62
AmazonBedrockBaseClient awsBedrockClient
57
- ) {
63
+ ) throws ExecutionException , InterruptedException {
58
64
var converseStreamRequest = ConverseStreamRequest .builder ()
59
65
.messages (getUnifiedConverseMessageList (requestEntity .messages ()))
60
66
.modelId (amazonBedrockModel .model ());
@@ -86,6 +92,59 @@ public Flow.Publisher<StreamingUnifiedChatCompletionResults.Results> executeStre
86
92
});
87
93
}
88
94
95
+ inferenceConfig (requestEntity ).ifPresent (converseStreamRequest ::inferenceConfig );
96
+ var response = awsBedrockClient .converseUnifiedStream (converseStreamRequest .build ());
97
+
98
+ var toolRequested = new CompletableFuture <Boolean >();
99
+ final String [] toolUseIdHolder = new String [1 ];
100
+ final StringBuilder toolJsonArgs = new StringBuilder ();
101
+ final StringBuilder assistantText = new StringBuilder ();
102
+
103
+ var handler = ConverseStreamResponseHandler .builder ()
104
+ .onEventStream (es -> es .subscribe (event -> {
105
+ switch (event .sdkEventType ()) {
106
+ case MESSAGE_START :
107
+ break ;
108
+ case CONTENT_BLOCK_START :
109
+ var start = ((ContentBlockStartEvent ) event ).start ();
110
+ if (start .toolUse () != null ) {
111
+ toolUseIdHolder [0 ] = start .toolUse ().toolUseId ();
112
+ }
113
+ break ;
114
+ case CONTENT_BLOCK_DELTA :
115
+ var delta = ((ContentBlockDeltaEvent ) event ).delta ();
116
+ if (delta .toolUse () != null && delta .toolUse ().input () != null ) {
117
+ toolJsonArgs .append (delta .toolUse ().input ());
118
+ }
119
+ if (delta .text () != null ) {
120
+ assistantText .append (delta .text ());
121
+ }
122
+ break ;
123
+ case MESSAGE_STOP :
124
+ var stop = ((MessageStopEvent ) event ).stopReason ();
125
+ if ("tool_use" .equalsIgnoreCase (stop .name ())) {
126
+ toolRequested .complete (true );
127
+ } else {
128
+ toolRequested .complete (false );
129
+ }
130
+ break ;
131
+ default :
132
+ }
133
+ }))
134
+ .onResponse (r -> toolRequested .complete (true ))
135
+ .onError (toolRequested ::completeExceptionally );
136
+
137
+ handler .subscriber (converseStreamOutput ->
138
+ getUnifiedConverseMessageList (requestEntity .messages ()).forEach (toolJsonArgs ::append ));
139
+
140
+ if (Boolean .TRUE .equals (toolRequested .get ())) {
141
+ toolJsonArgs .toString ().contains ("args" );
142
+ Map <String , Object > result = Map .of ("tool_use" , toolUseIdHolder [0 ]);
143
+ // var toolResultBlock = ContentBlock
144
+ // .fromToolResult(ToolResultContentBlock.builder()
145
+ // .document(DocumentBlock.builder().context(result).build()));
146
+
147
+ }
89
148
inferenceConfig (requestEntity ).ifPresent (converseStreamRequest ::inferenceConfig );
90
149
return awsBedrockClient .converseUnifiedStream (converseStreamRequest .build ());
91
150
}
0 commit comments