diff --git a/docs/release_notes.md b/docs/release_notes.md index aff176462..8167109df 100644 --- a/docs/release_notes.md +++ b/docs/release_notes.md @@ -12,6 +12,7 @@ ### ✨ New Functionality +- [OpenAI] [Add convenience for tool definition, parsing function calls and tool execution](https://sap.github.io/ai-sdk/docs/java/foundation-models/openai/chat-completion#executing-tool-calls) - [Orchestration] Added new model DeepSeek-R1: `OrchestrationAiModel.DEEPSEEK_R1` - [Orchestration] [Tool execution fully enabled](https://sap.github.io/ai-sdk/docs/java/spring-ai/orchestration#tool-calling) diff --git a/foundation-models/openai/pom.xml b/foundation-models/openai/pom.xml index ce3cd93b1..0e880a326 100644 --- a/foundation-models/openai/pom.xml +++ b/foundation-models/openai/pom.xml @@ -77,6 +77,14 @@ com.fasterxml.jackson.core jackson-annotations + + com.github.victools + jsonschema-generator + + + com.github.victools + jsonschema-module-jackson + io.vavr vavr diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequest.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequest.java index 5f308a1c4..427749495 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequest.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequest.java @@ -9,6 +9,7 @@ import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionRequestAllOfResponseFormat; import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionRequestAllOfStop; import java.math.BigDecimal; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Objects; @@ -125,6 +126,15 @@ public class OpenAiChatCompletionRequest { /** List of tools that the model may invoke during the completion. */ @Nullable List tools; + /** + * List of tools that are executable at runtime of the application. + * + * @since 1.7.0 + */ + @Getter(value = AccessLevel.PACKAGE) + @Nullable + List toolsExecutable; + /** Option to control which tool is invoked by the model. */ @With(AccessLevel.PRIVATE) @Nullable @@ -179,6 +189,7 @@ public OpenAiChatCompletionRequest(@Nonnull final List messages) null, null, null, + null, null); } @@ -226,6 +237,7 @@ public OpenAiChatCompletionRequest withParallelToolCalls( this.streamOptions, this.responseFormat, this.tools, + this.toolsExecutable, this.toolChoice); } @@ -258,6 +270,7 @@ public OpenAiChatCompletionRequest withLogprobs(@Nonnull final Boolean logprobs) this.streamOptions, this.responseFormat, this.tools, + this.toolsExecutable, this.toolChoice); } @@ -312,10 +325,24 @@ CreateChatCompletionRequest createCreateChatCompletionRequest() { request.seed(this.seed); request.streamOptions(this.streamOptions); request.responseFormat(this.responseFormat); - request.tools(this.tools); + request.tools(getChatCompletionTools()); request.toolChoice(this.toolChoice); request.functionCall(null); request.functions(null); return request; } + + @Nullable + private List getChatCompletionTools() { + final var toolsCombined = new ArrayList(); + if (this.tools != null) { + toolsCombined.addAll(this.tools); + } + if (this.toolsExecutable != null) { + for (final OpenAiTool tool : this.toolsExecutable) { + toolsCombined.add(tool.createChatCompletionTool()); + } + } + return toolsCombined.isEmpty() ? null : toolsCombined; + } } diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionResponse.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionResponse.java index c1d0bb772..35fdc3e5e 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionResponse.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionResponse.java @@ -11,6 +11,7 @@ import java.util.List; import java.util.Objects; import javax.annotation.Nonnull; +import lombok.Getter; import lombok.RequiredArgsConstructor; import lombok.Setter; import lombok.Value; @@ -26,7 +27,12 @@ @Setter(value = NONE) public class OpenAiChatCompletionResponse { /** The original response from the OpenAI API. */ - @Nonnull final CreateChatCompletionResponse originalResponse; + @Nonnull CreateChatCompletionResponse originalResponse; + + /** The original request that was sent to the OpenAI API. */ + @Getter(NONE) + @Nonnull + OpenAiChatCompletionRequest originalRequest; /** * Gets the token usage from the original response. @@ -96,4 +102,16 @@ public OpenAiAssistantMessage getMessage() { return new OpenAiAssistantMessage(new OpenAiMessageContent(contentItems), openAiToolCalls); } + + /** + * Execute tool calls that were suggested by the assistant response. + * + * @return the list of tool messages that were serialized for the computed results. Empty list if + * no tools were called. + */ + @Nonnull + public List executeTools() { + final var tools = originalRequest.getToolsExecutable(); + return OpenAiTool.execute(tools != null ? tools : List.of(), getMessage()); + } } diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java index 7cfcadb35..88f8513a0 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java @@ -158,8 +158,8 @@ public OpenAiChatCompletionOutput chatCompletion(@Nonnull final String prompt) public OpenAiChatCompletionResponse chatCompletion( @Nonnull final OpenAiChatCompletionRequest request) throws OpenAiClientException { warnIfUnsupportedUsage(); - return new OpenAiChatCompletionResponse( - chatCompletion(request.createCreateChatCompletionRequest())); + final var response = chatCompletion(request.createCreateChatCompletionRequest()); + return new OpenAiChatCompletionResponse(response, request); } /** diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiTool.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiTool.java new file mode 100644 index 000000000..3d2e7e900 --- /dev/null +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiTool.java @@ -0,0 +1,224 @@ +package com.sap.ai.sdk.foundationmodels.openai; + +import static com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool.TypeEnum.FUNCTION; +import static java.util.function.UnaryOperator.identity; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.github.victools.jsonschema.generator.Option; +import com.github.victools.jsonschema.generator.OptionPreset; +import com.github.victools.jsonschema.generator.SchemaGenerator; +import com.github.victools.jsonschema.generator.SchemaGeneratorConfigBuilder; +import com.github.victools.jsonschema.generator.SchemaVersion; +import com.github.victools.jsonschema.module.jackson.JacksonModule; +import com.github.victools.jsonschema.module.jackson.JacksonOption; +import com.google.common.annotations.Beta; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.FunctionObject; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import lombok.AccessLevel; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.Setter; +import lombok.Value; +import lombok.With; +import lombok.extern.slf4j.Slf4j; + +/** + * Represents an OpenAI tool that can be used to define a function call in an OpenAI Chat Completion + * request. This tool generates a JSON schema based on the provided class representing the + * function's request structure. + * + * @see OpenAI Function + * @since 1.7.0 + */ +@Slf4j +@Beta +@Value +@With +@Getter(AccessLevel.PACKAGE) +@AllArgsConstructor(access = AccessLevel.PRIVATE) +public class OpenAiTool { + + private static final ObjectMapper JACKSON = new ObjectMapper(); + + /** The schema generator used to create JSON schemas. */ + @Nonnull private static final SchemaGenerator GENERATOR = createSchemaGenerator(); + + /** The name of the function. */ + @Setter(AccessLevel.NONE) + @Nonnull + String name; + + /** The function to execute a string argument to tool result object. */ + @Setter(AccessLevel.NONE) + @Nonnull + Function functionExecutor; + + /** schema to be used for the function call. */ + @Setter(AccessLevel.NONE) + @Nonnull + ObjectNode schema; + + /** An optional description of the function. */ + @Nullable String description; + + /** An optional flag indicating whether the function parameters should be treated strictly. */ + @Nullable Boolean strict; + + /** + * Instantiates a OpenAiTool builder instance on behalf of an executable function. + * + * @param function the function to be executed. + * @return an OpenAiTool builder instance. + * @param the type of the function input-argument class. + */ + @Nonnull + public static Builder1 forFunction(@Nonnull final Function function) { + return inputClass -> + name -> { + final Function exec = + s -> function.apply(deserializeArgument(inputClass, s)); + final var schema = GENERATOR.generateSchema(inputClass); + return new OpenAiTool(name, exec, schema, null, null); + }; + } + + /** + * Creates a new OpenAiTool instance with the specified function and input class. + * + * @param the type of the input class. + */ + public interface Builder1 { + /** + * Sets the name of the function. + * + * @param inputClass the class of the input object. + * @return a new OpenAiTool instance with the specified function and input class. + */ + @Nonnull + Builder2 withArgument(@Nonnull final Class inputClass); + } + + /** Creates a new OpenAiTool instance with the specified name. */ + public interface Builder2 { + /** + * Sets the name of the function. + * + * @param name the name of the function + * @return a new OpenAiTool instance with the specified name + */ + @Nonnull + OpenAiTool withName(@Nonnull final String name); + } + + @Nullable + static T deserializeArgument(@Nonnull final Class cl, @Nonnull final String s) { + try { + return JACKSON.readValue(s, cl); + } catch (JsonProcessingException e) { + throw new IllegalArgumentException("Failed to parse JSON string to class " + cl, e); + } + } + + ChatCompletionTool createChatCompletionTool() { + final var schemaMap = + OpenAiUtils.getOpenAiObjectMapper() + .convertValue(getSchema(), new TypeReference>() {}); + + return new ChatCompletionTool() + .type(FUNCTION) + .function( + new FunctionObject() + .name(getName()) + .description(getDescription()) + .parameters(schemaMap) + .strict(getStrict())); + } + + private static SchemaGenerator createSchemaGenerator() { + final var module = + new JacksonModule( + JacksonOption.RESPECT_JSONPROPERTY_REQUIRED, JacksonOption.RESPECT_JSONPROPERTY_ORDER); + return new SchemaGenerator( + new SchemaGeneratorConfigBuilder(SchemaVersion.DRAFT_2020_12, OptionPreset.PLAIN_JSON) + .without(Option.SCHEMA_VERSION_INDICATOR) + .with(module) + .build()); + } + + /** + * Executes the given tool calls with the provided tools and returns the results as a list of + * {@link OpenAiToolMessage} containing execution results encoded as JSON string. + * + * @param tools the list of tools to execute + * @param msg the assistant message containing a list of tool calls with arguments + * @return The list of tool messages with the results. + */ + @Beta + @Nonnull + static List execute( + @Nonnull final List tools, @Nonnull final OpenAiAssistantMessage msg) { + final var toolResults = executeInternal(tools, msg); + final var result = new ArrayList(); + for (final var entry : toolResults.entrySet()) { + final var functionCall = entry.getKey().getId(); + final var serializedValue = serializeObject(entry.getValue()); + result.add(OpenAiMessage.tool(serializedValue, functionCall)); + } + return result; + } + + /** + * Executes the given tool calls with the provided tools and returns the results as a list of + * {@link OpenAiToolMessage} containing execution results encoded as JSON string. + * + * @param tools the list of tools to execute + * @param msg the assistant message containing a list of tool calls with arguments + * @return a map that contains the function calls and their respective tool results. + */ + @Nonnull + private static Map executeInternal( + @Nonnull final List tools, @Nonnull final OpenAiAssistantMessage msg) { + final var result = new LinkedHashMap(); + final var toolMap = tools.stream().collect(Collectors.toMap(OpenAiTool::getName, identity())); + for (final OpenAiToolCall toolCall : msg.toolCalls()) { + if (toolCall instanceof OpenAiFunctionCall functionCall) { + final var tool = toolMap.get(functionCall.getName()); + if (tool == null) { + log.warn("Tool not found for function call: {}", functionCall.getName()); + continue; + } + final var toolResult = executeFunction(tool, functionCall); + result.put(functionCall, toolResult); + } + } + return result; + } + + @Nonnull + private static Object executeFunction( + @Nonnull final OpenAiTool tool, @Nonnull final OpenAiFunctionCall toolCall) { + final Function executor = tool.getFunctionExecutor(); + final String arguments = toolCall.getArguments(); + return executor.apply(arguments); + } + + @Nonnull + private static String serializeObject(@Nonnull final Object obj) throws IllegalArgumentException { + try { + return JACKSON.writeValueAsString(obj); + } catch (JsonProcessingException e) { + throw new IllegalArgumentException("Failed to serialize object to JSON", e); + } + } +} diff --git a/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequestTest.java b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequestTest.java index d222ddf05..88fcebae4 100644 --- a/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequestTest.java +++ b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequestTest.java @@ -4,11 +4,13 @@ import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestUserMessage; import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestUserMessageContent; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool; import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionToolChoiceOption; import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionRequestAllOfStop; import java.math.BigDecimal; import java.util.ArrayList; import java.util.List; +import java.util.Map; import org.junit.jupiter.api.Test; class OpenAiChatCompletionRequestTest { @@ -114,4 +116,44 @@ void messageListExternallyUnmodifiable() { .as("Modifying the original list should not affect the messages in the request object.") .hasSize(1); } + + @Test + void withOpenAiTools() { + record DummyRequest(String param1, int param2) {} + + var request = + new OpenAiChatCompletionRequest(OpenAiMessage.user("Hello, world")) + .withToolsExecutable( + List.of( + OpenAiTool.forFunction(r -> "result") + .withArgument(DummyRequest.class) + .withName("toolA") + .withDescription("descA") + .withStrict(true), + OpenAiTool.forFunction(r -> "result") + .withArgument(DummyRequest.class) + .withName("toolB") + .withDescription("descB") + .withStrict(true))); + + var lowLevelRequest = request.createCreateChatCompletionRequest(); + assertThat(lowLevelRequest.getTools()).hasSize(2); + + var toolA = lowLevelRequest.getTools().get(0); + assertThat(toolA).isInstanceOf(ChatCompletionTool.class); + assertThat(toolA.getType()).isEqualTo(ChatCompletionTool.TypeEnum.FUNCTION); + assertThat(toolA.getFunction().getName()).isEqualTo("toolA"); + assertThat(toolA.getFunction().getDescription()).isEqualTo("descA"); + assertThat(toolA.getFunction().isStrict()).isTrue(); + + assertThat(toolA.getFunction().getParameters()) + .isEqualTo( + Map.of( + "properties", + Map.of( + "param1", Map.of("type", "string"), + "param2", Map.of("type", "integer")), + "type", + "object")); + } } diff --git a/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiToolTest.java b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiToolTest.java new file mode 100644 index 000000000..41c911e3f --- /dev/null +++ b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiToolTest.java @@ -0,0 +1,109 @@ +package com.sap.ai.sdk.foundationmodels.openai; + +import static com.sap.ai.sdk.foundationmodels.openai.OpenAiTool.deserializeArgument; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import lombok.EqualsAndHashCode; +import org.junit.jupiter.api.Test; + +class OpenAiToolTest { + private static final OpenAiFunctionCall FUNCTION_CALL_A = + new OpenAiFunctionCall("1", "functionA", "{\"key\":\"value\"}"); + private static final OpenAiFunctionCall FUNCTION_CALL_B = + new OpenAiFunctionCall("2", "functionB", "{\"key\":\"value\"}"); + private static final OpenAiFunctionCall INVALID_FUNCTION_CALL_A = + new OpenAiFunctionCall("3", "functionA", "{invalid-json}"); + + private static final OpenAiMessageContent EMPTY_MSG_CONTENT = + new OpenAiMessageContent(Collections.emptyList()); + + private static class Dummy { + record Request(String key) {} + + record Response(String toolMsg) {} + + static final Function conCat = + request -> new Dummy.Response(request.key()); + } + + @Test + void getArgumentsAsMapValid() { + final var result = deserializeArgument(Map.class, FUNCTION_CALL_A.getArguments()); + assertThat(result).containsEntry("key", "value"); + } + + @Test + void getArgumentsAsMapInvalid() { + assertThatThrownBy(() -> deserializeArgument(Map.class, INVALID_FUNCTION_CALL_A.getArguments())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Failed to parse JSON string"); + } + + @Test + void getArgumentsAsObjectValid() { + final var result = deserializeArgument(Dummy.Request.class, FUNCTION_CALL_A.getArguments()); + assertThat(result).isInstanceOf(Dummy.Request.class); + assertThat(result.key()).isEqualTo("value"); + } + + @Test + void getArgumentsAsObjectInvalid() { + final var payload = INVALID_FUNCTION_CALL_A.getArguments(); + assertThatThrownBy(() -> deserializeArgument(Integer.class, payload)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Failed to parse JSON string"); + } + + @Test + void executeToolsValid() { + final var toolA = + OpenAiTool.forFunction(Dummy.conCat) + .withArgument(Dummy.Request.class) + .withName("functionA"); + final var assistMsg = new OpenAiAssistantMessage(EMPTY_MSG_CONTENT, List.of(FUNCTION_CALL_A)); + final var toolMsgs = OpenAiTool.execute(List.of(toolA), assistMsg); + + assertThat(toolMsgs).hasSize(1); + assertThat(toolMsgs.get(0).toolCallId()).isEqualTo("1"); + assertThat(((OpenAiTextItem) toolMsgs.get(0).content().items().get(0)).text()) + .isEqualTo("{\"toolMsg\":\"value\"}"); + } + + @Test + void executeToolsNoMatchingCall() { + final var toolA = + OpenAiTool.forFunction(Dummy.conCat) + .withArgument(Dummy.Request.class) + .withName("functionA"); + final var assistMsg = new OpenAiAssistantMessage(EMPTY_MSG_CONTENT, List.of(FUNCTION_CALL_B)); + final var toolMsgs = OpenAiTool.execute(List.of(toolA), assistMsg); + assertThat(toolMsgs).isEmpty(); + } + + @Test + void executeToolsThrowsOnSerializationError() { + @EqualsAndHashCode + class NonSerializableResponse { + private String result; + + NonSerializableResponse(String result) { + this.result = result; + } + } + + final Function badF = + request -> new NonSerializableResponse(request.key()); + final var toolA = + OpenAiTool.forFunction(badF).withArgument(Dummy.Request.class).withName("functionA"); + final var assistMsg = new OpenAiAssistantMessage(EMPTY_MSG_CONTENT, List.of(FUNCTION_CALL_A)); + + assertThatThrownBy(() -> OpenAiTool.execute(List.of(toolA), assistMsg)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Failed to serialize object to JSON"); + } +} diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OpenAiServiceV2.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OpenAiServiceV2.java index 6ddd0ead6..0502cafad 100644 --- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OpenAiServiceV2.java +++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OpenAiServiceV2.java @@ -3,30 +3,19 @@ import static com.sap.ai.sdk.foundationmodels.openai.OpenAiModel.GPT_4O; import static com.sap.ai.sdk.foundationmodels.openai.OpenAiModel.GPT_4O_MINI; import static com.sap.ai.sdk.foundationmodels.openai.OpenAiModel.TEXT_EMBEDDING_3_SMALL; -import static com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool.TypeEnum.FUNCTION; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.JsonMappingException; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.module.jsonSchema.JsonSchemaGenerator; import com.sap.ai.sdk.core.AiCoreService; -import com.sap.ai.sdk.foundationmodels.openai.OpenAiAssistantMessage; import com.sap.ai.sdk.foundationmodels.openai.OpenAiChatCompletionDelta; import com.sap.ai.sdk.foundationmodels.openai.OpenAiChatCompletionRequest; import com.sap.ai.sdk.foundationmodels.openai.OpenAiChatCompletionResponse; import com.sap.ai.sdk.foundationmodels.openai.OpenAiClient; import com.sap.ai.sdk.foundationmodels.openai.OpenAiEmbeddingRequest; import com.sap.ai.sdk.foundationmodels.openai.OpenAiEmbeddingResponse; -import com.sap.ai.sdk.foundationmodels.openai.OpenAiFunctionCall; import com.sap.ai.sdk.foundationmodels.openai.OpenAiImageItem; import com.sap.ai.sdk.foundationmodels.openai.OpenAiMessage; -import com.sap.ai.sdk.foundationmodels.openai.OpenAiToolCall; -import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool; -import com.sap.ai.sdk.foundationmodels.openai.generated.model.FunctionObject; +import com.sap.ai.sdk.foundationmodels.openai.OpenAiTool; import java.util.ArrayList; import java.util.List; -import java.util.Map; import java.util.stream.Stream; import javax.annotation.Nonnull; import lombok.extern.slf4j.Slf4j; @@ -36,8 +25,6 @@ @Service @Slf4j public class OpenAiServiceV2 { - private static final ObjectMapper JACKSON = new ObjectMapper(); - /** * Chat request to OpenAI * @@ -95,7 +82,8 @@ public OpenAiChatCompletionResponse chatCompletionImage(@Nonnull final String li } /** - * Executes a chat completion request to OpenAI with a tool that calculates the weather. + * Chat request to OpenAI with tool that gets the weather for a given location and unit. The tool + * executed and the result is sent back to the assistant. * * @param location The location to get the weather for. * @param unit The unit of temperature to use. @@ -104,62 +92,29 @@ public OpenAiChatCompletionResponse chatCompletionImage(@Nonnull final String li @Nonnull public OpenAiChatCompletionResponse chatCompletionToolExecution( @Nonnull final String location, @Nonnull final String unit) { - - // 1. Define the function - final Map schemaMap = generateSchema(WeatherMethod.Request.class); - final var function = - new FunctionObject() - .name("weather") - .description("Get the weather for the given location") - .parameters(schemaMap); - final var tool = new ChatCompletionTool().type(FUNCTION).function(function); + final OpenAiClient client = OpenAiClient.forModel(GPT_4O_MINI); final var messages = new ArrayList(); messages.add(OpenAiMessage.user("What's the weather in %s in %s?".formatted(location, unit))); - // Assistant will call the function - final var request = new OpenAiChatCompletionRequest(messages).withTools(List.of(tool)); - final OpenAiChatCompletionResponse response = - OpenAiClient.forModel(GPT_4O_MINI).chatCompletion(request); - - // 2. Optionally, execute the function. - final OpenAiAssistantMessage assistantMessage = response.getMessage(); - messages.add(assistantMessage); - - final OpenAiToolCall toolCall = assistantMessage.toolCalls().get(0); - if (!(toolCall instanceof OpenAiFunctionCall functionCall)) { - throw new IllegalArgumentException( - "Expected a function call, but got: %s".formatted(assistantMessage)); - } - - final WeatherMethod.Request arguments = - parseJson(functionCall.getArguments(), WeatherMethod.Request.class); - final WeatherMethod.Response weatherMethod = WeatherMethod.getCurrentWeather(arguments); - - messages.add(OpenAiMessage.tool(weatherMethod.toString(), functionCall.getId())); - - // Send back the results, and the model will incorporate them into its final response. - return OpenAiClient.forModel(GPT_4O_MINI).chatCompletion(request.withMessages(messages)); - } - - @Nonnull - private static T parseJson(@Nonnull final String rawJson, @Nonnull final Class clazz) { - try { - return JACKSON.readValue(rawJson, clazz); - } catch (JsonProcessingException e) { - throw new IllegalArgumentException("Failed to parse tool call arguments: " + rawJson, e); - } - } - - @Nonnull - private static Map generateSchema(@Nonnull final Class clazz) { - final var jsonSchemaGenerator = new JsonSchemaGenerator(JACKSON); - try { - final var schema = jsonSchemaGenerator.generateSchema(clazz); - return JACKSON.convertValue(schema, new TypeReference<>() {}); - } catch (JsonMappingException e) { - throw new IllegalArgumentException("Could not generate schema for " + clazz.getName(), e); - } + // 1. Define the function + final List tools = + List.of( + OpenAiTool.forFunction(WeatherMethod::getCurrentWeather) + .withArgument(WeatherMethod.Request.class) + .withName("weather") + .withDescription("Get the weather for the given location")); + + // 2. Assistant calls the function + final var request = new OpenAiChatCompletionRequest(messages).withToolsExecutable(tools); + final OpenAiChatCompletionResponse response = client.chatCompletion(request); + + // 3. Execute the tool calls + messages.add(response.getMessage()); + messages.addAll(response.executeTools()); + + // 4. Have model run the final request with incorporated tool results + return client.chatCompletion(request.withMessages(messages)); } /**