|
4 | 4 | import static org.assertj.core.api.Assertions.assertThatThrownBy; |
5 | 5 |
|
6 | 6 | import com.sap.ai.sdk.app.services.SpringAiOrchestrationService; |
| 7 | +import com.sap.ai.sdk.orchestration.AzureFilterThreshold; |
7 | 8 | import com.sap.ai.sdk.orchestration.OrchestrationClientException; |
| 9 | +import com.sap.ai.sdk.orchestration.spring.OrchestrationSpringChatResponse; |
8 | 10 | import java.util.List; |
9 | 11 | import java.util.concurrent.atomic.AtomicInteger; |
10 | 12 | import lombok.extern.slf4j.Slf4j; |
@@ -58,6 +60,62 @@ void testMasking() { |
58 | 60 | assertThat(response.getResult().getOutput().getText()).isNotEmpty(); |
59 | 61 | } |
60 | 62 |
|
| 63 | + @Test |
| 64 | + void testInputFilteringStrict() { |
| 65 | + var policy = AzureFilterThreshold.ALLOW_SAFE; |
| 66 | + |
| 67 | + assertThatThrownBy(() -> service.inputFiltering(policy)) |
| 68 | + .isInstanceOf(OrchestrationClientException.class) |
| 69 | + .hasMessageContaining( |
| 70 | + "Content filtered due to safety violations. Please modify the prompt and try again.") |
| 71 | + .hasMessageContaining("400 Bad Request"); |
| 72 | + } |
| 73 | + |
| 74 | + @Test |
| 75 | + void testInputFilteringLenient() { |
| 76 | + var policy = AzureFilterThreshold.ALLOW_ALL; |
| 77 | + |
| 78 | + var response = service.inputFiltering(policy); |
| 79 | + |
| 80 | + assertThat(response.getResult().getMetadata().getFinishReason()).isEqualTo("stop"); |
| 81 | + assertThat(response.getResult().getOutput().getText()).isNotEmpty(); |
| 82 | + |
| 83 | + var filterResult = |
| 84 | + ((OrchestrationSpringChatResponse) response) |
| 85 | + .getOrchestrationResponse() |
| 86 | + .getOriginalResponse() |
| 87 | + .getModuleResults() |
| 88 | + .getInputFiltering(); |
| 89 | + assertThat(filterResult.getMessage()).contains("passed"); |
| 90 | + } |
| 91 | + |
| 92 | + @Test |
| 93 | + void testOutputFilteringStrict() { |
| 94 | + var policy = AzureFilterThreshold.ALLOW_SAFE; |
| 95 | + |
| 96 | + assertThatThrownBy(() -> service.outputFiltering(policy)) |
| 97 | + .isInstanceOf(OrchestrationClientException.class) |
| 98 | + .hasMessageContaining("Content filter filtered the output."); |
| 99 | + } |
| 100 | + |
| 101 | + @Test |
| 102 | + void testOutputFilteringLenient() { |
| 103 | + var policy = AzureFilterThreshold.ALLOW_ALL; |
| 104 | + |
| 105 | + var response = service.outputFiltering(policy); |
| 106 | + |
| 107 | + assertThat(response.getResult().getMetadata().getFinishReason()).isEqualTo("stop"); |
| 108 | + assertThat(response.getResult().getOutput().getText()).isNotEmpty(); |
| 109 | + |
| 110 | + var filterResult = |
| 111 | + ((OrchestrationSpringChatResponse) response) |
| 112 | + .getOrchestrationResponse() |
| 113 | + .getOriginalResponse() |
| 114 | + .getModuleResults() |
| 115 | + .getOutputFiltering(); |
| 116 | + assertThat(filterResult.getMessage()).containsPattern("0 of \\d+ choices failed"); |
| 117 | + } |
| 118 | + |
61 | 119 | @Test |
62 | 120 | void testToolCallingWithoutExecution() { |
63 | 121 | ChatResponse response = service.toolCalling(false); |
|
0 commit comments