Skip to content

Commit c52fafa

Browse files
committed
With purely explicit types
1 parent bcb42c8 commit c52fafa

File tree

7 files changed

+94
-49
lines changed

7 files changed

+94
-49
lines changed

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,8 @@ public OpenAiChatCompletionRequest withToolChoice(@Nonnull final OpenAiToolChoic
291291
* @since 1.7.0
292292
*/
293293
@Nonnull
294-
public OpenAiChatCompletionRequest withOpenAiTools(@Nonnull final List<OpenAiTool> tools) {
294+
public <I, O> OpenAiChatCompletionRequest withOpenAiTools(
295+
@Nonnull final List<OpenAiTool<I, O>> tools) {
295296
return this.withTools(tools.stream().map(OpenAiTool::createChatCompletionTool).toList());
296297
}
297298

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionResponse.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import static lombok.AccessLevel.NONE;
55
import static lombok.AccessLevel.PACKAGE;
66

7-
import com.fasterxml.jackson.core.type.TypeReference;
87
import com.google.common.annotations.Beta;
98
import com.sap.ai.sdk.foundationmodels.openai.generated.model.CompletionUsage;
109
import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionResponse;
@@ -99,15 +98,15 @@ public OpenAiAssistantMessage getMessage() {
9998
return new OpenAiAssistantMessage(new OpenAiMessageContent(contentItems), openAiToolCalls);
10099
}
101100

102-
public <T, R> List<OpenAiToolMessage> executeTools(List<OpenAiTool> tools) {
101+
public <T, R> List<OpenAiToolMessage> executeTools(List<OpenAiTool<T, R>> tools) {
103102
var toolMessages = new ArrayList<OpenAiToolMessage>();
104103

105104
for (var toolCall : getMessage().toolCalls()) {
106105
if (toolCall instanceof OpenAiFunctionCall functionCall) {
107106
for (OpenAiTool<T, R> tool : tools) {
108107
if (tool.getName().equals(functionCall.getName())) {
109-
T arguments = functionCall.getArgumentsAsObject(new TypeReference<T>() {});
110-
R response = tool.call(arguments);
108+
T arguments = functionCall.getArgumentsAsObject(tool);
109+
R response = tool.execute(arguments);
111110
toolMessages.add(OpenAiMessage.tool(response.toString(), functionCall.getId()));
112111
}
113112
}

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiFunctionCall.java

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import com.fasterxml.jackson.core.JsonProcessingException;
66
import com.fasterxml.jackson.core.type.TypeReference;
77
import com.google.common.annotations.Beta;
8+
import java.lang.reflect.Type;
89
import java.util.Map;
910
import javax.annotation.Nonnull;
1011
import lombok.AllArgsConstructor;
@@ -37,17 +38,40 @@ public class OpenAiFunctionCall implements OpenAiToolCall {
3738
*/
3839
@Nonnull
3940
public Map<String, Object> getArgumentsAsMap() throws IllegalArgumentException {
40-
return getArgumentsAsObject(new TypeReference<>() {});
41+
return parseArguments(new TypeReference<>() {});
42+
}
43+
44+
/**
45+
* Parses the arguments, encoded as a JSON string, into an object of type expected by a function
46+
* tool.
47+
*
48+
* @param tool the function tool the arguments are for
49+
* @param <T> the type of the class
50+
* @return the parsed arguments as an object
51+
* @throws IllegalArgumentException if parsing fails
52+
* @since 1.7.0
53+
*/
54+
@Nonnull
55+
public <T> T getArgumentsAsObject(@Nonnull final OpenAiTool<T, ?> tool)
56+
throws IllegalArgumentException {
57+
final var typeRef =
58+
new TypeReference<T>() {
59+
@Override
60+
public Type getType() {
61+
return tool.getRequestClass();
62+
}
63+
};
64+
return parseArguments(typeRef);
4165
}
4266

4367
@Nonnull
44-
<T> T getArgumentsAsObject(@Nonnull final TypeReference<T> typeReference)
68+
private <T> T parseArguments(@Nonnull final TypeReference<T> typeReference)
4569
throws IllegalArgumentException {
4670
try {
47-
return getOpenAiObjectMapper().readValue(arguments, typeReference);
71+
return getOpenAiObjectMapper().readValue(getArguments(), typeReference);
4872
} catch (JsonProcessingException e) {
4973
throw new IllegalArgumentException(
50-
"Failed to parse JSON string to class " + typeReference, e);
74+
"Failed to parse JSON string to class " + typeReference.getType(), e);
5175
}
5276
}
5377
}

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiTool.java

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
import javax.annotation.Nullable;
1717
import lombok.AccessLevel;
1818
import lombok.AllArgsConstructor;
19+
import lombok.Data;
1920
import lombok.Getter;
20-
import lombok.Value;
21-
import lombok.With;
21+
import lombok.Setter;
22+
import lombok.experimental.Accessors;
2223

2324
/**
2425
* Represents an OpenAI function tool that can be used to define a function call in an OpenAI Chat
@@ -29,38 +30,67 @@
2930
* @since 1.7.0
3031
*/
3132
@Beta
32-
@Value
33-
@With
33+
@Data
3434
@Getter(AccessLevel.PACKAGE)
35+
@Accessors(chain = true)
3536
@AllArgsConstructor(access = AccessLevel.PRIVATE)
3637
public class OpenAiTool<T, R> {
3738

3839
/** The name of the function. */
3940
@Nonnull String name;
4041

4142
/** The model class for function request. */
42-
@Nonnull Function<T, R> function;
43+
@Nonnull Class<T> requestClass;
4344

4445
/** An optional description of the function. */
4546
@Nullable String description;
4647

4748
/** An optional flag indicating whether the function parameters should be treated strictly. */
4849
@Nullable Boolean strict;
4950

51+
/** The function to be called. */
52+
@Setter(AccessLevel.NONE)
53+
@Nullable
54+
Function<T, R> function;
55+
56+
/** The response class for the function. */
57+
@Setter(AccessLevel.NONE)
58+
@Nullable
59+
Class<R> responseClass;
60+
61+
public static <I, O> OpenAiTool<I, O> of(@Nonnull String name, @Nonnull Class<I> requestClass) {
62+
return new OpenAiTool<>(name, requestClass);
63+
}
64+
5065
/**
5166
* Constructs an {@code OpenAiFunctionTool} with the specified name and a model class that
5267
* captures the request to the function.
5368
*
5469
* @param name the name of the function
55-
* @param function the model class for function request
70+
* @param requestClass the model class for function request
71+
*/
72+
private OpenAiTool(@Nonnull final String name, @Nonnull final Class<T> requestClass) {
73+
this(name, requestClass, null, null, null, null);
74+
}
75+
76+
/**
77+
* Sets the function to be called and the response class for the function.
78+
*
79+
* @param function the function to be called
80+
* @param responseClass the response class for the function
81+
* @return this instance of {@code OpenAiFunctionTool}
5682
*/
57-
public OpenAiTool(@Nonnull final String name, @Nonnull final Function<T, R> function) {
58-
this(name, function, null, null);
83+
@Nonnull
84+
public OpenAiTool<T, R> setCallback(
85+
@Nonnull final Function<T, R> function, @Nonnull final Class<R> responseClass) {
86+
this.function = function;
87+
this.responseClass = responseClass;
88+
return this;
5989
}
6090

6191
@Nonnull
62-
R call(@Nonnull final T request) {
63-
return function.apply(request);
92+
R execute(@Nonnull final T argument) {
93+
return function.apply(argument);
6494
}
6595

6696
ChatCompletionTool createChatCompletionTool() {
@@ -74,7 +104,6 @@ ChatCompletionTool createChatCompletionTool() {
74104
throw new IllegalArgumentException("Could not generate schema for " + name, e);
75105
}
76106

77-
schema.setId(null);
78107
final var schemaMap =
79108
objectMapper.convertValue(schema, new TypeReference<Map<String, Object>>() {});
80109

foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequestTest.java

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,12 @@ record DummyResponse(String result) {}
130130
new OpenAiChatCompletionRequest(OpenAiMessage.user("Hello, world"))
131131
.withOpenAiTools(
132132
List.of(
133-
new OpenAiTool("toolA", conCat).withDescription("descA").withStrict(true),
134-
new OpenAiTool("toolB", conCat).withDescription("descB").withStrict(false)));
133+
OpenAiTool.of("toolA", DummyRequest.class)
134+
.setDescription("descA")
135+
.setStrict(true),
136+
OpenAiTool.of("toolB", DummyRequest.class)
137+
.setDescription("descB")
138+
.setStrict(false)));
135139

136140
var lowLevelRequest = request.createCreateChatCompletionRequest();
137141
assertThat(lowLevelRequest.getTools()).hasSize(2);
@@ -142,31 +146,18 @@ record DummyResponse(String result) {}
142146
assertThat(toolA.getFunction().getName()).isEqualTo("toolA");
143147
assertThat(toolA.getFunction().getDescription()).isEqualTo("descA");
144148
assertThat(toolA.getFunction().isStrict()).isTrue();
145-
/// {"id"="urn:jsonschema:com:sap:ai:sdk:foundationmodels:openai:OpenAiTool:1",
146-
// "properties"={"type"={"id"="urn:jsonschema:java:lang:reflect:Type",
147-
// "properties"={"typeName"={"type"="string"}}, "type"="object"}}, "type"="object"}
149+
148150
assertThat(toolA.getFunction().getParameters())
149151
.isEqualTo(
150152
Map.of(
153+
"id", "urn:jsonschema:com:sap:ai:sdk:foundationmodels:openai:OpenAiTool:1",
151154
"properties",
152-
Map.of(
153-
"type",
154155
Map.of(
155-
"properties",
156-
Map.of("typeName", Map.of("type", "string")),
157156
"type",
158-
"object")),
159-
"type",
160-
"object"));
161-
assertThat(toolA.getFunction().getParameters())
162-
.isEqualTo(Map.of("type", "string", "typeName", "java.lang.reflect.Type"));
163-
164-
var toolB = lowLevelRequest.getTools().get(1);
165-
assertThat(toolB).isInstanceOf(ChatCompletionTool.class);
166-
assertThat(toolB.getType()).isEqualTo(ChatCompletionTool.TypeEnum.FUNCTION);
167-
assertThat(toolB.getFunction().getName()).isEqualTo("toolB");
168-
assertThat(toolB.getFunction().getDescription()).isEqualTo("descB");
169-
assertThat(toolB.getFunction().isStrict()).isFalse();
170-
assertThat(toolB.getFunction().getParameters()).isEqualTo(Map.of("type", "string"));
157+
Map.of(
158+
"id", "urn:jsonschema:java:lang:reflect:Type",
159+
"properties", Map.of("typeName", Map.of("type", "string")),
160+
"type", "object")),
161+
"type", "object"));
171162
}
172163
}

foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiToolCallTest.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ record Response(String result) {}
2222
}
2323

2424
private static final OpenAiTool<Dummy.Request, Dummy.Request> TOOL =
25-
new OpenAiTool("functionName", Dummy.conCat);
25+
OpenAiTool.of("functionName", Dummy.Request.class);
2626

2727
@Test
2828
void getArgumentsAsMapParsesValidJson() {
@@ -39,15 +39,14 @@ void getArgumentsAsMapThrowsOnInvalidJson() {
3939

4040
@Test
4141
void getArgumentsAsObjectParsesValidJson() {
42-
var result = VALID_FUNCTION_CALL.getArgumentsAsObject(new TypeReference<Dummy.Request>() {});
42+
var result = VALID_FUNCTION_CALL.getArgumentsAsObject(TOOL);
4343
assertThat(result).isInstanceOf(Dummy.Request.class);
4444
assertThat(result.key()).isEqualTo("value");
4545
}
4646

4747
@Test
4848
void getArgumentsAsObjectThrowsOnInvalidJson() {
49-
assertThatThrownBy(
50-
() -> INVALID_FUNCTION_CALL.getArgumentsAsObject(new TypeReference<Dummy.Request>() {}))
49+
assertThatThrownBy(() -> INVALID_FUNCTION_CALL.getArgumentsAsObject(TOOL))
5150
.isInstanceOf(IllegalArgumentException.class)
5251
.hasMessageContaining("Failed to parse JSON string");
5352
}

sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OpenAiServiceV2.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,12 @@ public OpenAiChatCompletionResponse chatCompletionToolExecution(
9696
messages.add(OpenAiMessage.user("What's the weather in %s in %s?".formatted(location, unit)));
9797

9898
// 1. Define the function
99-
final List<OpenAiTool> tools =
99+
final var tools =
100100
List.of(
101-
new OpenAiTool<>("weather", WeatherMethod::getCurrentWeather)
102-
.withDescription("Get the weather for the given location"));
101+
OpenAiTool.<WeatherMethod.Request, WeatherMethod.Response>of(
102+
"weather", WeatherMethod.Request.class)
103+
.setDescription("Get the weather for the given location")
104+
.setCallback(WeatherMethod::getCurrentWeather, WeatherMethod.Response.class));
103105

104106
// 2. Assistant calls the function
105107
final var request = new OpenAiChatCompletionRequest(messages).withOpenAiTools(tools);

0 commit comments

Comments
 (0)