88package org .elasticsearch .xpack .inference .services .amazonbedrock .request .completion ;
99
1010import software .amazon .awssdk .core .document .Document ;
11- import software .amazon .awssdk .services .bedrockruntime .model .ContentBlockDeltaEvent ;
12- import software .amazon .awssdk .services .bedrockruntime .model .ContentBlockStartEvent ;
1311import software .amazon .awssdk .services .bedrockruntime .model .ConverseStreamRequest ;
14- import software .amazon .awssdk .services .bedrockruntime .model .ConverseStreamResponseHandler ;
15- import software .amazon .awssdk .services .bedrockruntime .model .MessageStopEvent ;
1612import software .amazon .awssdk .services .bedrockruntime .model .SpecificToolChoice ;
1713import software .amazon .awssdk .services .bedrockruntime .model .Tool ;
1814import software .amazon .awssdk .services .bedrockruntime .model .ToolChoice ;
3430import java .util .HashMap ;
3531import java .util .Map ;
3632import java .util .Objects ;
37- import java .util .concurrent .CompletableFuture ;
38- import java .util .concurrent .ExecutionException ;
3933import java .util .concurrent .Flow ;
4034
4135import static org .elasticsearch .xpack .inference .services .amazonbedrock .request .completion .AmazonBedrockConverseUtils .getUnifiedConverseMessageList ;
@@ -60,7 +54,7 @@ public AmazonBedrockUnifiedChatCompletionRequest(
6054
6155 public Flow .Publisher <StreamingUnifiedChatCompletionResults .Results > executeStreamChatCompletionRequest (
6256 AmazonBedrockBaseClient awsBedrockClient
63- ) throws ExecutionException , InterruptedException {
57+ ) {
6458 var converseStreamRequest = ConverseStreamRequest .builder ()
6559 .messages (getUnifiedConverseMessageList (requestEntity .messages ()))
6660 .modelId (amazonBedrockModel .model ());
@@ -92,55 +86,6 @@ public Flow.Publisher<StreamingUnifiedChatCompletionResults.Results> executeStre
9286 });
9387 }
9488
95- inferenceConfig (requestEntity ).ifPresent (converseStreamRequest ::inferenceConfig );
96- var response = awsBedrockClient .converseUnifiedStream (converseStreamRequest .build ());
97-
98- var toolRequested = new CompletableFuture <Boolean >();
99- final String [] toolUseIdHolder = new String [1 ];
100- final StringBuilder toolJsonArgs = new StringBuilder ();
101- final StringBuilder assistantText = new StringBuilder ();
102-
103- var handler = ConverseStreamResponseHandler .builder ().onEventStream (es -> es .subscribe (event -> {
104- switch (event .sdkEventType ()) {
105- case MESSAGE_START :
106- break ;
107- case CONTENT_BLOCK_START :
108- var start = ((ContentBlockStartEvent ) event ).start ();
109- if (start .toolUse () != null ) {
110- toolUseIdHolder [0 ] = start .toolUse ().toolUseId ();
111- }
112- break ;
113- case CONTENT_BLOCK_DELTA :
114- var delta = ((ContentBlockDeltaEvent ) event ).delta ();
115- if (delta .toolUse () != null && delta .toolUse ().input () != null ) {
116- toolJsonArgs .append (delta .toolUse ().input ());
117- }
118- if (delta .text () != null ) {
119- assistantText .append (delta .text ());
120- }
121- break ;
122- case MESSAGE_STOP :
123- var stop = ((MessageStopEvent ) event ).stopReason ();
124- if ("tool_use" .equalsIgnoreCase (stop .name ())) {
125- toolRequested .complete (true );
126- } else {
127- toolRequested .complete (false );
128- }
129- break ;
130- default :
131- }
132- })).onResponse (r -> toolRequested .complete (true )).onError (toolRequested ::completeExceptionally );
133-
134- handler .subscriber (converseStreamOutput -> getUnifiedConverseMessageList (requestEntity .messages ()).forEach (toolJsonArgs ::append ));
135-
136- if (Boolean .TRUE .equals (toolRequested .get ())) {
137- toolJsonArgs .toString ().contains ("args" );
138- Map <String , Object > result = Map .of ("tool_use" , toolUseIdHolder [0 ]);
139- // var toolResultBlock = ContentBlock
140- // .fromToolResult(ToolResultContentBlock.builder()
141- // .document(DocumentBlock.builder().context(result).build()));
142-
143- }
14489 inferenceConfig (requestEntity ).ifPresent (converseStreamRequest ::inferenceConfig );
14590 return awsBedrockClient .converseUnifiedStream (converseStreamRequest .build ());
14691 }
0 commit comments