Skip to content

Commit 55ca0a4

Browse files
committed
Set filters conveniently
- Reuse the safety enum values - Carry the immutability characteristic - Disallow setting empty filter
1 parent 50b3c47 commit 55ca0a4

File tree

3 files changed

+67
-60
lines changed

3 files changed

+67
-60
lines changed

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationModuleConfig.java

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
package com.sap.ai.sdk.orchestration;
22

3+
import com.sap.ai.sdk.orchestration.client.model.AzureContentSafety;
4+
import com.sap.ai.sdk.orchestration.client.model.AzureContentSafetyFilterConfig;
35
import com.sap.ai.sdk.orchestration.client.model.FilteringModuleConfig;
6+
import com.sap.ai.sdk.orchestration.client.model.InputFilteringConfig;
47
import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig;
58
import com.sap.ai.sdk.orchestration.client.model.MaskingModuleConfig;
9+
import com.sap.ai.sdk.orchestration.client.model.OutputFilteringConfig;
610
import com.sap.ai.sdk.orchestration.client.model.TemplatingModuleConfig;
11+
import java.util.List;
12+
import javax.annotation.Nonnull;
713
import javax.annotation.Nullable;
814
import lombok.AccessLevel;
915
import lombok.AllArgsConstructor;
@@ -28,24 +34,57 @@
2834
* </ul>
2935
*/
3036
@Value
31-
@With
3237
@AllArgsConstructor(access = AccessLevel.PRIVATE)
3338
@NoArgsConstructor(force = true)
3439
public class OrchestrationModuleConfig {
3540
/**
3641
* The configured language model settings. This configuration is required when executing requests.
3742
*/
38-
@Nullable LLMModuleConfig llmConfig;
43+
@With @Nullable LLMModuleConfig llmConfig;
3944

4045
/**
4146
* A template to be populated with input parameters. Upon request execution, this template will be
4247
* enhanced with any messages and parameter values from {@link OrchestrationPrompt}.
4348
*/
44-
@Nullable TemplatingModuleConfig templateConfig;
49+
@With @Nullable TemplatingModuleConfig templateConfig;
4550

4651
/** A masking configuration to pseudonymous or anonymize sensitive data in the input. */
47-
@Nullable MaskingModuleConfig maskingConfig;
52+
@With @Nullable MaskingModuleConfig maskingConfig;
4853

4954
/** A content filter to filter the prompt. */
50-
@Nullable FilteringModuleConfig filteringConfig;
55+
@With(AccessLevel.PRIVATE)
56+
@Nullable
57+
FilteringModuleConfig filteringConfig;
58+
59+
@Nonnull
60+
public OrchestrationModuleConfig withInputFiltering(AzureContentSafety contentFilter) {
61+
var azureFilter =
62+
new AzureContentSafetyFilterConfig()
63+
.type(AzureContentSafetyFilterConfig.TypeEnum.AZURE_CONTENT_SAFETY)
64+
.config(contentFilter);
65+
var inputFilters = new InputFilteringConfig().filters(List.of(azureFilter));
66+
67+
var newFilteringConfig =
68+
new FilteringModuleConfig()
69+
.input(inputFilters)
70+
.output(this.filteringConfig != null ? this.filteringConfig.getOutput() : null);
71+
72+
return this.withFilteringConfig(newFilteringConfig);
73+
}
74+
75+
@Nonnull
76+
public OrchestrationModuleConfig withOutputFiltering(AzureContentSafety safety) {
77+
var filter =
78+
new AzureContentSafetyFilterConfig()
79+
.type(AzureContentSafetyFilterConfig.TypeEnum.AZURE_CONTENT_SAFETY)
80+
.config(safety);
81+
var outputFilteringConfig = new OutputFilteringConfig().filters(List.of(filter));
82+
83+
var newFilteringConfig =
84+
new FilteringModuleConfig()
85+
.output(outputFilteringConfig)
86+
.input(this.filteringConfig != null ? this.filteringConfig.getInput() : null);
87+
88+
return this.withFilteringConfig(newFilteringConfig);
89+
}
5190
}

orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,15 @@
2929
import com.github.tomakehurst.wiremock.stubbing.Scenario;
3030
import com.sap.ai.sdk.core.AiCoreService;
3131
import com.sap.ai.sdk.orchestration.client.model.AzureContentSafety;
32-
import com.sap.ai.sdk.orchestration.client.model.AzureContentSafetyFilterConfig;
33-
import com.sap.ai.sdk.orchestration.client.model.AzureThreshold;
3432
import com.sap.ai.sdk.orchestration.client.model.ChatMessage;
3533
import com.sap.ai.sdk.orchestration.client.model.CompletionPostRequest;
3634
import com.sap.ai.sdk.orchestration.client.model.DPIConfig;
3735
import com.sap.ai.sdk.orchestration.client.model.DPIEntities;
3836
import com.sap.ai.sdk.orchestration.client.model.DPIEntityConfig;
39-
import com.sap.ai.sdk.orchestration.client.model.FilteringModuleConfig;
4037
import com.sap.ai.sdk.orchestration.client.model.GenericModuleResult;
41-
import com.sap.ai.sdk.orchestration.client.model.InputFilteringConfig;
4238
import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig;
4339
import com.sap.ai.sdk.orchestration.client.model.LLMModuleResultSynchronous;
4440
import com.sap.ai.sdk.orchestration.client.model.MaskingModuleConfig;
45-
import com.sap.ai.sdk.orchestration.client.model.OutputFilteringConfig;
4641
import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination;
4742
import java.io.IOException;
4843
import java.io.InputStream;
@@ -217,9 +212,14 @@ void filteringLoose() throws IOException {
217212
.withBodyFile("filteringLooseResponse.json")
218213
.withHeader("Content-Type", "application/json")));
219214

