Skip to content

Commit 57ba2a1

Browse files
finito
1 parent bfce048 commit 57ba2a1

File tree

6 files changed

+56
-67
lines changed

6 files changed

+56
-67
lines changed

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequest.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ public class OpenAiChatCompletionRequest {
123123
@Nullable CreateChatCompletionRequestAllOfResponseFormat responseFormat;
124124

125125
/** List of tools that the model may invoke during the completion. */
126-
@Nullable List<ChatCompletionTool> tools;
126+
@Getter(value = AccessLevel.PACKAGE)
127+
@Nullable List<OpenAiFunctionTool<?, ?>> tools;
127128

128129
/** Option to control which tool is invoked by the model. */
129130
@With(AccessLevel.PRIVATE)
@@ -297,7 +298,7 @@ public OpenAiChatCompletionRequest withOpenAiTools(@Nonnull final List<OpenAiToo
297298
.map(
298299
tool -> {
299300
if (tool instanceof OpenAiFunctionTool) {
300-
return ((OpenAiFunctionTool) tool).createChatCompletionTool();
301+
return ((OpenAiFunctionTool<?,?>) tool).createChatCompletionTool();
301302
} else {
302303
throw new IllegalArgumentException(
303304
"Unsupported tool type: " + tool.getClass().getName());
@@ -336,7 +337,7 @@ CreateChatCompletionRequest createCreateChatCompletionRequest() {
336337
request.seed(this.seed);
337338
request.streamOptions(this.streamOptions);
338339
request.responseFormat(this.responseFormat);
339-
request.tools(this.tools);
340+
request.tools(this.tools.createChatCompletionTool());
340341
request.toolChoice(this.toolChoice);
341342
request.functionCall(null);
342343
request.functions(null);

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionResponse.java

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,20 @@
44
import static lombok.AccessLevel.NONE;
55
import static lombok.AccessLevel.PACKAGE;
66

7+
import com.fasterxml.jackson.core.type.TypeReference;
78
import com.google.common.annotations.Beta;
9+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool;
810
import com.sap.ai.sdk.foundationmodels.openai.generated.model.CompletionUsage;
911
import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionResponse;
1012
import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionResponseChoicesInner;
11-
12-
import java.util.ArrayList;
1313
import java.util.List;
1414
import java.util.Objects;
1515
import javax.annotation.Nonnull;
16+
import javax.annotation.Nullable;
1617
import lombok.RequiredArgsConstructor;
1718
import lombok.Setter;
1819
import lombok.Value;
20+
import lombok.val;
1921

2022
/**
2123
* Represents the output of an OpenAI chat completion. *
@@ -28,7 +30,9 @@
2830
@Setter(value = NONE)
2931
public class OpenAiChatCompletionResponse {
3032
/** The original response from the OpenAI API. */
31-
@Nonnull final CreateChatCompletionResponse originalResponse;
33+
@Nonnull CreateChatCompletionResponse originalResponse;
34+
35+
@Nullable List<ChatCompletionTool> functions;
3236

3337
/**
3438
* Gets the token usage from the original response.
@@ -99,20 +103,30 @@ public OpenAiAssistantMessage getMessage() {
99103
return new OpenAiAssistantMessage(new OpenAiMessageContent(contentItems), openAiToolCalls);
100104
}
101105

102-
public List<OpenAiToolMessage> executeTools() {
103-
var toolMessages = new ArrayList<OpenAiToolMessage>();
104-
for (var toolcall : getMessage().toolCalls()) {
105-
if (toolcall instanceof OpenAiFunctionCall functionCall) {
106-
if (functionCall.getName() == "weather") {
107-
final WeatherMethod.Request arguments =
108-
functionCall.getArgumentsAsObject(weatherFunction);
109-
final WeatherMethod.Response currentWeather = WeatherMethod.getCurrentWeather(arguments);
110-
toolMessages.add(currentWeather.toString(), functionCall.getId());
111-
}
112-
} else {
113-
throw new IllegalArgumentException(
114-
"Expected a function call, but got: %s".formatted(assistantMessage));
106+
public <T, R> List<OpenAiToolMessage> executeTools( List<OpenAiTool> tools ) {
107+
return getMessage().toolCalls().stream()
108+
.filter(toolCall -> toolCall instanceof OpenAiFunctionCall)
109+
.map(toolCall -> (OpenAiFunctionCall) toolCall)
110+
.map(
111+
functionCall -> {
112+
OpenAiFunctionTool<T, R> request = findFunction(tools, functionCall.getName());
113+
T arguments = functionCall.parseArguments(new TypeReference<T>() {});
114+
R response = request.call(arguments);
115+
return OpenAiMessage.tool(response, functionCall.getId());
116+
})
117+
.toList();
118+
}
119+
120+
@Nullable
121+
private <T, R> OpenAiFunctionTool<T, R> findFunction(List<OpenAiTool> tools, String name) {
122+
if (functions == null) {
123+
return null;
124+
}
125+
for (OpenAiTool tool : tools) {
126+
if (tool instanceof OpenAiFunctionTool<?, ?> function && function.getName().equals(name)) {
127+
return (OpenAiFunctionTool<T, R>) function;
115128
}
116129
}
130+
return null;
117131
}
118132
}

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ public OpenAiChatCompletionResponse chatCompletion(
159159
@Nonnull final OpenAiChatCompletionRequest request) throws OpenAiClientException {
160160
warnIfUnsupportedUsage();
161161
return new OpenAiChatCompletionResponse(
162-
chatCompletion(request.createCreateChatCompletionRequest()));
162+
chatCompletion(request.createCreateChatCompletionRequest()), request.getTools());
163163
}
164164

165165
/**

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiFunctionCall.java

Lines changed: 6 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
import com.fasterxml.jackson.core.JsonProcessingException;
66
import com.fasterxml.jackson.core.type.TypeReference;
77
import com.google.common.annotations.Beta;
8-
import java.lang.reflect.Type;
9-
import java.util.Map;
108
import javax.annotation.Nonnull;
119
import lombok.AllArgsConstructor;
1210
import lombok.Value;
@@ -29,49 +27,25 @@ public class OpenAiFunctionCall implements OpenAiToolCall {
2927
/** The arguments for the function call, encoded as a JSON string. */
3028
@Nonnull String arguments;
3129

32-
/**
33-
* Parses the arguments, encoded as a JSON string, into a {@code Map<String, Object>}.
34-
*
35-
* @return a map of the arguments
36-
* @throws IllegalArgumentException if parsing fails
37-
* @since 1.7.0
38-
*/
39-
@Nonnull
40-
public Map<String, Object> getArgumentsAsMap() throws IllegalArgumentException {
41-
return parseArguments(new TypeReference<>() {});
42-
}
43-
4430
/**
4531
* Parses the arguments, encoded as a JSON string, into an object of type expected by a function
4632
* tool.
4733
*
48-
* @param tool the function tool the arguments are for
49-
* @param <T> the type of the class
34+
* @param request the type of the class
5035
* @return the parsed arguments as an object
5136
* @throws IllegalArgumentException if parsing fails
5237
* @since 1.7.0
5338
*/
5439
@Nonnull
55-
public <T> T getArgumentsAsObject(@Nonnull final OpenAiFunctionTool tool)
56-
throws IllegalArgumentException {
57-
final var typeRef =
58-
new TypeReference<T>() {
59-
@Override
60-
public Type getType() {
61-
return tool.getFunction();
62-
}
63-
};
64-
return parseArguments(typeRef);
65-
}
66-
67-
@Nonnull
68-
private <T> T parseArguments(@Nonnull final TypeReference<T> typeReference)
40+
<T> T parseArguments(@Nonnull final TypeReference<T> request)
6941
throws IllegalArgumentException {
7042
try {
71-
return getOpenAiObjectMapper().readValue(getArguments(), typeReference);
43+
return getOpenAiObjectMapper().readValue(arguments, request);
7244
} catch (JsonProcessingException e) {
7345
throw new IllegalArgumentException(
74-
"Failed to parse JSON string to class " + typeReference.getType(), e);
46+
"Failed to parse JSON string to class " + request, e);
7547
}
7648
}
49+
50+
7751
}

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiFunctionTool.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@ public class OpenAiFunctionTool<T, R> implements OpenAiTool {
4141
/** The model class for function request. */
4242
@Nonnull Function<T, R> function;
4343

44-
/** The model class for function response. */
45-
4644
/** An optional description of the function. */
4745
@Nullable String description;
4846

@@ -60,14 +58,19 @@ public OpenAiFunctionTool(@Nonnull final String name, @Nonnull final Function<T,
6058
this(name, function, null, null);
6159
}
6260

61+
@Nonnull
62+
R call(@Nonnull final T request) {
63+
return function.apply(request);
64+
}
65+
6366
ChatCompletionTool createChatCompletionTool() {
6467
final var objectMapper = new ObjectMapper();
6568
JsonSchema schema = null;
6669
try {
67-
schema = new JsonSchemaGenerator(objectMapper).generateSchema(Class<T>.class);
70+
schema = new JsonSchemaGenerator(objectMapper).generateSchema(new TypeReference<T>() {}.getClass());
6871
} catch (JsonMappingException e) {
6972
throw new IllegalArgumentException(
70-
"Could not generate schema for " + function.getTypeName(), e);
73+
"Could not generate schema for " + name, e);
7174
}
7275

7376
schema.setId(null);

sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OpenAiServiceV2.java

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import com.sap.ai.sdk.foundationmodels.openai.OpenAiFunctionTool;
1616
import com.sap.ai.sdk.foundationmodels.openai.OpenAiImageItem;
1717
import com.sap.ai.sdk.foundationmodels.openai.OpenAiMessage;
18+
import com.sap.ai.sdk.foundationmodels.openai.OpenAiTool;
1819
import java.util.ArrayList;
1920
import java.util.List;
2021
import java.util.stream.Stream;
@@ -96,24 +97,20 @@ public OpenAiChatCompletionResponse chatCompletionToolExecution(
9697
messages.add(OpenAiMessage.user("What's the weather in %s in %s?".formatted(location, unit)));
9798

9899
// 1. Define the function
99-
final var weatherFunction =
100-
new OpenAiFunctionTool(
101-
"weather", WeatherMethod.Request.class, WeatherMethod.Response.class)
102-
.withDescription("Get the weather for the given location");
100+
final List<OpenAiTool> tools =
101+
List.of(
102+
new OpenAiFunctionTool<>("weather", WeatherMethod::getCurrentWeather)
103+
.withDescription("Get the weather for the given location"));
103104

104105
// 2. Assistant calls the function
105-
final var request =
106-
new OpenAiChatCompletionRequest(messages).withOpenAiTools(List.of(weatherFunction));
106+
final var request = new OpenAiChatCompletionRequest(messages).withOpenAiTools(tools);
107107
final OpenAiChatCompletionResponse response =
108108
OpenAiClient.forModel(GPT_4O_MINI).chatCompletion(request);
109109
final OpenAiAssistantMessage assistantMessage = response.getMessage();
110-
111-
// 3. Execute the function
112-
113-
110+
114111
// 4. Send back the results, and the model will incorporate them into its final response.
115112
messages.add(assistantMessage);
116-
messages.add(OpenAiMessage.tool(currentWeather.toString(), functionCall.getId()));
113+
messages.addAll(response.executeTools(tools));
117114
return OpenAiClient.forModel(GPT_4O_MINI).chatCompletion(request.withMessages(messages));
118115
}
119116

0 commit comments

Comments
 (0)