diff --git a/pom.xml b/pom.xml index a37c49be5..1826b3321 100644 --- a/pom.xml +++ b/pom.xml @@ -72,6 +72,7 @@ 5.16.1 3.26.3 4.38.0 + 2.18.3 1.14.2 20250107 @@ -132,7 +133,12 @@ com.fasterxml.jackson.module jackson-module-parameter-names - 2.18.3 + ${jackson.version} + + + com.fasterxml.jackson.module + jackson-module-jsonSchema + ${jackson.version} com.github.victools diff --git a/sample-code/spring-app/pom.xml b/sample-code/spring-app/pom.xml index 46ffc7cb5..06399a8c0 100644 --- a/sample-code/spring-app/pom.xml +++ b/sample-code/spring-app/pom.xml @@ -70,6 +70,10 @@ com.sap.cloud.sdk.cloudplatform cloudplatform-core + + com.sap.cloud.sdk.cloudplatform + cloudplatform-connectivity + com.sap.cloud.sdk.datamodel openapi-core @@ -131,6 +135,14 @@ com.fasterxml.jackson.core jackson-annotations + + com.fasterxml.jackson.module + jackson-module-jsonSchema + + + com.fasterxml.jackson.core + jackson-core + ch.qos.logback @@ -177,10 +189,6 @@ assertj-core test - - com.sap.cloud.sdk.cloudplatform - cloudplatform-connectivity - diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OpenAiController.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OpenAiController.java index e48b67b70..2abe83fb5 100644 --- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OpenAiController.java +++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OpenAiController.java @@ -120,11 +120,13 @@ Object chatCompletionImage( return response.getChoices().get(0).getMessage(); } - @GetMapping("/chatCompletionTool") + @GetMapping("/chatCompletionToolExecution") @Nonnull - Object chatCompletionTools( - @Nullable @RequestParam(value = "format", required = false) final String format) { - final var response = service.chatCompletionTools(12); + Object chatCompletionToolExecution( + @Nullable @RequestParam(value = "format", required = false) final String format, + @Nonnull @RequestParam(value = "location", defaultValue = "Dubai") final String location, + @Nonnull @RequestParam(value = "unit", defaultValue = "°C") final String unit) { + final var response = service.chatCompletionToolExecution(location, unit); if ("json".equals(format)) { return response; } diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OpenAiService.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OpenAiService.java index 5ee361616..8196659aa 100644 --- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OpenAiService.java +++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OpenAiService.java @@ -5,6 +5,11 @@ import static com.sap.ai.sdk.foundationmodels.openai.OpenAiModel.TEXT_EMBEDDING_3_SMALL; import static com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionTool.ToolType.FUNCTION; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.JsonMappingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.module.jsonSchema.JsonSchemaGenerator; import com.sap.ai.sdk.core.AiCoreService; import com.sap.ai.sdk.foundationmodels.openai.OpenAiClient; import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionDelta; @@ -13,8 +18,10 @@ import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionParameters; import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionTool; import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage; +import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatToolCall; import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiEmbeddingOutput; import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiEmbeddingParameters; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.stream.Stream; @@ -26,6 +33,7 @@ @Service @Slf4j public class OpenAiService { + private static final ObjectMapper JACKSON = new ObjectMapper(); /** * Chat request to OpenAI @@ -86,30 +94,79 @@ public OpenAiChatCompletionOutput chatCompletionImage(@Nonnull final String link } /** - * Chat request to OpenAI with a tool. + * Executes a chat completion request to OpenAI with a tool that calculates the weather. * - * @param months The number of months to be inferred in the tool - * @return the assistant message response + * @param location The location to get the weather for. + * @param unit The unit of temperature to use. + * @return The assistant message response. */ @Nonnull - public OpenAiChatCompletionOutput chatCompletionTools(final int months) { - final var question = - "A pair of rabbits is placed in a field. Each month, every pair produces one new pair, starting from the second month. How many rabbits will there be after %s months?" - .formatted(months); - final var par = Map.of("type", "object", "properties", Map.of("N", Map.of("type", "integer"))); + public OpenAiChatCompletionOutput chatCompletionToolExecution( + @Nonnull final String location, @Nonnull final String unit) { + + // 1. Define the function + final Map schemaMap = generateSchema(WeatherMethod.Request.class); final var function = new OpenAiChatCompletionFunction() - .setName("fibonacci") - .setDescription("Calculate the Fibonacci number for given sequence index.") - .setParameters(par); + .setName("weather") + .setDescription("Get the weather for the given location") + .setParameters(schemaMap); final var tool = new OpenAiChatCompletionTool().setType(FUNCTION).setFunction(function); + + final var messages = new ArrayList(); + messages.add( + new OpenAiChatMessage.OpenAiChatUserMessage() + .addText("What's the weather in %s in %s?".formatted(location, unit))); + + // Assistant will call the function final var request = new OpenAiChatCompletionParameters() - .addMessages(new OpenAiChatMessage.OpenAiChatUserMessage().addText(question)) - .setTools(List.of(tool)) - .setToolChoiceFunction("fibonacci"); + .addMessages(messages.toArray(OpenAiChatMessage[]::new)) + .setTools(List.of(tool)); + + final OpenAiChatCompletionOutput response = + OpenAiClient.forModel(GPT_4O_MINI).chatCompletion(request); + + // 2. Optionally, execute the function. + final OpenAiChatToolCall toolCall = + response.getChoices().get(0).getMessage().getToolCalls().get(0); + final WeatherMethod.Request arguments = + parseJson(toolCall.getFunction().getArguments(), WeatherMethod.Request.class); + final WeatherMethod.Response currentWeather = WeatherMethod.getCurrentWeather(arguments); + + final OpenAiChatMessage.OpenAiChatAssistantMessage assistantMessage = + response.getChoices().get(0).getMessage(); + messages.add(assistantMessage); + + final var toolMessage = + new OpenAiChatMessage.OpenAiChatToolMessage() + .setToolCallId(toolCall.getId()) + .setContent(currentWeather.toString()); + messages.add(toolMessage); + + final var finalRequest = + new OpenAiChatCompletionParameters() + .addMessages(messages.toArray(OpenAiChatMessage[]::new)); + + return OpenAiClient.forModel(GPT_4O_MINI).chatCompletion(finalRequest); + } + + private static T parseJson(@Nonnull final String rawJson, @Nonnull final Class clazz) { + try { + return JACKSON.readValue(rawJson, clazz); + } catch (JsonProcessingException e) { + throw new IllegalArgumentException("Failed to parse tool call arguments: " + rawJson, e); + } + } - return OpenAiClient.forModel(GPT_4O_MINI).chatCompletion(request); + private static Map generateSchema(@Nonnull final Class clazz) { + final var jsonSchemaGenerator = new JsonSchemaGenerator(JACKSON); + try { + final var schema = jsonSchemaGenerator.generateSchema(clazz); + return JACKSON.convertValue(schema, new TypeReference<>() {}); + } catch (JsonMappingException e) { + throw new IllegalArgumentException("Could not generate schema for " + clazz.getName(), e); + } } /** diff --git a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/services/OpenAiServiceV2.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OpenAiServiceV2.java similarity index 57% rename from sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/services/OpenAiServiceV2.java rename to sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OpenAiServiceV2.java index 388959d68..8eef87928 100644 --- a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/services/OpenAiServiceV2.java +++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OpenAiServiceV2.java @@ -5,18 +5,26 @@ import static com.sap.ai.sdk.foundationmodels.openai.OpenAiModel.TEXT_EMBEDDING_3_SMALL; import static com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool.TypeEnum.FUNCTION; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.JsonMappingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.module.jsonSchema.JsonSchemaGenerator; import com.sap.ai.sdk.core.AiCoreService; +import com.sap.ai.sdk.foundationmodels.openai.OpenAiAssistantMessage; import com.sap.ai.sdk.foundationmodels.openai.OpenAiChatCompletionDelta; import com.sap.ai.sdk.foundationmodels.openai.OpenAiChatCompletionRequest; import com.sap.ai.sdk.foundationmodels.openai.OpenAiChatCompletionResponse; import com.sap.ai.sdk.foundationmodels.openai.OpenAiClient; import com.sap.ai.sdk.foundationmodels.openai.OpenAiEmbeddingRequest; import com.sap.ai.sdk.foundationmodels.openai.OpenAiEmbeddingResponse; +import com.sap.ai.sdk.foundationmodels.openai.OpenAiFunctionCall; import com.sap.ai.sdk.foundationmodels.openai.OpenAiImageItem; import com.sap.ai.sdk.foundationmodels.openai.OpenAiMessage; -import com.sap.ai.sdk.foundationmodels.openai.OpenAiToolChoice; +import com.sap.ai.sdk.foundationmodels.openai.OpenAiToolCall; import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool; import com.sap.ai.sdk.foundationmodels.openai.generated.model.FunctionObject; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.stream.Stream; @@ -28,6 +36,7 @@ @Service @Slf4j public class OpenAiServiceV2 { + private static final ObjectMapper JACKSON = new ObjectMapper(); /** * Chat request to OpenAI @@ -84,30 +93,71 @@ public OpenAiChatCompletionResponse chatCompletionImage(@Nonnull final String li } /** - * Chat request to OpenAI with a tool. + * Executes a chat completion request to OpenAI with a tool that calculates the weather. * - * @param months The number of months to be inferred in the tool - * @return the assistant message response + * @param location The location to get the weather for. + * @param unit The unit of temperature to use. + * @return The assistant message response. */ @Nonnull - public OpenAiChatCompletionResponse chatCompletionTools(final int months) { + public OpenAiChatCompletionResponse chatCompletionToolExecution( + @Nonnull final String location, @Nonnull final String unit) { + + // 1. Define the function + final Map schemaMap = generateSchema(WeatherMethod.Request.class); final var function = new FunctionObject() - .name("fibonacci") - .description("Calculate the Fibonacci number for given sequence index.") - .parameters( - Map.of("type", "object", "properties", Map.of("N", Map.of("type", "integer")))); - + .name("weather") + .description("Get the weather for the given location") + .parameters(schemaMap); final var tool = new ChatCompletionTool().type(FUNCTION).function(function); - final var request = - new OpenAiChatCompletionRequest( - "A pair of rabbits is placed in a field. Each month, every pair produces one new pair, starting from the second month. How many rabbits will there be after %s months?" - .formatted(months)) - .withTools(List.of(tool)) - .withToolChoice(OpenAiToolChoice.function("fibonacci")); + final var messages = new ArrayList(); + messages.add(OpenAiMessage.user("What's the weather in %s in %s?".formatted(location, unit))); + + // Assistant will call the function + final var request = new OpenAiChatCompletionRequest(messages).withTools(List.of(tool)); + final OpenAiChatCompletionResponse response = + OpenAiClient.forModel(GPT_4O_MINI).chatCompletion(request); + + // 2. Optionally, execute the function. + final OpenAiAssistantMessage assistantMessage = response.getMessage(); + messages.add(assistantMessage); + + final OpenAiToolCall toolCall = assistantMessage.toolCalls().get(0); + if (!(toolCall instanceof OpenAiFunctionCall functionCall)) { + throw new IllegalArgumentException( + "Expected a function call, but got: %s".formatted(assistantMessage)); + } + + final WeatherMethod.Request arguments = + parseJson(functionCall.getArguments(), WeatherMethod.Request.class); + final WeatherMethod.Response weatherMethod = WeatherMethod.getCurrentWeather(arguments); - return OpenAiClient.forModel(GPT_4O_MINI).chatCompletion(request); + messages.add(OpenAiMessage.tool(weatherMethod.toString(), functionCall.getId())); + + // Send back the results, and the model will incorporate them into its final response. + return OpenAiClient.forModel(GPT_4O_MINI).chatCompletion(request.withMessages(messages)); + } + + @Nonnull + private static T parseJson(@Nonnull final String rawJson, @Nonnull final Class clazz) { + try { + return JACKSON.readValue(rawJson, clazz); + } catch (JsonProcessingException e) { + throw new IllegalArgumentException("Failed to parse tool call arguments: " + rawJson, e); + } + } + + @Nonnull + private static Map generateSchema(@Nonnull final Class clazz) { + final var jsonSchemaGenerator = new JsonSchemaGenerator(JACKSON); + try { + final var schema = jsonSchemaGenerator.generateSchema(clazz); + return JACKSON.convertValue(schema, new TypeReference<>() {}); + } catch (JsonMappingException e) { + throw new IllegalArgumentException("Could not generate schema for " + clazz.getName(), e); + } } /** diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/WeatherMethod.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/WeatherMethod.java index f7413a20d..9e1a7decd 100644 --- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/WeatherMethod.java +++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/WeatherMethod.java @@ -35,7 +35,7 @@ record Response(double temp, Unit unit) {} @Nonnull @SuppressWarnings("unused") @Tool(description = "Get the weather in location") - Response getCurrentWeather(@ToolParam @Nonnull final Request request) { + static Response getCurrentWeather(@ToolParam @Nonnull final Request request) { final int temperature = request.location.hashCode() % 30; return new Response(temperature, request.unit); } diff --git a/sample-code/spring-app/src/main/resources/static/index.html b/sample-code/spring-app/src/main/resources/static/index.html index 051fc60f5..0019ecada 100644 --- a/sample-code/spring-app/src/main/resources/static/index.html +++ b/sample-code/spring-app/src/main/resources/static/index.html @@ -569,12 +569,12 @@
OpenAI
  • -
    - Chat request to OpenAI with a tool. + Chat request to OpenAI with an executed tool call.
  • diff --git a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OpenAiTest.java b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OpenAiTest.java index 1e2e1c3bd..7a029ebdd 100644 --- a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OpenAiTest.java +++ b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OpenAiTest.java @@ -73,16 +73,6 @@ void streamChatCompletion() { assertThat(totalOutput.getChoices().get(0).getContentFilterResults()).isNotNull(); } - @Test - void chatCompletionTools() { - final var completion = service.chatCompletionTools(12); - - final var message = completion.getChoices().get(0).getMessage(); - assertThat(message.getRole()).isEqualTo("assistant"); - assertThat(message.getToolCalls()).isNotNull(); - assertThat(message.getToolCalls().get(0).getFunction().getName()).isEqualTo("fibonacci"); - } - @Test void embedding() { final var embedding = service.embedding("Hello world"); @@ -101,4 +91,13 @@ void chatCompletionWithResource() { assertThat(message.getRole()).isEqualTo("assistant"); assertThat(message.getContent()).isNotEmpty(); } + + @Test + void chatCompletionToolExecution() { + final var completion = service.chatCompletionToolExecution("Dubai", "°C"); + + String content = completion.getContent(); + + assertThat(content).contains("°C"); + } } diff --git a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OpenAiV2Test.java b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OpenAiV2Test.java index 2fa3a86ee..9300561b6 100644 --- a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OpenAiV2Test.java +++ b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OpenAiV2Test.java @@ -71,16 +71,6 @@ void streamChatCompletion() { assertThat(usageRef.get().getCompletionTokens()).isGreaterThan(0); } - @Test - void chatCompletionTools() { - final var completion = service.chatCompletionTools(12); - - final var message = completion.getChoice().getMessage(); - assertThat(message.getRole()).isEqualTo(ASSISTANT); - assertThat(message.getToolCalls()).isNotNull(); - assertThat(message.getToolCalls().get(0).getFunction().getName()).isEqualTo("fibonacci"); - } - @Test void embedding() { final var embedding = service.embedding("Hello world"); @@ -102,4 +92,13 @@ void chatCompletionWithResource() { assertThat(completion.getChoice().getMessage().getRole()).isEqualTo(ASSISTANT); assertThat(completion.getContent()).isNotEmpty(); } + + @Test + void chatCompletionToolExecution() { + final var completion = service.chatCompletionToolExecution("Dubai", "°C"); + + String content = completion.getContent(); + + assertThat(content).contains("°C"); + } }