Skip to content

Commit 4d3868e

Browse files
tzolovnamsoo2
authored andcommitted
Additional tests for the spring-projects#1878
Add additional integration tests to ensure that the spring-projects#1878 issue is resolved Signed-off-by: Christian Tzolov <[email protected]> Signed-off-by: minsoo.nam <[email protected]>
1 parent c21190d commit 4d3868e

File tree

1 file changed

+50
-1
lines changed

1 file changed

+50
-1
lines changed

models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/client/BedrockNovaChatClientIT.java

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,21 @@
2020
import java.time.Duration;
2121
import java.util.Set;
2222
import java.util.function.Supplier;
23+
import java.util.stream.Collectors;
2324

2425
import org.junit.jupiter.api.Test;
2526
import org.slf4j.Logger;
2627
import org.slf4j.LoggerFactory;
28+
import reactor.core.publisher.Flux;
2729
import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider;
2830
import software.amazon.awssdk.regions.Region;
2931

3032
import org.springframework.ai.bedrock.converse.BedrockProxyChatModel;
3133
import org.springframework.ai.bedrock.converse.RequiresAwsCredentials;
3234
import org.springframework.ai.chat.client.ChatClient;
35+
import org.springframework.ai.chat.client.ChatClient.StreamResponseSpec;
3336
import org.springframework.ai.chat.model.ChatModel;
37+
import org.springframework.ai.chat.model.ChatResponse;
3438
import org.springframework.ai.content.Media;
3539
import org.springframework.ai.model.tool.ToolCallingChatOptions;
3640
import org.springframework.ai.tool.annotation.Tool;
@@ -175,7 +179,7 @@ public record WeatherResponse(int temp, String unit) {
175179

176180
// https://github.com/spring-projects/spring-ai/issues/1878
177181
@Test
178-
void toolAnnotationWeatherForecastTest() {
182+
void toolAnnotationWeatherForecast() {
179183

180184
ChatClient chatClient = ChatClient.builder(this.chatModel).build();
181185

@@ -189,6 +193,27 @@ void toolAnnotationWeatherForecastTest() {
189193
assertThat(response).contains("20 degrees");
190194
}
191195

196+
@Test
197+
void toolAnnotationWeatherForecastStreaming() {
198+
199+
ChatClient chatClient = ChatClient.builder(this.chatModel).build();
200+
201+
Flux<ChatResponse> responses = chatClient.prompt()
202+
.tools(new DummyWeatherForcastTools())
203+
.user("Get current weather in Amsterdam")
204+
.stream()
205+
.chatResponse();
206+
207+
String content = responses.collectList()
208+
.block()
209+
.stream()
210+
.filter(cr -> cr.getResult() != null)
211+
.map(cr -> cr.getResult().getOutput().getText())
212+
.collect(Collectors.joining());
213+
214+
assertThat(content).contains("20 degrees");
215+
}
216+
192217
public static class DummyWeatherForcastTools {
193218

194219
@Tool(description = "Get the current weather forcast in Amsterdam")
@@ -217,6 +242,30 @@ void supplierBasedToolCalling() {
217242
assertThat(response.temp()).isEqualTo(30.0);
218243
}
219244

245+
@Test
246+
void supplierBasedToolCallingStreaming() {
247+
248+
ChatClient chatClient = ChatClient.builder(this.chatModel).build();
249+
250+
Flux<ChatResponse> responses = chatClient.prompt()
251+
.toolCallbacks(FunctionToolCallback.builder("weather", new WeatherService())
252+
.description("Get the current weather")
253+
.inputType(Void.class)
254+
.build())
255+
.user("Get current weather in Amsterdam")
256+
.stream()
257+
.chatResponse();
258+
259+
String content = responses.collectList()
260+
.block()
261+
.stream()
262+
.filter(cr -> cr.getResult() != null)
263+
.map(cr -> cr.getResult().getOutput().getText())
264+
.collect(Collectors.joining());
265+
266+
assertThat(content).contains("30.0");
267+
}
268+
220269
public static class WeatherService implements Supplier<WeatherService.Response> {
221270

222271
public record Response(double temp) {

0 commit comments

Comments
 (0)