2020import java .time .Duration ;
2121import java .util .Set ;
2222import java .util .function .Supplier ;
23+ import java .util .stream .Collectors ;
2324
2425import org .junit .jupiter .api .Test ;
2526import org .slf4j .Logger ;
2627import org .slf4j .LoggerFactory ;
28+ import reactor .core .publisher .Flux ;
2729import software .amazon .awssdk .auth .credentials .EnvironmentVariableCredentialsProvider ;
2830import software .amazon .awssdk .regions .Region ;
2931
3032import org .springframework .ai .bedrock .converse .BedrockProxyChatModel ;
3133import org .springframework .ai .bedrock .converse .RequiresAwsCredentials ;
3234import org .springframework .ai .chat .client .ChatClient ;
35+ import org .springframework .ai .chat .client .ChatClient .StreamResponseSpec ;
3336import org .springframework .ai .chat .model .ChatModel ;
37+ import org .springframework .ai .chat .model .ChatResponse ;
3438import org .springframework .ai .content .Media ;
3539import org .springframework .ai .model .tool .ToolCallingChatOptions ;
3640import org .springframework .ai .tool .annotation .Tool ;
@@ -175,7 +179,7 @@ public record WeatherResponse(int temp, String unit) {
175179
176180 // https://github.com/spring-projects/spring-ai/issues/1878
177181 @ Test
178- void toolAnnotationWeatherForecastTest () {
182+ void toolAnnotationWeatherForecast () {
179183
180184 ChatClient chatClient = ChatClient .builder (this .chatModel ).build ();
181185
@@ -189,6 +193,27 @@ void toolAnnotationWeatherForecastTest() {
189193 assertThat (response ).contains ("20 degrees" );
190194 }
191195
196+ @ Test
197+ void toolAnnotationWeatherForecastStreaming () {
198+
199+ ChatClient chatClient = ChatClient .builder (this .chatModel ).build ();
200+
201+ Flux <ChatResponse > responses = chatClient .prompt ()
202+ .tools (new DummyWeatherForcastTools ())
203+ .user ("Get current weather in Amsterdam" )
204+ .stream ()
205+ .chatResponse ();
206+
207+ String content = responses .collectList ()
208+ .block ()
209+ .stream ()
210+ .filter (cr -> cr .getResult () != null )
211+ .map (cr -> cr .getResult ().getOutput ().getText ())
212+ .collect (Collectors .joining ());
213+
214+ assertThat (content ).contains ("20 degrees" );
215+ }
216+
192217 public static class DummyWeatherForcastTools {
193218
194219 @ Tool (description = "Get the current weather forcast in Amsterdam" )
@@ -217,6 +242,30 @@ void supplierBasedToolCalling() {
217242 assertThat (response .temp ()).isEqualTo (30.0 );
218243 }
219244
245+ @ Test
246+ void supplierBasedToolCallingStreaming () {
247+
248+ ChatClient chatClient = ChatClient .builder (this .chatModel ).build ();
249+
250+ Flux <ChatResponse > responses = chatClient .prompt ()
251+ .toolCallbacks (FunctionToolCallback .builder ("weather" , new WeatherService ())
252+ .description ("Get the current weather" )
253+ .inputType (Void .class )
254+ .build ())
255+ .user ("Get current weather in Amsterdam" )
256+ .stream ()
257+ .chatResponse ();
258+
259+ String content = responses .collectList ()
260+ .block ()
261+ .stream ()
262+ .filter (cr -> cr .getResult () != null )
263+ .map (cr -> cr .getResult ().getOutput ().getText ())
264+ .collect (Collectors .joining ());
265+
266+ assertThat (content ).contains ("30.0" );
267+ }
268+
220269 public static class WeatherService implements Supplier <WeatherService .Response > {
221270
222271 public record Response (double temp ) {
0 commit comments