diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java
index 100291936..4649e58ff 100644
--- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java
+++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java
@@ -35,16 +35,16 @@ public class OrchestrationChatResponse {
*
Note: If there are multiple choices only the first one is returned
*
* @return the message content or empty string.
- * @throws OrchestrationFilterException.Output if the content filter filtered the output.
+ * @throws OrchestrationClientException if the content filter filtered the output.
*/
@Nonnull
- public String getContent() throws OrchestrationFilterException.Output {
+ public String getContent() throws OrchestrationClientException {
final var choice = getChoice();
if ("content_filter".equals(choice.getFinishReason())) {
final var filterDetails = Try.of(this::getOutputFilteringChoices).getOrElseGet(e -> Map.of());
final var message = "Content filter filtered the output.";
- throw new OrchestrationFilterException.Output(message).setFilterDetails(filterDetails);
+ throw new OrchestrationClientException(message).setFilterDetails(filterDetails);
}
return choice.getMessage().getContent();
}
diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java
index d7e165cdb..f115f66a9 100644
--- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java
+++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java
@@ -116,13 +116,13 @@ public Stream streamChatCompletion(
}
private static void throwOnContentFilter(@Nonnull final OrchestrationChatCompletionDelta delta)
- throws OrchestrationFilterException.Output {
+ throws OrchestrationClientException {
final String finishReason = delta.getFinishReason();
if (finishReason != null && finishReason.equals("content_filter")) {
final var filterDetails =
Try.of(() -> getOutputFilteringChoices(delta)).getOrElseGet(e -> Map.of());
final var message = "Content filter filtered the output.";
- throw new OrchestrationFilterException.Output(message).setFilterDetails(filterDetails);
+ throw new OrchestrationClientException(message).setFilterDetails(filterDetails);
}
}
diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClientException.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClientException.java
index 350679e4d..9b3665232 100644
--- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClientException.java
+++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClientException.java
@@ -1,22 +1,31 @@
package com.sap.ai.sdk.orchestration;
+import static com.sap.ai.sdk.orchestration.OrchestrationJacksonConfiguration.getOrchestrationObjectMapper;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.annotations.Beta;
import com.sap.ai.sdk.core.common.ClientException;
import com.sap.ai.sdk.core.common.ClientExceptionFactory;
-import com.sap.ai.sdk.orchestration.OrchestrationFilterException.Input;
+import com.sap.ai.sdk.orchestration.model.AzureContentSafetyInput;
+import com.sap.ai.sdk.orchestration.model.AzureContentSafetyOutput;
import com.sap.ai.sdk.orchestration.model.Error;
import com.sap.ai.sdk.orchestration.model.ErrorResponse;
import com.sap.ai.sdk.orchestration.model.ErrorResponseStreaming;
import com.sap.ai.sdk.orchestration.model.ErrorStreaming;
import com.sap.ai.sdk.orchestration.model.GenericModuleResult;
+import com.sap.ai.sdk.orchestration.model.LlamaGuard38b;
import com.sap.ai.sdk.orchestration.model.ModuleResults;
import com.sap.ai.sdk.orchestration.model.ModuleResultsStreaming;
-import java.util.Collections;
import java.util.Map;
import java.util.Optional;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
+import lombok.AccessLevel;
+import lombok.Getter;
+import lombok.Setter;
+import lombok.experimental.Accessors;
import lombok.experimental.StandardException;
+import lombok.val;
/** Exception thrown by the {@link OrchestrationClient} in case of an error. */
@StandardException
@@ -24,35 +33,37 @@ public class OrchestrationClientException extends ClientException {
static final ClientExceptionFactory FACTORY =
(message, clientError, cause) -> {
- final var details = extractInputFilterDetails(clientError);
- if (details.isEmpty()) {
- return new OrchestrationClientException(message, cause).setClientError(clientError);
+ val result = new OrchestrationClientException(message, cause);
+ val details = extractModuleResults(clientError).map(GenericModuleResult::getData);
+ if (details.orElse(null) instanceof Map, ?> m) {
+ result.setFilterDetails((Map) m);
}
- return new Input(message, cause).setFilterDetails(details).setClientError(clientError);
+ return result.setClientError(clientError);
};
- @SuppressWarnings("unchecked")
+ private static final ObjectMapper OBJECT_MAPPER = getOrchestrationObjectMapper();
+
+ /** Details about the filters that caused the exception. */
+ @Accessors(chain = true)
+ @Setter(AccessLevel.PACKAGE)
+ @Getter
+ @Nullable
+ private Map filterDetails = null;
+
@Nonnull
- static Map extractInputFilterDetails(@Nullable final OrchestrationError error) {
- if (error instanceof OrchestrationError.Synchronous synchronousError) {
+ static Optional extractModuleResults(@Nullable final OrchestrationError e) {
+ if (e instanceof OrchestrationError.Synchronous synchronousError) {
return Optional.of(synchronousError.getErrorResponse())
.map(ErrorResponse::getError)
.map(Error::getIntermediateResults)
- .map(ModuleResults::getInputFiltering)
- .map(GenericModuleResult::getData)
- .map(map -> (Map) map)
- .orElseGet(Collections::emptyMap);
- } else if (error instanceof OrchestrationError.Streaming streamingError) {
+ .map(ModuleResults::getInputFiltering);
+ } else if (e instanceof OrchestrationError.Streaming streamingError) {
return Optional.of(streamingError.getErrorResponse())
.map(ErrorResponseStreaming::getError)
.map(ErrorStreaming::getIntermediateResults)
- .map(ModuleResultsStreaming::getInputFiltering)
- .map(GenericModuleResult::getData)
- .filter(Map.class::isInstance)
- .map(map -> (Map) map)
- .orElseGet(Collections::emptyMap);
+ .map(ModuleResultsStreaming::getInputFiltering);
}
- return Collections.emptyMap();
+ return Optional.empty();
}
@Override
@@ -105,4 +116,49 @@ public Integer getStatusCode() {
.map(Error::getCode)
.orElse(null);
}
+
+ /**
+ * Retrieves LlamaGuard 3.8b details from {@code filterDetails}, if present.
+ *
+ * @return The LlamaGuard38b object, or {@code null} if not found or conversion fails.
+ * @throws IllegalArgumentException if the conversion of filter details to {@link LlamaGuard38b}
+ * fails due to invalid content.
+ */
+ @Nullable
+ public LlamaGuard38b getLlamaGuard38b() {
+ return Optional.ofNullable(filterDetails)
+ .map(details -> details.get("llama_guard_3_8b"))
+ .map(obj -> OBJECT_MAPPER.convertValue(obj, LlamaGuard38b.class))
+ .orElse(null);
+ }
+
+ /**
+ * Retrieves Azure Content Safety input details from {@code filterDetails}, if present.
+ *
+ * @return The AzureContentSafetyInput object, or {@code null} if not found or conversion fails.
+ * @throws IllegalArgumentException if the conversion of filter details to {@link
+ * AzureContentSafetyInput} fails due to invalid content.
+ */
+ @Nullable
+ public AzureContentSafetyInput getAzureContentSafetyInput() {
+ return Optional.ofNullable(filterDetails)
+ .map(details -> details.get("azure_content_safety"))
+ .map(obj -> OBJECT_MAPPER.convertValue(obj, AzureContentSafetyInput.class))
+ .orElse(null);
+ }
+
+ /**
+ * Retrieves Azure Content Safety output details from {@code filterDetails}, if present.
+ *
+ * @return The AzureContentSafetyOutput object, or {@code null} if not found or conversion fails.
+ * @throws IllegalArgumentException if the conversion of filter details to {@link
+ * AzureContentSafetyOutput} fails due to invalid content.
+ */
+ @Nullable
+ public AzureContentSafetyOutput getAzureContentSafetyOutput() {
+ return Optional.ofNullable(filterDetails)
+ .map(details -> details.get("azure_content_safety"))
+ .map(obj -> OBJECT_MAPPER.convertValue(obj, AzureContentSafetyOutput.class))
+ .orElse(null);
+ }
}
diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationFilterException.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationFilterException.java
deleted file mode 100644
index c62d74fa2..000000000
--- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationFilterException.java
+++ /dev/null
@@ -1,90 +0,0 @@
-package com.sap.ai.sdk.orchestration;
-
-import static com.sap.ai.sdk.orchestration.OrchestrationJacksonConfiguration.getOrchestrationObjectMapper;
-
-import com.google.common.annotations.Beta;
-import com.sap.ai.sdk.orchestration.model.AzureContentSafetyInput;
-import com.sap.ai.sdk.orchestration.model.AzureContentSafetyOutput;
-import com.sap.ai.sdk.orchestration.model.LlamaGuard38b;
-import java.util.Map;
-import java.util.Optional;
-import javax.annotation.Nonnull;
-import javax.annotation.Nullable;
-import lombok.AccessLevel;
-import lombok.Getter;
-import lombok.Setter;
-import lombok.experimental.Accessors;
-import lombok.experimental.StandardException;
-
-/** Base exception for errors occurring during orchestration filtering. */
-@Beta
-@StandardException(access = AccessLevel.PRIVATE)
-public class OrchestrationFilterException extends OrchestrationClientException {
-
- /** Details about the filters that caused the exception. */
- @Accessors(chain = true)
- @Setter(AccessLevel.PACKAGE)
- @Getter
- @Nonnull
- protected Map filterDetails = Map.of();
-
- /**
- * Retrieves LlamaGuard 3.8b details from {@code filterDetails}, if present.
- *
- * @return The LlamaGuard38b object, or {@code null} if not found or conversion fails.
- * @throws IllegalArgumentException if the conversion of filter details to {@link LlamaGuard38b}
- * fails due to invalid content.
- */
- @Nullable
- public LlamaGuard38b getLlamaGuard38b() {
- return Optional.ofNullable(filterDetails.get("llama_guard_3_8b"))
- .map(obj -> getOrchestrationObjectMapper().convertValue(obj, LlamaGuard38b.class))
- .orElse(null);
- }
-
- /** Exception thrown when an error occurs during input filtering. */
- @StandardException
- public static class Input extends OrchestrationFilterException {
-
- /**
- * Retrieves Azure Content Safety input details from {@code filterDetails}, if present.
- *
- * @return The AzureContentSafetyInput object, or {@code null} if not found or conversion fails.
- * @throws IllegalArgumentException if the conversion of filter details to {@link
- * AzureContentSafetyInput} fails due to invalid content.
- */
- @Nullable
- public AzureContentSafetyInput getAzureContentSafetyInput() {
- return Optional.ofNullable(filterDetails.get("azure_content_safety"))
- .map(
- obj ->
- getOrchestrationObjectMapper().convertValue(obj, AzureContentSafetyInput.class))
- .orElse(null);
- }
- }
-
- /**
- * Exception thrown when an error occurs during output filtering, specifically when the finish
- * reason is a content filter.
- */
- @StandardException
- public static class Output extends OrchestrationFilterException {
-
- /**
- * Retrieves Azure Content Safety output details from {@code filterDetails}, if present.
- *
- * @return The AzureContentSafetyOutput object, or {@code null} if not found or conversion
- * fails.
- * @throws IllegalArgumentException if the conversion of filter details to {@link
- * AzureContentSafetyOutput} fails due to invalid content.
- */
- @Nullable
- public AzureContentSafetyOutput getAzureContentSafetyOutput() {
- return Optional.ofNullable(filterDetails.get("azure_content_safety"))
- .map(
- obj ->
- getOrchestrationObjectMapper().convertValue(obj, AzureContentSafetyOutput.class))
- .orElse(null);
- }
- }
-}
diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java
index 060ca8430..286fb1b2f 100644
--- a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java
+++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java
@@ -449,7 +449,7 @@ void inputFilteringStrict() {
assertThatThrownBy(() -> client.chatCompletion(prompt, configWithFilter))
.isInstanceOfSatisfying(
- OrchestrationFilterException.Input.class,
+ OrchestrationClientException.class,
e -> {
assertThat(e.getMessage())
.isEqualTo(
@@ -503,7 +503,7 @@ void outputFilteringStrict() {
assertThatThrownBy(client.chatCompletion(prompt, configWithFilter)::getContent)
.isInstanceOfSatisfying(
- OrchestrationFilterException.Output.class,
+ OrchestrationClientException.class,
e -> {
assertThat(e.getMessage()).isEqualTo("Content filter filtered the output.");
assertThat(e.getFilterDetails())
@@ -759,9 +759,9 @@ void testThrowsOnContentFilter() {
// this must not throw, since the stream is lazily evaluated
var stream = mock.streamChatCompletion(new OrchestrationPrompt(""), config);
assertThatThrownBy(stream::toList)
- .isInstanceOf(OrchestrationFilterException.Output.class)
+ .isInstanceOf(OrchestrationClientException.class)
.hasMessage("Content filter filtered the output.")
- .extracting(e -> ((OrchestrationFilterException.Output) e).getFilterDetails())
+ .extracting(e -> ((OrchestrationClientException) e).getFilterDetails())
.isEqualTo(Map.of("azure_content_safety", Map.of("hate", 0, "self_harm", 0)));
}
@@ -785,7 +785,7 @@ void streamChatCompletionOutputFilterErrorHandling() throws IOException {
assertThatThrownBy(() -> stream.forEach(System.out::println))
.hasMessage("Content filter filtered the output.")
.isInstanceOfSatisfying(
- OrchestrationFilterException.Output.class,
+ OrchestrationClientException.class,
e -> {
assertThat(e.getErrorResponse()).isNull();
assertThat(e.getErrorResponseStreaming()).isNull();
diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java
index 6f1d48b1f..07f6ef14e 100644
--- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java
+++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java
@@ -5,7 +5,7 @@
import com.sap.ai.sdk.app.services.OrchestrationService;
import com.sap.ai.sdk.orchestration.AzureFilterThreshold;
import com.sap.ai.sdk.orchestration.OrchestrationChatResponse;
-import com.sap.ai.sdk.orchestration.OrchestrationFilterException;
+import com.sap.ai.sdk.orchestration.OrchestrationClientException;
import com.sap.ai.sdk.orchestration.model.AzureContentSafetyInput;
import com.sap.ai.sdk.orchestration.model.AzureContentSafetyOutput;
import com.sap.ai.sdk.orchestration.model.DPIEntities;
@@ -127,7 +127,7 @@ Object inputFiltering(
final OrchestrationChatResponse response;
try {
response = service.inputFiltering(policy);
- } catch (OrchestrationFilterException.Input e) {
+ } catch (OrchestrationClientException e) {
final var msg =
new StringBuilder(
"[Http %d] Failed to obtain a response as the content was flagged by input filter. "
@@ -159,7 +159,7 @@ Object outputFiltering(
final String content;
try {
content = response.getContent();
- } catch (OrchestrationFilterException.Output e) {
+ } catch (OrchestrationClientException e) {
final var msg =
new StringBuilder(
"Failed to obtain a response as the content was flagged by output filter. ");
@@ -188,7 +188,7 @@ Object llamaGuardInputFiltering(
final OrchestrationChatResponse response;
try {
response = service.llamaGuardInputFilter(enabled);
- } catch (OrchestrationFilterException.Input e) {
+ } catch (OrchestrationClientException e) {
var msg =
"[Http %d] Failed to obtain a response as the content was flagged by input filter. "
.formatted(e.getStatusCode());
diff --git a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java
index 379abd476..420cb7844 100644
--- a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java
+++ b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java
@@ -15,7 +15,6 @@
import com.sap.ai.sdk.orchestration.Message;
import com.sap.ai.sdk.orchestration.OrchestrationClient;
import com.sap.ai.sdk.orchestration.OrchestrationClientException;
-import com.sap.ai.sdk.orchestration.OrchestrationFilterException;
import com.sap.ai.sdk.orchestration.OrchestrationModuleConfig;
import com.sap.ai.sdk.orchestration.OrchestrationPrompt;
import com.sap.ai.sdk.orchestration.TemplateConfig;
@@ -221,7 +220,7 @@ void testInputFilteringStrict() {
"Prompt filtered due to safety violations. Please modify the prompt and try again.")
.hasMessageContaining("400 (Bad Request)")
.isInstanceOfSatisfying(
- OrchestrationFilterException.Input.class,
+ OrchestrationClientException.class,
e -> {
var actualAzureContentSafety = e.getAzureContentSafetyInput();
assertThat(actualAzureContentSafety).isNotNull();
@@ -253,7 +252,7 @@ void testOutputFilteringStrict() {
assertThatThrownBy(response::getContent)
.hasMessageContaining("Content filter filtered the output.")
.isInstanceOfSatisfying(
- OrchestrationFilterException.Output.class,
+ OrchestrationClientException.class,
e -> {
var actualAzureContentSafety = e.getAzureContentSafetyOutput();
assertThat(actualAzureContentSafety).isNotNull();
@@ -280,12 +279,12 @@ void testOutputFilteringLenient() {
@Test
void testLlamaGuardEnabled() {
assertThatThrownBy(() -> service.llamaGuardInputFilter(true))
- .isInstanceOf(OrchestrationFilterException.Input.class)
+ .isInstanceOf(OrchestrationClientException.class)
.hasMessageContaining(
"Prompt filtered due to safety violations. Please modify the prompt and try again.")
.hasMessageContaining("400 (Bad Request)")
.isInstanceOfSatisfying(
- OrchestrationFilterException.Input.class,
+ OrchestrationClientException.class,
e -> {
var llamaGuard38b = e.getLlamaGuard38b();
assertThat(llamaGuard38b).isNotNull();
@@ -419,7 +418,7 @@ void testStreamingErrorHandlingInputFilter() {
val configWithFilter = config.withInputFiltering(filterConfig);
assertThatThrownBy(() -> client.streamChatCompletion(prompt, configWithFilter))
- .isInstanceOf(OrchestrationFilterException.Input.class)
+ .isInstanceOf(OrchestrationClientException.class)
.hasMessageContaining("status 400 (Bad Request)")
.hasMessageContaining("Filtering Module - Input Filter");
}