Skip to content

Commit e0f0b62

Browse files
committed
Fix tests and add new mixin for filtering response
1 parent 64ac941 commit e0f0b62

File tree

8 files changed

+130
-66
lines changed

8 files changed

+130
-66
lines changed

orchestration/src/main/java/com/sap/ai/sdk/orchestration/JacksonMixins.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
package com.sap.ai.sdk.orchestration;
22

3+
import com.fasterxml.jackson.annotation.JsonAlias;
4+
import com.fasterxml.jackson.annotation.JsonProperty;
35
import com.fasterxml.jackson.annotation.JsonSubTypes;
46
import com.fasterxml.jackson.annotation.JsonTypeInfo;
57
import com.fasterxml.jackson.annotation.JsonTypeInfo.As;
68
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
9+
import com.sap.ai.sdk.orchestration.model.AzureThreshold;
710
import com.sap.ai.sdk.orchestration.model.LLMModuleResult;
811
import lombok.AccessLevel;
912
import lombok.NoArgsConstructor;
@@ -56,4 +59,22 @@ interface ResponseFormatSubTypesMixin {}
5659
name = "user")
5760
})
5861
interface ChatMessageMixin {}
62+
63+
abstract static class AzureContentSafetyCaseAgnostic {
64+
@JsonProperty("hate")
65+
@JsonAlias("Hate")
66+
private AzureThreshold hate;
67+
68+
@JsonProperty("self_harm")
69+
@JsonAlias("SelfHarm")
70+
private AzureThreshold selfHarm;
71+
72+
@JsonProperty("sexual")
73+
@JsonAlias("Sexual")
74+
private AzureThreshold sexual;
75+
76+
@JsonProperty("violence")
77+
@JsonAlias("Violence")
78+
private AzureThreshold violence;
79+
}
5980
}

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationJacksonConfiguration.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import com.fasterxml.jackson.databind.ObjectMapper;
66
import com.fasterxml.jackson.databind.module.SimpleModule;
77
import com.google.common.annotations.Beta;
8+
import com.sap.ai.sdk.orchestration.model.AzureContentSafetyInput;
9+
import com.sap.ai.sdk.orchestration.model.AzureContentSafetyOutput;
810
import com.sap.ai.sdk.orchestration.model.ChatMessage;
911
import com.sap.ai.sdk.orchestration.model.TemplateResponseFormat;
1012
import javax.annotation.Nonnull;
@@ -38,7 +40,11 @@ public static ObjectMapper getOrchestrationObjectMapper() {
3840
TemplateResponseFormat.class,
3941
PolymorphicFallbackDeserializer.fromJsonSubTypes(TemplateResponseFormat.class))
4042
.setMixInAnnotation(
41-
TemplateResponseFormat.class, JacksonMixins.ResponseFormatSubTypesMixin.class);
43+
TemplateResponseFormat.class, JacksonMixins.ResponseFormatSubTypesMixin.class)
44+
.setMixInAnnotation(
45+
AzureContentSafetyOutput.class, JacksonMixins.AzureContentSafetyCaseAgnostic.class)
46+
.setMixInAnnotation(
47+
AzureContentSafetyInput.class, JacksonMixins.AzureContentSafetyCaseAgnostic.class);
4248

4349
return getDefaultObjectMapper()
4450
.rebuild()

orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationSpringChatResponse.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ static Generation toGeneration(@Nonnull final LLMChoice choice) {
4848
if (choice.getLogprobs() != null && !choice.getLogprobs().getContent().isEmpty()) {
4949
metadata.metadata("logprobs", choice.getLogprobs().getContent());
5050
}
51-
5251
val toolCalls =
5352
choice.getMessage().getToolCalls().stream()
5453
.map(

orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java

Lines changed: 65 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -433,42 +433,43 @@ void inputFilteringStrict() {
433433
new LlamaGuardFilter().config(LlamaGuard38b.create().violentCrimes(true));
434434
final var configWithFilter = config.withInputFiltering(azureFilter, llamaFilter);
435435

436-
try {
437-
client.chatCompletion(prompt, configWithFilter);
438-
} catch (OrchestrationFilterException.Input e) {
439-
assertThat(e.getMessage())
440-
.isEqualTo(
441-
"Request failed with status 400 (Bad Request): 400 - Filtering Module - Input Filter: Prompt filtered due to safety violations. Please modify the prompt and try again.");
442-
assertThat(e.getStatusCode()).isEqualTo(SC_BAD_REQUEST);
443-
assertThat(e.getFilterDetails())
444-
.isEqualTo(
445-
Map.of(
446-
"azure_content_safety",
436+
assertThatThrownBy(() -> client.chatCompletion(prompt, configWithFilter))
437+
.isInstanceOfSatisfying(
438+
OrchestrationFilterException.Input.class,
439+
e -> {
440+
assertThat(e.getMessage())
441+
.isEqualTo(
442+
"Request failed with status 400 (Bad Request): 400 - Filtering Module - Input Filter: Prompt filtered due to safety violations. Please modify the prompt and try again.");
443+
assertThat(e.getStatusCode()).isEqualTo(SC_BAD_REQUEST);
444+
assertThat(e.getFilterDetails())
445+
.isEqualTo(
447446
Map.of(
448-
"Hate", 6,
449-
"SelfHarm", 0,
450-
"Sexual", 0,
451-
"Violence", 6,
452-
"userPromptAnalysis", Map.of("attackDetected", false)),
453-
"llama_guard_3_8b", Map.of("violent_crimes", true)));
454-
455-
final var errorResponse = e.getErrorResponse();
456-
assertThat(errorResponse).isNotNull();
457-
assertThat(errorResponse).isInstanceOf(ErrorResponse.class);
458-
assertThat(errorResponse.getError().getCode()).isEqualTo(SC_BAD_REQUEST);
459-
assertThat(errorResponse.getError().getCode())
460-
.isEqualTo(
461-
"400 - Filtering Module - Input Filter: Prompt filtered due to safety violations. Please modify the prompt and try again.");
462-
463-
assertThat(e.getAzureContentSafetyInput()).isNotNull();
464-
assertThat(e.getAzureContentSafetyInput().getHate()).isEqualTo(NUMBER_6);
465-
assertThat(e.getAzureContentSafetyInput().getSelfHarm()).isEqualTo(NUMBER_0);
466-
assertThat(e.getAzureContentSafetyInput().getSexual()).isEqualTo(NUMBER_0);
467-
assertThat(e.getAzureContentSafetyInput().getViolence()).isEqualTo(NUMBER_6);
468-
469-
assertThat(e.getLlamaGuard38b()).isNotNull();
470-
assertThat(e.getLlamaGuard38b().isViolentCrimes()).isTrue();
471-
}
447+
"azure_content_safety",
448+
Map.of(
449+
"Hate", 6,
450+
"SelfHarm", 0,
451+
"Sexual", 0,
452+
"Violence", 6,
453+
"userPromptAnalysis", Map.of("attackDetected", false)),
454+
"llama_guard_3_8b", Map.of("violent_crimes", true)));
455+
456+
final var errorResponse = e.getErrorResponse();
457+
assertThat(errorResponse).isNotNull();
458+
assertThat(errorResponse).isInstanceOf(ErrorResponse.class);
459+
assertThat(errorResponse.getError().getCode()).isEqualTo(SC_BAD_REQUEST);
460+
assertThat(errorResponse.getError().getMessage())
461+
.isEqualTo(
462+
"400 - Filtering Module - Input Filter: Prompt filtered due to safety violations. Please modify the prompt and try again.");
463+
464+
assertThat(e.getAzureContentSafetyInput()).isNotNull();
465+
assertThat(e.getAzureContentSafetyInput().getHate()).isEqualTo(NUMBER_6);
466+
assertThat(e.getAzureContentSafetyInput().getSelfHarm()).isEqualTo(NUMBER_0);
467+
assertThat(e.getAzureContentSafetyInput().getSexual()).isEqualTo(NUMBER_0);
468+
assertThat(e.getAzureContentSafetyInput().getViolence()).isEqualTo(NUMBER_6);
469+
470+
assertThat(e.getLlamaGuard38b()).isNotNull();
471+
assertThat(e.getLlamaGuard38b().isViolentCrimes()).isTrue();
472+
});
472473
}
473474

474475
@Test
@@ -486,33 +487,36 @@ void outputFilteringStrict() {
486487
new LlamaGuardFilter().config(LlamaGuard38b.create().violentCrimes(true));
487488
final var configWithFilter = config.withOutputFiltering(azureFilter, llamaFilter);
488489

489-
try {
490-
client.chatCompletion(prompt, configWithFilter).getContent();
491-
} catch (OrchestrationFilterException.Output e) {
492-
assertThat(e.getMessage()).isEqualTo("Content filter filtered the output.");
493-
assertThat(e.getFilterDetails())
494-
.isEqualTo(
495-
Map.of(
496-
"index", 0,
497-
"azure_content_safety",
490+
assertThatThrownBy(client.chatCompletion(prompt, configWithFilter)::getContent)
491+
.isInstanceOfSatisfying(
492+
OrchestrationFilterException.Output.class,
493+
e -> {
494+
assertThat(e.getMessage()).isEqualTo("Content filter filtered the output.");
495+
assertThat(e.getFilterDetails())
496+
.isEqualTo(
498497
Map.of(
499-
"Hate", 6,
500-
"SelfHarm", 0,
501-
"Sexual", 0,
502-
"Violence", 6),
503-
"llama_guard_3_8b", Map.of("violent_crimes", true)));
504-
assertThat(e.getErrorResponse()).isNull();
505-
assertThat(e.getStatusCode()).isNull();
506-
507-
assertThat(e.getAzureContentSafetyOutput()).isNotNull();
508-
assertThat(e.getAzureContentSafetyOutput().getHate()).isEqualTo(NUMBER_6);
509-
assertThat(e.getAzureContentSafetyOutput().getSelfHarm()).isEqualTo(NUMBER_0);
510-
assertThat(e.getAzureContentSafetyOutput().getSexual()).isEqualTo(NUMBER_0);
511-
assertThat(e.getAzureContentSafetyOutput().getViolence()).isEqualTo(NUMBER_6);
512-
513-
assertThat(e.getLlamaGuard38b()).isNotNull();
514-
assertThat(e.getLlamaGuard38b().isViolentCrimes()).isTrue();
515-
}
498+
"index",
499+
0,
500+
"azure_content_safety",
501+
Map.of(
502+
"Hate", 6,
503+
"SelfHarm", 0,
504+
"Sexual", 0,
505+
"Violence", 6),
506+
"llama_guard_3_8b",
507+
Map.of("violent_crimes", true)));
508+
assertThat(e.getErrorResponse()).isNull();
509+
assertThat(e.getStatusCode()).isNull();
510+
511+
assertThat(e.getAzureContentSafetyOutput()).isNotNull();
512+
assertThat(e.getAzureContentSafetyOutput().getHate()).isEqualTo(NUMBER_6);
513+
assertThat(e.getAzureContentSafetyOutput().getSelfHarm()).isEqualTo(NUMBER_0);
514+
assertThat(e.getAzureContentSafetyOutput().getSexual()).isEqualTo(NUMBER_0);
515+
assertThat(e.getAzureContentSafetyOutput().getViolence()).isEqualTo(NUMBER_6);
516+
517+
assertThat(e.getLlamaGuard38b()).isNotNull();
518+
assertThat(e.getLlamaGuard38b().isViolentCrimes()).isTrue();
519+
});
516520
}
517521

