diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/JacksonMixins.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/JacksonMixins.java new file mode 100644 index 000000000..a6b06e751 --- /dev/null +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/JacksonMixins.java @@ -0,0 +1,25 @@ +package com.sap.ai.sdk.foundationmodels.openai; + +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionStreamResponse; +import lombok.AccessLevel; +import lombok.NoArgsConstructor; + +/** + * This class contains Jackson Mixins for customizing the serialization and deserialization behavior + * of certain classes in the OpenAI SDK. + */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) +final class JacksonMixins { + + /** + * Mixin interface to customize the deserialization of CreateChatCompletionStreamResponse. + * + *

Disables type information inclusion and specifies the concrete class to use for + * deserialization. + */ + @JsonTypeInfo(use = JsonTypeInfo.Id.NONE) + @JsonDeserialize(as = CreateChatCompletionStreamResponse.class) + interface DefaultChatCompletionCreate200ResponseMixIn {} +} diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiAssistantMessage.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiAssistantMessage.java new file mode 100644 index 000000000..15cb3ffe0 --- /dev/null +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiAssistantMessage.java @@ -0,0 +1,37 @@ +package com.sap.ai.sdk.foundationmodels.openai; + +import com.google.common.annotations.Beta; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestAssistantMessage; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestAssistantMessageContent; +import javax.annotation.Nonnull; +import lombok.Value; +import lombok.experimental.Accessors; + +/** + * Represents a chat message as 'assistant' to OpenAI service. + * + * @since 1.4.0 + */ +@Beta +@Value +@Accessors(fluent = true) +class OpenAiAssistantMessage implements OpenAiMessage { + + /** The role of the message. */ + @Nonnull String role = "assistant"; + + /** The content of the message. */ + @Nonnull String content; + + /** + * Converts the message to a serializable object. + * + * @return the corresponding {@code ChatCompletionRequestAssistantMessage} object. + */ + @Nonnull + ChatCompletionRequestAssistantMessage createChatCompletionRequestMessage() { + return new ChatCompletionRequestAssistantMessage() + .role(ChatCompletionRequestAssistantMessage.RoleEnum.fromValue(role())) + .content(ChatCompletionRequestAssistantMessageContent.create(content)); + } +} diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionDelta.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionDelta.java new file mode 100644 index 000000000..46b5416e5 --- /dev/null +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionDelta.java @@ -0,0 +1,69 @@ +package com.sap.ai.sdk.foundationmodels.openai; + +import static com.sap.ai.sdk.foundationmodels.openai.OpenAiUtils.getOpenAiObjectMapper; +import static lombok.AccessLevel.PACKAGE; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.google.common.annotations.Beta; +import com.sap.ai.sdk.core.common.StreamedDelta; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.CompletionUsage; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionStreamResponse; +import java.util.Map; +import java.util.Objects; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.ToString; + +/** + * Represents an OpenAI chat completion output delta for streaming. + * + * @since 1.4.0 + */ +@Beta +@RequiredArgsConstructor(onConstructor_ = @JsonCreator, access = PACKAGE) +@Getter +@ToString +@EqualsAndHashCode +public class OpenAiChatCompletionDelta implements StreamedDelta { + /** The original response from the chat completion stream. */ + @Nonnull private final CreateChatCompletionStreamResponse originalResponse; + + @Nonnull + @Override + public String getDeltaContent() { + final var choices = getOriginalResponse().getChoices(); + if (!choices.isEmpty() && choices.get(0).getIndex() == 0) { + final var message = choices.get(0).getDelta(); + return Objects.requireNonNullElse(message.getContent(), ""); + } + return ""; + } + + @Nullable + @Override + public String getFinishReason() { + final var choices = getOriginalResponse().getChoices(); + if (!choices.isEmpty()) { + final var finishReason = choices.get(0).getFinishReason(); + return finishReason != null ? finishReason.getValue() : null; + } + return null; + } + + /** + * Retrieves the completion usage from the response, or null if it is not available. + * + * @return The completion usage or null. + */ + @Nullable + public CompletionUsage getCompletionUsage() { + if (getOriginalResponse().getCustomFieldNames().contains("usage") + && getOriginalResponse().getCustomField("usage") instanceof Map usage) { + return getOpenAiObjectMapper().convertValue(usage, CompletionUsage.class); + } + return null; + } +} diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequest.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequest.java new file mode 100644 index 000000000..18d23f148 --- /dev/null +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequest.java @@ -0,0 +1,287 @@ +package com.sap.ai.sdk.foundationmodels.openai; + +import com.google.common.annotations.Beta; +import com.google.common.collect.Lists; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionStreamOptions; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionToolChoiceOption; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionRequest; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionRequestAllOfResponseFormat; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionRequestAllOfStop; +import java.math.BigDecimal; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import lombok.AccessLevel; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.Value; +import lombok.With; +import lombok.experimental.Tolerate; + +/** + * Represents a request for OpenAI chat completion, including conversation messages and parameters. + * + * @see OpenAI + * API Reference + * @since 1.4.0 + */ +@Beta +@Value +@With +@AllArgsConstructor(access = AccessLevel.PRIVATE) +@Getter(value = AccessLevel.NONE) +public class OpenAiChatCompletionRequest { + /** List of messages from the conversation. */ + @Nonnull List messages; + + /** Upto 4 Stop sequences to interrupts token generation and returns a response without them. */ + @Nullable List stop; + + /** + * Controls the randomness of the completion. + * + *

Lower values (e.g. 0.0) make the model more deterministic and repetitive, while higher + * values (e.g. 1.0) make the model more random and creative. + */ + @Nullable BigDecimal temperature; + + /** + * Controls the cumulative probability threshold used for nucleus sampling. Alternative to {@link + * #temperature}. + * + *

Lower values (e.g. 0.1) limit the model to consider only the smallest set of tokens whose + * combined probabilities add up to at least 10% of the total. + */ + @Nullable BigDecimal topP; + + /** Maximum number of tokens that can be generated for the completion. */ + @Nullable Integer maxTokens; + + /** + * Maximum number of tokens that can be generated for the completion, including consumed reasoning + * tokens. This field supersedes {@link #maxTokens} and should be used with newer models. + */ + @Nullable Integer maxCompletionTokens; + + /** + * Encourage new topic by penalising token based on their presence in the completion. + * + *

Value should be in range [-2, 2]. + */ + @Nullable BigDecimal presencePenalty; + + /** + * Encourage new topic by penalising tokens based on their frequency in the completion. + * + *

Value should be in range [-2, 2]. + */ + @Nullable BigDecimal frequencyPenalty; + + /** + * A map that adjusts the likelihood of specified tokens by adding a bias value (between -100 and + * 100) to the logits before sampling. Extreme values can effectively ban or enforce the selection + * of tokens. + */ + @Nullable Map logitBias; + + /** + * Unique identifier for the end-user making the request. This can help with monitoring and abuse + * detection. + */ + @Nullable String user; + + /** Whether to include log probabilities in the response. */ + @With(AccessLevel.NONE) + @Nullable + Boolean logprobs; + + /** + * Number of top log probabilities to return for each token. An integer between 0 and 20. This is + * only relevant if {@code logprobs} is enabled. + */ + @Nullable Integer topLogprobs; + + /** Number of completions to generate. */ + @Nullable Integer n; + + /** Whether to allow parallel tool calls. */ + @With(AccessLevel.NONE) + @Nullable + Boolean parallelToolCalls; + + /** Seed for random number generation. */ + @Nullable Integer seed; + + /** Options for streaming the completion response. */ + @Nullable ChatCompletionStreamOptions streamOptions; + + /** Response format for the completion. */ + @Nullable CreateChatCompletionRequestAllOfResponseFormat responseFormat; + + /** List of tools that the model may invoke during the completion. */ + @Nullable List tools; + + /** Option to control which tool is invoked by the model. */ + @Nullable ChatCompletionToolChoiceOption toolChoice; + + /** + * Creates an OpenAiChatCompletionPrompt with string as user message. + * + * @param message the message to be added to the prompt + */ + @Tolerate + public OpenAiChatCompletionRequest(@Nonnull final String message) { + this(OpenAiMessage.user(message)); + } + + /** + * Creates an OpenAiChatCompletionPrompt with a multiple unpacked messages. + * + * @param message the primary message to be added to the prompt + * @param messages additional messages to be added to the prompt + */ + @Tolerate + public OpenAiChatCompletionRequest( + @Nonnull final OpenAiMessage message, @Nonnull final OpenAiMessage... messages) { + this( + Lists.asList(message, messages), + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null); + } + + /** + * Adds stop sequences to the request. + * + * @param sequence the primary stop sequence + * @param sequences additional stop sequences + * @return a new OpenAiChatCompletionRequest instance with the specified stop sequences + */ + @Tolerate + @Nonnull + public OpenAiChatCompletionRequest withStop( + @Nonnull final String sequence, @Nonnull final String... sequences) { + return this.withStop(Lists.asList(sequence, sequences)); + } + + /** + * Sets the parallel tool calls option. + * + * @param parallelToolCalls Whether to allow parallel tool calls. + * @return A new instance with the specified option. + */ + @Nonnull + public OpenAiChatCompletionRequest withParallelToolCalls( + @Nonnull final Boolean parallelToolCalls) { + return Objects.equals(this.parallelToolCalls, parallelToolCalls) + ? this + : new OpenAiChatCompletionRequest( + this.messages, + this.stop, + this.temperature, + this.topP, + this.maxTokens, + this.maxCompletionTokens, + this.presencePenalty, + this.frequencyPenalty, + this.logitBias, + this.user, + this.logprobs, + this.topLogprobs, + this.n, + parallelToolCalls, + this.seed, + this.streamOptions, + this.responseFormat, + this.tools, + this.toolChoice); + } + + /** + * Sets the log probabilities option. + * + * @param logprobs Whether to include log probabilities in the response. + * @return A new instance with the specified option. + */ + @Nonnull + public OpenAiChatCompletionRequest withLogprobs(@Nonnull final Boolean logprobs) { + return Objects.equals(this.logprobs, logprobs) + ? this + : new OpenAiChatCompletionRequest( + this.messages, + this.stop, + this.temperature, + this.topP, + this.maxTokens, + this.maxCompletionTokens, + this.presencePenalty, + this.frequencyPenalty, + this.logitBias, + this.user, + logprobs, + this.topLogprobs, + this.n, + this.parallelToolCalls, + this.seed, + this.streamOptions, + this.responseFormat, + this.tools, + this.toolChoice); + } + + /** + * Converts the request to a generated model class CreateChatCompletionRequest. + * + * @return the CreateChatCompletionRequest + */ + CreateChatCompletionRequest createCreateChatCompletionRequest() { + final var request = new CreateChatCompletionRequest(); + this.messages.forEach( + message -> + request.addMessagesItem(OpenAiUtils.createChatCompletionRequestMessage(message))); + + request.stop(this.stop != null ? CreateChatCompletionRequestAllOfStop.create(this.stop) : null); + + request.temperature(this.temperature); + request.topP(this.topP); + + request.stream(null); + request.maxTokens(this.maxTokens); + request.maxCompletionTokens(this.maxCompletionTokens); + request.presencePenalty(this.presencePenalty); + request.frequencyPenalty(this.frequencyPenalty); + request.logitBias(this.logitBias); + request.user(this.user); + request.logprobs(this.logprobs); + request.topLogprobs(this.topLogprobs); + request.n(this.n); + request.parallelToolCalls(this.parallelToolCalls); + request.seed(this.seed); + request.streamOptions(this.streamOptions); + request.responseFormat(this.responseFormat); + request.tools(this.tools); + request.toolChoice(this.toolChoice); + request.functionCall(null); + request.functions(null); + return request; + } +} diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionResponse.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionResponse.java new file mode 100644 index 000000000..4da73fe6f --- /dev/null +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionResponse.java @@ -0,0 +1,64 @@ +package com.sap.ai.sdk.foundationmodels.openai; + +import static com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionResponseChoicesInner.FinishReasonEnum.CONTENT_FILTER; +import static lombok.AccessLevel.NONE; +import static lombok.AccessLevel.PACKAGE; + +import com.google.common.annotations.Beta; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.CompletionUsage; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionResponse; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionResponseChoicesInner; +import java.util.Objects; +import javax.annotation.Nonnull; +import lombok.RequiredArgsConstructor; +import lombok.Setter; +import lombok.Value; + +/** + * Represents the output of an OpenAI chat completion. * + * + * @since 1.4.0 + */ +@Beta +@Value +@RequiredArgsConstructor(access = PACKAGE) +@Setter(value = NONE) +public class OpenAiChatCompletionResponse { + /** The original response from the OpenAI API. */ + @Nonnull final CreateChatCompletionResponse originalResponse; + + /** + * Gets the token usage from the original response. + * + * @return the token usage + */ + @Nonnull + public CompletionUsage getTokenUsage() { + return getOriginalResponse().getUsage(); + } + + /** + * Gets the first choice from the original response. + * + * @return the first choice + */ + @Nonnull + public CreateChatCompletionResponseChoicesInner getChoice() { + return getOriginalResponse().getChoices().get(0); + } + + /** + * Gets the content of the first choice. + * + * @return the content of the first choice + * @throws OpenAiClientException if the content is filtered by the content filter + */ + @Nonnull + public String getContent() { + if (CONTENT_FILTER.equals(getOriginalResponse().getChoices().get(0).getFinishReason())) { + throw new OpenAiClientException("Content filter filtered the output."); + } + + return Objects.requireNonNullElse(getChoice().getMessage().getContent(), ""); + } +} diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java index 8699201a7..fe5415dfb 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java @@ -1,6 +1,6 @@ package com.sap.ai.sdk.foundationmodels.openai; -import static com.sap.ai.sdk.core.JacksonConfiguration.getDefaultObjectMapper; +import static com.sap.ai.sdk.foundationmodels.openai.OpenAiUtils.getOpenAiObjectMapper; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; @@ -10,14 +10,17 @@ import com.sap.ai.sdk.core.common.ClientResponseHandler; import com.sap.ai.sdk.core.common.ClientStreamingHandler; import com.sap.ai.sdk.core.common.StreamedDelta; -import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionDelta; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionStreamOptions; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionRequest; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionResponse; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreate200Response; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequest; import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionOutput; import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionParameters; import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage.OpenAiChatSystemMessage; import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage.OpenAiChatUserMessage; import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiEmbeddingOutput; import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiEmbeddingParameters; -import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiError; import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor; import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination; import com.sap.cloud.sdk.cloudplatform.connectivity.Destination; @@ -39,7 +42,8 @@ @RequiredArgsConstructor(access = AccessLevel.PRIVATE) public final class OpenAiClient { private static final String DEFAULT_API_VERSION = "2024-02-01"; - static final ObjectMapper JACKSON = getDefaultObjectMapper(); + static final ObjectMapper JACKSON = getOpenAiObjectMapper(); + @Nullable private String systemPrompt = null; @Nonnull private final Destination destination; @@ -107,7 +111,11 @@ public static OpenAiClient withCustomDestination(@Nonnull final Destination dest } /** - * Add a system prompt before user prompts. + * Use this method to set a system prompt that should be used across multiple chat completions + * with basic string prompts {@link #streamChatCompletionDeltas(OpenAiChatCompletionParameters)}. + * + *

Note: The system prompt is ignored on chat completions invoked with + * OpenAiChatCompletionPrompt. * * @param systemPrompt the system prompt * @return the client @@ -119,7 +127,7 @@ public OpenAiClient withSystemPrompt(@Nonnull final String systemPrompt) { } /** - * Generate a completion for the given user prompt. + * Generate a completion for the given string prompt as user. * * @param prompt a text message. * @return the completion output @@ -137,9 +145,41 @@ public OpenAiChatCompletionOutput chatCompletion(@Nonnull final String prompt) } /** - * Generate a completion for the given prompt. + * Generate a completion for the given conversation and request parameters. + * + * @param request the completion request. + * @return the completion output + * @throws OpenAiClientException if the request fails + * @since 1.4.0 + */ + @Beta + @Nonnull + public OpenAiChatCompletionResponse chatCompletion( + @Nonnull final OpenAiChatCompletionRequest request) throws OpenAiClientException { + warnIfUnsupportedUsage(); + return new OpenAiChatCompletionResponse( + chatCompletion(request.createCreateChatCompletionRequest())); + } + + /** + * Generate a completion for the given low-level request object. * - * @param parameters the prompt, including messages and other parameters. + * @param request the completion request. + * @return the completion output + * @throws OpenAiClientException if the request fails + * @since 1.4.0 + */ + @Beta + @Nonnull + public CreateChatCompletionResponse chatCompletion( + @Nonnull final CreateChatCompletionRequest request) throws OpenAiClientException { + return execute("/chat/completions", request, CreateChatCompletionResponse.class); + } + + /** + * Generate a completion for the given conversation and request parameters. + * + * @param parameters the completion request. * @return the completion output * @throws OpenAiClientException if the request fails */ @@ -151,9 +191,10 @@ public OpenAiChatCompletionOutput chatCompletion( } /** - * Stream a completion for the given prompt. Returns a lazily populated stream of text - * chunks. To access more details about the individual chunks, use {@link - * #streamChatCompletionDeltas(OpenAiChatCompletionParameters)}. + * Stream a completion for the given string prompt as user. + * + *

Returns a lazily populated stream of text chunks. To access more details about the + * individual chunks, use {@link #streamChatCompletionDeltas(OpenAiChatCompletionRequest)}. * *

The stream should be consumed using a try-with-resources block to ensure that the underlying * HTTP connection is closed. @@ -171,19 +212,21 @@ public OpenAiChatCompletionOutput chatCompletion( * Stream#parallel()} on this stream is not supported. * * @param prompt a text message. - * @return A stream of message deltas + * @return A stream of text chunks * @throws OpenAiClientException if the request fails or if the finish reason is content_filter - * @see #streamChatCompletionDeltas(OpenAiChatCompletionParameters) + * @see #streamChatCompletionDeltas(OpenAiChatCompletionRequest) */ @Nonnull public Stream streamChatCompletion(@Nonnull final String prompt) throws OpenAiClientException { - final OpenAiChatCompletionParameters parameters = new OpenAiChatCompletionParameters(); - if (systemPrompt != null) { - parameters.addMessages(new OpenAiChatSystemMessage().setContent(systemPrompt)); - } - parameters.addMessages(new OpenAiChatUserMessage().addText(prompt)); - return streamChatCompletionDeltas(parameters) + final var userPrompt = OpenAiMessage.user(prompt); + + final var request = + systemPrompt != null + ? new OpenAiChatCompletionRequest(OpenAiMessage.system(systemPrompt), userPrompt) + : new OpenAiChatCompletionRequest(userPrompt); + + return streamChatCompletionDeltas(request.createCreateChatCompletionRequest()) .peek(OpenAiClient::throwOnContentFilter) .map(OpenAiChatCompletionDelta::getDeltaContent); } @@ -196,8 +239,10 @@ private static void throwOnContentFilter(@Nonnull final OpenAiChatCompletionDelt } /** - * Stream a completion for the given prompt. Returns a lazily populated stream of delta - * objects. To simply stream the text chunks use {@link #streamChatCompletion(String)} + * Stream a completion for the given conversation and request parameters. + * + *

Returns a lazily populated stream of delta objects. To simply stream the text chunks + * use {@link #streamChatCompletion(String)} * *

The stream should be consumed using a try-with-resources block to ensure that the underlying * HTTP connection is closed. @@ -205,7 +250,7 @@ private static void throwOnContentFilter(@Nonnull final OpenAiChatCompletionDelt *

Example: * *

{@code
-   * try (var stream = client.streamChatCompletionDeltas(params)) {
+   * try (var stream = client.streamChatCompletionDeltas(prompt)) {
    *       stream
    *           .peek(delta -> System.out.println(delta.getUsage()))
    *           .map(OpenAiChatCompletionDelta::getDeltaContent)
@@ -217,17 +262,75 @@ private static void throwOnContentFilter(@Nonnull final OpenAiChatCompletionDelt
    * block until all chunks are consumed. Also, for obvious reasons, invoking {@link
    * Stream#parallel()} on this stream is not supported.
    *
-   * @param parameters The prompt, including messages and other parameters.
+   * @param request The prompt, including a list of messages.
    * @return A stream of message deltas
    * @throws OpenAiClientException if the request fails or if the finish reason is content_filter
    * @see #streamChatCompletion(String)
+   * @since 1.4.0
    */
+  @Beta
   @Nonnull
   public Stream streamChatCompletionDeltas(
-      @Nonnull final OpenAiChatCompletionParameters parameters) throws OpenAiClientException {
+      @Nonnull final OpenAiChatCompletionRequest request) throws OpenAiClientException {
+    return streamChatCompletionDeltas(request.createCreateChatCompletionRequest());
+  }
+
+  /**
+   * Stream a completion for the given low-level request object. Returns a lazily populated
+   * stream of delta objects.
+   *
+   * @param request The completion request.
+   * @return A stream of message deltas
+   * @throws OpenAiClientException if the request fails or if the finish reason is content_filter
+   * @see #streamChatCompletionDeltas(OpenAiChatCompletionRequest) for a higher-level API
+   * @since 1.4.0
+   */
+  @Beta
+  @Nonnull
+  public Stream streamChatCompletionDeltas(
+      @Nonnull final CreateChatCompletionRequest request) throws OpenAiClientException {
+    request.stream(true).streamOptions(new ChatCompletionStreamOptions().includeUsage(true));
+    return executeStream("/chat/completions", request, OpenAiChatCompletionDelta.class);
+  }
+
+  /**
+   * Stream a completion for the given conversation and request parameters.
+   *
+   * 

Returns a lazily populated stream of delta objects. To simply stream the text chunks + * use {@link #streamChatCompletion(String)} + * + *

The stream should be consumed using a try-with-resources block to ensure that the underlying + * HTTP connection is closed. + * + *

Example: + * + *

{@code
+   * try (var stream = client.streamChatCompletionDeltas(request)) {
+   *       stream
+   *           .peek(delta -> System.out.println(delta.getUsage()))
+   *           .map(com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionDelta::getDeltaContent)
+   *           .forEach(System.out::println);
+   * }
+   * }
+ * + *

Please keep in mind that using a terminal stream operation like {@link Stream#forEach} will + * block until all chunks are consumed. Also, for obvious reasons, invoking {@link + * Stream#parallel()} on this stream is not supported. + * + * @param parameters The prompt, including a list of messages. + * @return A stream of message deltas + * @throws OpenAiClientException if the request fails or if the finish reason is content_filter + */ + @Nonnull + public Stream + streamChatCompletionDeltas(@Nonnull final OpenAiChatCompletionParameters parameters) + throws OpenAiClientException { warnIfUnsupportedUsage(); parameters.enableStreaming(); - return executeStream("/chat/completions", parameters, OpenAiChatCompletionDelta.class); + return executeStream( + "/chat/completions", + parameters, + com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionDelta.class); } private void warnIfUnsupportedUsage() { @@ -237,6 +340,22 @@ private void warnIfUnsupportedUsage() { } } + /** + * Get a vector representation of a given request with input that can be easily consumed by + * machine learning models and algorithms. + * + * @param request the request with input text. + * @return the embedding output + * @throws OpenAiClientException if the request fails + * @since 1.4.0 + */ + @Beta + @Nonnull + public EmbeddingsCreate200Response embedding(@Nonnull final EmbeddingsCreateRequest request) + throws OpenAiClientException { + return execute("/embeddings", request, EmbeddingsCreate200Response.class); + } + /** * Get a vector representation of a given input that can be easily consumed by machine learning * models and algorithms. @@ -300,6 +419,7 @@ private Stream streamRequest( try { final var client = ApacheHttpClient5Accessor.getHttpClient(destination); return new ClientStreamingHandler<>(deltaType, OpenAiError.class, OpenAiClientException::new) + .objectMapper(JACKSON) .handleStreamingResponse(client.executeOpen(null, request, null)); } catch (final IOException e) { throw new OpenAiClientException("Request to OpenAI model failed", e); diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiError.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiError.java new file mode 100644 index 000000000..4f1011ff0 --- /dev/null +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiError.java @@ -0,0 +1,35 @@ +package com.sap.ai.sdk.foundationmodels.openai; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.google.common.annotations.Beta; +import com.sap.ai.sdk.core.common.ClientError; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ErrorResponse; +import javax.annotation.Nonnull; +import lombok.AccessLevel; +import lombok.AllArgsConstructor; +import lombok.Value; +import lombok.experimental.Delegate; + +/** + * Represents an error response from the OpenAI API. + * + * @since 1.4.0 + */ +@Beta +@Value +@AllArgsConstructor(onConstructor = @__({@JsonCreator}), access = AccessLevel.PROTECTED) +public class OpenAiError implements ClientError { + /** The original error response from the OpenAI API. */ + @Delegate(types = {ClientError.class}) + ErrorResponse originalResponse; + + /** + * Gets the error message from the contained original response. + * + * @return the error message + */ + @Nonnull + public String getMessage() { + return originalResponse.getError().getMessage(); + } +} diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiMessage.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiMessage.java new file mode 100644 index 000000000..a8a25855d --- /dev/null +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiMessage.java @@ -0,0 +1,46 @@ +package com.sap.ai.sdk.foundationmodels.openai; + +import com.google.common.annotations.Beta; +import javax.annotation.Nonnull; + +/** + * Interface representing convenience wrappers of chat message to the openai service. + * + * @since 1.4.0 + */ +@Beta +public interface OpenAiMessage { + + /** + * A convenience method to create a user message. + * + * @param msg the message content. + * @return the user message. + */ + @Nonnull + static OpenAiMessage user(@Nonnull final String msg) { + return new OpenAiUserMessage(msg); + } + + /** + * A convenience method to create an assistant message. + * + * @param msg the message content. + * @return the assistant message. + */ + @Nonnull + static OpenAiMessage assistant(@Nonnull final String msg) { + return new OpenAiAssistantMessage(msg); + } + + /** + * A convenience method to create a system message. + * + * @param msg the message content. + * @return the system message. + */ + @Nonnull + static OpenAiMessage system(@Nonnull final String msg) { + return new OpenAiSystemMessage(msg); + } +} diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiSystemMessage.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiSystemMessage.java new file mode 100644 index 000000000..c24f6c277 --- /dev/null +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiSystemMessage.java @@ -0,0 +1,37 @@ +package com.sap.ai.sdk.foundationmodels.openai; + +import com.google.common.annotations.Beta; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestSystemMessage; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestSystemMessageContent; +import javax.annotation.Nonnull; +import lombok.Value; +import lombok.experimental.Accessors; + +/** + * Represents a chat message as 'system' to OpenAI service. * + * + * @since 1.4.0 + */ +@Beta +@Value +@Accessors(fluent = true) +class OpenAiSystemMessage implements OpenAiMessage { + + /** The role of the message. */ + @Nonnull String role = "system"; + + /** The content of the message. */ + @Nonnull String content; + + /** + * Converts the message to a serializable object. + * + * @return the corresponding {@code ChatCompletionRequestSystemMessage} object. + */ + @Nonnull + ChatCompletionRequestSystemMessage createChatCompletionRequestMessage() { + return new ChatCompletionRequestSystemMessage() + .role(ChatCompletionRequestSystemMessage.RoleEnum.fromValue(role())) + .content(ChatCompletionRequestSystemMessageContent.create(content())); + } +} diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiUserMessage.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiUserMessage.java new file mode 100644 index 000000000..597340a3e --- /dev/null +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiUserMessage.java @@ -0,0 +1,37 @@ +package com.sap.ai.sdk.foundationmodels.openai; + +import com.google.common.annotations.Beta; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestUserMessage; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestUserMessageContent; +import javax.annotation.Nonnull; +import lombok.Value; +import lombok.experimental.Accessors; + +/** + * Represents a chat message as 'user' to OpenAI service. * + * + * @since 1.4.0 + */ +@Beta +@Value +@Accessors(fluent = true) +class OpenAiUserMessage implements OpenAiMessage { + + /** The role of the message. */ + @Nonnull String role = "user"; + + /** The content of the message. */ + @Nonnull String content; + + /** + * Converts the message to a serializable object. + * + * @return the corresponding {@code ChatCompletionRequestUserMessage} object. + */ + @Nonnull + ChatCompletionRequestUserMessage createChatCompletionRequestMessage() { + return new ChatCompletionRequestUserMessage() + .role(ChatCompletionRequestUserMessage.RoleEnum.fromValue(role())) + .content(ChatCompletionRequestUserMessageContent.create(content())); + } +} diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiUtils.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiUtils.java new file mode 100644 index 000000000..2abe27dd8 --- /dev/null +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiUtils.java @@ -0,0 +1,54 @@ +package com.sap.ai.sdk.foundationmodels.openai; + +import static com.sap.ai.sdk.core.JacksonConfiguration.getDefaultObjectMapper; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.annotations.Beta; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestMessage; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionsCreate200Response; +import javax.annotation.Nonnull; + +/** + * Utility class for handling OpenAI module. + * + *

Only intended for internal usage within this SDK. + * + * @since 1.4.0 + */ +@Beta +class OpenAiUtils { + + /** + * Converts an OpenAiMessage to a ChatCompletionRequestMessage. + * + * @param message the OpenAiMessage to convert + * @return the corresponding ChatCompletionRequestMessage + * @throws IllegalArgumentException if the message type is unknown + */ + @Nonnull + static ChatCompletionRequestMessage createChatCompletionRequestMessage( + @Nonnull final OpenAiMessage message) throws IllegalArgumentException { + if (message instanceof OpenAiUserMessage userMessage) { + return userMessage.createChatCompletionRequestMessage(); + } else if (message instanceof OpenAiAssistantMessage assistantMessage) { + return assistantMessage.createChatCompletionRequestMessage(); + } else if (message instanceof OpenAiSystemMessage systemMessage) { + return systemMessage.createChatCompletionRequestMessage(); + } else { + throw new IllegalArgumentException("Unknown message type: " + message.getClass()); + } + } + + /** + * Default object mapper used for JSON de-/serialization. + * + * @return A new object mapper with the default configuration. + */ + @Nonnull + static ObjectMapper getOpenAiObjectMapper() { + return getDefaultObjectMapper() + .addMixIn( + ChatCompletionsCreate200Response.class, + JacksonMixins.DefaultChatCompletionCreate200ResponseMixIn.class); + } +} diff --git a/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/BaseOpenAiClientTest.java b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/BaseOpenAiClientTest.java new file mode 100644 index 000000000..122b074d7 --- /dev/null +++ b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/BaseOpenAiClientTest.java @@ -0,0 +1,182 @@ +package com.sap.ai.sdk.foundationmodels.openai; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl; +import static com.github.tomakehurst.wiremock.client.WireMock.badRequest; +import static com.github.tomakehurst.wiremock.client.WireMock.equalTo; +import static com.github.tomakehurst.wiremock.client.WireMock.noContent; +import static com.github.tomakehurst.wiremock.client.WireMock.okXml; +import static com.github.tomakehurst.wiremock.client.WireMock.post; +import static com.github.tomakehurst.wiremock.client.WireMock.serverError; +import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; +import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; + +import com.fasterxml.jackson.core.JsonParseException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; +import com.github.tomakehurst.wiremock.junit5.WireMockTest; +import com.github.tomakehurst.wiremock.stubbing.Scenario; +import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor; +import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Cache; +import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination; +import java.io.IOException; +import java.io.InputStream; +import java.util.Objects; +import java.util.function.Function; +import javax.annotation.Nonnull; +import org.apache.hc.client5.http.classic.HttpClient; +import org.apache.hc.core5.http.ContentType; +import org.apache.hc.core5.http.io.entity.InputStreamEntity; +import org.apache.hc.core5.http.message.BasicClassicHttpResponse; +import org.assertj.core.api.SoftAssertions; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; + +@WireMockTest +abstract class BaseOpenAiClientTest { + + protected static final ObjectMapper MAPPER = new ObjectMapper(); + protected static OpenAiClient client; + protected final Function fileLoader = + filename -> Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream(filename)); + + static void stubForChatCompletion() { + + stubFor( + post(urlPathEqualTo("/chat/completions")) + .withQueryParam("api-version", equalTo("2024-02-01")) + .willReturn( + aResponse() + .withBodyFile("chatCompletionResponse.json") + .withHeader("Content-Type", "application/json"))); + } + + static void stubForEmbedding() { + stubFor( + post(urlPathEqualTo("/embeddings")) + .willReturn( + aResponse() + .withBodyFile("embeddingResponse.json") + .withHeader("Content-Type", "application/json"))); + } + + static void stubForChatCompletionTool() { + stubFor( + post(urlPathEqualTo("/chat/completions")) + .willReturn( + aResponse() + .withHeader("Content-Type", "application/json") + .withBodyFile("chatCompletionToolResponse.json"))); + } + + static void stubForErrorHandling() { + final var errorJson = + """ + { "error": { "code": null, "message": "foo", "type": "invalid stuff" } } + """; + stubFor( + post(anyUrl()) + .inScenario("Errors") + .whenScenarioStateIs(Scenario.STARTED) + .willReturn(serverError()) + .willSetStateTo("1")); + stubFor( + post(anyUrl()) + .inScenario("Errors") + .whenScenarioStateIs("1") + .willReturn( + badRequest().withBody(errorJson).withHeader("Content-type", "application/json")) + .willSetStateTo("2")); + stubFor( + post(anyUrl()) + .inScenario("Errors") + .whenScenarioStateIs("2") + .willReturn( + badRequest() + .withBody("{ broken json") + .withHeader("Content-type", "application/json")) + .willSetStateTo("3")); + stubFor( + post(anyUrl()) + .inScenario("Errors") + .whenScenarioStateIs("3") + .willReturn(okXml("")) + .willSetStateTo("4")); + stubFor(post(anyUrl()).inScenario("Errors").whenScenarioStateIs("4").willReturn(noContent())); + } + + static void assertForErrorHandling(@Nonnull final Runnable request) { + + final var softly = new SoftAssertions(); + + softly + .assertThatThrownBy(request::run) + .describedAs("Server errors should be handled") + .isInstanceOf(OpenAiClientException.class) + .hasMessageContaining("500"); + + softly + .assertThatThrownBy(request::run) + .describedAs("Error objects from OpenAI should be interpreted") + .isInstanceOf(OpenAiClientException.class) + .hasMessageContaining("error message: 'foo'"); + + softly + .assertThatThrownBy(request::run) + .describedAs("Failures while parsing error message should be handled") + .isInstanceOf(OpenAiClientException.class) + .hasMessageContaining("400") + .extracting(e -> e.getSuppressed()[0]) + .isInstanceOf(JsonParseException.class); + + softly + .assertThatThrownBy(request::run) + .describedAs("Non-JSON responses should be handled") + .isInstanceOf(OpenAiClientException.class) + .hasMessageContaining("Failed to parse"); + + softly + .assertThatThrownBy(request::run) + .describedAs("Empty responses should be handled") + .isInstanceOf(OpenAiClientException.class) + .hasMessageContaining("was empty"); + + softly.assertAll(); + } + + @BeforeEach + void setup(WireMockRuntimeInfo server) { + final DefaultHttpDestination destination = + DefaultHttpDestination.builder(server.getHttpBaseUrl()).build(); + client = OpenAiClient.withCustomDestination(destination); + ApacheHttpClient5Accessor.setHttpClientCache(ApacheHttpClient5Cache.DISABLED); + } + + @AfterEach + void reset() { + ApacheHttpClient5Accessor.setHttpClientCache(null); + ApacheHttpClient5Accessor.setHttpClientFactory(null); + } + + InputStream stubStreamChatCompletion(String responseFile) throws IOException { + var inputStream = spy(fileLoader.apply(responseFile)); + + final var httpClient = mock(HttpClient.class); + ApacheHttpClient5Accessor.setHttpClientFactory(destination -> httpClient); + + // Create a mock response + final var mockResponse = new BasicClassicHttpResponse(200, "OK"); + final var inputStreamEntity = new InputStreamEntity(inputStream, ContentType.TEXT_PLAIN); + mockResponse.setEntity(inputStreamEntity); + mockResponse.setHeader("Content-Type", "text/event-stream"); + + // Configure the HttpClient mock to return the mock response + doReturn(mockResponse).when(httpClient).executeOpen(any(), any(), any()); + + return inputStream; + } +} diff --git a/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequestTest.java b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequestTest.java new file mode 100644 index 000000000..0ede0f272 --- /dev/null +++ b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequestTest.java @@ -0,0 +1,48 @@ +package com.sap.ai.sdk.foundationmodels.openai; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestUserMessage; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestUserMessageContent; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionRequestAllOfStop; +import java.math.BigDecimal; +import java.util.List; +import org.junit.jupiter.api.Test; + +class OpenAiChatCompletionRequestTest { + + @Test + void stopSequence() { + var request = + new OpenAiChatCompletionRequest(OpenAiMessage.user("Hello, world")) + .withStop("stop1", "stop2 stopNot3", "stop3"); + + var lowLevelRequest = request.createCreateChatCompletionRequest(); + assertThat( + ((CreateChatCompletionRequestAllOfStop.InnerStrings) lowLevelRequest.getStop()) + .values()) + .containsExactly("stop1", "stop2 stopNot3", "stop3"); + } + + @Test + void createWithExistingRequest() { + var originalRequest = + new OpenAiChatCompletionRequest(OpenAiMessage.user("First message")) + .withSeed(123) + .withTemperature(BigDecimal.valueOf(0.5)); + + var newRequest = originalRequest.withMessages(List.of(OpenAiMessage.user("Another message"))); + + var lowlevelRequest = newRequest.createCreateChatCompletionRequest(); + + assertThat(newRequest).isNotEqualTo(originalRequest); + assertThat(lowlevelRequest.getMessages()).hasSize(1); + assertThat(lowlevelRequest.getMessages().get(0)) + .isInstanceOf(ChatCompletionRequestUserMessage.class); + assertThat( + ((ChatCompletionRequestUserMessage) lowlevelRequest.getMessages().get(0)).getContent()) + .isEqualTo(ChatCompletionRequestUserMessageContent.create("Another message")); + assertThat(lowlevelRequest.getSeed()).isEqualTo(123); + assertThat(lowlevelRequest.getTemperature()).isEqualTo(BigDecimal.valueOf(0.5)); + } +} diff --git a/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientGeneratedTest.java b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientGeneratedTest.java new file mode 100644 index 000000000..fbf01573e --- /dev/null +++ b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientGeneratedTest.java @@ -0,0 +1,605 @@ +package com.sap.ai.sdk.foundationmodels.openai; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl; +import static com.github.tomakehurst.wiremock.client.WireMock.equalTo; +import static com.github.tomakehurst.wiremock.client.WireMock.equalToJson; +import static com.github.tomakehurst.wiremock.client.WireMock.exactly; +import static com.github.tomakehurst.wiremock.client.WireMock.okJson; +import static com.github.tomakehurst.wiremock.client.WireMock.post; +import static com.github.tomakehurst.wiremock.client.WireMock.postRequestedFor; +import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; +import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo; +import static com.github.tomakehurst.wiremock.client.WireMock.verify; +import static com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionStreamResponseDelta.RoleEnum.ASSISTANT; +import static com.sap.ai.sdk.foundationmodels.openai.generated.model.ContentFilterSeverityResult.SeverityEnum.SAFE; +import static com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionResponse.ObjectEnum.CHAT_COMPLETION; +import static com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionResponseChoicesInner.FinishReasonEnum.STOP; +import static com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionStreamResponse.ObjectEnum.CHAT_COMPLETION_CHUNK; +import static com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionStreamResponse.ObjectEnum.UNKNOWN_DEFAULT_OPEN_API; +import static com.sap.ai.sdk.foundationmodels.openai.generated.model.ToolCallType.FUNCTION; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.when; + +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionNamedToolChoice; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionNamedToolChoiceFunction; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionResponseMessageRole; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionToolChoiceOption; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ContentFilterPromptResults; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionRequest; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionStreamResponseChoicesInner; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequest; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequestInput; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.FunctionObject; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.PromptFilterResult; +import io.vavr.control.Try; +import java.io.IOException; +import java.math.BigDecimal; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; +import javax.annotation.Nonnull; +import lombok.SneakyThrows; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mockito; + +class OpenAiClientGeneratedTest extends BaseOpenAiClientTest { + + @Test + void openAiModels() { + var model = OpenAiModel.GPT_4; + var newModel = model.withVersion("v1"); + + assertThat(model.name()).isEqualTo("gpt-4"); + assertThat(model.version()).isNull(); + + assertThat(newModel.name()).isEqualTo("gpt-4"); + assertThat(newModel.version()).isEqualTo("v1"); + + assertThat(model).isNotSameAs(newModel); + } + + private static Runnable[] errorHandlingCalls() { + return new Runnable[] { + () -> client.chatCompletion(new OpenAiChatCompletionRequest("")), + () -> + client + .streamChatCompletionDeltas(new OpenAiChatCompletionRequest("")) + // the stream needs to be consumed to parse the response + .forEach(System.out::println) + }; + } + + @ParameterizedTest + @MethodSource("errorHandlingCalls") + void chatCompletionErrorHandling(@Nonnull final Runnable request) { + stubForErrorHandling(); + assertForErrorHandling(request); + } + + @Test + void apiVersion() { + stubFor(post(anyUrl()).willReturn(okJson("{}"))); + Try.of(() -> client.chatCompletion(new OpenAiChatCompletionRequest(""))); + + verify( + exactly(1), + postRequestedFor(anyUrl()).withQueryParam("api-version", equalTo("2024-02-01"))); + + Try.of( + () -> client.withApiVersion("fooBar").chatCompletion(new OpenAiChatCompletionRequest(""))); + verify(exactly(1), postRequestedFor(anyUrl()).withQueryParam("api-version", equalTo("fooBar"))); + + assertThat(client) + .describedAs( + "withApiVersion should return a new object, the sut object should remain unchanged") + .isNotSameAs(client.withApiVersion("fooBar")); + Try.of(() -> client.chatCompletion(new OpenAiChatCompletionRequest(""))); + verify( + exactly(2), + postRequestedFor(anyUrl()).withQueryParam("api-version", equalTo("2024-02-01"))); + } + + @Test + void chatCompletion() { + + stubForChatCompletion(); + + final var systemMessage = OpenAiMessage.system("You are a helpful AI"); + final var userMessage = OpenAiMessage.user("Hello World! Why is this phrase so famous?"); + final var prompt = new OpenAiChatCompletionRequest(systemMessage, userMessage); + final var result = client.chatCompletion(prompt).getOriginalResponse(); + + assertThat(result).isNotNull(); + assertThat(result.getCreated()).isEqualTo(1727436279); + assertThat(result.getId()).isEqualTo("chatcmpl-AC3NPPYlxem8kRBBAX9EBObMMsrnf"); + assertThat(result.getModel()).isEqualTo("gpt-35-turbo"); + assertThat(result.getObject()).isEqualTo(CHAT_COMPLETION); + assertThat(result.getSystemFingerprint()).isEqualTo("fp_e49e4201a9"); + + assertThat(result.getUsage()).isNotNull(); + assertThat(result.getUsage().getCompletionTokens()).isEqualTo(20); + assertThat(result.getUsage().getPromptTokens()).isEqualTo(13); + assertThat(result.getUsage().getTotalTokens()).isEqualTo(33); + + assertThat(result.getPromptFilterResults()).hasSize(1); + assertThat(result.getPromptFilterResults().get(0).getPromptIndex()).isZero(); + + var promptFilterResults = result.getPromptFilterResults().get(0).getContentFilterResults(); + assertThat(promptFilterResults).isNotNull(); + assertThat(promptFilterResults.getSexual()).isNotNull(); + assertThat(promptFilterResults.getSexual().isFiltered()).isFalse(); + assertThat(promptFilterResults.getSexual().getSeverity()).isEqualTo(SAFE); + assertThat(promptFilterResults.getViolence()).isNotNull(); + assertThat(promptFilterResults.getViolence().isFiltered()).isFalse(); + assertThat(promptFilterResults.getViolence().getSeverity()).isEqualTo(SAFE); + assertThat(promptFilterResults.getHate()).isNotNull(); + assertThat(promptFilterResults.getHate().isFiltered()).isFalse(); + assertThat(promptFilterResults.getHate().getSeverity()).isEqualTo(SAFE); + assertThat(promptFilterResults.getSelfHarm()).isNotNull(); + assertThat(promptFilterResults.getSelfHarm().getSeverity()).isEqualTo(SAFE); + assertThat(promptFilterResults.getSelfHarm().isFiltered()).isFalse(); + assertThat(promptFilterResults.getProfanity()).isNull(); + assertThat(promptFilterResults.getError()).isNull(); + assertThat(promptFilterResults.getJailbreak()).isNull(); + + assertThat(result.getChoices()).hasSize(1); + + var choice = result.getChoices().get(0); + assertThat(choice.getFinishReason()).isEqualTo(STOP); + assertThat(choice.getIndex()).isZero(); + assertThat(choice.getMessage().getContent()) + .isEqualTo( + "I'm an AI and cannot answer that question as beauty is subjective and varies from person to person."); + assertThat(choice.getMessage().getRole()) + .isEqualTo(ChatCompletionResponseMessageRole.ASSISTANT); + assertThat(choice.getMessage().getToolCalls()).isNull(); + + var contentFilterResults = choice.getContentFilterResults(); + assertThat(contentFilterResults).isNotNull(); + assertThat(contentFilterResults.getSexual()).isNotNull(); + assertThat(contentFilterResults.getSexual().isFiltered()).isFalse(); + assertThat(contentFilterResults.getSexual().getSeverity()).isEqualTo(SAFE); + assertThat(contentFilterResults.getViolence()).isNotNull(); + assertThat(contentFilterResults.getViolence().isFiltered()).isFalse(); + assertThat(contentFilterResults.getViolence().getSeverity()).isEqualTo(SAFE); + assertThat(contentFilterResults.getHate()).isNotNull(); + assertThat(contentFilterResults.getHate().isFiltered()).isFalse(); + assertThat(contentFilterResults.getHate().getSeverity()).isEqualTo(SAFE); + assertThat(contentFilterResults.getSelfHarm()).isNotNull(); + assertThat(contentFilterResults.getSelfHarm().getSeverity()).isEqualTo(SAFE); + assertThat(contentFilterResults.getSelfHarm().isFiltered()).isFalse(); + assertThat(contentFilterResults.getProfanity()).isNull(); + assertThat(contentFilterResults.getError()).isNull(); + + verify( + postRequestedFor(urlPathEqualTo("/chat/completions")) + .withQueryParam("api-version", equalTo("2024-02-01")) + .withRequestBody( + equalToJson( + """ + { + "messages" : [ { + "content" : "You are a helpful AI", + "role" : "system" + }, { + "content" : "Hello World! Why is this phrase so famous?", + "role" : "user" + } ] + }"""))); + } + + @Test + @DisplayName("Chat history is not implemented yet") + void history() { + stubFor( + post(urlPathEqualTo("/chat/completions")) + .willReturn( + aResponse() + .withBodyFile("chatCompletionResponse.json") + .withHeader("Content-Type", "application/json"))); + + client.chatCompletion(new OpenAiChatCompletionRequest("First message")); + + verify( + exactly(1), + postRequestedFor(urlPathEqualTo("/chat/completions")) + .withRequestBody( + equalToJson( + """ + { + "messages" : [{ + "content" : "First message", + "role" : "user" + }] + }"""))); + + var response = client.chatCompletion(new OpenAiChatCompletionRequest("Second message")); + + assertThat(response.getContent()).isNotNull(); + assertThat(response.getContent()) + .isEqualTo( + "I'm an AI and cannot answer that question as beauty is subjective and varies from person to person."); + + verify( + exactly(1), + postRequestedFor(urlPathEqualTo("/chat/completions")) + .withRequestBody( + equalToJson( + """ + { + "messages" : [{ + "content" : "Second message", + "role" : "user" + }] + }"""))); + } + + @Test + void embedding() { + stubForEmbedding(); + + final var result = + client.embedding( + new EmbeddingsCreateRequest() + .input(EmbeddingsCreateRequestInput.create("Hello World"))); + + assertThat(result).isNotNull(); + assertThat(result.getModel()).isEqualTo("ada"); + assertThat(result.getObject()).isEqualTo("list"); + + assertThat(result.getUsage()).isNotNull(); + assertThat(result.getUsage().getPromptTokens()).isEqualTo(2); + assertThat(result.getUsage().getTotalTokens()).isEqualTo(2); + + assertThat(result.getData()).isNotNull().hasSize(1); + var embeddingData = result.getData().get(0); + assertThat(embeddingData).isNotNull(); + assertThat(embeddingData.getObject()).isEqualTo("embedding"); + assertThat(embeddingData.getIndex()).isZero(); + assertThat(embeddingData.getEmbedding()) + .isNotNull() + .isNotEmpty() + .containsExactly( + new BigDecimal("0.0"), + new BigDecimal("3.4028235E+38"), + new BigDecimal("1.4E-45"), + new BigDecimal("1.23"), + new BigDecimal("-4.56")); + + verify( + postRequestedFor(urlPathEqualTo("/embeddings")) + .withRequestBody( + equalToJson( + """ + {"input": "Hello World"}"""))); + } + + @Test + void testThrowsOnContentFilter() { + var mock = mock(OpenAiClient.class); + when(mock.streamChatCompletion(any())).thenCallRealMethod(); + + var deltaWithContentFilter = mock(OpenAiChatCompletionDelta.class); + when(deltaWithContentFilter.getFinishReason()).thenReturn("content_filter"); + when(mock.streamChatCompletionDeltas((CreateChatCompletionRequest) any())) + .thenReturn(Stream.of(deltaWithContentFilter)); + + // this must not throw, since the stream is lazily evaluated + var stream = mock.streamChatCompletion(""); + assertThatThrownBy(stream::toList) + .isInstanceOf(OpenAiClientException.class) + .hasMessageContaining("Content filter"); + } + + @Test + void streamChatCompletionDeltasErrorHandling() throws IOException { + try (var inputStream = stubStreamChatCompletion("streamChatCompletionError.txt")) { + + final var request = + new OpenAiChatCompletionRequest( + "Can you give me the first 100 numbers of the Fibonacci sequence?"); + + try (var stream = client.streamChatCompletionDeltas(request)) { + assertThatThrownBy(() -> stream.forEach(System.out::println)) + .isInstanceOf(OpenAiClientException.class) + .hasMessage("Failed to parse response and error message: 'exceeded token rate limit'"); + } + + Mockito.verify(inputStream, times(1)).close(); + } + } + + @SneakyThrows + @Test + void streamChatCompletionWithString() { + try (var inputStream = stubStreamChatCompletion("streamChatCompletion.txt")) { + final var userMessage = "Hello World! Why is this phrase so famous?"; + client.withSystemPrompt("You are a helpful AI"); + final var result = client.streamChatCompletion(userMessage).toList(); + + assertThat(result).hasSize(5); + // the first two and the last delta don't have any content + assertThat(result.get(0)).isEmpty(); + assertThat(result.get(1)).isEmpty(); + assertThat(result.get(2)).isEqualTo("Sure"); + assertThat(result.get(3)).isEqualTo("!"); + assertThat(result.get(4)).isEmpty(); + + Mockito.verify(inputStream, times(1)).close(); + } + } + + @Test + void streamChatCompletionDeltas() throws IOException { + + try (var inputStream = stubStreamChatCompletion("streamChatCompletion.txt")) { + + final var request = + new OpenAiChatCompletionRequest( + "Can you give me the first 100 numbers of the Fibonacci sequence?"); + + try (Stream stream = client.streamChatCompletionDeltas(request)) { + final List deltaList = stream.toList(); + + assertThat(deltaList).hasSize(5); + + final var delta0 = deltaList.get(0).getOriginalResponse(); + final var delta1 = deltaList.get(1).getOriginalResponse(); + final var delta2 = deltaList.get(2).getOriginalResponse(); + final var delta3 = deltaList.get(3).getOriginalResponse(); + final var delta4 = deltaList.get(4).getOriginalResponse(); + + assertThat(delta0.getCreated()).isZero(); + assertThat(delta1.getCreated()).isEqualTo(1724825677); + assertThat(delta2.getCreated()).isEqualTo(1724825677); + assertThat(delta3.getCreated()).isEqualTo(1724825677); + assertThat(delta4.getCreated()).isEqualTo(1724825677); + + assertThat(delta0.getId()).isEmpty(); + assertThat(delta1.getId()).isEqualTo("chatcmpl-A16EvnkgEm6AdxY0NoOmGPjsJucQ1"); + assertThat(delta2.getId()).isEqualTo("chatcmpl-A16EvnkgEm6AdxY0NoOmGPjsJucQ1"); + assertThat(delta3.getId()).isEqualTo("chatcmpl-A16EvnkgEm6AdxY0NoOmGPjsJucQ1"); + assertThat(delta4.getId()).isEqualTo("chatcmpl-A16EvnkgEm6AdxY0NoOmGPjsJucQ1"); + + assertThat(delta0.getModel()).isEmpty(); + assertThat(delta1.getModel()).isEqualTo("gpt-35-turbo"); + assertThat(delta2.getModel()).isEqualTo("gpt-35-turbo"); + assertThat(delta3.getModel()).isEqualTo("gpt-35-turbo"); + assertThat(delta4.getModel()).isEqualTo("gpt-35-turbo"); + + assertThat(delta0.getObject()).isEqualTo(UNKNOWN_DEFAULT_OPEN_API); + assertThat(delta1.getObject()).isEqualTo(CHAT_COMPLETION_CHUNK); + assertThat(delta2.getObject()).isEqualTo(CHAT_COMPLETION_CHUNK); + assertThat(delta3.getObject()).isEqualTo(CHAT_COMPLETION_CHUNK); + assertThat(delta4.getObject()).isEqualTo(CHAT_COMPLETION_CHUNK); + + assertThat(delta0.getSystemFingerprint()).isNull(); + assertThat(delta1.getSystemFingerprint()).isEqualTo("fp_e49e4201a9"); + assertThat(delta2.getSystemFingerprint()).isEqualTo("fp_e49e4201a9"); + assertThat(delta3.getSystemFingerprint()).isEqualTo("fp_e49e4201a9"); + assertThat(delta4.getSystemFingerprint()).isEqualTo("fp_e49e4201a9"); + + assertThat(delta0.getCustomFieldNames()).contains("prompt_filter_results"); + assertThat(delta1.getCustomFieldNames()).doesNotContain("prompt_filter_results"); + assertThat(delta2.getCustomFieldNames()).doesNotContain("prompt_filter_results"); + assertThat(delta3.getCustomFieldNames()).doesNotContain("prompt_filter_results"); + assertThat(delta4.getCustomFieldNames()).doesNotContain("prompt_filter_results"); + var promptFilterResults = (List) delta0.getCustomField("prompt_filter_results"); + final var promptFilter0 = + MAPPER.convertValue(promptFilterResults.get(0), PromptFilterResult.class); + assertThat(promptFilter0).isNotNull(); + assertThat(promptFilter0.getPromptIndex()).isZero(); + assertFilter(promptFilter0.getContentFilterResults()); + assertThat(promptFilter0.getContentFilterResults()).isNotNull(); + assertFilter(promptFilter0.getContentFilterResults()); + + // delta0.choices + assertThat(delta0.getChoices()).isEmpty(); + + // delta1.choices + assertThat(delta1.getChoices()).hasSize(1); + var choice1 = delta1.getChoices().get(0); + assertThat(choice1.getFinishReason()).isNull(); + assertThat(choice1.getIndex()).isZero(); + assertThat(choice1.getDelta().getContent()).isEmpty(); + assertThat(choice1.getDelta().getRole()).isEqualTo(ASSISTANT); + assertThat(choice1.getCustomField("content_filter_results")).isNotNull(); + assertThat(choice1.getCustomField("content_filter_results")).isEqualTo(Map.of()); + + // delta2.choices + assertThat(delta2.getChoices()).hasSize(1); + var choice2 = delta2.getChoices().get(0); + assertThat(choice2.getFinishReason()).isNull(); + assertThat(choice2.getIndex()).isZero(); + assertThat(choice2.getDelta().getContent()).isEqualTo("Sure"); + assertThat(choice2.getDelta().getRole()).isNull(); + assertThat(choice2.getCustomField("content_filter_results")).isNotNull(); + final var contentFilter2 = + MAPPER.convertValue( + choice2.getCustomField("content_filter_results"), ContentFilterPromptResults.class); + assertThat(contentFilter2).isNotNull(); + assertFilter(contentFilter2); + + // delta3.choices + assertThat(delta3.getChoices()).hasSize(1); + var choice3 = delta3.getChoices().get(0); + assertThat(choice3.getFinishReason()).isNull(); + assertThat(choice3.getIndex()).isZero(); + assertThat(choice3.getDelta().getContent()).isEqualTo("!"); + assertThat(choice3.getDelta().getRole()).isNull(); + assertThat(choice3.getCustomField("content_filter_results")).isNotNull(); + var contentFilter3 = + MAPPER.convertValue( + choice3.getCustomField("content_filter_results"), ContentFilterPromptResults.class); + assertThat(contentFilter3).isNotNull(); + assertFilter(contentFilter3); + + // delta4.choices + assertThat(delta4.getChoices()).hasSize(1); + var choice4 = delta4.getChoices().get(0); + assertThat(choice4.getFinishReason()) + .isEqualTo(CreateChatCompletionStreamResponseChoicesInner.FinishReasonEnum.STOP); + assertThat(choice4.getIndex()).isZero(); + assertThat(choice4.getDelta().getContent()).isNull(); + assertThat(choice4.getDelta().getRole()).isNull(); + assertThat(choice4.getCustomField("content_filter_results")).isEqualTo(Map.of()); + } + + Mockito.verify(inputStream, times(1)).close(); + } + } + + @Test + void streamCompletionDeltasResponseConvenience() throws IOException { + try (var inputStream = stubStreamChatCompletion("streamChatCompletion.txt")) { + + final var request = + new OpenAiChatCompletionRequest( + "Can you give me the first 100 numbers of the Fibonacci sequence?"); + + try (Stream stream = client.streamChatCompletionDeltas(request)) { + final List deltaList = stream.toList(); + + assertThat(deltaList).hasSize(5); + + assertThat(deltaList.get(0).getFinishReason()).isNull(); + assertThat(deltaList.get(1).getFinishReason()).isNull(); + assertThat(deltaList.get(2).getFinishReason()).isNull(); + assertThat(deltaList.get(3).getFinishReason()).isNull(); + assertThat(deltaList.get(3).getFinishReason()).isNull(); + assertThat(deltaList.get(4).getFinishReason()).isEqualTo("stop"); + + assertThat(deltaList.get(0).getDeltaContent()).isEmpty(); + assertThat(deltaList.get(1).getDeltaContent()).isEmpty(); + assertThat(deltaList.get(2).getDeltaContent()).isEqualTo("Sure"); + assertThat(deltaList.get(3).getDeltaContent()).isEqualTo("!"); + assertThat(deltaList.get(4).getDeltaContent()).isEmpty(); + + assertThat(deltaList.get(0).getCompletionUsage()).isNull(); + assertThat(deltaList.get(1).getCompletionUsage()).isNull(); + assertThat(deltaList.get(2).getCompletionUsage()).isNull(); + assertThat(deltaList.get(3).getCompletionUsage()).isNull(); + assertThat(deltaList.get(4).getCompletionUsage()).isNotNull(); + assertThat(deltaList.get(4).getCompletionUsage().getCompletionTokens()).isEqualTo(607); + assertThat(deltaList.get(4).getCompletionUsage().getPromptTokens()).isEqualTo(21); + assertThat(deltaList.get(4).getCompletionUsage().getTotalTokens()).isEqualTo(628); + } + + Mockito.verify(inputStream, times(1)).close(); + } + } + + void assertFilter(ContentFilterPromptResults filter) { + assertThat(filter).isNotNull(); + assertThat(filter.getHate()).isNotNull(); + assertThat(filter.getHate().isFiltered()).isFalse(); + assertThat(filter.getHate().getSeverity()).isEqualTo(SAFE); + assertThat(filter.getSelfHarm()).isNotNull(); + assertThat(filter.getSelfHarm().isFiltered()).isFalse(); + assertThat(filter.getSelfHarm().getSeverity()).isEqualTo(SAFE); + assertThat(filter.getSexual()).isNotNull(); + assertThat(filter.getSexual().isFiltered()).isFalse(); + assertThat(filter.getSexual().getSeverity()).isEqualTo(SAFE); + assertThat(filter.getViolence()).isNotNull(); + assertThat(filter.getViolence().isFiltered()).isFalse(); + assertThat(filter.getViolence().getSeverity()).isEqualTo(SAFE); + assertThat(filter.getJailbreak()).isNull(); + assertThat(filter.getProfanity()).isNull(); + assertThat(filter.getError()).isNull(); + } + + @Test + void chatCompletionTool() { + stubForChatCompletionTool(); + + final var function = + new FunctionObject() + .name("fibonacci") + .parameters( + Map.of("type", "object", "properties", Map.of("N", Map.of("type", "integer")))); + + final var tool = + new ChatCompletionTool().type(ChatCompletionTool.TypeEnum.FUNCTION).function(function); + + final var toolChoice = + ChatCompletionToolChoiceOption.create( + new ChatCompletionNamedToolChoice() + .type(ChatCompletionNamedToolChoice.TypeEnum.FUNCTION) + .function(new ChatCompletionNamedToolChoiceFunction().name("fibonacci"))); + + final var request = + new OpenAiChatCompletionRequest( + "A pair of rabbits is placed in a field. Each month, every pair produces one new pair, starting from the second month. How many rabbits will there be after 12 months?") + .withTools(List.of(tool)) + .withToolChoice(toolChoice); + + var response = client.chatCompletion(request).getOriginalResponse(); + + assertThat(response).isNotNull(); + assertThat(response.getChoices()).hasSize(1); + assertThat(response.getChoices().get(0).getFinishReason()).isEqualTo(STOP); + assertThat(response.getChoices().get(0).getMessage().getRole()) + .isEqualTo(ChatCompletionResponseMessageRole.ASSISTANT); + assertThat(response.getChoices().get(0).getMessage().getToolCalls()).hasSize(1); + assertThat(response.getChoices().get(0).getMessage().getToolCalls().get(0).getId()) + .isEqualTo("call_CUYGJf2j7FRWJMHT3PN3aGxK"); + assertThat(response.getChoices().get(0).getMessage().getToolCalls().get(0).getType()) + .isEqualTo(FUNCTION); + assertThat( + response.getChoices().get(0).getMessage().getToolCalls().get(0).getFunction().getName()) + .isEqualTo("fibonacci"); + assertThat( + response + .getChoices() + .get(0) + .getMessage() + .getToolCalls() + .get(0) + .getFunction() + .getArguments()) + .isEqualTo("{\"N\":12}"); + + verify( + postRequestedFor(anyUrl()) + .withRequestBody( + equalToJson( + """ + { + "messages" : [ { + "content" : "A pair of rabbits is placed in a field. Each month, every pair produces one new pair, starting from the second month. How many rabbits will there be after 12 months?", + "role" : "user" + } ], + "tools" : [ { + "type" : "function", + "function" : { + "name" : "fibonacci", + "parameters" : { + "type" : "object", + "properties" : { + "N" : { + "type" : "integer" + } + } + }, + "strict" : false + } + } ], + "tool_choice" : { + "type" : "function", + "function" : { + "name" : "fibonacci" + } + } + } + """))); + } +} diff --git a/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientTest.java b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientTest.java index 67c1c8a24..eb57255a2 100644 --- a/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientTest.java +++ b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientTest.java @@ -1,180 +1,54 @@ package com.sap.ai.sdk.foundationmodels.openai; import static com.github.tomakehurst.wiremock.client.WireMock.*; +import static com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionTool.ToolType.FUNCTION; +import static com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage.*; import static com.sap.ai.sdk.foundationmodels.openai.model.OpenAiContentFilterSeverityResult.Severity.SAFE; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; -import static org.mockito.Mockito.when; -import com.fasterxml.jackson.core.JsonParseException; -import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; -import com.github.tomakehurst.wiremock.junit5.WireMockTest; -import com.github.tomakehurst.wiremock.stubbing.Scenario; -import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionChoice; import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionDelta; +import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionFunction; import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionOutput; import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionParameters; -import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage.OpenAiChatSystemMessage; -import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage.OpenAiChatUserMessage; +import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionTool; import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiContentFilterPromptResults; import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiEmbeddingParameters; -import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor; -import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Cache; -import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination; -import io.vavr.control.Try; import java.io.IOException; -import java.io.InputStream; import java.util.List; -import java.util.Objects; +import java.util.Map; import java.util.concurrent.Callable; -import java.util.function.Function; import java.util.stream.Stream; import javax.annotation.Nonnull; import lombok.SneakyThrows; -import org.apache.hc.client5.http.classic.HttpClient; -import org.apache.hc.core5.http.ContentType; -import org.apache.hc.core5.http.io.entity.InputStreamEntity; -import org.apache.hc.core5.http.message.BasicClassicHttpResponse; -import org.assertj.core.api.SoftAssertions; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; import org.mockito.Mockito; -@WireMockTest -class OpenAiClientTest { - private static OpenAiClient client; - private final Function fileLoader = - filename -> Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream(filename)); - - @BeforeEach - void setup(WireMockRuntimeInfo server) { - final DefaultHttpDestination destination = - DefaultHttpDestination.builder(server.getHttpBaseUrl()).build(); - client = OpenAiClient.withCustomDestination(destination); - ApacheHttpClient5Accessor.setHttpClientCache(ApacheHttpClient5Cache.DISABLED); - } - - @AfterEach - void reset() { - ApacheHttpClient5Accessor.setHttpClientCache(null); - ApacheHttpClient5Accessor.setHttpClientFactory(null); - } - - @Test - void apiVersion() { - stubFor(post(anyUrl()).willReturn(okJson("{}"))); - Try.of(() -> client.chatCompletion(new OpenAiChatCompletionParameters())); - - verify( - exactly(1), - postRequestedFor(anyUrl()).withQueryParam("api-version", equalTo("2024-02-01"))); - - Try.of( - () -> client.withApiVersion("fooBar").chatCompletion(new OpenAiChatCompletionParameters())); - verify(exactly(1), postRequestedFor(anyUrl()).withQueryParam("api-version", equalTo("fooBar"))); - - assertThat(client) - .describedAs( - "withApiVersion should return a new object, the sut object should remain unchanged") - .isNotSameAs(client.withApiVersion("fooBar")); - Try.of(() -> client.chatCompletion(new OpenAiChatCompletionParameters())); - verify( - exactly(2), - postRequestedFor(anyUrl()).withQueryParam("api-version", equalTo("2024-02-01"))); - } +class OpenAiClientTest extends BaseOpenAiClientTest { private static Runnable[] errorHandlingCalls() { return new Runnable[] { - () -> client.chatCompletion(new OpenAiChatCompletionParameters()), + () -> client.chatCompletion(""), () -> client - .streamChatCompletionDeltas(new OpenAiChatCompletionParameters()) + .streamChatCompletionDeltas( + new OpenAiChatCompletionParameters() + .addMessages(new OpenAiChatUserMessage().addText(""))) // the stream needs to be consumed to parse the response - .forEach(System.out::println) + .forEach(System.out::println), }; } @ParameterizedTest @MethodSource("errorHandlingCalls") void chatCompletionErrorHandling(@Nonnull final Runnable request) { - final var errorJson = - """ - { "error": { "code": null, "message": "foo", "type": "invalid stuff" } } - """; - stubFor( - post(anyUrl()) - .inScenario("Errors") - .whenScenarioStateIs(Scenario.STARTED) - .willReturn(serverError()) - .willSetStateTo("1")); - stubFor( - post(anyUrl()) - .inScenario("Errors") - .whenScenarioStateIs("1") - .willReturn( - badRequest().withBody(errorJson).withHeader("Content-type", "application/json")) - .willSetStateTo("2")); - stubFor( - post(anyUrl()) - .inScenario("Errors") - .whenScenarioStateIs("2") - .willReturn( - badRequest() - .withBody("{ broken json") - .withHeader("Content-type", "application/json")) - .willSetStateTo("3")); - stubFor( - post(anyUrl()) - .inScenario("Errors") - .whenScenarioStateIs("3") - .willReturn(okXml("")) - .willSetStateTo("4")); - stubFor(post(anyUrl()).inScenario("Errors").whenScenarioStateIs("4").willReturn(noContent())); - - final var softly = new SoftAssertions(); - - softly - .assertThatThrownBy(request::run) - .describedAs("Server errors should be handled") - .isInstanceOf(OpenAiClientException.class) - .hasMessageContaining("500"); - - softly - .assertThatThrownBy(request::run) - .describedAs("Error objects from OpenAI should be interpreted") - .isInstanceOf(OpenAiClientException.class) - .hasMessageContaining("error message: 'foo'"); - - softly - .assertThatThrownBy(request::run) - .describedAs("Failures while parsing error message should be handled") - .isInstanceOf(OpenAiClientException.class) - .hasMessageContaining("400") - .extracting(e -> e.getSuppressed()[0]) - .isInstanceOf(JsonParseException.class); - - softly - .assertThatThrownBy(request::run) - .describedAs("Non-JSON responses should be handled") - .isInstanceOf(OpenAiClientException.class) - .hasMessageContaining("Failed to parse"); - - softly - .assertThatThrownBy(request::run) - .describedAs("Empty responses should be handled") - .isInstanceOf(OpenAiClientException.class) - .hasMessageContaining("was empty"); - - softly.assertAll(); + + stubForErrorHandling(); + assertForErrorHandling(request); } private static Callable[] chatCompletionCalls() { @@ -198,97 +72,94 @@ private static Callable[] chatCompletionCalls() { @ParameterizedTest @MethodSource("chatCompletionCalls") void chatCompletion(@Nonnull final Callable request) { - try (var inputStream = fileLoader.apply("__files/chatCompletionResponse.json")) { - - final String response = new String(inputStream.readAllBytes()); - // with query parameter api-version=2024-02-01 - stubFor( - post(urlPathEqualTo("/chat/completions")) - .withQueryParam("api-version", equalTo("2024-02-01")) - .willReturn(okJson(response))); - - final OpenAiChatCompletionOutput result = request.call(); - - assertThat(result).isNotNull(); - assertThat(result.getCreated()).isEqualTo(1727436279); - assertThat(result.getId()).isEqualTo("chatcmpl-AC3NPPYlxem8kRBBAX9EBObMMsrnf"); - assertThat(result.getModel()).isEqualTo("gpt-35-turbo"); - assertThat(result.getObject()).isEqualTo("chat.completion"); - assertThat(result.getSystemFingerprint()).isEqualTo("fp_e49e4201a9"); - - assertThat(result.getUsage()).isNotNull(); - assertThat(result.getUsage().getCompletionTokens()).isEqualTo(20); - assertThat(result.getUsage().getPromptTokens()).isEqualTo(13); - assertThat(result.getUsage().getTotalTokens()).isEqualTo(33); - - assertThat(result.getPromptFilterResults()).hasSize(1); - assertThat(result.getPromptFilterResults().get(0).getPromptIndex()).isEqualTo(0); - OpenAiContentFilterPromptResults promptFilterResults = - result.getPromptFilterResults().get(0).getContentFilterResults(); - assertThat(promptFilterResults).isNotNull(); - assertThat(promptFilterResults.getSexual()).isNotNull(); - assertThat(promptFilterResults.getSexual().isFiltered()).isFalse(); - assertThat(promptFilterResults.getSexual().getSeverity()).isEqualTo(SAFE); - assertThat(promptFilterResults.getViolence()).isNotNull(); - assertThat(promptFilterResults.getViolence().isFiltered()).isFalse(); - assertThat(promptFilterResults.getViolence().getSeverity()).isEqualTo(SAFE); - assertThat(promptFilterResults.getHate()).isNotNull(); - assertThat(promptFilterResults.getHate().isFiltered()).isFalse(); - assertThat(promptFilterResults.getHate().getSeverity()).isEqualTo(SAFE); - assertThat(promptFilterResults.getSelfHarm()).isNotNull(); - assertThat(promptFilterResults.getSelfHarm().getSeverity()).isEqualTo(SAFE); - assertThat(promptFilterResults.getSelfHarm().isFiltered()).isFalse(); - assertThat(promptFilterResults.getProfanity()).isNull(); - assertThat(promptFilterResults.getError()).isNull(); - assertThat(promptFilterResults.getJailbreak()).isNull(); - - assertThat(result.getChoices()).hasSize(1); - OpenAiChatCompletionChoice choice = result.getChoices().get(0); - assertThat(choice.getFinishReason()).isEqualTo("stop"); - assertThat(choice.getIndex()).isEqualTo(0); - assertThat(choice.getMessage().getContent()) - .isEqualTo( - "I'm an AI and cannot answer that question as beauty is subjective and varies from person to person."); - assertThat(choice.getMessage().getRole()).isEqualTo("assistant"); - assertThat(choice.getMessage().getToolCalls()).isNull(); - - OpenAiContentFilterPromptResults contentFilterResults = choice.getContentFilterResults(); - assertThat(contentFilterResults).isNotNull(); - assertThat(contentFilterResults.getSexual()).isNotNull(); - assertThat(contentFilterResults.getSexual().isFiltered()).isFalse(); - assertThat(contentFilterResults.getSexual().getSeverity()).isEqualTo(SAFE); - assertThat(contentFilterResults.getViolence()).isNotNull(); - assertThat(contentFilterResults.getViolence().isFiltered()).isFalse(); - assertThat(contentFilterResults.getViolence().getSeverity()).isEqualTo(SAFE); - assertThat(contentFilterResults.getHate()).isNotNull(); - assertThat(contentFilterResults.getHate().isFiltered()).isFalse(); - assertThat(contentFilterResults.getHate().getSeverity()).isEqualTo(SAFE); - assertThat(contentFilterResults.getSelfHarm()).isNotNull(); - assertThat(contentFilterResults.getSelfHarm().getSeverity()).isEqualTo(SAFE); - assertThat(contentFilterResults.getSelfHarm().isFiltered()).isFalse(); - assertThat(contentFilterResults.getProfanity()).isNull(); - assertThat(contentFilterResults.getError()).isNull(); - assertThat(contentFilterResults.getJailbreak()).isNull(); - - verify( - postRequestedFor(urlPathEqualTo("/chat/completions")) - .withQueryParam("api-version", equalTo("2024-02-01")) - .withRequestBody( - equalToJson( - """ - { - "messages" : [ { - "role" : "system", - "content" : "You are a helpful AI" - }, { - "role" : "user", - "content" : [ { - "type" : "text", - "text" : "Hello World! Why is this phrase so famous?" - } ] - } ] - }"""))); - } + + stubForChatCompletion(); + + var result = request.call(); + + assertThat(result).isNotNull(); + assertThat(result.getCreated()).isEqualTo(1727436279); + assertThat(result.getId()).isEqualTo("chatcmpl-AC3NPPYlxem8kRBBAX9EBObMMsrnf"); + assertThat(result.getModel()).isEqualTo("gpt-35-turbo"); + assertThat(result.getObject()).isEqualTo("chat.completion"); + assertThat(result.getSystemFingerprint()).isEqualTo("fp_e49e4201a9"); + + assertThat(result.getUsage()).isNotNull(); + assertThat(result.getUsage().getCompletionTokens()).isEqualTo(20); + assertThat(result.getUsage().getPromptTokens()).isEqualTo(13); + assertThat(result.getUsage().getTotalTokens()).isEqualTo(33); + + assertThat(result.getPromptFilterResults()).hasSize(1); + assertThat(result.getPromptFilterResults().get(0).getPromptIndex()).isEqualTo(0); + assertThat(result.getContent()) + .isEqualTo( + "I'm an AI and cannot answer that question as beauty is subjective and varies from person to person."); + + var promptFilterResults = result.getPromptFilterResults().get(0).getContentFilterResults(); + assertThat(promptFilterResults).isNotNull(); + assertThat(promptFilterResults.getSexual()).isNotNull(); + assertThat(promptFilterResults.getSexual().isFiltered()).isFalse(); + assertThat(promptFilterResults.getSexual().getSeverity()).isEqualTo(SAFE); + assertThat(promptFilterResults.getViolence()).isNotNull(); + assertThat(promptFilterResults.getViolence().isFiltered()).isFalse(); + assertThat(promptFilterResults.getViolence().getSeverity()).isEqualTo(SAFE); + assertThat(promptFilterResults.getHate()).isNotNull(); + assertThat(promptFilterResults.getHate().isFiltered()).isFalse(); + assertThat(promptFilterResults.getHate().getSeverity()).isEqualTo(SAFE); + assertThat(promptFilterResults.getSelfHarm()).isNotNull(); + assertThat(promptFilterResults.getSelfHarm().getSeverity()).isEqualTo(SAFE); + assertThat(promptFilterResults.getSelfHarm().isFiltered()).isFalse(); + assertThat(promptFilterResults.getProfanity()).isNull(); + assertThat(promptFilterResults.getError()).isNull(); + assertThat(promptFilterResults.getJailbreak()).isNull(); + + assertThat(result.getChoices()).hasSize(1); + + var choice = result.getChoices().get(0); + assertThat(choice.getFinishReason()).isEqualTo("stop"); + assertThat(choice.getIndex()).isZero(); + assertThat(choice.getMessage().getContent()) + .isEqualTo( + "I'm an AI and cannot answer that question as beauty is subjective and varies from person to person."); + assertThat(choice.getMessage().getRole()).isEqualTo("assistant"); + assertThat(choice.getMessage().getToolCalls()).isNull(); + + var contentFilterResults = choice.getContentFilterResults(); + assertThat(contentFilterResults).isNotNull(); + assertThat(contentFilterResults.getSexual()).isNotNull(); + assertThat(contentFilterResults.getSexual().isFiltered()).isFalse(); + assertThat(contentFilterResults.getSexual().getSeverity()).isEqualTo(SAFE); + assertThat(contentFilterResults.getViolence()).isNotNull(); + assertThat(contentFilterResults.getViolence().isFiltered()).isFalse(); + assertThat(contentFilterResults.getViolence().getSeverity()).isEqualTo(SAFE); + assertThat(contentFilterResults.getHate()).isNotNull(); + assertThat(contentFilterResults.getHate().isFiltered()).isFalse(); + assertThat(contentFilterResults.getHate().getSeverity()).isEqualTo(SAFE); + assertThat(contentFilterResults.getSelfHarm()).isNotNull(); + assertThat(contentFilterResults.getSelfHarm().getSeverity()).isEqualTo(SAFE); + assertThat(contentFilterResults.getSelfHarm().isFiltered()).isFalse(); + assertThat(contentFilterResults.getProfanity()).isNull(); + assertThat(contentFilterResults.getError()).isNull(); + assertThat(contentFilterResults.getJailbreak()).isNull(); + + verify( + postRequestedFor(urlPathEqualTo("/chat/completions")) + .withQueryParam("api-version", equalTo("2024-02-01")) + .withRequestBody( + equalToJson( + """ + { + "messages" : [ { + "role" : "system", + "content" : "You are a helpful AI" + }, { + "role" : "user", + "content" : [ { + "type" : "text", + "text" : "Hello World! Why is this phrase so famous?" + } ] + } ] + }"""))); } @Test @@ -301,7 +172,7 @@ void history() { .withBodyFile("chatCompletionResponse.json") .withHeader("Content-Type", "application/json"))); - client.withSystemPrompt("system prompt").chatCompletion("chat completion 1"); + client.chatCompletion("First message"); verify( exactly(1), @@ -311,18 +182,20 @@ void history() { """ { "messages" : [ { - "role" : "system", - "content" : "system prompt" - }, { "role" : "user", "content" : [ { "type" : "text", - "text" : "chat completion 1" + "text" : "First message" } ] } ] }"""))); - client.withSystemPrompt("system prompt").chatCompletion("chat completion 2"); + var response = client.chatCompletion("Second message"); + + assertThat(response.getContent()).isNotNull(); + assertThat(response.getContent()) + .isEqualTo( + "I'm an AI and cannot answer that question as beauty is subjective and varies from person to person."); verify( exactly(1), @@ -331,27 +204,19 @@ void history() { equalToJson( """ { - "messages" : [ { - "role" : "system", - "content" : "system prompt" - }, { - "role" : "user", - "content" : [ { - "type" : "text", - "text" : "chat completion 2" - } ] - } ] - }"""))); + "messages" : [ { + "role" : "user", + "content" : [ { + "type" : "text", + "text" : "Second message" + } ] + } ] + }"""))); } @Test void embedding() { - stubFor( - post(urlPathEqualTo("/embeddings")) - .willReturn( - aResponse() - .withBodyFile("embeddingResponse.json") - .withHeader("Content-Type", "application/json"))); + stubForEmbedding(); final var request = new OpenAiEmbeddingParameters().setInput("Hello World"); final var result = client.embedding(request); @@ -369,7 +234,7 @@ void embedding() { var embeddingData = result.getData().get(0); assertThat(embeddingData).isNotNull(); assertThat(embeddingData.getObject()).isEqualTo("embedding"); - assertThat(embeddingData.getIndex()).isEqualTo(0); + assertThat(embeddingData.getIndex()).isZero(); assertThat(embeddingData.getEmbedding()) .isNotNull() .isNotEmpty() @@ -380,40 +245,12 @@ void embedding() { .withRequestBody( equalToJson( """ - {"input":["Hello World"]}"""))); - } - - @Test - void testThrowsOnContentFilter() { - var mock = mock(OpenAiClient.class); - when(mock.streamChatCompletion(any())).thenCallRealMethod(); - - var deltaWithContentFilter = mock(OpenAiChatCompletionDelta.class); - when(deltaWithContentFilter.getFinishReason()).thenReturn("content_filter"); - when(mock.streamChatCompletionDeltas(any())).thenReturn(Stream.of(deltaWithContentFilter)); - - // this must not throw, since the stream is lazily evaluated - var stream = mock.streamChatCompletion(""); - assertThatThrownBy(stream::toList) - .isInstanceOf(OpenAiClientException.class) - .hasMessageContaining("Content filter"); + {"input": ["Hello World"]}"""))); } @Test void streamChatCompletionDeltasErrorHandling() throws IOException { - try (var inputStream = spy(fileLoader.apply("streamChatCompletionError.txt"))) { - - final var httpClient = mock(HttpClient.class); - ApacheHttpClient5Accessor.setHttpClientFactory(destination -> httpClient); - - // Create a mock response - final var mockResponse = new BasicClassicHttpResponse(200, "OK"); - final var inputStreamEntity = new InputStreamEntity(inputStream, ContentType.TEXT_PLAIN); - mockResponse.setEntity(inputStreamEntity); - mockResponse.setHeader("Content-Type", "text/event-stream"); - - // Configure the HttpClient mock to return the mock response - doReturn(mockResponse).when(httpClient).executeOpen(any(), any(), any()); + try (var inputStream = stubStreamChatCompletion("streamChatCompletionError.txt")) { final var request = new OpenAiChatCompletionParameters() @@ -421,7 +258,7 @@ void streamChatCompletionDeltasErrorHandling() throws IOException { new OpenAiChatUserMessage() .addText("Can you give me the first 100 numbers of the Fibonacci sequence?")); - try (Stream stream = client.streamChatCompletionDeltas(request)) { + try (var stream = client.streamChatCompletionDeltas(request)) { assertThatThrownBy(() -> stream.forEach(System.out::println)) .isInstanceOf(OpenAiClientException.class) .hasMessage("Failed to parse response and error message: 'exceeded token rate limit'"); @@ -433,19 +270,7 @@ void streamChatCompletionDeltasErrorHandling() throws IOException { @Test void streamChatCompletionDeltas() throws IOException { - try (var inputStream = spy(fileLoader.apply("streamChatCompletion.txt"))) { - - final var httpClient = mock(HttpClient.class); - ApacheHttpClient5Accessor.setHttpClientFactory(destination -> httpClient); - - // Create a mock response - final var mockResponse = new BasicClassicHttpResponse(200, "OK"); - final var inputStreamEntity = new InputStreamEntity(inputStream, ContentType.TEXT_PLAIN); - mockResponse.setEntity(inputStreamEntity); - mockResponse.setHeader("Content-Type", "text/event-stream"); - - // Configure the HttpClient mock to return the mock response - doReturn(mockResponse).when(httpClient).executeOpen(any(), any(), any()); + try (var inputStream = stubStreamChatCompletion("streamChatCompletion.txt")) { final var request = new OpenAiChatCompletionParameters() @@ -461,11 +286,11 @@ void streamChatCompletionDeltas() throws IOException { assertThat(deltaList).hasSize(5); // the first two and the last delta don't have any content - assertThat(deltaList.get(0).getDeltaContent()).isEqualTo(""); - assertThat(deltaList.get(1).getDeltaContent()).isEqualTo(""); + assertThat(deltaList.get(0).getDeltaContent()).isEmpty(); + assertThat(deltaList.get(1).getDeltaContent()).isEmpty(); assertThat(deltaList.get(2).getDeltaContent()).isEqualTo("Sure"); assertThat(deltaList.get(3).getDeltaContent()).isEqualTo("!"); - assertThat(deltaList.get(4).getDeltaContent()).isEqualTo(""); + assertThat(deltaList.get(4).getDeltaContent()).isEmpty(); assertThat(deltaList.get(0).getSystemFingerprint()).isNull(); assertThat(deltaList.get(1).getSystemFingerprint()).isEqualTo("fp_e49e4201a9"); @@ -490,15 +315,16 @@ void streamChatCompletionDeltas() throws IOException { assertThat(deltaList.get(4).getChoices()).hasSize(1); final var delta0 = deltaList.get(0); - assertThat(delta0.getId()).isEqualTo(""); - assertThat(delta0.getCreated()).isEqualTo(0); - assertThat(delta0.getModel()).isEqualTo(""); - assertThat(delta0.getObject()).isEqualTo(""); + assertThat(delta0.getId()).isEmpty(); + assertThat(delta0.getCreated()).isZero(); + assertThat(delta0.getModel()).isEmpty(); + assertThat(delta0.getObject()).isEmpty(); assertThat(delta0.getUsage()).isNull(); assertThat(delta0.getChoices()).isEmpty(); + assertThat(delta0.getFinishReason()).isNull(); // prompt filter results are only present in delta 0 assertThat(delta0.getPromptFilterResults()).isNotNull(); - assertThat(delta0.getPromptFilterResults().get(0).getPromptIndex()).isEqualTo(0); + assertThat(delta0.getPromptFilterResults().get(0).getPromptIndex()).isZero(); final var promptFilter0 = delta0.getPromptFilterResults().get(0).getContentFilterResults(); assertThat(promptFilter0).isNotNull(); assertFilter(promptFilter0); @@ -510,9 +336,10 @@ void streamChatCompletionDeltas() throws IOException { assertThat(delta2.getObject()).isEqualTo("chat.completion.chunk"); assertThat(delta2.getUsage()).isNull(); assertThat(delta2.getPromptFilterResults()).isNull(); + assertThat(delta2.getFinishReason()).isNull(); + final var choices2 = delta2.getChoices().get(0); assertThat(choices2.getIndex()).isEqualTo(0); - assertThat(choices2.getFinishReason()).isNull(); assertThat(choices2.getMessage()).isNotNull(); // the role is only defined in delta 1, but it defaults to "assistant" for all deltas assertThat(choices2.getMessage().getRole()).isEqualTo("assistant"); @@ -524,8 +351,8 @@ void streamChatCompletionDeltas() throws IOException { final var delta3 = deltaList.get(3); assertThat(delta3.getDeltaContent()).isEqualTo("!"); + assertThat(deltaList.get(4).getFinishReason()).isEqualTo("stop"); final var delta4Choice = deltaList.get(4).getChoices().get(0); - assertThat(delta4Choice.getFinishReason()).isEqualTo("stop"); assertThat(delta4Choice.getMessage()).isNotNull(); // the role is only defined in delta 1, but it defaults to "assistant" for all deltas assertThat(delta4Choice.getMessage().getRole()).isEqualTo("assistant"); @@ -576,4 +403,81 @@ void assertFilter(OpenAiContentFilterPromptResults filter) { assertThat(filter.getProfanity()).isNull(); assertThat(filter.getError()).isNull(); } + + @Test + void chatCompletionTool() { + stubForChatCompletionTool(); + + final var question = + "A pair of rabbits is placed in a field. Each month, every pair produces one new pair, starting from the second month. How many rabbits will there be after 12 months?"; + final var par = Map.of("type", "object", "properties", Map.of("N", Map.of("type", "integer"))); + final var function = new OpenAiChatCompletionFunction().setName("fibonacci").setParameters(par); + final var tool = new OpenAiChatCompletionTool().setType(FUNCTION).setFunction(function); + final var request = + new OpenAiChatCompletionParameters() + .addMessages(new OpenAiChatUserMessage().addText(question)) + .setTools(List.of(tool)) + .setToolChoiceFunction("fibonacci"); + + var response = client.chatCompletion(request); + + assertThat(response).isNotNull(); + assertThat(response.getChoices()).hasSize(1); + assertThat(response.getChoices().get(0).getFinishReason()).isEqualTo("stop"); + assertThat(response.getChoices().get(0).getMessage().getRole()).isEqualTo("assistant"); + assertThat(response.getChoices().get(0).getMessage().getToolCalls()).hasSize(1); + assertThat(response.getChoices().get(0).getMessage().getToolCalls().get(0).getId()) + .isEqualTo("call_CUYGJf2j7FRWJMHT3PN3aGxK"); + assertThat(response.getChoices().get(0).getMessage().getToolCalls().get(0).getType()) + .isEqualTo("function"); + assertThat( + response.getChoices().get(0).getMessage().getToolCalls().get(0).getFunction().getName()) + .isEqualTo("fibonacci"); + assertThat( + response + .getChoices() + .get(0) + .getMessage() + .getToolCalls() + .get(0) + .getFunction() + .getArguments()) + .isEqualTo("{\"N\":12}"); + + verify( + postRequestedFor(anyUrl()) + .withRequestBody( + equalToJson( + """ + { + "messages" : [ { + "role" : "user", + "content" : [ { + "type" : "text", + "text" : "A pair of rabbits is placed in a field. Each month, every pair produces one new pair, starting from the second month. How many rabbits will there be after 12 months?" + } ] + } ], + "tools" : [ { + "type" : "function", + "function" : { + "name" : "fibonacci", + "parameters" : { + "type" : "object", + "properties" : { + "N" : { + "type" : "integer" + } + } + } + } + } ], + "tool_choice" : { + "function" : { + "name" : "fibonacci" + }, + "type" : "function" + } + } + """))); + } } diff --git a/foundation-models/openai/src/test/resources/__files/chatCompletionToolResponse.json b/foundation-models/openai/src/test/resources/__files/chatCompletionToolResponse.json new file mode 100644 index 000000000..f2cfab223 --- /dev/null +++ b/foundation-models/openai/src/test/resources/__files/chatCompletionToolResponse.json @@ -0,0 +1,56 @@ +{ + "choices": [ + { + "content_filter_results": {}, + "finish_reason": "stop", + "index": 0, + "message": { + "content": null, + "role": "assistant", + "tool_calls": [ + { + "function": { + "arguments": "{\"N\":12}", + "name": "fibonacci" + }, + "id": "call_CUYGJf2j7FRWJMHT3PN3aGxK", + "type": "function" + } + ] + } + } + ], + "created": 1738867288, + "id": "chatcmpl-Ay16WxAjYvRDJ12ImPi8E1T1TZHoP", + "model": "gpt-3.5-turbo-1106", + "object": "chat.completion", + "prompt_filter_results": [ + { + "content_filter_results": { + "hate": { + "filtered": false, + "severity": "safe" + }, + "self_harm": { + "filtered": false, + "severity": "safe" + }, + "sexual": { + "filtered": false, + "severity": "safe" + }, + "violence": { + "filtered": false, + "severity": "safe" + } + }, + "prompt_index": 0 + } + ], + "system_fingerprint": "fp_0165350fbb", + "usage": { + "completion_tokens": 5, + "prompt_tokens": 93, + "total_tokens": 98 + } +} \ No newline at end of file diff --git a/sample-code/spring-app/pom.xml b/sample-code/spring-app/pom.xml index 34d2be7fe..a8c348c85 100644 --- a/sample-code/spring-app/pom.xml +++ b/sample-code/spring-app/pom.xml @@ -114,10 +114,6 @@ com.fasterxml.jackson.core jackson-databind - - com.fasterxml.jackson.core - jackson-core - com.fasterxml.jackson.core jackson-annotations diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OpenAiController.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OpenAiController.java index 26ea351bc..e48b67b70 100644 --- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OpenAiController.java +++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OpenAiController.java @@ -1,12 +1,14 @@ package com.sap.ai.sdk.app.controllers; -import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.annotation.JsonAutoDetect; +import com.fasterxml.jackson.annotation.PropertyAccessor; import com.fasterxml.jackson.databind.ObjectMapper; import com.sap.ai.sdk.app.services.OpenAiService; -import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionOutput; +import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiUsage; import com.sap.cloud.sdk.cloudplatform.thread.ThreadContextExecutors; import java.io.IOException; import java.util.Arrays; +import java.util.concurrent.atomic.AtomicReference; import javax.annotation.Nonnull; import javax.annotation.Nullable; import lombok.extern.slf4j.Slf4j; @@ -25,6 +27,8 @@ @SuppressWarnings("unused") public class OpenAiController { @Autowired private OpenAiService service; + private static final ObjectMapper MAPPER = + new ObjectMapper().setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY); @GetMapping("/chatCompletion") @Nonnull @@ -44,16 +48,19 @@ ResponseEntity streamChatCompletionDeltas() { final var message = "Can you give me the first 100 numbers of the Fibonacci sequence?"; final var stream = service.streamChatCompletionDeltas(message); final var emitter = new ResponseBodyEmitter(); + final var totalUsage = new AtomicReference(); final Runnable consumeStream = () -> { - final var totalOutput = new OpenAiChatCompletionOutput(); - // try-with-resources ensures the stream is closed try (stream) { - stream - .peek(totalOutput::addDelta) - .forEach(delta -> send(emitter, delta.getDeltaContent())); + stream.forEach( + delta -> { + // Instead of getCompletionUsage(MAPPER), we now use getUsage() + final var usage = delta.getUsage(); + totalUsage.compareAndExchange(null, usage); + send(emitter, delta.getDeltaContent()); + }); } finally { - send(emitter, "\n\n-----Total Output-----\n\n" + objectToJson(totalOutput)); + send(emitter, "\n\n-----Total Usage-----\n\n" + totalUsage.get()); emitter.complete(); } }; @@ -100,20 +107,6 @@ public static void send(@Nonnull final ResponseBodyEmitter emitter, @Nonnull fin } } - /** - * Convert an object to JSON - * - * @param obj The object to convert - * @return The JSON representation of the object - */ - private static String objectToJson(@Nonnull final Object obj) { - try { - return new ObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(obj); - } catch (final JsonProcessingException ignored) { - return "Could not parse object to JSON"; - } - } - @GetMapping("/chatCompletionImage") @Nonnull Object chatCompletionImage( @@ -124,15 +117,14 @@ Object chatCompletionImage( if ("json".equals(format)) { return response; } - return response.getContent(); + return response.getChoices().get(0).getMessage(); } @GetMapping("/chatCompletionTool") @Nonnull Object chatCompletionTools( @Nullable @RequestParam(value = "format", required = false) final String format) { - final var response = - service.chatCompletionTools("Calculate the Fibonacci number for given sequence index."); + final var response = service.chatCompletionTools(12); if ("json".equals(format)) { return response; } diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OpenAiService.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OpenAiService.java index 5d5f26209..0032093d6 100644 --- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OpenAiService.java +++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OpenAiService.java @@ -88,18 +88,19 @@ public OpenAiChatCompletionOutput chatCompletionImage(@Nonnull final String link /** * Chat request to OpenAI with a tool. * - * @param prompt The prompt to send to the assistant + * @param months The number of months to be inferred in the tool * @return the assistant message response */ @Nonnull - public OpenAiChatCompletionOutput chatCompletionTools(@Nonnull final String prompt) { + public OpenAiChatCompletionOutput chatCompletionTools(final int months) { final var question = - "A pair of rabbits is placed in a field. Each month, every pair produces one new pair, starting from the second month. How many rabbits will there be after 12 months?"; + "A pair of rabbits is placed in a field. Each month, every pair produces one new pair, starting from the second month. How many rabbits will there be after %s months?" + .formatted(months); final var par = Map.of("type", "object", "properties", Map.of("N", Map.of("type", "integer"))); final var function = new OpenAiChatCompletionFunction() .setName("fibonacci") - .setDescription(prompt) + .setDescription("Calculate the Fibonacci number for given sequence index.") .setParameters(par); final var tool = new OpenAiChatCompletionTool().setType(FUNCTION).setFunction(function); final var request = diff --git a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/NewOpenAiTest.java b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/NewOpenAiTest.java new file mode 100644 index 000000000..89c43ca0a --- /dev/null +++ b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/NewOpenAiTest.java @@ -0,0 +1,103 @@ +package com.sap.ai.sdk.app.controllers; + +import static com.sap.ai.sdk.foundationmodels.openai.OpenAiModel.GPT_35_TURBO; +import static com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionResponseMessageRole.ASSISTANT; +import static org.assertj.core.api.Assertions.assertThat; + +import com.sap.ai.sdk.app.services.NewOpenAiService; +import com.sap.ai.sdk.foundationmodels.openai.OpenAiChatCompletionRequest; +import com.sap.ai.sdk.foundationmodels.openai.OpenAiClient; +import com.sap.ai.sdk.foundationmodels.openai.OpenAiMessage; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.CompletionUsage; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import lombok.extern.slf4j.Slf4j; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +@Slf4j +class NewOpenAiTest { + NewOpenAiService service; + + @BeforeEach + void setUp() { + service = new NewOpenAiService(); + } + + @Test + void chatCompletion() { + final var completion = service.chatCompletion("Who is the prettiest"); + + assertThat(completion.getChoice().getMessage().getRole()).isEqualTo(ASSISTANT); + assertThat(completion.getContent()).isNotEmpty(); + } + + @Test + void chatCompletionImage() { + final var completion = + service.chatCompletionImage( + "https://upload.wikimedia.org/wikipedia/commons/thumb/5/59/SAP_2011_logo.svg/440px-SAP_2011_logo.svg.png"); + + final var message = completion.getChoices().get(0).getMessage(); + assertThat(message.getRole()).isEqualTo(ASSISTANT); + assertThat(message.getContent()).isNotEmpty(); + } + + @Test + void streamChatCompletion() { + final var userMessage = OpenAiMessage.user("Who is the prettiest?"); + final var prompt = new OpenAiChatCompletionRequest(userMessage); + + final var totalOutput = new AtomicReference(); + final var filledDeltaCount = new AtomicInteger(0); + OpenAiClient.forModel(GPT_35_TURBO) + .streamChatCompletionDeltas(prompt) + // foreach consumes all elements, closing the stream at the end + .forEach( + delta -> { + final var usage = delta.getCompletionUsage(); + totalOutput.compareAndExchange(null, usage); + final String deltaContent = delta.getDeltaContent(); + log.info("delta: {}", delta); + if (!deltaContent.isEmpty()) { + filledDeltaCount.incrementAndGet(); + } + }); + + // the first two and the last delta don't have any content + // see OpenAiChatCompletionDelta#getDeltaContent + assertThat(filledDeltaCount.get()).isGreaterThan(0); + + assertThat(totalOutput.get().getTotalTokens()).isGreaterThan(0); + assertThat(totalOutput.get().getPromptTokens()).isEqualTo(14); + assertThat(totalOutput.get().getCompletionTokens()).isGreaterThan(0); + } + + @Test + void chatCompletionTools() { + final var completion = service.chatCompletionTools(12); + + final var message = completion.getChoice().getMessage(); + assertThat(message.getRole()).isEqualTo(ASSISTANT); + assertThat(message.getToolCalls()).isNotNull(); + assertThat(message.getToolCalls().get(0).getFunction().getName()).isEqualTo("fibonacci"); + } + + @Test + void embedding() { + final var embedding = service.embedding("Hello world"); + + assertThat(embedding.getData().get(0).getEmbedding()).hasSizeGreaterThan(1); + assertThat(embedding.getModel()).isEqualTo("ada"); + assertThat(embedding.getObject()).isEqualTo("list"); + } + + @Test + void chatCompletionWithResource() { + final var completion = + service.chatCompletionWithResource("ai-sdk-java-e2e", "Where is the nearest coffee shop?"); + + assertThat(completion.getChoice().getMessage().getRole()).isEqualTo(ASSISTANT); + assertThat(completion.getContent()).isNotEmpty(); + } +} diff --git a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OpenAiTest.java b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OpenAiTest.java index ec07fb037..78ff6fccb 100644 --- a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OpenAiTest.java +++ b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OpenAiTest.java @@ -75,8 +75,7 @@ void streamChatCompletion() { @Test void chatCompletionTools() { - final var completion = - service.chatCompletionTools("Calculate the Fibonacci number for given sequence index."); + final var completion = service.chatCompletionTools(12); final var message = completion.getChoices().get(0).getMessage(); assertThat(message.getRole()).isEqualTo("assistant"); diff --git a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/services/NewOpenAiService.java b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/services/NewOpenAiService.java new file mode 100644 index 000000000..f81881278 --- /dev/null +++ b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/services/NewOpenAiService.java @@ -0,0 +1,179 @@ +package com.sap.ai.sdk.app.services; + +import static com.sap.ai.sdk.foundationmodels.openai.OpenAiModel.GPT_35_TURBO; +import static com.sap.ai.sdk.foundationmodels.openai.OpenAiModel.GPT_4O; +import static com.sap.ai.sdk.foundationmodels.openai.OpenAiModel.TEXT_EMBEDDING_ADA_002; +import static com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestMessageContentPartImage.TypeEnum.IMAGE_URL; +import static com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestMessageContentPartImageImageUrl.DetailEnum.HIGH; +import static com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestMessageContentPartText.TypeEnum.TEXT; +import static com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestUserMessage.RoleEnum.USER; +import static com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool.TypeEnum.FUNCTION; + +import com.sap.ai.sdk.core.AiCoreService; +import com.sap.ai.sdk.foundationmodels.openai.OpenAiChatCompletionDelta; +import com.sap.ai.sdk.foundationmodels.openai.OpenAiChatCompletionRequest; +import com.sap.ai.sdk.foundationmodels.openai.OpenAiChatCompletionResponse; +import com.sap.ai.sdk.foundationmodels.openai.OpenAiClient; +import com.sap.ai.sdk.foundationmodels.openai.OpenAiMessage; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionNamedToolChoice; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionNamedToolChoiceFunction; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestMessageContentPartImage; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestMessageContentPartImageImageUrl; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestMessageContentPartText; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestUserMessage; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestUserMessageContent; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionToolChoiceOption; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionRequest; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionResponse; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreate200Response; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequest; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequestInput; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.FunctionObject; +import java.net.URI; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; +import javax.annotation.Nonnull; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Service; + +/** Service class for OpenAI service */ +@Service +@Slf4j +public class NewOpenAiService { + + /** + * Chat request to OpenAI + * + * @param prompt The prompt to send to the assistant + * @return the assistant message response + */ + @Nonnull + public OpenAiChatCompletionResponse chatCompletion(@Nonnull final String prompt) { + return OpenAiClient.forModel(GPT_35_TURBO) + .chatCompletion(new OpenAiChatCompletionRequest(prompt)); + } + + /** + * Asynchronous stream of an OpenAI chat request + * + * @return the emitter that streams the assistant message response + */ + @Nonnull + public Stream streamChatCompletionDeltas( + @Nonnull final String message) { + final var request = new OpenAiChatCompletionRequest(OpenAiMessage.user(message)); + + return OpenAiClient.forModel(GPT_35_TURBO).streamChatCompletionDeltas(request); + } + + /** + * Asynchronous stream of an OpenAI chat request + * + * @return the emitter that streams the assistant message response + */ + @Nonnull + public Stream streamChatCompletion(@Nonnull final String message) { + return OpenAiClient.forModel(GPT_35_TURBO) + .withSystemPrompt("Be a good, honest AI and answer the following question:") + .streamChatCompletion(message); + } + + /** + * Chat request to OpenAI with an image + * + * @param linkToImage The link to the image + * @return the assistant message response + */ + @Nonnull + public CreateChatCompletionResponse chatCompletionImage(@Nonnull final String linkToImage) { + final var partText = + new ChatCompletionRequestMessageContentPartText() + .type(TEXT) + .text("Describe the following image."); + final var partImageUrl = + new ChatCompletionRequestMessageContentPartImageImageUrl() + .url(URI.create(linkToImage)) + .detail(HIGH); + final var partImage = + new ChatCompletionRequestMessageContentPartImage().type(IMAGE_URL).imageUrl(partImageUrl); + final var userMessage = + new ChatCompletionRequestUserMessage() + .role(USER) + .content(ChatCompletionRequestUserMessageContent.create(List.of(partText, partImage))); + final var request = + new CreateChatCompletionRequest() + .addMessagesItem(userMessage) + .functions(null) + .tools(null) + .parallelToolCalls(null); + + return OpenAiClient.forModel(GPT_4O).chatCompletion(request); + } + + /** + * Chat request to OpenAI with a tool. + * + * @param months The number of months to be inferred in the tool + * @return the assistant message response + */ + @Nonnull + public OpenAiChatCompletionResponse chatCompletionTools(final int months) { + final var function = + new FunctionObject() + .name("fibonacci") + .description("Calculate the Fibonacci number for given sequence index.") + .parameters( + Map.of("type", "object", "properties", Map.of("N", Map.of("type", "integer")))); + + final var tool = new ChatCompletionTool().type(FUNCTION).function(function); + + final var toolChoice = + ChatCompletionToolChoiceOption.create( + new ChatCompletionNamedToolChoice() + .type(ChatCompletionNamedToolChoice.TypeEnum.FUNCTION) + .function(new ChatCompletionNamedToolChoiceFunction().name("fibonacci"))); + + final var request = + new OpenAiChatCompletionRequest( + "A pair of rabbits is placed in a field. Each month, every pair produces one new pair, starting from the second month. How many rabbits will there be after %s months?" + .formatted(months)) + .withTools(List.of(tool)) + .withToolChoice(toolChoice); + + return OpenAiClient.forModel(GPT_35_TURBO).chatCompletion(request); + } + + /** + * Get the embedding of a text + * + * @param input The text to embed + * @return the embedding response + */ + @Nonnull + public EmbeddingsCreate200Response embedding(@Nonnull final String input) { + final var request = + new EmbeddingsCreateRequest().input(EmbeddingsCreateRequestInput.create(input)); + + return OpenAiClient.forModel(TEXT_EMBEDDING_ADA_002).embedding(request); + } + + /** + * Chat request to OpenAI filtering by resource group + * + * @param resourceGroup The resource group to use + * @param prompt The prompt to send to the assistant + * @return the assistant message response + */ + @Nonnull + public OpenAiChatCompletionResponse chatCompletionWithResource( + @Nonnull final String resourceGroup, @Nonnull final String prompt) { + + final var destination = + new AiCoreService().getInferenceDestination(resourceGroup).forModel(GPT_4O); + + return OpenAiClient.withCustomDestination(destination) + .chatCompletion(new OpenAiChatCompletionRequest(prompt)); + } +}