Skip to content

Commit 3038eb3

Browse files
committed
Introduce OpenAiToolExecutor
1 parent 66cef01 commit 3038eb3

File tree

8 files changed

+104
-81
lines changed

8 files changed

+104
-81
lines changed

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequest.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,8 +291,7 @@ public OpenAiChatCompletionRequest withToolChoice(@Nonnull final OpenAiToolChoic
291291
* @since 1.7.0
292292
*/
293293
@Nonnull
294-
public <I, O> OpenAiChatCompletionRequest withOpenAiTools(
295-
@Nonnull final List<OpenAiTool<I, O>> tools) {
294+
public OpenAiChatCompletionRequest withOpenAiTools(@Nonnull final List<OpenAiTool<?>> tools) {
296295
return this.withTools(tools.stream().map(OpenAiTool::createChatCompletionTool).toList());
297296
}
298297

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionResponse.java

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -99,27 +99,5 @@ public OpenAiAssistantMessage getMessage() {
9999
return new OpenAiAssistantMessage(new OpenAiMessageContent(contentItems), openAiToolCalls);
100100
}
101101

102-
public <T, R> List<OpenAiToolMessage> executeTools(List<OpenAiTool<T, R>> tools) {
103-
var toolMessages = new ArrayList<OpenAiToolMessage>();
104-
105-
for (var toolCall : getMessage().toolCalls()) {
106-
if (toolCall instanceof OpenAiFunctionCall functionCall) {
107-
for (OpenAiTool<T, R> tool : tools) {
108-
if (tool.getName().equals(functionCall.getName())) {
109-
T arguments = functionCall.getArgumentsAsObject(tool);
110-
R response = tool.execute(arguments);
111-
112-
String serializedResponse;
113-
try {
114-
serializedResponse = OpenAiUtils.getOpenAiObjectMapper().writeValueAsString(response);
115-
} catch (JsonProcessingException e) {
116-
throw new IllegalArgumentException(e);
117-
}
118-
toolMessages.add(OpenAiMessage.tool(serializedResponse, functionCall.getId()));
119-
}
120-
}
121-
}
122-
}
123-
return toolMessages;
124-
}
102+
125103
}

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiFunctionCall.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,16 @@ public Map<String, Object> getArgumentsAsMap() throws IllegalArgumentException {
5252
* @since 1.7.0
5353
*/
5454
@Nonnull
55-
public <T> T getArgumentsAsObject(@Nonnull final OpenAiTool<T, ?> tool)
55+
public <T> T getArgumentsAsObject(@Nonnull final OpenAiTool<T> tool)
5656
throws IllegalArgumentException {
57-
final var typeRef =
58-
new TypeReference<T>() {
57+
58+
return parseArguments(
59+
new TypeReference<>() {
5960
@Override
6061
public Type getType() {
6162
return tool.getRequestClass();
6263
}
63-
};
64-
return parseArguments(typeRef);
64+
});
6565
}
6666

6767
@Nonnull

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiTool.java

Lines changed: 15 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -34,33 +34,22 @@
3434
@Getter(AccessLevel.PACKAGE)
3535
@Accessors(chain = true)
3636
@AllArgsConstructor(access = AccessLevel.PRIVATE)
37-
public class OpenAiTool<T, R> {
37+
public class OpenAiTool<I> {
3838

3939
/** The name of the function. */
4040
@Nonnull String name;
4141

4242
/** The model class for function request. */
43-
@Nonnull Class<T> requestClass;
43+
@Nonnull Class<I> requestClass;
4444

4545
/** An optional description of the function. */
46-
@Nullable String description;
46+
@Setter @Nullable String description;
4747

4848
/** An optional flag indicating whether the function parameters should be treated strictly. */
49-
@Nullable Boolean strict;
49+
@Setter @Nullable Boolean strict;
5050

5151
/** The function to be called. */
52-
@Setter(AccessLevel.NONE)
53-
@Nullable
54-
Function<T, R> function;
55-
56-
/** The response class for the function. */
57-
@Setter(AccessLevel.NONE)
58-
@Nullable
59-
Class<R> responseClass;
60-
61-
public static <I, O> OpenAiTool<I, O> of(@Nonnull String name, @Nonnull Class<I> requestClass) {
62-
return new OpenAiTool<>(name, requestClass);
63-
}
52+
@Setter @Nullable Function<I, ?> function;
6453

6554
/**
6655
* Constructs an {@code OpenAiFunctionTool} with the specified name and a model class that
@@ -69,28 +58,21 @@ public static <I, O> OpenAiTool<I, O> of(@Nonnull String name, @Nonnull Class<I>
6958
* @param name the name of the function
7059
* @param requestClass the model class for function request
7160
*/
72-
private OpenAiTool(@Nonnull final String name, @Nonnull final Class<T> requestClass) {
73-
this(name, requestClass, null, null, null, null);
61+
public OpenAiTool(@Nonnull final String name, @Nonnull final Class<I> requestClass) {
62+
this(name, requestClass, null, null, null);
7463
}
7564

76-
/**
77-
* Sets the function to be called and the response class for the function.
78-
*
79-
* @param function the function to be called
80-
* @param responseClass the response class for the function
81-
* @return this instance of {@code OpenAiFunctionTool}
82-
*/
8365
@Nonnull
84-
public OpenAiTool<T, R> setCallback(
85-
@Nonnull final Function<T, R> function, @Nonnull final Class<R> responseClass) {
86-
this.function = function;
87-
this.responseClass = responseClass;
88-
return this;
66+
Object execute(@Nonnull final I argument) {
67+
if (getFunction() == null) {
68+
throw new IllegalStateException("Callback function is not set");
69+
}
70+
return getFunction().apply(argument);
8971
}
9072

91-
@Nonnull
92-
R execute(@Nonnull final T argument) {
93-
return function.apply(argument);
73+
public OpenAiTool<I> setCallback(Function<I, ?> function) {
74+
this.function = function;
75+
return this;
9476
}
9577

9678
ChatCompletionTool createChatCompletionTool() {
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package com.sap.ai.sdk.foundationmodels.openai;
2+
3+
import static lombok.AccessLevel.PRIVATE;
4+
5+
import com.fasterxml.jackson.core.JsonProcessingException;
6+
import com.fasterxml.jackson.databind.ObjectMapper;
7+
import com.google.common.annotations.Beta;
8+
import java.util.List;
9+
import java.util.stream.Collectors;
10+
import javax.annotation.Nonnull;
11+
import lombok.AllArgsConstructor;
12+
13+
/**
14+
* A class for OpenAI tool execution.
15+
*
16+
* @since 1.7.0
17+
*/
18+
@Beta
19+
@AllArgsConstructor(access = PRIVATE)
20+
public class OpenAiToolExecutor {
21+
22+
private static final ObjectMapper JACKSON = new ObjectMapper();
23+
24+
/**
25+
* Executes the given tool calls with the provided tools and returns the results as a list of
26+
* {@link OpenAiToolMessage}.
27+
*
28+
* @param tools the list of tools to execute
29+
* @param toolCalls the list of tool calls with arguments
30+
* @return the list of tool messages with the results
31+
*/
32+
@Nonnull
33+
public static List<OpenAiToolMessage> executeTools(
34+
List<OpenAiTool<?>> tools, List<OpenAiToolCall> toolCalls) {
35+
36+
final var toolMap = tools.stream().collect(Collectors.toMap(OpenAiTool::getName, tool -> tool));
37+
38+
return toolCalls.stream()
39+
.filter(OpenAiFunctionCall.class::isInstance)
40+
.map(OpenAiFunctionCall.class::cast)
41+
.filter(functionCall -> toolMap.containsKey(functionCall.getName()))
42+
.map(
43+
functionCall -> {
44+
var tool = toolMap.get(functionCall.getName());
45+
var result = executeFunction(tool, functionCall);
46+
return OpenAiMessage.tool(serializeObject(result), functionCall.getId());
47+
})
48+
.toList();
49+
}
50+
51+
private static <I> Object executeFunction(OpenAiTool<I> tool, OpenAiFunctionCall toolCall) {
52+
I arguments = toolCall.getArgumentsAsObject(tool);
53+
return tool.execute(arguments);
54+
}
55+
56+
@Nonnull
57+
private static String serializeObject(@Nonnull final Object obj) {
58+
try {
59+
return JACKSON.writeValueAsString(obj);
60+
} catch (JsonProcessingException e) {
61+
throw new IllegalArgumentException(e);
62+
}
63+
}
64+
}

foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequestTest.java

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import java.util.ArrayList;
1212
import java.util.List;
1313
import java.util.Map;
14-
import java.util.function.Function;
1514
import org.junit.jupiter.api.Test;
1615

1716
class OpenAiChatCompletionRequestTest {
@@ -121,19 +120,15 @@ void messageListExternallyUnmodifiable() {
121120
@Test
122121
void withOpenAiTools() {
123122
record DummyRequest(String param1, int param2) {}
124-
record DummyResponse(String result) {}
125-
126-
Function<DummyRequest, DummyResponse> conCat =
127-
(request) -> new DummyResponse(request.param1 + request.param2);
128123

129124
var request =
130125
new OpenAiChatCompletionRequest(OpenAiMessage.user("Hello, world"))
131126
.withOpenAiTools(
132127
List.of(
133-
OpenAiTool.of("toolA", DummyRequest.class)
128+
new OpenAiTool<>("toolA", DummyRequest.class)
134129
.setDescription("descA")
135130
.setStrict(true),
136-
OpenAiTool.of("toolB", DummyRequest.class)
131+
new OpenAiTool<>("toolB", DummyRequest.class)
137132
.setDescription("descB")
138133
.setStrict(false)));
139134

foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiToolCallTest.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import static org.assertj.core.api.Assertions.assertThat;
44
import static org.assertj.core.api.Assertions.assertThatThrownBy;
55

6-
import com.fasterxml.jackson.core.type.TypeReference;
76
import java.util.function.Function;
87
import org.junit.jupiter.api.Test;
98

@@ -21,8 +20,8 @@ record Response(String result) {}
2120
static Function<Request, Response> conCat = request -> new Response(request.key());
2221
}
2322

24-
private static final OpenAiTool<Dummy.Request, Dummy.Request> TOOL =
25-
OpenAiTool.of("functionName", Dummy.Request.class);
23+
private static final OpenAiTool<Dummy.Request> TOOL =
24+
new OpenAiTool<>("functionName", Dummy.Request.class);
2625

2726
@Test
2827
void getArgumentsAsMapParsesValidJson() {
@@ -39,7 +38,7 @@ void getArgumentsAsMapThrowsOnInvalidJson() {
3938

4039
@Test
4140
void getArgumentsAsObjectParsesValidJson() {
42-
var result = VALID_FUNCTION_CALL.getArgumentsAsObject(TOOL);
41+
Dummy.Request result = VALID_FUNCTION_CALL.getArgumentsAsObject(TOOL);
4342
assertThat(result).isInstanceOf(Dummy.Request.class);
4443
assertThat(result.key()).isEqualTo("value");
4544
}

sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OpenAiServiceV2.java

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import com.sap.ai.sdk.foundationmodels.openai.OpenAiImageItem;
1616
import com.sap.ai.sdk.foundationmodels.openai.OpenAiMessage;
1717
import com.sap.ai.sdk.foundationmodels.openai.OpenAiTool;
18+
import com.sap.ai.sdk.foundationmodels.openai.OpenAiToolExecutor;
19+
import com.sap.ai.sdk.foundationmodels.openai.OpenAiToolMessage;
1820
import java.util.ArrayList;
1921
import java.util.List;
2022
import java.util.stream.Stream;
@@ -83,7 +85,8 @@ public OpenAiChatCompletionResponse chatCompletionImage(@Nonnull final String li
8385
}
8486

8587
/**
86-
* Executes a chat completion request to OpenAI with a tool that calculates the weather.
88+
* Chat request to OpenAI with tool that gets the weather for a given location and unit. The tool
89+
* executed and the result is sent back to the assistant.
8790
*
8891
* @param location The location to get the weather for.
8992
* @param unit The unit of temperature to use.
@@ -96,22 +99,25 @@ public OpenAiChatCompletionResponse chatCompletionToolExecution(
9699
messages.add(OpenAiMessage.user("What's the weather in %s in %s?".formatted(location, unit)));
97100

98101
// 1. Define the function
99-
final var tools =
102+
final List<OpenAiTool<?>> tools =
100103
List.of(
101-
OpenAiTool.<WeatherMethod.Request, WeatherMethod.Response>of(
102-
"weather", WeatherMethod.Request.class)
104+
new OpenAiTool<>("weather", WeatherMethod.Request.class)
103105
.setDescription("Get the weather for the given location")
104-
.setCallback(WeatherMethod::getCurrentWeather, WeatherMethod.Response.class));
106+
.setCallback(WeatherMethod::getCurrentWeather));
105107

106108
// 2. Assistant calls the function
107109
final var request = new OpenAiChatCompletionRequest(messages).withOpenAiTools(tools);
108110
final OpenAiChatCompletionResponse response =
109111
OpenAiClient.forModel(GPT_4O_MINI).chatCompletion(request);
110112
final OpenAiAssistantMessage assistantMessage = response.getMessage();
111113

112-
// 4. Send back the results, and the model will incorporate them into its final response.
114+
// 3. Execute the tool call for given tools
115+
List<OpenAiToolMessage> toolMessages =
116+
OpenAiToolExecutor.executeTools(tools, assistantMessage.toolCalls());
117+
118+
// 4. Send back the results for model will incorporate them into its final response.
113119
messages.add(assistantMessage);
114-
messages.addAll(response.executeTools(tools));
120+
messages.addAll(toolMessages);
115121

116122
return OpenAiClient.forModel(GPT_4O_MINI).chatCompletion(request.withMessages(messages));
117123
}

0 commit comments

Comments
 (0)