-
Notifications
You must be signed in to change notification settings - Fork 16
feat: [OpenAI] Tool Definition and Call Parsing V2 (Incl Execution) #418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 24 commits
6ebb2db
f7e9088
75fe7dd
e726226
484c14d
58792c6
cd03dac
c92954a
03453a4
36ed5bd
33f2ca8
87faa45
7ce6aef
bfce048
57ba2a1
36b0189
bcb42c8
c52fafa
66cef01
3038eb3
c53b04e
27fa975
e8764c8
9d3684d
ad89753
f8e3645
2da40a0
c4bfa20
d15ee63
91433a2
2239ad7
f297fed
7baf530
9f96e1a
ce3ba95
3667b1e
2de2592
454f7c5
ed01c22
1f4d2a2
6bd64ff
ca7fc25
3e69967
8f2346d
c21ef75
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,93 @@ | ||
| package com.sap.ai.sdk.foundationmodels.openai; | ||
|
|
||
| import static com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool.TypeEnum.FUNCTION; | ||
|
|
||
| import com.fasterxml.jackson.core.type.TypeReference; | ||
| import com.fasterxml.jackson.databind.JsonMappingException; | ||
| import com.fasterxml.jackson.databind.ObjectMapper; | ||
| import com.fasterxml.jackson.module.jsonSchema.JsonSchema; | ||
| import com.fasterxml.jackson.module.jsonSchema.JsonSchemaGenerator; | ||
| import com.google.common.annotations.Beta; | ||
| import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool; | ||
| import com.sap.ai.sdk.foundationmodels.openai.generated.model.FunctionObject; | ||
| import java.util.Map; | ||
| import java.util.function.Function; | ||
| import javax.annotation.Nonnull; | ||
| import javax.annotation.Nullable; | ||
| import lombok.AccessLevel; | ||
| import lombok.AllArgsConstructor; | ||
| import lombok.Data; | ||
| import lombok.Getter; | ||
| import lombok.experimental.Accessors; | ||
|
|
||
| /** | ||
| * Represents an OpenAI function tool that can be used to define a function call in an OpenAI Chat | ||
| * Completion request. This tool generates a JSON schema based on the provided class representing | ||
| * the function's request structure. | ||
| * | ||
| * @param <I> the type of the input argument for the function | ||
| * @see <a href="https://platform.openai.com/docs/guides/gpt/function-calling"/>OpenAI Function | ||
| * @since 1.7.0 | ||
| */ | ||
| @Beta | ||
| @Data | ||
| @Getter(AccessLevel.PACKAGE) | ||
| @Accessors(chain = true) | ||
| @AllArgsConstructor(access = AccessLevel.PRIVATE) | ||
| public class OpenAiTool<I> { | ||
rpanackal marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After discussing with @CharlesDuboisSAP, we decided to keep OpenAiTools as a class and not as an interface. This approach conflicts with design of tool calls as
The initial motivation was to have an interface that will implemented for possible new tool types other than function. Since, we have don't enough information to back this possibility and to avoid casting, we decided to go with the current state. Now, the question remains whether we also align tool call's design to match the current state with a breaking change - OpenAiFunctionCall <- OpenAiToolCall(interface)
+ OpenAiToolCall (class) |
||
| /** The name of the function. */ | ||
| private @Nonnull String name; | ||
|
|
||
| /** The model class for function request. */ | ||
| private @Nonnull Class<I> requestClass; | ||
|
|
||
| /** An optional description of the function. */ | ||
| private @Nullable String description; | ||
|
|
||
| /** An optional flag indicating whether the function parameters should be treated strictly. */ | ||
| private @Nullable Boolean strict; | ||
|
|
||
| /** The function to be called. */ | ||
| private @Nullable Function<I, ?> function; | ||
|
||
|
|
||
| /** | ||
| * Constructs an {@code OpenAiFunctionTool} with the specified name and a model class that | ||
| * captures the request to the function. | ||
| * | ||
| * @param name the name of the function | ||
| * @param requestClass the model class for function request | ||
| */ | ||
| public OpenAiTool(@Nonnull final String name, @Nonnull final Class<I> requestClass) { | ||
| this(name, requestClass, null, null, null); | ||
| } | ||
|
|
||
| @Nonnull | ||
| Object execute(@Nonnull final I argument) { | ||
| if (getFunction() == null) { | ||
| throw new IllegalStateException("Function must not be set to execute the tool."); | ||
| } | ||
| return getFunction().apply(argument); | ||
| } | ||
|
|
||
| ChatCompletionTool createChatCompletionTool() { | ||
| final var objectMapper = new ObjectMapper(); | ||
| JsonSchema schema = null; | ||
| try { | ||
| schema = new JsonSchemaGenerator(objectMapper).generateSchema(getRequestClass()); | ||
| } catch (JsonMappingException e) { | ||
| throw new IllegalArgumentException("Could not generate schema for " + getRequestClass(), e); | ||
| } | ||
|
|
||
| final var schemaMap = | ||
| objectMapper.convertValue(schema, new TypeReference<Map<String, Object>>() {}); | ||
|
|
||
| final var function = | ||
| new FunctionObject() | ||
| .name(getName()) | ||
| .description(getDescription()) | ||
| .parameters(schemaMap) | ||
| .strict(getStrict()); | ||
| return new ChatCompletionTool().type(FUNCTION).function(function); | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,68 @@ | ||
| package com.sap.ai.sdk.foundationmodels.openai; | ||
|
|
||
| import static lombok.AccessLevel.PRIVATE; | ||
|
|
||
| import com.fasterxml.jackson.core.JsonProcessingException; | ||
| import com.fasterxml.jackson.databind.ObjectMapper; | ||
| import com.google.common.annotations.Beta; | ||
| import java.util.List; | ||
| import java.util.stream.Collectors; | ||
| import javax.annotation.Nonnull; | ||
| import lombok.AllArgsConstructor; | ||
|
|
||
| /** | ||
| * A class for OpenAI tool execution. | ||
| * | ||
| * @since 1.7.0 | ||
| */ | ||
| @Beta | ||
| @AllArgsConstructor(access = PRIVATE) | ||
| public class OpenAiToolExecutor { | ||
newtork marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| private static final ObjectMapper JACKSON = new ObjectMapper(); | ||
|
|
||
| /** | ||
| * Executes the given tool calls with the provided tools and returns the results as a list of | ||
| * {@link OpenAiToolMessage} containing execution results encoded as JSON string. | ||
| * | ||
| * @param tools the list of tools to execute | ||
| * @param toolCalls the list of tool calls with arguments | ||
| * @return the list of tool messages with the results | ||
| * @throws IllegalArgumentException if the tool results cannot be serialized to JSON | ||
| */ | ||
| @Nonnull | ||
| public static List<OpenAiToolMessage> executeTools( | ||
| @Nonnull final List<OpenAiTool<?>> tools, @Nonnull final List<OpenAiToolCall> toolCalls) | ||
| throws IllegalArgumentException { | ||
|
|
||
| final var toolMap = tools.stream().collect(Collectors.toMap(OpenAiTool::getName, tool -> tool)); | ||
|
|
||
| return toolCalls.stream() | ||
| .filter(OpenAiFunctionCall.class::isInstance) | ||
| .map(OpenAiFunctionCall.class::cast) | ||
| .filter(functionCall -> toolMap.containsKey(functionCall.getName())) | ||
| .map( | ||
| functionCall -> { | ||
| final var tool = toolMap.get(functionCall.getName()); | ||
| final var result = executeFunction(tool, functionCall); | ||
| return OpenAiMessage.tool(serializeObject(result), functionCall.getId()); | ||
| }) | ||
| .toList(); | ||
| } | ||
|
|
||
| @Nonnull | ||
| private static <I> Object executeFunction( | ||
| @Nonnull final OpenAiTool<I> tool, @Nonnull final OpenAiFunctionCall toolCall) { | ||
| final I arguments = toolCall.getArgumentsAsObject(tool); | ||
| return tool.execute(arguments); | ||
| } | ||
|
|
||
| @Nonnull | ||
| private static String serializeObject(@Nonnull final Object obj) throws IllegalArgumentException { | ||
| try { | ||
| return JACKSON.writeValueAsString(obj); | ||
| } catch (JsonProcessingException e) { | ||
| throw new IllegalArgumentException("Failed to serialize object to JSON", e); | ||
| } | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.