Skip to content

Commit 6ba9954

Browse files
CharlesDuboisSAPbot-sdk-jsnewtorka-d
authored
feat: Spring AI🍃Tool Calling (#320)
* Spring AI functions * Test function callback * Remove test * First tool call * Best effort * TODO * Formatting * Fix test * Removed TODO * Checkstyle * lombok * List of tool calls supported * Unit test wip * Unit test almost * Unit test finished * Green * Formatting * Green * new ToolCall class * Added documentation and release notes * Replaced ToolCall class with existing ResponseMessageToolCall * Tool call wip * Spring AI 1.0.0-M6 * Reduced options class size * Added tests * Updated release notes * Replaced FunctionCallbacks with ToolCallbacks * Updated docs * Updated docs * Removed useless code * Removed whitespace * Added deprecation notice * Update orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatModel.java Co-authored-by: Alexander Dümont <[email protected]> * composition for toolCallingManager * since 1.4.0 * dependencies * make fields final again --------- Co-authored-by: SAP Cloud SDK Bot <[email protected]> Co-authored-by: Alexander Dümont <[email protected]> Co-authored-by: Alexander Dümont <[email protected]>
1 parent 39efe77 commit 6ba9954

26 files changed

+949
-78
lines changed

docs/guides/SPRING_AI_INTEGRATION.md

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
- [Introduction](#introduction)
66
- [Orchestration Chat Completion](#orchestration-chat-completion)
77
- [Orchestration Masking](#orchestration-masking)
8+
- [Stream chat completion](#stream-chat-completion)
9+
- [Tool Calling](#tool-calling)
810

911
## Introduction
1012

@@ -32,7 +34,7 @@ First, add the Spring AI dependency to your `pom.xml`:
3234

3335
:::note Spring AI Milestone Version
3436
Note that currently no stable version of Spring AI exists just yet.
35-
The AI SDK currently uses the [M5 milestone](https://spring.io/blog/2024/12/23/spring-ai-1-0-0-m5-released).
37+
The AI SDK currently uses the [M6 milestone](https://spring.io/blog/2025/02/14/spring-ai-1-0-0-m6-released).
3638

3739
Please be aware that future versions of the AI SDK may increase the Spring AI version.
3840
:::
@@ -99,3 +101,40 @@ Flux<String> responseFlux =
99101
_Note: A Spring endpoint can return `Flux` instead of `ResponseEntity`._
100102

101103
Please find [an example in our Spring Boot application](../../sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/SpringAiOrchestrationService.java).
104+
105+
## Tool Calling
106+
107+
First define a function that will be called by the LLM:
108+
109+
```java
110+
class WeatherMethod {
111+
enum Unit {C,F}
112+
record Request(String location, Unit unit) {}
113+
record Response(double temp, Unit unit) {}
114+
115+
@Tool(description = "Get the weather in location")
116+
Response getCurrentWeather(@ToolParam Request request) {
117+
int temperature = request.location.hashCode() % 30;
118+
return new Response(temperature, request.unit);
119+
}
120+
}
121+
```
122+
123+
Then add your tool to the options:
124+
125+
```java
126+
ChatModel client = new OrchestrationChatModel();
127+
OrchestrationModuleConfig config = new OrchestrationModuleConfig().withLlmConfig(GPT_35_TURBO);
128+
OrchestrationChatOptions opts = new OrchestrationChatOptions(config);
129+
130+
options.setToolCallbacks(List.of(ToolCallbacks.from(new WeatherMethod())));
131+
132+
options.setInternalToolExecutionEnabled(false);// tool execution is not yet available in orchestration
133+
134+
Prompt prompt = new Prompt("What is the weather in Potsdam and in Toulouse?", options);
135+
136+
ChatResponse response = client.call(prompt);
137+
```
138+
139+
Please find [an example in our Spring Boot application](../../sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/SpringAiOrchestrationService.java).
140+

docs/release-notes/release_notes.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
### ✨ New Functionality
1414

15-
-
15+
- [Add Spring AI tool calling](../guides/SPRING_AI_INTEGRATION.md#tool-calling).
1616

1717
### 📈 Improvements
1818

orchestration/pom.xml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@
3131
</developers>
3232
<properties>
3333
<project.rootdir>${project.basedir}/../</project.rootdir>
34-
<coverage.complexity>80%</coverage.complexity>
34+
<coverage.complexity>81%</coverage.complexity>
3535
<coverage.line>92%</coverage.line>
3636
<coverage.instruction>93%</coverage.instruction>
37-
<coverage.branch>71%</coverage.branch>
38-
<coverage.method>95%</coverage.method>
37+
<coverage.branch>74%</coverage.branch>
38+
<coverage.method>92%</coverage.method>
3939
<coverage.class>100%</coverage.class>
4040
</properties>
4141

orchestration/src/main/java/com/sap/ai/sdk/orchestration/AssistantMessage.java

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
11
package com.sap.ai.sdk.orchestration;
22

33
import com.google.common.annotations.Beta;
4+
import com.sap.ai.sdk.orchestration.model.ChatMessage;
5+
import com.sap.ai.sdk.orchestration.model.ResponseMessageToolCall;
6+
import com.sap.ai.sdk.orchestration.model.SingleChatMessage;
47
import java.util.List;
58
import javax.annotation.Nonnull;
9+
import javax.annotation.Nullable;
610
import lombok.Getter;
711
import lombok.Value;
812
import lombok.experimental.Accessors;
13+
import lombok.val;
914

1015
/** Represents a chat message as 'assistant' to the orchestration service. */
1116
@Value
17+
@Getter
1218
@Accessors(fluent = true)
1319
public class AssistantMessage implements Message {
1420

@@ -20,12 +26,38 @@ public class AssistantMessage implements Message {
2026
@Getter(onMethod_ = @Beta)
2127
MessageContent content;
2228

29+
/** Tool call if there is any. */
30+
@Nullable List<ResponseMessageToolCall> toolCalls;
31+
2332
/**
2433
* Creates a new assistant message with the given single message.
2534
*
2635
* @param singleMessage the single message.
2736
*/
2837
public AssistantMessage(@Nonnull final String singleMessage) {
2938
content = new MessageContent(List.of(new TextItem(singleMessage)));
39+
toolCalls = null;
40+
}
41+
42+
/**
43+
* Creates a new assistant message with the given tool calls.
44+
*
45+
* @param toolCalls list of tool call objects
46+
*/
47+
public AssistantMessage(@Nonnull final List<ResponseMessageToolCall> toolCalls) {
48+
content = new MessageContent(List.of());
49+
this.toolCalls = toolCalls;
50+
}
51+
52+
@Nonnull
53+
@Override
54+
public ChatMessage createChatMessage() {
55+
if (toolCalls() != null) {
56+
// content shouldn't be required for tool calls 🤷
57+
val message = SingleChatMessage.create().role(role).content("");
58+
message.setCustomField("tool_calls", toolCalls);
59+
return message;
60+
}
61+
return Message.super.createChatMessage();
3062
}
3163
}

orchestration/src/main/java/com/sap/ai/sdk/orchestration/ConfigToRequestTransformer.java

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import com.sap.ai.sdk.orchestration.model.TemplatingModuleConfig;
99
import io.vavr.control.Option;
1010
import java.util.ArrayList;
11-
import java.util.List;
1211
import javax.annotation.Nonnull;
1312
import javax.annotation.Nullable;
1413
import lombok.AccessLevel;
@@ -40,24 +39,28 @@ static CompletionPostRequest toCompletionPostRequest(
4039

4140
@Nonnull
4241
static TemplatingModuleConfig toTemplateModuleConfig(
43-
@Nonnull final OrchestrationPrompt prompt, @Nullable final TemplatingModuleConfig template) {
42+
@Nonnull final OrchestrationPrompt prompt, @Nullable final TemplatingModuleConfig config) {
4443
/*
4544
* Currently, we have to merge the prompt into the template configuration.
4645
* This works around the limitation that the template config is required.
4746
* This comes at the risk that the prompt unintentionally contains the templating pattern "{{? .. }}".
4847
* In this case, the request will fail, since the templating module will try to resolve the parameter.
4948
* To be fixed with https://github.tools.sap/AI/llm-orchestration/issues/662
5049
*/
51-
val messages = template instanceof Template t ? t.getTemplate() : List.<ChatMessage>of();
52-
val responseFormat = template instanceof Template t ? t.getResponseFormat() : null;
50+
val template = config instanceof Template t ? t : Template.create().template();
51+
val messages = template.getTemplate();
52+
val responseFormat = template.getResponseFormat();
5353
val messagesWithPrompt = new ArrayList<>(messages);
5454
messagesWithPrompt.addAll(
5555
prompt.getMessages().stream().map(Message::createChatMessage).toList());
5656
if (messagesWithPrompt.isEmpty()) {
5757
throw new IllegalStateException(
5858
"A prompt is required. Pass at least one message or configure a template with messages or a template reference.");
5959
}
60-
return Template.create().template(messagesWithPrompt).responseFormat(responseFormat);
60+
return Template.create()
61+
.template(messagesWithPrompt)
62+
.tools(template.getTools())
63+
.responseFormat(responseFormat);
6164
}
6265

6366
@Nonnull

orchestration/src/main/java/com/sap/ai/sdk/orchestration/Message.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import javax.annotation.Nonnull;
1414

1515
/** Interface representing convenience wrappers of chat message to the orchestration service. */
16-
public sealed interface Message permits UserMessage, AssistantMessage, SystemMessage {
16+
public sealed interface Message permits AssistantMessage, SystemMessage, ToolMessage, UserMessage {
1717

1818
/**
1919
* A convenience method to create a user message from a string.
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import com.sap.ai.sdk.orchestration.model.ChatMessage;
4+
import com.sap.ai.sdk.orchestration.model.SingleChatMessage;
5+
import java.util.List;
6+
import javax.annotation.Nonnull;
7+
import lombok.Value;
8+
import lombok.experimental.Accessors;
9+
10+
/**
11+
* Represents a chat message as 'tool' to the orchestration service.
12+
*
13+
* @since 1.4.0
14+
*/
15+
@Value
16+
@Accessors(fluent = true)
17+
public class ToolMessage implements Message {
18+
19+
/** The role of the assistant. */
20+
@Nonnull String role = "tool";
21+
22+
@Nonnull String id;
23+
24+
@Nonnull String content;
25+
26+
@Nonnull
27+
@Override
28+
public MessageContent content() {
29+
return new MessageContent(List.of(new TextItem(content)));
30+
}
31+
32+
@Nonnull
33+
@Override
34+
public ChatMessage createChatMessage() {
35+
final SingleChatMessage message = SingleChatMessage.create().role(role()).content(content);
36+
message.setCustomField("tool_call_id", id);
37+
return message;
38+
}
39+
}

orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatModel.java

Lines changed: 69 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,32 @@
11
package com.sap.ai.sdk.orchestration.spring;
22

33
import static com.sap.ai.sdk.orchestration.OrchestrationClient.toCompletionPostRequest;
4+
import static com.sap.ai.sdk.orchestration.model.ResponseMessageToolCall.TypeEnum.FUNCTION;
45

56
import com.google.common.annotations.Beta;
67
import com.sap.ai.sdk.orchestration.AssistantMessage;
78
import com.sap.ai.sdk.orchestration.OrchestrationChatCompletionDelta;
89
import com.sap.ai.sdk.orchestration.OrchestrationClient;
910
import com.sap.ai.sdk.orchestration.OrchestrationPrompt;
1011
import com.sap.ai.sdk.orchestration.SystemMessage;
12+
import com.sap.ai.sdk.orchestration.ToolMessage;
1113
import com.sap.ai.sdk.orchestration.UserMessage;
14+
import com.sap.ai.sdk.orchestration.model.ResponseMessageToolCall;
15+
import com.sap.ai.sdk.orchestration.model.ResponseMessageToolCallFunction;
1216
import java.util.List;
1317
import java.util.Map;
1418
import java.util.function.Function;
1519
import javax.annotation.Nonnull;
16-
import lombok.RequiredArgsConstructor;
1720
import lombok.extern.slf4j.Slf4j;
1821
import lombok.val;
22+
import org.springframework.ai.chat.messages.AssistantMessage.ToolCall;
1923
import org.springframework.ai.chat.messages.Message;
24+
import org.springframework.ai.chat.messages.ToolResponseMessage;
2025
import org.springframework.ai.chat.model.ChatModel;
2126
import org.springframework.ai.chat.model.ChatResponse;
2227
import org.springframework.ai.chat.prompt.Prompt;
28+
import org.springframework.ai.model.tool.DefaultToolCallingManager;
29+
import org.springframework.ai.model.tool.ToolCallingChatOptions;
2330
import reactor.core.publisher.Flux;
2431

2532
/**
@@ -29,28 +36,48 @@
2936
*/
3037
@Beta
3138
@Slf4j
32-
@RequiredArgsConstructor
3339
public class OrchestrationChatModel implements ChatModel {
34-
@Nonnull private OrchestrationClient client;
40+
@Nonnull private final OrchestrationClient client;
41+
42+
@Nonnull
43+
private final DefaultToolCallingManager toolCallingManager =
44+
DefaultToolCallingManager.builder().build();
3545

3646
/**
3747
* Default constructor.
3848
*
3949
* @since 1.2.0
4050
*/
4151
public OrchestrationChatModel() {
42-
this.client = new OrchestrationClient();
52+
this(new OrchestrationClient());
53+
}
54+
55+
/**
56+
* Constructor with a custom client.
57+
*
58+
* @since 1.2.0
59+
*/
60+
public OrchestrationChatModel(@Nonnull final OrchestrationClient client) {
61+
this.client = client;
4362
}
4463

4564
@Nonnull
4665
@Override
4766
public ChatResponse call(@Nonnull final Prompt prompt) {
48-
4967
if (prompt.getOptions() instanceof OrchestrationChatOptions options) {
5068

5169
val orchestrationPrompt = toOrchestrationPrompt(prompt);
52-
val response = client.chatCompletion(orchestrationPrompt, options.getConfig());
53-
return new OrchestrationSpringChatResponse(response);
70+
val response =
71+
new OrchestrationSpringChatResponse(
72+
client.chatCompletion(orchestrationPrompt, options.getConfig()));
73+
74+
if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions())
75+
&& response.hasToolCalls()) {
76+
val toolExecutionResult = toolCallingManager.executeToolCalls(prompt, response);
77+
// Send the tool execution result back to the model.
78+
return call(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()));
79+
}
80+
return response;
5481
}
5582
throw new IllegalArgumentException(
5683
"Please add OrchestrationChatOptions to the Prompt: new Prompt(\"message\", new OrchestrationChatOptions(config))");
@@ -92,18 +119,47 @@ private OrchestrationPrompt toOrchestrationPrompt(@Nonnull final Prompt prompt)
92119
@Nonnull
93120
private static com.sap.ai.sdk.orchestration.Message[] toOrchestrationMessages(
94121
@Nonnull final List<Message> messages) {
95-
final Function<Message, com.sap.ai.sdk.orchestration.Message> mapper =
122+
final Function<Message, List<com.sap.ai.sdk.orchestration.Message>> mapper =
96123
msg ->
97124
switch (msg.getMessageType()) {
98125
case SYSTEM:
99-
yield new SystemMessage(msg.getText());
126+
yield List.of(new SystemMessage(msg.getText()));
100127
case USER:
101-
yield new UserMessage(msg.getText());
128+
yield List.of(new UserMessage(msg.getText()));
102129
case ASSISTANT:
103-
yield new AssistantMessage(msg.getText());
130+
val springToolCalls =
131+
((org.springframework.ai.chat.messages.AssistantMessage) msg).getToolCalls();
132+
if (springToolCalls != null) {
133+
final List<ResponseMessageToolCall> sdkToolCalls =
134+
springToolCalls.stream()
135+
.map(OrchestrationChatModel::toOrchestrationToolCall)
136+
.toList();
137+
yield List.of(new AssistantMessage(sdkToolCalls));
138+
}
139+
yield List.of(new AssistantMessage(msg.getText()));
104140
case TOOL:
105-
throw new IllegalArgumentException("Tool messages are not supported");
141+
val toolResponses = ((ToolResponseMessage) msg).getResponses();
142+
yield toolResponses.stream()
143+
.map(
144+
r ->
145+
(com.sap.ai.sdk.orchestration.Message)
146+
new ToolMessage(r.id(), r.responseData()))
147+
.toList();
106148
};
107-
return messages.stream().map(mapper).toArray(com.sap.ai.sdk.orchestration.Message[]::new);
149+
return messages.stream()
150+
.map(mapper)
151+
.flatMap(List::stream)
152+
.toArray(com.sap.ai.sdk.orchestration.Message[]::new);
153+
}
154+
155+
@Nonnull
156+
private static ResponseMessageToolCall toOrchestrationToolCall(@Nonnull final ToolCall toolCall) {
157+
return ResponseMessageToolCall.create()
158+
.id(toolCall.id())
159+
.type(FUNCTION)
160+
.function(
161+
ResponseMessageToolCallFunction.create()
162+
.name(toolCall.name())
163+
.arguments(toolCall.arguments()));
108164
}
109165
}

0 commit comments

Comments
 (0)