Skip to content

Commit 27e56d5

Browse files
committed
Add OrchestrationConfig class
1 parent 922f469 commit 27e56d5

File tree

6 files changed

+126
-60
lines changed

6 files changed

+126
-60
lines changed
Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
package com.sap.ai.sdk.orchestration;
22

33
import com.sap.ai.sdk.orchestration.client.model.CompletionPostRequest;
4+
import com.sap.ai.sdk.orchestration.client.model.FilteringModuleConfig;
5+
import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig;
46
import com.sap.ai.sdk.orchestration.client.model.ModuleConfigs;
57
import com.sap.ai.sdk.orchestration.client.model.TemplatingModuleConfig;
68
import java.util.ArrayList;
9+
import java.util.List;
710
import javax.annotation.Nonnull;
11+
import javax.annotation.Nullable;
12+
13+
import io.vavr.control.Option;
814
import lombok.AccessLevel;
915
import lombok.NoArgsConstructor;
1016
import lombok.val;
@@ -14,31 +20,31 @@
1420
final class ModuleConfigFactory {
1521
@Nonnull
1622
static CompletionPostRequest toCompletionPostRequestDto(
17-
@Nonnull final OrchestrationPrompt prompt, @Nonnull final ModuleConfigs config) {
18-
// copying is required because we have to merge the prompt into the template config
19-
// also, users may modify the object before request execution
20-
val configCopy = copyModuleConfigs(config);
21-
configCopy.setTemplatingModuleConfig(
22-
toTemplateModuleConfigDto(prompt, config.getTemplatingModuleConfig()));
23+
@Nonnull final OrchestrationPrompt prompt, @Nonnull final OrchestrationModuleConfig config) {
24+
val template = toTemplateModuleConfigDto(prompt, config.getTemplate());
25+
// note that the config is immutable and implicitly copied here
26+
// copying is required here, to not alter the original config object, which might be reused for subsequent requests
27+
val configCopy = config.withTemplate(template);
2328

2429
return CompletionPostRequest.create()
2530
.orchestrationConfig(
2631
com.sap.ai.sdk.orchestration.client.model.OrchestrationConfig.create()
27-
.moduleConfigurations(configCopy))
32+
.moduleConfigurations(toModuleConfigsDto(configCopy)))
2833
.inputParams(prompt.getTemplateParameters());
2934
}
3035

