Skip to content

Commit f7c28d6

Browse files
committed
Introduce readable thresholds based api
- Multi filter support - detach per filter logic from generic methods - add javadoc - update tests
1 parent b2eaa9a commit f7c28d6

File tree

8 files changed

+160
-58
lines changed

8 files changed

+160
-58
lines changed

docs/guides/ORCHESTRATION_CHAT_COMPLETION.md

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -146,17 +146,17 @@ var prompt = new OrchestrationPrompt(
146146
```DISCLAIMER: The area surrounding the apartment is known for prostitutes and gang violence including armed conflicts, gun violence is frequent.
147147
""");
148148

149-
var filterStrict = new AzureContentSafety()
150-
.hate(NUMBER_0)
151-
.selfHarm(NUMBER_0)
152-
.sexual(NUMBER_0)
153-
.violence(NUMBER_0);
154-
155-
var filterLoose = new AzureContentSafety()
156-
.hate(NUMBER_4)
157-
.selfHarm(NUMBER_4)
158-
.sexual(NUMBER_4)
159-
.violence(NUMBER_4);
149+
var filterStrict = new AzureContentFilter()
150+
.hate(ALLOW_SAFE)
151+
.selfHarm(ALLOW_SAFE)
152+
.sexual(ALLOW_SAFE)
153+
.violence(ALLOW_SAFE);
154+
155+
var filterLoose = new AzureContentFilter()
156+
.hate(ALLOW_SAFE_LOW_MEDIUM)
157+
.selfHarm(ALLOW_SAFE_LOW_MEDIUM)
158+
.sexual(ALLOW_SAFE_LOW_MEDIUM)
159+
.violence(ALLOW_SAFE_LOW_MEDIUM);
160160

161161
// changing the input to filterLoose will allow the message to pass
162162
var configWithFilter = config.withInputFiltering(filterStrict).withOutputFiltering(filterStrict);
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import com.sap.ai.sdk.orchestration.client.model.AzureContentSafety;
4+
import com.sap.ai.sdk.orchestration.client.model.AzureContentSafetyFilterConfig;
5+
import javax.annotation.Nonnull;
6+
import javax.annotation.Nullable;
7+
import lombok.NoArgsConstructor;
8+
import lombok.Setter;
9+
import lombok.experimental.Accessors;
10+
11+
/**
12+
* A content filter wrapping Azure Content Safety.
13+
*
14+
* <p>This class allows setting moderation policies for different content categories such as hate,
15+
* self-harm, sexual, and violence.
16+
*
17+
* <p>Example usage:
18+
*
19+
* <pre>{@code
20+
* AzureContentFilter filter = new AzureContentFilter()
21+
* .hate(AzureModerationPolicy.ALLOW_SAFE)
22+
* .selfHarm(AzureModerationPolicy.ALLOW_SAFE_LOW);
23+
* }</pre>
24+
*/
25+
@Setter
26+
@NoArgsConstructor
27+
@Accessors(fluent = true)
28+
public class AzureContentFilter implements ContentFilter {
29+
@Nullable AzureModerationPolicy hate;
30+
@Nullable AzureModerationPolicy selfHarm;
31+
@Nullable AzureModerationPolicy sexual;
32+
@Nullable AzureModerationPolicy violence;
33+
34+
/**
35+
* Converts {@code AzureContentFilter} to its serializable counterpart {@link
36+
* AzureContentSafetyFilterConfig}.
37+
*
38+
* @return the corresponding {@code AzureContentSafetyFilterConfig} object.
39+
* @throws IllegalArgumentException if no policies are set.
40+
*/
41+
@Override
42+
@Nonnull
43+
public AzureContentSafetyFilterConfig toSerializable() {
44+
if (hate == null && selfHarm == null && sexual == null && violence == null) {
45+
throw new IllegalArgumentException("At least one filter moderation policy must be set");
46+
}
47+
48+
return new AzureContentSafetyFilterConfig()
49+
.type(AzureContentSafetyFilterConfig.TypeEnum.AZURE_CONTENT_SAFETY)
50+
.config(
51+
new AzureContentSafety()
52+
.hate(hate != null ? hate.getAzureThreshold() : null)
53+
.selfHarm(selfHarm != null ? selfHarm.getAzureThreshold() : null)
54+
.sexual(sexual != null ? sexual.getAzureThreshold() : null)
55+
.violence(violence != null ? violence.getAzureThreshold() : null));
56+
}
57+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import com.sap.ai.sdk.orchestration.client.model.AzureThreshold;
4+
import javax.annotation.Nonnull;
5+
import lombok.AllArgsConstructor;
6+
import lombok.Getter;
7+
8+
/** An Enum wrapping Azure thresholds with readable names. */
9+
@Getter
10+
@AllArgsConstructor
11+
public enum AzureModerationPolicy {
12+
ALLOW_SAFE(AzureThreshold.NUMBER_0),
13+
ALLOW_SAFE_LOW(AzureThreshold.NUMBER_2),
14+
ALLOW_SAFE_LOW_MEDIUM(AzureThreshold.NUMBER_4),
15+
ALLOW_ALL(AzureThreshold.NUMBER_6);
16+
17+
@Nonnull final AzureThreshold azureThreshold;
18+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import com.sap.ai.sdk.orchestration.client.model.FilterConfig;
4+
5+
/**
6+
* Interface representing convenience wrappers of serializable content filter that defines
7+
* moderation policies for different content categories.
8+
*/
9+
public interface ContentFilter {
10+
11+
/**
12+
* A method that produces the serializable equivalent {@link FilterConfig} object from data
13+
* encapsulated in the {@link ContentFilter} object.
14+
*
15+
* @return the corresponding {@code FilterConfig} object.
16+
*/
17+
FilterConfig toSerializable();
18+
}

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationModuleConfig.java

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
package com.sap.ai.sdk.orchestration;
22

3-
import com.sap.ai.sdk.orchestration.client.model.AzureContentSafety;
4-
import com.sap.ai.sdk.orchestration.client.model.AzureContentSafetyFilterConfig;
53
import com.sap.ai.sdk.orchestration.client.model.FilteringModuleConfig;
64
import com.sap.ai.sdk.orchestration.client.model.InputFilteringConfig;
75
import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig;
86
import com.sap.ai.sdk.orchestration.client.model.MaskingModuleConfig;
97
import com.sap.ai.sdk.orchestration.client.model.OutputFilteringConfig;
108
import com.sap.ai.sdk.orchestration.client.model.TemplatingModuleConfig;
11-
import java.util.List;
9+
import java.util.Arrays;
1210
import javax.annotation.Nonnull;
1311
import javax.annotation.Nullable;
1412
import lombok.AccessLevel;
@@ -34,36 +32,43 @@
3432
* </ul>
3533
*/
3634
@Value
35+
@With
3736
@AllArgsConstructor(access = AccessLevel.PRIVATE)
3837
@NoArgsConstructor(force = true)
3938
public class OrchestrationModuleConfig {
4039
/**
4140
* The configured language model settings. This configuration is required when executing requests.
4241
*/
43-
@With @Nullable LLMModuleConfig llmConfig;
42+
@Nullable LLMModuleConfig llmConfig;
4443

4544
/**
4645
* A template to be populated with input parameters. Upon request execution, this template will be
4746
* enhanced with any messages and parameter values from {@link OrchestrationPrompt}.
4847
*/
49-
@With @Nullable TemplatingModuleConfig templateConfig;
48+
@Nullable TemplatingModuleConfig templateConfig;
5049

5150
/** A masking configuration to pseudonymous or anonymize sensitive data in the input. */
52-
@With @Nullable MaskingModuleConfig maskingConfig;
51+
@Nullable MaskingModuleConfig maskingConfig;
5352

5453
/** A content filter to filter the prompt. */
55-
@With(AccessLevel.PRIVATE)
56-
@Nullable
57-
FilteringModuleConfig filteringConfig;
54+
@Nullable FilteringModuleConfig filteringConfig;
5855

56+
/**
57+
* Adds input content filters to the orchestration configuration.
58+
*
59+
* <p>Preferred over {@link #withFilteringConfig(FilteringModuleConfig)} for adding input filters.
60+
*
61+
* @param contentFilters one or more content filters to apply to the input.
62+
* @return a new {@code OrchestrationModuleConfig} instance with the specified input filters
63+
* added.
64+
*/
5965
@Nonnull
6066
public OrchestrationModuleConfig withInputFiltering(
61-
@Nonnull final AzureContentSafety contentFilter) {
62-
var azureFilter =
63-
new AzureContentSafetyFilterConfig()
64-
.type(AzureContentSafetyFilterConfig.TypeEnum.AZURE_CONTENT_SAFETY)
65-
.config(contentFilter);
66-
var inputFilter = new InputFilteringConfig().filters(List.of(azureFilter));
67+
@Nonnull final ContentFilter... contentFilters) {
68+
69+
var filterConfigs = Arrays.stream(contentFilters).map(ContentFilter::toSerializable).toList();
70+
71+
var inputFilter = new InputFilteringConfig().filters(filterConfigs);
6772

6873
var newFilteringConfig =
6974
new FilteringModuleConfig()
@@ -73,14 +78,22 @@ public OrchestrationModuleConfig withInputFiltering(
7378
return this.withFilteringConfig(newFilteringConfig);
7479
}
7580

81+
/**
82+
* Adds output content filters to the orchestration configuration.
83+
*
84+
* <p>Preferred over {@link #withFilteringConfig(FilteringModuleConfig)} for adding output
85+
* filters.
86+
*
87+
* @param contentFilters one or more content filters to apply to the output.
88+
* @return a new {@code OrchestrationModuleConfig} instance with the specified output filters
89+
* added.
90+
*/
7691
@Nonnull
7792
public OrchestrationModuleConfig withOutputFiltering(
78-
@Nonnull final AzureContentSafety contentFilter) {
79-
var azureFilter =
80-
new AzureContentSafetyFilterConfig()
81-
.type(AzureContentSafetyFilterConfig.TypeEnum.AZURE_CONTENT_SAFETY)
82-
.config(contentFilter);
83-
var outputFilter = new OutputFilteringConfig().filters(List.of(azureFilter));
93+
@Nonnull final ContentFilter... contentFilters) {
94+
95+
var filterConfigs = Arrays.stream(contentFilters).map(ContentFilter::toSerializable).toList();
96+
var outputFilter = new OutputFilteringConfig().filters(filterConfigs);
8497

8598
var newFilteringConfig =
8699
new FilteringModuleConfig()

orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
import static com.github.tomakehurst.wiremock.client.WireMock.stubFor;
1717
import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo;
1818
import static com.github.tomakehurst.wiremock.client.WireMock.verify;
19-
import static com.sap.ai.sdk.orchestration.client.model.AzureThreshold.NUMBER_0;
20-
import static com.sap.ai.sdk.orchestration.client.model.AzureThreshold.NUMBER_4;
19+
import static com.sap.ai.sdk.orchestration.AzureModerationPolicy.ALLOW_SAFE;
20+
import static com.sap.ai.sdk.orchestration.AzureModerationPolicy.ALLOW_SAFE_LOW_MEDIUM;
2121
import static org.apache.hc.core5.http.HttpStatus.SC_BAD_REQUEST;
2222
import static org.assertj.core.api.Assertions.assertThat;
2323
import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -28,7 +28,6 @@
2828
import com.github.tomakehurst.wiremock.junit5.WireMockTest;
2929
import com.github.tomakehurst.wiremock.stubbing.Scenario;
3030
import com.sap.ai.sdk.core.AiCoreService;
31-
import com.sap.ai.sdk.orchestration.client.model.AzureContentSafety;
3231
import com.sap.ai.sdk.orchestration.client.model.ChatMessage;
3332
import com.sap.ai.sdk.orchestration.client.model.CompletionPostRequest;
3433
import com.sap.ai.sdk.orchestration.client.model.DPIConfig;
@@ -213,11 +212,11 @@ void filteringLoose() throws IOException {
213212
.withHeader("Content-Type", "application/json")));
214213

215214
final var filter =
216-
new AzureContentSafety()
217-
.hate(NUMBER_4)
218-
.selfHarm(NUMBER_4)
219-
.sexual(NUMBER_4)
220-
.violence(NUMBER_4);
215+
new AzureContentFilter()
216+
.hate(ALLOW_SAFE_LOW_MEDIUM)
217+
.selfHarm(ALLOW_SAFE_LOW_MEDIUM)
218+
.sexual(ALLOW_SAFE_LOW_MEDIUM)
219+
.violence(ALLOW_SAFE_LOW_MEDIUM);
221220

222221
client.chatCompletion(prompt, config.withInputFiltering(filter).withOutputFiltering(filter));
223222
// the result is asserted in the verify step below
@@ -241,11 +240,11 @@ void filteringStrict() {
241240
.willReturn(jsonResponse(response, SC_BAD_REQUEST)));
242241

243242
final var filter =
244-
new AzureContentSafety()
245-
.hate(NUMBER_0)
246-
.selfHarm(NUMBER_0)
247-
.sexual(NUMBER_0)
248-
.violence(NUMBER_0);
243+
new AzureContentFilter()
244+
.hate(ALLOW_SAFE)
245+
.selfHarm(ALLOW_SAFE)
246+
.sexual(ALLOW_SAFE)
247+
.violence(ALLOW_SAFE);
249248

250249
final var configWithFilter = config.withInputFiltering(filter).withOutputFiltering(filter);
251250

sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
package com.sap.ai.sdk.app.controllers;
22

3+
import com.sap.ai.sdk.orchestration.AzureContentFilter;
4+
import com.sap.ai.sdk.orchestration.AzureModerationPolicy;
35
import com.sap.ai.sdk.orchestration.OrchestrationChatResponse;
46
import com.sap.ai.sdk.orchestration.OrchestrationClient;
57
import com.sap.ai.sdk.orchestration.OrchestrationModuleConfig;
68
import com.sap.ai.sdk.orchestration.OrchestrationPrompt;
7-
import com.sap.ai.sdk.orchestration.client.model.AzureContentSafety;
8-
import com.sap.ai.sdk.orchestration.client.model.AzureThreshold;
99
import com.sap.ai.sdk.orchestration.client.model.ChatMessage;
1010
import com.sap.ai.sdk.orchestration.client.model.DPIConfig;
1111
import com.sap.ai.sdk.orchestration.client.model.DPIEntities;
@@ -89,13 +89,13 @@ public OrchestrationChatResponse messagesHistory() {
8989
/**
9090
* Apply both input and output filtering for a request to orchestration.
9191
*
92-
* @param threshold A high threshold is a loose filter, a low threshold is a strict filter
92+
* @param policy A high threshold is a loose filter, a low threshold is a strict filter
9393
* @return the result object
9494
*/
95-
@GetMapping("/filter/{threshold}")
95+
@GetMapping("/filter/{policy}")
9696
@Nonnull
9797
public OrchestrationChatResponse filter(
98-
@Nonnull @PathVariable("threshold") final AzureThreshold threshold) {
98+
@Nonnull @PathVariable("policy") final AzureModerationPolicy policy) {
9999
final var prompt =
100100
new OrchestrationPrompt(
101101
"""
@@ -104,11 +104,8 @@ public OrchestrationChatResponse filter(
104104
```DISCLAIMER: The area surrounding the apartment is known for prostitutes and gang violence including armed conflicts, gun violence is frequent.
105105
""");
106106
final var filterConfig =
107-
new AzureContentSafety()
108-
.hate(threshold)
109-
.selfHarm(threshold)
110-
.sexual(threshold)
111-
.violence(threshold);
107+
new AzureContentFilter().hate(policy).selfHarm(policy).sexual(policy).violence(policy);
108+
112109
final var configWithFilter =
113110
config.withInputFiltering(filterConfig).withOutputFiltering(filterConfig);
114111

sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import static org.assertj.core.api.Assertions.assertThat;
44
import static org.assertj.core.api.Assertions.assertThatThrownBy;
55

6+
import com.sap.ai.sdk.orchestration.AzureModerationPolicy;
67
import com.sap.ai.sdk.orchestration.OrchestrationClientException;
7-
import com.sap.ai.sdk.orchestration.client.model.AzureThreshold;
88
import com.sap.ai.sdk.orchestration.client.model.CompletionPostResponse;
99
import com.sap.ai.sdk.orchestration.client.model.LLMChoice;
1010
import com.sap.ai.sdk.orchestration.client.model.LLMModuleResultSynchronous;
@@ -72,7 +72,7 @@ void testTemplate() {
7272

7373
@Test
7474
void testLenientContentFilter() {
75-
var response = controller.filter(AzureThreshold.NUMBER_4);
75+
var response = controller.filter(AzureModerationPolicy.ALLOW_SAFE_LOW_MEDIUM);
7676
var result = response.getOriginalResponse();
7777
var llmChoice =
7878
((LLMModuleResultSynchronous) result.getOrchestrationResult()).getChoices().get(0);
@@ -85,7 +85,7 @@ void testLenientContentFilter() {
8585

8686
@Test
8787
void testStrictContentFilter() {
88-
assertThatThrownBy(() -> controller.filter(AzureThreshold.NUMBER_0))
88+
assertThatThrownBy(() -> controller.filter(AzureModerationPolicy.ALLOW_SAFE))
8989
.isInstanceOf(OrchestrationClientException.class)
9090
.hasMessageContaining("400 Bad Request")
9191
.hasMessageContaining("Content filtered");

0 commit comments

Comments
 (0)