Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,16 @@ public class OrchestrationChatResponse {
* <p>Note: If there are multiple choices only the first one is returned
*
* @return the message content or empty string.
* @throws OrchestrationFilterException.Output if the content filter filtered the output.
* @throws OrchestrationClientException if the content filter filtered the output.
*/
@Nonnull
public String getContent() throws OrchestrationFilterException.Output {
public String getContent() throws OrchestrationClientException {
final var choice = getChoice();

if ("content_filter".equals(choice.getFinishReason())) {
final var filterDetails = Try.of(this::getOutputFilteringChoices).getOrElseGet(e -> Map.of());
final var message = "Content filter filtered the output.";
throw new OrchestrationFilterException.Output(message).setFilterDetails(filterDetails);
throw new OrchestrationClientException(message).setFilterDetails(filterDetails);
}
return choice.getMessage().getContent();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,13 @@ public Stream<String> streamChatCompletion(
}

private static void throwOnContentFilter(@Nonnull final OrchestrationChatCompletionDelta delta)
throws OrchestrationFilterException.Output {
throws OrchestrationClientException {
final String finishReason = delta.getFinishReason();
if (finishReason != null && finishReason.equals("content_filter")) {
final var filterDetails =
Try.of(() -> getOutputFilteringChoices(delta)).getOrElseGet(e -> Map.of());
final var message = "Content filter filtered the output.";
throw new OrchestrationFilterException.Output(message).setFilterDetails(filterDetails);
throw new OrchestrationClientException(message).setFilterDetails(filterDetails);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,58 +1,69 @@
package com.sap.ai.sdk.orchestration;

import static com.sap.ai.sdk.orchestration.OrchestrationJacksonConfiguration.getOrchestrationObjectMapper;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.annotations.Beta;
import com.sap.ai.sdk.core.common.ClientException;
import com.sap.ai.sdk.core.common.ClientExceptionFactory;
import com.sap.ai.sdk.orchestration.OrchestrationFilterException.Input;
import com.sap.ai.sdk.orchestration.model.AzureContentSafetyInput;
import com.sap.ai.sdk.orchestration.model.AzureContentSafetyOutput;
import com.sap.ai.sdk.orchestration.model.Error;
import com.sap.ai.sdk.orchestration.model.ErrorResponse;
import com.sap.ai.sdk.orchestration.model.ErrorResponseStreaming;
import com.sap.ai.sdk.orchestration.model.ErrorStreaming;
import com.sap.ai.sdk.orchestration.model.GenericModuleResult;
import com.sap.ai.sdk.orchestration.model.LlamaGuard38b;
import com.sap.ai.sdk.orchestration.model.ModuleResults;
import com.sap.ai.sdk.orchestration.model.ModuleResultsStreaming;
import java.util.Collections;
import java.util.Map;
import java.util.Optional;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.Setter;
import lombok.experimental.Accessors;
import lombok.experimental.StandardException;
import lombok.val;

/** Exception thrown by the {@link OrchestrationClient} in case of an error. */
@StandardException
public class OrchestrationClientException extends ClientException {

static final ClientExceptionFactory<OrchestrationClientException, OrchestrationError> FACTORY =
(message, clientError, cause) -> {
final var details = extractInputFilterDetails(clientError);
if (details.isEmpty()) {
return new OrchestrationClientException(message, cause).setClientError(clientError);
val result = new OrchestrationClientException(message, cause);
val details = extractModuleResults(clientError).map(GenericModuleResult::getData);
if (details.orElse(null) instanceof Map<?, ?> m) {
result.setFilterDetails((Map<String, Object>) m);
}
return new Input(message, cause).setFilterDetails(details).setClientError(clientError);
return result.setClientError(clientError);
};

@SuppressWarnings("unchecked")
private static final ObjectMapper OBJECT_MAPPER = getOrchestrationObjectMapper();

/** Details about the filters that caused the exception. */
@Accessors(chain = true)
@Setter(AccessLevel.PACKAGE)
@Getter
@Nullable
private Map<String, Object> filterDetails = null;

@Nonnull
static Map<String, Object> extractInputFilterDetails(@Nullable final OrchestrationError error) {
if (error instanceof OrchestrationError.Synchronous synchronousError) {
static Optional<GenericModuleResult> extractModuleResults(@Nullable final OrchestrationError e) {
if (e instanceof OrchestrationError.Synchronous synchronousError) {
return Optional.of(synchronousError.getErrorResponse())
.map(ErrorResponse::getError)
.map(Error::getIntermediateResults)
.map(ModuleResults::getInputFiltering)
.map(GenericModuleResult::getData)
.map(map -> (Map<String, Object>) map)
.orElseGet(Collections::emptyMap);
} else if (error instanceof OrchestrationError.Streaming streamingError) {
.map(ModuleResults::getInputFiltering);
} else if (e instanceof OrchestrationError.Streaming streamingError) {
return Optional.of(streamingError.getErrorResponse())
.map(ErrorResponseStreaming::getError)
.map(ErrorStreaming::getIntermediateResults)
.map(ModuleResultsStreaming::getInputFiltering)
.map(GenericModuleResult::getData)
.filter(Map.class::isInstance)
.map(map -> (Map<String, Object>) map)
.orElseGet(Collections::emptyMap);
.map(ModuleResultsStreaming::getInputFiltering);
}
return Collections.emptyMap();
return Optional.empty();
}

@Override
Expand Down Expand Up @@ -105,4 +116,49 @@ public Integer getStatusCode() {
.map(Error::getCode)
.orElse(null);
}

/**
* Retrieves LlamaGuard 3.8b details from {@code filterDetails}, if present.
*
* @return The LlamaGuard38b object, or {@code null} if not found or conversion fails.
* @throws IllegalArgumentException if the conversion of filter details to {@link LlamaGuard38b}
* fails due to invalid content.
*/
@Nullable
public LlamaGuard38b getLlamaGuard38b() {
return Optional.ofNullable(filterDetails)
.map(details -> details.get("llama_guard_3_8b"))
.map(obj -> OBJECT_MAPPER.convertValue(obj, LlamaGuard38b.class))
.orElse(null);
}

/**
* Retrieves Azure Content Safety input details from {@code filterDetails}, if present.
*
* @return The AzureContentSafetyInput object, or {@code null} if not found or conversion fails.
* @throws IllegalArgumentException if the conversion of filter details to {@link
* AzureContentSafetyInput} fails due to invalid content.
*/
@Nullable
public AzureContentSafetyInput getAzureContentSafetyInput() {
return Optional.ofNullable(filterDetails)
.map(details -> details.get("azure_content_safety"))
.map(obj -> OBJECT_MAPPER.convertValue(obj, AzureContentSafetyInput.class))
.orElse(null);
}

/**
* Retrieves Azure Content Safety output details from {@code filterDetails}, if present.
*
* @return The AzureContentSafetyOutput object, or {@code null} if not found or conversion fails.
* @throws IllegalArgumentException if the conversion of filter details to {@link
* AzureContentSafetyOutput} fails due to invalid content.
*/
@Nullable
public AzureContentSafetyOutput getAzureContentSafetyOutput() {
return Optional.ofNullable(filterDetails)
.map(details -> details.get("azure_content_safety"))
.map(obj -> OBJECT_MAPPER.convertValue(obj, AzureContentSafetyOutput.class))
.orElse(null);
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ void inputFilteringStrict() {

assertThatThrownBy(() -> client.chatCompletion(prompt, configWithFilter))
.isInstanceOfSatisfying(
OrchestrationFilterException.Input.class,
OrchestrationClientException.class,
e -> {
assertThat(e.getMessage())
.isEqualTo(
Expand Down Expand Up @@ -503,7 +503,7 @@ void outputFilteringStrict() {

assertThatThrownBy(client.chatCompletion(prompt, configWithFilter)::getContent)
.isInstanceOfSatisfying(
OrchestrationFilterException.Output.class,
OrchestrationClientException.class,
e -> {
assertThat(e.getMessage()).isEqualTo("Content filter filtered the output.");
assertThat(e.getFilterDetails())
Expand Down Expand Up @@ -759,9 +759,9 @@ void testThrowsOnContentFilter() {
// this must not throw, since the stream is lazily evaluated
var stream = mock.streamChatCompletion(new OrchestrationPrompt(""), config);
assertThatThrownBy(stream::toList)
.isInstanceOf(OrchestrationFilterException.Output.class)
.isInstanceOf(OrchestrationClientException.class)
.hasMessage("Content filter filtered the output.")
.extracting(e -> ((OrchestrationFilterException.Output) e).getFilterDetails())
.extracting(e -> ((OrchestrationClientException) e).getFilterDetails())
.isEqualTo(Map.of("azure_content_safety", Map.of("hate", 0, "self_harm", 0)));
}

Expand All @@ -785,7 +785,7 @@ void streamChatCompletionOutputFilterErrorHandling() throws IOException {
assertThatThrownBy(() -> stream.forEach(System.out::println))
.hasMessage("Content filter filtered the output.")
.isInstanceOfSatisfying(
OrchestrationFilterException.Output.class,
OrchestrationClientException.class,
e -> {
assertThat(e.getErrorResponse()).isNull();
assertThat(e.getErrorResponseStreaming()).isNull();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import com.sap.ai.sdk.app.services.OrchestrationService;
import com.sap.ai.sdk.orchestration.AzureFilterThreshold;
import com.sap.ai.sdk.orchestration.OrchestrationChatResponse;
import com.sap.ai.sdk.orchestration.OrchestrationFilterException;
import com.sap.ai.sdk.orchestration.OrchestrationClientException;
import com.sap.ai.sdk.orchestration.model.AzureContentSafetyInput;
import com.sap.ai.sdk.orchestration.model.AzureContentSafetyOutput;
import com.sap.ai.sdk.orchestration.model.DPIEntities;
Expand Down Expand Up @@ -127,7 +127,7 @@ Object inputFiltering(
final OrchestrationChatResponse response;
try {
response = service.inputFiltering(policy);
} catch (OrchestrationFilterException.Input e) {
} catch (OrchestrationClientException e) {
final var msg =
new StringBuilder(
"[Http %d] Failed to obtain a response as the content was flagged by input filter. "
Expand Down Expand Up @@ -159,7 +159,7 @@ Object outputFiltering(
final String content;
try {
content = response.getContent();
} catch (OrchestrationFilterException.Output e) {
} catch (OrchestrationClientException e) {
final var msg =
new StringBuilder(
"Failed to obtain a response as the content was flagged by output filter. ");
Expand Down Expand Up @@ -188,7 +188,7 @@ Object llamaGuardInputFiltering(
final OrchestrationChatResponse response;
try {
response = service.llamaGuardInputFilter(enabled);
} catch (OrchestrationFilterException.Input e) {
} catch (OrchestrationClientException e) {
var msg =
"[Http %d] Failed to obtain a response as the content was flagged by input filter. "
.formatted(e.getStatusCode());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import com.sap.ai.sdk.orchestration.Message;
import com.sap.ai.sdk.orchestration.OrchestrationClient;
import com.sap.ai.sdk.orchestration.OrchestrationClientException;
import com.sap.ai.sdk.orchestration.OrchestrationFilterException;
import com.sap.ai.sdk.orchestration.OrchestrationModuleConfig;
import com.sap.ai.sdk.orchestration.OrchestrationPrompt;
import com.sap.ai.sdk.orchestration.TemplateConfig;
Expand Down Expand Up @@ -221,7 +220,7 @@ void testInputFilteringStrict() {
"Prompt filtered due to safety violations. Please modify the prompt and try again.")
.hasMessageContaining("400 (Bad Request)")
.isInstanceOfSatisfying(
OrchestrationFilterException.Input.class,
OrchestrationClientException.class,
e -> {
var actualAzureContentSafety = e.getAzureContentSafetyInput();
assertThat(actualAzureContentSafety).isNotNull();
Expand Down Expand Up @@ -253,7 +252,7 @@ void testOutputFilteringStrict() {
assertThatThrownBy(response::getContent)
.hasMessageContaining("Content filter filtered the output.")
.isInstanceOfSatisfying(
OrchestrationFilterException.Output.class,
OrchestrationClientException.class,
e -> {
var actualAzureContentSafety = e.getAzureContentSafetyOutput();
assertThat(actualAzureContentSafety).isNotNull();
Expand All @@ -280,12 +279,12 @@ void testOutputFilteringLenient() {
@Test
void testLlamaGuardEnabled() {
assertThatThrownBy(() -> service.llamaGuardInputFilter(true))
.isInstanceOf(OrchestrationFilterException.Input.class)
.isInstanceOf(OrchestrationClientException.class)
.hasMessageContaining(
"Prompt filtered due to safety violations. Please modify the prompt and try again.")
.hasMessageContaining("400 (Bad Request)")
.isInstanceOfSatisfying(
OrchestrationFilterException.Input.class,
OrchestrationClientException.class,
e -> {
var llamaGuard38b = e.getLlamaGuard38b();
assertThat(llamaGuard38b).isNotNull();
Expand Down Expand Up @@ -419,7 +418,7 @@ void testStreamingErrorHandlingInputFilter() {
val configWithFilter = config.withInputFiltering(filterConfig);

assertThatThrownBy(() -> client.streamChatCompletion(prompt, configWithFilter))
.isInstanceOf(OrchestrationFilterException.Input.class)
.isInstanceOf(OrchestrationClientException.class)
.hasMessageContaining("status 400 (Bad Request)")
.hasMessageContaining("Filtering Module - Input Filter");
}
Expand Down