Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/release_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
Interfaces with only one implementation were reduced.
As a result, the accessors for fields `OrchestrationModuleConfig.inputTranslationConfig` and `OrchestrationModuleConfig.outputTranslationConfig` now handle the implementing class explicitly.
The same applies to helper methods `DpiMasking#createConfig()` and `MaskingProvider#createConfig()`.
- [Orchestration] The method `createConfig()` is removed from `ContentFilter`, `AzureContentFilter` and `LlamaGuardFilter` and is replaced by `createInputFilterConfig()` and `createOutputFilterConfig()`.

### ✨ New Functionality

-
- [Orchestration] Added `AzureContentFilter#promptShield()` available for input filtering.

### 📈 Improvements

Expand Down
4 changes: 2 additions & 2 deletions orchestration/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@
</developers>
<properties>
<project.rootdir>${project.basedir}/../</project.rootdir>
<coverage.complexity>84%</coverage.complexity>
<coverage.complexity>82%</coverage.complexity>
<coverage.line>94%</coverage.line>
<coverage.instruction>95%</coverage.instruction>
<coverage.branch>79%</coverage.branch>
<coverage.branch>77%</coverage.branch>
<coverage.method>93%</coverage.method>
<coverage.class>100%</coverage.class>
</properties>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package com.sap.ai.sdk.orchestration;

