|
1 | 1 | package com.sap.ai.sdk.foundationmodels.openai; |
2 | 2 |
|
3 | 3 | import static com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool.TypeEnum.FUNCTION; |
| 4 | +import static java.util.function.UnaryOperator.identity; |
4 | 5 |
|
| 6 | +import com.fasterxml.jackson.core.JsonProcessingException; |
5 | 7 | import com.fasterxml.jackson.core.type.TypeReference; |
| 8 | +import com.fasterxml.jackson.databind.ObjectMapper; |
6 | 9 | import com.github.victools.jsonschema.generator.Option; |
7 | 10 | import com.github.victools.jsonschema.generator.OptionPreset; |
8 | 11 | import com.github.victools.jsonschema.generator.SchemaGenerator; |
|
13 | 16 | import com.google.common.annotations.Beta; |
14 | 17 | import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool; |
15 | 18 | import com.sap.ai.sdk.foundationmodels.openai.generated.model.FunctionObject; |
| 19 | +import java.util.ArrayList; |
| 20 | +import java.util.LinkedHashMap; |
| 21 | +import java.util.List; |
16 | 22 | import java.util.Map; |
17 | 23 | import java.util.function.Function; |
| 24 | +import java.util.stream.Collectors; |
18 | 25 | import javax.annotation.Nonnull; |
19 | 26 | import javax.annotation.Nullable; |
20 | 27 | import lombok.AccessLevel; |
21 | 28 | import lombok.AllArgsConstructor; |
22 | 29 | import lombok.Data; |
23 | 30 | import lombok.Getter; |
| 31 | +import lombok.RequiredArgsConstructor; |
24 | 32 | import lombok.experimental.Accessors; |
| 33 | +import lombok.extern.slf4j.Slf4j; |
25 | 34 |
|
26 | 35 | /** |
27 | 36 | * Represents an OpenAI tool that can be used to define a function call in an OpenAI Chat Completion |
|
32 | 41 | * @see <a href="https://platform.openai.com/docs/guides/gpt/function-calling"/>OpenAI Function |
33 | 42 | * @since 1.7.0 |
34 | 43 | */ |
| 44 | +@Slf4j |
35 | 45 | @Beta |
36 | 46 | @Data |
37 | 47 | @Getter(AccessLevel.PACKAGE) |
38 | 48 | @Accessors(chain = true) |
39 | 49 | @AllArgsConstructor(access = AccessLevel.PRIVATE) |
40 | 50 | public class OpenAiTool<InputT> { |
41 | 51 |
|
| 52 | + private static final ObjectMapper JACKSON = new ObjectMapper(); |
| 53 | + |
42 | 54 | /** The schema generator used to create JSON schemas. */ |
43 | 55 | @Nonnull private static final SchemaGenerator GENERATOR = createSchemaGenerator(); |
44 | 56 |
|
@@ -71,7 +83,8 @@ public OpenAiTool(@Nonnull final String name, @Nonnull final Class<InputT> reque |
71 | 83 | @Nonnull |
72 | 84 | Object execute(@Nonnull final InputT argument) { |
73 | 85 | if (getFunction() == null) { |
74 | | - throw new IllegalStateException("No function configured to execute."); |
| 86 | + throw new IllegalStateException( |
| 87 | + "Tool " + name + " is missing a method reference to execute."); |
75 | 88 | } |
76 | 89 | return getFunction().apply(argument); |
77 | 90 | } |
@@ -102,4 +115,78 @@ private static SchemaGenerator createSchemaGenerator() { |
102 | 115 | .with(module) |
103 | 116 | .build()); |
104 | 117 | } |
| 118 | + |
| 119 | + /** |
| 120 | + * Executes the given tool calls with the provided tools and returns the results as a list of |
| 121 | + * {@link OpenAiToolMessage} containing execution results encoded as JSON string. |
| 122 | + * |
| 123 | + * @param tools the list of tools to execute |
| 124 | + * @param msg the assistant message containing a list of tool calls with arguments |
| 125 | + * @return a result object that contains the list of tool messages with the results |
| 126 | + * @throws IllegalStateException if a tool is missing a method reference for function execution. |
| 127 | + */ |
| 128 | + @Beta |
| 129 | + @Nonnull |
| 130 | + public static Execution execute( |
| 131 | + @Nonnull final List<OpenAiTool<?>> tools, @Nonnull final OpenAiAssistantMessage msg) |
| 132 | + throws IllegalArgumentException { |
| 133 | + final var result = new LinkedHashMap<OpenAiFunctionCall, Object>(); |
| 134 | + |
| 135 | + final var toolMap = tools.stream().collect(Collectors.toMap(OpenAiTool::getName, identity())); |
| 136 | + for (final OpenAiToolCall toolCall : msg.toolCalls()) { |
| 137 | + if (toolCall instanceof OpenAiFunctionCall functionCall) { |
| 138 | + final var tool = toolMap.get(functionCall.getName()); |
| 139 | + if (tool == null) { |
| 140 | + log.warn("Tool not found for function call: {}", functionCall.getName()); |
| 141 | + continue; |
| 142 | + } |
| 143 | + final var toolResult = executeFunction(tool, functionCall); |
| 144 | + result.put(functionCall, toolResult); |
| 145 | + } |
| 146 | + } |
| 147 | + return new Execution(result); |
| 148 | + } |
| 149 | + |
| 150 | + @Nonnull |
| 151 | + private static <I> Object executeFunction( |
| 152 | + @Nonnull final OpenAiTool<I> tool, @Nonnull final OpenAiFunctionCall toolCall) { |
| 153 | + final I arguments = toolCall.getArgumentsAsObject(tool.getRequestClass()); |
| 154 | + return tool.execute(arguments); |
| 155 | + } |
| 156 | + |
| 157 | + @Nonnull |
| 158 | + private static String serializeObject(@Nonnull final Object obj) throws IllegalArgumentException { |
| 159 | + try { |
| 160 | + return JACKSON.writeValueAsString(obj); |
| 161 | + } catch (JsonProcessingException e) { |
| 162 | + throw new IllegalArgumentException("Failed to serialize object to JSON", e); |
| 163 | + } |
| 164 | + } |
| 165 | + |
| 166 | + /** |
| 167 | + * Represents the result of executing a tool call, containing the results of the function calls. |
| 168 | + */ |
| 169 | + @RequiredArgsConstructor |
| 170 | + @Beta |
| 171 | + public static class Execution { |
| 172 | + @Getter @Beta @Nonnull private final Map<OpenAiFunctionCall, Object> results; |
| 173 | + |
| 174 | + /** |
| 175 | + * Creates a new list of serialized OpenAI tool messages. |
| 176 | + * |
| 177 | + * @return the list of serialized OpenAI tool messages. |
| 178 | + * @throws IllegalArgumentException if the tool results cannot be serialized to JSON |
| 179 | + */ |
| 180 | + @Beta |
| 181 | + @Nonnull |
| 182 | + public List<OpenAiToolMessage> getMessages() { |
| 183 | + final var result = new ArrayList<OpenAiToolMessage>(); |
| 184 | + for (final var entry : getResults().entrySet()) { |
| 185 | + final var functionCall = entry.getKey().getId(); |
| 186 | + final var serializedValue = serializeObject(entry.getValue()); |
| 187 | + result.add(OpenAiMessage.tool(serializedValue, functionCall)); |
| 188 | + } |
| 189 | + return result; |
| 190 | + } |
| 191 | + } |
105 | 192 | } |
0 commit comments