Skip to content

Commit 33f2ca8

Browse files
committed
refactor: update OpenAiFunctionTool to improve argument parsing and enhance type safety
- OpenAiFunctionTool getter package private - introduce getArgumentsAsObject(OpenAiFunctionTool)
1 parent 36ed5bd commit 33f2ca8

File tree

5 files changed

+51
-43
lines changed

5 files changed

+51
-43
lines changed

docs/release_notes.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
### ✨ New Functionality
1414

15-
-
15+
- [OpenAI] [Add convenience for tool definition and parsing function calls](https://sap.github.io/ai-sdk/docs/java/foundation-models/openai/chat-completion#executing-tool-calls)
1616

1717
### 📈 Improvements
1818

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiFunctionCall.java

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import static com.sap.ai.sdk.foundationmodels.openai.OpenAiUtils.getOpenAiObjectMapper;
44

55
import com.fasterxml.jackson.core.JsonProcessingException;
6+
import com.fasterxml.jackson.core.type.TypeReference;
67
import com.google.common.annotations.Beta;
8+
import java.lang.reflect.Type;
79
import java.util.Map;
810
import javax.annotation.Nonnull;
911
import lombok.AllArgsConstructor;
@@ -36,25 +38,40 @@ public class OpenAiFunctionCall implements OpenAiToolCall {
3638
*/
3739
@Nonnull
3840
public Map<String, Object> getArgumentsAsMap() throws IllegalArgumentException {
39-
return getArgumentsAsObject(Map.class);
41+
return parseArguments(new TypeReference<>() {});
4042
}
4143

4244
/**
43-
* Parses the arguments, encoded as a JSON string, into an object of the specified type.
45+
* Parses the arguments, encoded as a JSON string, into an object of type expected by a function
46+
* tool.
4447
*
45-
* @param clazz the class to convert the arguments to
48+
* @param tool the function tool the arguments are for
4649
* @param <T> the type of the class
4750
* @return the parsed arguments as an object
4851
* @throws IllegalArgumentException if parsing fails
4952
* @since 1.7.0
5053
*/
5154
@Nonnull
52-
public <T> T getArgumentsAsObject(@Nonnull final Class<T> clazz) throws IllegalArgumentException {
55+
public <T> T getArgumentsAsObject(@Nonnull final OpenAiFunctionTool tool)
56+
throws IllegalArgumentException {
57+
final var typeRef =
58+
new TypeReference<T>() {
59+
@Override
60+
public Type getType() {
61+
return tool.getRequestModel();
62+
}
63+
};
64+
return parseArguments(typeRef);
65+
}
66+
67+
@Nonnull
68+
private <T> T parseArguments(@Nonnull final TypeReference<T> typeReference)
69+
throws IllegalArgumentException {
5370
try {
54-
return getOpenAiObjectMapper().readValue(getArguments(), clazz);
71+
return getOpenAiObjectMapper().readValue(getArguments(), typeReference);
5572
} catch (JsonProcessingException e) {
5673
throw new IllegalArgumentException(
57-
"Failed to parse JSON string to class " + clazz.getTypeName(), e);
74+
"Failed to parse JSON string to class " + typeReference.getType(), e);
5875
}
5976
}
6077
}

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiFunctionTool.java

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import javax.annotation.Nullable;
1616
import lombok.AccessLevel;
1717
import lombok.AllArgsConstructor;
18+
import lombok.Getter;
1819
import lombok.Value;
1920
import lombok.With;
2021

@@ -29,14 +30,15 @@
2930
@Beta
3031
@Value
3132
@With
33+
@Getter(AccessLevel.PACKAGE)
3234
@AllArgsConstructor(access = AccessLevel.PRIVATE)
3335
public class OpenAiFunctionTool implements OpenAiTool {
3436

3537
/** The name of the function. */
3638
@Nonnull String name;
3739

3840
/** The model class for function request. */
39-
@Nonnull Class<?> clazz;
41+
@Nonnull Class<?> requestModel;
4042

4143
/** An optional description of the function. */
4244
@Nullable String description;
@@ -45,22 +47,24 @@ public class OpenAiFunctionTool implements OpenAiTool {
4547
@Nullable Boolean strict;
4648

4749
/**
48-
* Constructs an {@code OpenAiFunctionTool} with the specified name and request class.
50+
* Constructs an {@code OpenAiFunctionTool} with the specified name and a model class that
51+
* captures the request to the function.
4952
*
5053
* @param name the name of the function
51-
* @param clazz the model class for the function request
54+
* @param requestModel the model class for the function request
5255
*/
53-
public <T> OpenAiFunctionTool(@Nonnull final String name, @Nonnull final Class<T> clazz) {
54-
this(name, clazz, null, null);
56+
public <T> OpenAiFunctionTool(@Nonnull final String name, @Nonnull final Class<T> requestModel) {
57+
this(name, requestModel, null, null);
5558
}
5659

5760
ChatCompletionTool createChatCompletionTool() {
5861
final var objectMapper = new ObjectMapper();
5962
JsonSchema schema = null;
6063
try {
61-
schema = new JsonSchemaGenerator(objectMapper).generateSchema(clazz);
64+
schema = new JsonSchemaGenerator(objectMapper).generateSchema(requestModel);
6265
} catch (JsonMappingException e) {
63-
throw new IllegalArgumentException("Could not generate schema for " + clazz.getTypeName(), e);
66+
throw new IllegalArgumentException(
67+
"Could not generate schema for " + requestModel.getTypeName(), e);
6468
}
6569

6670
schema.setId(null);

foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiToolCallTest.java

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ class OpenAiToolCallTest {
1111
private static final OpenAiFunctionCall INVALID_FUNCTION_CALL =
1212
new OpenAiFunctionCall("1", "functionName", "{invalid-json}");
1313

14+
private static final OpenAiFunctionTool FUNCTION_TOOL =
15+
new OpenAiFunctionTool("functionName", DummyRequest.class);
16+
1417
record DummyRequest(String key) {}
1518

1619
@Test
@@ -28,21 +31,14 @@ void getArgumentsAsMapThrowsOnInvalidJson() {
2831

2932
@Test
3033
void getArgumentsAsObjectParsesValidJson() {
31-
var result = VALID_FUNCTION_CALL.getArgumentsAsObject(DummyRequest.class);
34+
var result = (DummyRequest) VALID_FUNCTION_CALL.getArgumentsAsObject(FUNCTION_TOOL);
3235
assertThat(result).isInstanceOf(DummyRequest.class);
3336
assertThat(result.key()).isEqualTo("value");
3437
}
3538

3639
@Test
3740
void getArgumentsAsObjectThrowsOnInvalidJson() {
38-
assertThatThrownBy(() -> INVALID_FUNCTION_CALL.getArgumentsAsObject(DummyRequest.class))
39-
.isInstanceOf(IllegalArgumentException.class)
40-
.hasMessageContaining("Failed to parse JSON string");
41-
}
42-
43-
@Test
44-
void getArgumentsAsObjectThrowsOnTypeMismatch() {
45-
assertThatThrownBy(() -> VALID_FUNCTION_CALL.getArgumentsAsObject(Integer.class))
41+
assertThatThrownBy(() -> INVALID_FUNCTION_CALL.getArgumentsAsObject(FUNCTION_TOOL))
4642
.isInstanceOf(IllegalArgumentException.class)
4743
.hasMessageContaining("Failed to parse JSON string");
4844
}

sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OpenAiServiceV2.java

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import static com.sap.ai.sdk.foundationmodels.openai.OpenAiModel.GPT_4O_MINI;
55
import static com.sap.ai.sdk.foundationmodels.openai.OpenAiModel.TEXT_EMBEDDING_3_SMALL;
66

7-
import com.fasterxml.jackson.databind.ObjectMapper;
87
import com.sap.ai.sdk.core.AiCoreService;
98
import com.sap.ai.sdk.foundationmodels.openai.OpenAiAssistantMessage;
109
import com.sap.ai.sdk.foundationmodels.openai.OpenAiChatCompletionDelta;
@@ -29,8 +28,6 @@
2928
@Service
3029
@Slf4j
3130
public class OpenAiServiceV2 {
32-
private static final ObjectMapper JACKSON = new ObjectMapper();
33-
3431
/**
3532
* Chat request to OpenAI
3633
*
@@ -95,39 +92,33 @@ public OpenAiChatCompletionResponse chatCompletionImage(@Nonnull final String li
9592
@Nonnull
9693
public OpenAiChatCompletionResponse chatCompletionToolExecution(
9794
@Nonnull final String location, @Nonnull final String unit) {
95+
final var messages = new ArrayList<OpenAiMessage>();
96+
messages.add(OpenAiMessage.user("What's the weather in %s in %s?".formatted(location, unit)));
9897

9998
// 1. Define the function
100-
final var openAiTool =
99+
final var weatherFunction =
101100
new OpenAiFunctionTool("weather", WeatherMethod.Request.class)
102101
.withDescription("Get the weather for the given location");
103102

104-
final var messages = new ArrayList<OpenAiMessage>();
105-
messages.add(OpenAiMessage.user("What's the weather in %s in %s?".formatted(location, unit)));
106-
107-
// Assistant will call the function
103+
// 2. Assistant calls the function
108104
final var request =
109-
new OpenAiChatCompletionRequest(messages).withOpenAiTools(List.of(openAiTool));
110-
105+
new OpenAiChatCompletionRequest(messages).withOpenAiTools(List.of(weatherFunction));
111106
final OpenAiChatCompletionResponse response =
112107
OpenAiClient.forModel(GPT_4O_MINI).chatCompletion(request);
113-
114-
// 2. Optionally, execute the function.
115108
final OpenAiAssistantMessage assistantMessage = response.getMessage();
116-
messages.add(assistantMessage);
117109

110+
// 3. Execute the function
118111
final OpenAiToolCall toolCall = assistantMessage.toolCalls().get(0);
119112
if (!(toolCall instanceof OpenAiFunctionCall functionCall)) {
120113
throw new IllegalArgumentException(
121114
"Expected a function call, but got: %s".formatted(assistantMessage));
122115
}
116+
final WeatherMethod.Request arguments = functionCall.getArgumentsAsObject(weatherFunction);
117+
final WeatherMethod.Response currentWeather = WeatherMethod.getCurrentWeather(arguments);
123118

124-
final WeatherMethod.Request arguments =
125-
functionCall.getArgumentsAsObject(WeatherMethod.Request.class);
126-
final WeatherMethod.Response weatherMethod = WeatherMethod.getCurrentWeather(arguments);
127-
128-
messages.add(OpenAiMessage.tool(weatherMethod.toString(), functionCall.getId()));
129-
130-
// Send back the results, and the model will incorporate them into its final response.
119+
// 4. Send back the results, and the model will incorporate them into its final response.
120+
messages.add(assistantMessage);
121+
messages.add(OpenAiMessage.tool(currentWeather.toString(), functionCall.getId()));
131122
return OpenAiClient.forModel(GPT_4O_MINI).chatCompletion(request.withMessages(messages));
132123
}
133124

0 commit comments

Comments
 (0)