Skip to content

Commit 6bf9b60

Browse files
Add orchestration convenient filtering API (#169)
* Set filters conveniently - Reuse the safety enum values - Carry the immutability characteristic - Disallow setting empty filter * Update docs * Formatting * consistent variable naming * Introduce readable thresholds based api - Multi filter support - detach per filter logic from generic methods - add javadoc - update tests * Formatting * change method name * Improve docs * Formatting * Improve docs and confirm checkstyle * Update tests * Update test organization * Update naming * Update docs again for naming change --------- Co-authored-by: Roshin Rajan Panackal <[email protected]> Co-authored-by: SAP Cloud SDK Bot <[email protected]>
1 parent 5e57895 commit 6bf9b60

File tree

10 files changed

+305
-143
lines changed

10 files changed

+305
-143
lines changed

docs/guides/ORCHESTRATION_CHAT_COMPLETION.md

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -146,33 +146,20 @@ var prompt = new OrchestrationPrompt(
146146
```DISCLAIMER: The area surrounding the apartment is known for prostitutes and gang violence including armed conflicts, gun violence is frequent.
147147
""");
148148

149-
var filterStrict =
150-
FilterConfig.create()
151-
.type(FilterConfig.TypeEnum.AZURE_CONTENT_SAFETY)
152-
.config(
153-
AzureContentSafety.create()
154-
.hate(NUMBER_0)
155-
.selfHarm(NUMBER_0)
156-
.sexual(NUMBER_0)
157-
.violence(NUMBER_0));
158-
159-
var filterLoose =
160-
FilterConfig.create()
161-
.type(FilterConfig.TypeEnum.AZURE_CONTENT_SAFETY)
162-
.config(
163-
AzureContentSafety.create()
164-
.hate(NUMBER_4)
165-
.selfHarm(NUMBER_4)
166-
.sexual(NUMBER_4)
167-
.violence(NUMBER_4));
168-
169-
var filteringConfig =
170-
FilteringModuleConfig.create()
171-
// changing the input to filterLoose will allow the message to pass
172-
.input(InputFilteringConfig.create().filters(filterStrict))
173-
.output(OutputFilteringConfig.create().filters(filterStrict));
174-
175-
var configWithFilter = config.withFilteringConfig(filteringConfig);
149+
var filterStrict = new AzureContentFilter()
150+
.hate(ALLOW_SAFE)
151+
.selfHarm(ALLOW_SAFE)
152+
.sexual(ALLOW_SAFE)
153+
.violence(ALLOW_SAFE);
154+
155+
var filterLoose = new AzureContentFilter()
156+
.hate(ALLOW_SAFE_LOW_MEDIUM)
157+
.selfHarm(ALLOW_SAFE_LOW_MEDIUM)
158+
.sexual(ALLOW_SAFE_LOW_MEDIUM)
159+
.violence(ALLOW_SAFE_LOW_MEDIUM);
160+
161+
// changing the input to filterLoose will allow the message to pass
162+
var configWithFilter = config.withInputFiltering(filterStrict).withOutputFiltering(filterStrict);
176163

177164
// this fails with Bad Request because the strict filter prohibits the input message
178165
var result =
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import com.sap.ai.sdk.orchestration.client.model.AzureContentSafety;
4+
import com.sap.ai.sdk.orchestration.client.model.AzureContentSafetyFilterConfig;
5+
import javax.annotation.Nonnull;
6+
import javax.annotation.Nullable;
7+
import lombok.NoArgsConstructor;
8+
import lombok.Setter;
9+
import lombok.experimental.Accessors;
10+
11+
/**
12+
* A content filter wrapping Azure Content Safety.
13+
*
14+
* <p>This class allows setting filtration thresholds for different content categories such as hate,
15+
* self-harm, sexual, and violence.
16+
*
17+
* <p>Example usage:
18+
*
19+
* <pre>{@code
20+
* AzureContentFilter filter = new AzureContentFilter()
21+
* .hate(AzureFilterThreshold.ALLOW_SAFE)
22+
* .selfHarm(AzureFilterThreshold.ALLOW_SAFE_LOW);
23+
* }</pre>
24+
*/
25+
@Setter
26+
@NoArgsConstructor
27+
@Accessors(fluent = true)
28+
public class AzureContentFilter implements ContentFilter {
29+
30+
/* The filter category for hate content. */
31+
@Nullable AzureFilterThreshold hate;
32+
33+
/* The filter category for self-harm content. */
34+
@Nullable AzureFilterThreshold selfHarm;
35+
36+
/* The filter category for sexual content. */
37+
@Nullable AzureFilterThreshold sexual;
38+
39+
/* The filter category for violence content. */
40+
@Nullable AzureFilterThreshold violence;
41+
42+
/**
43+
* Converts {@code AzureContentFilter} to its serializable counterpart {@link
44+
* AzureContentSafetyFilterConfig}.
45+
*
46+
* @return the corresponding {@code AzureContentSafetyFilterConfig} object.
47+
* @throws IllegalArgumentException if no policies are set.
48+
*/
49+
@Override
50+
@Nonnull
51+
public AzureContentSafetyFilterConfig createConfig() {
52+
if (hate == null && selfHarm == null && sexual == null && violence == null) {
53+
throw new IllegalArgumentException("At least one filter category must be set");
54+
}
55+
56+
return new AzureContentSafetyFilterConfig()
57+
.type(AzureContentSafetyFilterConfig.TypeEnum.AZURE_CONTENT_SAFETY)
58+
.config(
59+
new AzureContentSafety()
60+
.hate(hate != null ? hate.getAzureThreshold() : null)
61+
.selfHarm(selfHarm != null ? selfHarm.getAzureThreshold() : null)
62+
.sexual(sexual != null ? sexual.getAzureThreshold() : null)
63+
.violence(violence != null ? violence.getAzureThreshold() : null));
64+
}
65+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import com.sap.ai.sdk.orchestration.client.model.AzureThreshold;
4+
import javax.annotation.Nonnull;
5+
import lombok.AllArgsConstructor;
6+
import lombok.Getter;
7+
8+
/** An Enum wrapping Azure thresholds with readable names. */
9+
@Getter
10+
@AllArgsConstructor
11+
public enum AzureFilterThreshold {
12+
/** Only safe content is allowed. */
13+
ALLOW_SAFE(AzureThreshold.NUMBER_0),
14+
15+
/** Safe and low-risk content is allowed. */
16+
ALLOW_SAFE_LOW(AzureThreshold.NUMBER_2),
17+
18+
/** Safe, low-risk, and medium-risk content is allowed. */
19+
ALLOW_SAFE_LOW_MEDIUM(AzureThreshold.NUMBER_4),
20+
21+
/** All content is allowed. */
22+
ALLOW_ALL(AzureThreshold.NUMBER_6);
23+
24+
@Nonnull final AzureThreshold azureThreshold;
25+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import com.sap.ai.sdk.orchestration.client.model.FilterConfig;
4+
import javax.annotation.Nonnull;
5+
6+
/**
7+
* Interface representing convenience wrappers of serializable content filter that defines
8+
* thresholds for different content categories.
9+
*/
10+
public interface ContentFilter {
11+
12+
/**
13+
* A method that produces the serializable equivalent {@link FilterConfig} object from data
14+
* encapsulated in the {@link ContentFilter} object.
15+
*
16+
* @return the corresponding {@code FilterConfig} object.
17+
*/
18+
@Nonnull
19+
FilterConfig createConfig();
20+
}

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationModuleConfig.java

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
package com.sap.ai.sdk.orchestration;
22

33
import com.sap.ai.sdk.orchestration.client.model.FilteringModuleConfig;
4+
import com.sap.ai.sdk.orchestration.client.model.InputFilteringConfig;
45
import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig;
56
import com.sap.ai.sdk.orchestration.client.model.MaskingModuleConfig;
7+
import com.sap.ai.sdk.orchestration.client.model.OutputFilteringConfig;
68
import com.sap.ai.sdk.orchestration.client.model.TemplatingModuleConfig;
9+
import java.util.ArrayList;
710
import java.util.Arrays;
11+
import java.util.Objects;
812
import javax.annotation.Nonnull;
913
import javax.annotation.Nullable;
1014
import lombok.AccessLevel;
@@ -84,4 +88,65 @@ public OrchestrationModuleConfig withMaskingConfig(
8488

8589
return withMaskingConfig(newMaskingConfig);
8690
}
91+
92+
/**
93+
* Adds input content filters to the orchestration configuration.
94+
*
95+
* <p>Preferred over {@link #withFilteringConfig(FilteringModuleConfig)} for adding input filters.
96+
*
97+
* @param contentFilters one or more content filters to apply to the input.
98+
* @return a new {@code OrchestrationModuleConfig} instance with the specified input filters
99+
* added.
100+
*/
101+
@Nonnull
102+
public OrchestrationModuleConfig withInputFiltering(
103+
@Nonnull final ContentFilter contentFilter, @Nonnull final ContentFilter... contentFilters) {
104+
105+
final var allFilters = new ArrayList<ContentFilter>();
106+
allFilters.add(contentFilter);
107+
allFilters.addAll(Arrays.asList(contentFilters));
108+
109+
final var filterConfigs =
110+
allFilters.stream().filter(Objects::nonNull).map(ContentFilter::createConfig).toList();
111+
112+
final var inputFilter = new InputFilteringConfig().filters(filterConfigs);
113+
114+
final var newFilteringConfig =
115+
new FilteringModuleConfig()
116+
.input(inputFilter)
117+
.output(this.filteringConfig != null ? this.filteringConfig.getOutput() : null);
118+
119+
return this.withFilteringConfig(newFilteringConfig);
120+
}
121+
122+
/**
123+
* Adds output content filters to the orchestration configuration.
124+
*
125+
* <p>Preferred over {@link #withFilteringConfig(FilteringModuleConfig)} for adding output
126+
* filters.
127+
*
128+
* @param contentFilters one or more content filters to apply to the output.
129+
* @return a new {@code OrchestrationModuleConfig} instance with the specified output filters
130+
* added.
131+
*/
132+
@Nonnull
133+
public OrchestrationModuleConfig withOutputFiltering(
134+
@Nonnull final ContentFilter contentFilter, @Nonnull final ContentFilter... contentFilters) {
135+
136+
final var allFilters = new ArrayList<ContentFilter>();
137+
allFilters.add(contentFilter);
138+
allFilters.addAll(Arrays.asList(contentFilters));
139+
140+
final var filterConfigs =
141+
allFilters.stream().filter(Objects::nonNull).map(ContentFilter::createConfig).toList();
142+
143+
final var outputFilter = new OutputFilteringConfig().filters(filterConfigs);
144+
145+
final var newFilteringConfig =
146+
new FilteringModuleConfig()
147+
.output(outputFilter)
148+
.input(this.filteringConfig != null ? this.filteringConfig.getInput() : null);
149+
150+
return this.withFilteringConfig(newFilteringConfig);
151+
}
87152
}

orchestration/src/test/java/com/sap/ai/sdk/orchestration/ConfigToRequestTransformerTest.java

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
package com.sap.ai.sdk.orchestration;
22

3-
import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.GPT_4O;
43
import static com.sap.ai.sdk.orchestration.OrchestrationUnitTest.CUSTOM_GPT_35;
54
import static org.assertj.core.api.Assertions.assertThat;
65
import static org.assertj.core.api.Assertions.assertThatThrownBy;
76

87
import com.sap.ai.sdk.orchestration.client.model.ChatMessage;
9-
import com.sap.ai.sdk.orchestration.client.model.DPIConfig;
10-
import com.sap.ai.sdk.orchestration.client.model.DPIEntities;
118
import com.sap.ai.sdk.orchestration.client.model.Template;
129
import java.util.List;
1310
import java.util.Map;
@@ -78,50 +75,4 @@ void testMessagesHistory() {
7875

7976
assertThat(actual.getMessagesHistory()).containsExactly(systemMessage);
8077
}
81-
82-
@Test
83-
void testDpiMaskingConfig() {
84-
var maskingConfig = DpiMasking.anonymization().withEntities(DPIEntities.ADDRESS);
85-
var config =
86-
new OrchestrationModuleConfig()
87-
.withLlmConfig(CUSTOM_GPT_35)
88-
.withMaskingConfig(maskingConfig);
89-
90-
var actual = ConfigToRequestTransformer.toModuleConfigs(config);
91-
92-
assertThat(actual.getMaskingModuleConfig()).isNotNull();
93-
assertThat(actual.getMaskingModuleConfig().getMaskingProviders()).hasSize(1);
94-
DPIConfig dpiConfig = (DPIConfig) actual.getMaskingModuleConfig().getMaskingProviders().get(0);
95-
assertThat(dpiConfig.getMethod()).isEqualTo(DPIConfig.MethodEnum.ANONYMIZATION);
96-
assertThat(dpiConfig.getEntities()).hasSize(1);
97-
assertThat(dpiConfig.getEntities().get(0).getType()).isEqualTo(DPIEntities.ADDRESS);
98-
99-
var configModified = config.withMaskingConfig(maskingConfig);
100-
assertThat(configModified.getMaskingConfig()).isNotNull();
101-
assertThat(configModified.getMaskingConfig().getMaskingProviders())
102-
.withFailMessage("withMaskingConfig() should overwrite the existing config and not append")
103-
.hasSize(1);
104-
}
105-
106-
@Test
107-
void testLLMConfig() {
108-
Map<String, Object> params = Map.of("foo", "bar");
109-
String version = "2024-05-13";
110-
OrchestrationAiModel aiModel = GPT_4O.withModelParams(params).withModelVersion(version);
111-
var config = new OrchestrationModuleConfig().withLlmConfig(aiModel);
112-
113-
var actual = ConfigToRequestTransformer.toModuleConfigs(config);
114-
115-
assertThat(actual.getLlmModuleConfig()).isNotNull();
116-
assertThat(actual.getLlmModuleConfig().getModelName()).isEqualTo(GPT_4O.getModelName());
117-
assertThat(actual.getLlmModuleConfig().getModelParams()).isEqualTo(params);
118-
assertThat(actual.getLlmModuleConfig().getModelVersion()).isEqualTo(version);
119-
120-
assertThat(GPT_4O.getModelParams())
121-
.withFailMessage("Static models should be unchanged")
122-
.isEmpty();
123-
assertThat(GPT_4O.getModelVersion())
124-
.withFailMessage("Static models should be unchanged")
125-
.isEqualTo("latest");
126-
}
12778
}
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import static com.sap.ai.sdk.orchestration.AzureFilterThreshold.ALLOW_SAFE_LOW_MEDIUM;
4+
import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.GPT_4O;
5+
import static org.assertj.core.api.Assertions.assertThat;
6+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
7+
8+
import com.sap.ai.sdk.orchestration.client.model.DPIConfig;
9+
import com.sap.ai.sdk.orchestration.client.model.DPIEntities;
10+
import java.util.Map;
11+
import org.junit.jupiter.api.Test;
12+
13+
class OrchestrationModuleConfigTest {
14+
15+
@Test
16+
void testStackingInputAndOutputFilter() {
17+
final var config = new OrchestrationModuleConfig().withLlmConfig(GPT_4O);
18+
19+
final var filter =
20+
new AzureContentFilter()
21+
.hate(ALLOW_SAFE_LOW_MEDIUM)
22+
.selfHarm(ALLOW_SAFE_LOW_MEDIUM)
23+
.sexual(ALLOW_SAFE_LOW_MEDIUM)
24+
.violence(ALLOW_SAFE_LOW_MEDIUM);
25+
26+
final var configWithInputFirst = config.withInputFiltering(filter).withOutputFiltering(filter);
27+
assertThat(configWithInputFirst.getFilteringConfig()).isNotNull();
28+
assertThat(configWithInputFirst.getFilteringConfig().getInput()).isNotNull();
29+
30+
final var configWithOutputFirst = config.withOutputFiltering(filter).withInputFiltering(filter);
31+
assertThat(configWithOutputFirst.getFilteringConfig()).isNotNull();
32+
assertThat(configWithOutputFirst.getFilteringConfig().getOutput()).isNotNull();
33+
}
34+
35+
@Test
36+
void testThrowOnEmptyFilterConfig() {
37+
38+
final var config = new OrchestrationModuleConfig().withLlmConfig(GPT_4O);
39+
40+
assertThatThrownBy(() -> config.withInputFiltering(new AzureContentFilter()))
41+
.isInstanceOf(IllegalArgumentException.class)
42+
.hasMessage("At least one filter category must be set");
43+
assertThatThrownBy(() -> config.withOutputFiltering(new AzureContentFilter()))
44+
.isInstanceOf(IllegalArgumentException.class)
45+
.hasMessage("At least one filter category must be set");
46+
}
47+
48+
@Test
49+
void testDpiMaskingConfig() {
50+
var maskingConfig = DpiMasking.anonymization().withEntities(DPIEntities.ADDRESS);
51+
var config =
52+
new OrchestrationModuleConfig().withLlmConfig(GPT_4O).withMaskingConfig(maskingConfig);
53+
54+
assertThat(config.getMaskingConfig()).isNotNull();
55+
assertThat(config.getMaskingConfig().getMaskingProviders()).hasSize(1);
56+
DPIConfig dpiConfig = (DPIConfig) config.getMaskingConfig().getMaskingProviders().get(0);
57+
assertThat(dpiConfig.getMethod()).isEqualTo(DPIConfig.MethodEnum.ANONYMIZATION);
58+
assertThat(dpiConfig.getEntities()).hasSize(1);
59+
assertThat(dpiConfig.getEntities().get(0).getType()).isEqualTo(DPIEntities.ADDRESS);
60+
61+
var configModified = config.withMaskingConfig(maskingConfig);
62+
assertThat(configModified.getMaskingConfig()).isNotNull();
63+
assertThat(configModified.getMaskingConfig().getMaskingProviders())
64+
.withFailMessage("withMaskingConfig() should overwrite the existing config and not append")
65+
.hasSize(1);
66+
}
67+
68+
@Test
69+
void testLLMConfig() {
70+
Map<String, Object> params = Map.of("foo", "bar");
71+
String version = "2024-05-13";
72+
OrchestrationAiModel aiModel = GPT_4O.withModelParams(params).withModelVersion(version);
73+
var config = new OrchestrationModuleConfig().withLlmConfig(aiModel);
74+
75+
assertThat(config.getLlmConfig()).isNotNull();
76+
assertThat(config.getLlmConfig().getModelName()).isEqualTo(GPT_4O.getModelName());
77+
assertThat(config.getLlmConfig().getModelParams()).isEqualTo(params);
78+
assertThat(config.getLlmConfig().getModelVersion()).isEqualTo(version);
79+
80+
assertThat(GPT_4O.getModelParams())
81+
.withFailMessage("Static models should be unchanged")
82+
.isEmpty();
83+
assertThat(GPT_4O.getModelVersion())
84+
.withFailMessage("Static models should be unchanged")
85+
.isEqualTo("latest");
86+
}
87+
}

0 commit comments

Comments
 (0)