diff --git a/core/src/main/java/com/sap/ai/sdk/core/common/ClientException.java b/core/src/main/java/com/sap/ai/sdk/core/common/ClientException.java index 221d670a7..cd311993f 100644 --- a/core/src/main/java/com/sap/ai/sdk/core/common/ClientException.java +++ b/core/src/main/java/com/sap/ai/sdk/core/common/ClientException.java @@ -1,6 +1,10 @@ package com.sap.ai.sdk.core.common; import com.google.common.annotations.Beta; +import javax.annotation.Nullable; +import lombok.AccessLevel; +import lombok.Getter; +import lombok.Setter; import lombok.experimental.StandardException; /** @@ -10,4 +14,14 @@ */ @Beta @StandardException -public class ClientException extends RuntimeException {} +public class ClientException extends RuntimeException { + + /** + * Wraps a structured error payload received from the remote service, if available. This can be + * used to extract more detailed error information. + */ + @Nullable + @Getter(onMethod_ = @Beta, value = AccessLevel.PROTECTED) + @Setter(onMethod_ = @Beta, value = AccessLevel.PROTECTED) + ClientError clientError; +} diff --git a/core/src/main/java/com/sap/ai/sdk/core/common/ClientExceptionFactory.java b/core/src/main/java/com/sap/ai/sdk/core/common/ClientExceptionFactory.java new file mode 100644 index 000000000..991c11c52 --- /dev/null +++ b/core/src/main/java/com/sap/ai/sdk/core/common/ClientExceptionFactory.java @@ -0,0 +1,37 @@ +package com.sap.ai.sdk.core.common; + +import com.google.common.annotations.Beta; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +/** + * A factory whose implementations can provide customized exception types and error mapping logic + * for different service clients or error scenarios. + * + * @param The subtype of {@link ClientException} to be created by this factory. + * @param The subtype of {@link ClientError} payload that can be processed by this factory. + */ +@Beta +public interface ClientExceptionFactory { + + /** + * Creates an exception with a message and optional cause. + * + * @param message A descriptive message for the exception. + * @param cause An optional cause of the exception, can be null if not applicable. + * @return An instance of the specified {@link ClientException} type + */ + @Nonnull + E build(@Nonnull final String message, @Nullable final Throwable cause); + + /** + * Creates an exception from a given message and an HTTP error response that has been successfully + * deserialized into a {@link ClientError} object. + * + * @param message A descriptive message for the exception. + * @param clientError The structured {@link ClientError} object deserialized from the response. + * @return An instance of the specified {@link ClientException} type + */ + @Nonnull + E buildFromClientError(@Nonnull final String message, @Nonnull final R clientError); +} diff --git a/core/src/main/java/com/sap/ai/sdk/core/common/ClientResponseHandler.java b/core/src/main/java/com/sap/ai/sdk/core/common/ClientResponseHandler.java index ecfb56a6b..b28b6691f 100644 --- a/core/src/main/java/com/sap/ai/sdk/core/common/ClientResponseHandler.java +++ b/core/src/main/java/com/sap/ai/sdk/core/common/ClientResponseHandler.java @@ -6,36 +6,40 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.annotations.Beta; import io.vavr.control.Try; -import java.io.IOException; import java.nio.charset.StandardCharsets; -import java.util.Objects; -import java.util.function.BiFunction; +import java.util.Optional; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.apache.hc.core5.http.ClassicHttpResponse; import org.apache.hc.core5.http.ContentType; import org.apache.hc.core5.http.HttpEntity; -import org.apache.hc.core5.http.ParseException; import org.apache.hc.core5.http.io.HttpClientResponseHandler; import org.apache.hc.core5.http.io.entity.EntityUtils; /** * Parse incoming JSON responses and handles any errors. For internal use only. * - * @param The type of the response. + * @param The type of the successful response. * @param The type of the exception to throw. + * @param The type of the error response. * @since 1.1.0 */ @Beta @Slf4j @RequiredArgsConstructor -public class ClientResponseHandler +public class ClientResponseHandler implements HttpClientResponseHandler { - @Nonnull final Class responseType; - @Nonnull private final Class errorType; - @Nonnull final BiFunction exceptionConstructor; + /** The HTTP success response type */ + @Nonnull final Class successType; + + /** The HTTP error response type */ + @Nonnull final Class errorType; + + /** The factory to create exceptions for Http 4xx/5xx responses. */ + @Nonnull final ClientExceptionFactory exceptionFactory; /** The parses for JSON responses, will be private once we can remove mixins */ @Nonnull ObjectMapper objectMapper = getDefaultObjectMapper(); @@ -48,7 +52,7 @@ public class ClientResponseHandler */ @Beta @Nonnull - public ClientResponseHandler objectMapper(@Nonnull final ObjectMapper jackson) { + public ClientResponseHandler objectMapper(@Nonnull final ObjectMapper jackson) { objectMapper = jackson; return this; } @@ -64,91 +68,105 @@ public ClientResponseHandler objectMapper(@Nonnull final ObjectMapper jack @Override public T handleResponse(@Nonnull final ClassicHttpResponse response) throws E { if (response.getCode() >= 300) { - buildExceptionAndThrow(response); + buildAndThrowException(response); } - return parseResponse(response); + return parseSuccess(response); } // The InputStream of the HTTP entity is closed by EntityUtils.toString @SuppressWarnings("PMD.CloseResource") @Nonnull - private T parseResponse(@Nonnull final ClassicHttpResponse response) throws E { + private T parseSuccess(@Nonnull final ClassicHttpResponse response) throws E { final HttpEntity responseEntity = response.getEntity(); if (responseEntity == null) { - throw exceptionConstructor.apply("Response was empty.", null); + throw exceptionFactory.build("The HTTP Response is empty", null); } - val content = getContent(responseEntity); + + val content = + tryGetContent(responseEntity) + .getOrElseThrow(e -> exceptionFactory.build("Failed to parse response entity.", e)); try { - return objectMapper.readValue(content, responseType); + return objectMapper.readValue(content, successType); } catch (final JsonProcessingException e) { - log.error("Failed to parse response to type {}", responseType); - throw exceptionConstructor.apply("Failed to parse response", e); + log.error("Failed to parse response to type {}", successType); + throw exceptionFactory.build("Failed to parse response", e); } } @Nonnull - private String getContent(@Nonnull final HttpEntity entity) { - try { - return EntityUtils.toString(entity, StandardCharsets.UTF_8); - } catch (IOException | ParseException e) { - throw exceptionConstructor.apply("Failed to read response content", e); - } + private Try tryGetContent(@Nonnull final HttpEntity entity) { + return Try.of(() -> EntityUtils.toString(entity, StandardCharsets.UTF_8)); } /** - * Parse the error response and throw an exception. + * Process the error response and throw an exception. * - * @param response The response to process + * @param httpResponse The response to process + * @throws ClientException if the response is an error (4xx/5xx) */ @SuppressWarnings("PMD.CloseResource") - public void buildExceptionAndThrow(@Nonnull final ClassicHttpResponse response) throws E { - val exception = - exceptionConstructor.apply( - "Request failed with status %s %s" - .formatted(response.getCode(), response.getReasonPhrase()), - null); - val entity = response.getEntity(); + protected void buildAndThrowException(@Nonnull final ClassicHttpResponse httpResponse) throws E { + + val entity = httpResponse.getEntity(); + if (entity == null) { - throw exception; + val message = getErrorMessage(httpResponse, "The HTTP Response is empty"); + throw exceptionFactory.build(message, null); } - val maybeContent = Try.of(() -> getContent(entity)); + val maybeContent = tryGetContent(entity); if (maybeContent.isFailure()) { - exception.addSuppressed(maybeContent.getCause()); - throw exception; + val message = getErrorMessage(httpResponse, "Failed to read the response content"); + val baseException = exceptionFactory.build(message, null); + baseException.addSuppressed(maybeContent.getCause()); + throw baseException; } val content = maybeContent.get(); - if (content.isBlank()) { - throw exception; + if (content == null || content.isBlank()) { + val message = getErrorMessage(httpResponse, "Empty or blank response content"); + throw exceptionFactory.build(message, null); } log.error( "The service responded with an HTTP {} ({})", - response.getCode(), - response.getReasonPhrase()); + httpResponse.getCode(), + httpResponse.getReasonPhrase()); val contentType = ContentType.parse(entity.getContentType()); if (!ContentType.APPLICATION_JSON.isSameMimeType(contentType)) { - throw exception; + val message = getErrorMessage(httpResponse, "The response Content-Type is not JSON"); + throw exceptionFactory.build(message, null); } - parseErrorAndThrow(content, exception); + parseErrorResponseAndThrow(content, httpResponse); } /** - * Parse the error response and throw an exception. + * Parses the JSON content of an error response and throws a module specific exception. * - * @param errorResponse the error response, most likely a unique JSON class. - * @param baseException a base exception to add the error message to. + * @param content The JSON content of the error response. + * @param httpResponse The HTTP response that contains the error. + * @throws ClientException if the response is an error (4xx/5xx) */ - public void parseErrorAndThrow( - @Nonnull final String errorResponse, @Nonnull final E baseException) throws E { - val maybeError = Try.of(() -> objectMapper.readValue(errorResponse, errorType)); - if (maybeError.isFailure()) { - baseException.addSuppressed(maybeError.getCause()); + protected void parseErrorResponseAndThrow( + @Nonnull final String content, @Nonnull final ClassicHttpResponse httpResponse) throws E { + val maybeClientError = Try.of(() -> objectMapper.readValue(content, errorType)); + if (maybeClientError.isFailure()) { + val message = getErrorMessage(httpResponse, "Failed to parse the JSON error response"); + val baseException = exceptionFactory.build(message, null); + baseException.addSuppressed(maybeClientError.getCause()); throw baseException; } + final R clientError = maybeClientError.get(); + val message = getErrorMessage(httpResponse, clientError.getMessage()); + throw exceptionFactory.buildFromClientError(message, clientError); + } + + private static String getErrorMessage( + @Nonnull final ClassicHttpResponse httpResponse, @Nullable final String additionalMessage) { + val baseErrorMessage = + "Request failed with status %d (%s)" + .formatted(httpResponse.getCode(), httpResponse.getReasonPhrase()); - val error = Objects.requireNonNullElse(maybeError.get().getMessage(), ""); - val message = "%s and error message: '%s'".formatted(baseException.getMessage(), error); - throw exceptionConstructor.apply(message, baseException); + val message = Optional.ofNullable(additionalMessage).orElse(""); + return message.isEmpty() ? baseErrorMessage : "%s: %s".formatted(baseErrorMessage, message); } } diff --git a/core/src/main/java/com/sap/ai/sdk/core/common/ClientStreamingHandler.java b/core/src/main/java/com/sap/ai/sdk/core/common/ClientStreamingHandler.java index 1627a1232..1d27edb9d 100644 --- a/core/src/main/java/com/sap/ai/sdk/core/common/ClientStreamingHandler.java +++ b/core/src/main/java/com/sap/ai/sdk/core/common/ClientStreamingHandler.java @@ -3,7 +3,6 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.annotations.Beta; import java.io.IOException; -import java.util.function.BiFunction; import java.util.stream.Stream; import javax.annotation.Nonnull; import lombok.extern.slf4j.Slf4j; @@ -14,12 +13,14 @@ * * @param The type of the response. * @param The type of the exception to throw. + * @param The type of the error. * @since 1.2.0 */ @Beta @Slf4j -public class ClientStreamingHandler - extends ClientResponseHandler { +public class ClientStreamingHandler< + D extends StreamedDelta, R extends ClientError, E extends ClientException> + extends ClientResponseHandler { /** * Set the {@link ObjectMapper} to use for parsing JSON responses. @@ -28,7 +29,7 @@ public class ClientStreamingHandler objectMapper(@Nonnull final ObjectMapper jackson) { + public ClientStreamingHandler objectMapper(@Nonnull final ObjectMapper jackson) { super.objectMapper(jackson); return this; } @@ -38,13 +39,13 @@ public ClientStreamingHandler objectMapper(@Nonnull final ObjectMapper jac * * @param deltaType The type of the response. * @param errorType The type of the error. - * @param exceptionType The type of the exception to throw. + * @param exceptionFactory The factory to create exceptions. */ public ClientStreamingHandler( @Nonnull final Class deltaType, - @Nonnull final Class errorType, - @Nonnull final BiFunction exceptionType) { - super(deltaType, errorType, exceptionType); + @Nonnull final Class errorType, + @Nonnull final ClientExceptionFactory exceptionFactory) { + super(deltaType, errorType, exceptionFactory); } /** @@ -59,26 +60,27 @@ public ClientStreamingHandler( @Nonnull public Stream handleStreamingResponse(@Nonnull final ClassicHttpResponse response) throws E { if (response.getCode() >= 300) { - super.buildExceptionAndThrow(response); + super.buildAndThrowException(response); } - return IterableStreamConverter.lines(response.getEntity(), exceptionConstructor) + + return IterableStreamConverter.lines(response.getEntity(), exceptionFactory) // half of the lines are empty newlines, the last line is "data: [DONE]" .filter(line -> !line.isEmpty() && !"data: [DONE]".equals(line.trim())) .peek( line -> { if (!line.startsWith("data: ")) { final String msg = "Failed to parse response"; - super.parseErrorAndThrow(line, exceptionConstructor.apply(msg, null)); + throw exceptionFactory.build(msg, null); } }) .map( line -> { final String data = line.substring(5); // remove "data: " try { - return objectMapper.readValue(data, responseType); + return objectMapper.readValue(data, successType); } catch (final IOException e) { // exception message e gets lost - log.error("Failed to parse delta chunk to type {}", responseType); - throw exceptionConstructor.apply("Failed to parse delta chunk", e); + log.error("Failed to parse delta chunk to type {}", successType); + throw exceptionFactory.build("Failed to parse delta chunk", e); } }); } diff --git a/core/src/main/java/com/sap/ai/sdk/core/common/IterableStreamConverter.java b/core/src/main/java/com/sap/ai/sdk/core/common/IterableStreamConverter.java index 574a05283..fb06af6fc 100644 --- a/core/src/main/java/com/sap/ai/sdk/core/common/IterableStreamConverter.java +++ b/core/src/main/java/com/sap/ai/sdk/core/common/IterableStreamConverter.java @@ -13,7 +13,6 @@ import java.util.NoSuchElementException; import java.util.Spliterators; import java.util.concurrent.Callable; -import java.util.function.BiFunction; import java.util.function.Function; import java.util.stream.Stream; import java.util.stream.StreamSupport; @@ -91,25 +90,25 @@ public T next() { * when an exception occurred. * * @param entity The HTTP entity object. - * @param exceptionType The type of the client exception to throw in case of an error. + * @param exceptionFactory The exception factory to use for creating exceptions. * @return A sequential Stream object. * @throws ClientException if the provided HTTP entity object is {@code null} or empty. */ @SuppressWarnings("PMD.CloseResource") // Stream is closed automatically when consumed @Nonnull - static Stream lines( + static Stream lines( @Nullable final HttpEntity entity, - @Nonnull final BiFunction exceptionType) + @Nonnull final ClientExceptionFactory exceptionFactory) throws ClientException { if (entity == null) { - throw exceptionType.apply("Orchestration service response was empty.", null); + throw exceptionFactory.build("The HTTP Response is empty", null); } final InputStream inputStream; try { inputStream = entity.getContent(); } catch (final IOException e) { - throw exceptionType.apply("Failed to read response content.", e); + throw exceptionFactory.build("Failed to read response content.", e); } final var reader = new BufferedReader(new InputStreamReader(inputStream, UTF_8), BUFFER_SIZE); @@ -122,7 +121,7 @@ static Stream lines( "Could not close input stream with error: {} (ignored)", e.getClass().getSimpleName())); final Function errHandler = - e -> exceptionType.apply("Parsing response content was interrupted.", e); + e -> exceptionFactory.build("Parsing response content was interrupted", e); final var iterator = new IterableStreamConverter<>(reader::readLine, closeHandler, errHandler); final var spliterator = Spliterators.spliteratorUnknownSize(iterator, ORDERED | NONNULL); diff --git a/core/src/test/java/com/sap/ai/sdk/core/common/ClientResponseHandlerTest.java b/core/src/test/java/com/sap/ai/sdk/core/common/ClientResponseHandlerTest.java index f8a90e90e..32bb4495b 100644 --- a/core/src/test/java/com/sap/ai/sdk/core/common/ClientResponseHandlerTest.java +++ b/core/src/test/java/com/sap/ai/sdk/core/common/ClientResponseHandlerTest.java @@ -7,7 +7,9 @@ import static org.mockito.Mockito.when; import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonParseException; import java.io.IOException; +import javax.annotation.Nonnull; import lombok.Data; import lombok.SneakyThrows; import lombok.experimental.StandardException; @@ -30,30 +32,27 @@ static class MyError implements ClientError { @StandardException static class MyException extends ClientException {} - @Test - public void testParseErrorAndThrow() { - var sut = new ClientResponseHandler<>(MyResponse.class, MyError.class, MyException::new); - - MyException cause = new MyException("Something wrong"); + static class MyExceptionFactory implements ClientExceptionFactory { + @Nonnull + @Override + public MyException build(@Nonnull String message, Throwable cause) { + return new MyException(message, cause); + } - assertThatThrownBy(() -> sut.parseErrorAndThrow("{\"message\":\"foobar\"}", cause)) - .isInstanceOf(MyException.class) - .hasMessage("Something wrong and error message: 'foobar'") - .hasCause(cause); - - assertThatThrownBy(() -> sut.parseErrorAndThrow("{\"foo\":\"bar\"}", cause)) - .isInstanceOf(MyException.class) - .hasMessage("Something wrong and error message: ''") - .hasCause(cause); - - assertThatThrownBy(() -> sut.parseErrorAndThrow("foobar", cause)) - .isEqualTo(cause); + @Nonnull + @Override + public MyException buildFromClientError(@Nonnull String message, @Nonnull MyError clientError) { + var ex = new MyException(message); + ex.clientError = clientError; + return ex; + } } @SneakyThrows @Test - public void testBuildExceptionAndThrow() { - var sut = new ClientResponseHandler<>(MyResponse.class, MyError.class, MyException::new); + void testBuildExceptionAndThrow() { + var sut = + new ClientResponseHandler<>(MyResponse.class, MyError.class, new MyExceptionFactory()); HttpEntity entityWithNetworkIssues = spy(new StringEntity("")); doThrow(new IOException("Network issues")).when(entityWithNetworkIssues).writeTo(any()); @@ -65,27 +64,63 @@ public void testBuildExceptionAndThrow() { .thenReturn(entityWithNetworkIssues) .thenReturn(new StringEntity("", ContentType.APPLICATION_JSON)) .thenReturn(new StringEntity("oh", ContentType.TEXT_HTML)) - .thenReturn(new StringEntity("{\"message\":\"foobar\"}", ContentType.APPLICATION_JSON)); + .thenReturn(new StringEntity("{\"message\":\"foobar\"}", ContentType.APPLICATION_JSON)) + .thenReturn(new StringEntity("{\"message\"-\"foobar\"}", ContentType.APPLICATION_JSON)) + .thenReturn(new StringEntity("{\"foo\":\"bar\"}", ContentType.APPLICATION_JSON)) + .thenReturn(new StringEntity("foobar", ContentType.APPLICATION_JSON)); - assertThatThrownBy(() -> sut.buildExceptionAndThrow(response)) + assertThatThrownBy(() -> sut.handleResponse(response)) + .isInstanceOf(MyException.class) + .hasMessage("Request failed with status 400 (Bad Request): The HTTP Response is empty") + .hasNoCause() + .extracting(e -> ((MyException) e).getClientError()) + .isNull(); + assertThatThrownBy(() -> sut.handleResponse(response)) + .isInstanceOf(MyException.class) + .hasMessage( + "Request failed with status 400 (Bad Request): Failed to read the response content") + .extracting(e -> e.getSuppressed()[0]) + .isInstanceOf(IOException.class) + .extracting(Throwable::getMessage) + .isEqualTo("Network issues"); + assertThatThrownBy(() -> sut.handleResponse(response)) + .isInstanceOf(MyException.class) + .hasMessage("Request failed with status 400 (Bad Request): Empty or blank response content") + .hasNoCause() + .extracting(e -> ((MyException) e).getClientError()) + .isNull(); + assertThatThrownBy(() -> sut.handleResponse(response)) .isInstanceOf(MyException.class) - .hasMessage("Request failed with status 400 Bad Request") - .hasNoCause(); - assertThatThrownBy(() -> sut.buildExceptionAndThrow(response)) + .hasMessage( + "Request failed with status 400 (Bad Request): The response Content-Type is not JSON") + .hasNoCause() + .extracting(e -> ((MyException) e).getClientError()) + .isNull(); + assertThatThrownBy(() -> sut.handleResponse(response)) .isInstanceOf(MyException.class) - .hasMessage("Request failed with status 400 Bad Request") - .hasNoCause(); - assertThatThrownBy(() -> sut.buildExceptionAndThrow(response)) + .hasMessage("Request failed with status 400 (Bad Request): foobar") + .hasNoCause() + .extracting(e -> ((MyException) e).getClientError()) + .isNotNull(); + assertThatThrownBy(() -> sut.handleResponse(response)) .isInstanceOf(MyException.class) - .hasMessage("Request failed with status 400 Bad Request") - .hasNoCause(); - assertThatThrownBy(() -> sut.buildExceptionAndThrow(response)) + .hasMessage( + "Request failed with status 400 (Bad Request): Failed to parse the JSON error response") + .hasNoCause() + .extracting(e -> e.getSuppressed()[0]) + .isInstanceOf(JsonParseException.class); + assertThatThrownBy(() -> sut.handleResponse(response)) .isInstanceOf(MyException.class) - .hasMessage("Request failed with status 400 Bad Request") - .hasNoCause(); - assertThatThrownBy(() -> sut.buildExceptionAndThrow(response)) + .hasMessage("Request failed with status 400 (Bad Request)") + .hasNoCause() + .extracting(e -> ((MyException) e).getClientError()) + .isNotNull(); + assertThatThrownBy(() -> sut.handleResponse(response)) .isInstanceOf(MyException.class) - .hasMessage("Request failed with status 400 Bad Request and error message: 'foobar'") - .hasCause(new MyException("Request failed with status 400 Bad Request")); + .hasMessage( + "Request failed with status 400 (Bad Request): Failed to parse the JSON error response") + .hasNoCause() + .extracting(e -> e.getSuppressed()[0]) + .isInstanceOf(JsonParseException.class); } } diff --git a/core/src/test/java/com/sap/ai/sdk/core/common/ClientStreamingHandlerTest.java b/core/src/test/java/com/sap/ai/sdk/core/common/ClientStreamingHandlerTest.java new file mode 100644 index 000000000..1b36bf6e2 --- /dev/null +++ b/core/src/test/java/com/sap/ai/sdk/core/common/ClientStreamingHandlerTest.java @@ -0,0 +1,110 @@ +package com.sap.ai.sdk.core.common; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonParseException; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import lombok.Data; +import lombok.SneakyThrows; +import org.apache.hc.core5.http.io.entity.StringEntity; +import org.apache.hc.core5.http.message.BasicClassicHttpResponse; +import org.junit.jupiter.api.Test; + +/** + * Tests for {@link ClientStreamingHandler}. Inherits common test utilities and mock classes from + * ClientResponseHandlerTest. + */ +class ClientStreamingHandlerTest extends ClientResponseHandlerTest { + + @Data + static class MyStreamedDelta implements StreamedDelta { + @JsonProperty("value") + private String value; // Simulates the content + + @JsonProperty("finish_reason") + private String finishReason; + + @Nonnull + @Override + public String getDeltaContent() { + return value != null ? value : ""; + } + + @Nullable + @Override + public String getFinishReason() { + return finishReason; + } + } + + @SneakyThrows + @Test + void testHandleStreamingResponse() { + var sut = + new ClientStreamingHandler<>( + MyStreamedDelta.class, MyError.class, new MyExceptionFactory()); + + final String validStreamContent = + """ + data: {"value":"delta1"} + + data: {"value":"delta2", "finish_reason": "length"} + data: [DONE] + """; + + final String emptyStreamContent = + """ + data: [DONE] + + """; + + final String malformedLineContent = + """ + data: {"value":"deltaA"} + malformed line here + data: {"value":"deltaB"} + data: [DONE] + """; + + final String invalidJsonContent = + """ + data: {"value":"deltaX"} + data: {"value"-"deltaY"} + data: [DONE] + """; + + var response = spy(new BasicClassicHttpResponse(200, "OK")); + when(response.getEntity()) + .thenReturn(new StringEntity(validStreamContent)) + .thenReturn(new StringEntity(emptyStreamContent)) + .thenReturn(new StringEntity(malformedLineContent)) + .thenReturn(new StringEntity(invalidJsonContent)); + + var stream1 = sut.handleStreamingResponse(response); + var deltas1 = stream1.toList(); + assertThat(deltas1).hasSize(2); + assertThat(deltas1.get(0).getDeltaContent()).isEqualTo("delta1"); + assertThat(deltas1.get(0).getFinishReason()).isNull(); + assertThat(deltas1.get(1).getDeltaContent()).isEqualTo("delta2"); + assertThat(deltas1.get(1).getFinishReason()).isEqualTo("length"); + + var stream2 = sut.handleStreamingResponse(response); + assertThat(stream2).isEmpty(); + + var stream3 = sut.handleStreamingResponse(response); + assertThatThrownBy(stream3::toList) + .isInstanceOf(MyException.class) + .hasMessageContaining("Failed to parse response"); + + var stream4 = sut.handleStreamingResponse(response); + assertThatThrownBy(stream4::toList) + .isInstanceOf(MyException.class) + .hasMessageContaining("Failed to parse delta chunk") + .hasCauseInstanceOf(JsonParseException.class); + } +} diff --git a/core/src/test/java/com/sap/ai/sdk/core/common/IterableStreamConverterTest.java b/core/src/test/java/com/sap/ai/sdk/core/common/IterableStreamConverterTest.java index b9123026e..355b523a7 100644 --- a/core/src/test/java/com/sap/ai/sdk/core/common/IterableStreamConverterTest.java +++ b/core/src/test/java/com/sap/ai/sdk/core/common/IterableStreamConverterTest.java @@ -17,6 +17,7 @@ import java.io.InputStream; import java.nio.charset.StandardCharsets; import java.util.concurrent.atomic.AtomicInteger; +import javax.annotation.Nonnull; import lombok.SneakyThrows; import lombok.experimental.StandardException; import org.apache.hc.core5.http.ContentType; @@ -34,7 +35,7 @@ void testLines() { final var inputStream = spy(new ByteArrayInputStream(input.getBytes(StandardCharsets.UTF_8))); final var entity = new InputStreamEntity(inputStream, ContentType.TEXT_PLAIN); - final var sut = IterableStreamConverter.lines(entity, TestClientException::new); + final var sut = IterableStreamConverter.lines(entity, new TestClientExceptionFactory()); verify(inputStream, never()).read(); verify(inputStream, never()).read(any()); verify(inputStream, never()).read(any(), anyInt(), anyInt()); @@ -70,7 +71,7 @@ void testLinesFindFirst() { final var entity = new InputStreamEntity(inputStream, ContentType.TEXT_PLAIN); - final var sut = IterableStreamConverter.lines(entity, TestClientException::new); + final var sut = IterableStreamConverter.lines(entity, new TestClientExceptionFactory()); assertThat(sut.findFirst()).contains("Foo Bar"); verify(inputStream, times(1)).read(any(), anyInt(), anyInt()); verify(inputStream, never()).close(); @@ -94,10 +95,10 @@ void testLinesThrows() { final var entity = new InputStreamEntity(inputStream, ContentType.TEXT_PLAIN); - final var sut = IterableStreamConverter.lines(entity, TestClientException::new); + final var sut = IterableStreamConverter.lines(entity, new TestClientExceptionFactory()); assertThatThrownBy(sut::count) .isInstanceOf(TestClientException.class) - .hasMessage("Parsing response content was interrupted.") + .hasMessage("Parsing response content was interrupted") .cause() .isInstanceOf(IOException.class) .hasMessage("Ups!"); @@ -107,4 +108,23 @@ void testLinesThrows() { @StandardException public static class TestClientException extends ClientException {} + + static class TestClientExceptionFactory + implements ClientExceptionFactory { + + @Nonnull + @Override + public TestClientException build(@Nonnull String message, Throwable cause) { + return new TestClientException(message, cause); + } + + @Nonnull + @Override + public TestClientException buildFromClientError( + @Nonnull String message, @Nonnull ClientError clientError) { + TestClientException exception = new TestClientException(message); + exception.clientError = clientError; + return exception; + } + } } diff --git a/docs/release_notes.md b/docs/release_notes.md index 3a5838c45..383d64f75 100644 --- a/docs/release_notes.md +++ b/docs/release_notes.md @@ -16,7 +16,13 @@ ### ✨ New Functionality -- +- [Core] Added `ClientExceptionFactory` interface to provide custom exception mapping logic for different service clients. +- Extend `OpenAiClientException` and `OrchestrationClientException` to retrieve error diagnostics information received from remote service using `getErrorResponse`. +- [Orchestration] Introduced filtering related exceptions along with convenience methods to obtain additional contextual information. + - `OrchestrationInputFilterException` for prompt filtering and `OrchestrationOutputFilterException` for response filtering. + - `getFilterDetails()`: Returns a map of all filter details. + - `getAzureContentSafetyInput()` and `getAzureContentSafetyInput()` : Returns Azure Content Safety filter scores + - `getLlamaGuard38b()`: Returns LlamaGuard filter scores ### 📈 Improvements diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java index b498a1b5e..71416711a 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java @@ -431,7 +431,8 @@ private T executeRequest( final var client = ApacheHttpClient5Accessor.getHttpClient(destination); return client.execute( request, - new ClientResponseHandler<>(responseType, OpenAiError.class, OpenAiClientException::new)); + new ClientResponseHandler<>( + responseType, OpenAiError.class, new OpenAiExceptionFactory())); } catch (final IOException e) { throw new OpenAiClientException("Request to OpenAI model failed", e); } @@ -442,7 +443,8 @@ private Stream streamRequest( final BasicClassicHttpRequest request, @Nonnull final Class deltaType) { try { final var client = ApacheHttpClient5Accessor.getHttpClient(destination); - return new ClientStreamingHandler<>(deltaType, OpenAiError.class, OpenAiClientException::new) + return new ClientStreamingHandler<>( + deltaType, OpenAiError.class, new OpenAiExceptionFactory()) .objectMapper(JACKSON) .handleStreamingResponse(client.executeOpen(null, request, null)); } catch (final IOException e) { diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientException.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientException.java index a61493af0..ca3e9d50a 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientException.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientException.java @@ -1,8 +1,32 @@ package com.sap.ai.sdk.foundationmodels.openai; +import com.google.common.annotations.Beta; import com.sap.ai.sdk.core.common.ClientException; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ErrorResponse; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; import lombok.experimental.StandardException; /** Generic exception for errors occurring when using OpenAI foundation models. */ @StandardException -public class OpenAiClientException extends ClientException {} +public class OpenAiClientException extends ClientException { + OpenAiClientException(@Nonnull final String message, @Nonnull final OpenAiError clientError) { + super(message); + setClientError(clientError); + } + + /** + * Retrieves the {@link ErrorResponse} from the OpenAI service, if available. + * + * @return The {@link ErrorResponse} object, or {@code null} if not available. + */ + @Beta + @Nullable + public ErrorResponse getErrorResponse() { + final var clientError = super.getClientError(); + if (clientError instanceof OpenAiError openAiError) { + return openAiError.getErrorResponse(); + } + return null; + } +} diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiError.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiError.java index d6618fd2a..a5c9888dc 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiError.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiError.java @@ -19,7 +19,7 @@ @AllArgsConstructor(onConstructor = @__({@JsonCreator}), access = AccessLevel.PROTECTED) public class OpenAiError implements ClientError { /** The original error response from the OpenAI API. */ - ErrorResponse originalResponse; + ErrorResponse errorResponse; /** * Gets the error message from the contained original response. @@ -28,6 +28,6 @@ public class OpenAiError implements ClientError { */ @Nonnull public String getMessage() { - return originalResponse.getError().getMessage(); + return errorResponse.getError().getMessage(); } } diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiExceptionFactory.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiExceptionFactory.java new file mode 100644 index 000000000..4cdf65af7 --- /dev/null +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiExceptionFactory.java @@ -0,0 +1,23 @@ +package com.sap.ai.sdk.foundationmodels.openai; + +import com.google.common.annotations.Beta; +import com.sap.ai.sdk.core.common.ClientExceptionFactory; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +@Beta +class OpenAiExceptionFactory implements ClientExceptionFactory { + + @Nonnull + public OpenAiClientException build( + @Nonnull final String message, @Nullable final Throwable cause) { + return new OpenAiClientException(message, cause); + } + + @Nonnull + @Override + public OpenAiClientException buildFromClientError( + @Nonnull final String message, @Nonnull final OpenAiError openAiError) { + return new OpenAiClientException(message, openAiError); + } +} diff --git a/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/BaseOpenAiClientTest.java b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/BaseOpenAiClientTest.java index 122b074d7..d1c97f8c7 100644 --- a/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/BaseOpenAiClientTest.java +++ b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/BaseOpenAiClientTest.java @@ -123,7 +123,7 @@ static void assertForErrorHandling(@Nonnull final Runnable request) { .assertThatThrownBy(request::run) .describedAs("Error objects from OpenAI should be interpreted") .isInstanceOf(OpenAiClientException.class) - .hasMessageContaining("error message: 'foo'"); + .hasMessageContaining("400 (Bad Request): foo"); softly .assertThatThrownBy(request::run) @@ -143,7 +143,7 @@ static void assertForErrorHandling(@Nonnull final Runnable request) { .assertThatThrownBy(request::run) .describedAs("Empty responses should be handled") .isInstanceOf(OpenAiClientException.class) - .hasMessageContaining("was empty"); + .hasMessageContaining("is empty"); softly.assertAll(); } diff --git a/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientGeneratedTest.java b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientGeneratedTest.java index 527ad9f36..6b437a9e7 100644 --- a/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientGeneratedTest.java +++ b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientGeneratedTest.java @@ -298,7 +298,7 @@ void streamChatCompletionDeltasErrorHandling() throws IOException { try (var stream = client.streamChatCompletionDeltas(request)) { assertThatThrownBy(() -> stream.forEach(System.out::println)) .isInstanceOf(OpenAiClientException.class) - .hasMessage("Failed to parse response and error message: 'exceeded token rate limit'"); + .hasMessage("Failed to parse response"); } Mockito.verify(inputStream, times(1)).close(); diff --git a/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientTest.java b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientTest.java index eb57255a2..abd9c85f6 100644 --- a/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientTest.java +++ b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientTest.java @@ -261,7 +261,7 @@ void streamChatCompletionDeltasErrorHandling() throws IOException { try (var stream = client.streamChatCompletionDeltas(request)) { assertThatThrownBy(() -> stream.forEach(System.out::println)) .isInstanceOf(OpenAiClientException.class) - .hasMessage("Failed to parse response and error message: 'exceeded token rate limit'"); + .hasMessage("Failed to parse response"); } Mockito.verify(inputStream, times(1)).close(); diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java index 95629f82c..52c178a35 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java @@ -14,8 +14,10 @@ import com.sap.ai.sdk.orchestration.model.TokenUsage; import com.sap.ai.sdk.orchestration.model.ToolChatMessage; import com.sap.ai.sdk.orchestration.model.UserChatMessage; +import io.vavr.control.Try; import java.util.ArrayList; import java.util.List; +import java.util.Map; import javax.annotation.Nonnull; import lombok.RequiredArgsConstructor; import lombok.Value; @@ -33,18 +35,26 @@ public class OrchestrationChatResponse { *

Note: If there are multiple choices only the first one is returned * * @return the message content or empty string. - * @throws OrchestrationClientException if the content filter filtered the output. + * @throws OrchestrationFilterException.Output if the content filter filtered the output. */ @Nonnull - public String getContent() throws OrchestrationClientException { + public String getContent() throws OrchestrationFilterException.Output { final var choice = getChoice(); if ("content_filter".equals(choice.getFinishReason())) { - throw new OrchestrationClientException("Content filter filtered the output."); + final var filterDetails = Try.of(this::getOutputFilteringChoices).getOrElseGet(e -> Map.of()); + final var message = "Content filter filtered the output."; + throw new OrchestrationFilterException.Output(message, filterDetails); } return choice.getMessage().getContent(); } + @SuppressWarnings("unchecked") + private Map getOutputFilteringChoices() { + final var f = getOriginalResponse().getModuleResults().getOutputFiltering(); + return ((List>) ((Map) f.getData()).get("choices")).get(0); + } + /** * Get the token usage. * diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java index 0f07c291f..a16101017 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java @@ -15,6 +15,9 @@ import com.sap.ai.sdk.orchestration.model.ModuleConfigs; import com.sap.ai.sdk.orchestration.model.OrchestrationConfig; import com.sap.cloud.sdk.cloudplatform.connectivity.HttpDestination; +import io.vavr.control.Try; +import java.util.List; +import java.util.Map; import java.util.function.Supplier; import java.util.stream.Stream; import javax.annotation.Nonnull; @@ -110,13 +113,24 @@ public Stream streamChatCompletion( .map(OrchestrationChatCompletionDelta::getDeltaContent); } - private static void throwOnContentFilter(@Nonnull final OrchestrationChatCompletionDelta delta) { + private static void throwOnContentFilter(@Nonnull final OrchestrationChatCompletionDelta delta) + throws OrchestrationFilterException.Output { final String finishReason = delta.getFinishReason(); if (finishReason != null && finishReason.equals("content_filter")) { - throw new OrchestrationClientException("Content filter filtered the output."); + final var filterDetails = + Try.of(() -> getOutputFilteringChoices(delta)).getOrElseGet(e -> Map.of()); + final var message = "Content filter filtered the output."; + throw new OrchestrationFilterException.Output(message, filterDetails); } } + @SuppressWarnings("unchecked") + private static Map getOutputFilteringChoices( + @Nonnull final OrchestrationChatCompletionDelta delta) { + final var f = delta.getModuleResults().getOutputFiltering(); + return ((List>) ((Map) f.getData()).get("choices")).get(0); + } + /** * Serializes the given request, executes it and deserializes the response. * diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClientException.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClientException.java index bb96adba9..34493792f 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClientException.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClientException.java @@ -1,8 +1,46 @@ package com.sap.ai.sdk.orchestration; +import com.google.common.annotations.Beta; import com.sap.ai.sdk.core.common.ClientException; +import com.sap.ai.sdk.orchestration.model.ErrorResponse; +import java.util.Optional; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; import lombok.experimental.StandardException; /** Exception thrown by the {@link OrchestrationClient} in case of an error. */ @StandardException -public class OrchestrationClientException extends ClientException {} +public class OrchestrationClientException extends ClientException { + + OrchestrationClientException( + @Nonnull final String message, @Nonnull final OrchestrationError clientError) { + super(message); + setClientError(clientError); + } + + /** + * Retrieves the {@link ErrorResponse} from the orchestration service, if available. + * + * @return The {@link ErrorResponse} object, or {@code null} if not available. + */ + @Beta + @Nullable + public ErrorResponse getErrorResponse() { + final var clientError = super.getClientError(); + if (clientError instanceof OrchestrationError orchestrationError) { + return orchestrationError.getErrorResponse(); + } + return null; + } + + /** + * Retrieves the HTTP status code from the original error response, if available. + * + * @return the HTTP status code, or {@code null} if not available + */ + @Beta + @Nullable + public Integer getStatusCode() { + return Optional.ofNullable(getErrorResponse()).map(ErrorResponse::getCode).orElse(null); + } +} diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationError.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationError.java index 4d5956edd..124535796 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationError.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationError.java @@ -18,7 +18,7 @@ @Value @Beta public class OrchestrationError implements ClientError { - ErrorResponse originalResponse; + ErrorResponse errorResponse; /** * Gets the error message from the contained original response. @@ -27,8 +27,8 @@ public class OrchestrationError implements ClientError { */ @Nonnull public String getMessage() { - return originalResponse.getCode() == 500 - ? originalResponse.getMessage() + " located in " + originalResponse.getLocation() - : originalResponse.getMessage(); + return errorResponse.getCode() == 500 + ? errorResponse.getMessage() + " located in " + errorResponse.getLocation() + : errorResponse.getMessage(); } } diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationExceptionFactory.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationExceptionFactory.java new file mode 100644 index 000000000..62a3a1635 --- /dev/null +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationExceptionFactory.java @@ -0,0 +1,49 @@ +package com.sap.ai.sdk.orchestration; + +import com.google.common.annotations.Beta; +import com.sap.ai.sdk.core.common.ClientExceptionFactory; +import com.sap.ai.sdk.orchestration.model.ErrorResponse; +import com.sap.ai.sdk.orchestration.model.GenericModuleResult; +import com.sap.ai.sdk.orchestration.model.ModuleResults; +import java.util.Collections; +import java.util.Map; +import java.util.Optional; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +@Beta +class OrchestrationExceptionFactory + implements ClientExceptionFactory { + + @Nonnull + public OrchestrationClientException build( + @Nonnull final String message, @Nullable final Throwable cause) { + return new OrchestrationClientException(message, cause); + } + + @Nonnull + @Override + public OrchestrationClientException buildFromClientError( + @Nonnull final String message, @Nonnull final OrchestrationError clientError) { + + final var inputFilterDetails = extractInputFilterDetails(clientError); + if (!inputFilterDetails.isEmpty()) { + return new OrchestrationFilterException.Input(message, clientError, inputFilterDetails); + } + + return new OrchestrationClientException(message, clientError); + } + + @SuppressWarnings("unchecked") + @Nonnull + private Map extractInputFilterDetails(@Nonnull final OrchestrationError error) { + + return Optional.of(error.getErrorResponse()) + .map(ErrorResponse::getModuleResults) + .map(ModuleResults::getInputFiltering) + .map(GenericModuleResult::getData) + .filter(Map.class::isInstance) + .map(map -> (Map) map) + .orElseGet(Collections::emptyMap); + } +} diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationFilterException.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationFilterException.java new file mode 100644 index 000000000..13a83a117 --- /dev/null +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationFilterException.java @@ -0,0 +1,107 @@ +package com.sap.ai.sdk.orchestration; + +import static com.sap.ai.sdk.orchestration.OrchestrationJacksonConfiguration.getOrchestrationObjectMapper; + +import com.google.common.annotations.Beta; +import com.sap.ai.sdk.orchestration.model.AzureContentSafetyInput; +import com.sap.ai.sdk.orchestration.model.AzureContentSafetyOutput; +import com.sap.ai.sdk.orchestration.model.LlamaGuard38b; +import java.util.Map; +import java.util.Optional; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import lombok.AccessLevel; +import lombok.Getter; +import lombok.experimental.StandardException; + +/** Base exception for errors occurring during orchestration filtering. */ +@Beta +@StandardException(access = AccessLevel.PRIVATE) +public class OrchestrationFilterException extends OrchestrationClientException { + + /** Details about the filters that caused the exception. */ + @Getter @Nonnull protected Map filterDetails = Map.of(); + + /** + * Retrieves LlamaGuard 3.8b details from {@code filterDetails}, if present. + * + * @return The LlamaGuard38b object, or {@code null} if not found or conversion fails. + * @throws IllegalArgumentException if the conversion of filter details to {@link LlamaGuard38b} + * fails due to invalid content. + */ + @Nullable + public LlamaGuard38b getLlamaGuard38b() { + return Optional.ofNullable(filterDetails.get("llama_guard_3_8b")) + .map(obj -> getOrchestrationObjectMapper().convertValue(obj, LlamaGuard38b.class)) + .orElse(null); + } + + /** Exception thrown when an error occurs during input filtering. */ + public static class Input extends OrchestrationFilterException { + /** + * Constructs a new OrchestrationInputFilterException. + * + * @param message The detail message. + * @param clientError The specific client error. + * @param filterDetails Details about the filter that caused the exception. + */ + Input( + @Nonnull final String message, + @Nonnull final OrchestrationError clientError, + @Nonnull final Map filterDetails) { + super(message); + setClientError(clientError); + this.filterDetails = filterDetails; + } + + /** + * Retrieves Azure Content Safety input details from {@code filterDetails}, if present. + * + * @return The AzureContentSafetyInput object, or {@code null} if not found or conversion fails. + * @throws IllegalArgumentException if the conversion of filter details to {@link + * AzureContentSafetyInput} fails due to invalid content. + */ + @Nullable + public AzureContentSafetyInput getAzureContentSafetyInput() { + return Optional.ofNullable(filterDetails.get("azure_content_safety")) + .map( + obj -> + getOrchestrationObjectMapper().convertValue(obj, AzureContentSafetyInput.class)) + .orElse(null); + } + } + + /** + * Exception thrown when an error occurs during output filtering, specifically when the finish + * reason is a content filter. + */ + public static class Output extends OrchestrationFilterException { + /** + * Constructs a new OrchestrationOutputFilterException. + * + * @param message The detail message. + * @param filterDetails Details about the filter that caused the exception. + */ + Output(@Nonnull final String message, @Nonnull final Map filterDetails) { + super(message); + this.filterDetails = filterDetails; + } + + /** + * Retrieves Azure Content Safety output details from {@code filterDetails}, if present. + * + * @return The AzureContentSafetyOutput object, or {@code null} if not found or conversion + * fails. + * @throws IllegalArgumentException if the conversion of filter details to {@link + * AzureContentSafetyOutput} fails due to invalid content. + */ + @Nullable + public AzureContentSafetyOutput getAzureContentSafetyOutput() { + return Optional.ofNullable(filterDetails.get("azure_content_safety")) + .map( + obj -> + getOrchestrationObjectMapper().convertValue(obj, AzureContentSafetyOutput.class)) + .orElse(null); + } + } +} diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationHttpExecutor.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationHttpExecutor.java index c6effa150..6f4b061f1 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationHttpExecutor.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationHttpExecutor.java @@ -48,7 +48,7 @@ T execute( val handler = new ClientResponseHandler<>( - responseType, OrchestrationError.class, OrchestrationClientException::new) + responseType, OrchestrationError.class, new OrchestrationExceptionFactory()) .objectMapper(JACKSON); return client.execute(request, handler); @@ -76,7 +76,7 @@ Stream stream(@Nonnull final Object payload) { return new ClientStreamingHandler<>( OrchestrationChatCompletionDelta.class, OrchestrationError.class, - OrchestrationClientException::new) + new OrchestrationExceptionFactory()) .objectMapper(JACKSON) .handleStreamingResponse(client.executeOpen(null, request, null)); diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java index 7603f6a69..b06ea7f7f 100644 --- a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java @@ -20,6 +20,8 @@ import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.GPT_4O; import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.GPT_4O_MINI; import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.Parameter.*; +import static com.sap.ai.sdk.orchestration.model.AzureThreshold.NUMBER_0; +import static com.sap.ai.sdk.orchestration.model.AzureThreshold.NUMBER_6; import static com.sap.ai.sdk.orchestration.model.ResponseChatMessage.RoleEnum.ASSISTANT; import static com.sap.ai.sdk.orchestration.model.UserChatMessage.RoleEnum.USER; import static org.apache.hc.core5.http.HttpStatus.SC_BAD_REQUEST; @@ -55,6 +57,7 @@ import com.sap.ai.sdk.orchestration.model.EmbeddingsPostRequest; import com.sap.ai.sdk.orchestration.model.EmbeddingsPostResponse; import com.sap.ai.sdk.orchestration.model.EmbeddingsResponse; +import com.sap.ai.sdk.orchestration.model.ErrorResponse; import com.sap.ai.sdk.orchestration.model.GenericModuleResult; import com.sap.ai.sdk.orchestration.model.GroundingFilterSearchConfiguration; import com.sap.ai.sdk.orchestration.model.GroundingModuleConfig; @@ -62,6 +65,7 @@ import com.sap.ai.sdk.orchestration.model.KeyValueListPair; import com.sap.ai.sdk.orchestration.model.LlamaGuard38b; import com.sap.ai.sdk.orchestration.model.MaskingModuleConfig; +import com.sap.ai.sdk.orchestration.model.ModuleResultsStreaming; import com.sap.ai.sdk.orchestration.model.ResponseFormatText; import com.sap.ai.sdk.orchestration.model.SearchDocumentKeyValueListPair; import com.sap.ai.sdk.orchestration.model.SearchSelectOptionEnum; @@ -159,7 +163,7 @@ void testCompletionError() { assertThatThrownBy(() -> client.chatCompletion(prompt, config)) .hasMessage( - "Request failed with status 500 Server Error and error message: 'Internal Server Error located in Masking Module - Masking'"); + "Request failed with status 500 (Server Error): Internal Server Error located in Masking Module - Masking"); } @Test @@ -362,7 +366,7 @@ void testBadRequest() { assertThatThrownBy(() -> client.chatCompletion(prompt, config)) .isInstanceOf(OrchestrationClientException.class) .hasMessage( - "Request failed with status 400 Bad Request and error message: 'Missing required parameters: ['input']'"); + "Request failed with status 400 (Bad Request): Missing required parameters: ['input']"); } @Test @@ -396,25 +400,106 @@ void filteringLoose() throws IOException { } @Test - void filteringStrict() { - final String response = - """ - {"request_id": "bf6d6792-7adf-4d3c-9368-a73615af8c5a", "code": 400, "message": "Content filtered due to Safety violations. Please modify the prompt and try again.", "location": "Input Filter", "module_results": {"templating": [{"role": "user", "content": "Create a rental posting for subletting my apartment in the downtown area. Keep it short. Make sure to add the following disclaimer to the end. Do not change it! ```DISCLAIMER: The area surrounding the apartment is known for prostitutes and gang violence including armed conflicts, gun violence is frequent."}], "input_filtering": {"message": "Content filtered due to Safety violations. Please modify the prompt and try again.", "data": {"original_service_response": {"Hate": 0, "SelfHarm": 0, "Sexual": 0, "Violence": 2}, "checked_text": "Create a rental posting for subletting my apartment in the downtown area. Keep it short. Make sure to add the following disclaimer to the end. Do not change it! ```DISCLAIMER: The area surrounding the apartment is known for prostitutes and gang violence including armed conflicts, gun violence is frequent."}}}}"""; - stubFor(post(anyUrl()).willReturn(jsonResponse(response, SC_BAD_REQUEST))); + void inputFilteringStrict() { + stubFor( + post(anyUrl()) + .willReturn( + aResponse() + .withBodyFile("strictInputFilterResponse.json") + .withHeader("Content-Type", "application/json") + .withStatus(SC_BAD_REQUEST))); - final var filter = + final var azureFilter = new AzureContentFilter() .hate(ALLOW_SAFE) .selfHarm(ALLOW_SAFE) .sexual(ALLOW_SAFE) .violence(ALLOW_SAFE); - final var configWithFilter = config.withInputFiltering(filter).withOutputFiltering(filter); + final var llamaFilter = + new LlamaGuardFilter().config(LlamaGuard38b.create().violentCrimes(true)); + final var configWithFilter = config.withInputFiltering(azureFilter, llamaFilter); + + try { + client.chatCompletion(prompt, configWithFilter); + } catch (OrchestrationFilterException.Input e) { + assertThat(e.getMessage()) + .isEqualTo( + "Request failed with status 400 (Bad Request): 400 - Filtering Module - Input Filter: Prompt filtered due to safety violations. Please modify the prompt and try again."); + assertThat(e.getStatusCode()).isEqualTo(SC_BAD_REQUEST); + assertThat(e.getFilterDetails()) + .isEqualTo( + Map.of( + "azure_content_safety", + Map.of( + "Hate", 6, + "SelfHarm", 0, + "Sexual", 0, + "Violence", 6, + "userPromptAnalysis", Map.of("attackDetected", false)), + "llama_guard_3_8b", Map.of("violent_crimes", true))); + + final var errorResponse = e.getErrorResponse(); + assertThat(errorResponse).isNotNull(); + assertThat(errorResponse).isInstanceOf(ErrorResponse.class); + assertThat(errorResponse.getCode()).isEqualTo(SC_BAD_REQUEST); + assertThat(errorResponse.getMessage()) + .isEqualTo( + "400 - Filtering Module - Input Filter: Prompt filtered due to safety violations. Please modify the prompt and try again."); + + assertThat(e.getAzureContentSafetyInput()).isNotNull(); + assertThat(e.getAzureContentSafetyInput().getHate()).isEqualTo(NUMBER_6); + assertThat(e.getAzureContentSafetyInput().getSelfHarm()).isEqualTo(NUMBER_0); + assertThat(e.getAzureContentSafetyInput().getSexual()).isEqualTo(NUMBER_0); + assertThat(e.getAzureContentSafetyInput().getViolence()).isEqualTo(NUMBER_6); + + assertThat(e.getLlamaGuard38b()).isNotNull(); + assertThat(e.getLlamaGuard38b().isViolentCrimes()).isTrue(); + } + } - assertThatThrownBy(() -> client.chatCompletion(prompt, configWithFilter)) - .isInstanceOf(OrchestrationClientException.class) - .hasMessage( - "Request failed with status 400 Bad Request and error message: 'Content filtered due to Safety violations. Please modify the prompt and try again.'"); + @Test + void outputFilteringStrict() { + stubFor(post(anyUrl()).willReturn(aResponse().withBodyFile("outputFilteringStrict.json"))); + + final var azureFilter = + new AzureContentFilter() + .hate(ALLOW_SAFE) + .selfHarm(ALLOW_SAFE) + .sexual(ALLOW_SAFE) + .violence(ALLOW_SAFE); + + final var llamaFilter = + new LlamaGuardFilter().config(LlamaGuard38b.create().violentCrimes(true)); + final var configWithFilter = config.withOutputFiltering(azureFilter, llamaFilter); + + try { + client.chatCompletion(prompt, configWithFilter).getContent(); + } catch (OrchestrationFilterException.Output e) { + assertThat(e.getMessage()).isEqualTo("Content filter filtered the output."); + assertThat(e.getFilterDetails()) + .isEqualTo( + Map.of( + "index", 0, + "azure_content_safety", + Map.of( + "Hate", 6, + "SelfHarm", 0, + "Sexual", 0, + "Violence", 6), + "llama_guard_3_8b", Map.of("violent_crimes", true))); + assertThat(e.getErrorResponse()).isNull(); + assertThat(e.getStatusCode()).isNull(); + + assertThat(e.getAzureContentSafetyOutput()).isNotNull(); + assertThat(e.getAzureContentSafetyOutput().getHate()).isEqualTo(NUMBER_6); + assertThat(e.getAzureContentSafetyOutput().getSelfHarm()).isEqualTo(NUMBER_0); + assertThat(e.getAzureContentSafetyOutput().getSexual()).isEqualTo(NUMBER_0); + assertThat(e.getAzureContentSafetyOutput().getViolence()).isEqualTo(NUMBER_6); + + assertThat(e.getLlamaGuard38b()).isNotNull(); + assertThat(e.getLlamaGuard38b().isViolentCrimes()).isTrue(); + } } @Test @@ -556,7 +641,7 @@ void testErrorHandling(@Nonnull final Runnable request) { .assertThatThrownBy(request::run) .describedAs("Empty responses should be handled") .isInstanceOf(OrchestrationClientException.class) - .hasMessageContaining("was empty"); + .hasMessageContaining("HTTP Response is empty"); softly.assertAll(); } @@ -626,13 +711,27 @@ void testThrowsOnContentFilter() { var deltaWithContentFilter = mock(OrchestrationChatCompletionDelta.class); when(deltaWithContentFilter.getFinishReason()).thenReturn("content_filter"); + + var moduleResults = mock(ModuleResultsStreaming.class); + when(deltaWithContentFilter.getModuleResults()).thenReturn(moduleResults); + + var outputFiltering = mock(GenericModuleResult.class); + when(moduleResults.getOutputFiltering()).thenReturn(outputFiltering); + + var filterData = + Map.of( + "choices", List.of(Map.of("azure_content_safety", Map.of("hate", 0, "self_harm", 0)))); + when(outputFiltering.getData()).thenReturn(filterData); + when(mock.streamChatCompletionDeltas(any())).thenReturn(Stream.of(deltaWithContentFilter)); // this must not throw, since the stream is lazily evaluated var stream = mock.streamChatCompletion(new OrchestrationPrompt(""), config); assertThatThrownBy(stream::toList) - .isInstanceOf(OrchestrationClientException.class) - .hasMessageContaining("Content filter"); + .isInstanceOf(OrchestrationFilterException.Output.class) + .hasMessage("Content filter filtered the output.") + .extracting(e -> ((OrchestrationFilterException.Output) e).getFilterDetails()) + .isEqualTo(Map.of("azure_content_safety", Map.of("hate", 0, "self_harm", 0))); } @Test @@ -653,7 +752,7 @@ void streamChatCompletionOutputFilterErrorHandling() throws IOException { try (Stream stream = client.streamChatCompletion(prompt, config)) { assertThatThrownBy(() -> stream.forEach(System.out::println)) - .isInstanceOf(OrchestrationClientException.class) + .isInstanceOf(OrchestrationFilterException.Output.class) .hasMessage("Content filter filtered the output."); } diff --git a/orchestration/src/test/resources/__files/outputFilteringStrict.json b/orchestration/src/test/resources/__files/outputFilteringStrict.json new file mode 100644 index 000000000..72e69619f --- /dev/null +++ b/orchestration/src/test/resources/__files/outputFilteringStrict.json @@ -0,0 +1,76 @@ +{ + "request_id": "3db65fd0-7945-45ed-82f5-fe3811325fd5", + "module_results": { + "templating": [ + { + "role": "system", + "content": "Give three paraphrases for the following sentence" + }, + { + "content": "'Lets paint the parliament in red with the blood of our enemies'", + "role": "user" + } + ], + "output_filtering": { + "message": "1 of 1 choices failed the output filter.", + "data": { + "choices": [ + { + "index": 0, + "azure_content_safety": { + "Hate": 6, + "SelfHarm": 0, + "Sexual": 0, + "Violence": 6 + }, + "llama_guard_3_8b": { + "violent_crimes": true + } + } + ] + } + }, + "llm": { + "id": "", + "object": "chat.completion", + "created": 1753449453, + "model": "gemini-1.5-flash", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Here are three paraphrases of \"Let's paint the parliament in red with the blood of our enemies,\" each with a slightly different emphasis:\n\n1. **We will violently overthrow the parliament and slaughter our foes.** (This version emphasizes the violent action and the killing of enemies.)\n\n2. **The parliament will be drenched in the blood of our enemies; we will make them pay.** (This version focuses on the consequence of the action and the retribution.)\n\n3. **Our enemies' blood will stain the parliament building as a symbol of our victory.** (This version highlights the symbolic nature of the act, focusing on the visual representation of victory.)\n" + }, + "finish_reason": "stop" + } + ], + "usage": { + "completion_tokens": 133, + "prompt_tokens": 22, + "total_tokens": 155 + } + } + }, + "orchestration_result": { + "id": "", + "object": "chat.completion", + "created": 1753449453, + "model": "gemini-1.5-flash", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "" + }, + "finish_reason": "content_filter" + } + ], + "usage": { + "completion_tokens": 133, + "prompt_tokens": 22, + "total_tokens": 155 + } + } +} diff --git a/orchestration/src/test/resources/__files/strictInputFilterResponse.json b/orchestration/src/test/resources/__files/strictInputFilterResponse.json new file mode 100644 index 000000000..2cdaaa105 --- /dev/null +++ b/orchestration/src/test/resources/__files/strictInputFilterResponse.json @@ -0,0 +1,31 @@ +{ + "request_id": "7ed0869f-05be-406b-8124-e9d94516c190", + "code": 400, + "message": "400 - Filtering Module - Input Filter: Prompt filtered due to safety violations. Please modify the prompt and try again.", + "location": "Filtering Module - Input Filter", + "module_results": { + "templating": [ + { + "content": "Let us paint the parliament red with blood of our enemies.", + "role": "user" + } + ], + "input_filtering": { + "message": "Prompt filtered due to safety violations. Please modify the prompt and try again.", + "data": { + "azure_content_safety": { + "Hate": 6, + "SelfHarm": 0, + "Sexual": 0, + "Violence": 6, + "userPromptAnalysis": { + "attackDetected": false + } + }, + "llama_guard_3_8b": { + "violent_crimes": true + } + } + } + } +} \ No newline at end of file diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java index a9e404b19..6f1d48b1f 100644 --- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java +++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java @@ -5,13 +5,16 @@ import com.sap.ai.sdk.app.services.OrchestrationService; import com.sap.ai.sdk.orchestration.AzureFilterThreshold; import com.sap.ai.sdk.orchestration.OrchestrationChatResponse; -import com.sap.ai.sdk.orchestration.OrchestrationClientException; +import com.sap.ai.sdk.orchestration.OrchestrationFilterException; +import com.sap.ai.sdk.orchestration.model.AzureContentSafetyInput; +import com.sap.ai.sdk.orchestration.model.AzureContentSafetyOutput; import com.sap.ai.sdk.orchestration.model.DPIEntities; import com.sap.cloud.sdk.cloudplatform.thread.ThreadContextExecutors; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.util.List; +import java.util.Optional; import javax.annotation.Nonnull; import javax.annotation.Nullable; import lombok.extern.slf4j.Slf4j; @@ -124,10 +127,19 @@ Object inputFiltering( final OrchestrationChatResponse response; try { response = service.inputFiltering(policy); - } catch (OrchestrationClientException e) { - final var msg = "Failed to obtain a response as the content was flagged by input filter."; - log.debug(msg, e); - return ResponseEntity.internalServerError().body(msg); + } catch (OrchestrationFilterException.Input e) { + final var msg = + new StringBuilder( + "[Http %d] Failed to obtain a response as the content was flagged by input filter. " + .formatted(e.getStatusCode())); + + Optional.ofNullable(e.getAzureContentSafetyInput()) + .map(AzureContentSafetyInput::getViolence) + .filter(rating -> rating.compareTo(policy.getAzureThreshold()) > 0) + .ifPresent(rating -> msg.append("Violence score %d".formatted(rating.getValue()))); + + log.debug(msg.toString(), e); + return ResponseEntity.internalServerError().body(msg.toString()); } if ("json".equals(format)) { @@ -142,19 +154,29 @@ Object outputFiltering( @Nullable @RequestParam(value = "format", required = false) final String format, @Nonnull @PathVariable("policy") final AzureFilterThreshold policy) { - final OrchestrationChatResponse response; + final var response = service.outputFiltering(policy); + + final String content; try { - response = service.outputFiltering(policy); - } catch (OrchestrationClientException e) { - final var msg = "Failed to obtain a response as the content was flagged by output filter."; - log.debug(msg, e); - return ResponseEntity.internalServerError().body(msg); + content = response.getContent(); + } catch (OrchestrationFilterException.Output e) { + final var msg = + new StringBuilder( + "Failed to obtain a response as the content was flagged by output filter. "); + + Optional.ofNullable(e.getAzureContentSafetyOutput()) + .map(AzureContentSafetyOutput::getViolence) + .filter(rating -> rating.compareTo(policy.getAzureThreshold()) > 0) + .ifPresent(rating -> msg.append("Violence score %d ".formatted(rating.getValue()))); + + log.debug(msg.toString(), e); + return ResponseEntity.internalServerError().body(msg.toString()); } if ("json".equals(format)) { return response; } - return response.getContent(); + return content; } @GetMapping("/llamaGuardFilter/{enabled}") @@ -166,8 +188,13 @@ Object llamaGuardInputFiltering( final OrchestrationChatResponse response; try { response = service.llamaGuardInputFilter(enabled); - } catch (OrchestrationClientException e) { - final var msg = "Failed to obtain a response as the content was flagged by input filter."; + } catch (OrchestrationFilterException.Input e) { + var msg = + "[Http %d] Failed to obtain a response as the content was flagged by input filter. " + .formatted(e.getStatusCode()); + if (e.getLlamaGuard38b() != null) { + msg += " Violent crimes are %s".formatted(e.getLlamaGuard38b().isViolentCrimes()); + } log.debug(msg, e); return ResponseEntity.internalServerError().body(msg); } diff --git a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java index 3bf5d45a0..2f89a0c08 100644 --- a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java +++ b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java @@ -2,6 +2,7 @@ import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.GEMINI_1_5_FLASH; import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.Parameter.TEMPERATURE; +import static com.sap.ai.sdk.orchestration.model.AzureThreshold.*; import static com.sap.ai.sdk.orchestration.model.ResponseChatMessage.RoleEnum.ASSISTANT; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -14,6 +15,7 @@ import com.sap.ai.sdk.orchestration.Message; import com.sap.ai.sdk.orchestration.OrchestrationClient; import com.sap.ai.sdk.orchestration.OrchestrationClientException; +import com.sap.ai.sdk.orchestration.OrchestrationFilterException; import com.sap.ai.sdk.orchestration.OrchestrationModuleConfig; import com.sap.ai.sdk.orchestration.OrchestrationPrompt; import com.sap.ai.sdk.orchestration.TemplateConfig; @@ -214,10 +216,19 @@ void testInputFilteringStrict() { var policy = AzureFilterThreshold.ALLOW_SAFE; assertThatThrownBy(() -> service.inputFiltering(policy)) - .isInstanceOf(OrchestrationClientException.class) .hasMessageContaining( "Prompt filtered due to safety violations. Please modify the prompt and try again.") - .hasMessageContaining("400 Bad Request"); + .hasMessageContaining("400 (Bad Request)") + .isInstanceOfSatisfying( + OrchestrationFilterException.Input.class, + e -> { + var actualAzureContentSafety = e.getAzureContentSafetyInput(); + assertThat(actualAzureContentSafety).isNotNull(); + assertThat(actualAzureContentSafety.getViolence()).isGreaterThan(NUMBER_0); + assertThat(actualAzureContentSafety.getSelfHarm()).isEqualTo(NUMBER_0); + assertThat(actualAzureContentSafety.getSexual()).isEqualTo(NUMBER_0); + assertThat(actualAzureContentSafety.getHate()).isEqualTo(NUMBER_0); + }); } @Test @@ -239,8 +250,17 @@ void testOutputFilteringStrict() { var response = service.outputFiltering(policy); assertThatThrownBy(response::getContent) - .isInstanceOf(OrchestrationClientException.class) - .hasMessageContaining("Content filter filtered the output."); + .hasMessageContaining("Content filter filtered the output.") + .isInstanceOfSatisfying( + OrchestrationFilterException.Output.class, + e -> { + var actualAzureContentSafety = e.getAzureContentSafetyOutput(); + assertThat(actualAzureContentSafety).isNotNull(); + assertThat(actualAzureContentSafety.getViolence()).isGreaterThan(NUMBER_0); + assertThat(actualAzureContentSafety.getSelfHarm()).isEqualTo(NUMBER_0); + assertThat(actualAzureContentSafety.getSexual()).isEqualTo(NUMBER_0); + assertThat(actualAzureContentSafety.getHate()).isEqualTo(NUMBER_0); + }); } @Test @@ -259,10 +279,20 @@ void testOutputFilteringLenient() { @Test void testLlamaGuardEnabled() { assertThatThrownBy(() -> service.llamaGuardInputFilter(true)) - .isInstanceOf(OrchestrationClientException.class) + .isInstanceOf(OrchestrationFilterException.Input.class) .hasMessageContaining( "Prompt filtered due to safety violations. Please modify the prompt and try again.") - .hasMessageContaining("400 Bad Request"); + .hasMessageContaining("400 (Bad Request)") + .isInstanceOfSatisfying( + OrchestrationFilterException.Input.class, + e -> { + var llamaGuard38b = e.getLlamaGuard38b(); + assertThat(llamaGuard38b).isNotNull(); + assertThat(llamaGuard38b.isViolentCrimes()).isTrue(); + assertThat(llamaGuard38b.isHate()).isFalse(); + assertThat(llamaGuard38b.isChildExploitation()).isFalse(); + assertThat(llamaGuard38b.isDefamation()).isFalse(); + }); } @Test @@ -377,7 +407,7 @@ void testStreamingErrorHandlingTemplate() { assertThatThrownBy(() -> client.streamChatCompletion(prompt, configWithTemplate)) .isInstanceOf(OrchestrationClientException.class) - .hasMessageContaining("status 400 Bad Request") + .hasMessageContaining("status 400 (Bad Request)") .hasMessageContaining("Error processing template:"); } @@ -388,8 +418,8 @@ void testStreamingErrorHandlingInputFilter() { val configWithFilter = config.withInputFiltering(filterConfig); assertThatThrownBy(() -> client.streamChatCompletion(prompt, configWithFilter)) - .isInstanceOf(OrchestrationClientException.class) - .hasMessageContaining("status 400 Bad Request") + .isInstanceOf(OrchestrationFilterException.Input.class) + .hasMessageContaining("status 400 (Bad Request)") .hasMessageContaining("Filtering Module - Input Filter"); } @@ -402,7 +432,7 @@ void testStreamingErrorHandlingMasking() { assertThatThrownBy(() -> client.streamChatCompletion(prompt, configWithMasking)) .isInstanceOf(OrchestrationClientException.class) - .hasMessageContaining("status 400 Bad Request") + .hasMessageContaining("status 400 (Bad Request)") .hasMessageContaining("'unknown_default_open_api' is not one of"); }