Skip to content

Commit 3b97b98

Browse files
List of tool calls supported
1 parent 9cfc138 commit 3b97b98

File tree

3 files changed

+47
-23
lines changed

3 files changed

+47
-23
lines changed

orchestration/src/main/java/com/sap/ai/sdk/orchestration/AssistantMessage.java

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import com.sap.ai.sdk.orchestration.model.ChatMessage;
44
import com.sap.ai.sdk.orchestration.model.SingleChatMessage;
5+
import java.util.List;
56
import java.util.Map;
67
import javax.annotation.Nonnull;
78
import javax.annotation.Nullable;
@@ -21,7 +22,7 @@ public final class AssistantMessage implements Message {
2122
@Nonnull String content;
2223

2324
/** Tool call if there is any. */
24-
@Nullable ToolCall toolCall = null;
25+
@Nullable List<ToolCall> toolCalls = null;
2526

2627
/**
2728
* Represents a tool call.
@@ -43,24 +44,33 @@ public AssistantMessage(@Nonnull final String content) {
4344
}
4445

4546
/**
46-
* Creates a new assistant message with the given tool call.
47+
* Creates a new assistant message with the given tool calls.
4748
*
48-
* @param toolCall the tool call object
49+
* @param toolCalls list of tool call objects
4950
*/
50-
public AssistantMessage(@Nonnull final ToolCall toolCall) {
51+
public AssistantMessage(@Nonnull final List<ToolCall> toolCalls) {
5152
content = "";
52-
this.toolCall = toolCall;
53+
this.toolCalls = toolCalls;
5354
}
5455

5556
@Nonnull
5657
@Override
5758
public ChatMessage createChatMessage() {
58-
if (toolCall() != null) {
59-
val function = Map.of("name", toolCall().name(), "arguments", toolCall().arguments());
60-
val toolCallMap = Map.of("id", toolCall().id(), toolCall().type(), function);
59+
if (toolCalls() != null) {
60+
final List<Map<String, Object>> toolCallList =
61+
toolCalls().stream()
62+
.map(
63+
toolCall -> {
64+
val function =
65+
Map.of("name", toolCall.name(), "arguments", toolCall.arguments());
66+
return Map.of(
67+
"id", toolCall.id(), "type", toolCall.type(), toolCall.type(), function);
68+
})
69+
.toList();
6170

62-
val message = SingleChatMessage.create().role(role).content(""); // content shouldn't be set
63-
message.setCustomField("tool_calls", toolCallMap);
71+
// content shouldn't be required for tool calls 🤷
72+
val message = SingleChatMessage.create().role(role).content("");
73+
message.setCustomField("tool_calls", toolCallList);
6474
return message;
6575
}
6676
return Message.super.createChatMessage();

orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatModel.java

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -123,28 +123,42 @@ private OrchestrationPrompt toOrchestrationPrompt(@Nonnull final Prompt prompt)
123123
@Nonnull
124124
private static com.sap.ai.sdk.orchestration.Message[] toOrchestrationMessages(
125125
@Nonnull final List<Message> messages) {
126-
final Function<Message, com.sap.ai.sdk.orchestration.Message> mapper =
126+
final Function<Message, List<com.sap.ai.sdk.orchestration.Message>> mapper =
127127
msg ->
128128
switch (msg.getMessageType()) {
129129
case SYSTEM:
130-
yield new SystemMessage(msg.getText());
130+
yield List.of(new SystemMessage(msg.getText()));
131131
case USER:
132-
yield new UserMessage(msg.getText());
132+
yield List.of(new UserMessage(msg.getText()));
133133
case ASSISTANT:
134134
final List<ToolCall> toolCalls =
135135
((org.springframework.ai.chat.messages.AssistantMessage) msg).getToolCalls();
136136
if (toolCalls != null) {
137-
val toolCall = toolCalls.get(0);
138-
yield new AssistantMessage(
139-
new AssistantMessage.ToolCall(
140-
toolCall.id(), toolCall.type(), toolCall.name(), toolCall.arguments()));
137+
final List<AssistantMessage.ToolCall> toolCallList =
138+
toolCalls.stream()
139+
.map(
140+
toolCall ->
141+
new AssistantMessage.ToolCall(
142+
toolCall.id(),
143+
toolCall.type(),
144+
toolCall.name(),
145+
toolCall.arguments()))
146+
.toList();
147+
yield List.of(new AssistantMessage(toolCallList));
141148
}
142-
yield new AssistantMessage(msg.getText());
149+
yield List.of(new AssistantMessage(msg.getText()));
143150
case TOOL:
144-
val responses = ((ToolResponseMessage) msg).getResponses();
145-
val response = responses.get(0);
146-
yield new ToolMessage(response.id(), response.responseData());
151+
val toolResponses = ((ToolResponseMessage) msg).getResponses();
152+
yield toolResponses.stream()
153+
.map(
154+
r ->
155+
(com.sap.ai.sdk.orchestration.Message)
156+
new ToolMessage(r.id(), r.responseData()))
157+
.toList();
147158
};
148-
return messages.stream().map(mapper).toArray(com.sap.ai.sdk.orchestration.Message[]::new);
159+
return messages.stream()
160+
.map(mapper)
161+
.flatMap(List::stream)
162+
.toArray(com.sap.ai.sdk.orchestration.Message[]::new);
149163
}
150164
}

sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/SpringAiOrchestrationService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public ChatResponse completion() {
4242
.description("Get the weather in location") // (2) function description
4343
.inputType(MockWeatherService.Request.class) // (3) function input type
4444
.build()));
45-
val prompt = new Prompt("What is the weather in Potsdam?", defaultOptions);
45+
val prompt = new Prompt("What is the weather in Potsdam and in Toulouse?", defaultOptions);
4646

4747
return client.call(prompt);
4848
}

0 commit comments

Comments
 (0)