Skip to content

Commit 0dfa7da

Browse files
ilayaperumalgchedim
authored andcommitted
Bedrock converse chat model to merge toolcalling chat options (spring-projects#4314)
- When building the chat client request, if the ToolCallingChatOptions is passed via prompt, merge the options into Bedrock chat options. Previously, this was ignored. - Add test to verify function calling when ToolCallingChatOptions is passed Auto-cherry-pick to 1.0.x Fixes: spring-projects#4314 Signed-off-by: Ilayaperumal Gopinathan <[email protected]>
1 parent 19eef31 commit 0dfa7da

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,10 @@ Prompt buildRequestPrompt(Prompt prompt) {
273273
if (prompt.getOptions() instanceof BedrockChatOptions bedrockChatOptions) {
274274
runtimeOptions = bedrockChatOptions.copy();
275275
}
276+
else if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) {
277+
runtimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class,
278+
BedrockChatOptions.class);
279+
}
276280
else {
277281
runtimeOptions = from(prompt.getOptions());
278282
}

models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import org.springframework.ai.converter.BeanOutputConverter;
4747
import org.springframework.ai.converter.ListOutputConverter;
4848
import org.springframework.ai.converter.MapOutputConverter;
49+
import org.springframework.ai.model.tool.ToolCallingChatOptions;
4950
import org.springframework.ai.tool.function.FunctionToolCallback;
5051
import org.springframework.beans.factory.annotation.Autowired;
5152
import org.springframework.beans.factory.annotation.Value;
@@ -279,6 +280,29 @@ void functionCallTest() {
279280
assertThat(generation.getOutput().getText()).contains("30", "10", "15");
280281
}
281282

283+
@Test
284+
void functionCallTestWithToolCallingOptions() {
285+
286+
UserMessage userMessage = new UserMessage(
287+
"What's the weather like in San Francisco, Tokyo and Paris? Return the result in Celsius.");
288+
289+
List<Message> messages = new ArrayList<>(List.of(userMessage));
290+
291+
var promptOptions = ToolCallingChatOptions.builder()
292+
.toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
293+
.description("Get the weather in location. Return in 36°C format")
294+
.inputType(MockWeatherService.Request.class)
295+
.build()))
296+
.build();
297+
298+
ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions));
299+
300+
logger.info("Response: {}", response);
301+
302+
Generation generation = response.getResult();
303+
assertThat(generation.getOutput().getText()).contains("30", "10", "15");
304+
}
305+
282306
@Test
283307
void streamFunctionCallTest() {
284308

0 commit comments

Comments
 (0)