3136
@Nonnull
3237
static TemplatingModuleConfig toTemplateModuleConfigDto(
33-
@Nonnull final OrchestrationPrompt prompt, @Nonnull final TemplatingModuleConfig template) {
38+
@Nonnull final OrchestrationPrompt prompt, @Nullable final TemplatingModuleConfig template) {
3439
/*
3540
* Currently, we have to merge the prompt into the template configuration.
36-
* This works around the limitation that the template config isn't optional.
41+
* This works around the limitation that the template config is required.
3742
* This comes at the risk that the prompt unintentionally contains the templating pattern "{{? .. }}".
3843
* In this case, the request will fail, since the templating module will try to resolve the parameter.
3944
* To be fixed with https://github.tools.sap/AI/llm-orchestration/issues/662
4045
*/
41-
val messagesWithPrompt = new ArrayList<>(template.getTemplate());
46+
val messages = Option.of(template).map(TemplatingModuleConfig::getTemplate).getOrElse(List::of);
47+
val messagesWithPrompt = new ArrayList<>(messages);
4248
messagesWithPrompt.addAll(prompt.getMessages());
4349
if (messagesWithPrompt.isEmpty()) {
4450
throw new IllegalStateException(
@@ -47,11 +53,27 @@ static TemplatingModuleConfig toTemplateModuleConfigDto(
4753
return TemplatingModuleConfig.create().template(messagesWithPrompt);
4854
}
4955

50-
static ModuleConfigs copyModuleConfigs(@Nonnull final ModuleConfigs configs) {
51-
return ModuleConfigs.create()
52-
.llmModuleConfig(configs.getLlmModuleConfig())
53-
.templatingModuleConfig(configs.getTemplatingModuleConfig())
54-
.maskingModuleConfig(configs.getMaskingModuleConfig())
55-
.filteringModuleConfig(configs.getFilteringModuleConfig());
56+
@Nonnull
57+
static ModuleConfigs toModuleConfigsDto(@Nonnull final OrchestrationModuleConfig config) {
58+
val llmConfig =
59+
Option.of(config.getLlmConfig()).getOrElseThrow(() -> new IllegalStateException("LLM config is required."));
60+
61+
//noinspection DataFlowIssue the template is always non-null here
62+
val moduleConfig =
63+
ModuleConfigs.create()
64+
.llmModuleConfig(llmConfig)
65+
.templatingModuleConfig(config.getTemplate());
66+
67+
val maybeInputFilter = Option.of(config.getInputContentFilter());
68+
val maybeOutputFilter = Option.of(config.getOutputContentFilter());
69+
70+
if (maybeInputFilter.isDefined() || maybeOutputFilter.isDefined()) {
71+
val filter = FilteringModuleConfig.create();
72+
maybeInputFilter.forEach(filter::input);
73+
maybeOutputFilter.forEach(filter::output);
74+
}
75+
Option.of(config.getMaskingConfig()).forEach(moduleConfig::maskingModuleConfig);
76+
77+
return moduleConfig;
5678
}
5779
}

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -76,23 +76,10 @@ public OrchestrationClient(@Nonnull final AiCoreDeployment deployment) {
7676
*/
7777
@Nonnull
7878
public CompletionPostResponse chatCompletion(
79-
@Nonnull final OrchestrationPrompt prompt, @Nonnull final ModuleConfigs config)
79+
@Nonnull final OrchestrationPrompt prompt, @Nonnull final OrchestrationModuleConfig config)
8080
throws OrchestrationClientException {
8181

82-
val request = toCompletionPostRequestDto(prompt, config);
83-
return executeRequest(request);
84-
}
85-
86-
/**
87-
* Generate a completion for the given prompt.
88-
*
89-
* @param request The request to send to orchestration.
90-
* @return the completion output
91-
* @throws OrchestrationClientException if the request fails
92-
*/
93-
@Nonnull
94-
public CompletionPostResponse chatCompletion(@Nonnull final CompletionPostRequest request)
95-
throws OrchestrationClientException {
82+
val request = toCompletionPostRequest(prompt, config);
9683
return executeRequest(request);
9784
}
9885

@@ -140,8 +127,8 @@ public CompletionPostResponse executeRequest(@Nonnull final CompletionPostReques
140127
* @return The low-level request DTO to send to orchestration.
141128
*/
142129
@Nonnull
143-
public static CompletionPostRequest toCompletionPostRequestDto(
144-
@Nonnull final OrchestrationPrompt prompt, @Nonnull final ModuleConfigs config) {
130+
public static CompletionPostRequest toCompletionPostRequest(
131+
@Nonnull final OrchestrationPrompt prompt, @Nonnull final OrchestrationModuleConfig config) {
145132
return ModuleConfigFactory.toCompletionPostRequestDto(prompt, config);
146133
}
147134

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import javax.annotation.Nullable;
4+
5+
import com.sap.ai.sdk.orchestration.client.model.InputFilteringConfig;
6+
import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig;
7+
import com.sap.ai.sdk.orchestration.client.model.MaskingModuleConfig;
8+
import com.sap.ai.sdk.orchestration.client.model.OutputFilteringConfig;
9+
import com.sap.ai.sdk.orchestration.client.model.TemplatingModuleConfig;
10+
import lombok.AccessLevel;
11+
import lombok.AllArgsConstructor;
12+
import lombok.NoArgsConstructor;
13+
import lombok.Value;
14+
import lombok.With;
15+
16+
/**
17+
* Represents the configuration for the orchestration service. Allows for configuring the different
18+
* modules of the orchestration service via a fluent API.
19+
*
20+
* <p>The orchestration pipeline combines different modules into a single execution flow where the
21+
* output of one module serves as input for the next. The pipeline consists of the following
22+
* modules:
23+
*
24+
* <ul>
25+
* <li>LLM Config (Mandatory)
26+
* <li>Templating (Optional)
27+
* <li>Data Masking (Optional)
28+
* <li>Input Content Filtering (Optional)
29+
* <li>Output Content Filtering (Optional)
30+
* </ul>
31+
*/
32+
@Value
33+
@With
34+
@AllArgsConstructor(access = AccessLevel.PRIVATE)
35+
@NoArgsConstructor(force = true)
36+
public class OrchestrationModuleConfig {
37+
/**
38+
* The configured language model settings. This configuration is required when executing requests.
39+
*/
40+
@Nullable
41+
LLMModuleConfig llmConfig;
42+
43+
/**
44+
* A template to be populated with input parameters. Upon request execution, this template will be
45+
* enhanced with any messages and parameter values from {@link OrchestrationPrompt}.
46+
*/
47+
@Nullable
48+
TemplatingModuleConfig template;
49+
50+
/**
51+
* A masking configuration to pseudonymous or anonymize sensitive data in the input.
52+
*/
53+
@Nullable
54+
MaskingModuleConfig maskingConfig;
55+
56+
/**
57+
* A content filter to filter the prompt.
58+
*/
59+
@Nullable
60+
InputFilteringConfig inputContentFilter;
61+
62+
/**
63+
* A content filter to filter the language model response.
64+
*/
65+
@Nullable
66+
OutputFilteringConfig outputContentFilter;
67+
}

orchestration/src/test/java/com/sap/ai/sdk/orchestration/ModuleConfigFactoryTest.java

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,18 @@
1515

1616
class ModuleConfigFactoryTest {
1717

18+
@Test
19+
void testThrowsOnMissingLlmConfig() {
20+
assertThatThrownBy(() -> ModuleConfigFactory.toModuleConfigsDto(new OrchestrationModuleConfig()))
21+
.isInstanceOf(IllegalStateException.class)
22+
.hasMessageContaining("A prompt is required");
23+
}
24+
1825
@Test
1926
void testThrowsOnMissingMessages() {
2027
var prompt = new OrchestrationPrompt(Map.of());
21-
var templateConfig = TemplatingModuleConfig.create().template();
2228

23-
assertThatThrownBy(() -> ModuleConfigFactory.toTemplateModuleConfigDto(prompt, templateConfig))
29+
assertThatThrownBy(() -> ModuleConfigFactory.toTemplateModuleConfigDto(prompt, null))
2430
.isInstanceOf(IllegalStateException.class)
2531
.hasMessageContaining("A prompt is required");
2632
}
@@ -59,17 +65,4 @@ void testMergingTemplateConfig() {
5965

6066
assertThat(actual).isEqualTo(expected);
6167
}
62-
63-
@Test
64-
void testCopy() {
65-
var moduleConfigs =
66-
ModuleConfigs.create()
67-
.llmModuleConfig(mock(LLMModuleConfig.class))
68-
.templatingModuleConfig(mock(TemplatingModuleConfig.class))
69-
.filteringModuleConfig(mock(FilteringModuleConfig.class))
70-
.maskingModuleConfig(mock(MaskingModuleConfig.class));
71-
assertThat(ModuleConfigFactory.copyModuleConfigs(moduleConfigs))
72-
.isEqualTo(moduleConfigs)
73-
.isNotSameAs(moduleConfigs);
74-
}
7568
}

orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
@WireMockTest
5959
class OrchestrationUnitTest {
6060
private OrchestrationClient client;
61-
private ModuleConfigs config;
61+
private OrchestrationModuleConfig config;
6262
private final Function<String, InputStream> fileLoader =
6363
filename -> Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream(filename));
6464

@@ -100,10 +100,7 @@ void setup(WireMockRuntimeInfo server) {
100100
.forDeploymentByScenario("orchestration")
101101
.withResourceGroup("my-resource-group");
102102
client = new OrchestrationClient(deployment);
103-
config =
104-
ModuleConfigs.create()
105-
.llmModuleConfig(LLM_CONFIG)
106-
.templatingModuleConfig(TemplatingModuleConfig.create().template());
103+
config = new OrchestrationModuleConfig().withLlmConfig(LLM_CONFIG);
107104
}
108105

109106
@Test
@@ -306,7 +303,7 @@ void messagesHistory() throws IOException {
306303
ChatMessage.create().role("user").content("What is the typical food there?");
307304

308305
final var prompt = new OrchestrationPrompt(message);
309-
final var request = OrchestrationClient.toCompletionPostRequestDto(prompt, config);
306+
final var request = OrchestrationClient.toCompletionPostRequest(prompt, config);
310307
request.setMessagesHistory(messagesHistory);
311308

312309
final var result = client.chatCompletion(request);
@@ -382,7 +379,7 @@ void maskingAnonymization() throws IOException {
382379
void testGenericErrorHandling() {
383380
stubFor(post(anyUrl()).willReturn(serverError()));
384381

385-
assertThatThrownBy(() -> client.chatCompletion(mock(CompletionPostRequest.class)))
382+
assertThatThrownBy(() -> client.executeRequest(mock(CompletionPostRequest.class)))
386383
.isInstanceOf(OrchestrationClientException.class)
387384
.hasMessageContaining("500 Server Error");
388385
}
@@ -396,7 +393,7 @@ void testOrchestrationErrorParsing() {
396393
.withHeader("Content-Type", "application/json")
397394
.withBodyFile("errorResponse.json")));
398395

399-
assertThatThrownBy(() -> client.chatCompletion(mock(CompletionPostRequest.class)))
396+
assertThatThrownBy(() -> client.executeRequest(mock(CompletionPostRequest.class)))
400397
.isInstanceOf(OrchestrationClientException.class)
401398
.hasMessageContaining("400 Bad Request")
402399
.hasMessageContaining("'orchestration_config' is a required property");

sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ public CompletionPostResponse filter(@Nonnull @PathVariable("threshold") final S
170170
final var filter =
171171
FILTERING_CONFIG.apply(AzureThreshold.fromValue(Integer.parseInt(threshold)));
172172

173-
final var request = OrchestrationClient.toCompletionPostRequestDto(prompt, config);
173+
final var request = OrchestrationClient.toCompletionPostRequest(prompt, config);
174174
request.getOrchestrationConfig().getModuleConfigurations().setFilteringModuleConfig(filter);
175175

176176
return CLIENT.chatCompletion(request);
@@ -193,7 +193,7 @@ public CompletionPostResponse messagesHistory() {
193193
ChatMessage.create().role("user").content("What is the typical food there?");
194194
final var prompt = new OrchestrationPrompt(message);
195195

196-
final var request = OrchestrationClient.toCompletionPostRequestDto(prompt, config);
196+
final var request = OrchestrationClient.toCompletionPostRequest(prompt, config);
197197
request.setMessagesHistory(messagesHistory);
198198

199199
return CLIENT.chatCompletion(request);
@@ -217,7 +217,7 @@ public CompletionPostResponse maskingAnonymization() {
217217
.method(MaskingProviderConfig.MethodEnum.ANONYMIZATION)
218218
.entities(ALL_DPI_ENTITIES));
219219

220-
final var request = OrchestrationClient.toCompletionPostRequestDto(prompt, config);
220+
final var request = OrchestrationClient.toCompletionPostRequest(prompt, config);
221221
request
222222
.getOrchestrationConfig()
223223
.getModuleConfigurations()
@@ -245,7 +245,7 @@ public CompletionPostResponse maskingPseudonymization() {
245245
.method(MethodEnum.PSEUDONYMIZATION)
246246
.entities(ALL_DPI_ENTITIES));
247247

248-
final var request = OrchestrationClient.toCompletionPostRequestDto(prompt, config);
248+
final var request = OrchestrationClient.toCompletionPostRequest(prompt, config);
249249
request
250250
.getOrchestrationConfig()
251251
.getModuleConfigurations()

0 commit comments

Comments
 (0)