diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/JacksonMixins.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/JacksonMixins.java
new file mode 100644
index 000000000..a6b06e751
--- /dev/null
+++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/JacksonMixins.java
@@ -0,0 +1,25 @@
+package com.sap.ai.sdk.foundationmodels.openai;
+
+import com.fasterxml.jackson.annotation.JsonTypeInfo;
+import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionStreamResponse;
+import lombok.AccessLevel;
+import lombok.NoArgsConstructor;
+
+/**
+ * This class contains Jackson Mixins for customizing the serialization and deserialization behavior
+ * of certain classes in the OpenAI SDK.
+ */
+@NoArgsConstructor(access = AccessLevel.PRIVATE)
+final class JacksonMixins {
+
+ /**
+ * Mixin interface to customize the deserialization of CreateChatCompletionStreamResponse.
+ *
+ *
Disables type information inclusion and specifies the concrete class to use for
+ * deserialization.
+ */
+ @JsonTypeInfo(use = JsonTypeInfo.Id.NONE)
+ @JsonDeserialize(as = CreateChatCompletionStreamResponse.class)
+ interface DefaultChatCompletionCreate200ResponseMixIn {}
+}
diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiAssistantMessage.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiAssistantMessage.java
new file mode 100644
index 000000000..15cb3ffe0
--- /dev/null
+++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiAssistantMessage.java
@@ -0,0 +1,37 @@
+package com.sap.ai.sdk.foundationmodels.openai;
+
+import com.google.common.annotations.Beta;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestAssistantMessage;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestAssistantMessageContent;
+import javax.annotation.Nonnull;
+import lombok.Value;
+import lombok.experimental.Accessors;
+
+/**
+ * Represents a chat message as 'assistant' to OpenAI service.
+ *
+ * @since 1.4.0
+ */
+@Beta
+@Value
+@Accessors(fluent = true)
+class OpenAiAssistantMessage implements OpenAiMessage {
+
+ /** The role of the message. */
+ @Nonnull String role = "assistant";
+
+ /** The content of the message. */
+ @Nonnull String content;
+
+ /**
+ * Converts the message to a serializable object.
+ *
+ * @return the corresponding {@code ChatCompletionRequestAssistantMessage} object.
+ */
+ @Nonnull
+ ChatCompletionRequestAssistantMessage createChatCompletionRequestMessage() {
+ return new ChatCompletionRequestAssistantMessage()
+ .role(ChatCompletionRequestAssistantMessage.RoleEnum.fromValue(role()))
+ .content(ChatCompletionRequestAssistantMessageContent.create(content));
+ }
+}
diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionDelta.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionDelta.java
new file mode 100644
index 000000000..46b5416e5
--- /dev/null
+++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionDelta.java
@@ -0,0 +1,69 @@
+package com.sap.ai.sdk.foundationmodels.openai;
+
+import static com.sap.ai.sdk.foundationmodels.openai.OpenAiUtils.getOpenAiObjectMapper;
+import static lombok.AccessLevel.PACKAGE;
+
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.google.common.annotations.Beta;
+import com.sap.ai.sdk.core.common.StreamedDelta;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.CompletionUsage;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionStreamResponse;
+import java.util.Map;
+import java.util.Objects;
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+import lombok.EqualsAndHashCode;
+import lombok.Getter;
+import lombok.RequiredArgsConstructor;
+import lombok.ToString;
+
+/**
+ * Represents an OpenAI chat completion output delta for streaming.
+ *
+ * @since 1.4.0
+ */
+@Beta
+@RequiredArgsConstructor(onConstructor_ = @JsonCreator, access = PACKAGE)
+@Getter
+@ToString
+@EqualsAndHashCode
+public class OpenAiChatCompletionDelta implements StreamedDelta {
+ /** The original response from the chat completion stream. */
+ @Nonnull private final CreateChatCompletionStreamResponse originalResponse;
+
+ @Nonnull
+ @Override
+ public String getDeltaContent() {
+ final var choices = getOriginalResponse().getChoices();
+ if (!choices.isEmpty() && choices.get(0).getIndex() == 0) {
+ final var message = choices.get(0).getDelta();
+ return Objects.requireNonNullElse(message.getContent(), "");
+ }
+ return "";
+ }
+
+ @Nullable
+ @Override
+ public String getFinishReason() {
+ final var choices = getOriginalResponse().getChoices();
+ if (!choices.isEmpty()) {
+ final var finishReason = choices.get(0).getFinishReason();
+ return finishReason != null ? finishReason.getValue() : null;
+ }
+ return null;
+ }
+
+ /**
+ * Retrieves the completion usage from the response, or null if it is not available.
+ *
+ * @return The completion usage or null.
+ */
+ @Nullable
+ public CompletionUsage getCompletionUsage() {
+ if (getOriginalResponse().getCustomFieldNames().contains("usage")
+ && getOriginalResponse().getCustomField("usage") instanceof Map, ?> usage) {
+ return getOpenAiObjectMapper().convertValue(usage, CompletionUsage.class);
+ }
+ return null;
+ }
+}
diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequest.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequest.java
new file mode 100644
index 000000000..18d23f148
--- /dev/null
+++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequest.java
@@ -0,0 +1,287 @@
+package com.sap.ai.sdk.foundationmodels.openai;
+
+import com.google.common.annotations.Beta;
+import com.google.common.collect.Lists;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionStreamOptions;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionToolChoiceOption;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionRequest;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionRequestAllOfResponseFormat;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionRequestAllOfStop;
+import java.math.BigDecimal;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+import lombok.AccessLevel;
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+import lombok.Value;
+import lombok.With;
+import lombok.experimental.Tolerate;
+
+/**
+ * Represents a request for OpenAI chat completion, including conversation messages and parameters.
+ *
+ * @see OpenAI
+ * API Reference
+ * @since 1.4.0
+ */
+@Beta
+@Value
+@With
+@AllArgsConstructor(access = AccessLevel.PRIVATE)
+@Getter(value = AccessLevel.NONE)
+public class OpenAiChatCompletionRequest {
+ /** List of messages from the conversation. */
+ @Nonnull List messages;
+
+ /** Upto 4 Stop sequences to interrupts token generation and returns a response without them. */
+ @Nullable List stop;
+
+ /**
+ * Controls the randomness of the completion.
+ *
+ * Lower values (e.g. 0.0) make the model more deterministic and repetitive, while higher
+ * values (e.g. 1.0) make the model more random and creative.
+ */
+ @Nullable BigDecimal temperature;
+
+ /**
+ * Controls the cumulative probability threshold used for nucleus sampling. Alternative to {@link
+ * #temperature}.
+ *
+ *
Lower values (e.g. 0.1) limit the model to consider only the smallest set of tokens whose
+ * combined probabilities add up to at least 10% of the total.
+ */
+ @Nullable BigDecimal topP;
+
+ /** Maximum number of tokens that can be generated for the completion. */
+ @Nullable Integer maxTokens;
+
+ /**
+ * Maximum number of tokens that can be generated for the completion, including consumed reasoning
+ * tokens. This field supersedes {@link #maxTokens} and should be used with newer models.
+ */
+ @Nullable Integer maxCompletionTokens;
+
+ /**
+ * Encourage new topic by penalising token based on their presence in the completion.
+ *
+ *
Value should be in range [-2, 2].
+ */
+ @Nullable BigDecimal presencePenalty;
+
+ /**
+ * Encourage new topic by penalising tokens based on their frequency in the completion.
+ *
+ *
Value should be in range [-2, 2].
+ */
+ @Nullable BigDecimal frequencyPenalty;
+
+ /**
+ * A map that adjusts the likelihood of specified tokens by adding a bias value (between -100 and
+ * 100) to the logits before sampling. Extreme values can effectively ban or enforce the selection
+ * of tokens.
+ */
+ @Nullable Map logitBias;
+
+ /**
+ * Unique identifier for the end-user making the request. This can help with monitoring and abuse
+ * detection.
+ */
+ @Nullable String user;
+
+ /** Whether to include log probabilities in the response. */
+ @With(AccessLevel.NONE)
+ @Nullable
+ Boolean logprobs;
+
+ /**
+ * Number of top log probabilities to return for each token. An integer between 0 and 20. This is
+ * only relevant if {@code logprobs} is enabled.
+ */
+ @Nullable Integer topLogprobs;
+
+ /** Number of completions to generate. */
+ @Nullable Integer n;
+
+ /** Whether to allow parallel tool calls. */
+ @With(AccessLevel.NONE)
+ @Nullable
+ Boolean parallelToolCalls;
+
+ /** Seed for random number generation. */
+ @Nullable Integer seed;
+
+ /** Options for streaming the completion response. */
+ @Nullable ChatCompletionStreamOptions streamOptions;
+
+ /** Response format for the completion. */
+ @Nullable CreateChatCompletionRequestAllOfResponseFormat responseFormat;
+
+ /** List of tools that the model may invoke during the completion. */
+ @Nullable List tools;
+
+ /** Option to control which tool is invoked by the model. */
+ @Nullable ChatCompletionToolChoiceOption toolChoice;
+
+ /**
+ * Creates an OpenAiChatCompletionPrompt with string as user message.
+ *
+ * @param message the message to be added to the prompt
+ */
+ @Tolerate
+ public OpenAiChatCompletionRequest(@Nonnull final String message) {
+ this(OpenAiMessage.user(message));
+ }
+
+ /**
+ * Creates an OpenAiChatCompletionPrompt with a multiple unpacked messages.
+ *
+ * @param message the primary message to be added to the prompt
+ * @param messages additional messages to be added to the prompt
+ */
+ @Tolerate
+ public OpenAiChatCompletionRequest(
+ @Nonnull final OpenAiMessage message, @Nonnull final OpenAiMessage... messages) {
+ this(
+ Lists.asList(message, messages),
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null);
+ }
+
+ /**
+ * Adds stop sequences to the request.
+ *
+ * @param sequence the primary stop sequence
+ * @param sequences additional stop sequences
+ * @return a new OpenAiChatCompletionRequest instance with the specified stop sequences
+ */
+ @Tolerate
+ @Nonnull
+ public OpenAiChatCompletionRequest withStop(
+ @Nonnull final String sequence, @Nonnull final String... sequences) {
+ return this.withStop(Lists.asList(sequence, sequences));
+ }
+
+ /**
+ * Sets the parallel tool calls option.
+ *
+ * @param parallelToolCalls Whether to allow parallel tool calls.
+ * @return A new instance with the specified option.
+ */
+ @Nonnull
+ public OpenAiChatCompletionRequest withParallelToolCalls(
+ @Nonnull final Boolean parallelToolCalls) {
+ return Objects.equals(this.parallelToolCalls, parallelToolCalls)
+ ? this
+ : new OpenAiChatCompletionRequest(
+ this.messages,
+ this.stop,
+ this.temperature,
+ this.topP,
+ this.maxTokens,
+ this.maxCompletionTokens,
+ this.presencePenalty,
+ this.frequencyPenalty,
+ this.logitBias,
+ this.user,
+ this.logprobs,
+ this.topLogprobs,
+ this.n,
+ parallelToolCalls,
+ this.seed,
+ this.streamOptions,
+ this.responseFormat,
+ this.tools,
+ this.toolChoice);
+ }
+
+ /**
+ * Sets the log probabilities option.
+ *
+ * @param logprobs Whether to include log probabilities in the response.
+ * @return A new instance with the specified option.
+ */
+ @Nonnull
+ public OpenAiChatCompletionRequest withLogprobs(@Nonnull final Boolean logprobs) {
+ return Objects.equals(this.logprobs, logprobs)
+ ? this
+ : new OpenAiChatCompletionRequest(
+ this.messages,
+ this.stop,
+ this.temperature,
+ this.topP,
+ this.maxTokens,
+ this.maxCompletionTokens,
+ this.presencePenalty,
+ this.frequencyPenalty,
+ this.logitBias,
+ this.user,
+ logprobs,
+ this.topLogprobs,
+ this.n,
+ this.parallelToolCalls,
+ this.seed,
+ this.streamOptions,
+ this.responseFormat,
+ this.tools,
+ this.toolChoice);
+ }
+
+ /**
+ * Converts the request to a generated model class CreateChatCompletionRequest.
+ *
+ * @return the CreateChatCompletionRequest
+ */
+ CreateChatCompletionRequest createCreateChatCompletionRequest() {
+ final var request = new CreateChatCompletionRequest();
+ this.messages.forEach(
+ message ->
+ request.addMessagesItem(OpenAiUtils.createChatCompletionRequestMessage(message)));
+
+ request.stop(this.stop != null ? CreateChatCompletionRequestAllOfStop.create(this.stop) : null);
+
+ request.temperature(this.temperature);
+ request.topP(this.topP);
+
+ request.stream(null);
+ request.maxTokens(this.maxTokens);
+ request.maxCompletionTokens(this.maxCompletionTokens);
+ request.presencePenalty(this.presencePenalty);
+ request.frequencyPenalty(this.frequencyPenalty);
+ request.logitBias(this.logitBias);
+ request.user(this.user);
+ request.logprobs(this.logprobs);
+ request.topLogprobs(this.topLogprobs);
+ request.n(this.n);
+ request.parallelToolCalls(this.parallelToolCalls);
+ request.seed(this.seed);
+ request.streamOptions(this.streamOptions);
+ request.responseFormat(this.responseFormat);
+ request.tools(this.tools);
+ request.toolChoice(this.toolChoice);
+ request.functionCall(null);
+ request.functions(null);
+ return request;
+ }
+}
diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionResponse.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionResponse.java
new file mode 100644
index 000000000..4da73fe6f
--- /dev/null
+++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionResponse.java
@@ -0,0 +1,64 @@
+package com.sap.ai.sdk.foundationmodels.openai;
+
+import static com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionResponseChoicesInner.FinishReasonEnum.CONTENT_FILTER;
+import static lombok.AccessLevel.NONE;
+import static lombok.AccessLevel.PACKAGE;
+
+import com.google.common.annotations.Beta;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.CompletionUsage;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionResponse;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionResponseChoicesInner;
+import java.util.Objects;
+import javax.annotation.Nonnull;
+import lombok.RequiredArgsConstructor;
+import lombok.Setter;
+import lombok.Value;
+
+/**
+ * Represents the output of an OpenAI chat completion. *
+ *
+ * @since 1.4.0
+ */
+@Beta
+@Value
+@RequiredArgsConstructor(access = PACKAGE)
+@Setter(value = NONE)
+public class OpenAiChatCompletionResponse {
+ /** The original response from the OpenAI API. */
+ @Nonnull final CreateChatCompletionResponse originalResponse;
+
+ /**
+ * Gets the token usage from the original response.
+ *
+ * @return the token usage
+ */
+ @Nonnull
+ public CompletionUsage getTokenUsage() {
+ return getOriginalResponse().getUsage();
+ }
+
+ /**
+ * Gets the first choice from the original response.
+ *
+ * @return the first choice
+ */
+ @Nonnull
+ public CreateChatCompletionResponseChoicesInner getChoice() {
+ return getOriginalResponse().getChoices().get(0);
+ }
+
+ /**
+ * Gets the content of the first choice.
+ *
+ * @return the content of the first choice
+ * @throws OpenAiClientException if the content is filtered by the content filter
+ */
+ @Nonnull
+ public String getContent() {
+ if (CONTENT_FILTER.equals(getOriginalResponse().getChoices().get(0).getFinishReason())) {
+ throw new OpenAiClientException("Content filter filtered the output.");
+ }
+
+ return Objects.requireNonNullElse(getChoice().getMessage().getContent(), "");
+ }
+}
diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java
index 8699201a7..fe5415dfb 100644
--- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java
+++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java
@@ -1,6 +1,6 @@
package com.sap.ai.sdk.foundationmodels.openai;
-import static com.sap.ai.sdk.core.JacksonConfiguration.getDefaultObjectMapper;
+import static com.sap.ai.sdk.foundationmodels.openai.OpenAiUtils.getOpenAiObjectMapper;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
@@ -10,14 +10,17 @@
import com.sap.ai.sdk.core.common.ClientResponseHandler;
import com.sap.ai.sdk.core.common.ClientStreamingHandler;
import com.sap.ai.sdk.core.common.StreamedDelta;
-import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionDelta;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionStreamOptions;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionRequest;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionResponse;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreate200Response;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequest;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionOutput;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionParameters;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage.OpenAiChatSystemMessage;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage.OpenAiChatUserMessage;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiEmbeddingOutput;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiEmbeddingParameters;
-import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiError;
import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor;
import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination;
import com.sap.cloud.sdk.cloudplatform.connectivity.Destination;
@@ -39,7 +42,8 @@
@RequiredArgsConstructor(access = AccessLevel.PRIVATE)
public final class OpenAiClient {
private static final String DEFAULT_API_VERSION = "2024-02-01";
- static final ObjectMapper JACKSON = getDefaultObjectMapper();
+ static final ObjectMapper JACKSON = getOpenAiObjectMapper();
+
@Nullable private String systemPrompt = null;
@Nonnull private final Destination destination;
@@ -107,7 +111,11 @@ public static OpenAiClient withCustomDestination(@Nonnull final Destination dest
}
/**
- * Add a system prompt before user prompts.
+ * Use this method to set a system prompt that should be used across multiple chat completions
+ * with basic string prompts {@link #streamChatCompletionDeltas(OpenAiChatCompletionParameters)}.
+ *
+ * Note: The system prompt is ignored on chat completions invoked with
+ * OpenAiChatCompletionPrompt.
*
* @param systemPrompt the system prompt
* @return the client
@@ -119,7 +127,7 @@ public OpenAiClient withSystemPrompt(@Nonnull final String systemPrompt) {
}
/**
- * Generate a completion for the given user prompt.
+ * Generate a completion for the given string prompt as user.
*
* @param prompt a text message.
* @return the completion output
@@ -137,9 +145,41 @@ public OpenAiChatCompletionOutput chatCompletion(@Nonnull final String prompt)
}
/**
- * Generate a completion for the given prompt.
+ * Generate a completion for the given conversation and request parameters.
+ *
+ * @param request the completion request.
+ * @return the completion output
+ * @throws OpenAiClientException if the request fails
+ * @since 1.4.0
+ */
+ @Beta
+ @Nonnull
+ public OpenAiChatCompletionResponse chatCompletion(
+ @Nonnull final OpenAiChatCompletionRequest request) throws OpenAiClientException {
+ warnIfUnsupportedUsage();
+ return new OpenAiChatCompletionResponse(
+ chatCompletion(request.createCreateChatCompletionRequest()));
+ }
+
+ /**
+ * Generate a completion for the given low-level request object.
*
- * @param parameters the prompt, including messages and other parameters.
+ * @param request the completion request.
+ * @return the completion output
+ * @throws OpenAiClientException if the request fails
+ * @since 1.4.0
+ */
+ @Beta
+ @Nonnull
+ public CreateChatCompletionResponse chatCompletion(
+ @Nonnull final CreateChatCompletionRequest request) throws OpenAiClientException {
+ return execute("/chat/completions", request, CreateChatCompletionResponse.class);
+ }
+
+ /**
+ * Generate a completion for the given conversation and request parameters.
+ *
+ * @param parameters the completion request.
* @return the completion output
* @throws OpenAiClientException if the request fails
*/
@@ -151,9 +191,10 @@ public OpenAiChatCompletionOutput chatCompletion(
}
/**
- * Stream a completion for the given prompt. Returns a lazily populated stream of text
- * chunks. To access more details about the individual chunks, use {@link
- * #streamChatCompletionDeltas(OpenAiChatCompletionParameters)}.
+ * Stream a completion for the given string prompt as user.
+ *
+ *
Returns a lazily populated stream of text chunks. To access more details about the
+ * individual chunks, use {@link #streamChatCompletionDeltas(OpenAiChatCompletionRequest)}.
*
*
The stream should be consumed using a try-with-resources block to ensure that the underlying
* HTTP connection is closed.
@@ -171,19 +212,21 @@ public OpenAiChatCompletionOutput chatCompletion(
* Stream#parallel()} on this stream is not supported.
*
* @param prompt a text message.
- * @return A stream of message deltas
+ * @return A stream of text chunks
* @throws OpenAiClientException if the request fails or if the finish reason is content_filter
- * @see #streamChatCompletionDeltas(OpenAiChatCompletionParameters)
+ * @see #streamChatCompletionDeltas(OpenAiChatCompletionRequest)
*/
@Nonnull
public Stream streamChatCompletion(@Nonnull final String prompt)
throws OpenAiClientException {
- final OpenAiChatCompletionParameters parameters = new OpenAiChatCompletionParameters();
- if (systemPrompt != null) {
- parameters.addMessages(new OpenAiChatSystemMessage().setContent(systemPrompt));
- }
- parameters.addMessages(new OpenAiChatUserMessage().addText(prompt));
- return streamChatCompletionDeltas(parameters)
+ final var userPrompt = OpenAiMessage.user(prompt);
+
+ final var request =
+ systemPrompt != null
+ ? new OpenAiChatCompletionRequest(OpenAiMessage.system(systemPrompt), userPrompt)
+ : new OpenAiChatCompletionRequest(userPrompt);
+
+ return streamChatCompletionDeltas(request.createCreateChatCompletionRequest())
.peek(OpenAiClient::throwOnContentFilter)
.map(OpenAiChatCompletionDelta::getDeltaContent);
}
@@ -196,8 +239,10 @@ private static void throwOnContentFilter(@Nonnull final OpenAiChatCompletionDelt
}
/**
- * Stream a completion for the given prompt. Returns a lazily populated stream of delta
- * objects. To simply stream the text chunks use {@link #streamChatCompletion(String)}
+ * Stream a completion for the given conversation and request parameters.
+ *
+ * Returns a lazily populated stream of delta objects. To simply stream the text chunks
+ * use {@link #streamChatCompletion(String)}
*
*
The stream should be consumed using a try-with-resources block to ensure that the underlying
* HTTP connection is closed.
@@ -205,7 +250,7 @@ private static void throwOnContentFilter(@Nonnull final OpenAiChatCompletionDelt
*
Example:
*
*
{@code
- * try (var stream = client.streamChatCompletionDeltas(params)) {
+ * try (var stream = client.streamChatCompletionDeltas(prompt)) {
* stream
* .peek(delta -> System.out.println(delta.getUsage()))
* .map(OpenAiChatCompletionDelta::getDeltaContent)
@@ -217,17 +262,75 @@ private static void throwOnContentFilter(@Nonnull final OpenAiChatCompletionDelt
* block until all chunks are consumed. Also, for obvious reasons, invoking {@link
* Stream#parallel()} on this stream is not supported.
*
- * @param parameters The prompt, including messages and other parameters.
+ * @param request The prompt, including a list of messages.
* @return A stream of message deltas
* @throws OpenAiClientException if the request fails or if the finish reason is content_filter
* @see #streamChatCompletion(String)
+ * @since 1.4.0
*/
+ @Beta
@Nonnull
public Stream streamChatCompletionDeltas(
- @Nonnull final OpenAiChatCompletionParameters parameters) throws OpenAiClientException {
+ @Nonnull final OpenAiChatCompletionRequest request) throws OpenAiClientException {
+ return streamChatCompletionDeltas(request.createCreateChatCompletionRequest());
+ }
+
+ /**
+ * Stream a completion for the given low-level request object. Returns a lazily populated
+ * stream of delta objects.
+ *
+ * @param request The completion request.
+ * @return A stream of message deltas
+ * @throws OpenAiClientException if the request fails or if the finish reason is content_filter
+ * @see #streamChatCompletionDeltas(OpenAiChatCompletionRequest) for a higher-level API
+ * @since 1.4.0
+ */
+ @Beta
+ @Nonnull
+ public Stream streamChatCompletionDeltas(
+ @Nonnull final CreateChatCompletionRequest request) throws OpenAiClientException {
+ request.stream(true).streamOptions(new ChatCompletionStreamOptions().includeUsage(true));
+ return executeStream("/chat/completions", request, OpenAiChatCompletionDelta.class);
+ }
+
+ /**
+ * Stream a completion for the given conversation and request parameters.
+ *
+ * Returns a lazily populated stream of delta objects. To simply stream the text chunks
+ * use {@link #streamChatCompletion(String)}
+ *
+ *
The stream should be consumed using a try-with-resources block to ensure that the underlying
+ * HTTP connection is closed.
+ *
+ *
Example:
+ *
+ *
{@code
+ * try (var stream = client.streamChatCompletionDeltas(request)) {
+ * stream
+ * .peek(delta -> System.out.println(delta.getUsage()))
+ * .map(com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionDelta::getDeltaContent)
+ * .forEach(System.out::println);
+ * }
+ * }
+ *
+ * Please keep in mind that using a terminal stream operation like {@link Stream#forEach} will
+ * block until all chunks are consumed. Also, for obvious reasons, invoking {@link
+ * Stream#parallel()} on this stream is not supported.
+ *
+ * @param parameters The prompt, including a list of messages.
+ * @return A stream of message deltas
+ * @throws OpenAiClientException if the request fails or if the finish reason is content_filter
+ */
+ @Nonnull
+ public Stream
+ streamChatCompletionDeltas(@Nonnull final OpenAiChatCompletionParameters parameters)
+ throws OpenAiClientException {
warnIfUnsupportedUsage();
parameters.enableStreaming();
- return executeStream("/chat/completions", parameters, OpenAiChatCompletionDelta.class);
+ return executeStream(
+ "/chat/completions",
+ parameters,
+ com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionDelta.class);
}
private void warnIfUnsupportedUsage() {
@@ -237,6 +340,22 @@ private void warnIfUnsupportedUsage() {
}
}
+ /**
+ * Get a vector representation of a given request with input that can be easily consumed by
+ * machine learning models and algorithms.
+ *
+ * @param request the request with input text.
+ * @return the embedding output
+ * @throws OpenAiClientException if the request fails
+ * @since 1.4.0
+ */
+ @Beta
+ @Nonnull
+ public EmbeddingsCreate200Response embedding(@Nonnull final EmbeddingsCreateRequest request)
+ throws OpenAiClientException {
+ return execute("/embeddings", request, EmbeddingsCreate200Response.class);
+ }
+
/**
* Get a vector representation of a given input that can be easily consumed by machine learning
* models and algorithms.
@@ -300,6 +419,7 @@ private Stream streamRequest(
try {
final var client = ApacheHttpClient5Accessor.getHttpClient(destination);
return new ClientStreamingHandler<>(deltaType, OpenAiError.class, OpenAiClientException::new)
+ .objectMapper(JACKSON)
.handleStreamingResponse(client.executeOpen(null, request, null));
} catch (final IOException e) {
throw new OpenAiClientException("Request to OpenAI model failed", e);
diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiError.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiError.java
new file mode 100644
index 000000000..4f1011ff0
--- /dev/null
+++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiError.java
@@ -0,0 +1,35 @@
+package com.sap.ai.sdk.foundationmodels.openai;
+
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.google.common.annotations.Beta;
+import com.sap.ai.sdk.core.common.ClientError;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.ErrorResponse;
+import javax.annotation.Nonnull;
+import lombok.AccessLevel;
+import lombok.AllArgsConstructor;
+import lombok.Value;
+import lombok.experimental.Delegate;
+
+/**
+ * Represents an error response from the OpenAI API.
+ *
+ * @since 1.4.0
+ */
+@Beta
+@Value
+@AllArgsConstructor(onConstructor = @__({@JsonCreator}), access = AccessLevel.PROTECTED)
+public class OpenAiError implements ClientError {
+ /** The original error response from the OpenAI API. */
+ @Delegate(types = {ClientError.class})
+ ErrorResponse originalResponse;
+
+ /**
+ * Gets the error message from the contained original response.
+ *
+ * @return the error message
+ */
+ @Nonnull
+ public String getMessage() {
+ return originalResponse.getError().getMessage();
+ }
+}
diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiMessage.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiMessage.java
new file mode 100644
index 000000000..a8a25855d
--- /dev/null
+++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiMessage.java
@@ -0,0 +1,46 @@
+package com.sap.ai.sdk.foundationmodels.openai;
+
+import com.google.common.annotations.Beta;
+import javax.annotation.Nonnull;
+
+/**
+ * Interface representing convenience wrappers of chat message to the openai service.
+ *
+ * @since 1.4.0
+ */
+@Beta
+public interface OpenAiMessage {
+
+ /**
+ * A convenience method to create a user message.
+ *
+ * @param msg the message content.
+ * @return the user message.
+ */
+ @Nonnull
+ static OpenAiMessage user(@Nonnull final String msg) {
+ return new OpenAiUserMessage(msg);
+ }
+
+ /**
+ * A convenience method to create an assistant message.
+ *
+ * @param msg the message content.
+ * @return the assistant message.
+ */
+ @Nonnull
+ static OpenAiMessage assistant(@Nonnull final String msg) {
+ return new OpenAiAssistantMessage(msg);
+ }
+
+ /**
+ * A convenience method to create a system message.
+ *
+ * @param msg the message content.
+ * @return the system message.
+ */
+ @Nonnull
+ static OpenAiMessage system(@Nonnull final String msg) {
+ return new OpenAiSystemMessage(msg);
+ }
+}
diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiSystemMessage.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiSystemMessage.java
new file mode 100644
index 000000000..c24f6c277
--- /dev/null
+++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiSystemMessage.java
@@ -0,0 +1,37 @@
+package com.sap.ai.sdk.foundationmodels.openai;
+
+import com.google.common.annotations.Beta;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestSystemMessage;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestSystemMessageContent;
+import javax.annotation.Nonnull;
+import lombok.Value;
+import lombok.experimental.Accessors;
+
+/**
+ * Represents a chat message as 'system' to OpenAI service. *
+ *
+ * @since 1.4.0
+ */
+@Beta
+@Value
+@Accessors(fluent = true)
+class OpenAiSystemMessage implements OpenAiMessage {
+
+ /** The role of the message. */
+ @Nonnull String role = "system";
+
+ /** The content of the message. */
+ @Nonnull String content;
+
+ /**
+ * Converts the message to a serializable object.
+ *
+ * @return the corresponding {@code ChatCompletionRequestSystemMessage} object.
+ */
+ @Nonnull
+ ChatCompletionRequestSystemMessage createChatCompletionRequestMessage() {
+ return new ChatCompletionRequestSystemMessage()
+ .role(ChatCompletionRequestSystemMessage.RoleEnum.fromValue(role()))
+ .content(ChatCompletionRequestSystemMessageContent.create(content()));
+ }
+}
diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiUserMessage.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiUserMessage.java
new file mode 100644
index 000000000..597340a3e
--- /dev/null
+++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiUserMessage.java
@@ -0,0 +1,37 @@
+package com.sap.ai.sdk.foundationmodels.openai;
+
+import com.google.common.annotations.Beta;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestUserMessage;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestUserMessageContent;
+import javax.annotation.Nonnull;
+import lombok.Value;
+import lombok.experimental.Accessors;
+
+/**
+ * Represents a chat message as 'user' to OpenAI service. *
+ *
+ * @since 1.4.0
+ */
+@Beta
+@Value
+@Accessors(fluent = true)
+class OpenAiUserMessage implements OpenAiMessage {
+
+ /** The role of the message. */
+ @Nonnull String role = "user";
+
+ /** The content of the message. */
+ @Nonnull String content;
+
+ /**
+ * Converts the message to a serializable object.
+ *
+ * @return the corresponding {@code ChatCompletionRequestUserMessage} object.
+ */
+ @Nonnull
+ ChatCompletionRequestUserMessage createChatCompletionRequestMessage() {
+ return new ChatCompletionRequestUserMessage()
+ .role(ChatCompletionRequestUserMessage.RoleEnum.fromValue(role()))
+ .content(ChatCompletionRequestUserMessageContent.create(content()));
+ }
+}
diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiUtils.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiUtils.java
new file mode 100644
index 000000000..2abe27dd8
--- /dev/null
+++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiUtils.java
@@ -0,0 +1,54 @@
+package com.sap.ai.sdk.foundationmodels.openai;
+
+import static com.sap.ai.sdk.core.JacksonConfiguration.getDefaultObjectMapper;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.google.common.annotations.Beta;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestMessage;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionsCreate200Response;
+import javax.annotation.Nonnull;
+
+/**
+ * Utility class for handling OpenAI module.
+ *
+ * Only intended for internal usage within this SDK.
+ *
+ * @since 1.4.0
+ */
+@Beta
+class OpenAiUtils {
+
+ /**
+ * Converts an OpenAiMessage to a ChatCompletionRequestMessage.
+ *
+ * @param message the OpenAiMessage to convert
+ * @return the corresponding ChatCompletionRequestMessage
+ * @throws IllegalArgumentException if the message type is unknown
+ */
+ @Nonnull
+ static ChatCompletionRequestMessage createChatCompletionRequestMessage(
+ @Nonnull final OpenAiMessage message) throws IllegalArgumentException {
+ if (message instanceof OpenAiUserMessage userMessage) {
+ return userMessage.createChatCompletionRequestMessage();
+ } else if (message instanceof OpenAiAssistantMessage assistantMessage) {
+ return assistantMessage.createChatCompletionRequestMessage();
+ } else if (message instanceof OpenAiSystemMessage systemMessage) {
+ return systemMessage.createChatCompletionRequestMessage();
+ } else {
+ throw new IllegalArgumentException("Unknown message type: " + message.getClass());
+ }
+ }
+
+ /**
+ * Default object mapper used for JSON de-/serialization.
+ *
+ * @return A new object mapper with the default configuration.
+ */
+ @Nonnull
+ static ObjectMapper getOpenAiObjectMapper() {
+ return getDefaultObjectMapper()
+ .addMixIn(
+ ChatCompletionsCreate200Response.class,
+ JacksonMixins.DefaultChatCompletionCreate200ResponseMixIn.class);
+ }
+}
diff --git a/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/BaseOpenAiClientTest.java b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/BaseOpenAiClientTest.java
new file mode 100644
index 000000000..122b074d7
--- /dev/null
+++ b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/BaseOpenAiClientTest.java
@@ -0,0 +1,182 @@
+package com.sap.ai.sdk.foundationmodels.openai;
+
+import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
+import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl;
+import static com.github.tomakehurst.wiremock.client.WireMock.badRequest;
+import static com.github.tomakehurst.wiremock.client.WireMock.equalTo;
+import static com.github.tomakehurst.wiremock.client.WireMock.noContent;
+import static com.github.tomakehurst.wiremock.client.WireMock.okXml;
+import static com.github.tomakehurst.wiremock.client.WireMock.post;
+import static com.github.tomakehurst.wiremock.client.WireMock.serverError;
+import static com.github.tomakehurst.wiremock.client.WireMock.stubFor;
+import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+
+import com.fasterxml.jackson.core.JsonParseException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo;
+import com.github.tomakehurst.wiremock.junit5.WireMockTest;
+import com.github.tomakehurst.wiremock.stubbing.Scenario;
+import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor;
+import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Cache;
+import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination;
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.Objects;
+import java.util.function.Function;
+import javax.annotation.Nonnull;
+import org.apache.hc.client5.http.classic.HttpClient;
+import org.apache.hc.core5.http.ContentType;
+import org.apache.hc.core5.http.io.entity.InputStreamEntity;
+import org.apache.hc.core5.http.message.BasicClassicHttpResponse;
+import org.assertj.core.api.SoftAssertions;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+
+@WireMockTest
+abstract class BaseOpenAiClientTest {
+
+ protected static final ObjectMapper MAPPER = new ObjectMapper();
+ protected static OpenAiClient client;
+ protected final Function fileLoader =
+ filename -> Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream(filename));
+
+ static void stubForChatCompletion() {
+
+ stubFor(
+ post(urlPathEqualTo("/chat/completions"))
+ .withQueryParam("api-version", equalTo("2024-02-01"))
+ .willReturn(
+ aResponse()
+ .withBodyFile("chatCompletionResponse.json")
+ .withHeader("Content-Type", "application/json")));
+ }
+
+ static void stubForEmbedding() {
+ stubFor(
+ post(urlPathEqualTo("/embeddings"))
+ .willReturn(
+ aResponse()
+ .withBodyFile("embeddingResponse.json")
+ .withHeader("Content-Type", "application/json")));
+ }
+
+ static void stubForChatCompletionTool() {
+ stubFor(
+ post(urlPathEqualTo("/chat/completions"))
+ .willReturn(
+ aResponse()
+ .withHeader("Content-Type", "application/json")
+ .withBodyFile("chatCompletionToolResponse.json")));
+ }
+
+ static void stubForErrorHandling() {
+ final var errorJson =
+ """
+ { "error": { "code": null, "message": "foo", "type": "invalid stuff" } }
+ """;
+ stubFor(
+ post(anyUrl())
+ .inScenario("Errors")
+ .whenScenarioStateIs(Scenario.STARTED)
+ .willReturn(serverError())
+ .willSetStateTo("1"));
+ stubFor(
+ post(anyUrl())
+ .inScenario("Errors")
+ .whenScenarioStateIs("1")
+ .willReturn(
+ badRequest().withBody(errorJson).withHeader("Content-type", "application/json"))
+ .willSetStateTo("2"));
+ stubFor(
+ post(anyUrl())
+ .inScenario("Errors")
+ .whenScenarioStateIs("2")
+ .willReturn(
+ badRequest()
+ .withBody("{ broken json")
+ .withHeader("Content-type", "application/json"))
+ .willSetStateTo("3"));
+ stubFor(
+ post(anyUrl())
+ .inScenario("Errors")
+ .whenScenarioStateIs("3")
+ .willReturn(okXml(""))
+ .willSetStateTo("4"));
+ stubFor(post(anyUrl()).inScenario("Errors").whenScenarioStateIs("4").willReturn(noContent()));
+ }
+
+ static void assertForErrorHandling(@Nonnull final Runnable request) {
+
+ final var softly = new SoftAssertions();
+
+ softly
+ .assertThatThrownBy(request::run)
+ .describedAs("Server errors should be handled")
+ .isInstanceOf(OpenAiClientException.class)
+ .hasMessageContaining("500");
+
+ softly
+ .assertThatThrownBy(request::run)
+ .describedAs("Error objects from OpenAI should be interpreted")
+ .isInstanceOf(OpenAiClientException.class)
+ .hasMessageContaining("error message: 'foo'");
+
+ softly
+ .assertThatThrownBy(request::run)
+ .describedAs("Failures while parsing error message should be handled")
+ .isInstanceOf(OpenAiClientException.class)
+ .hasMessageContaining("400")
+ .extracting(e -> e.getSuppressed()[0])
+ .isInstanceOf(JsonParseException.class);
+
+ softly
+ .assertThatThrownBy(request::run)
+ .describedAs("Non-JSON responses should be handled")
+ .isInstanceOf(OpenAiClientException.class)
+ .hasMessageContaining("Failed to parse");
+
+ softly
+ .assertThatThrownBy(request::run)
+ .describedAs("Empty responses should be handled")
+ .isInstanceOf(OpenAiClientException.class)
+ .hasMessageContaining("was empty");
+
+ softly.assertAll();
+ }
+
+ @BeforeEach
+ void setup(WireMockRuntimeInfo server) {
+ final DefaultHttpDestination destination =
+ DefaultHttpDestination.builder(server.getHttpBaseUrl()).build();
+ client = OpenAiClient.withCustomDestination(destination);
+ ApacheHttpClient5Accessor.setHttpClientCache(ApacheHttpClient5Cache.DISABLED);
+ }
+
+ @AfterEach
+ void reset() {
+ ApacheHttpClient5Accessor.setHttpClientCache(null);
+ ApacheHttpClient5Accessor.setHttpClientFactory(null);
+ }
+
+ InputStream stubStreamChatCompletion(String responseFile) throws IOException {
+ var inputStream = spy(fileLoader.apply(responseFile));
+
+ final var httpClient = mock(HttpClient.class);
+ ApacheHttpClient5Accessor.setHttpClientFactory(destination -> httpClient);
+
+ // Create a mock response
+ final var mockResponse = new BasicClassicHttpResponse(200, "OK");
+ final var inputStreamEntity = new InputStreamEntity(inputStream, ContentType.TEXT_PLAIN);
+ mockResponse.setEntity(inputStreamEntity);
+ mockResponse.setHeader("Content-Type", "text/event-stream");
+
+ // Configure the HttpClient mock to return the mock response
+ doReturn(mockResponse).when(httpClient).executeOpen(any(), any(), any());
+
+ return inputStream;
+ }
+}
diff --git a/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequestTest.java b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequestTest.java
new file mode 100644
index 000000000..0ede0f272
--- /dev/null
+++ b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequestTest.java
@@ -0,0 +1,48 @@
+package com.sap.ai.sdk.foundationmodels.openai;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestUserMessage;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestUserMessageContent;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionRequestAllOfStop;
+import java.math.BigDecimal;
+import java.util.List;
+import org.junit.jupiter.api.Test;
+
+class OpenAiChatCompletionRequestTest {
+
+ @Test
+ void stopSequence() {
+ var request =
+ new OpenAiChatCompletionRequest(OpenAiMessage.user("Hello, world"))
+ .withStop("stop1", "stop2 stopNot3", "stop3");
+
+ var lowLevelRequest = request.createCreateChatCompletionRequest();
+ assertThat(
+ ((CreateChatCompletionRequestAllOfStop.InnerStrings) lowLevelRequest.getStop())
+ .values())
+ .containsExactly("stop1", "stop2 stopNot3", "stop3");
+ }
+
+ @Test
+ void createWithExistingRequest() {
+ var originalRequest =
+ new OpenAiChatCompletionRequest(OpenAiMessage.user("First message"))
+ .withSeed(123)
+ .withTemperature(BigDecimal.valueOf(0.5));
+
+ var newRequest = originalRequest.withMessages(List.of(OpenAiMessage.user("Another message")));
+
+ var lowlevelRequest = newRequest.createCreateChatCompletionRequest();
+
+ assertThat(newRequest).isNotEqualTo(originalRequest);
+ assertThat(lowlevelRequest.getMessages()).hasSize(1);
+ assertThat(lowlevelRequest.getMessages().get(0))
+ .isInstanceOf(ChatCompletionRequestUserMessage.class);
+ assertThat(
+ ((ChatCompletionRequestUserMessage) lowlevelRequest.getMessages().get(0)).getContent())
+ .isEqualTo(ChatCompletionRequestUserMessageContent.create("Another message"));
+ assertThat(lowlevelRequest.getSeed()).isEqualTo(123);
+ assertThat(lowlevelRequest.getTemperature()).isEqualTo(BigDecimal.valueOf(0.5));
+ }
+}
diff --git a/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientGeneratedTest.java b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientGeneratedTest.java
new file mode 100644
index 000000000..fbf01573e
--- /dev/null
+++ b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientGeneratedTest.java
@@ -0,0 +1,605 @@
+package com.sap.ai.sdk.foundationmodels.openai;
+
+import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
+import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl;
+import static com.github.tomakehurst.wiremock.client.WireMock.equalTo;
+import static com.github.tomakehurst.wiremock.client.WireMock.equalToJson;
+import static com.github.tomakehurst.wiremock.client.WireMock.exactly;
+import static com.github.tomakehurst.wiremock.client.WireMock.okJson;
+import static com.github.tomakehurst.wiremock.client.WireMock.post;
+import static com.github.tomakehurst.wiremock.client.WireMock.postRequestedFor;
+import static com.github.tomakehurst.wiremock.client.WireMock.stubFor;
+import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo;
+import static com.github.tomakehurst.wiremock.client.WireMock.verify;
+import static com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionStreamResponseDelta.RoleEnum.ASSISTANT;
+import static com.sap.ai.sdk.foundationmodels.openai.generated.model.ContentFilterSeverityResult.SeverityEnum.SAFE;
+import static com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionResponse.ObjectEnum.CHAT_COMPLETION;
+import static com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionResponseChoicesInner.FinishReasonEnum.STOP;
+import static com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionStreamResponse.ObjectEnum.CHAT_COMPLETION_CHUNK;
+import static com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionStreamResponse.ObjectEnum.UNKNOWN_DEFAULT_OPEN_API;
+import static com.sap.ai.sdk.foundationmodels.openai.generated.model.ToolCallType.FUNCTION;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.when;
+
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionNamedToolChoice;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionNamedToolChoiceFunction;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionResponseMessageRole;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionToolChoiceOption;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.ContentFilterPromptResults;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionRequest;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionStreamResponseChoicesInner;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequest;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequestInput;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.FunctionObject;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.PromptFilterResult;
+import io.vavr.control.Try;
+import java.io.IOException;
+import java.math.BigDecimal;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Stream;
+import javax.annotation.Nonnull;
+import lombok.SneakyThrows;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.MethodSource;
+import org.mockito.Mockito;
+
+class OpenAiClientGeneratedTest extends BaseOpenAiClientTest {
+
+ @Test
+ void openAiModels() {
+ var model = OpenAiModel.GPT_4;
+ var newModel = model.withVersion("v1");
+
+ assertThat(model.name()).isEqualTo("gpt-4");
+ assertThat(model.version()).isNull();
+
+ assertThat(newModel.name()).isEqualTo("gpt-4");
+ assertThat(newModel.version()).isEqualTo("v1");
+
+ assertThat(model).isNotSameAs(newModel);
+ }
+
+ private static Runnable[] errorHandlingCalls() {
+ return new Runnable[] {
+ () -> client.chatCompletion(new OpenAiChatCompletionRequest("")),
+ () ->
+ client
+ .streamChatCompletionDeltas(new OpenAiChatCompletionRequest(""))
+ // the stream needs to be consumed to parse the response
+ .forEach(System.out::println)
+ };
+ }
+
+ @ParameterizedTest
+ @MethodSource("errorHandlingCalls")
+ void chatCompletionErrorHandling(@Nonnull final Runnable request) {
+ stubForErrorHandling();
+ assertForErrorHandling(request);
+ }
+
+ @Test
+ void apiVersion() {
+ stubFor(post(anyUrl()).willReturn(okJson("{}")));
+ Try.of(() -> client.chatCompletion(new OpenAiChatCompletionRequest("")));
+
+ verify(
+ exactly(1),
+ postRequestedFor(anyUrl()).withQueryParam("api-version", equalTo("2024-02-01")));
+
+ Try.of(
+ () -> client.withApiVersion("fooBar").chatCompletion(new OpenAiChatCompletionRequest("")));
+ verify(exactly(1), postRequestedFor(anyUrl()).withQueryParam("api-version", equalTo("fooBar")));
+
+ assertThat(client)
+ .describedAs(
+ "withApiVersion should return a new object, the sut object should remain unchanged")
+ .isNotSameAs(client.withApiVersion("fooBar"));
+ Try.of(() -> client.chatCompletion(new OpenAiChatCompletionRequest("")));
+ verify(
+ exactly(2),
+ postRequestedFor(anyUrl()).withQueryParam("api-version", equalTo("2024-02-01")));
+ }
+
+ @Test
+ void chatCompletion() {
+
+ stubForChatCompletion();
+
+ final var systemMessage = OpenAiMessage.system("You are a helpful AI");
+ final var userMessage = OpenAiMessage.user("Hello World! Why is this phrase so famous?");
+ final var prompt = new OpenAiChatCompletionRequest(systemMessage, userMessage);
+ final var result = client.chatCompletion(prompt).getOriginalResponse();
+
+ assertThat(result).isNotNull();
+ assertThat(result.getCreated()).isEqualTo(1727436279);
+ assertThat(result.getId()).isEqualTo("chatcmpl-AC3NPPYlxem8kRBBAX9EBObMMsrnf");
+ assertThat(result.getModel()).isEqualTo("gpt-35-turbo");
+ assertThat(result.getObject()).isEqualTo(CHAT_COMPLETION);
+ assertThat(result.getSystemFingerprint()).isEqualTo("fp_e49e4201a9");
+
+ assertThat(result.getUsage()).isNotNull();
+ assertThat(result.getUsage().getCompletionTokens()).isEqualTo(20);
+ assertThat(result.getUsage().getPromptTokens()).isEqualTo(13);
+ assertThat(result.getUsage().getTotalTokens()).isEqualTo(33);
+
+ assertThat(result.getPromptFilterResults()).hasSize(1);
+ assertThat(result.getPromptFilterResults().get(0).getPromptIndex()).isZero();
+
+ var promptFilterResults = result.getPromptFilterResults().get(0).getContentFilterResults();
+ assertThat(promptFilterResults).isNotNull();
+ assertThat(promptFilterResults.getSexual()).isNotNull();
+ assertThat(promptFilterResults.getSexual().isFiltered()).isFalse();
+ assertThat(promptFilterResults.getSexual().getSeverity()).isEqualTo(SAFE);
+ assertThat(promptFilterResults.getViolence()).isNotNull();
+ assertThat(promptFilterResults.getViolence().isFiltered()).isFalse();
+ assertThat(promptFilterResults.getViolence().getSeverity()).isEqualTo(SAFE);
+ assertThat(promptFilterResults.getHate()).isNotNull();
+ assertThat(promptFilterResults.getHate().isFiltered()).isFalse();
+ assertThat(promptFilterResults.getHate().getSeverity()).isEqualTo(SAFE);
+ assertThat(promptFilterResults.getSelfHarm()).isNotNull();
+ assertThat(promptFilterResults.getSelfHarm().getSeverity()).isEqualTo(SAFE);
+ assertThat(promptFilterResults.getSelfHarm().isFiltered()).isFalse();
+ assertThat(promptFilterResults.getProfanity()).isNull();
+ assertThat(promptFilterResults.getError()).isNull();
+ assertThat(promptFilterResults.getJailbreak()).isNull();
+
+ assertThat(result.getChoices()).hasSize(1);
+
+ var choice = result.getChoices().get(0);
+ assertThat(choice.getFinishReason()).isEqualTo(STOP);
+ assertThat(choice.getIndex()).isZero();
+ assertThat(choice.getMessage().getContent())
+ .isEqualTo(
+ "I'm an AI and cannot answer that question as beauty is subjective and varies from person to person.");
+ assertThat(choice.getMessage().getRole())
+ .isEqualTo(ChatCompletionResponseMessageRole.ASSISTANT);
+ assertThat(choice.getMessage().getToolCalls()).isNull();
+
+ var contentFilterResults = choice.getContentFilterResults();
+ assertThat(contentFilterResults).isNotNull();
+ assertThat(contentFilterResults.getSexual()).isNotNull();
+ assertThat(contentFilterResults.getSexual().isFiltered()).isFalse();
+ assertThat(contentFilterResults.getSexual().getSeverity()).isEqualTo(SAFE);
+ assertThat(contentFilterResults.getViolence()).isNotNull();
+ assertThat(contentFilterResults.getViolence().isFiltered()).isFalse();
+ assertThat(contentFilterResults.getViolence().getSeverity()).isEqualTo(SAFE);
+ assertThat(contentFilterResults.getHate()).isNotNull();
+ assertThat(contentFilterResults.getHate().isFiltered()).isFalse();
+ assertThat(contentFilterResults.getHate().getSeverity()).isEqualTo(SAFE);
+ assertThat(contentFilterResults.getSelfHarm()).isNotNull();
+ assertThat(contentFilterResults.getSelfHarm().getSeverity()).isEqualTo(SAFE);
+ assertThat(contentFilterResults.getSelfHarm().isFiltered()).isFalse();
+ assertThat(contentFilterResults.getProfanity()).isNull();
+ assertThat(contentFilterResults.getError()).isNull();
+
+ verify(
+ postRequestedFor(urlPathEqualTo("/chat/completions"))
+ .withQueryParam("api-version", equalTo("2024-02-01"))
+ .withRequestBody(
+ equalToJson(
+ """
+ {
+ "messages" : [ {
+ "content" : "You are a helpful AI",
+ "role" : "system"
+ }, {
+ "content" : "Hello World! Why is this phrase so famous?",
+ "role" : "user"
+ } ]
+ }""")));
+ }
+
+ @Test
+ @DisplayName("Chat history is not implemented yet")
+ void history() {
+ stubFor(
+ post(urlPathEqualTo("/chat/completions"))
+ .willReturn(
+ aResponse()
+ .withBodyFile("chatCompletionResponse.json")
+ .withHeader("Content-Type", "application/json")));
+
+ client.chatCompletion(new OpenAiChatCompletionRequest("First message"));
+
+ verify(
+ exactly(1),
+ postRequestedFor(urlPathEqualTo("/chat/completions"))
+ .withRequestBody(
+ equalToJson(
+ """
+ {
+ "messages" : [{
+ "content" : "First message",
+ "role" : "user"
+ }]
+ }""")));
+
+ var response = client.chatCompletion(new OpenAiChatCompletionRequest("Second message"));
+
+ assertThat(response.getContent()).isNotNull();
+ assertThat(response.getContent())
+ .isEqualTo(
+ "I'm an AI and cannot answer that question as beauty is subjective and varies from person to person.");
+
+ verify(
+ exactly(1),
+ postRequestedFor(urlPathEqualTo("/chat/completions"))
+ .withRequestBody(
+ equalToJson(
+ """
+ {
+ "messages" : [{
+ "content" : "Second message",
+ "role" : "user"
+ }]
+ }""")));
+ }
+
+ @Test
+ void embedding() {
+ stubForEmbedding();
+
+ final var result =
+ client.embedding(
+ new EmbeddingsCreateRequest()
+ .input(EmbeddingsCreateRequestInput.create("Hello World")));
+
+ assertThat(result).isNotNull();
+ assertThat(result.getModel()).isEqualTo("ada");
+ assertThat(result.getObject()).isEqualTo("list");
+
+ assertThat(result.getUsage()).isNotNull();
+ assertThat(result.getUsage().getPromptTokens()).isEqualTo(2);
+ assertThat(result.getUsage().getTotalTokens()).isEqualTo(2);
+
+ assertThat(result.getData()).isNotNull().hasSize(1);
+ var embeddingData = result.getData().get(0);
+ assertThat(embeddingData).isNotNull();
+ assertThat(embeddingData.getObject()).isEqualTo("embedding");
+ assertThat(embeddingData.getIndex()).isZero();
+ assertThat(embeddingData.getEmbedding())
+ .isNotNull()
+ .isNotEmpty()
+ .containsExactly(
+ new BigDecimal("0.0"),
+ new BigDecimal("3.4028235E+38"),
+ new BigDecimal("1.4E-45"),
+ new BigDecimal("1.23"),
+ new BigDecimal("-4.56"));
+
+ verify(
+ postRequestedFor(urlPathEqualTo("/embeddings"))
+ .withRequestBody(
+ equalToJson(
+ """
+ {"input": "Hello World"}""")));
+ }
+
+ @Test
+ void testThrowsOnContentFilter() {
+ var mock = mock(OpenAiClient.class);
+ when(mock.streamChatCompletion(any())).thenCallRealMethod();
+
+ var deltaWithContentFilter = mock(OpenAiChatCompletionDelta.class);
+ when(deltaWithContentFilter.getFinishReason()).thenReturn("content_filter");
+ when(mock.streamChatCompletionDeltas((CreateChatCompletionRequest) any()))
+ .thenReturn(Stream.of(deltaWithContentFilter));
+
+ // this must not throw, since the stream is lazily evaluated
+ var stream = mock.streamChatCompletion("");
+ assertThatThrownBy(stream::toList)
+ .isInstanceOf(OpenAiClientException.class)
+ .hasMessageContaining("Content filter");
+ }
+
+ @Test
+ void streamChatCompletionDeltasErrorHandling() throws IOException {
+ try (var inputStream = stubStreamChatCompletion("streamChatCompletionError.txt")) {
+
+ final var request =
+ new OpenAiChatCompletionRequest(
+ "Can you give me the first 100 numbers of the Fibonacci sequence?");
+
+ try (var stream = client.streamChatCompletionDeltas(request)) {
+ assertThatThrownBy(() -> stream.forEach(System.out::println))
+ .isInstanceOf(OpenAiClientException.class)
+ .hasMessage("Failed to parse response and error message: 'exceeded token rate limit'");
+ }
+
+ Mockito.verify(inputStream, times(1)).close();
+ }
+ }
+
+ @SneakyThrows
+ @Test
+ void streamChatCompletionWithString() {
+ try (var inputStream = stubStreamChatCompletion("streamChatCompletion.txt")) {
+ final var userMessage = "Hello World! Why is this phrase so famous?";
+ client.withSystemPrompt("You are a helpful AI");
+ final var result = client.streamChatCompletion(userMessage).toList();
+
+ assertThat(result).hasSize(5);
+ // the first two and the last delta don't have any content
+ assertThat(result.get(0)).isEmpty();
+ assertThat(result.get(1)).isEmpty();
+ assertThat(result.get(2)).isEqualTo("Sure");
+ assertThat(result.get(3)).isEqualTo("!");
+ assertThat(result.get(4)).isEmpty();
+
+ Mockito.verify(inputStream, times(1)).close();
+ }
+ }
+
+ @Test
+ void streamChatCompletionDeltas() throws IOException {
+
+ try (var inputStream = stubStreamChatCompletion("streamChatCompletion.txt")) {
+
+ final var request =
+ new OpenAiChatCompletionRequest(
+ "Can you give me the first 100 numbers of the Fibonacci sequence?");
+
+ try (Stream stream = client.streamChatCompletionDeltas(request)) {
+ final List deltaList = stream.toList();
+
+ assertThat(deltaList).hasSize(5);
+
+ final var delta0 = deltaList.get(0).getOriginalResponse();
+ final var delta1 = deltaList.get(1).getOriginalResponse();
+ final var delta2 = deltaList.get(2).getOriginalResponse();
+ final var delta3 = deltaList.get(3).getOriginalResponse();
+ final var delta4 = deltaList.get(4).getOriginalResponse();
+
+ assertThat(delta0.getCreated()).isZero();
+ assertThat(delta1.getCreated()).isEqualTo(1724825677);
+ assertThat(delta2.getCreated()).isEqualTo(1724825677);
+ assertThat(delta3.getCreated()).isEqualTo(1724825677);
+ assertThat(delta4.getCreated()).isEqualTo(1724825677);
+
+ assertThat(delta0.getId()).isEmpty();
+ assertThat(delta1.getId()).isEqualTo("chatcmpl-A16EvnkgEm6AdxY0NoOmGPjsJucQ1");
+ assertThat(delta2.getId()).isEqualTo("chatcmpl-A16EvnkgEm6AdxY0NoOmGPjsJucQ1");
+ assertThat(delta3.getId()).isEqualTo("chatcmpl-A16EvnkgEm6AdxY0NoOmGPjsJucQ1");
+ assertThat(delta4.getId()).isEqualTo("chatcmpl-A16EvnkgEm6AdxY0NoOmGPjsJucQ1");
+
+ assertThat(delta0.getModel()).isEmpty();
+ assertThat(delta1.getModel()).isEqualTo("gpt-35-turbo");
+ assertThat(delta2.getModel()).isEqualTo("gpt-35-turbo");
+ assertThat(delta3.getModel()).isEqualTo("gpt-35-turbo");
+ assertThat(delta4.getModel()).isEqualTo("gpt-35-turbo");
+
+ assertThat(delta0.getObject()).isEqualTo(UNKNOWN_DEFAULT_OPEN_API);
+ assertThat(delta1.getObject()).isEqualTo(CHAT_COMPLETION_CHUNK);
+ assertThat(delta2.getObject()).isEqualTo(CHAT_COMPLETION_CHUNK);
+ assertThat(delta3.getObject()).isEqualTo(CHAT_COMPLETION_CHUNK);
+ assertThat(delta4.getObject()).isEqualTo(CHAT_COMPLETION_CHUNK);
+
+ assertThat(delta0.getSystemFingerprint()).isNull();
+ assertThat(delta1.getSystemFingerprint()).isEqualTo("fp_e49e4201a9");
+ assertThat(delta2.getSystemFingerprint()).isEqualTo("fp_e49e4201a9");
+ assertThat(delta3.getSystemFingerprint()).isEqualTo("fp_e49e4201a9");
+ assertThat(delta4.getSystemFingerprint()).isEqualTo("fp_e49e4201a9");
+
+ assertThat(delta0.getCustomFieldNames()).contains("prompt_filter_results");
+ assertThat(delta1.getCustomFieldNames()).doesNotContain("prompt_filter_results");
+ assertThat(delta2.getCustomFieldNames()).doesNotContain("prompt_filter_results");
+ assertThat(delta3.getCustomFieldNames()).doesNotContain("prompt_filter_results");
+ assertThat(delta4.getCustomFieldNames()).doesNotContain("prompt_filter_results");
+ var promptFilterResults = (List>) delta0.getCustomField("prompt_filter_results");
+ final var promptFilter0 =
+ MAPPER.convertValue(promptFilterResults.get(0), PromptFilterResult.class);
+ assertThat(promptFilter0).isNotNull();
+ assertThat(promptFilter0.getPromptIndex()).isZero();
+ assertFilter(promptFilter0.getContentFilterResults());
+ assertThat(promptFilter0.getContentFilterResults()).isNotNull();
+ assertFilter(promptFilter0.getContentFilterResults());
+
+ // delta0.choices
+ assertThat(delta0.getChoices()).isEmpty();
+
+ // delta1.choices
+ assertThat(delta1.getChoices()).hasSize(1);
+ var choice1 = delta1.getChoices().get(0);
+ assertThat(choice1.getFinishReason()).isNull();
+ assertThat(choice1.getIndex()).isZero();
+ assertThat(choice1.getDelta().getContent()).isEmpty();
+ assertThat(choice1.getDelta().getRole()).isEqualTo(ASSISTANT);
+ assertThat(choice1.getCustomField("content_filter_results")).isNotNull();
+ assertThat(choice1.getCustomField("content_filter_results")).isEqualTo(Map.of());
+
+ // delta2.choices
+ assertThat(delta2.getChoices()).hasSize(1);
+ var choice2 = delta2.getChoices().get(0);
+ assertThat(choice2.getFinishReason()).isNull();
+ assertThat(choice2.getIndex()).isZero();
+ assertThat(choice2.getDelta().getContent()).isEqualTo("Sure");
+ assertThat(choice2.getDelta().getRole()).isNull();
+ assertThat(choice2.getCustomField("content_filter_results")).isNotNull();
+ final var contentFilter2 =
+ MAPPER.convertValue(
+ choice2.getCustomField("content_filter_results"), ContentFilterPromptResults.class);
+ assertThat(contentFilter2).isNotNull();
+ assertFilter(contentFilter2);
+
+ // delta3.choices
+ assertThat(delta3.getChoices()).hasSize(1);
+ var choice3 = delta3.getChoices().get(0);
+ assertThat(choice3.getFinishReason()).isNull();
+ assertThat(choice3.getIndex()).isZero();
+ assertThat(choice3.getDelta().getContent()).isEqualTo("!");
+ assertThat(choice3.getDelta().getRole()).isNull();
+ assertThat(choice3.getCustomField("content_filter_results")).isNotNull();
+ var contentFilter3 =
+ MAPPER.convertValue(
+ choice3.getCustomField("content_filter_results"), ContentFilterPromptResults.class);
+ assertThat(contentFilter3).isNotNull();
+ assertFilter(contentFilter3);
+
+ // delta4.choices
+ assertThat(delta4.getChoices()).hasSize(1);
+ var choice4 = delta4.getChoices().get(0);
+ assertThat(choice4.getFinishReason())
+ .isEqualTo(CreateChatCompletionStreamResponseChoicesInner.FinishReasonEnum.STOP);
+ assertThat(choice4.getIndex()).isZero();
+ assertThat(choice4.getDelta().getContent()).isNull();
+ assertThat(choice4.getDelta().getRole()).isNull();
+ assertThat(choice4.getCustomField("content_filter_results")).isEqualTo(Map.of());
+ }
+
+ Mockito.verify(inputStream, times(1)).close();
+ }
+ }
+
+ @Test
+ void streamCompletionDeltasResponseConvenience() throws IOException {
+ try (var inputStream = stubStreamChatCompletion("streamChatCompletion.txt")) {
+
+ final var request =
+ new OpenAiChatCompletionRequest(
+ "Can you give me the first 100 numbers of the Fibonacci sequence?");
+
+ try (Stream stream = client.streamChatCompletionDeltas(request)) {
+ final List deltaList = stream.toList();
+
+ assertThat(deltaList).hasSize(5);
+
+ assertThat(deltaList.get(0).getFinishReason()).isNull();
+ assertThat(deltaList.get(1).getFinishReason()).isNull();
+ assertThat(deltaList.get(2).getFinishReason()).isNull();
+ assertThat(deltaList.get(3).getFinishReason()).isNull();
+ assertThat(deltaList.get(3).getFinishReason()).isNull();
+ assertThat(deltaList.get(4).getFinishReason()).isEqualTo("stop");
+
+ assertThat(deltaList.get(0).getDeltaContent()).isEmpty();
+ assertThat(deltaList.get(1).getDeltaContent()).isEmpty();
+ assertThat(deltaList.get(2).getDeltaContent()).isEqualTo("Sure");
+ assertThat(deltaList.get(3).getDeltaContent()).isEqualTo("!");
+ assertThat(deltaList.get(4).getDeltaContent()).isEmpty();
+
+ assertThat(deltaList.get(0).getCompletionUsage()).isNull();
+ assertThat(deltaList.get(1).getCompletionUsage()).isNull();
+ assertThat(deltaList.get(2).getCompletionUsage()).isNull();
+ assertThat(deltaList.get(3).getCompletionUsage()).isNull();
+ assertThat(deltaList.get(4).getCompletionUsage()).isNotNull();
+ assertThat(deltaList.get(4).getCompletionUsage().getCompletionTokens()).isEqualTo(607);
+ assertThat(deltaList.get(4).getCompletionUsage().getPromptTokens()).isEqualTo(21);
+ assertThat(deltaList.get(4).getCompletionUsage().getTotalTokens()).isEqualTo(628);
+ }
+
+ Mockito.verify(inputStream, times(1)).close();
+ }
+ }
+
+ void assertFilter(ContentFilterPromptResults filter) {
+ assertThat(filter).isNotNull();
+ assertThat(filter.getHate()).isNotNull();
+ assertThat(filter.getHate().isFiltered()).isFalse();
+ assertThat(filter.getHate().getSeverity()).isEqualTo(SAFE);
+ assertThat(filter.getSelfHarm()).isNotNull();
+ assertThat(filter.getSelfHarm().isFiltered()).isFalse();
+ assertThat(filter.getSelfHarm().getSeverity()).isEqualTo(SAFE);
+ assertThat(filter.getSexual()).isNotNull();
+ assertThat(filter.getSexual().isFiltered()).isFalse();
+ assertThat(filter.getSexual().getSeverity()).isEqualTo(SAFE);
+ assertThat(filter.getViolence()).isNotNull();
+ assertThat(filter.getViolence().isFiltered()).isFalse();
+ assertThat(filter.getViolence().getSeverity()).isEqualTo(SAFE);
+ assertThat(filter.getJailbreak()).isNull();
+ assertThat(filter.getProfanity()).isNull();
+ assertThat(filter.getError()).isNull();
+ }
+
+ @Test
+ void chatCompletionTool() {
+ stubForChatCompletionTool();
+
+ final var function =
+ new FunctionObject()
+ .name("fibonacci")
+ .parameters(
+ Map.of("type", "object", "properties", Map.of("N", Map.of("type", "integer"))));
+
+ final var tool =
+ new ChatCompletionTool().type(ChatCompletionTool.TypeEnum.FUNCTION).function(function);
+
+ final var toolChoice =
+ ChatCompletionToolChoiceOption.create(
+ new ChatCompletionNamedToolChoice()
+ .type(ChatCompletionNamedToolChoice.TypeEnum.FUNCTION)
+ .function(new ChatCompletionNamedToolChoiceFunction().name("fibonacci")));
+
+ final var request =
+ new OpenAiChatCompletionRequest(
+ "A pair of rabbits is placed in a field. Each month, every pair produces one new pair, starting from the second month. How many rabbits will there be after 12 months?")
+ .withTools(List.of(tool))
+ .withToolChoice(toolChoice);
+
+ var response = client.chatCompletion(request).getOriginalResponse();
+
+ assertThat(response).isNotNull();
+ assertThat(response.getChoices()).hasSize(1);
+ assertThat(response.getChoices().get(0).getFinishReason()).isEqualTo(STOP);
+ assertThat(response.getChoices().get(0).getMessage().getRole())
+ .isEqualTo(ChatCompletionResponseMessageRole.ASSISTANT);
+ assertThat(response.getChoices().get(0).getMessage().getToolCalls()).hasSize(1);
+ assertThat(response.getChoices().get(0).getMessage().getToolCalls().get(0).getId())
+ .isEqualTo("call_CUYGJf2j7FRWJMHT3PN3aGxK");
+ assertThat(response.getChoices().get(0).getMessage().getToolCalls().get(0).getType())
+ .isEqualTo(FUNCTION);
+ assertThat(
+ response.getChoices().get(0).getMessage().getToolCalls().get(0).getFunction().getName())
+ .isEqualTo("fibonacci");
+ assertThat(
+ response
+ .getChoices()
+ .get(0)
+ .getMessage()
+ .getToolCalls()
+ .get(0)
+ .getFunction()
+ .getArguments())
+ .isEqualTo("{\"N\":12}");
+
+ verify(
+ postRequestedFor(anyUrl())
+ .withRequestBody(
+ equalToJson(
+ """
+ {
+ "messages" : [ {
+ "content" : "A pair of rabbits is placed in a field. Each month, every pair produces one new pair, starting from the second month. How many rabbits will there be after 12 months?",
+ "role" : "user"
+ } ],
+ "tools" : [ {
+ "type" : "function",
+ "function" : {
+ "name" : "fibonacci",
+ "parameters" : {
+ "type" : "object",
+ "properties" : {
+ "N" : {
+ "type" : "integer"
+ }
+ }
+ },
+ "strict" : false
+ }
+ } ],
+ "tool_choice" : {
+ "type" : "function",
+ "function" : {
+ "name" : "fibonacci"
+ }
+ }
+ }
+ """)));
+ }
+}
diff --git a/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientTest.java b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientTest.java
index 67c1c8a24..eb57255a2 100644
--- a/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientTest.java
+++ b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientTest.java
@@ -1,180 +1,54 @@
package com.sap.ai.sdk.foundationmodels.openai;
import static com.github.tomakehurst.wiremock.client.WireMock.*;
+import static com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionTool.ToolType.FUNCTION;
+import static com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage.*;
import static com.sap.ai.sdk.foundationmodels.openai.model.OpenAiContentFilterSeverityResult.Severity.SAFE;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.Mockito.doReturn;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
-import static org.mockito.Mockito.when;
-import com.fasterxml.jackson.core.JsonParseException;
-import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo;
-import com.github.tomakehurst.wiremock.junit5.WireMockTest;
-import com.github.tomakehurst.wiremock.stubbing.Scenario;
-import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionChoice;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionDelta;
+import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionFunction;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionOutput;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionParameters;
-import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage.OpenAiChatSystemMessage;
-import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage.OpenAiChatUserMessage;
+import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionTool;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiContentFilterPromptResults;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiEmbeddingParameters;
-import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor;
-import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Cache;
-import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination;
-import io.vavr.control.Try;
import java.io.IOException;
-import java.io.InputStream;
import java.util.List;
-import java.util.Objects;
+import java.util.Map;
import java.util.concurrent.Callable;
-import java.util.function.Function;
import java.util.stream.Stream;
import javax.annotation.Nonnull;
import lombok.SneakyThrows;
-import org.apache.hc.client5.http.classic.HttpClient;
-import org.apache.hc.core5.http.ContentType;
-import org.apache.hc.core5.http.io.entity.InputStreamEntity;
-import org.apache.hc.core5.http.message.BasicClassicHttpResponse;
-import org.assertj.core.api.SoftAssertions;
-import org.junit.jupiter.api.AfterEach;
-import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.Mockito;
-@WireMockTest
-class OpenAiClientTest {
- private static OpenAiClient client;
- private final Function fileLoader =
- filename -> Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream(filename));
-
- @BeforeEach
- void setup(WireMockRuntimeInfo server) {
- final DefaultHttpDestination destination =
- DefaultHttpDestination.builder(server.getHttpBaseUrl()).build();
- client = OpenAiClient.withCustomDestination(destination);
- ApacheHttpClient5Accessor.setHttpClientCache(ApacheHttpClient5Cache.DISABLED);
- }
-
- @AfterEach
- void reset() {
- ApacheHttpClient5Accessor.setHttpClientCache(null);
- ApacheHttpClient5Accessor.setHttpClientFactory(null);
- }
-
- @Test
- void apiVersion() {
- stubFor(post(anyUrl()).willReturn(okJson("{}")));
- Try.of(() -> client.chatCompletion(new OpenAiChatCompletionParameters()));
-
- verify(
- exactly(1),
- postRequestedFor(anyUrl()).withQueryParam("api-version", equalTo("2024-02-01")));
-
- Try.of(
- () -> client.withApiVersion("fooBar").chatCompletion(new OpenAiChatCompletionParameters()));
- verify(exactly(1), postRequestedFor(anyUrl()).withQueryParam("api-version", equalTo("fooBar")));
-
- assertThat(client)
- .describedAs(
- "withApiVersion should return a new object, the sut object should remain unchanged")
- .isNotSameAs(client.withApiVersion("fooBar"));
- Try.of(() -> client.chatCompletion(new OpenAiChatCompletionParameters()));
- verify(
- exactly(2),
- postRequestedFor(anyUrl()).withQueryParam("api-version", equalTo("2024-02-01")));
- }
+class OpenAiClientTest extends BaseOpenAiClientTest {
private static Runnable[] errorHandlingCalls() {
return new Runnable[] {
- () -> client.chatCompletion(new OpenAiChatCompletionParameters()),
+ () -> client.chatCompletion(""),
() ->
client
- .streamChatCompletionDeltas(new OpenAiChatCompletionParameters())
+ .streamChatCompletionDeltas(
+ new OpenAiChatCompletionParameters()
+ .addMessages(new OpenAiChatUserMessage().addText("")))
// the stream needs to be consumed to parse the response
- .forEach(System.out::println)
+ .forEach(System.out::println),
};
}
@ParameterizedTest
@MethodSource("errorHandlingCalls")
void chatCompletionErrorHandling(@Nonnull final Runnable request) {
- final var errorJson =
- """
- { "error": { "code": null, "message": "foo", "type": "invalid stuff" } }
- """;
- stubFor(
- post(anyUrl())
- .inScenario("Errors")
- .whenScenarioStateIs(Scenario.STARTED)
- .willReturn(serverError())
- .willSetStateTo("1"));
- stubFor(
- post(anyUrl())
- .inScenario("Errors")
- .whenScenarioStateIs("1")
- .willReturn(
- badRequest().withBody(errorJson).withHeader("Content-type", "application/json"))
- .willSetStateTo("2"));
- stubFor(
- post(anyUrl())
- .inScenario("Errors")
- .whenScenarioStateIs("2")
- .willReturn(
- badRequest()
- .withBody("{ broken json")
- .withHeader("Content-type", "application/json"))
- .willSetStateTo("3"));
- stubFor(
- post(anyUrl())
- .inScenario("Errors")
- .whenScenarioStateIs("3")
- .willReturn(okXml(""))
- .willSetStateTo("4"));
- stubFor(post(anyUrl()).inScenario("Errors").whenScenarioStateIs("4").willReturn(noContent()));
-
- final var softly = new SoftAssertions();
-
- softly
- .assertThatThrownBy(request::run)
- .describedAs("Server errors should be handled")
- .isInstanceOf(OpenAiClientException.class)
- .hasMessageContaining("500");
-
- softly
- .assertThatThrownBy(request::run)
- .describedAs("Error objects from OpenAI should be interpreted")
- .isInstanceOf(OpenAiClientException.class)
- .hasMessageContaining("error message: 'foo'");
-
- softly
- .assertThatThrownBy(request::run)
- .describedAs("Failures while parsing error message should be handled")
- .isInstanceOf(OpenAiClientException.class)
- .hasMessageContaining("400")
- .extracting(e -> e.getSuppressed()[0])
- .isInstanceOf(JsonParseException.class);
-
- softly
- .assertThatThrownBy(request::run)
- .describedAs("Non-JSON responses should be handled")
- .isInstanceOf(OpenAiClientException.class)
- .hasMessageContaining("Failed to parse");
-
- softly
- .assertThatThrownBy(request::run)
- .describedAs("Empty responses should be handled")
- .isInstanceOf(OpenAiClientException.class)
- .hasMessageContaining("was empty");
-
- softly.assertAll();
+
+ stubForErrorHandling();
+ assertForErrorHandling(request);
}
private static Callable>[] chatCompletionCalls() {
@@ -198,97 +72,94 @@ private static Callable>[] chatCompletionCalls() {
@ParameterizedTest
@MethodSource("chatCompletionCalls")
void chatCompletion(@Nonnull final Callable request) {
- try (var inputStream = fileLoader.apply("__files/chatCompletionResponse.json")) {
-
- final String response = new String(inputStream.readAllBytes());
- // with query parameter api-version=2024-02-01
- stubFor(
- post(urlPathEqualTo("/chat/completions"))
- .withQueryParam("api-version", equalTo("2024-02-01"))
- .willReturn(okJson(response)));
-
- final OpenAiChatCompletionOutput result = request.call();
-
- assertThat(result).isNotNull();
- assertThat(result.getCreated()).isEqualTo(1727436279);
- assertThat(result.getId()).isEqualTo("chatcmpl-AC3NPPYlxem8kRBBAX9EBObMMsrnf");
- assertThat(result.getModel()).isEqualTo("gpt-35-turbo");
- assertThat(result.getObject()).isEqualTo("chat.completion");
- assertThat(result.getSystemFingerprint()).isEqualTo("fp_e49e4201a9");
-
- assertThat(result.getUsage()).isNotNull();
- assertThat(result.getUsage().getCompletionTokens()).isEqualTo(20);
- assertThat(result.getUsage().getPromptTokens()).isEqualTo(13);
- assertThat(result.getUsage().getTotalTokens()).isEqualTo(33);
-
- assertThat(result.getPromptFilterResults()).hasSize(1);
- assertThat(result.getPromptFilterResults().get(0).getPromptIndex()).isEqualTo(0);
- OpenAiContentFilterPromptResults promptFilterResults =
- result.getPromptFilterResults().get(0).getContentFilterResults();
- assertThat(promptFilterResults).isNotNull();
- assertThat(promptFilterResults.getSexual()).isNotNull();
- assertThat(promptFilterResults.getSexual().isFiltered()).isFalse();
- assertThat(promptFilterResults.getSexual().getSeverity()).isEqualTo(SAFE);
- assertThat(promptFilterResults.getViolence()).isNotNull();
- assertThat(promptFilterResults.getViolence().isFiltered()).isFalse();
- assertThat(promptFilterResults.getViolence().getSeverity()).isEqualTo(SAFE);
- assertThat(promptFilterResults.getHate()).isNotNull();
- assertThat(promptFilterResults.getHate().isFiltered()).isFalse();
- assertThat(promptFilterResults.getHate().getSeverity()).isEqualTo(SAFE);
- assertThat(promptFilterResults.getSelfHarm()).isNotNull();
- assertThat(promptFilterResults.getSelfHarm().getSeverity()).isEqualTo(SAFE);
- assertThat(promptFilterResults.getSelfHarm().isFiltered()).isFalse();
- assertThat(promptFilterResults.getProfanity()).isNull();
- assertThat(promptFilterResults.getError()).isNull();
- assertThat(promptFilterResults.getJailbreak()).isNull();
-
- assertThat(result.getChoices()).hasSize(1);
- OpenAiChatCompletionChoice choice = result.getChoices().get(0);
- assertThat(choice.getFinishReason()).isEqualTo("stop");
- assertThat(choice.getIndex()).isEqualTo(0);
- assertThat(choice.getMessage().getContent())
- .isEqualTo(
- "I'm an AI and cannot answer that question as beauty is subjective and varies from person to person.");
- assertThat(choice.getMessage().getRole()).isEqualTo("assistant");
- assertThat(choice.getMessage().getToolCalls()).isNull();
-
- OpenAiContentFilterPromptResults contentFilterResults = choice.getContentFilterResults();
- assertThat(contentFilterResults).isNotNull();
- assertThat(contentFilterResults.getSexual()).isNotNull();
- assertThat(contentFilterResults.getSexual().isFiltered()).isFalse();
- assertThat(contentFilterResults.getSexual().getSeverity()).isEqualTo(SAFE);
- assertThat(contentFilterResults.getViolence()).isNotNull();
- assertThat(contentFilterResults.getViolence().isFiltered()).isFalse();
- assertThat(contentFilterResults.getViolence().getSeverity()).isEqualTo(SAFE);
- assertThat(contentFilterResults.getHate()).isNotNull();
- assertThat(contentFilterResults.getHate().isFiltered()).isFalse();
- assertThat(contentFilterResults.getHate().getSeverity()).isEqualTo(SAFE);
- assertThat(contentFilterResults.getSelfHarm()).isNotNull();
- assertThat(contentFilterResults.getSelfHarm().getSeverity()).isEqualTo(SAFE);
- assertThat(contentFilterResults.getSelfHarm().isFiltered()).isFalse();
- assertThat(contentFilterResults.getProfanity()).isNull();
- assertThat(contentFilterResults.getError()).isNull();
- assertThat(contentFilterResults.getJailbreak()).isNull();
-
- verify(
- postRequestedFor(urlPathEqualTo("/chat/completions"))
- .withQueryParam("api-version", equalTo("2024-02-01"))
- .withRequestBody(
- equalToJson(
- """
- {
- "messages" : [ {
- "role" : "system",
- "content" : "You are a helpful AI"
- }, {
- "role" : "user",
- "content" : [ {
- "type" : "text",
- "text" : "Hello World! Why is this phrase so famous?"
- } ]
- } ]
- }""")));
- }
+
+ stubForChatCompletion();
+
+ var result = request.call();
+
+ assertThat(result).isNotNull();
+ assertThat(result.getCreated()).isEqualTo(1727436279);
+ assertThat(result.getId()).isEqualTo("chatcmpl-AC3NPPYlxem8kRBBAX9EBObMMsrnf");
+ assertThat(result.getModel()).isEqualTo("gpt-35-turbo");
+ assertThat(result.getObject()).isEqualTo("chat.completion");
+ assertThat(result.getSystemFingerprint()).isEqualTo("fp_e49e4201a9");
+
+ assertThat(result.getUsage()).isNotNull();
+ assertThat(result.getUsage().getCompletionTokens()).isEqualTo(20);
+ assertThat(result.getUsage().getPromptTokens()).isEqualTo(13);
+ assertThat(result.getUsage().getTotalTokens()).isEqualTo(33);
+
+ assertThat(result.getPromptFilterResults()).hasSize(1);
+ assertThat(result.getPromptFilterResults().get(0).getPromptIndex()).isEqualTo(0);
+ assertThat(result.getContent())
+ .isEqualTo(
+ "I'm an AI and cannot answer that question as beauty is subjective and varies from person to person.");
+
+ var promptFilterResults = result.getPromptFilterResults().get(0).getContentFilterResults();
+ assertThat(promptFilterResults).isNotNull();
+ assertThat(promptFilterResults.getSexual()).isNotNull();
+ assertThat(promptFilterResults.getSexual().isFiltered()).isFalse();
+ assertThat(promptFilterResults.getSexual().getSeverity()).isEqualTo(SAFE);
+ assertThat(promptFilterResults.getViolence()).isNotNull();
+ assertThat(promptFilterResults.getViolence().isFiltered()).isFalse();
+ assertThat(promptFilterResults.getViolence().getSeverity()).isEqualTo(SAFE);
+ assertThat(promptFilterResults.getHate()).isNotNull();
+ assertThat(promptFilterResults.getHate().isFiltered()).isFalse();
+ assertThat(promptFilterResults.getHate().getSeverity()).isEqualTo(SAFE);
+ assertThat(promptFilterResults.getSelfHarm()).isNotNull();
+ assertThat(promptFilterResults.getSelfHarm().getSeverity()).isEqualTo(SAFE);
+ assertThat(promptFilterResults.getSelfHarm().isFiltered()).isFalse();
+ assertThat(promptFilterResults.getProfanity()).isNull();
+ assertThat(promptFilterResults.getError()).isNull();
+ assertThat(promptFilterResults.getJailbreak()).isNull();
+
+ assertThat(result.getChoices()).hasSize(1);
+
+ var choice = result.getChoices().get(0);
+ assertThat(choice.getFinishReason()).isEqualTo("stop");
+ assertThat(choice.getIndex()).isZero();
+ assertThat(choice.getMessage().getContent())
+ .isEqualTo(
+ "I'm an AI and cannot answer that question as beauty is subjective and varies from person to person.");
+ assertThat(choice.getMessage().getRole()).isEqualTo("assistant");
+ assertThat(choice.getMessage().getToolCalls()).isNull();
+
+ var contentFilterResults = choice.getContentFilterResults();
+ assertThat(contentFilterResults).isNotNull();
+ assertThat(contentFilterResults.getSexual()).isNotNull();
+ assertThat(contentFilterResults.getSexual().isFiltered()).isFalse();
+ assertThat(contentFilterResults.getSexual().getSeverity()).isEqualTo(SAFE);
+ assertThat(contentFilterResults.getViolence()).isNotNull();
+ assertThat(contentFilterResults.getViolence().isFiltered()).isFalse();
+ assertThat(contentFilterResults.getViolence().getSeverity()).isEqualTo(SAFE);
+ assertThat(contentFilterResults.getHate()).isNotNull();
+ assertThat(contentFilterResults.getHate().isFiltered()).isFalse();
+ assertThat(contentFilterResults.getHate().getSeverity()).isEqualTo(SAFE);
+ assertThat(contentFilterResults.getSelfHarm()).isNotNull();
+ assertThat(contentFilterResults.getSelfHarm().getSeverity()).isEqualTo(SAFE);
+ assertThat(contentFilterResults.getSelfHarm().isFiltered()).isFalse();
+ assertThat(contentFilterResults.getProfanity()).isNull();
+ assertThat(contentFilterResults.getError()).isNull();
+ assertThat(contentFilterResults.getJailbreak()).isNull();
+
+ verify(
+ postRequestedFor(urlPathEqualTo("/chat/completions"))
+ .withQueryParam("api-version", equalTo("2024-02-01"))
+ .withRequestBody(
+ equalToJson(
+ """
+ {
+ "messages" : [ {
+ "role" : "system",
+ "content" : "You are a helpful AI"
+ }, {
+ "role" : "user",
+ "content" : [ {
+ "type" : "text",
+ "text" : "Hello World! Why is this phrase so famous?"
+ } ]
+ } ]
+ }""")));
}
@Test
@@ -301,7 +172,7 @@ void history() {
.withBodyFile("chatCompletionResponse.json")
.withHeader("Content-Type", "application/json")));
- client.withSystemPrompt("system prompt").chatCompletion("chat completion 1");
+ client.chatCompletion("First message");
verify(
exactly(1),
@@ -311,18 +182,20 @@ void history() {
"""
{
"messages" : [ {
- "role" : "system",
- "content" : "system prompt"
- }, {
"role" : "user",
"content" : [ {
"type" : "text",
- "text" : "chat completion 1"
+ "text" : "First message"
} ]
} ]
}""")));
- client.withSystemPrompt("system prompt").chatCompletion("chat completion 2");
+ var response = client.chatCompletion("Second message");
+
+ assertThat(response.getContent()).isNotNull();
+ assertThat(response.getContent())
+ .isEqualTo(
+ "I'm an AI and cannot answer that question as beauty is subjective and varies from person to person.");
verify(
exactly(1),
@@ -331,27 +204,19 @@ void history() {
equalToJson(
"""
{
- "messages" : [ {
- "role" : "system",
- "content" : "system prompt"
- }, {
- "role" : "user",
- "content" : [ {
- "type" : "text",
- "text" : "chat completion 2"
- } ]
- } ]
- }""")));
+ "messages" : [ {
+ "role" : "user",
+ "content" : [ {
+ "type" : "text",
+ "text" : "Second message"
+ } ]
+ } ]
+ }""")));
}
@Test
void embedding() {
- stubFor(
- post(urlPathEqualTo("/embeddings"))
- .willReturn(
- aResponse()
- .withBodyFile("embeddingResponse.json")
- .withHeader("Content-Type", "application/json")));
+ stubForEmbedding();
final var request = new OpenAiEmbeddingParameters().setInput("Hello World");
final var result = client.embedding(request);
@@ -369,7 +234,7 @@ void embedding() {
var embeddingData = result.getData().get(0);
assertThat(embeddingData).isNotNull();
assertThat(embeddingData.getObject()).isEqualTo("embedding");
- assertThat(embeddingData.getIndex()).isEqualTo(0);
+ assertThat(embeddingData.getIndex()).isZero();
assertThat(embeddingData.getEmbedding())
.isNotNull()
.isNotEmpty()
@@ -380,40 +245,12 @@ void embedding() {
.withRequestBody(
equalToJson(
"""
- {"input":["Hello World"]}""")));
- }
-
- @Test
- void testThrowsOnContentFilter() {
- var mock = mock(OpenAiClient.class);
- when(mock.streamChatCompletion(any())).thenCallRealMethod();
-
- var deltaWithContentFilter = mock(OpenAiChatCompletionDelta.class);
- when(deltaWithContentFilter.getFinishReason()).thenReturn("content_filter");
- when(mock.streamChatCompletionDeltas(any())).thenReturn(Stream.of(deltaWithContentFilter));
-
- // this must not throw, since the stream is lazily evaluated
- var stream = mock.streamChatCompletion("");
- assertThatThrownBy(stream::toList)
- .isInstanceOf(OpenAiClientException.class)
- .hasMessageContaining("Content filter");
+ {"input": ["Hello World"]}""")));
}
@Test
void streamChatCompletionDeltasErrorHandling() throws IOException {
- try (var inputStream = spy(fileLoader.apply("streamChatCompletionError.txt"))) {
-
- final var httpClient = mock(HttpClient.class);
- ApacheHttpClient5Accessor.setHttpClientFactory(destination -> httpClient);
-
- // Create a mock response
- final var mockResponse = new BasicClassicHttpResponse(200, "OK");
- final var inputStreamEntity = new InputStreamEntity(inputStream, ContentType.TEXT_PLAIN);
- mockResponse.setEntity(inputStreamEntity);
- mockResponse.setHeader("Content-Type", "text/event-stream");
-
- // Configure the HttpClient mock to return the mock response
- doReturn(mockResponse).when(httpClient).executeOpen(any(), any(), any());
+ try (var inputStream = stubStreamChatCompletion("streamChatCompletionError.txt")) {
final var request =
new OpenAiChatCompletionParameters()
@@ -421,7 +258,7 @@ void streamChatCompletionDeltasErrorHandling() throws IOException {
new OpenAiChatUserMessage()
.addText("Can you give me the first 100 numbers of the Fibonacci sequence?"));
- try (Stream stream = client.streamChatCompletionDeltas(request)) {
+ try (var stream = client.streamChatCompletionDeltas(request)) {
assertThatThrownBy(() -> stream.forEach(System.out::println))
.isInstanceOf(OpenAiClientException.class)
.hasMessage("Failed to parse response and error message: 'exceeded token rate limit'");
@@ -433,19 +270,7 @@ void streamChatCompletionDeltasErrorHandling() throws IOException {
@Test
void streamChatCompletionDeltas() throws IOException {
- try (var inputStream = spy(fileLoader.apply("streamChatCompletion.txt"))) {
-
- final var httpClient = mock(HttpClient.class);
- ApacheHttpClient5Accessor.setHttpClientFactory(destination -> httpClient);
-
- // Create a mock response
- final var mockResponse = new BasicClassicHttpResponse(200, "OK");
- final var inputStreamEntity = new InputStreamEntity(inputStream, ContentType.TEXT_PLAIN);
- mockResponse.setEntity(inputStreamEntity);
- mockResponse.setHeader("Content-Type", "text/event-stream");
-
- // Configure the HttpClient mock to return the mock response
- doReturn(mockResponse).when(httpClient).executeOpen(any(), any(), any());
+ try (var inputStream = stubStreamChatCompletion("streamChatCompletion.txt")) {
final var request =
new OpenAiChatCompletionParameters()
@@ -461,11 +286,11 @@ void streamChatCompletionDeltas() throws IOException {
assertThat(deltaList).hasSize(5);
// the first two and the last delta don't have any content
- assertThat(deltaList.get(0).getDeltaContent()).isEqualTo("");
- assertThat(deltaList.get(1).getDeltaContent()).isEqualTo("");
+ assertThat(deltaList.get(0).getDeltaContent()).isEmpty();
+ assertThat(deltaList.get(1).getDeltaContent()).isEmpty();
assertThat(deltaList.get(2).getDeltaContent()).isEqualTo("Sure");
assertThat(deltaList.get(3).getDeltaContent()).isEqualTo("!");
- assertThat(deltaList.get(4).getDeltaContent()).isEqualTo("");
+ assertThat(deltaList.get(4).getDeltaContent()).isEmpty();
assertThat(deltaList.get(0).getSystemFingerprint()).isNull();
assertThat(deltaList.get(1).getSystemFingerprint()).isEqualTo("fp_e49e4201a9");
@@ -490,15 +315,16 @@ void streamChatCompletionDeltas() throws IOException {
assertThat(deltaList.get(4).getChoices()).hasSize(1);
final var delta0 = deltaList.get(0);
- assertThat(delta0.getId()).isEqualTo("");
- assertThat(delta0.getCreated()).isEqualTo(0);
- assertThat(delta0.getModel()).isEqualTo("");
- assertThat(delta0.getObject()).isEqualTo("");
+ assertThat(delta0.getId()).isEmpty();
+ assertThat(delta0.getCreated()).isZero();
+ assertThat(delta0.getModel()).isEmpty();
+ assertThat(delta0.getObject()).isEmpty();
assertThat(delta0.getUsage()).isNull();
assertThat(delta0.getChoices()).isEmpty();
+ assertThat(delta0.getFinishReason()).isNull();
// prompt filter results are only present in delta 0
assertThat(delta0.getPromptFilterResults()).isNotNull();
- assertThat(delta0.getPromptFilterResults().get(0).getPromptIndex()).isEqualTo(0);
+ assertThat(delta0.getPromptFilterResults().get(0).getPromptIndex()).isZero();
final var promptFilter0 = delta0.getPromptFilterResults().get(0).getContentFilterResults();
assertThat(promptFilter0).isNotNull();
assertFilter(promptFilter0);
@@ -510,9 +336,10 @@ void streamChatCompletionDeltas() throws IOException {
assertThat(delta2.getObject()).isEqualTo("chat.completion.chunk");
assertThat(delta2.getUsage()).isNull();
assertThat(delta2.getPromptFilterResults()).isNull();
+ assertThat(delta2.getFinishReason()).isNull();
+
final var choices2 = delta2.getChoices().get(0);
assertThat(choices2.getIndex()).isEqualTo(0);
- assertThat(choices2.getFinishReason()).isNull();
assertThat(choices2.getMessage()).isNotNull();
// the role is only defined in delta 1, but it defaults to "assistant" for all deltas
assertThat(choices2.getMessage().getRole()).isEqualTo("assistant");
@@ -524,8 +351,8 @@ void streamChatCompletionDeltas() throws IOException {
final var delta3 = deltaList.get(3);
assertThat(delta3.getDeltaContent()).isEqualTo("!");
+ assertThat(deltaList.get(4).getFinishReason()).isEqualTo("stop");
final var delta4Choice = deltaList.get(4).getChoices().get(0);
- assertThat(delta4Choice.getFinishReason()).isEqualTo("stop");
assertThat(delta4Choice.getMessage()).isNotNull();
// the role is only defined in delta 1, but it defaults to "assistant" for all deltas
assertThat(delta4Choice.getMessage().getRole()).isEqualTo("assistant");
@@ -576,4 +403,81 @@ void assertFilter(OpenAiContentFilterPromptResults filter) {
assertThat(filter.getProfanity()).isNull();
assertThat(filter.getError()).isNull();
}
+
+ @Test
+ void chatCompletionTool() {
+ stubForChatCompletionTool();
+
+ final var question =
+ "A pair of rabbits is placed in a field. Each month, every pair produces one new pair, starting from the second month. How many rabbits will there be after 12 months?";
+ final var par = Map.of("type", "object", "properties", Map.of("N", Map.of("type", "integer")));
+ final var function = new OpenAiChatCompletionFunction().setName("fibonacci").setParameters(par);
+ final var tool = new OpenAiChatCompletionTool().setType(FUNCTION).setFunction(function);
+ final var request =
+ new OpenAiChatCompletionParameters()
+ .addMessages(new OpenAiChatUserMessage().addText(question))
+ .setTools(List.of(tool))
+ .setToolChoiceFunction("fibonacci");
+
+ var response = client.chatCompletion(request);
+
+ assertThat(response).isNotNull();
+ assertThat(response.getChoices()).hasSize(1);
+ assertThat(response.getChoices().get(0).getFinishReason()).isEqualTo("stop");
+ assertThat(response.getChoices().get(0).getMessage().getRole()).isEqualTo("assistant");
+ assertThat(response.getChoices().get(0).getMessage().getToolCalls()).hasSize(1);
+ assertThat(response.getChoices().get(0).getMessage().getToolCalls().get(0).getId())
+ .isEqualTo("call_CUYGJf2j7FRWJMHT3PN3aGxK");
+ assertThat(response.getChoices().get(0).getMessage().getToolCalls().get(0).getType())
+ .isEqualTo("function");
+ assertThat(
+ response.getChoices().get(0).getMessage().getToolCalls().get(0).getFunction().getName())
+ .isEqualTo("fibonacci");
+ assertThat(
+ response
+ .getChoices()
+ .get(0)
+ .getMessage()
+ .getToolCalls()
+ .get(0)
+ .getFunction()
+ .getArguments())
+ .isEqualTo("{\"N\":12}");
+
+ verify(
+ postRequestedFor(anyUrl())
+ .withRequestBody(
+ equalToJson(
+ """
+ {
+ "messages" : [ {
+ "role" : "user",
+ "content" : [ {
+ "type" : "text",
+ "text" : "A pair of rabbits is placed in a field. Each month, every pair produces one new pair, starting from the second month. How many rabbits will there be after 12 months?"
+ } ]
+ } ],
+ "tools" : [ {
+ "type" : "function",
+ "function" : {
+ "name" : "fibonacci",
+ "parameters" : {
+ "type" : "object",
+ "properties" : {
+ "N" : {
+ "type" : "integer"
+ }
+ }
+ }
+ }
+ } ],
+ "tool_choice" : {
+ "function" : {
+ "name" : "fibonacci"
+ },
+ "type" : "function"
+ }
+ }
+ """)));
+ }
}
diff --git a/foundation-models/openai/src/test/resources/__files/chatCompletionToolResponse.json b/foundation-models/openai/src/test/resources/__files/chatCompletionToolResponse.json
new file mode 100644
index 000000000..f2cfab223
--- /dev/null
+++ b/foundation-models/openai/src/test/resources/__files/chatCompletionToolResponse.json
@@ -0,0 +1,56 @@
+{
+ "choices": [
+ {
+ "content_filter_results": {},
+ "finish_reason": "stop",
+ "index": 0,
+ "message": {
+ "content": null,
+ "role": "assistant",
+ "tool_calls": [
+ {
+ "function": {
+ "arguments": "{\"N\":12}",
+ "name": "fibonacci"
+ },
+ "id": "call_CUYGJf2j7FRWJMHT3PN3aGxK",
+ "type": "function"
+ }
+ ]
+ }
+ }
+ ],
+ "created": 1738867288,
+ "id": "chatcmpl-Ay16WxAjYvRDJ12ImPi8E1T1TZHoP",
+ "model": "gpt-3.5-turbo-1106",
+ "object": "chat.completion",
+ "prompt_filter_results": [
+ {
+ "content_filter_results": {
+ "hate": {
+ "filtered": false,
+ "severity": "safe"
+ },
+ "self_harm": {
+ "filtered": false,
+ "severity": "safe"
+ },
+ "sexual": {
+ "filtered": false,
+ "severity": "safe"
+ },
+ "violence": {
+ "filtered": false,
+ "severity": "safe"
+ }
+ },
+ "prompt_index": 0
+ }
+ ],
+ "system_fingerprint": "fp_0165350fbb",
+ "usage": {
+ "completion_tokens": 5,
+ "prompt_tokens": 93,
+ "total_tokens": 98
+ }
+}
\ No newline at end of file
diff --git a/sample-code/spring-app/pom.xml b/sample-code/spring-app/pom.xml
index 34d2be7fe..a8c348c85 100644
--- a/sample-code/spring-app/pom.xml
+++ b/sample-code/spring-app/pom.xml
@@ -114,10 +114,6 @@
com.fasterxml.jackson.core
jackson-databind
-
- com.fasterxml.jackson.core
- jackson-core
-
com.fasterxml.jackson.core
jackson-annotations
diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OpenAiController.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OpenAiController.java
index 26ea351bc..e48b67b70 100644
--- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OpenAiController.java
+++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OpenAiController.java
@@ -1,12 +1,14 @@
package com.sap.ai.sdk.app.controllers;
-import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.annotation.JsonAutoDetect;
+import com.fasterxml.jackson.annotation.PropertyAccessor;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.sap.ai.sdk.app.services.OpenAiService;
-import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionOutput;
+import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiUsage;
import com.sap.cloud.sdk.cloudplatform.thread.ThreadContextExecutors;
import java.io.IOException;
import java.util.Arrays;
+import java.util.concurrent.atomic.AtomicReference;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import lombok.extern.slf4j.Slf4j;
@@ -25,6 +27,8 @@
@SuppressWarnings("unused")
public class OpenAiController {
@Autowired private OpenAiService service;
+ private static final ObjectMapper MAPPER =
+ new ObjectMapper().setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
@GetMapping("/chatCompletion")
@Nonnull
@@ -44,16 +48,19 @@ ResponseEntity streamChatCompletionDeltas() {
final var message = "Can you give me the first 100 numbers of the Fibonacci sequence?";
final var stream = service.streamChatCompletionDeltas(message);
final var emitter = new ResponseBodyEmitter();
+ final var totalUsage = new AtomicReference();
final Runnable consumeStream =
() -> {
- final var totalOutput = new OpenAiChatCompletionOutput();
- // try-with-resources ensures the stream is closed
try (stream) {
- stream
- .peek(totalOutput::addDelta)
- .forEach(delta -> send(emitter, delta.getDeltaContent()));
+ stream.forEach(
+ delta -> {
+ // Instead of getCompletionUsage(MAPPER), we now use getUsage()
+ final var usage = delta.getUsage();
+ totalUsage.compareAndExchange(null, usage);
+ send(emitter, delta.getDeltaContent());
+ });
} finally {
- send(emitter, "\n\n-----Total Output-----\n\n" + objectToJson(totalOutput));
+ send(emitter, "\n\n-----Total Usage-----\n\n" + totalUsage.get());
emitter.complete();
}
};
@@ -100,20 +107,6 @@ public static void send(@Nonnull final ResponseBodyEmitter emitter, @Nonnull fin
}
}
- /**
- * Convert an object to JSON
- *
- * @param obj The object to convert
- * @return The JSON representation of the object
- */
- private static String objectToJson(@Nonnull final Object obj) {
- try {
- return new ObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(obj);
- } catch (final JsonProcessingException ignored) {
- return "Could not parse object to JSON";
- }
- }
-
@GetMapping("/chatCompletionImage")
@Nonnull
Object chatCompletionImage(
@@ -124,15 +117,14 @@ Object chatCompletionImage(
if ("json".equals(format)) {
return response;
}
- return response.getContent();
+ return response.getChoices().get(0).getMessage();
}
@GetMapping("/chatCompletionTool")
@Nonnull
Object chatCompletionTools(
@Nullable @RequestParam(value = "format", required = false) final String format) {
- final var response =
- service.chatCompletionTools("Calculate the Fibonacci number for given sequence index.");
+ final var response = service.chatCompletionTools(12);
if ("json".equals(format)) {
return response;
}
diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OpenAiService.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OpenAiService.java
index 5d5f26209..0032093d6 100644
--- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OpenAiService.java
+++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OpenAiService.java
@@ -88,18 +88,19 @@ public OpenAiChatCompletionOutput chatCompletionImage(@Nonnull final String link
/**
* Chat request to OpenAI with a tool.
*
- * @param prompt The prompt to send to the assistant
+ * @param months The number of months to be inferred in the tool
* @return the assistant message response
*/
@Nonnull
- public OpenAiChatCompletionOutput chatCompletionTools(@Nonnull final String prompt) {
+ public OpenAiChatCompletionOutput chatCompletionTools(final int months) {
final var question =
- "A pair of rabbits is placed in a field. Each month, every pair produces one new pair, starting from the second month. How many rabbits will there be after 12 months?";
+ "A pair of rabbits is placed in a field. Each month, every pair produces one new pair, starting from the second month. How many rabbits will there be after %s months?"
+ .formatted(months);
final var par = Map.of("type", "object", "properties", Map.of("N", Map.of("type", "integer")));
final var function =
new OpenAiChatCompletionFunction()
.setName("fibonacci")
- .setDescription(prompt)
+ .setDescription("Calculate the Fibonacci number for given sequence index.")
.setParameters(par);
final var tool = new OpenAiChatCompletionTool().setType(FUNCTION).setFunction(function);
final var request =
diff --git a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/NewOpenAiTest.java b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/NewOpenAiTest.java
new file mode 100644
index 000000000..89c43ca0a
--- /dev/null
+++ b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/NewOpenAiTest.java
@@ -0,0 +1,103 @@
+package com.sap.ai.sdk.app.controllers;
+
+import static com.sap.ai.sdk.foundationmodels.openai.OpenAiModel.GPT_35_TURBO;
+import static com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionResponseMessageRole.ASSISTANT;
+import static org.assertj.core.api.Assertions.assertThat;
+
+import com.sap.ai.sdk.app.services.NewOpenAiService;
+import com.sap.ai.sdk.foundationmodels.openai.OpenAiChatCompletionRequest;
+import com.sap.ai.sdk.foundationmodels.openai.OpenAiClient;
+import com.sap.ai.sdk.foundationmodels.openai.OpenAiMessage;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.CompletionUsage;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicReference;
+import lombok.extern.slf4j.Slf4j;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+@Slf4j
+class NewOpenAiTest {
+ NewOpenAiService service;
+
+ @BeforeEach
+ void setUp() {
+ service = new NewOpenAiService();
+ }
+
+ @Test
+ void chatCompletion() {
+ final var completion = service.chatCompletion("Who is the prettiest");
+
+ assertThat(completion.getChoice().getMessage().getRole()).isEqualTo(ASSISTANT);
+ assertThat(completion.getContent()).isNotEmpty();
+ }
+
+ @Test
+ void chatCompletionImage() {
+ final var completion =
+ service.chatCompletionImage(
+ "https://upload.wikimedia.org/wikipedia/commons/thumb/5/59/SAP_2011_logo.svg/440px-SAP_2011_logo.svg.png");
+
+ final var message = completion.getChoices().get(0).getMessage();
+ assertThat(message.getRole()).isEqualTo(ASSISTANT);
+ assertThat(message.getContent()).isNotEmpty();
+ }
+
+ @Test
+ void streamChatCompletion() {
+ final var userMessage = OpenAiMessage.user("Who is the prettiest?");
+ final var prompt = new OpenAiChatCompletionRequest(userMessage);
+
+ final var totalOutput = new AtomicReference();
+ final var filledDeltaCount = new AtomicInteger(0);
+ OpenAiClient.forModel(GPT_35_TURBO)
+ .streamChatCompletionDeltas(prompt)
+ // foreach consumes all elements, closing the stream at the end
+ .forEach(
+ delta -> {
+ final var usage = delta.getCompletionUsage();
+ totalOutput.compareAndExchange(null, usage);
+ final String deltaContent = delta.getDeltaContent();
+ log.info("delta: {}", delta);
+ if (!deltaContent.isEmpty()) {
+ filledDeltaCount.incrementAndGet();
+ }
+ });
+
+ // the first two and the last delta don't have any content
+ // see OpenAiChatCompletionDelta#getDeltaContent
+ assertThat(filledDeltaCount.get()).isGreaterThan(0);
+
+ assertThat(totalOutput.get().getTotalTokens()).isGreaterThan(0);
+ assertThat(totalOutput.get().getPromptTokens()).isEqualTo(14);
+ assertThat(totalOutput.get().getCompletionTokens()).isGreaterThan(0);
+ }
+
+ @Test
+ void chatCompletionTools() {
+ final var completion = service.chatCompletionTools(12);
+
+ final var message = completion.getChoice().getMessage();
+ assertThat(message.getRole()).isEqualTo(ASSISTANT);
+ assertThat(message.getToolCalls()).isNotNull();
+ assertThat(message.getToolCalls().get(0).getFunction().getName()).isEqualTo("fibonacci");
+ }
+
+ @Test
+ void embedding() {
+ final var embedding = service.embedding("Hello world");
+
+ assertThat(embedding.getData().get(0).getEmbedding()).hasSizeGreaterThan(1);
+ assertThat(embedding.getModel()).isEqualTo("ada");
+ assertThat(embedding.getObject()).isEqualTo("list");
+ }
+
+ @Test
+ void chatCompletionWithResource() {
+ final var completion =
+ service.chatCompletionWithResource("ai-sdk-java-e2e", "Where is the nearest coffee shop?");
+
+ assertThat(completion.getChoice().getMessage().getRole()).isEqualTo(ASSISTANT);
+ assertThat(completion.getContent()).isNotEmpty();
+ }
+}
diff --git a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OpenAiTest.java b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OpenAiTest.java
index ec07fb037..78ff6fccb 100644
--- a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OpenAiTest.java
+++ b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OpenAiTest.java
@@ -75,8 +75,7 @@ void streamChatCompletion() {
@Test
void chatCompletionTools() {
- final var completion =
- service.chatCompletionTools("Calculate the Fibonacci number for given sequence index.");
+ final var completion = service.chatCompletionTools(12);
final var message = completion.getChoices().get(0).getMessage();
assertThat(message.getRole()).isEqualTo("assistant");
diff --git a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/services/NewOpenAiService.java b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/services/NewOpenAiService.java
new file mode 100644
index 000000000..f81881278
--- /dev/null
+++ b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/services/NewOpenAiService.java
@@ -0,0 +1,179 @@
+package com.sap.ai.sdk.app.services;
+
+import static com.sap.ai.sdk.foundationmodels.openai.OpenAiModel.GPT_35_TURBO;
+import static com.sap.ai.sdk.foundationmodels.openai.OpenAiModel.GPT_4O;
+import static com.sap.ai.sdk.foundationmodels.openai.OpenAiModel.TEXT_EMBEDDING_ADA_002;
+import static com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestMessageContentPartImage.TypeEnum.IMAGE_URL;
+import static com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestMessageContentPartImageImageUrl.DetailEnum.HIGH;
+import static com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestMessageContentPartText.TypeEnum.TEXT;
+import static com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestUserMessage.RoleEnum.USER;
+import static com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool.TypeEnum.FUNCTION;
+
+import com.sap.ai.sdk.core.AiCoreService;
+import com.sap.ai.sdk.foundationmodels.openai.OpenAiChatCompletionDelta;
+import com.sap.ai.sdk.foundationmodels.openai.OpenAiChatCompletionRequest;
+import com.sap.ai.sdk.foundationmodels.openai.OpenAiChatCompletionResponse;
+import com.sap.ai.sdk.foundationmodels.openai.OpenAiClient;
+import com.sap.ai.sdk.foundationmodels.openai.OpenAiMessage;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionNamedToolChoice;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionNamedToolChoiceFunction;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestMessageContentPartImage;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestMessageContentPartImageImageUrl;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestMessageContentPartText;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestUserMessage;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestUserMessageContent;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionToolChoiceOption;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionRequest;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionResponse;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreate200Response;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequest;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequestInput;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.FunctionObject;
+import java.net.URI;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Stream;
+import javax.annotation.Nonnull;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.stereotype.Service;
+
+/** Service class for OpenAI service */
+@Service
+@Slf4j
+public class NewOpenAiService {
+
+ /**
+ * Chat request to OpenAI
+ *
+ * @param prompt The prompt to send to the assistant
+ * @return the assistant message response
+ */
+ @Nonnull
+ public OpenAiChatCompletionResponse chatCompletion(@Nonnull final String prompt) {
+ return OpenAiClient.forModel(GPT_35_TURBO)
+ .chatCompletion(new OpenAiChatCompletionRequest(prompt));
+ }
+
+ /**
+ * Asynchronous stream of an OpenAI chat request
+ *
+ * @return the emitter that streams the assistant message response
+ */
+ @Nonnull
+ public Stream streamChatCompletionDeltas(
+ @Nonnull final String message) {
+ final var request = new OpenAiChatCompletionRequest(OpenAiMessage.user(message));
+
+ return OpenAiClient.forModel(GPT_35_TURBO).streamChatCompletionDeltas(request);
+ }
+
+ /**
+ * Asynchronous stream of an OpenAI chat request
+ *
+ * @return the emitter that streams the assistant message response
+ */
+ @Nonnull
+ public Stream streamChatCompletion(@Nonnull final String message) {
+ return OpenAiClient.forModel(GPT_35_TURBO)
+ .withSystemPrompt("Be a good, honest AI and answer the following question:")
+ .streamChatCompletion(message);
+ }
+
+ /**
+ * Chat request to OpenAI with an image
+ *
+ * @param linkToImage The link to the image
+ * @return the assistant message response
+ */
+ @Nonnull
+ public CreateChatCompletionResponse chatCompletionImage(@Nonnull final String linkToImage) {
+ final var partText =
+ new ChatCompletionRequestMessageContentPartText()
+ .type(TEXT)
+ .text("Describe the following image.");
+ final var partImageUrl =
+ new ChatCompletionRequestMessageContentPartImageImageUrl()
+ .url(URI.create(linkToImage))
+ .detail(HIGH);
+ final var partImage =
+ new ChatCompletionRequestMessageContentPartImage().type(IMAGE_URL).imageUrl(partImageUrl);
+ final var userMessage =
+ new ChatCompletionRequestUserMessage()
+ .role(USER)
+ .content(ChatCompletionRequestUserMessageContent.create(List.of(partText, partImage)));
+ final var request =
+ new CreateChatCompletionRequest()
+ .addMessagesItem(userMessage)
+ .functions(null)
+ .tools(null)
+ .parallelToolCalls(null);
+
+ return OpenAiClient.forModel(GPT_4O).chatCompletion(request);
+ }
+
+ /**
+ * Chat request to OpenAI with a tool.
+ *
+ * @param months The number of months to be inferred in the tool
+ * @return the assistant message response
+ */
+ @Nonnull
+ public OpenAiChatCompletionResponse chatCompletionTools(final int months) {
+ final var function =
+ new FunctionObject()
+ .name("fibonacci")
+ .description("Calculate the Fibonacci number for given sequence index.")
+ .parameters(
+ Map.of("type", "object", "properties", Map.of("N", Map.of("type", "integer"))));
+
+ final var tool = new ChatCompletionTool().type(FUNCTION).function(function);
+
+ final var toolChoice =
+ ChatCompletionToolChoiceOption.create(
+ new ChatCompletionNamedToolChoice()
+ .type(ChatCompletionNamedToolChoice.TypeEnum.FUNCTION)
+ .function(new ChatCompletionNamedToolChoiceFunction().name("fibonacci")));
+
+ final var request =
+ new OpenAiChatCompletionRequest(
+ "A pair of rabbits is placed in a field. Each month, every pair produces one new pair, starting from the second month. How many rabbits will there be after %s months?"
+ .formatted(months))
+ .withTools(List.of(tool))
+ .withToolChoice(toolChoice);
+
+ return OpenAiClient.forModel(GPT_35_TURBO).chatCompletion(request);
+ }
+
+ /**
+ * Get the embedding of a text
+ *
+ * @param input The text to embed
+ * @return the embedding response
+ */
+ @Nonnull
+ public EmbeddingsCreate200Response embedding(@Nonnull final String input) {
+ final var request =
+ new EmbeddingsCreateRequest().input(EmbeddingsCreateRequestInput.create(input));
+
+ return OpenAiClient.forModel(TEXT_EMBEDDING_ADA_002).embedding(request);
+ }
+
+ /**
+ * Chat request to OpenAI filtering by resource group
+ *
+ * @param resourceGroup The resource group to use
+ * @param prompt The prompt to send to the assistant
+ * @return the assistant message response
+ */
+ @Nonnull
+ public OpenAiChatCompletionResponse chatCompletionWithResource(
+ @Nonnull final String resourceGroup, @Nonnull final String prompt) {
+
+ final var destination =
+ new AiCoreService().getInferenceDestination(resourceGroup).forModel(GPT_4O);
+
+ return OpenAiClient.withCustomDestination(destination)
+ .chatCompletion(new OpenAiChatCompletionRequest(prompt));
+ }
+}