220-
final var filter = createAzureContentFilter(NUMBER_4);
215+
final var filter =
216+
new AzureContentSafety()
217+
.hate(NUMBER_4)
218+
.selfHarm(NUMBER_4)
219+
.sexual(NUMBER_4)
220+
.violence(NUMBER_4);
221221

222-
client.chatCompletion(prompt, config.withFilteringConfig(filter));
222+
client.chatCompletion(prompt, config.withInputFiltering(filter).withOutputFiltering(filter));
223223
// the result is asserted in the verify step below
224224

225225
// verify that null fields are absent from the sent request
@@ -240,33 +240,21 @@ void filteringStrict() {
240240
post(urlPathEqualTo("/v2/inference/deployments/abcdef0123456789/completion"))
241241
.willReturn(jsonResponse(response, SC_BAD_REQUEST)));
242242

243-
final var filter = createAzureContentFilter(NUMBER_0);
243+
final var filter =
244+
new AzureContentSafety()
245+
.hate(NUMBER_0)
246+
.selfHarm(NUMBER_0)
247+
.sexual(NUMBER_0)
248+
.violence(NUMBER_0);
244249

245-
final var configWithFilter = config.withFilteringConfig(filter);
250+
final var configWithFilter = config.withInputFiltering(filter).withOutputFiltering(filter);
246251

247252
assertThatThrownBy(() -> client.chatCompletion(prompt, configWithFilter))
248253
.isInstanceOf(OrchestrationClientException.class)
249254
.hasMessage(
250255
"Request to orchestration service failed with status 400 Bad Request and error message: 'Content filtered due to Safety violations. Please modify the prompt and try again.'");
251256
}
252257

253-
private static FilteringModuleConfig createAzureContentFilter(
254-
@Nonnull final AzureThreshold threshold) {
255-
final var filter =
256-
new AzureContentSafetyFilterConfig()
257-
.type(AzureContentSafetyFilterConfig.TypeEnum.AZURE_CONTENT_SAFETY)
258-
.config(
259-
new AzureContentSafety()
260-
.hate(threshold)
261-
.selfHarm(threshold)
262-
.sexual(threshold)
263-
.violence(threshold));
264-
265-
return new FilteringModuleConfig()
266-
.input(new InputFilteringConfig().filters(List.of(filter)))
267-
.output(new OutputFilteringConfig().filters(List.of(filter)));
268-
}
269-
270258
@Test
271259
void messagesHistory() throws IOException {
272260
stubFor(

sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java

Lines changed: 9 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,14 @@
44
import com.sap.ai.sdk.orchestration.OrchestrationModuleConfig;
55
import com.sap.ai.sdk.orchestration.OrchestrationPrompt;
66
import com.sap.ai.sdk.orchestration.client.model.AzureContentSafety;
7-
import com.sap.ai.sdk.orchestration.client.model.AzureContentSafetyFilterConfig;
87
import com.sap.ai.sdk.orchestration.client.model.AzureThreshold;
98
import com.sap.ai.sdk.orchestration.client.model.ChatMessage;
109
import com.sap.ai.sdk.orchestration.client.model.CompletionPostResponse;
1110
import com.sap.ai.sdk.orchestration.client.model.DPIConfig;
1211
import com.sap.ai.sdk.orchestration.client.model.DPIEntities;
1312
import com.sap.ai.sdk.orchestration.client.model.DPIEntityConfig;
14-
import com.sap.ai.sdk.orchestration.client.model.FilteringModuleConfig;
15-
import com.sap.ai.sdk.orchestration.client.model.InputFilteringConfig;
1613
import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig;
1714
import com.sap.ai.sdk.orchestration.client.model.MaskingModuleConfig;
18-
import com.sap.ai.sdk.orchestration.client.model.OutputFilteringConfig;
1915
import com.sap.ai.sdk.orchestration.client.model.Template;
2016
import java.util.Arrays;
2117
import java.util.List;
@@ -107,34 +103,18 @@ public CompletionPostResponse filter(
107103
108104
```DISCLAIMER: The area surrounding the apartment is known for prostitutes and gang violence including armed conflicts, gun violence is frequent.
109105
""");
110-
final var filterConfig = createAzureContentFilter(threshold);
111-
final var configWithFilter = config.withFilteringConfig(filterConfig);
106+
final var filterConfig =
107+
new AzureContentSafety()
108+
.hate(threshold)
109+
.selfHarm(threshold)
110+
.sexual(threshold)
111+
.violence(threshold);
112+
final var configWithFilter =
113+
config.withInputFiltering(filterConfig).withOutputFiltering(filterConfig);
112114

113115
return client.chatCompletion(prompt, configWithFilter);
114116
}
115-
116-
/**
117-
* Helper method to build filter configurations.
118-
*
119-
* @param threshold The threshold to be applied across all filter categories.
120-
* @return A new filter configuration object.
121-
*/
122-
private static FilteringModuleConfig createAzureContentFilter(
123-
@Nonnull final AzureThreshold threshold) {
124-
final var filter =
125-
new AzureContentSafetyFilterConfig()
126-
.type(AzureContentSafetyFilterConfig.TypeEnum.AZURE_CONTENT_SAFETY)
127-
.config(
128-
new AzureContentSafety()
129-
.hate(threshold)
130-
.selfHarm(threshold)
131-
.sexual(threshold)
132-
.violence(threshold));
133-
134-
return new FilteringModuleConfig()
135-
.input(new InputFilteringConfig().filters(List.of(filter)))
136-
.output(new OutputFilteringConfig().filters(List.of(filter)));
137-
}
117+
138118

139119
/**
140120
* Let the orchestration service evaluate the feedback on the AI SDK provided by a hypothetical

0 commit comments

Comments
 (0)