diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java index 100291936..4649e58ff 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java @@ -35,16 +35,16 @@ public class OrchestrationChatResponse { *

Note: If there are multiple choices only the first one is returned * * @return the message content or empty string. - * @throws OrchestrationFilterException.Output if the content filter filtered the output. + * @throws OrchestrationClientException if the content filter filtered the output. */ @Nonnull - public String getContent() throws OrchestrationFilterException.Output { + public String getContent() throws OrchestrationClientException { final var choice = getChoice(); if ("content_filter".equals(choice.getFinishReason())) { final var filterDetails = Try.of(this::getOutputFilteringChoices).getOrElseGet(e -> Map.of()); final var message = "Content filter filtered the output."; - throw new OrchestrationFilterException.Output(message).setFilterDetails(filterDetails); + throw new OrchestrationClientException(message).setFilterDetails(filterDetails); } return choice.getMessage().getContent(); } diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java index d7e165cdb..f115f66a9 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java @@ -116,13 +116,13 @@ public Stream streamChatCompletion( } private static void throwOnContentFilter(@Nonnull final OrchestrationChatCompletionDelta delta) - throws OrchestrationFilterException.Output { + throws OrchestrationClientException { final String finishReason = delta.getFinishReason(); if (finishReason != null && finishReason.equals("content_filter")) { final var filterDetails = Try.of(() -> getOutputFilteringChoices(delta)).getOrElseGet(e -> Map.of()); final var message = "Content filter filtered the output."; - throw new OrchestrationFilterException.Output(message).setFilterDetails(filterDetails); + throw new OrchestrationClientException(message).setFilterDetails(filterDetails); } } diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClientException.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClientException.java index 350679e4d..9b3665232 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClientException.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClientException.java @@ -1,22 +1,31 @@ package com.sap.ai.sdk.orchestration; +import static com.sap.ai.sdk.orchestration.OrchestrationJacksonConfiguration.getOrchestrationObjectMapper; + +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.annotations.Beta; import com.sap.ai.sdk.core.common.ClientException; import com.sap.ai.sdk.core.common.ClientExceptionFactory; -import com.sap.ai.sdk.orchestration.OrchestrationFilterException.Input; +import com.sap.ai.sdk.orchestration.model.AzureContentSafetyInput; +import com.sap.ai.sdk.orchestration.model.AzureContentSafetyOutput; import com.sap.ai.sdk.orchestration.model.Error; import com.sap.ai.sdk.orchestration.model.ErrorResponse; import com.sap.ai.sdk.orchestration.model.ErrorResponseStreaming; import com.sap.ai.sdk.orchestration.model.ErrorStreaming; import com.sap.ai.sdk.orchestration.model.GenericModuleResult; +import com.sap.ai.sdk.orchestration.model.LlamaGuard38b; import com.sap.ai.sdk.orchestration.model.ModuleResults; import com.sap.ai.sdk.orchestration.model.ModuleResultsStreaming; -import java.util.Collections; import java.util.Map; import java.util.Optional; import javax.annotation.Nonnull; import javax.annotation.Nullable; +import lombok.AccessLevel; +import lombok.Getter; +import lombok.Setter; +import lombok.experimental.Accessors; import lombok.experimental.StandardException; +import lombok.val; /** Exception thrown by the {@link OrchestrationClient} in case of an error. */ @StandardException @@ -24,35 +33,37 @@ public class OrchestrationClientException extends ClientException { static final ClientExceptionFactory FACTORY = (message, clientError, cause) -> { - final var details = extractInputFilterDetails(clientError); - if (details.isEmpty()) { - return new OrchestrationClientException(message, cause).setClientError(clientError); + val result = new OrchestrationClientException(message, cause); + val details = extractModuleResults(clientError).map(GenericModuleResult::getData); + if (details.orElse(null) instanceof Map m) { + result.setFilterDetails((Map) m); } - return new Input(message, cause).setFilterDetails(details).setClientError(clientError); + return result.setClientError(clientError); }; - @SuppressWarnings("unchecked") + private static final ObjectMapper OBJECT_MAPPER = getOrchestrationObjectMapper(); + + /** Details about the filters that caused the exception. */ + @Accessors(chain = true) + @Setter(AccessLevel.PACKAGE) + @Getter + @Nullable + private Map filterDetails = null; + @Nonnull - static Map extractInputFilterDetails(@Nullable final OrchestrationError error) { - if (error instanceof OrchestrationError.Synchronous synchronousError) { + static Optional extractModuleResults(@Nullable final OrchestrationError e) { + if (e instanceof OrchestrationError.Synchronous synchronousError) { return Optional.of(synchronousError.getErrorResponse()) .map(ErrorResponse::getError) .map(Error::getIntermediateResults) - .map(ModuleResults::getInputFiltering) - .map(GenericModuleResult::getData) - .map(map -> (Map) map) - .orElseGet(Collections::emptyMap); - } else if (error instanceof OrchestrationError.Streaming streamingError) { + .map(ModuleResults::getInputFiltering); + } else if (e instanceof OrchestrationError.Streaming streamingError) { return Optional.of(streamingError.getErrorResponse()) .map(ErrorResponseStreaming::getError) .map(ErrorStreaming::getIntermediateResults) - .map(ModuleResultsStreaming::getInputFiltering) - .map(GenericModuleResult::getData) - .filter(Map.class::isInstance) - .map(map -> (Map) map) - .orElseGet(Collections::emptyMap); + .map(ModuleResultsStreaming::getInputFiltering); } - return Collections.emptyMap(); + return Optional.empty(); } @Override @@ -105,4 +116,49 @@ public Integer getStatusCode() { .map(Error::getCode) .orElse(null); } + + /** + * Retrieves LlamaGuard 3.8b details from {@code filterDetails}, if present. + * + * @return The LlamaGuard38b object, or {@code null} if not found or conversion fails. + * @throws IllegalArgumentException if the conversion of filter details to {@link LlamaGuard38b} + * fails due to invalid content. + */ + @Nullable + public LlamaGuard38b getLlamaGuard38b() { + return Optional.ofNullable(filterDetails) + .map(details -> details.get("llama_guard_3_8b")) + .map(obj -> OBJECT_MAPPER.convertValue(obj, LlamaGuard38b.class)) + .orElse(null); + } + + /** + * Retrieves Azure Content Safety input details from {@code filterDetails}, if present. + * + * @return The AzureContentSafetyInput object, or {@code null} if not found or conversion fails. + * @throws IllegalArgumentException if the conversion of filter details to {@link + * AzureContentSafetyInput} fails due to invalid content. + */ + @Nullable + public AzureContentSafetyInput getAzureContentSafetyInput() { + return Optional.ofNullable(filterDetails) + .map(details -> details.get("azure_content_safety")) + .map(obj -> OBJECT_MAPPER.convertValue(obj, AzureContentSafetyInput.class)) + .orElse(null); + } + + /** + * Retrieves Azure Content Safety output details from {@code filterDetails}, if present. + * + * @return The AzureContentSafetyOutput object, or {@code null} if not found or conversion fails. + * @throws IllegalArgumentException if the conversion of filter details to {@link + * AzureContentSafetyOutput} fails due to invalid content. + */ + @Nullable + public AzureContentSafetyOutput getAzureContentSafetyOutput() { + return Optional.ofNullable(filterDetails) + .map(details -> details.get("azure_content_safety")) + .map(obj -> OBJECT_MAPPER.convertValue(obj, AzureContentSafetyOutput.class)) + .orElse(null); + } } diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationFilterException.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationFilterException.java deleted file mode 100644 index c62d74fa2..000000000 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationFilterException.java +++ /dev/null @@ -1,90 +0,0 @@ -package com.sap.ai.sdk.orchestration; - -import static com.sap.ai.sdk.orchestration.OrchestrationJacksonConfiguration.getOrchestrationObjectMapper; - -import com.google.common.annotations.Beta; -import com.sap.ai.sdk.orchestration.model.AzureContentSafetyInput; -import com.sap.ai.sdk.orchestration.model.AzureContentSafetyOutput; -import com.sap.ai.sdk.orchestration.model.LlamaGuard38b; -import java.util.Map; -import java.util.Optional; -import javax.annotation.Nonnull; -import javax.annotation.Nullable; -import lombok.AccessLevel; -import lombok.Getter; -import lombok.Setter; -import lombok.experimental.Accessors; -import lombok.experimental.StandardException; - -/** Base exception for errors occurring during orchestration filtering. */ -@Beta -@StandardException(access = AccessLevel.PRIVATE) -public class OrchestrationFilterException extends OrchestrationClientException { - - /** Details about the filters that caused the exception. */ - @Accessors(chain = true) - @Setter(AccessLevel.PACKAGE) - @Getter - @Nonnull - protected Map filterDetails = Map.of(); - - /** - * Retrieves LlamaGuard 3.8b details from {@code filterDetails}, if present. - * - * @return The LlamaGuard38b object, or {@code null} if not found or conversion fails. - * @throws IllegalArgumentException if the conversion of filter details to {@link LlamaGuard38b} - * fails due to invalid content. - */ - @Nullable - public LlamaGuard38b getLlamaGuard38b() { - return Optional.ofNullable(filterDetails.get("llama_guard_3_8b")) - .map(obj -> getOrchestrationObjectMapper().convertValue(obj, LlamaGuard38b.class)) - .orElse(null); - } - - /** Exception thrown when an error occurs during input filtering. */ - @StandardException - public static class Input extends OrchestrationFilterException { - - /** - * Retrieves Azure Content Safety input details from {@code filterDetails}, if present. - * - * @return The AzureContentSafetyInput object, or {@code null} if not found or conversion fails. - * @throws IllegalArgumentException if the conversion of filter details to {@link - * AzureContentSafetyInput} fails due to invalid content. - */ - @Nullable - public AzureContentSafetyInput getAzureContentSafetyInput() { - return Optional.ofNullable(filterDetails.get("azure_content_safety")) - .map( - obj -> - getOrchestrationObjectMapper().convertValue(obj, AzureContentSafetyInput.class)) - .orElse(null); - } - } - - /** - * Exception thrown when an error occurs during output filtering, specifically when the finish - * reason is a content filter. - */ - @StandardException - public static class Output extends OrchestrationFilterException { - - /** - * Retrieves Azure Content Safety output details from {@code filterDetails}, if present. - * - * @return The AzureContentSafetyOutput object, or {@code null} if not found or conversion - * fails. - * @throws IllegalArgumentException if the conversion of filter details to {@link - * AzureContentSafetyOutput} fails due to invalid content. - */ - @Nullable - public AzureContentSafetyOutput getAzureContentSafetyOutput() { - return Optional.ofNullable(filterDetails.get("azure_content_safety")) - .map( - obj -> - getOrchestrationObjectMapper().convertValue(obj, AzureContentSafetyOutput.class)) - .orElse(null); - } - } -} diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java index 060ca8430..286fb1b2f 100644 --- a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java @@ -449,7 +449,7 @@ void inputFilteringStrict() { assertThatThrownBy(() -> client.chatCompletion(prompt, configWithFilter)) .isInstanceOfSatisfying( - OrchestrationFilterException.Input.class, + OrchestrationClientException.class, e -> { assertThat(e.getMessage()) .isEqualTo( @@ -503,7 +503,7 @@ void outputFilteringStrict() { assertThatThrownBy(client.chatCompletion(prompt, configWithFilter)::getContent) .isInstanceOfSatisfying( - OrchestrationFilterException.Output.class, + OrchestrationClientException.class, e -> { assertThat(e.getMessage()).isEqualTo("Content filter filtered the output."); assertThat(e.getFilterDetails()) @@ -759,9 +759,9 @@ void testThrowsOnContentFilter() { // this must not throw, since the stream is lazily evaluated var stream = mock.streamChatCompletion(new OrchestrationPrompt(""), config); assertThatThrownBy(stream::toList) - .isInstanceOf(OrchestrationFilterException.Output.class) + .isInstanceOf(OrchestrationClientException.class) .hasMessage("Content filter filtered the output.") - .extracting(e -> ((OrchestrationFilterException.Output) e).getFilterDetails()) + .extracting(e -> ((OrchestrationClientException) e).getFilterDetails()) .isEqualTo(Map.of("azure_content_safety", Map.of("hate", 0, "self_harm", 0))); } @@ -785,7 +785,7 @@ void streamChatCompletionOutputFilterErrorHandling() throws IOException { assertThatThrownBy(() -> stream.forEach(System.out::println)) .hasMessage("Content filter filtered the output.") .isInstanceOfSatisfying( - OrchestrationFilterException.Output.class, + OrchestrationClientException.class, e -> { assertThat(e.getErrorResponse()).isNull(); assertThat(e.getErrorResponseStreaming()).isNull(); diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java index 6f1d48b1f..07f6ef14e 100644 --- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java +++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java @@ -5,7 +5,7 @@ import com.sap.ai.sdk.app.services.OrchestrationService; import com.sap.ai.sdk.orchestration.AzureFilterThreshold; import com.sap.ai.sdk.orchestration.OrchestrationChatResponse; -import com.sap.ai.sdk.orchestration.OrchestrationFilterException; +import com.sap.ai.sdk.orchestration.OrchestrationClientException; import com.sap.ai.sdk.orchestration.model.AzureContentSafetyInput; import com.sap.ai.sdk.orchestration.model.AzureContentSafetyOutput; import com.sap.ai.sdk.orchestration.model.DPIEntities; @@ -127,7 +127,7 @@ Object inputFiltering( final OrchestrationChatResponse response; try { response = service.inputFiltering(policy); - } catch (OrchestrationFilterException.Input e) { + } catch (OrchestrationClientException e) { final var msg = new StringBuilder( "[Http %d] Failed to obtain a response as the content was flagged by input filter. " @@ -159,7 +159,7 @@ Object outputFiltering( final String content; try { content = response.getContent(); - } catch (OrchestrationFilterException.Output e) { + } catch (OrchestrationClientException e) { final var msg = new StringBuilder( "Failed to obtain a response as the content was flagged by output filter. "); @@ -188,7 +188,7 @@ Object llamaGuardInputFiltering( final OrchestrationChatResponse response; try { response = service.llamaGuardInputFilter(enabled); - } catch (OrchestrationFilterException.Input e) { + } catch (OrchestrationClientException e) { var msg = "[Http %d] Failed to obtain a response as the content was flagged by input filter. " .formatted(e.getStatusCode()); diff --git a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java index 379abd476..420cb7844 100644 --- a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java +++ b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java @@ -15,7 +15,6 @@ import com.sap.ai.sdk.orchestration.Message; import com.sap.ai.sdk.orchestration.OrchestrationClient; import com.sap.ai.sdk.orchestration.OrchestrationClientException; -import com.sap.ai.sdk.orchestration.OrchestrationFilterException; import com.sap.ai.sdk.orchestration.OrchestrationModuleConfig; import com.sap.ai.sdk.orchestration.OrchestrationPrompt; import com.sap.ai.sdk.orchestration.TemplateConfig; @@ -221,7 +220,7 @@ void testInputFilteringStrict() { "Prompt filtered due to safety violations. Please modify the prompt and try again.") .hasMessageContaining("400 (Bad Request)") .isInstanceOfSatisfying( - OrchestrationFilterException.Input.class, + OrchestrationClientException.class, e -> { var actualAzureContentSafety = e.getAzureContentSafetyInput(); assertThat(actualAzureContentSafety).isNotNull(); @@ -253,7 +252,7 @@ void testOutputFilteringStrict() { assertThatThrownBy(response::getContent) .hasMessageContaining("Content filter filtered the output.") .isInstanceOfSatisfying( - OrchestrationFilterException.Output.class, + OrchestrationClientException.class, e -> { var actualAzureContentSafety = e.getAzureContentSafetyOutput(); assertThat(actualAzureContentSafety).isNotNull(); @@ -280,12 +279,12 @@ void testOutputFilteringLenient() { @Test void testLlamaGuardEnabled() { assertThatThrownBy(() -> service.llamaGuardInputFilter(true)) - .isInstanceOf(OrchestrationFilterException.Input.class) + .isInstanceOf(OrchestrationClientException.class) .hasMessageContaining( "Prompt filtered due to safety violations. Please modify the prompt and try again.") .hasMessageContaining("400 (Bad Request)") .isInstanceOfSatisfying( - OrchestrationFilterException.Input.class, + OrchestrationClientException.class, e -> { var llamaGuard38b = e.getLlamaGuard38b(); assertThat(llamaGuard38b).isNotNull(); @@ -419,7 +418,7 @@ void testStreamingErrorHandlingInputFilter() { val configWithFilter = config.withInputFiltering(filterConfig); assertThatThrownBy(() -> client.streamChatCompletion(prompt, configWithFilter)) - .isInstanceOf(OrchestrationFilterException.Input.class) + .isInstanceOf(OrchestrationClientException.class) .hasMessageContaining("status 400 (Bad Request)") .hasMessageContaining("Filtering Module - Input Filter"); }