Skip to content

Commit c28d11d

Browse files
feat: Llama🦙Guard (#308)
* feat: Llama Guard Filter * trigger filter * Added e2e test * Unit test * Formatting * doc * wrong package * javadoc * Merge conflicts * Revert object mapper * Added missing dependency --------- Co-authored-by: SAP Cloud SDK Bot <[email protected]>
1 parent 96e1695 commit c28d11d

File tree

10 files changed

+157
-12
lines changed

10 files changed

+157
-12
lines changed

docs/guides/ORCHESTRATION_CHAT_COMPLETION.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,11 @@ var filterLoose = new AzureContentFilter()
158158
.sexual(ALLOW_SAFE_LOW_MEDIUM)
159159
.violence(ALLOW_SAFE_LOW_MEDIUM);
160160

161+
// choose Llama Guard filter or/and Azure filter
162+
var llamaGuardFilter = new LlamaGuardFilter().config(LlamaGuard38b.create().selfHarm(true));
163+
161164
// changing the input to filterLoose will allow the message to pass
162-
var configWithFilter = config.withInputFiltering(filterStrict).withOutputFiltering(filterStrict);
165+
var configWithFilter = config.withInputFiltering(filterStrict).withOutputFiltering(filterStrict, llamaGuardFilter);
163166

164167
// this fails with Bad Request because the strict filter prohibits the input message
165168
var result =

docs/release-notes/release_notes.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
### ✨ New Functionality
1414

15-
-
15+
- [Add Orchestration `LlamaGuardFilter`](../guides/ORCHESTRATION_CHAT_COMPLETION.md#chat-completion-filter).
1616

1717
### 📈 Improvements
1818

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import static com.sap.ai.sdk.orchestration.model.LlamaGuard38bFilterConfig.TypeEnum.LLAMA_GUARD_3_8B;
4+
5+
import com.sap.ai.sdk.orchestration.model.LlamaGuard38b;
6+
import com.sap.ai.sdk.orchestration.model.LlamaGuard38bFilterConfig;
7+
import javax.annotation.Nonnull;
8+
import lombok.Setter;
9+
import lombok.experimental.Accessors;
10+
11+
/**
12+
* A content filter wrapping Llama Guard filter config.
13+
*
14+
* <p>This class allows setting filters for different content categories such as hate, self-harm,
15+
* sexual, and violence.
16+
*
17+
* <p>Example usage:
18+
*
19+
* <pre>{@code
20+
* // values not set are disabled by default
21+
* val config =
22+
* LlamaGuard38b.create()
23+
* .violentCrimes(true)
24+
* .selfHarm(true);
25+
* val filterConfig = new LlamaGuardFilter().config(config);
26+
* }</pre>
27+
*
28+
* @link <a
29+
* href="https://help.sap.com/docs/sap-ai-core/sap-ai-core-service-guide/input-filtering">SAP AI
30+
* Core: Orchestration - Input Filtering</a>
31+
* @link <a
32+
* href="https://help.sap.com/docs/sap-ai-core/sap-ai-core-service-guide/output-filtering">SAP
33+
* AI Core: Orchestration - Output Filtering</a>
34+
*/
35+
@Accessors(fluent = true)
36+
@Setter
37+
public class LlamaGuardFilter implements ContentFilter {
38+
39+
private LlamaGuard38b config = LlamaGuard38b.create();
40+
41+
@Nonnull
42+
@Override
43+
public LlamaGuard38bFilterConfig createConfig() {
44+
return LlamaGuard38bFilterConfig.create().type(LLAMA_GUARD_3_8B).config(config);
45+
}
46+
}

orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
import com.sap.ai.sdk.orchestration.model.LLMModuleResult;
5454
import com.sap.ai.sdk.orchestration.model.LLMModuleResultSynchronous;
5555
import com.sap.ai.sdk.orchestration.model.LlamaGuard38b;
56-
import com.sap.ai.sdk.orchestration.model.LlamaGuard38bFilterConfig;
5756
import com.sap.ai.sdk.orchestration.model.ModuleConfigs;
5857
import com.sap.ai.sdk.orchestration.model.MultiChatMessage;
5958
import com.sap.ai.sdk.orchestration.model.OrchestrationConfig;
@@ -309,11 +308,7 @@ void filteringLoose() throws IOException {
309308
.sexual(ALLOW_SAFE_LOW_MEDIUM)
310309
.violence(ALLOW_SAFE_LOW_MEDIUM);
311310

312-
final ContentFilter llamaFilter =
313-
() ->
314-
LlamaGuard38bFilterConfig.create()
315-
.type(LlamaGuard38bFilterConfig.TypeEnum.LLAMA_GUARD_3_8B)
316-
.config(LlamaGuard38b.create().selfHarm(true));
311+
final var llamaFilter = new LlamaGuardFilter().config(LlamaGuard38b.create().selfHarm(true));
317312

318313
client.chatCompletion(
319314
prompt,

sample-code/spring-app/pom.xml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,10 @@
118118
<groupId>com.fasterxml.jackson.core</groupId>
119119
<artifactId>jackson-core</artifactId>
120120
</dependency>
121+
<dependency>
122+
<groupId>com.fasterxml.jackson.core</groupId>
123+
<artifactId>jackson-annotations</artifactId>
124+
</dependency>
121125
<!-- scope "runtime" -->
122126
<dependency>
123127
<groupId>ch.qos.logback</groupId>

sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/Application.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
package com.sap.ai.sdk.app;
22

3+
import static com.sap.ai.sdk.core.JacksonConfiguration.getDefaultObjectMapper;
4+
5+
import com.fasterxml.jackson.annotation.JsonAutoDetect.Visibility;
6+
import com.fasterxml.jackson.annotation.PropertyAccessor;
37
import com.fasterxml.jackson.databind.ObjectMapper;
4-
import com.sap.ai.sdk.orchestration.OrchestrationJacksonConfiguration;
58
import javax.annotation.Nonnull;
69
import org.springframework.boot.SpringApplication;
710
import org.springframework.boot.autoconfigure.SpringBootApplication;
@@ -17,16 +20,16 @@
1720
public class Application {
1821

1922
/**
20-
* Temporary workaround to fix the issue with the Orchestration spec.
23+
* Changes Spring Boot's default object mapper to fix serialization issues.
2124
*
22-
* @return a modified object mapper that works for Orchestration.
25+
* @return a modified object mapper
2326
*/
2427
@Bean
2528
@Primary
2629
@SuppressWarnings("unused")
2730
@Nonnull
2831
public ObjectMapper objectMapper() {
29-
return OrchestrationJacksonConfiguration.getOrchestrationObjectMapper();
32+
return getDefaultObjectMapper().setVisibility(PropertyAccessor.FIELD, Visibility.ANY);
3033
}
3134

3235
/**

sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,27 @@ Object outputFiltering(
124124
return response.getContent();
125125
}
126126

127+
@GetMapping("/llamaGuardFilter/{enabled}")
128+
@Nonnull
129+
Object llamaGuardInputFiltering(
130+
@Nullable @RequestParam(value = "format", required = false) final String format,
131+
@PathVariable("enabled") final boolean enabled) {
132+
133+
final OrchestrationChatResponse response;
134+
try {
135+
response = service.llamaGuardInputFilter(enabled);
136+
} catch (OrchestrationClientException e) {
137+
final var msg = "Failed to obtain a response as the content was flagged by input filter.";
138+
log.debug(msg, e);
139+
return ResponseEntity.internalServerError().body(msg);
140+
}
141+
142+
if ("json".equals(format)) {
143+
return response;
144+
}
145+
return response.getContent();
146+
}
147+
127148
@GetMapping("/maskingAnonymization")
128149
@Nonnull
129150
Object maskingAnonymization(

sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OrchestrationService.java

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import com.sap.ai.sdk.orchestration.AzureFilterThreshold;
99
import com.sap.ai.sdk.orchestration.DpiMasking;
1010
import com.sap.ai.sdk.orchestration.Grounding;
11+
import com.sap.ai.sdk.orchestration.LlamaGuardFilter;
1112
import com.sap.ai.sdk.orchestration.Message;
1213
import com.sap.ai.sdk.orchestration.OrchestrationChatResponse;
1314
import com.sap.ai.sdk.orchestration.OrchestrationClient;
@@ -18,6 +19,7 @@
1819
import com.sap.ai.sdk.orchestration.model.DataRepositoryType;
1920
import com.sap.ai.sdk.orchestration.model.DocumentGroundingFilter;
2021
import com.sap.ai.sdk.orchestration.model.GroundingFilterSearchConfiguration;
22+
import com.sap.ai.sdk.orchestration.model.LlamaGuard38b;
2123
import com.sap.ai.sdk.orchestration.model.SearchDocumentKeyValueListPair;
2224
import com.sap.ai.sdk.orchestration.model.SearchSelectOptionEnum;
2325
import com.sap.ai.sdk.orchestration.model.Template;
@@ -150,6 +152,47 @@ public OrchestrationChatResponse outputFiltering(@Nonnull final AzureFilterThres
150152
return client.chatCompletion(prompt, configWithFilter);
151153
}
152154

155+
/**
156+
* Apply the Llama Guard filter.
157+
*
158+
* @link <a
159+
* href="https://help.sap.com/docs/sap-ai-core/sap-ai-core-service-guide/input-filtering">SAP
160+
* AI Core: Orchestration - Input Filtering</a>
161+
* @throws OrchestrationClientException if input filter filters the prompt
162+
* @param filter enable or disable the filter
163+
* @return the assistant response object
164+
*/
165+
@Nonnull
166+
public OrchestrationChatResponse llamaGuardInputFilter(final boolean filter)
167+
throws OrchestrationClientException {
168+
val prompt =
169+
new OrchestrationPrompt("'We shall spill blood tonight', said the operation in-charge.");
170+
171+
// values not set are disabled by default
172+
val config =
173+
LlamaGuard38b.create()
174+
.violentCrimes(filter)
175+
.nonViolentCrimes(filter)
176+
.sexCrimes(filter)
177+
.childExploitation(filter)
178+
.defamation(filter)
179+
.specializedAdvice(filter)
180+
.privacy(filter)
181+
.intellectualProperty(filter)
182+
.indiscriminateWeapons(filter)
183+
.hate(filter)
184+
.selfHarm(filter)
185+
.sexualContent(filter)
186+
.elections(filter)
187+
.codeInterpreterAbuse(filter);
188+
189+
val filterConfig = new LlamaGuardFilter().config(config);
190+
191+
val configWithFilter = this.config.withInputFiltering(filterConfig);
192+
193+
return client.chatCompletion(prompt, configWithFilter);
194+
}
195+
153196
/**
154197
* Let the orchestration service evaluate the feedback on the AI SDK provided by a hypothetical
155198
* user. Anonymize any names given as they are not relevant for judging the sentiment of the

sample-code/spring-app/src/main/resources/static/index.html

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,16 @@ <h2>Orchestration</h2>
273273
</div>
274274
</div>
275275
</li>
276+
<li class="list-group-item">
277+
<div class="info-tooltip">
278+
<button type="submit" formaction="/orchestration/llamaGuardFilter/false" class="link-offset-2-hover link-underline link-underline-opacity-0 link-underline-opacity-75-hover endpoint">
279+
<code>/orchestration/llamaGuardFilter/false</code>
280+
</button>
281+
<div class="tooltip-content">
282+
Apply lenient input filtering for a request to orchestration.
283+
</div>
284+
</div>
285+
</li>
276286
<li class="list-group-item">
277287
<div class="info-tooltip">
278288
<button type="submit" formaction="/orchestration/maskingAnonymization" class="link-offset-2-hover link-underline link-underline-opacity-0 link-underline-opacity-75-hover endpoint">

sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,4 +228,24 @@ void testOutputFilteringLenient() {
228228
var filterResult = response.getOriginalResponse().getModuleResults().getOutputFiltering();
229229
assertThat(filterResult.getMessage()).containsPattern("0 of \\d+ choices failed");
230230
}
231+
232+
@Test
233+
void testLlamaGuardEnabled() {
234+
assertThatThrownBy(() -> service.llamaGuardInputFilter(true))
235+
.isInstanceOf(OrchestrationClientException.class)
236+
.hasMessageContaining(
237+
"Content filtered due to safety violations. Please modify the prompt and try again.")
238+
.hasMessageContaining("400 Bad Request");
239+
}
240+
241+
@Test
242+
void testLlamaGuardDisabled() {
243+
var response = service.llamaGuardInputFilter(false);
244+
245+
assertThat(response.getChoice().getFinishReason()).isEqualTo("stop");
246+
assertThat(response.getContent()).isNotEmpty();
247+
248+
var filterResult = response.getOriginalResponse().getModuleResults().getInputFiltering();
249+
assertThat(filterResult.getMessage()).contains("passed");
250+
}
231251
}

0 commit comments

Comments
 (0)