Skip to content

Commit 9f96e1a

Browse files
committed
Remove generic type argument from OpenAiTool class; Migrate error-prone constructor to builder pattern that enforces required values
1 parent 7baf530 commit 9f96e1a

File tree

5 files changed

+105
-65
lines changed

5 files changed

+105
-65
lines changed

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ public OpenAiChatCompletionRequest withToolChoice(@Nonnull final OpenAiToolChoic
292292
*/
293293
@Nonnull
294294
@Beta
295-
public OpenAiChatCompletionRequest withToolsExecutable(@Nonnull final List<OpenAiTool<?>> tools) {
295+
public OpenAiChatCompletionRequest withToolsExecutable(@Nonnull final List<OpenAiTool> tools) {
296296
return this.withTools(tools.stream().map(OpenAiTool::createChatCompletionTool).toList());
297297
}
298298

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiTool.java

Lines changed: 74 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import com.fasterxml.jackson.core.JsonProcessingException;
77
import com.fasterxml.jackson.core.type.TypeReference;
88
import com.fasterxml.jackson.databind.ObjectMapper;
9+
import com.fasterxml.jackson.databind.node.ObjectNode;
910
import com.github.victools.jsonschema.generator.Option;
1011
import com.github.victools.jsonschema.generator.OptionPreset;
1112
import com.github.victools.jsonschema.generator.SchemaGenerator;
@@ -26,74 +27,114 @@
2627
import javax.annotation.Nullable;
2728
import lombok.AccessLevel;
2829
import lombok.AllArgsConstructor;
29-
import lombok.Data;
3030
import lombok.Getter;
3131
import lombok.RequiredArgsConstructor;
32-
import lombok.experimental.Accessors;
32+
import lombok.Setter;
33+
import lombok.Value;
34+
import lombok.With;
3335
import lombok.extern.slf4j.Slf4j;
3436

3537
/**
3638
* Represents an OpenAI tool that can be used to define a function call in an OpenAI Chat Completion
3739
* request. This tool generates a JSON schema based on the provided class representing the
3840
* function's request structure.
3941
*
40-
* @param <InputT> the type of the input argument for the function
4142
* @see <a href="https://platform.openai.com/docs/guides/gpt/function-calling"/>OpenAI Function
4243
* @since 1.7.0
4344
*/
4445
@Slf4j
4546
@Beta
46-
@Data
47+
@Value
48+
@With
4749
@Getter(AccessLevel.PACKAGE)
48-
@Accessors(chain = true)
4950
@AllArgsConstructor(access = AccessLevel.PRIVATE)
50-
public class OpenAiTool<InputT> {
51+
public class OpenAiTool {
5152

5253
private static final ObjectMapper JACKSON = new ObjectMapper();
5354

5455
/** The schema generator used to create JSON schemas. */
5556
@Nonnull private static final SchemaGenerator GENERATOR = createSchemaGenerator();
5657

5758
/** The name of the function. */
58-
@Nonnull private String name;
59+
@Setter(AccessLevel.NONE)
60+
@Nonnull
61+
String name;
62+
63+
/** The function to execute a string argument to tool result object. */
64+
@Setter(AccessLevel.NONE)
65+
@Nonnull
66+
Function<String, Object> functionExecutor;
5967

60-
/** The model class for function request. */
61-
@Nonnull private Class<InputT> requestClass;
68+
/** schema to be used for the function call. */
69+
@Setter(AccessLevel.NONE)
70+
@Nonnull
71+
ObjectNode schema;
6272

6373
/** An optional description of the function. */
64-
@Nullable private String description;
74+
@Nullable String description;
6575

6676
/** An optional flag indicating whether the function parameters should be treated strictly. */
67-
@Nullable private Boolean strict;
77+
@Nullable Boolean strict;
6878

69-
/** The function to be called. */
70-
@Nullable private Function<InputT, ?> function;
79+
/**
80+
* Instantiates a OpenAiTool builder instance on behalf of an executable function.
81+
*
82+
* @param function the function to be executed.
83+
* @return an OpenAiTool builder instance.
84+
* @param <InputT> the type of the function input-argument class.
85+
*/
86+
@Nonnull
87+
public static <InputT> Builder1<InputT> forFunction(@Nonnull final Function<InputT, ?> function) {
88+
return inputClass ->
89+
name -> {
90+
final Function<String, Object> exec =
91+
s -> function.apply(deserializeArgument(inputClass, s));
92+
final var schema = GENERATOR.generateSchema(inputClass);
93+
return new OpenAiTool(name, exec, schema, null, null);
94+
};
95+
}
7196

7297
/**
73-
* Constructs an {@code OpenAiFunctionTool} with the specified name and a model class that
74-
* captures the request to the function.
98+
* Creates a new OpenAiTool instance with the specified function and input class.
7599
*
76-
* @param name the name of the function
77-
* @param requestClass the model class for function request
100+
* @param <InputT> the type of the input class.
78101
*/
79-
public OpenAiTool(@Nonnull final String name, @Nonnull final Class<InputT> requestClass) {
80-
this(name, requestClass, null, null, null);
102+
public interface Builder1<InputT> {
103+
/**
104+
* Sets the name of the function.
105+
*
106+
* @param inputClass the class of the input object.
107+
* @return a new OpenAiTool instance with the specified function and input class.
108+
*/
109+
@Nonnull
110+
Builder2 withArgument(@Nonnull final Class<InputT> inputClass);
81111
}
82112

83-
@Nonnull
84-
Object execute(@Nonnull final InputT argument) {
85-
if (getFunction() == null) {
86-
throw new IllegalStateException(
87-
"Tool " + name + " is missing a method reference to execute.");
113+
/** Creates a new OpenAiTool instance with the specified name. */
114+
public interface Builder2 {
115+
/**
116+
* Sets the name of the function.
117+
*
118+
* @param name the name of the function
119+
* @return a new OpenAiTool instance with the specified name
120+
*/
121+
@Nonnull
122+
OpenAiTool withName(@Nonnull final String name);
123+
}
124+
125+
@Nullable
126+
private static <T> T deserializeArgument(@Nonnull final Class<T> cl, @Nonnull final String s) {
127+
try {
128+
return JACKSON.readValue(s, cl);
129+
} catch (JsonProcessingException e) {
130+
throw new IllegalArgumentException("Failed to parse JSON string to class " + cl, e);
88131
}
89-
return getFunction().apply(argument);
90132
}
91133

92134
ChatCompletionTool createChatCompletionTool() {
93-
final var schema = GENERATOR.generateSchema(getRequestClass());
94135
final var schemaMap =
95136
OpenAiUtils.getOpenAiObjectMapper()
96-
.convertValue(schema, new TypeReference<Map<String, Object>>() {});
137+
.convertValue(getSchema(), new TypeReference<Map<String, Object>>() {});
97138

98139
return new ChatCompletionTool()
99140
.type(FUNCTION)
@@ -128,7 +169,7 @@ private static SchemaGenerator createSchemaGenerator() {
128169
@Beta
129170
@Nonnull
130171
public static Execution execute(
131-
@Nonnull final List<OpenAiTool<?>> tools, @Nonnull final OpenAiAssistantMessage msg)
172+
@Nonnull final List<OpenAiTool> tools, @Nonnull final OpenAiAssistantMessage msg)
132173
throws IllegalArgumentException {
133174
final var result = new LinkedHashMap<OpenAiFunctionCall, Object>();
134175

@@ -148,10 +189,11 @@ public static Execution execute(
148189
}
149190

150191
@Nonnull
151-
private static <I> Object executeFunction(
152-
@Nonnull final OpenAiTool<I> tool, @Nonnull final OpenAiFunctionCall toolCall) {
153-
final I arguments = toolCall.getArgumentsAsObject(tool.getRequestClass());
154-
return tool.execute(arguments);
192+
private static Object executeFunction(
193+
@Nonnull final OpenAiTool tool, @Nonnull final OpenAiFunctionCall toolCall) {
194+
final Function<String, Object> executor = tool.getFunctionExecutor();
195+
final String arguments = toolCall.getArguments();
196+
return executor.apply(arguments);
155197
}
156198

157199
@Nonnull

foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequestTest.java

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,16 @@ record DummyRequest(String param1, int param2) {}
125125
new OpenAiChatCompletionRequest(OpenAiMessage.user("Hello, world"))
126126
.withToolsExecutable(
127127
List.of(
128-
new OpenAiTool<>("toolA", DummyRequest.class)
129-
.setDescription("descA")
130-
.setStrict(true),
131-
new OpenAiTool<>("toolB", DummyRequest.class)
132-
.setDescription("descB")
133-
.setStrict(false)));
128+
OpenAiTool.<DummyRequest>forFunction(r -> "result")
129+
.withArgument(DummyRequest.class)
130+
.withName("toolA")
131+
.withDescription("descA")
132+
.withStrict(true),
133+
OpenAiTool.<DummyRequest>forFunction(r -> "result")
134+
.withArgument(DummyRequest.class)
135+
.withName("toolB")
136+
.withDescription("descB")
137+
.withStrict(true)));
134138

135139
var lowLevelRequest = request.createCreateChatCompletionRequest();
136140
assertThat(lowLevelRequest.getTools()).hasSize(2);

foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiToolTest.java

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import java.util.Map;
99
import java.util.function.Function;
1010
import lombok.EqualsAndHashCode;
11-
import org.junit.jupiter.api.BeforeEach;
1211
import org.junit.jupiter.api.Test;
1312

1413
class OpenAiToolTest {
@@ -31,13 +30,6 @@ record Response(String toolMsg) {}
3130
request -> new Dummy.Response(request.key());
3231
}
3332

34-
private OpenAiTool<Dummy.Request> toolA;
35-
36-
@BeforeEach
37-
void setUp() {
38-
toolA = new OpenAiTool<>("functionA", Dummy.Request.class);
39-
}
40-
4133
@Test
4234
void getArgumentsAsMapValid() {
4335
final var result = FUNCTION_CALL_A.getArgumentsAsMap();
@@ -53,21 +45,24 @@ void getArgumentsAsMapInvalid() {
5345

5446
@Test
5547
void getArgumentsAsObjectValid() {
56-
final Dummy.Request result = FUNCTION_CALL_A.getArgumentsAsObject(toolA.getRequestClass());
48+
final Dummy.Request result = FUNCTION_CALL_A.getArgumentsAsObject(Dummy.Request.class);
5749
assertThat(result).isInstanceOf(Dummy.Request.class);
5850
assertThat(result.key()).isEqualTo("value");
5951
}
6052

6153
@Test
6254
void getArgumentsAsObjectInvalid() {
63-
assertThatThrownBy(() -> INVALID_FUNCTION_CALL_A.getArgumentsAsObject(toolA.getRequestClass()))
55+
assertThatThrownBy(() -> INVALID_FUNCTION_CALL_A.getArgumentsAsObject(Integer.class))
6456
.isInstanceOf(IllegalArgumentException.class)
6557
.hasMessageContaining("Failed to parse JSON string");
6658
}
6759

6860
@Test
6961
void executeToolsValid() {
70-
toolA.setFunction(Dummy.conCat);
62+
final var toolA =
63+
OpenAiTool.forFunction(Dummy.conCat)
64+
.withArgument(Dummy.Request.class)
65+
.withName("functionA");
7166
final var assistMsg = new OpenAiAssistantMessage(EMPTY_MSG_CONTENT, List.of(FUNCTION_CALL_A));
7267
final var execution = OpenAiTool.execute(List.of(toolA), assistMsg);
7368

@@ -83,19 +78,14 @@ void executeToolsValid() {
8378
.isEqualTo("{\"toolMsg\":\"value\"}");
8479
}
8580

86-
@Test
87-
void executeToolsThrowsOnNoFunction() {
88-
final var assistMsg = new OpenAiAssistantMessage(EMPTY_MSG_CONTENT, List.of(FUNCTION_CALL_A));
89-
assertThatThrownBy(() -> OpenAiTool.execute(List.of(toolA), assistMsg))
90-
.isInstanceOf(IllegalStateException.class)
91-
.hasMessageContaining("Tool functionA is missing a method reference to execute.");
92-
}
93-
9481
@Test
9582
void executeToolsNoMatchingCall() {
96-
final var toolAWithFunction = toolA.setFunction(Dummy.conCat);
83+
final var toolA =
84+
OpenAiTool.forFunction(Dummy.conCat)
85+
.withArgument(Dummy.Request.class)
86+
.withName("functionA");
9787
final var assistMsg = new OpenAiAssistantMessage(EMPTY_MSG_CONTENT, List.of(FUNCTION_CALL_B));
98-
final var executions = OpenAiTool.execute(List.of(toolAWithFunction), assistMsg);
88+
final var executions = OpenAiTool.execute(List.of(toolA), assistMsg);
9989
assertThat(executions.getResults()).isEmpty();
10090
assertThat(executions.getMessages()).isEmpty();
10191
}
@@ -111,7 +101,10 @@ class NonSerializableResponse {
111101
}
112102
}
113103

114-
toolA.setFunction(request -> new NonSerializableResponse(request.key()));
104+
final Function<Dummy.Request, Object> badF =
105+
request -> new NonSerializableResponse(request.key());
106+
final var toolA =
107+
OpenAiTool.forFunction(badF).withArgument(Dummy.Request.class).withName("functionA");
115108
final var assistMsg = new OpenAiAssistantMessage(EMPTY_MSG_CONTENT, List.of(FUNCTION_CALL_A));
116109
final var executions = OpenAiTool.execute(List.of(toolA), assistMsg);
117110

sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OpenAiServiceV2.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,12 @@ public OpenAiChatCompletionResponse chatCompletionToolExecution(
9999
messages.add(OpenAiMessage.user("What's the weather in %s in %s?".formatted(location, unit)));
100100

101101
// 1. Define the function
102-
final List<OpenAiTool<?>> tools =
102+
final List<OpenAiTool> tools =
103103
List.of(
104-
new OpenAiTool<>("weather", WeatherMethod.Request.class)
105-
.setDescription("Get the weather for the given location")
106-
.setFunction(WeatherMethod::getCurrentWeather));
104+
OpenAiTool.forFunction(WeatherMethod::getCurrentWeather)
105+
.withArgument(WeatherMethod.Request.class)
106+
.withName("weather")
107+
.withDescription("Get the weather for the given location"));
107108

108109
// 2. Assistant calls the function
109110
final var request = new OpenAiChatCompletionRequest(messages).withToolsExecutable(tools);

0 commit comments

Comments
 (0)