518522
@Test

orchestration/src/test/resources/__files/outputFilteringStrict.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"request_id": "3db65fd0-7945-45ed-82f5-fe3811325fd5",
3-
"module_results": {
3+
"intermediate_results": {
44
"templating": [
55
{
66
"role": "system",
@@ -52,7 +52,7 @@
5252
}
5353
}
5454
},
55-
"orchestration_result": {
55+
"final_result": {
5656
"id": "",
5757
"object": "chat.completion",
5858
"created": 1753449453,
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
{
2+
"error": {
3+
"request_id": "7ed0869f-05be-406b-8124-e9d94516c190",
4+
"code": 400,
5+
"message": "400 - Filtering Module - Input Filter: Prompt filtered due to safety violations. Please modify the prompt and try again.",
6+
"location": "Filtering Module - Input Filter",
7+
"intermediate_results": {
8+
"templating": [
9+
{
10+
"content": "Let us paint the parliament red with blood of our enemies.",
11+
"role": "user"
12+
}
13+
],
14+
"input_filtering": {
15+
"message": "Prompt filtered due to safety violations. Please modify the prompt and try again.",
16+
"data": {
17+
"azure_content_safety": {
18+
"Hate": 6,
19+
"SelfHarm": 0,
20+
"Sexual": 0,
21+
"Violence": 6,
22+
"userPromptAnalysis": {
23+
"attackDetected": false
24+
}
25+
},
26+
"llama_guard_3_8b": {
27+
"violent_crimes": true
28+
}
29+
}
30+
}
31+
}
32+
}
33+
}

pom.xml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,7 @@ https://gitbox.apache.org/repos/asf?p=maven-pmd-plugin.git;a=blob_plain;f=src/ma
684684
<exclude>com/sap/ai/sdk/prompt/registry/model/*</exclude>
685685
<exclude>com/sap/ai/sdk/prompt/registry/client/*</exclude>
686686
<exclude>com/sap/ai/sdk/app/**/*</exclude>
687+
<exclude>com/sap/ai/sdk/orchestration/JacksonMixins*</exclude>
687688
</excludes>
688689
</configuration>
689690
<executions>

sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/SpringAiOrchestrationTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ void testInputFilteringStrict() {
6868
.isInstanceOf(OrchestrationClientException.class)
6969
.hasMessageContaining(
7070
"Prompt filtered due to safety violations. Please modify the prompt and try again.")
71-
.hasMessageContaining("400 Bad Request");
71+
.hasMessageContaining("400 (Bad Request)");
7272
}
7373

7474
@Test

0 commit comments

Comments
 (0)