diff --git a/docs/release_notes.md b/docs/release_notes.md index 5745d6681..490694854 100644 --- a/docs/release_notes.md +++ b/docs/release_notes.md @@ -12,7 +12,7 @@ ### ✨ New Functionality -- +- [OpenAI] [Add convenience for tool definition and parsing function calls](https://sap.github.io/ai-sdk/docs/java/foundation-models/openai/chat-completion#executing-tool-calls) ### 📈 Improvements diff --git a/foundation-models/openai/pom.xml b/foundation-models/openai/pom.xml index bb3eb499a..f9d31ef36 100644 --- a/foundation-models/openai/pom.xml +++ b/foundation-models/openai/pom.xml @@ -77,6 +77,10 @@ com.fasterxml.jackson.core jackson-annotations + + com.fasterxml.jackson.module + jackson-module-jsonSchema + io.vavr vavr diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequest.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequest.java index 5f308a1c4..a0249b382 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequest.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequest.java @@ -282,6 +282,30 @@ public OpenAiChatCompletionRequest withToolChoice(@Nonnull final OpenAiToolChoic return this.withToolChoice(choice.toolChoice); } + /** + * Sets the tools to be used in the request with convenience class {@code OpenAiTool}. + * + * @param tools the list of tools to be used + * @return a new OpenAiChatCompletionRequest instance with the specified tools + * @throws IllegalArgumentException if the tool type is not supported + * @since 1.7.0 + */ + @Nonnull + public OpenAiChatCompletionRequest withOpenAiTools(@Nonnull final List tools) { + return this.withTools( + tools.stream() + .map( + tool -> { + if (tool instanceof OpenAiFunctionTool) { + return ((OpenAiFunctionTool) tool).createChatCompletionTool(); + } else { + throw new IllegalArgumentException( + "Unsupported tool type: " + tool.getClass().getName()); + } + }) + .toList()); + } + /** * Converts the request to a generated model class CreateChatCompletionRequest. * diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiFunctionCall.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiFunctionCall.java index c3668d26b..2801f536d 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiFunctionCall.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiFunctionCall.java @@ -1,6 +1,12 @@ package com.sap.ai.sdk.foundationmodels.openai; +import static com.sap.ai.sdk.foundationmodels.openai.OpenAiUtils.getOpenAiObjectMapper; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; import com.google.common.annotations.Beta; +import java.lang.reflect.Type; +import java.util.Map; import javax.annotation.Nonnull; import lombok.AllArgsConstructor; import lombok.Value; @@ -22,4 +28,50 @@ public class OpenAiFunctionCall implements OpenAiToolCall { /** The arguments for the function call, encoded as a JSON string. */ @Nonnull String arguments; + + /** + * Parses the arguments, encoded as a JSON string, into a {@code Map}. + * + * @return a map of the arguments + * @throws IllegalArgumentException if parsing fails + * @since 1.7.0 + */ + @Nonnull + public Map getArgumentsAsMap() throws IllegalArgumentException { + return parseArguments(new TypeReference<>() {}); + } + + /** + * Parses the arguments, encoded as a JSON string, into an object of type expected by a function + * tool. + * + * @param tool the function tool the arguments are for + * @return the parsed arguments as an object + * @param the type of object accepted by the function tool + * @throws IllegalArgumentException if parsing arguments fails + * @since 1.7.0 + */ + @Nonnull + public T getArgumentsAsObject(@Nonnull final OpenAiFunctionTool tool) + throws IllegalArgumentException { + final var typeRef = + new TypeReference() { + @Override + public Type getType() { + return tool.getRequestModel(); + } + }; + return parseArguments(typeRef); + } + + @Nonnull + private T parseArguments(@Nonnull final TypeReference typeReference) + throws IllegalArgumentException { + try { + return getOpenAiObjectMapper().readValue(getArguments(), typeReference); + } catch (JsonProcessingException e) { + throw new IllegalArgumentException( + "Failed to parse JSON string to class " + typeReference.getType(), e); + } + } } diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiFunctionTool.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiFunctionTool.java new file mode 100644 index 000000000..9d8d61cb5 --- /dev/null +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiFunctionTool.java @@ -0,0 +1,83 @@ +package com.sap.ai.sdk.foundationmodels.openai; + +import static com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool.TypeEnum.FUNCTION; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.JsonMappingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.module.jsonSchema.JsonSchema; +import com.fasterxml.jackson.module.jsonSchema.JsonSchemaGenerator; +import com.google.common.annotations.Beta; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.FunctionObject; +import java.util.Map; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import lombok.AccessLevel; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.Value; +import lombok.With; + +/** + * Represents an OpenAI function tool that can be used to define a function call in an OpenAI Chat + * Completion request. This tool generates a JSON schema based on the provided class representing + * the function's request structure. + * + * @see OpenAI Function + * @since 1.7.0 + */ +@Beta +@Value +@With +@Getter(AccessLevel.PACKAGE) +@AllArgsConstructor(access = AccessLevel.PRIVATE) +public class OpenAiFunctionTool implements OpenAiTool { + + /** The name of the function. */ + @Nonnull String name; + + /** The model class for function request. */ + @Nonnull Class requestModel; + + /** An optional description of the function. */ + @Nullable String description; + + /** An optional flag indicating whether the function parameters should be treated strictly. */ + @Nullable Boolean strict; + + /** + * Constructs an {@code OpenAiFunctionTool} with the specified name and a model class that + * captures the request to the function. + * + * @param name the name of the function + * @param requestModel the model class for the function request + * @param the type of the request model + */ + public OpenAiFunctionTool(@Nonnull final String name, @Nonnull final Class requestModel) { + this(name, requestModel, null, null); + } + + ChatCompletionTool createChatCompletionTool() { + final var objectMapper = new ObjectMapper(); + JsonSchema schema = null; + try { + schema = new JsonSchemaGenerator(objectMapper).generateSchema(requestModel); + } catch (JsonMappingException e) { + throw new IllegalArgumentException( + "Could not generate schema for " + requestModel.getTypeName(), e); + } + + schema.setId(null); + final var schemaMap = + objectMapper.convertValue(schema, new TypeReference>() {}); + + final var function = + new FunctionObject() + .name(getName()) + .description(getDescription()) + .parameters(schemaMap) + .strict(getStrict()); + return new ChatCompletionTool().type(FUNCTION).function(function); + } +} diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiTool.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiTool.java new file mode 100644 index 000000000..3c6675854 --- /dev/null +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiTool.java @@ -0,0 +1,8 @@ +package com.sap.ai.sdk.foundationmodels.openai; + +/** + * Represents a tool that can be integrated into an OpenAI Chat Completion request. + * + * @since 1.7.0 + */ +public sealed interface OpenAiTool permits OpenAiFunctionTool {} diff --git a/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequestTest.java b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequestTest.java index d222ddf05..7bf06ab2b 100644 --- a/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequestTest.java +++ b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequestTest.java @@ -4,11 +4,13 @@ import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestUserMessage; import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestUserMessageContent; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool; import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionToolChoiceOption; import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionRequestAllOfStop; import java.math.BigDecimal; import java.util.ArrayList; import java.util.List; +import java.util.Map; import org.junit.jupiter.api.Test; class OpenAiChatCompletionRequestTest { @@ -114,4 +116,45 @@ void messageListExternallyUnmodifiable() { .as("Modifying the original list should not affect the messages in the request object.") .hasSize(1); } + + @Test + void withOpenAiTools() { + record DummyRequest(String param1, int param2) {} + + var request = + new OpenAiChatCompletionRequest(OpenAiMessage.user("Hello, world")) + .withOpenAiTools( + List.of( + new OpenAiFunctionTool("toolA", DummyRequest.class) + .withDescription("descA") + .withStrict(true), + new OpenAiFunctionTool("toolB", String.class) + .withDescription("descB") + .withStrict(false))); + + var lowLevelRequest = request.createCreateChatCompletionRequest(); + assertThat(lowLevelRequest.getTools()).hasSize(2); + + var toolA = lowLevelRequest.getTools().get(0); + assertThat(toolA).isInstanceOf(ChatCompletionTool.class); + assertThat(toolA.getType()).isEqualTo(ChatCompletionTool.TypeEnum.FUNCTION); + assertThat(toolA.getFunction().getName()).isEqualTo("toolA"); + assertThat(toolA.getFunction().getDescription()).isEqualTo("descA"); + assertThat(toolA.getFunction().isStrict()).isTrue(); + assertThat(toolA.getFunction().getParameters()) + .isEqualTo( + Map.of( + "properties", + Map.of("param1", Map.of("type", "string"), "param2", Map.of("type", "integer")), + "type", + "object")); + + var toolB = lowLevelRequest.getTools().get(1); + assertThat(toolB).isInstanceOf(ChatCompletionTool.class); + assertThat(toolB.getType()).isEqualTo(ChatCompletionTool.TypeEnum.FUNCTION); + assertThat(toolB.getFunction().getName()).isEqualTo("toolB"); + assertThat(toolB.getFunction().getDescription()).isEqualTo("descB"); + assertThat(toolB.getFunction().isStrict()).isFalse(); + assertThat(toolB.getFunction().getParameters()).isEqualTo(Map.of("type", "string")); + } } diff --git a/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiToolCallTest.java b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiToolCallTest.java new file mode 100644 index 000000000..3f1f5823b --- /dev/null +++ b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiToolCallTest.java @@ -0,0 +1,45 @@ +package com.sap.ai.sdk.foundationmodels.openai; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import org.junit.jupiter.api.Test; + +class OpenAiToolCallTest { + private static final OpenAiFunctionCall VALID_FUNCTION_CALL = + new OpenAiFunctionCall("1", "functionName", "{\"key\":\"value\"}"); + private static final OpenAiFunctionCall INVALID_FUNCTION_CALL = + new OpenAiFunctionCall("1", "functionName", "{invalid-json}"); + + private static final OpenAiFunctionTool FUNCTION_TOOL = + new OpenAiFunctionTool("functionName", DummyRequest.class); + + record DummyRequest(String key) {} + + @Test + void getArgumentsAsMapParsesValidJson() { + var result = VALID_FUNCTION_CALL.getArgumentsAsMap(); + assertThat(result).containsEntry("key", "value"); + } + + @Test + void getArgumentsAsMapThrowsOnInvalidJson() { + assertThatThrownBy(INVALID_FUNCTION_CALL::getArgumentsAsMap) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Failed to parse JSON string"); + } + + @Test + void getArgumentsAsObjectParsesValidJson() { + var result = (DummyRequest) VALID_FUNCTION_CALL.getArgumentsAsObject(FUNCTION_TOOL); + assertThat(result).isInstanceOf(DummyRequest.class); + assertThat(result.key()).isEqualTo("value"); + } + + @Test + void getArgumentsAsObjectThrowsOnInvalidJson() { + assertThatThrownBy(() -> INVALID_FUNCTION_CALL.getArgumentsAsObject(FUNCTION_TOOL)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Failed to parse JSON string"); + } +} diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OpenAiServiceV2.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OpenAiServiceV2.java index 6ddd0ead6..1e1b165e1 100644 --- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OpenAiServiceV2.java +++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OpenAiServiceV2.java @@ -3,13 +3,7 @@ import static com.sap.ai.sdk.foundationmodels.openai.OpenAiModel.GPT_4O; import static com.sap.ai.sdk.foundationmodels.openai.OpenAiModel.GPT_4O_MINI; import static com.sap.ai.sdk.foundationmodels.openai.OpenAiModel.TEXT_EMBEDDING_3_SMALL; -import static com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool.TypeEnum.FUNCTION; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.JsonMappingException; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.module.jsonSchema.JsonSchemaGenerator; import com.sap.ai.sdk.core.AiCoreService; import com.sap.ai.sdk.foundationmodels.openai.OpenAiAssistantMessage; import com.sap.ai.sdk.foundationmodels.openai.OpenAiChatCompletionDelta; @@ -19,14 +13,12 @@ import com.sap.ai.sdk.foundationmodels.openai.OpenAiEmbeddingRequest; import com.sap.ai.sdk.foundationmodels.openai.OpenAiEmbeddingResponse; import com.sap.ai.sdk.foundationmodels.openai.OpenAiFunctionCall; +import com.sap.ai.sdk.foundationmodels.openai.OpenAiFunctionTool; import com.sap.ai.sdk.foundationmodels.openai.OpenAiImageItem; import com.sap.ai.sdk.foundationmodels.openai.OpenAiMessage; import com.sap.ai.sdk.foundationmodels.openai.OpenAiToolCall; -import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool; -import com.sap.ai.sdk.foundationmodels.openai.generated.model.FunctionObject; import java.util.ArrayList; import java.util.List; -import java.util.Map; import java.util.stream.Stream; import javax.annotation.Nonnull; import lombok.extern.slf4j.Slf4j; @@ -36,8 +28,6 @@ @Service @Slf4j public class OpenAiServiceV2 { - private static final ObjectMapper JACKSON = new ObjectMapper(); - /** * Chat request to OpenAI * @@ -104,64 +94,36 @@ public OpenAiChatCompletionResponse chatCompletionImage(@Nonnull final String li @Nonnull public OpenAiChatCompletionResponse chatCompletionToolExecution( @Nonnull final String location, @Nonnull final String unit) { - - // 1. Define the function - final Map schemaMap = generateSchema(WeatherMethod.Request.class); - final var function = - new FunctionObject() - .name("weather") - .description("Get the weather for the given location") - .parameters(schemaMap); - final var tool = new ChatCompletionTool().type(FUNCTION).function(function); - final var messages = new ArrayList(); messages.add(OpenAiMessage.user("What's the weather in %s in %s?".formatted(location, unit))); - // Assistant will call the function - final var request = new OpenAiChatCompletionRequest(messages).withTools(List.of(tool)); + // 1. Define the function + final var weatherFunction = + new OpenAiFunctionTool("weather", WeatherMethod.Request.class) + .withDescription("Get the weather for the given location"); + + // 2. Assistant calls the function + final var request = + new OpenAiChatCompletionRequest(messages).withOpenAiTools(List.of(weatherFunction)); final OpenAiChatCompletionResponse response = OpenAiClient.forModel(GPT_4O_MINI).chatCompletion(request); - - // 2. Optionally, execute the function. final OpenAiAssistantMessage assistantMessage = response.getMessage(); - messages.add(assistantMessage); + // 3. Execute the function final OpenAiToolCall toolCall = assistantMessage.toolCalls().get(0); if (!(toolCall instanceof OpenAiFunctionCall functionCall)) { throw new IllegalArgumentException( "Expected a function call, but got: %s".formatted(assistantMessage)); } + final WeatherMethod.Request arguments = functionCall.getArgumentsAsObject(weatherFunction); + final WeatherMethod.Response currentWeather = WeatherMethod.getCurrentWeather(arguments); - final WeatherMethod.Request arguments = - parseJson(functionCall.getArguments(), WeatherMethod.Request.class); - final WeatherMethod.Response weatherMethod = WeatherMethod.getCurrentWeather(arguments); - - messages.add(OpenAiMessage.tool(weatherMethod.toString(), functionCall.getId())); - - // Send back the results, and the model will incorporate them into its final response. + // 4. Send back the results, and the model will incorporate them into its final response. + messages.add(assistantMessage); + messages.add(OpenAiMessage.tool(currentWeather.toString(), functionCall.getId())); return OpenAiClient.forModel(GPT_4O_MINI).chatCompletion(request.withMessages(messages)); } - @Nonnull - private static T parseJson(@Nonnull final String rawJson, @Nonnull final Class clazz) { - try { - return JACKSON.readValue(rawJson, clazz); - } catch (JsonProcessingException e) { - throw new IllegalArgumentException("Failed to parse tool call arguments: " + rawJson, e); - } - } - - @Nonnull - private static Map generateSchema(@Nonnull final Class clazz) { - final var jsonSchemaGenerator = new JsonSchemaGenerator(JACKSON); - try { - final var schema = jsonSchemaGenerator.generateSchema(clazz); - return JACKSON.convertValue(schema, new TypeReference<>() {}); - } catch (JsonMappingException e) { - throw new IllegalArgumentException("Could not generate schema for " + clazz.getName(), e); - } - } - /** * Get the embedding of a text *