diff --git a/docs/guides/SPRING_AI_INTEGRATION.md b/docs/guides/SPRING_AI_INTEGRATION.md index 1da5722c0..4fde53798 100644 --- a/docs/guides/SPRING_AI_INTEGRATION.md +++ b/docs/guides/SPRING_AI_INTEGRATION.md @@ -5,6 +5,8 @@ - [Introduction](#introduction) - [Orchestration Chat Completion](#orchestration-chat-completion) - [Orchestration Masking](#orchestration-masking) +- [Stream chat completion](#stream-chat-completion) +- [Tool Calling](#tool-calling) ## Introduction @@ -32,7 +34,7 @@ First, add the Spring AI dependency to your `pom.xml`: :::note Spring AI Milestone Version Note that currently no stable version of Spring AI exists just yet. -The AI SDK currently uses the [M5 milestone](https://spring.io/blog/2024/12/23/spring-ai-1-0-0-m5-released). +The AI SDK currently uses the [M6 milestone](https://spring.io/blog/2025/02/14/spring-ai-1-0-0-m6-released). Please be aware that future versions of the AI SDK may increase the Spring AI version. ::: @@ -99,3 +101,40 @@ Flux responseFlux = _Note: A Spring endpoint can return `Flux` instead of `ResponseEntity`._ Please find [an example in our Spring Boot application](../../sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/SpringAiOrchestrationService.java). + +## Tool Calling + +First define a function that will be called by the LLM: + +```java +class WeatherMethod { + enum Unit {C,F} + record Request(String location, Unit unit) {} + record Response(double temp, Unit unit) {} + + @Tool(description = "Get the weather in location") + Response getCurrentWeather(@ToolParam Request request) { + int temperature = request.location.hashCode() % 30; + return new Response(temperature, request.unit); + } +} +``` + +Then add your tool to the options: + +```java +ChatModel client = new OrchestrationChatModel(); +OrchestrationModuleConfig config = new OrchestrationModuleConfig().withLlmConfig(GPT_35_TURBO); +OrchestrationChatOptions opts = new OrchestrationChatOptions(config); + +options.setToolCallbacks(List.of(ToolCallbacks.from(new WeatherMethod()))); + +options.setInternalToolExecutionEnabled(false);// tool execution is not yet available in orchestration + +Prompt prompt = new Prompt("What is the weather in Potsdam and in Toulouse?", options); + +ChatResponse response = client.call(prompt); +``` + +Please find [an example in our Spring Boot application](../../sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/SpringAiOrchestrationService.java). + diff --git a/docs/release-notes/release_notes.md b/docs/release-notes/release_notes.md index 27fd11ee5..011c003a3 100644 --- a/docs/release-notes/release_notes.md +++ b/docs/release-notes/release_notes.md @@ -12,7 +12,7 @@ ### ✨ New Functionality -- +- [Add Spring AI tool calling](../guides/SPRING_AI_INTEGRATION.md#tool-calling). ### 📈 Improvements diff --git a/orchestration/pom.xml b/orchestration/pom.xml index d60a1ad14..fc077820d 100644 --- a/orchestration/pom.xml +++ b/orchestration/pom.xml @@ -31,11 +31,11 @@ ${project.basedir}/../ - 80% + 81% 92% 93% - 71% - 95% + 74% + 92% 100% diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/AssistantMessage.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/AssistantMessage.java index f9cb82d08..6d157a561 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/AssistantMessage.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/AssistantMessage.java @@ -1,14 +1,20 @@ package com.sap.ai.sdk.orchestration; import com.google.common.annotations.Beta; +import com.sap.ai.sdk.orchestration.model.ChatMessage; +import com.sap.ai.sdk.orchestration.model.ResponseMessageToolCall; +import com.sap.ai.sdk.orchestration.model.SingleChatMessage; import java.util.List; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import lombok.Getter; import lombok.Value; import lombok.experimental.Accessors; +import lombok.val; /** Represents a chat message as 'assistant' to the orchestration service. */ @Value +@Getter @Accessors(fluent = true) public class AssistantMessage implements Message { @@ -20,6 +26,9 @@ public class AssistantMessage implements Message { @Getter(onMethod_ = @Beta) MessageContent content; + /** Tool call if there is any. */ + @Nullable List toolCalls; + /** * Creates a new assistant message with the given single message. * @@ -27,5 +36,28 @@ public class AssistantMessage implements Message { */ public AssistantMessage(@Nonnull final String singleMessage) { content = new MessageContent(List.of(new TextItem(singleMessage))); + toolCalls = null; + } + + /** + * Creates a new assistant message with the given tool calls. + * + * @param toolCalls list of tool call objects + */ + public AssistantMessage(@Nonnull final List toolCalls) { + content = new MessageContent(List.of()); + this.toolCalls = toolCalls; + } + + @Nonnull + @Override + public ChatMessage createChatMessage() { + if (toolCalls() != null) { + // content shouldn't be required for tool calls 🤷 + val message = SingleChatMessage.create().role(role).content(""); + message.setCustomField("tool_calls", toolCalls); + return message; + } + return Message.super.createChatMessage(); } } diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/ConfigToRequestTransformer.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/ConfigToRequestTransformer.java index 6c37731f4..8ef1009bb 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/ConfigToRequestTransformer.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/ConfigToRequestTransformer.java @@ -8,7 +8,6 @@ import com.sap.ai.sdk.orchestration.model.TemplatingModuleConfig; import io.vavr.control.Option; import java.util.ArrayList; -import java.util.List; import javax.annotation.Nonnull; import javax.annotation.Nullable; import lombok.AccessLevel; @@ -40,7 +39,7 @@ static CompletionPostRequest toCompletionPostRequest( @Nonnull static TemplatingModuleConfig toTemplateModuleConfig( - @Nonnull final OrchestrationPrompt prompt, @Nullable final TemplatingModuleConfig template) { + @Nonnull final OrchestrationPrompt prompt, @Nullable final TemplatingModuleConfig config) { /* * Currently, we have to merge the prompt into the template configuration. * This works around the limitation that the template config is required. @@ -48,8 +47,9 @@ static TemplatingModuleConfig toTemplateModuleConfig( * In this case, the request will fail, since the templating module will try to resolve the parameter. * To be fixed with https://github.tools.sap/AI/llm-orchestration/issues/662 */ - val messages = template instanceof Template t ? t.getTemplate() : List.of(); - val responseFormat = template instanceof Template t ? t.getResponseFormat() : null; + val template = config instanceof Template t ? t : Template.create().template(); + val messages = template.getTemplate(); + val responseFormat = template.getResponseFormat(); val messagesWithPrompt = new ArrayList<>(messages); messagesWithPrompt.addAll( prompt.getMessages().stream().map(Message::createChatMessage).toList()); @@ -57,7 +57,10 @@ static TemplatingModuleConfig toTemplateModuleConfig( throw new IllegalStateException( "A prompt is required. Pass at least one message or configure a template with messages or a template reference."); } - return Template.create().template(messagesWithPrompt).responseFormat(responseFormat); + return Template.create() + .template(messagesWithPrompt) + .tools(template.getTools()) + .responseFormat(responseFormat); } @Nonnull diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/Message.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/Message.java index 30a0b53ad..dcc303ca7 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/Message.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/Message.java @@ -13,7 +13,7 @@ import javax.annotation.Nonnull; /** Interface representing convenience wrappers of chat message to the orchestration service. */ -public sealed interface Message permits UserMessage, AssistantMessage, SystemMessage { +public sealed interface Message permits AssistantMessage, SystemMessage, ToolMessage, UserMessage { /** * A convenience method to create a user message from a string. diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/ToolMessage.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/ToolMessage.java new file mode 100644 index 000000000..fa6d40634 --- /dev/null +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/ToolMessage.java @@ -0,0 +1,39 @@ +package com.sap.ai.sdk.orchestration; + +import com.sap.ai.sdk.orchestration.model.ChatMessage; +import com.sap.ai.sdk.orchestration.model.SingleChatMessage; +import java.util.List; +import javax.annotation.Nonnull; +import lombok.Value; +import lombok.experimental.Accessors; + +/** + * Represents a chat message as 'tool' to the orchestration service. + * + * @since 1.4.0 + */ +@Value +@Accessors(fluent = true) +public class ToolMessage implements Message { + + /** The role of the assistant. */ + @Nonnull String role = "tool"; + + @Nonnull String id; + + @Nonnull String content; + + @Nonnull + @Override + public MessageContent content() { + return new MessageContent(List.of(new TextItem(content))); + } + + @Nonnull + @Override + public ChatMessage createChatMessage() { + final SingleChatMessage message = SingleChatMessage.create().role(role()).content(content); + message.setCustomField("tool_call_id", id); + return message; + } +} diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatModel.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatModel.java index d23c3266f..b972333a0 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatModel.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatModel.java @@ -1,6 +1,7 @@ package com.sap.ai.sdk.orchestration.spring; import static com.sap.ai.sdk.orchestration.OrchestrationClient.toCompletionPostRequest; +import static com.sap.ai.sdk.orchestration.model.ResponseMessageToolCall.TypeEnum.FUNCTION; import com.google.common.annotations.Beta; import com.sap.ai.sdk.orchestration.AssistantMessage; @@ -8,18 +9,24 @@ import com.sap.ai.sdk.orchestration.OrchestrationClient; import com.sap.ai.sdk.orchestration.OrchestrationPrompt; import com.sap.ai.sdk.orchestration.SystemMessage; +import com.sap.ai.sdk.orchestration.ToolMessage; import com.sap.ai.sdk.orchestration.UserMessage; +import com.sap.ai.sdk.orchestration.model.ResponseMessageToolCall; +import com.sap.ai.sdk.orchestration.model.ResponseMessageToolCallFunction; import java.util.List; import java.util.Map; import java.util.function.Function; import javax.annotation.Nonnull; -import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import lombok.val; +import org.springframework.ai.chat.messages.AssistantMessage.ToolCall; import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.ToolResponseMessage; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.tool.DefaultToolCallingManager; +import org.springframework.ai.model.tool.ToolCallingChatOptions; import reactor.core.publisher.Flux; /** @@ -29,9 +36,12 @@ */ @Beta @Slf4j -@RequiredArgsConstructor public class OrchestrationChatModel implements ChatModel { - @Nonnull private OrchestrationClient client; + @Nonnull private final OrchestrationClient client; + + @Nonnull + private final DefaultToolCallingManager toolCallingManager = + DefaultToolCallingManager.builder().build(); /** * Default constructor. @@ -39,18 +49,35 @@ public class OrchestrationChatModel implements ChatModel { * @since 1.2.0 */ public OrchestrationChatModel() { - this.client = new OrchestrationClient(); + this(new OrchestrationClient()); + } + + /** + * Constructor with a custom client. + * + * @since 1.2.0 + */ + public OrchestrationChatModel(@Nonnull final OrchestrationClient client) { + this.client = client; } @Nonnull @Override public ChatResponse call(@Nonnull final Prompt prompt) { - if (prompt.getOptions() instanceof OrchestrationChatOptions options) { val orchestrationPrompt = toOrchestrationPrompt(prompt); - val response = client.chatCompletion(orchestrationPrompt, options.getConfig()); - return new OrchestrationSpringChatResponse(response); + val response = + new OrchestrationSpringChatResponse( + client.chatCompletion(orchestrationPrompt, options.getConfig())); + + if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) + && response.hasToolCalls()) { + val toolExecutionResult = toolCallingManager.executeToolCalls(prompt, response); + // Send the tool execution result back to the model. + return call(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions())); + } + return response; } throw new IllegalArgumentException( "Please add OrchestrationChatOptions to the Prompt: new Prompt(\"message\", new OrchestrationChatOptions(config))"); @@ -92,18 +119,47 @@ private OrchestrationPrompt toOrchestrationPrompt(@Nonnull final Prompt prompt) @Nonnull private static com.sap.ai.sdk.orchestration.Message[] toOrchestrationMessages( @Nonnull final List messages) { - final Function mapper = + final Function> mapper = msg -> switch (msg.getMessageType()) { case SYSTEM: - yield new SystemMessage(msg.getText()); + yield List.of(new SystemMessage(msg.getText())); case USER: - yield new UserMessage(msg.getText()); + yield List.of(new UserMessage(msg.getText())); case ASSISTANT: - yield new AssistantMessage(msg.getText()); + val springToolCalls = + ((org.springframework.ai.chat.messages.AssistantMessage) msg).getToolCalls(); + if (springToolCalls != null) { + final List sdkToolCalls = + springToolCalls.stream() + .map(OrchestrationChatModel::toOrchestrationToolCall) + .toList(); + yield List.of(new AssistantMessage(sdkToolCalls)); + } + yield List.of(new AssistantMessage(msg.getText())); case TOOL: - throw new IllegalArgumentException("Tool messages are not supported"); + val toolResponses = ((ToolResponseMessage) msg).getResponses(); + yield toolResponses.stream() + .map( + r -> + (com.sap.ai.sdk.orchestration.Message) + new ToolMessage(r.id(), r.responseData())) + .toList(); }; - return messages.stream().map(mapper).toArray(com.sap.ai.sdk.orchestration.Message[]::new); + return messages.stream() + .map(mapper) + .flatMap(List::stream) + .toArray(com.sap.ai.sdk.orchestration.Message[]::new); + } + + @Nonnull + private static ResponseMessageToolCall toOrchestrationToolCall(@Nonnull final ToolCall toolCall) { + return ResponseMessageToolCall.create() + .id(toolCall.id()) + .type(FUNCTION) + .function( + ResponseMessageToolCallFunction.create() + .name(toolCall.name()) + .arguments(toolCall.arguments())); } } diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatOptions.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatOptions.java index 5d4a2e59c..8460d9409 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatOptions.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatOptions.java @@ -10,18 +10,25 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.annotations.Beta; import com.sap.ai.sdk.orchestration.OrchestrationModuleConfig; +import com.sap.ai.sdk.orchestration.model.ChatCompletionTool; +import com.sap.ai.sdk.orchestration.model.ChatCompletionTool.TypeEnum; +import com.sap.ai.sdk.orchestration.model.FunctionObject; import com.sap.ai.sdk.orchestration.model.LLMModuleConfig; +import com.sap.ai.sdk.orchestration.model.Template; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; import javax.annotation.Nonnull; import javax.annotation.Nullable; import lombok.AccessLevel; import lombok.Data; import lombok.Getter; -import lombok.Setter; import lombok.val; import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.tool.ToolCallingChatOptions; /** * Configuration to be used for orchestration requests. @@ -30,16 +37,20 @@ */ @Beta @Data -@Getter(AccessLevel.NONE) -@Setter(AccessLevel.NONE) -public class OrchestrationChatOptions implements ChatOptions { +public class OrchestrationChatOptions implements ToolCallingChatOptions { private static final ObjectMapper JACKSON = getOrchestrationObjectMapper(); - @Getter(AccessLevel.PUBLIC) - @Setter(AccessLevel.PUBLIC) - @Nonnull - OrchestrationModuleConfig config; + @Nonnull private OrchestrationModuleConfig config; + + private List functionCallbacks; + + @Getter(AccessLevel.NONE) + private Boolean internalToolExecutionEnabled; + + private Set toolNames; + + private Map toolContext; /** * Returns the model to use for the chat. @@ -160,13 +171,10 @@ public T copy() { return (T) new OrchestrationChatOptions(copyConfig); } - @SuppressWarnings("unchecked") // getModelParams() returns Object, it should return Map + @SuppressWarnings("unchecked") @Nullable private T getLlmConfigParam(@Nonnull final String param) { - if (getLlmConfigNonNull().getModelParams() instanceof Map) { - return ((Map) getLlmConfigNonNull().getModelParams()).get(param); - } - return null; + return ((Map) getLlmConfigNonNull().getModelParams()).get(param); } @Nonnull @@ -175,4 +183,66 @@ private LLMModuleConfig getLlmConfigNonNull() { config.getLlmConfig(), "LLM config is not set. Please set it: new OrchestrationChatOptions(new OrchestrationModuleConfig().withLlmConfig(...))"); } + + @Nonnull + @Override + public List getToolCallbacks() { + return functionCallbacks; + } + + @Override + @Deprecated + public void setFunctionCallbacks(@Nonnull final List toolCallbacks) { + setToolCallbacks(toolCallbacks); + } + + @Override + public void setToolCallbacks(@Nonnull final List toolCallbacks) { + this.functionCallbacks = toolCallbacks; + final Template template = + Objects.requireNonNullElse( + (Template) config.getTemplateConfig(), Template.create().template()); + val tools = toolCallbacks.stream().map(OrchestrationChatOptions::toOrchestrationTool).toList(); + config = config.withTemplateConfig(template.tools(tools)); + } + + private static ChatCompletionTool toOrchestrationTool( + @Nonnull final FunctionCallback functionCallback) { + return ChatCompletionTool.create() + .type(TypeEnum.FUNCTION) + .function( + FunctionObject.create() + .name(functionCallback.getName()) + .description(functionCallback.getDescription()) + .parameters(ModelOptionsUtils.jsonToMap(functionCallback.getInputTypeSchema()))); + } + + @Override + @Nullable + public Boolean isInternalToolExecutionEnabled() { + return this.internalToolExecutionEnabled; + } + + @Nonnull + @Override + public Set getFunctions() { + return Set.of(); + } + + @Override + public void setFunctions(@Nonnull final Set functions) { + // val template = + // Objects.requireNonNullElse( + // (Template) config.getTemplateConfig(), Template.create().template()); + // val tools = + // functionNames.stream() + // .map( + // functionName -> + // ChatCompletionTool.create() + // .type(TypeEnum.FUNCTION) + // .function(FunctionObject.create().name(functionName))) + // .toList(); + // config = config.withTemplateConfig(template.tools(tools)); + throw new UnsupportedOperationException("Not implemented yet"); + } } diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationSpringChatDelta.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationSpringChatDelta.java index 6ccf94727..9c3f4b0a5 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationSpringChatDelta.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationSpringChatDelta.java @@ -76,8 +76,6 @@ static ChatResponseMetadata toChatResponseMetadata( @Nonnull private static DefaultUsage toDefaultUsage(@Nonnull final TokenUsage usage) { return new DefaultUsage( - usage.getPromptTokens().longValue(), - usage.getCompletionTokens().longValue(), - usage.getTotalTokens().longValue()); + usage.getPromptTokens(), usage.getCompletionTokens(), usage.getTotalTokens()); } } diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationSpringChatResponse.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationSpringChatResponse.java index afedbb278..66636c471 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationSpringChatResponse.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationSpringChatResponse.java @@ -6,11 +6,13 @@ import com.sap.ai.sdk.orchestration.model.LLMModuleResultSynchronous; import com.sap.ai.sdk.orchestration.model.TokenUsage; import java.util.List; +import java.util.Map; import javax.annotation.Nonnull; import lombok.EqualsAndHashCode; import lombok.Value; import lombok.val; import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.AssistantMessage.ToolCall; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.metadata.DefaultUsage; @@ -52,7 +54,18 @@ static Generation toGeneration(@Nonnull final LLMChoice choice) { if (!choice.getLogprobs().isEmpty()) { metadata.metadata("logprobs", choice.getLogprobs()); } - val message = new AssistantMessage(choice.getMessage().getContent()); + val toolCalls = + choice.getMessage().getToolCalls().stream() + .map( + toolCall -> + new ToolCall( + toolCall.getId(), + toolCall.getType().getValue(), + toolCall.getFunction().getName(), + toolCall.getFunction().getArguments())) + .toList(); + + val message = new AssistantMessage(choice.getMessage().getContent(), Map.of(), toolCalls); return new Generation(message, metadata.build()); } @@ -74,8 +87,6 @@ static ChatResponseMetadata toChatResponseMetadata( @Nonnull private static DefaultUsage toDefaultUsage(@Nonnull final TokenUsage usage) { return new DefaultUsage( - usage.getPromptTokens().longValue(), - usage.getCompletionTokens().longValue(), - usage.getTotalTokens().longValue()); + usage.getPromptTokens(), usage.getCompletionTokens(), usage.getTotalTokens()); } } diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/spring/MockWeatherService.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/spring/MockWeatherService.java new file mode 100644 index 000000000..46c79bb3c --- /dev/null +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/spring/MockWeatherService.java @@ -0,0 +1,44 @@ +package com.sap.ai.sdk.orchestration.spring; + +import java.util.function.Function; +import javax.annotation.Nonnull; + +/** Function for tool calls in Spring AI */ +public class MockWeatherService + implements Function { + + /** Unit of temperature */ + public enum Unit { + /** Celsius */ + C, + /** Fahrenheit */ + F + } + + /** + * Request for the weather + * + * @param location the city + * @param unit the unit of temperature + */ + public record Request(String location, Unit unit) {} + + /** + * Response for the weather + * + * @param temp the temperature + * @param unit the unit of temperature + */ + public record Response(double temp, Unit unit) {} + + /** + * Apply the function + * + * @param request the request + * @return the response + */ + @Nonnull + public Response apply(@Nonnull Request request) { + return new Response(30.0, Unit.C); + } +} diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatDeltaTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatDeltaTest.java index e9b6ff4bb..b0580aaaf 100644 --- a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatDeltaTest.java +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatDeltaTest.java @@ -26,7 +26,7 @@ void testToGeneration() { Generation generation = OrchestrationSpringChatDelta.toGeneration(choice); - assertThat(generation.getOutput().getContent()).isEqualTo("Hello, world!"); + assertThat(generation.getOutput().getText()).isEqualTo("Hello, world!"); assertThat(generation.getMetadata().getFinishReason()).isEqualTo("stop"); assertThat(generation.getMetadata().get("index")).isEqualTo(0); } @@ -52,7 +52,7 @@ void testToChatResponseMetadata() { var usage = metadata.getUsage(); assertThat(usage.getPromptTokens()).isEqualTo(10L); - assertThat(usage.getGenerationTokens()).isEqualTo(20L); + assertThat(usage.getCompletionTokens()).isEqualTo(20L); assertThat(usage.getTotalTokens()).isEqualTo(30L); // delta without token usage diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatModelTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatModelTest.java index 78491d93e..9d5f08e89 100644 --- a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatModelTest.java +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatModelTest.java @@ -1,9 +1,14 @@ package com.sap.ai.sdk.orchestration.spring; import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl; +import static com.github.tomakehurst.wiremock.client.WireMock.equalToJson; import static com.github.tomakehurst.wiremock.client.WireMock.post; +import static com.github.tomakehurst.wiremock.client.WireMock.postRequestedFor; import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo; +import static com.github.tomakehurst.wiremock.client.WireMock.verify; +import static com.github.tomakehurst.wiremock.stubbing.Scenario.STARTED; import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.GPT_35_TURBO_16K; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -22,8 +27,10 @@ import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination; import java.io.IOException; import java.io.InputStream; +import java.util.List; import java.util.Objects; import java.util.function.Function; +import lombok.val; import org.apache.hc.client5.http.classic.HttpClient; import org.apache.hc.core5.http.ContentType; import org.apache.hc.core5.http.io.entity.InputStreamEntity; @@ -32,8 +39,10 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.Mockito; +import org.springframework.ai.chat.messages.AssistantMessage.ToolCall; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.tool.ToolCallbacks; import reactor.core.publisher.Flux; @WireMockTest @@ -43,7 +52,7 @@ public class OrchestrationChatModelTest { filename -> Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream(filename)); private static OrchestrationChatModel client; - private static OrchestrationModuleConfig config; + private static OrchestrationChatOptions defaultOptions; private static Prompt prompt; @BeforeEach @@ -51,10 +60,10 @@ void setup(WireMockRuntimeInfo server) { final DefaultHttpDestination destination = DefaultHttpDestination.builder(server.getHttpBaseUrl()).build(); client = new OrchestrationChatModel(new OrchestrationClient(destination)); - config = new OrchestrationModuleConfig().withLlmConfig(GPT_35_TURBO_16K); - prompt = - new Prompt( - "Hello World! Why is this phrase so famous?", new OrchestrationChatOptions(config)); + defaultOptions = + new OrchestrationChatOptions( + new OrchestrationModuleConfig().withLlmConfig(GPT_35_TURBO_16K)); + prompt = new Prompt("Hello World! Why is this phrase so famous?", defaultOptions); ApacheHttpClient5Accessor.setHttpClientCache(ApacheHttpClient5Cache.DISABLED); } @@ -72,10 +81,10 @@ void testCompletion() { aResponse() .withBodyFile("templatingResponse.json") .withHeader("Content-Type", "application/json"))); - final var result = client.call(prompt); + val result = client.call(prompt); assertThat(result).isNotNull(); - assertThat(result.getResult().getOutput().getContent()).isNotEmpty(); + assertThat(result.getResult().getOutput().getText()).isNotEmpty(); } @Test @@ -103,14 +112,14 @@ void testThrowsOnMissingLlmConfig() { @Test void testStreamCompletion() throws IOException { - try (var inputStream = spy(fileLoader.apply("streamChatCompletion.txt"))) { + try (val inputStream = spy(fileLoader.apply("streamChatCompletion.txt"))) { - final var httpClient = mock(HttpClient.class); + val httpClient = mock(HttpClient.class); ApacheHttpClient5Accessor.setHttpClientFactory(destination -> httpClient); // Create a mock response - final var mockResponse = new BasicClassicHttpResponse(200, "OK"); - final var inputStreamEntity = new InputStreamEntity(inputStream, ContentType.TEXT_PLAIN); + val mockResponse = new BasicClassicHttpResponse(200, "OK"); + val inputStreamEntity = new InputStreamEntity(inputStream, ContentType.TEXT_PLAIN); mockResponse.setEntity(inputStreamEntity); mockResponse.setHeader("Content-Type", "text/event-flux"); @@ -118,13 +127,13 @@ void testStreamCompletion() throws IOException { doReturn(mockResponse).when(httpClient).executeOpen(any(), any(), any()); Flux flux = client.stream(prompt); - var deltaList = flux.toStream().toList(); + val deltaList = flux.toStream().toList(); assertThat(deltaList).hasSize(3); // the first delta doesn't have any content - assertThat(deltaList.get(0).getResult().getOutput().getContent()).isEqualTo(""); - assertThat(deltaList.get(1).getResult().getOutput().getContent()).isEqualTo("Sure"); - assertThat(deltaList.get(2).getResult().getOutput().getContent()).isEqualTo("!"); + assertThat(deltaList.get(0).getResult().getOutput().getText()).isEqualTo(""); + assertThat(deltaList.get(1).getResult().getOutput().getText()).isEqualTo("Sure"); + assertThat(deltaList.get(2).getResult().getOutput().getText()).isEqualTo("!"); assertThat(deltaList.get(0).getResult().getMetadata().getFinishReason()).isEqualTo(""); assertThat(deltaList.get(1).getResult().getMetadata().getFinishReason()).isEqualTo(""); @@ -133,4 +142,76 @@ void testStreamCompletion() throws IOException { Mockito.verify(inputStream, times(1)).close(); } } + + @Test + void testToolCallsWithoutExecution() throws IOException { + stubFor( + post(urlPathEqualTo("/completion")) + .willReturn( + aResponse() + .withBodyFile("toolCallsResponse.json") + .withHeader("Content-Type", "application/json"))); + + defaultOptions.setToolCallbacks(List.of(ToolCallbacks.from(new WeatherMethod()))); + defaultOptions.setInternalToolExecutionEnabled(false); + val prompt = new Prompt("What is the weather in Potsdam and in Toulouse?", defaultOptions); + val result = client.call(prompt); + + List toolCalls = result.getResult().getOutput().getToolCalls(); + assertThat(toolCalls).hasSize(2); + ToolCall toolCall1 = toolCalls.get(0); + ToolCall toolCall2 = toolCalls.get(1); + assertThat(toolCall1.type()).isEqualTo("function"); + assertThat(toolCall2.type()).isEqualTo("function"); + assertThat(toolCall1.name()).isEqualTo("getCurrentWeather"); + assertThat(toolCall2.name()).isEqualTo("getCurrentWeather"); + assertThat(toolCall1.arguments()) + .isEqualTo("{\"arg0\": {\"location\": \"Potsdam\", \"unit\": \"C\"}}"); + assertThat(toolCall2.arguments()) + .isEqualTo("{\"arg0\": {\"location\": \"Toulouse\", \"unit\": \"C\"}}"); + + try (var request1InputStream = fileLoader.apply("toolCallsRequest.json")) { + final String request1 = new String(request1InputStream.readAllBytes()); + verify(postRequestedFor(anyUrl()).withRequestBody(equalToJson(request1))); + } + } + + @Test + void testToolCallsWithExecution() throws IOException { + // https://platform.openai.com/docs/guides/function-calling + stubFor( + post(urlPathEqualTo("/completion")) + .inScenario("Tool Calls") + .whenScenarioStateIs(STARTED) + .willReturn( + aResponse() + .withBodyFile("toolCallsResponse.json") + .withHeader("Content-Type", "application/json")) + .willSetStateTo("Second Call")); + + stubFor( + post(urlPathEqualTo("/completion")) + .inScenario("Tool Calls") + .whenScenarioStateIs("Second Call") + .willReturn( + aResponse() + .withBodyFile("toolCallsResponse2.json") + .withHeader("Content-Type", "application/json"))); + + defaultOptions.setToolCallbacks(List.of(ToolCallbacks.from(new WeatherMethod()))); + val prompt = new Prompt("What is the weather in Potsdam and in Toulouse?", defaultOptions); + val result = client.call(prompt); + + assertThat(result.getResult().getOutput().getText()) + .isEqualTo("The current temperature in Potsdam is 30°C and in Toulouse 30°C."); + + try (var request1InputStream = fileLoader.apply("toolCallsRequest.json")) { + try (var request2InputStream = fileLoader.apply("toolCallsRequest2.json")) { + final String request1 = new String(request1InputStream.readAllBytes()); + final String request2 = new String(request2InputStream.readAllBytes()); + verify(postRequestedFor(anyUrl()).withRequestBody(equalToJson(request1))); + verify(postRequestedFor(anyUrl()).withRequestBody(equalToJson(request2))); + } + } + } } diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatResponseTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatResponseTest.java index c03c2a50c..87d0e7403 100644 --- a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatResponseTest.java +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatResponseTest.java @@ -22,7 +22,7 @@ void testToGeneration() { Generation generation = OrchestrationSpringChatResponse.toGeneration(choice); - assertThat(generation.getOutput().getContent()).isEqualTo("Hello, world!"); + assertThat(generation.getOutput().getText()).isEqualTo("Hello, world!"); assertThat(generation.getMetadata().getFinishReason()).isEqualTo("stop"); assertThat(generation.getMetadata().get("index")).isEqualTo(0); } @@ -48,7 +48,7 @@ void testToChatResponseMetadata() { var usage = metadata.getUsage(); assertThat(usage.getPromptTokens()).isEqualTo(10L); - assertThat(usage.getGenerationTokens()).isEqualTo(20L); + assertThat(usage.getCompletionTokens()).isEqualTo(20L); assertThat(usage.getTotalTokens()).isEqualTo(30L); } } diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/spring/WeatherMethod.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/spring/WeatherMethod.java new file mode 100644 index 000000000..e2444bef0 --- /dev/null +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/spring/WeatherMethod.java @@ -0,0 +1,41 @@ +package com.sap.ai.sdk.orchestration.spring; + +import javax.annotation.Nonnull; +import org.springframework.ai.tool.annotation.Tool; +import org.springframework.ai.tool.annotation.ToolParam; + +public class WeatherMethod { + + /** Unit of temperature */ + public enum Unit { + /** Celsius */ + @SuppressWarnings("unused") + C, + /** Fahrenheit */ + @SuppressWarnings("unused") + F + } + + /** + * Request for the weather + * + * @param location the city + * @param unit the unit of temperature + */ + public record Request(String location, Unit unit) {} + + /** + * Response for the weather + * + * @param temp the temperature + * @param unit the unit of temperature + */ + public record Response(double temp, Unit unit) {} + + @Nonnull + @SuppressWarnings("unused") + @Tool(description = "Get the weather in location") + Response getCurrentWeather(@ToolParam @Nonnull Request request) { + return new Response(30, request.unit); + } +} diff --git a/orchestration/src/test/resources/__files/toolCallsResponse.json b/orchestration/src/test/resources/__files/toolCallsResponse.json new file mode 100644 index 000000000..5a0f30f0e --- /dev/null +++ b/orchestration/src/test/resources/__files/toolCallsResponse.json @@ -0,0 +1,91 @@ +{ + "request_id": "935d602a-021d-4da1-a8d9-b4bae42f5720", + "module_results": { + "templating": [ + { + "role": "user", + "content": "What is the weather in Potsdam and in Toulouse?" + } + ], + "llm": { + "id": "chatcmpl-AxUvsJ4kwEGSGFp6ha89MBWdU9lCW", + "object": "chat.completion", + "created": 1738743620, + "model": "gpt-3.5-turbo-1106", + "system_fingerprint": "fp_0165350fbb", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_LOyP7EVdeqFlGEmVzmPdCVey", + "type": "function", + "function": { + "name": "getCurrentWeather", + "arguments": "{\"arg0\": {\"location\": \"Potsdam\", \"unit\": \"C\"}}" + } + }, + { + "id": "call_bwFjnXCfCO4N3f0bMtFMlNSg", + "type": "function", + "function": { + "name": "getCurrentWeather", + "arguments": "{\"arg0\": {\"location\": \"Toulouse\", \"unit\": \"C\"}}" + } + } + ] + }, + "finish_reason": "tool_calls" + } + ], + "usage": { + "completion_tokens": 54, + "prompt_tokens": 67, + "total_tokens": 121 + } + } + }, + "orchestration_result": { + "id": "chatcmpl-AxUvsJ4kwEGSGFp6ha89MBWdU9lCW", + "object": "chat.completion", + "created": 1738743620, + "model": "gpt-3.5-turbo-1106", + "system_fingerprint": "fp_0165350fbb", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_LOyP7EVdeqFlGEmVzmPdCVey", + "type": "function", + "function": { + "name": "getCurrentWeather", + "arguments": "{\"arg0\": {\"location\": \"Potsdam\", \"unit\": \"C\"}}" + } + }, + { + "id": "call_bwFjnXCfCO4N3f0bMtFMlNSg", + "type": "function", + "function": { + "name": "getCurrentWeather", + "arguments": "{\"arg0\": {\"location\": \"Toulouse\", \"unit\": \"C\"}}" + } + } + ] + }, + "finish_reason": "tool_calls" + } + ], + "usage": { + "completion_tokens": 54, + "prompt_tokens": 67, + "total_tokens": 121 + } + } +} diff --git a/orchestration/src/test/resources/__files/toolCallsResponse2.json b/orchestration/src/test/resources/__files/toolCallsResponse2.json new file mode 100644 index 000000000..53037d1c9 --- /dev/null +++ b/orchestration/src/test/resources/__files/toolCallsResponse2.json @@ -0,0 +1,87 @@ +{ + "request_id": "935d602a-021d-4da1-a8d9-b4bae42f5720", + "module_results": { + "templating": [ + { + "role": "user", + "content": "What is the weather in Potsdam and in Toulouse?" + }, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_LOyP7EVdeqFlGEmVzmPdCVey", + "type": "function", + "function": { + "name": "getCurrentWeather", + "arguments": "{\"arg0\": {\"location\": \"Potsdam\", \"unit\": \"C\"}}" + } + }, + { + "id": "call_bwFjnXCfCO4N3f0bMtFMlNSg", + "type": "function", + "function": { + "name": "getCurrentWeather", + "arguments": "{\"arg0\": {\"location\": \"Toulouse\", \"unit\": \"C\"}}" + } + } + ] + }, + { + "role": "tool", + "content": "{\"temp\":30.0,\"unit\":\"C\"}", + "tool_call_id": "call_LOyP7EVdeqFlGEmVzmPdCVey" + }, + { + "role": "tool", + "content": "{\"temp\":30.0,\"unit\":\"C\"}", + "tool_call_id": "call_bwFjnXCfCO4N3f0bMtFMlNSg" + } + ], + "llm": { + "id": "chatcmpl-AxUvsJ4kwEGSGFp6ha89MBWdU9lCW", + "object": "chat.completion", + "created": 1738743620, + "model": "gpt-3.5-turbo-1106", + "system_fingerprint": "fp_0165350fbb", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "The current temperature in Potsdam is 30°C and in Toulouse 30°C." + }, + "finish_reason": "stop" + } + ], + "usage": { + "completion_tokens": 54, + "prompt_tokens": 67, + "total_tokens": 121 + } + } + }, + "orchestration_result": { + "id": "chatcmpl-AxUvsJ4kwEGSGFp6ha89MBWdU9lCW", + "object": "chat.completion", + "created": 1738743620, + "model": "gpt-3.5-turbo-1106", + "system_fingerprint": "fp_0165350fbb", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "The current temperature in Potsdam is 30°C and in Toulouse 30°C." + }, + "finish_reason": "stop" + } + ], + "usage": { + "completion_tokens": 54, + "prompt_tokens": 67, + "total_tokens": 121 + } + } +} diff --git a/orchestration/src/test/resources/toolCallsRequest.json b/orchestration/src/test/resources/toolCallsRequest.json new file mode 100644 index 000000000..d92e7bce0 --- /dev/null +++ b/orchestration/src/test/resources/toolCallsRequest.json @@ -0,0 +1,62 @@ +{ + "orchestration_config": { + "module_configurations": { + "llm_module_config": { + "model_name": "gpt-35-turbo-16k", + "model_params": {}, + "model_version": "latest" + }, + "templating_module_config": { + "template": [ + { + "role": "user", + "content": "What is the weather in Potsdam and in Toulouse?" + } + ], + "defaults": {}, + "tools": [ + { + "type": "function", + "function": { + "description": "Get the weather in location", + "name": "getCurrentWeather", + "parameters": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "additionalProperties": false, + "type": "object", + "properties": { + "arg0": { + "type": "object", + "properties": { + "location": { + "type": "string" + }, + "unit": { + "type": "string", + "enum": [ + "C", + "F" + ] + } + }, + "required": [ + "location", + "unit" + ] + } + }, + "required": [ + "arg0" + ] + }, + "strict": false + } + } + ] + } + }, + "stream": false + }, + "input_params": {}, + "messages_history": [] +} diff --git a/orchestration/src/test/resources/toolCallsRequest2.json b/orchestration/src/test/resources/toolCallsRequest2.json new file mode 100644 index 000000000..4549a97dc --- /dev/null +++ b/orchestration/src/test/resources/toolCallsRequest2.json @@ -0,0 +1,94 @@ +{ + "orchestration_config": { + "module_configurations": { + "llm_module_config": { + "model_name": "gpt-35-turbo-16k", + "model_params": {}, + "model_version": "latest" + }, + "templating_module_config": { + "template": [ + { + "role": "user", + "content": "What is the weather in Potsdam and in Toulouse?" + }, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "function": { + "name": "getCurrentWeather", + "arguments": "{\"arg0\": {\"location\": \"Potsdam\", \"unit\": \"C\"}}" + }, + "id": "call_LOyP7EVdeqFlGEmVzmPdCVey", + "type": "function" + }, + { + "function": { + "name": "getCurrentWeather", + "arguments": "{\"arg0\": {\"location\": \"Toulouse\", \"unit\": \"C\"}}" + }, + "id": "call_bwFjnXCfCO4N3f0bMtFMlNSg", + "type": "function" + } + ] + }, + { + "role": "tool", + "content": "{\"temp\":30.0,\"unit\":\"C\"}", + "tool_call_id": "call_LOyP7EVdeqFlGEmVzmPdCVey" + }, + { + "role": "tool", + "content": "{\"temp\":30.0,\"unit\":\"C\"}", + "tool_call_id": "call_bwFjnXCfCO4N3f0bMtFMlNSg" + } + ], + "defaults": {}, + "tools": [ + { + "type": "function", + "function": { + "description": "Get the weather in location", + "name": "getCurrentWeather", + "parameters": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "additionalProperties": false, + "type": "object", + "properties": { + "arg0": { + "type": "object", + "properties": { + "location": { + "type": "string" + }, + "unit": { + "type": "string", + "enum": [ + "C", + "F" + ] + } + }, + "required": [ + "location", + "unit" + ] + } + }, + "required": [ + "arg0" + ] + }, + "strict": false + } + } + ] + } + }, + "stream": false + }, + "input_params": {}, + "messages_history": [] +} diff --git a/pom.xml b/pom.xml index 53c02440b..32bdbb56f 100644 --- a/pom.xml +++ b/pom.xml @@ -64,7 +64,7 @@ 2.1.3 3.5.2 6.2.1 - 1.0.0-M5 + 1.0.0-M6 3.6.12 3.1.0 5.15.2 diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/SpringAiOrchestrationController.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/SpringAiOrchestrationController.java index f7b3521e3..2570325a8 100644 --- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/SpringAiOrchestrationController.java +++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/SpringAiOrchestrationController.java @@ -5,6 +5,7 @@ import javax.annotation.Nonnull; import javax.annotation.Nullable; import lombok.val; +import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RequestMapping; @@ -21,12 +22,14 @@ class SpringAiOrchestrationController { @GetMapping("/completion") Object completion( @Nullable @RequestParam(value = "format", required = false) final String format) { - val response = (OrchestrationSpringChatResponse) service.completion(); + val response = service.completion(); if ("json".equals(format)) { - return response.getOrchestrationResponse().getOriginalResponse(); + return ((OrchestrationSpringChatResponse) response) + .getOrchestrationResponse() + .getOriginalResponse(); } - return response.getResult().getOutput().getContent(); + return response.getResult().getOutput().getText(); } @GetMapping("/streamChatCompletion") @@ -34,26 +37,46 @@ Object completion( Flux streamChatCompletion() { return service .streamChatCompletion() - .map(chatResponse -> chatResponse.getResult().getOutput().getContent()); + .map(chatResponse -> chatResponse.getResult().getOutput().getText()); } @GetMapping("/template") Object template(@Nullable @RequestParam(value = "format", required = false) final String format) { - val response = (OrchestrationSpringChatResponse) service.template(); + val response = service.template(); if ("json".equals(format)) { - return response.getOrchestrationResponse().getOriginalResponse(); + return ((OrchestrationSpringChatResponse) response) + .getOrchestrationResponse() + .getOriginalResponse(); } - return response.getResult().getOutput().getContent(); + return response.getResult().getOutput().getText(); } @GetMapping("/masking") Object masking(@Nullable @RequestParam(value = "format", required = false) final String format) { - val response = (OrchestrationSpringChatResponse) service.masking(); + val response = service.masking(); if ("json".equals(format)) { - return response.getOrchestrationResponse().getOriginalResponse(); + return ((OrchestrationSpringChatResponse) response) + .getOrchestrationResponse() + .getOriginalResponse(); } - return response.getResult().getOutput().getContent(); + return response.getResult().getOutput().getText(); + } + + @GetMapping("/toolCalling") + Object toolCalling( + @Nullable @RequestParam(value = "format", required = false) final String format) { + // tool execution broken on orchestration https://jira.tools.sap/browse/AI-86627 + val response = service.toolCalling(false); + + if ("json".equals(format)) { + return ((OrchestrationSpringChatResponse) response) + .getOrchestrationResponse() + .getOriginalResponse(); + } + final AssistantMessage message = response.getResult().getOutput(); + final String text = message.getText(); + return text.isEmpty() ? message.getToolCalls().toString() : text; } } diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/SpringAiOrchestrationService.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/SpringAiOrchestrationService.java index 0bdea4180..7f4e30d68 100644 --- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/SpringAiOrchestrationService.java +++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/SpringAiOrchestrationService.java @@ -7,6 +7,7 @@ import com.sap.ai.sdk.orchestration.model.DPIEntities; import com.sap.ai.sdk.orchestration.spring.OrchestrationChatModel; import com.sap.ai.sdk.orchestration.spring.OrchestrationChatOptions; +import java.util.List; import java.util.Map; import javax.annotation.Nonnull; import lombok.val; @@ -14,6 +15,7 @@ import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.tool.ToolCallbacks; import org.springframework.stereotype.Service; import reactor.core.publisher.Flux; @@ -87,4 +89,21 @@ public ChatResponse masking() { return client.call(prompt); } + + /** + * Turn a method into a tool by annotating it with @Tool. Spring AI + * Tool Method Declarative Specification + * + * @return the assistant response object + */ + @Nonnull + public ChatResponse toolCalling(final boolean internalToolExecutionEnabled) { + final OrchestrationChatOptions options = new OrchestrationChatOptions(config); + options.setToolCallbacks(List.of(ToolCallbacks.from(new WeatherMethod()))); + options.setInternalToolExecutionEnabled(internalToolExecutionEnabled); + + val prompt = new Prompt("What is the weather in Potsdam and in Toulouse?", options); + return client.call(prompt); + } } diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/WeatherMethod.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/WeatherMethod.java new file mode 100644 index 000000000..f7413a20d --- /dev/null +++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/WeatherMethod.java @@ -0,0 +1,42 @@ +package com.sap.ai.sdk.app.services; + +import javax.annotation.Nonnull; +import org.springframework.ai.tool.annotation.Tool; +import org.springframework.ai.tool.annotation.ToolParam; + +class WeatherMethod { + + /** Unit of temperature */ + enum Unit { + /** Celsius */ + @SuppressWarnings("unused") + C, + /** Fahrenheit */ + @SuppressWarnings("unused") + F + } + + /** + * Request for the weather + * + * @param location the city + * @param unit the unit of temperature + */ + record Request(String location, Unit unit) {} + + /** + * Response for the weather + * + * @param temp the temperature + * @param unit the unit of temperature + */ + record Response(double temp, Unit unit) {} + + @Nonnull + @SuppressWarnings("unused") + @Tool(description = "Get the weather in location") + Response getCurrentWeather(@ToolParam @Nonnull final Request request) { + final int temperature = request.location.hashCode() % 30; + return new Response(temperature, request.unit); + } +} diff --git a/sample-code/spring-app/src/main/resources/static/index.html b/sample-code/spring-app/src/main/resources/static/index.html index 8f816ffe9..2dee34e4d 100644 --- a/sample-code/spring-app/src/main/resources/static/index.html +++ b/sample-code/spring-app/src/main/resources/static/index.html @@ -564,6 +564,16 @@
Orchestration Integration
+
  • +
    + +
    + Register a function that will be called when the user asks for the weather. +
    +
    +
  • diff --git a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/SpringAiOrchestrationTest.java b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/SpringAiOrchestrationTest.java index bdf079684..0ea05b7fe 100644 --- a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/SpringAiOrchestrationTest.java +++ b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/SpringAiOrchestrationTest.java @@ -1,11 +1,15 @@ package com.sap.ai.sdk.app.controllers; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import com.sap.ai.sdk.app.services.SpringAiOrchestrationService; +import com.sap.ai.sdk.orchestration.OrchestrationClientException; +import java.util.List; import java.util.concurrent.atomic.AtomicInteger; import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.messages.AssistantMessage.ToolCall; import org.springframework.ai.chat.model.ChatResponse; @Slf4j @@ -17,7 +21,7 @@ public class SpringAiOrchestrationTest { void testCompletion() { ChatResponse response = service.completion(); assertThat(response).isNotNull(); - assertThat(response.getResult().getOutput().getContent()).contains("Paris"); + assertThat(response.getResult().getOutput().getText()).contains("Paris"); } @Test @@ -30,7 +34,7 @@ void testStreamChatCompletion() { .forEach( delta -> { log.info("delta: {}", delta); - if (!delta.getResult().getOutput().getContent().isEmpty()) { + if (!delta.getResult().getOutput().getText().isEmpty()) { filledDeltaCount.incrementAndGet(); } }); @@ -44,13 +48,38 @@ void testStreamChatCompletion() { void testTemplate() { ChatResponse response = service.template(); assertThat(response).isNotNull(); - assertThat(response.getResult().getOutput().getContent()).isNotEmpty(); + assertThat(response.getResult().getOutput().getText()).isNotEmpty(); } @Test void testMasking() { ChatResponse response = service.masking(); assertThat(response).isNotNull(); - assertThat(response.getResult().getOutput().getContent()).isNotEmpty(); + assertThat(response.getResult().getOutput().getText()).isNotEmpty(); + } + + @Test + void testToolCallingWithoutExecution() { + ChatResponse response = service.toolCalling(false); + List toolCalls = response.getResult().getOutput().getToolCalls(); + assertThat(toolCalls).hasSize(2); + ToolCall toolCall1 = toolCalls.get(0); + ToolCall toolCall2 = toolCalls.get(1); + assertThat(toolCall1.type()).isEqualTo("function"); + assertThat(toolCall2.type()).isEqualTo("function"); + assertThat(toolCall1.name()).isEqualTo("getCurrentWeather"); + assertThat(toolCall2.name()).isEqualTo("getCurrentWeather"); + assertThat(toolCall1.arguments()) + .isEqualTo("{\"arg0\": {\"location\": \"Potsdam\", \"unit\": \"C\"}}"); + assertThat(toolCall2.arguments()) + .isEqualTo("{\"arg0\": {\"location\": \"Toulouse\", \"unit\": \"C\"}}"); + } + + @Test + void testToolCallingWithExecution() { + // tool execution broken on orchestration https://jira.tools.sap/browse/AI-86627 + assertThatThrownBy(() -> service.toolCalling(true)) + .isExactlyInstanceOf(OrchestrationClientException.class) + .hasMessageContaining("Request failed with status 400 Bad Request"); } }