Skip to content

Commit b87f0a9

Browse files
committed
Fix compilation (new Filtering classes)
1 parent 50262f3 commit b87f0a9

File tree

4 files changed

+64
-15
lines changed

4 files changed

+64
-15
lines changed

orchestration/src/main/java/com/sap/ai/sdk/orchestration/AzureContentFilter.java

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
package com.sap.ai.sdk.orchestration;
22

3-
import com.sap.ai.sdk.orchestration.model.AzureContentSafety;
4-
import com.sap.ai.sdk.orchestration.model.AzureContentSafetyFilterConfig;
3+
import com.sap.ai.sdk.orchestration.model.AzureContentSafetyInput;
4+
import com.sap.ai.sdk.orchestration.model.AzureContentSafetyInputFilterConfig;
5+
import com.sap.ai.sdk.orchestration.model.AzureContentSafetyOutput;
6+
import com.sap.ai.sdk.orchestration.model.AzureContentSafetyOutputFilterConfig;
57
import javax.annotation.Nonnull;
68
import javax.annotation.Nullable;
79
import lombok.NoArgsConstructor;
@@ -48,22 +50,46 @@ public class AzureContentFilter implements ContentFilter {
4850

4951
/**
5052
* Converts {@code AzureContentFilter} to its serializable counterpart {@link
51-
* AzureContentSafetyFilterConfig}.
53+
* AzureContentSafetyInputFilterConfig}.
5254
*
53-
* @return the corresponding {@code AzureContentSafetyFilterConfig} object.
55+
* @return the corresponding {@code AzureContentSafetyInputFilterConfig} object.
5456
* @throws IllegalArgumentException if no policies are set.
5557
*/
5658
@Override
5759
@Nonnull
58-
public AzureContentSafetyFilterConfig createConfig() {
60+
public AzureContentSafetyInputFilterConfig createInputConfig() {
5961
if (hate == null && selfHarm == null && sexual == null && violence == null) {
6062
throw new IllegalArgumentException("At least one filter category must be set");
6163
}
6264

63-
return AzureContentSafetyFilterConfig.create()
64-
.type(AzureContentSafetyFilterConfig.TypeEnum.AZURE_CONTENT_SAFETY)
65+
return AzureContentSafetyInputFilterConfig.create()
66+
.type(AzureContentSafetyInputFilterConfig.TypeEnum.AZURE_CONTENT_SAFETY)
6567
.config(
66-
AzureContentSafety.create()
68+
AzureContentSafetyInput.create()
69+
.hate(hate != null ? hate.getAzureThreshold() : null)
70+
.selfHarm(selfHarm != null ? selfHarm.getAzureThreshold() : null)
71+
.sexual(sexual != null ? sexual.getAzureThreshold() : null)
72+
.violence(violence != null ? violence.getAzureThreshold() : null));
73+
}
74+
75+
/**
76+
* Converts {@code AzureContentFilter} to its serializable counterpart {@link
77+
* AzureContentSafetyOutputFilterConfig}.
78+
*
79+
* @return the corresponding {@code AzureContentSafetyOutputFilterConfig} object.
80+
* @throws IllegalArgumentException if no policies are set.
81+
*/
82+
@Override
83+
@Nonnull
84+
public AzureContentSafetyOutputFilterConfig createOutputConfig() {
85+
if (hate == null && selfHarm == null && sexual == null && violence == null) {
86+
throw new IllegalArgumentException("At least one filter category must be set");
87+
}
88+
89+
return AzureContentSafetyOutputFilterConfig.create()
90+
.type(AzureContentSafetyOutputFilterConfig.TypeEnum.AZURE_CONTENT_SAFETY)
91+
.config(
92+
AzureContentSafetyOutput.create()
6793
.hate(hate != null ? hate.getAzureThreshold() : null)
6894
.selfHarm(selfHarm != null ? selfHarm.getAzureThreshold() : null)
6995
.sexual(sexual != null ? sexual.getAzureThreshold() : null)

orchestration/src/main/java/com/sap/ai/sdk/orchestration/ContentFilter.java

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package com.sap.ai.sdk.orchestration;
22

3-
import com.sap.ai.sdk.orchestration.model.FilterConfig;
3+
import com.sap.ai.sdk.orchestration.model.InputFilterConfig;
4+
import com.sap.ai.sdk.orchestration.model.OutputFilterConfig;
45
import javax.annotation.Nonnull;
56

67
/**
@@ -17,11 +18,20 @@
1718
public interface ContentFilter {
1819

1920
/**
20-
* A method that produces the serializable equivalent {@link FilterConfig} object from data
21+
* A method that produces the serializable equivalent {@link InputFilterConfig} object from data
2122
* encapsulated in the {@link ContentFilter} object.
2223
*
23-
* @return the corresponding {@code FilterConfig} object.
24+
* @return the corresponding {@code InputFilterConfig} object.
2425
*/
2526
@Nonnull
26-
FilterConfig createConfig();
27+
InputFilterConfig createInputConfig();
28+
29+
/**
30+
* A method that produces the serializable equivalent {@link OutputFilterConfig} object from data
31+
* encapsulated in the {@link ContentFilter} object.
32+
*
33+
* @return the corresponding {@code OutputFilterConfig} object.
34+
*/
35+
@Nonnull
36+
OutputFilterConfig createOutputConfig();
2737
}

orchestration/src/main/java/com/sap/ai/sdk/orchestration/LlamaGuardFilter.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,17 @@ public class LlamaGuardFilter implements ContentFilter {
4040

4141
@Nonnull
4242
@Override
43-
public LlamaGuard38bFilterConfig createConfig() {
43+
public LlamaGuard38bFilterConfig createInputConfig() {
44+
return createConfig();
45+
}
46+
47+
@Nonnull
48+
@Override
49+
public LlamaGuard38bFilterConfig createOutputConfig() {
50+
return createConfig();
51+
}
52+
53+
private LlamaGuard38bFilterConfig createConfig() {
4454
return LlamaGuard38bFilterConfig.create().type(LLAMA_GUARD_3_8B).config(config);
4555
}
4656
}

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationModuleConfig.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ public OrchestrationModuleConfig withInputFiltering(
160160
allFilters.addAll(Arrays.asList(contentFilters));
161161

162162
final var filterConfigs =
163-
allFilters.stream().filter(Objects::nonNull).map(ContentFilter::createConfig).toList();
163+
allFilters.stream().filter(Objects::nonNull).map(ContentFilter::createInputConfig).toList();
164164

165165
final var inputFilter = InputFilteringConfig.create().filters(filterConfigs);
166166

@@ -195,7 +195,10 @@ public OrchestrationModuleConfig withOutputFiltering(
195195
allFilters.addAll(Arrays.asList(contentFilters));
196196

197197
final var filterConfigs =
198-
allFilters.stream().filter(Objects::nonNull).map(ContentFilter::createConfig).toList();
198+
allFilters.stream()
199+
.filter(Objects::nonNull)
200+
.map(ContentFilter::createOutputConfig)
201+
.toList();
199202

200203
final var outputFilter = OutputFilteringConfig.create().filters(filterConfigs);
201204

0 commit comments

Comments
 (0)