diff --git a/orchestration/pom.xml b/orchestration/pom.xml index f87821a45..f976147b5 100644 --- a/orchestration/pom.xml +++ b/orchestration/pom.xml @@ -41,7 +41,7 @@ 94% 75% 93% - 100% + 97% diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java index 100291936..5ab354c16 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java @@ -35,16 +35,18 @@ public class OrchestrationChatResponse { *

Note: If there are multiple choices only the first one is returned * * @return the message content or empty string. - * @throws OrchestrationFilterException.Output if the content filter filtered the output. + * @throws OrchestrationClientException.Synchronous.OutputFilter if the content filter filtered + * the output. */ @Nonnull - public String getContent() throws OrchestrationFilterException.Output { + public String getContent() throws OrchestrationClientException.Synchronous.OutputFilter { final var choice = getChoice(); if ("content_filter".equals(choice.getFinishReason())) { final var filterDetails = Try.of(this::getOutputFilteringChoices).getOrElseGet(e -> Map.of()); final var message = "Content filter filtered the output."; - throw new OrchestrationFilterException.Output(message).setFilterDetails(filterDetails); + throw new OrchestrationClientException.Synchronous.OutputFilter(message) + .setFilterDetails(filterDetails); } return choice.getMessage().getContent(); } diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java index d7e165cdb..71e9d601f 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java @@ -116,13 +116,14 @@ public Stream streamChatCompletion( } private static void throwOnContentFilter(@Nonnull final OrchestrationChatCompletionDelta delta) - throws OrchestrationFilterException.Output { + throws OrchestrationClientException.Streaming.OutputFilter { final String finishReason = delta.getFinishReason(); if (finishReason != null && finishReason.equals("content_filter")) { final var filterDetails = Try.of(() -> getOutputFilteringChoices(delta)).getOrElseGet(e -> Map.of()); final var message = "Content filter filtered the output."; - throw new OrchestrationFilterException.Output(message).setFilterDetails(filterDetails); + throw new OrchestrationClientException.Streaming.OutputFilter(message) + .setFilterDetails(filterDetails); } } diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClientException.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClientException.java index 350679e4d..def28cf2f 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClientException.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClientException.java @@ -1,49 +1,65 @@ package com.sap.ai.sdk.orchestration; +import static com.sap.ai.sdk.orchestration.OrchestrationJacksonConfiguration.getOrchestrationObjectMapper; + +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.annotations.Beta; import com.sap.ai.sdk.core.common.ClientException; import com.sap.ai.sdk.core.common.ClientExceptionFactory; -import com.sap.ai.sdk.orchestration.OrchestrationFilterException.Input; +import com.sap.ai.sdk.orchestration.model.AzureContentSafetyInput; +import com.sap.ai.sdk.orchestration.model.AzureContentSafetyOutput; import com.sap.ai.sdk.orchestration.model.Error; import com.sap.ai.sdk.orchestration.model.ErrorResponse; import com.sap.ai.sdk.orchestration.model.ErrorResponseStreaming; import com.sap.ai.sdk.orchestration.model.ErrorStreaming; import com.sap.ai.sdk.orchestration.model.GenericModuleResult; +import com.sap.ai.sdk.orchestration.model.LlamaGuard38b; import com.sap.ai.sdk.orchestration.model.ModuleResults; import com.sap.ai.sdk.orchestration.model.ModuleResultsStreaming; +import io.vavr.control.Option; import java.util.Collections; import java.util.Map; import java.util.Optional; import javax.annotation.Nonnull; import javax.annotation.Nullable; +import lombok.AccessLevel; +import lombok.Getter; +import lombok.Setter; +import lombok.experimental.Accessors; import lombok.experimental.StandardException; /** Exception thrown by the {@link OrchestrationClient} in case of an error. */ @StandardException public class OrchestrationClientException extends ClientException { + private static final ObjectMapper MAPPER = getOrchestrationObjectMapper(); - static final ClientExceptionFactory FACTORY = - (message, clientError, cause) -> { - final var details = extractInputFilterDetails(clientError); - if (details.isEmpty()) { - return new OrchestrationClientException(message, cause).setClientError(clientError); - } - return new Input(message, cause).setFilterDetails(details).setClientError(clientError); - }; - - @SuppressWarnings("unchecked") + /** Details about the filters that caused the exception. */ + @Setter(AccessLevel.PACKAGE) + @Getter(AccessLevel.PACKAGE) + @Accessors(chain = true) @Nonnull - static Map extractInputFilterDetails(@Nullable final OrchestrationError error) { - if (error instanceof OrchestrationError.Synchronous synchronousError) { - return Optional.of(synchronousError.getErrorResponse()) - .map(ErrorResponse::getError) - .map(Error::getIntermediateResults) - .map(ModuleResults::getInputFiltering) - .map(GenericModuleResult::getData) - .map(map -> (Map) map) - .orElseGet(Collections::emptyMap); - } else if (error instanceof OrchestrationError.Streaming streamingError) { - return Optional.of(streamingError.getErrorResponse()) + protected Map filterDetails = Map.of(); + + /** Exception thrown during a streaming invocation. */ + @StandardException + public static class Streaming extends OrchestrationClientException { + static final ClientExceptionFactory FACTORY = + (message, clientError, cause) -> { + final var details = extractInputFilterDetails(clientError); + if (details.isEmpty()) { + return new Streaming(message, cause).setClientError(clientError); + } + return new InputFilter(message, cause) + .setFilterDetails(details) + .setClientError(clientError); + }; + + @SuppressWarnings("unchecked") + @Nonnull + private static Map extractInputFilterDetails( + @Nullable final OrchestrationError.Streaming error) { + return Optional.ofNullable(error) + .map(OrchestrationError.Streaming::getErrorResponse) .map(ErrorResponseStreaming::getError) .map(ErrorStreaming::getIntermediateResults) .map(ModuleResultsStreaming::getInputFiltering) @@ -52,57 +68,205 @@ static Map extractInputFilterDetails(@Nullable final Orchestrati .map(map -> (Map) map) .orElseGet(Collections::emptyMap); } - return Collections.emptyMap(); - } - @Override - @Nullable - public OrchestrationError getClientError() { - return (OrchestrationError) super.getClientError(); + /** + * Retrieves the {@link ErrorResponseStreaming} from the orchestration service, if available. + * + * @return The {@link ErrorResponseStreaming} object, or {@code null} if not available. + * @since 1.10.0 + */ + @Beta + @Nullable + public ErrorResponseStreaming getErrorResponse() { + return Option.of(getClientError()) + .map(OrchestrationError.Streaming::getErrorResponse) + .getOrNull(); + } + + /** + * Retrieves the client error details from the orchestration service, if available. + * + * @return The {@link OrchestrationError.Streaming} object, or {@code null} if not available. + * @since 1.10.0 + */ + @Override + public OrchestrationError.Streaming getClientError() { + return super.getClientError() instanceof OrchestrationError.Streaming e ? e : null; + } + + /** Exception thrown during a streaming invocation that contains input filter details. */ + @Beta + @StandardException + public static class InputFilter extends Streaming implements Filter.Input { + @Nonnull + @Override + public Map getFilterDetails() { + return super.getFilterDetails(); + } + } + + /** Exception thrown during a streaming invocation that contains output filter details. */ + @Beta + @StandardException + public static class OutputFilter extends Streaming implements Filter.Output { + @Nonnull + @Override + public Map getFilterDetails() { + return super.getFilterDetails(); + } + } } - /** - * Retrieves the {@link ErrorResponse} from the orchestration service, if available. - * - * @return The {@link ErrorResponse} object, or {@code null} if not available. - * @since 1.10.0 - */ - @Beta - @Nullable - public ErrorResponse getErrorResponse() { - if (getClientError() instanceof OrchestrationError.Synchronous orchestrationError) { - return orchestrationError.getErrorResponse(); + /** Exception thrown during a synchronous invocation. */ + @StandardException + public static class Synchronous extends OrchestrationClientException { + static final ClientExceptionFactory FACTORY = + (message, clientError, cause) -> { + final var details = extractInputFilterDetails(clientError); + if (details.isEmpty()) { + return new Synchronous(message, cause).setClientError(clientError); + } + return new InputFilter(message, cause) + .setFilterDetails(details) + .setClientError(clientError); + }; + + @SuppressWarnings("unchecked") + @Nonnull + private static Map extractInputFilterDetails( + @Nullable final OrchestrationError.Synchronous error) { + return Optional.ofNullable(error) + .map(OrchestrationError.Synchronous::getErrorResponse) + .map(ErrorResponse::getError) + .map(Error::getIntermediateResults) + .map(ModuleResults::getInputFiltering) + .map(GenericModuleResult::getData) + .map(map -> (Map) map) + .orElseGet(Collections::emptyMap); + } + + /** + * Retrieves the {@link ErrorResponse} from the orchestration service, if available. + * + * @return The {@link ErrorResponse} object, or {@code null} if not available. + * @since 1.10.0 + */ + @Beta + @Nullable + public ErrorResponse getErrorResponse() { + return Option.of(getClientError()) + .map(OrchestrationError.Synchronous::getErrorResponse) + .getOrNull(); + } + + /** + * Retrieves the client error details from the orchestration service, if available. + * + * @return The {@link OrchestrationError.Synchronous} object, or {@code null} if not available. + * @since 1.10.0 + */ + @Override + public OrchestrationError.Synchronous getClientError() { + return super.getClientError() instanceof OrchestrationError.Synchronous e ? e : null; + } + + /** Exception thrown during a synchronous invocation that contains input filter details. */ + @Beta + @StandardException + public static class InputFilter extends Synchronous implements Filter.Input { + @Nonnull + @Override + public Map getFilterDetails() { + return super.getFilterDetails(); + } + + /** + * Retrieves the HTTP status code from the original error response, if available. + * + * @return the HTTP status code, or {@code null} if not available + * @since 1.10.0 + */ + @Beta + @Nullable + public Integer getStatusCode() { + return Optional.ofNullable(getErrorResponse()) + .map(e -> e.getError().getCode()) + .orElse(null); + } + } + + /** Exception thrown during a synchronous invocation that contains output filter details. */ + @Beta + @StandardException + public static class OutputFilter extends Synchronous implements Filter.Output { + @Nonnull + @Override + public Map getFilterDetails() { + return super.getFilterDetails(); + } } - return null; } /** - * Retrieves the {@link ErrorResponseStreaming} from the orchestration service, if available. - * - * @return The {@link ErrorResponseStreaming} object, or {@code null} if not available. - * @since 1.10.0 + * Interface representing the filter details that can be included in an orchestration error + * response. */ @Beta - @Nullable - public ErrorResponseStreaming getErrorResponseStreaming() { - if (getClientError() instanceof OrchestrationError.Streaming orchestrationError) { - return orchestrationError.getErrorResponse(); + interface Filter { + /** + * Retrieves the filter details as a map. + * + * @return a map containing the filter details. + */ + @Nonnull + Map getFilterDetails(); + + /** + * Retrieves the Azure Content Safety input filter details, if available. + * + * @return the {@link AzureContentSafetyInput} object, or {@code null} if not available. + */ + @Nullable + default LlamaGuard38b getLlamaGuard38b() { + return Optional.ofNullable(getFilterDetails().get("llama_guard_3_8b")) + .map(obj -> MAPPER.convertValue(obj, LlamaGuard38b.class)) + .orElse(null); + } + + /** Interface for input filters that can be included in an orchestration error response. */ + interface Input extends Filter { + /** + * Retrieves the Azure Content Safety input filter details, if available. + * + * @return the {@link AzureContentSafetyInput} object, or {@code null} if not available. + */ + @Nullable + default AzureContentSafetyInput getAzureContentSafetyInput() { + return Optional.ofNullable(getFilterDetails().get("azure_content_safety")) + .map(obj -> MAPPER.convertValue(obj, AzureContentSafetyInput.class)) + .orElse(null); + } + } + + /** Interface for output filters that can be included in an orchestration error response. */ + interface Output extends Filter { + /** + * Retrieves the Azure Content Safety output filter details, if available. + * + * @return the {@link AzureContentSafetyOutput} object, or {@code null} if not available. + */ + @Nullable + default AzureContentSafetyOutput getAzureContentSafetyOutput() { + return Optional.ofNullable(getFilterDetails().get("azure_content_safety")) + .map(obj -> MAPPER.convertValue(obj, AzureContentSafetyOutput.class)) + .orElse(null); + } } - return null; } - /** - * Retrieves the HTTP status code from the original error response, if available. - * - * @return the HTTP status code, or {@code null} if not available - * @since 1.10.0 - */ - @Beta + @Override @Nullable - public Integer getStatusCode() { - return Optional.ofNullable(getErrorResponse()) - .map(ErrorResponse::getError) - .map(Error::getCode) - .orElse(null); + public OrchestrationError getClientError() { + return (OrchestrationError) super.getClientError(); } } diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationFilterException.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationFilterException.java deleted file mode 100644 index c62d74fa2..000000000 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationFilterException.java +++ /dev/null @@ -1,90 +0,0 @@ -package com.sap.ai.sdk.orchestration; - -import static com.sap.ai.sdk.orchestration.OrchestrationJacksonConfiguration.getOrchestrationObjectMapper; - -import com.google.common.annotations.Beta; -import com.sap.ai.sdk.orchestration.model.AzureContentSafetyInput; -import com.sap.ai.sdk.orchestration.model.AzureContentSafetyOutput; -import com.sap.ai.sdk.orchestration.model.LlamaGuard38b; -import java.util.Map; -import java.util.Optional; -import javax.annotation.Nonnull; -import javax.annotation.Nullable; -import lombok.AccessLevel; -import lombok.Getter; -import lombok.Setter; -import lombok.experimental.Accessors; -import lombok.experimental.StandardException; - -/** Base exception for errors occurring during orchestration filtering. */ -@Beta -@StandardException(access = AccessLevel.PRIVATE) -public class OrchestrationFilterException extends OrchestrationClientException { - - /** Details about the filters that caused the exception. */ - @Accessors(chain = true) - @Setter(AccessLevel.PACKAGE) - @Getter - @Nonnull - protected Map filterDetails = Map.of(); - - /** - * Retrieves LlamaGuard 3.8b details from {@code filterDetails}, if present. - * - * @return The LlamaGuard38b object, or {@code null} if not found or conversion fails. - * @throws IllegalArgumentException if the conversion of filter details to {@link LlamaGuard38b} - * fails due to invalid content. - */ - @Nullable - public LlamaGuard38b getLlamaGuard38b() { - return Optional.ofNullable(filterDetails.get("llama_guard_3_8b")) - .map(obj -> getOrchestrationObjectMapper().convertValue(obj, LlamaGuard38b.class)) - .orElse(null); - } - - /** Exception thrown when an error occurs during input filtering. */ - @StandardException - public static class Input extends OrchestrationFilterException { - - /** - * Retrieves Azure Content Safety input details from {@code filterDetails}, if present. - * - * @return The AzureContentSafetyInput object, or {@code null} if not found or conversion fails. - * @throws IllegalArgumentException if the conversion of filter details to {@link - * AzureContentSafetyInput} fails due to invalid content. - */ - @Nullable - public AzureContentSafetyInput getAzureContentSafetyInput() { - return Optional.ofNullable(filterDetails.get("azure_content_safety")) - .map( - obj -> - getOrchestrationObjectMapper().convertValue(obj, AzureContentSafetyInput.class)) - .orElse(null); - } - } - - /** - * Exception thrown when an error occurs during output filtering, specifically when the finish - * reason is a content filter. - */ - @StandardException - public static class Output extends OrchestrationFilterException { - - /** - * Retrieves Azure Content Safety output details from {@code filterDetails}, if present. - * - * @return The AzureContentSafetyOutput object, or {@code null} if not found or conversion - * fails. - * @throws IllegalArgumentException if the conversion of filter details to {@link - * AzureContentSafetyOutput} fails due to invalid content. - */ - @Nullable - public AzureContentSafetyOutput getAzureContentSafetyOutput() { - return Optional.ofNullable(filterDetails.get("azure_content_safety")) - .map( - obj -> - getOrchestrationObjectMapper().convertValue(obj, AzureContentSafetyOutput.class)) - .orElse(null); - } - } -} diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationHttpExecutor.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationHttpExecutor.java index 59954c929..c65060b37 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationHttpExecutor.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationHttpExecutor.java @@ -1,6 +1,5 @@ package com.sap.ai.sdk.orchestration; -import static com.sap.ai.sdk.orchestration.OrchestrationClientException.FACTORY; import static com.sap.ai.sdk.orchestration.OrchestrationJacksonConfiguration.getOrchestrationObjectMapper; import com.fasterxml.jackson.core.JsonProcessingException; @@ -49,7 +48,10 @@ T execute( val client = getHttpClient(); val handler = - new ClientResponseHandler<>(responseType, OrchestrationError.Synchronous.class, FACTORY) + new ClientResponseHandler<>( + responseType, + OrchestrationError.Synchronous.class, + OrchestrationClientException.Synchronous.FACTORY) .objectMapper(JACKSON); return client.execute(request, handler); @@ -76,7 +78,9 @@ Stream stream( val client = getHttpClient(); return new ClientStreamingHandler<>( - OrchestrationChatCompletionDelta.class, OrchestrationError.Streaming.class, FACTORY) + OrchestrationChatCompletionDelta.class, + OrchestrationError.Streaming.class, + OrchestrationClientException.Streaming.FACTORY) .objectMapper(JACKSON) .handleStreamingResponse(client.executeOpen(null, request, null)); diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java index 060ca8430..340bc4daa 100644 --- a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java @@ -379,12 +379,11 @@ void testBadRequest() { assertThatThrownBy(() -> client.chatCompletion(prompt, config)) .isInstanceOfSatisfying( - OrchestrationClientException.class, + OrchestrationClientException.Synchronous.class, e -> { assertThat(e.getMessage()) .isEqualTo( "Request failed with status 400 (Bad Request): Missing required parameters: ['input']"); - assertThat(e.getErrorResponseStreaming()).isNull(); assertThat(e.getErrorResponse()).isNotNull(); assertThat(e.getErrorResponse().getError().getMessage()) .isEqualTo("Missing required parameters: ['input']"); @@ -449,7 +448,7 @@ void inputFilteringStrict() { assertThatThrownBy(() -> client.chatCompletion(prompt, configWithFilter)) .isInstanceOfSatisfying( - OrchestrationFilterException.Input.class, + OrchestrationClientException.Synchronous.InputFilter.class, e -> { assertThat(e.getMessage()) .isEqualTo( @@ -503,7 +502,7 @@ void outputFilteringStrict() { assertThatThrownBy(client.chatCompletion(prompt, configWithFilter)::getContent) .isInstanceOfSatisfying( - OrchestrationFilterException.Output.class, + OrchestrationClientException.Synchronous.OutputFilter.class, e -> { assertThat(e.getMessage()).isEqualTo("Content filter filtered the output."); assertThat(e.getFilterDetails()) @@ -520,7 +519,6 @@ void outputFilteringStrict() { "llama_guard_3_8b", Map.of("violent_crimes", true))); assertThat(e.getErrorResponse()).isNull(); - assertThat(e.getStatusCode()).isNull(); assertThat(e.getAzureContentSafetyOutput()).isNotNull(); assertThat(e.getAzureContentSafetyOutput().getHate()).isEqualTo(NUMBER_6); @@ -759,9 +757,10 @@ void testThrowsOnContentFilter() { // this must not throw, since the stream is lazily evaluated var stream = mock.streamChatCompletion(new OrchestrationPrompt(""), config); assertThatThrownBy(stream::toList) - .isInstanceOf(OrchestrationFilterException.Output.class) + .isInstanceOf(OrchestrationClientException.Streaming.OutputFilter.class) .hasMessage("Content filter filtered the output.") - .extracting(e -> ((OrchestrationFilterException.Output) e).getFilterDetails()) + .extracting( + e -> ((OrchestrationClientException.Streaming.OutputFilter) e).getFilterDetails()) .isEqualTo(Map.of("azure_content_safety", Map.of("hate", 0, "self_harm", 0))); } @@ -785,11 +784,9 @@ void streamChatCompletionOutputFilterErrorHandling() throws IOException { assertThatThrownBy(() -> stream.forEach(System.out::println)) .hasMessage("Content filter filtered the output.") .isInstanceOfSatisfying( - OrchestrationFilterException.Output.class, + OrchestrationClientException.Streaming.OutputFilter.class, e -> { assertThat(e.getErrorResponse()).isNull(); - assertThat(e.getErrorResponseStreaming()).isNull(); - assertThat(e.getStatusCode()).isNull(); assertThat(e.getFilterDetails()) .isEqualTo( diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java index 6f1d48b1f..bb5f130ed 100644 --- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java +++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java @@ -5,7 +5,7 @@ import com.sap.ai.sdk.app.services.OrchestrationService; import com.sap.ai.sdk.orchestration.AzureFilterThreshold; import com.sap.ai.sdk.orchestration.OrchestrationChatResponse; -import com.sap.ai.sdk.orchestration.OrchestrationFilterException; +import com.sap.ai.sdk.orchestration.OrchestrationClientException; import com.sap.ai.sdk.orchestration.model.AzureContentSafetyInput; import com.sap.ai.sdk.orchestration.model.AzureContentSafetyOutput; import com.sap.ai.sdk.orchestration.model.DPIEntities; @@ -127,7 +127,7 @@ Object inputFiltering( final OrchestrationChatResponse response; try { response = service.inputFiltering(policy); - } catch (OrchestrationFilterException.Input e) { + } catch (OrchestrationClientException.Synchronous.InputFilter e) { final var msg = new StringBuilder( "[Http %d] Failed to obtain a response as the content was flagged by input filter. " @@ -159,7 +159,7 @@ Object outputFiltering( final String content; try { content = response.getContent(); - } catch (OrchestrationFilterException.Output e) { + } catch (OrchestrationClientException.Synchronous.OutputFilter e) { final var msg = new StringBuilder( "Failed to obtain a response as the content was flagged by output filter. "); @@ -188,7 +188,7 @@ Object llamaGuardInputFiltering( final OrchestrationChatResponse response; try { response = service.llamaGuardInputFilter(enabled); - } catch (OrchestrationFilterException.Input e) { + } catch (OrchestrationClientException.Synchronous.InputFilter e) { var msg = "[Http %d] Failed to obtain a response as the content was flagged by input filter. " .formatted(e.getStatusCode()); diff --git a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java index 379abd476..5543b6b15 100644 --- a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java +++ b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java @@ -15,7 +15,6 @@ import com.sap.ai.sdk.orchestration.Message; import com.sap.ai.sdk.orchestration.OrchestrationClient; import com.sap.ai.sdk.orchestration.OrchestrationClientException; -import com.sap.ai.sdk.orchestration.OrchestrationFilterException; import com.sap.ai.sdk.orchestration.OrchestrationModuleConfig; import com.sap.ai.sdk.orchestration.OrchestrationPrompt; import com.sap.ai.sdk.orchestration.TemplateConfig; @@ -221,7 +220,7 @@ void testInputFilteringStrict() { "Prompt filtered due to safety violations. Please modify the prompt and try again.") .hasMessageContaining("400 (Bad Request)") .isInstanceOfSatisfying( - OrchestrationFilterException.Input.class, + OrchestrationClientException.Synchronous.InputFilter.class, e -> { var actualAzureContentSafety = e.getAzureContentSafetyInput(); assertThat(actualAzureContentSafety).isNotNull(); @@ -253,7 +252,7 @@ void testOutputFilteringStrict() { assertThatThrownBy(response::getContent) .hasMessageContaining("Content filter filtered the output.") .isInstanceOfSatisfying( - OrchestrationFilterException.Output.class, + OrchestrationClientException.Synchronous.OutputFilter.class, e -> { var actualAzureContentSafety = e.getAzureContentSafetyOutput(); assertThat(actualAzureContentSafety).isNotNull(); @@ -280,12 +279,12 @@ void testOutputFilteringLenient() { @Test void testLlamaGuardEnabled() { assertThatThrownBy(() -> service.llamaGuardInputFilter(true)) - .isInstanceOf(OrchestrationFilterException.Input.class) + .isInstanceOf(OrchestrationClientException.Synchronous.InputFilter.class) .hasMessageContaining( "Prompt filtered due to safety violations. Please modify the prompt and try again.") .hasMessageContaining("400 (Bad Request)") .isInstanceOfSatisfying( - OrchestrationFilterException.Input.class, + OrchestrationClientException.Synchronous.InputFilter.class, e -> { var llamaGuard38b = e.getLlamaGuard38b(); assertThat(llamaGuard38b).isNotNull(); @@ -419,7 +418,7 @@ void testStreamingErrorHandlingInputFilter() { val configWithFilter = config.withInputFiltering(filterConfig); assertThatThrownBy(() -> client.streamChatCompletion(prompt, configWithFilter)) - .isInstanceOf(OrchestrationFilterException.Input.class) + .isInstanceOf(OrchestrationClientException.Streaming.InputFilter.class) .hasMessageContaining("status 400 (Bad Request)") .hasMessageContaining("Filtering Module - Input Filter"); }