import com.sap.ai.sdk.orchestration.model.AzureContentSafety;
import com.sap.ai.sdk.orchestration.model.AzureContentSafetyFilterConfig;
import com.sap.ai.sdk.orchestration.model.AzureContentSafetyInput;
import com.sap.ai.sdk.orchestration.model.AzureContentSafetyInputFilterConfig;
import com.sap.ai.sdk.orchestration.model.AzureContentSafetyOutput;
import com.sap.ai.sdk.orchestration.model.AzureContentSafetyOutputFilterConfig;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import lombok.NoArgsConstructor;
Expand Down Expand Up @@ -46,24 +48,52 @@ public class AzureContentFilter implements ContentFilter {
/* The filter category for violence content. */
@Nullable AzureFilterThreshold violence;

/* A flag to set prompt shield on input filer.*/
@Nullable Boolean promptShield;

/**
* Converts {@link AzureContentFilter} to its serializable counterpart {@link
* AzureContentSafetyInputFilterConfig}.
*
* @return the corresponding {@link AzureContentSafetyInputFilterConfig} object.
* @throws IllegalArgumentException if no policies are set.
*/
@Override
@Nonnull
public AzureContentSafetyInputFilterConfig createInputFilterConfig() {
if (hate == null && selfHarm == null && sexual == null && violence == null) {
throw new IllegalArgumentException("At least one filter category must be set");
}

return AzureContentSafetyInputFilterConfig.create()
.type(AzureContentSafetyInputFilterConfig.TypeEnum.AZURE_CONTENT_SAFETY)
.config(
AzureContentSafetyInput.create()
.hate(hate != null ? hate.getAzureThreshold() : null)
.selfHarm(selfHarm != null ? selfHarm.getAzureThreshold() : null)
.sexual(sexual != null ? sexual.getAzureThreshold() : null)
.violence(violence != null ? violence.getAzureThreshold() : null)
.promptShield(promptShield != null ? promptShield : null));
}

/**
* Converts {@code AzureContentFilter} to its serializable counterpart {@link
* AzureContentSafetyFilterConfig}.
* Converts {@link AzureContentFilter} to its serializable counterpart {@link
* AzureContentSafetyOutput}.
*
* @return the corresponding {@code AzureContentSafetyFilterConfig} object.
* @return the corresponding {@link AzureContentSafetyOutputFilterConfig} object.
* @throws IllegalArgumentException if no policies are set.
*/
@Override
@Nonnull
public AzureContentSafetyFilterConfig createConfig() {
public AzureContentSafetyOutputFilterConfig createOutputFilterConfig() {
if (hate == null && selfHarm == null && sexual == null && violence == null) {
throw new IllegalArgumentException("At least one filter category must be set");
}

return AzureContentSafetyFilterConfig.create()
.type(AzureContentSafetyFilterConfig.TypeEnum.AZURE_CONTENT_SAFETY)
return AzureContentSafetyOutputFilterConfig.create()
.type(AzureContentSafetyOutputFilterConfig.TypeEnum.AZURE_CONTENT_SAFETY)
.config(
AzureContentSafety.create()
AzureContentSafetyOutput.create()
.hate(hate != null ? hate.getAzureThreshold() : null)
.selfHarm(selfHarm != null ? selfHarm.getAzureThreshold() : null)
.sexual(sexual != null ? sexual.getAzureThreshold() : null)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.sap.ai.sdk.orchestration;

import com.sap.ai.sdk.orchestration.model.FilterConfig;
import com.sap.ai.sdk.orchestration.model.InputFilterConfig;
import com.sap.ai.sdk.orchestration.model.OutputFilterConfig;
import javax.annotation.Nonnull;

/**
Expand All @@ -17,11 +18,20 @@
public interface ContentFilter {

/**
* A method that produces the serializable equivalent {@link FilterConfig} object from data
* A method that produces the serializable equivalent {@link InputFilterConfig} object from data
* encapsulated in the {@link ContentFilter} object.
*
* @return the corresponding {@code FilterConfig} object.
* @return the corresponding {@link InputFilterConfig} object.
*/
@Nonnull
FilterConfig createConfig();
InputFilterConfig createInputFilterConfig();

/**
* A method that produces the serializable equivalent {@link OutputFilterConfig} object from data
* encapsulated in the {@link ContentFilter} object.
*
* @return the corresponding {@link OutputFilterConfig} object.
*/
@Nonnull
OutputFilterConfig createOutputFilterConfig();
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import com.fasterxml.jackson.annotation.JsonSubTypes;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.sap.ai.sdk.orchestration.model.LLMModuleResultSynchronous;
import com.sap.ai.sdk.orchestration.model.LLMModuleResult;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;

@NoArgsConstructor(access = AccessLevel.PRIVATE)
final class JacksonMixins {
/** Mixin to enforce a specific subtype to be deserialized always. */
@JsonTypeInfo(use = JsonTypeInfo.Id.NONE)
@JsonDeserialize(as = LLMModuleResultSynchronous.class)
@JsonDeserialize(as = LLMModuleResult.class)
interface LLMModuleResultMixIn {}

@JsonTypeInfo(use = JsonTypeInfo.Id.NONE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,13 @@ public class LlamaGuardFilter implements ContentFilter {

@Nonnull
@Override
public LlamaGuard38bFilterConfig createConfig() {
public LlamaGuard38bFilterConfig createInputFilterConfig() {
return LlamaGuard38bFilterConfig.create().type(LLAMA_GUARD_3_8B).config(config);
}

@Nonnull
@Override
public LlamaGuard38bFilterConfig createOutputFilterConfig() {
return createInputFilterConfig();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import com.sap.ai.sdk.orchestration.model.AssistantChatMessage;
import com.sap.ai.sdk.orchestration.model.ChatMessage;
import com.sap.ai.sdk.orchestration.model.ChatMessageContent;
import com.sap.ai.sdk.orchestration.model.CompletionPostResponseSynchronous;
import com.sap.ai.sdk.orchestration.model.LLMChoiceSynchronous;
import com.sap.ai.sdk.orchestration.model.CompletionPostResponse;
import com.sap.ai.sdk.orchestration.model.LLMChoice;
import com.sap.ai.sdk.orchestration.model.SystemChatMessage;
import com.sap.ai.sdk.orchestration.model.TokenUsage;
import com.sap.ai.sdk.orchestration.model.ToolChatMessage;
Expand All @@ -22,7 +22,7 @@
@Value
@RequiredArgsConstructor(access = PACKAGE)
public class OrchestrationChatResponse {
CompletionPostResponseSynchronous originalResponse;
CompletionPostResponse originalResponse;

/**
* Get the message content from the output.
Expand Down Expand Up @@ -97,10 +97,10 @@ public List<Message> getAllMessages() throws IllegalArgumentException {
/**
* Get the LLM response. Useful for accessing the finish reason or further data like logprobs.
*
* @return The (first, in case of multiple) {@link LLMChoiceSynchronous}.
* @return The (first, in case of multiple) {@link LLMChoice}.
*/
@Nonnull
public LLMChoiceSynchronous getChoice() {
public LLMChoice getChoice() {
// We expect choices to be defined and never empty.
return originalResponse.getOrchestrationResult().getChoices().get(0);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import com.google.common.annotations.Beta;
import com.sap.ai.sdk.core.AiCoreService;
import com.sap.ai.sdk.orchestration.model.CompletionPostRequest;
import com.sap.ai.sdk.orchestration.model.CompletionPostResponseSynchronous;
import com.sap.ai.sdk.orchestration.model.CompletionPostResponse;
import com.sap.ai.sdk.orchestration.model.EmbeddingsPostRequest;
import com.sap.ai.sdk.orchestration.model.EmbeddingsPostResponse;
import com.sap.ai.sdk.orchestration.model.ModuleConfigs;
Expand Down Expand Up @@ -138,9 +138,9 @@ private static void throwOnContentFilter(@Nonnull final OrchestrationChatComplet
* @throws OrchestrationClientException If the request fails.
*/
@Nonnull
public CompletionPostResponseSynchronous executeRequest(
@Nonnull final CompletionPostRequest request) throws OrchestrationClientException {
return executor.execute("/completion", request, CompletionPostResponseSynchronous.class);
public CompletionPostResponse executeRequest(@Nonnull final CompletionPostRequest request)
throws OrchestrationClientException {
return executor.execute("/completion", request, CompletionPostResponse.class);
}

/**
Expand Down Expand Up @@ -182,7 +182,7 @@ public OrchestrationChatResponse executeRequestFromJsonModuleConfig(
requestJson.set("orchestration_config", moduleConfigJson);

return new OrchestrationChatResponse(
executor.execute("/completion", requestJson, CompletionPostResponseSynchronous.class));
executor.execute("/completion", requestJson, CompletionPostResponse.class));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import com.fasterxml.jackson.annotation.JsonCreator;
import com.google.common.annotations.Beta;
import com.sap.ai.sdk.core.common.ClientError;
import com.sap.ai.sdk.orchestration.model.ErrorResponseSynchronous;
import com.sap.ai.sdk.orchestration.model.ErrorResponse;
import javax.annotation.Nonnull;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
Expand All @@ -18,7 +18,7 @@
@Value
@Beta
public class OrchestrationError implements ClientError {
ErrorResponseSynchronous originalResponse;
ErrorResponse originalResponse;

/**
* Gets the error message from the contained original response.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,10 @@ public OrchestrationModuleConfig withInputFiltering(
allFilters.addAll(Arrays.asList(contentFilters));

final var filterConfigs =
allFilters.stream().filter(Objects::nonNull).map(ContentFilter::createConfig).toList();
allFilters.stream()
.filter(Objects::nonNull)
.map(ContentFilter::createInputFilterConfig)
.toList();

final var inputFilter = InputFilteringConfig.create().filters(filterConfigs);

Expand Down Expand Up @@ -194,7 +197,10 @@ public OrchestrationModuleConfig withOutputFiltering(
allFilters.addAll(Arrays.asList(contentFilters));

final var filterConfigs =
allFilters.stream().filter(Objects::nonNull).map(ContentFilter::createConfig).toList();
allFilters.stream()
.filter(Objects::nonNull)
.map(ContentFilter::createOutputFilterConfig)
.toList();

final var outputFilter = OutputFilteringConfig.create().filters(filterConfigs);

Expand Down
Loading