Skip to content

Commit f23bf96

Browse files
committed
feat: Orchestration Prompt
1 parent eee6c5d commit f23bf96

File tree

9 files changed

+394
-170
lines changed

9 files changed

+394
-170
lines changed
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import com.sap.ai.sdk.orchestration.client.model.CompletionPostRequest;
4+
import com.sap.ai.sdk.orchestration.client.model.ModuleConfigs;
5+
import com.sap.ai.sdk.orchestration.client.model.TemplatingModuleConfig;
6+
import java.util.ArrayList;
7+
import javax.annotation.Nonnull;
8+
import lombok.AccessLevel;
9+
import lombok.NoArgsConstructor;
10+
import lombok.val;
11+
12+
/** Factory to create all DTOs from an orchestration configuration. */
13+
@NoArgsConstructor(access = AccessLevel.NONE)
14+
final class ModuleConfigFactory {
15+
@Nonnull
16+
static CompletionPostRequest toCompletionPostRequestDto(
17+
@Nonnull final OrchestrationPrompt prompt, @Nonnull final ModuleConfigs config) {
18+
// copying is required because we have to merge the prompt into the template config
19+
// also, users may modify the object before request execution
20+
val configCopy = copyModuleConfigs(config);
21+
configCopy.setTemplatingModuleConfig(
22+
toTemplateModuleConfigDto(prompt, config.getTemplatingModuleConfig()));
23+
24+
return CompletionPostRequest.create()
25+
.orchestrationConfig(
26+
com.sap.ai.sdk.orchestration.client.model.OrchestrationConfig.create()
27+
.moduleConfigurations(configCopy))
28+
.inputParams(prompt.getTemplateParameters());
29+
}
30+
31+
@Nonnull
32+
static TemplatingModuleConfig toTemplateModuleConfigDto(
33+
@Nonnull final OrchestrationPrompt prompt, @Nonnull final TemplatingModuleConfig template) {
34+
/*
35+
* Currently, we have to merge the prompt into the template configuration.
36+
* This works around the limitation that the template config isn't optional.
37+
* This comes at the risk that the prompt unintentionally contains the templating pattern "{{? .. }}".
38+
* In this case, the request will fail, since the templating module will try to resolve the parameter.
39+
* To be fixed with https://github.tools.sap/AI/llm-orchestration/issues/662
40+
*/
41+
val messagesWithPrompt = new ArrayList<>(template.getTemplate());
42+
messagesWithPrompt.addAll(prompt.getMessages());
43+
if (messagesWithPrompt.isEmpty()) {
44+
throw new IllegalStateException(
45+
"A prompt is required. Pass at least one message or configure a template with messages or a template reference.");
46+
}
47+
return TemplatingModuleConfig.create().template(messagesWithPrompt);
48+
}
49+
50+
static ModuleConfigs copyModuleConfigs(@Nonnull final ModuleConfigs configs) {
51+
return ModuleConfigs.create()
52+
.llmModuleConfig(configs.getLlmModuleConfig())
53+
.templatingModuleConfig(configs.getTemplatingModuleConfig())
54+
.maskingModuleConfig(configs.getMaskingModuleConfig())
55+
.filteringModuleConfig(configs.getFilteringModuleConfig());
56+
}
57+
}

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import com.sap.ai.sdk.core.AiCoreService;
1111
import com.sap.ai.sdk.orchestration.client.model.CompletionPostRequest;
1212
import com.sap.ai.sdk.orchestration.client.model.CompletionPostResponse;
13+
import com.sap.ai.sdk.orchestration.client.model.ModuleConfigs;
14+
import com.sap.ai.sdk.orchestration.client.model.OrchestrationConfig;
1315
import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor;
1416
import com.sap.cloud.sdk.cloudplatform.connectivity.exception.DestinationAccessException;
1517
import com.sap.cloud.sdk.cloudplatform.connectivity.exception.DestinationNotFoundException;
@@ -64,6 +66,23 @@ public OrchestrationClient(@Nonnull final AiCoreDeployment deployment) {
6466
this.deployment = () -> deployment;
6567
}
6668

69+
/**
70+
* Generate a completion for the given prompt.
71+
*
72+
* @param prompt The {@link OrchestrationPrompt} to send to orchestration.
73+
* @param config The {@link ModuleConfigs} configuration to use for the completion.
74+
* @return the completion output
75+
* @throws OrchestrationClientException if the request fails.
76+
*/
77+
@Nonnull
78+
public CompletionPostResponse chatCompletion(
79+
@Nonnull final OrchestrationPrompt prompt, @Nonnull final ModuleConfigs config)
80+
throws OrchestrationClientException {
81+
82+
val request = toCompletionPostRequestDto(prompt, config);
83+
return executeRequest(request);
84+
}
85+
6786
/**
6887
* Generate a completion for the given prompt.
6988
*
@@ -112,6 +131,20 @@ public CompletionPostResponse executeRequest(@Nonnull final CompletionPostReques
112131
return executeRequest(postRequest);
113132
}
114133

134+
/**
135+
* Convert the given prompt and config into a low-level request DTO. The DTO allows for further
136+
* customization before sending the request.
137+
*
138+
* @param prompt The {@link OrchestrationPrompt} to generate a completion for.
139+
* @param config The {@link OrchestrationConfig } configuration to use for the completion.
140+
* @return The low-level request DTO to send to orchestration.
141+
*/
142+
@Nonnull
143+
public static CompletionPostRequest toCompletionPostRequestDto(
144+
@Nonnull final OrchestrationPrompt prompt, @Nonnull final ModuleConfigs config) {
145+
return ModuleConfigFactory.toCompletionPostRequestDto(prompt, config);
146+
}
147+
115148
@SuppressWarnings("UnstableApiUsage")
116149
@Nonnull
117150
CompletionPostResponse executeRequest(@Nonnull final BasicClassicHttpRequest request) {
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import com.sap.ai.sdk.orchestration.client.model.ChatMessage;
4+
import com.sap.ai.sdk.orchestration.client.model.OrchestrationConfig;
5+
import java.util.ArrayList;
6+
import java.util.Arrays;
7+
import java.util.List;
8+
import java.util.Map;
9+
import javax.annotation.Nonnull;
10+
import lombok.AccessLevel;
11+
import lombok.AllArgsConstructor;
12+
import lombok.Getter;
13+
import lombok.Value;
14+
import lombok.val;
15+
16+
/**
17+
* Represents a request that can be sent to the orchestration service, containing messages and
18+
* configuration for the orchestration modules. Prompts may be reused across multiple requests.
19+
*
20+
* @see OrchestrationClient
21+
* @see OrchestrationConfig
22+
*/
23+
@Value
24+
@Getter(AccessLevel.PACKAGE)
25+
@AllArgsConstructor
26+
public class OrchestrationPrompt {
27+
@Nonnull List<ChatMessage> messages;
28+
@Nonnull Map<String, String> templateParameters;
29+
30+
/**
31+
* Initialize a prompt with the given user message.
32+
*
33+
* @param message A user message.
34+
*/
35+
public OrchestrationPrompt(@Nonnull final String message) {
36+
this(List.of(ChatMessage.create().role("user").content(message)), Map.of());
37+
}
38+
39+
/**
40+
* Initialize a prompt from the given messages.
41+
*
42+
* @param message The first message.
43+
* @param messages Optionally, more messages.
44+
*/
45+
public OrchestrationPrompt(
46+
@Nonnull final ChatMessage message, @Nonnull final ChatMessage... messages) {
47+
val allMessages = new ArrayList<ChatMessage>();
48+
allMessages.add(message);
49+
allMessages.addAll(Arrays.asList(messages));
50+
this.messages = allMessages;
51+
this.templateParameters = Map.of();
52+
}
53+
54+
/**
55+
* Initialize a prompt based on template variables.
56+
*
57+
* @param inputParams The input parameters as entries of template variables and their contents.
58+
*/
59+
public OrchestrationPrompt(@Nonnull final Map<String, String> inputParams) {
60+
this(List.of(), inputParams);
61+
}
62+
}
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import static org.assertj.core.api.Assertions.assertThat;
4+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
5+
import static org.mockito.Mockito.mock;
6+
7+
import com.sap.ai.sdk.orchestration.client.model.ChatMessage;
8+
import com.sap.ai.sdk.orchestration.client.model.FilteringModuleConfig;
9+
import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig;
10+
import com.sap.ai.sdk.orchestration.client.model.MaskingModuleConfig;
11+
import com.sap.ai.sdk.orchestration.client.model.ModuleConfigs;
12+
import com.sap.ai.sdk.orchestration.client.model.TemplatingModuleConfig;
13+
import java.util.Map;
14+
import org.junit.jupiter.api.Test;
15+
16+
class ModuleConfigFactoryTest {
17+
18+
@Test
19+
void testThrowsOnMissingMessages() {
20+
var prompt = new OrchestrationPrompt(Map.of());
21+
var templateConfig = TemplatingModuleConfig.create().template();
22+
23+
assertThatThrownBy(() -> ModuleConfigFactory.toTemplateModuleConfigDto(prompt, templateConfig))
24+
.isInstanceOf(IllegalStateException.class)
25+
.hasMessageContaining("A prompt is required");
26+
}
27+
28+
@Test
29+
void testEmptyTemplateConfig() {
30+
var systemMessage = ChatMessage.create().role("system").content("foo");
31+
var userMessage = ChatMessage.create().role("user").content("Hello");
32+
33+
var expected = TemplatingModuleConfig.create().template(systemMessage, userMessage);
34+
35+
var prompt = new OrchestrationPrompt(systemMessage, userMessage);
36+
var actual =
37+
ModuleConfigFactory.toTemplateModuleConfigDto(
38+
prompt, TemplatingModuleConfig.create().template());
39+
40+
assertThat(actual).isEqualTo(expected);
41+
assertThat(actual.getTemplate())
42+
.describedAs(
43+
"The template should be copied to not modify an existing config which might be reused.")
44+
.isNotSameAs(expected.getTemplate());
45+
}
46+
47+
@Test
48+
void testMergingTemplateConfig() {
49+
var systemMessage = ChatMessage.create().role("system").content("foo");
50+
var userMessage = ChatMessage.create().role("user").content("Hello ");
51+
var userMessage2 = ChatMessage.create().role("user").content("World");
52+
53+
var expected =
54+
TemplatingModuleConfig.create().template(systemMessage, userMessage, userMessage2);
55+
56+
var prompt = new OrchestrationPrompt(userMessage2);
57+
var templateConfig = TemplatingModuleConfig.create().template(systemMessage, userMessage);
58+
var actual = ModuleConfigFactory.toTemplateModuleConfigDto(prompt, templateConfig);
59+
60+
assertThat(actual).isEqualTo(expected);
61+
}
62+
63+
@Test
64+
void testCopy() {
65+
var moduleConfigs =
66+
ModuleConfigs.create()
67+
.llmModuleConfig(mock(LLMModuleConfig.class))
68+
.templatingModuleConfig(mock(TemplatingModuleConfig.class))
69+
.filteringModuleConfig(mock(FilteringModuleConfig.class))
70+
.maskingModuleConfig(mock(MaskingModuleConfig.class));
71+
assertThat(ModuleConfigFactory.copyModuleConfigs(moduleConfigs))
72+
.isEqualTo(moduleConfigs)
73+
.isNotSameAs(moduleConfigs);
74+
}
75+
}

orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java

Lines changed: 35 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
@WireMockTest
5959
class OrchestrationUnitTest {
6060
private OrchestrationClient client;
61+
private ModuleConfigs config;
6162
private final Function<String, InputStream> fileLoader =
6263
filename -> Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream(filename));
6364

@@ -71,16 +72,6 @@ class OrchestrationUnitTest {
7172
"frequency_penalty", 0,
7273
"presence_penalty", 0));
7374

74-
private static final Function<TemplatingModuleConfig, CompletionPostRequest> TEMPLATE_CONFIG =
75-
(TemplatingModuleConfig templatingModuleConfig) ->
76-
CompletionPostRequest.create()
77-
.orchestrationConfig(
78-
OrchestrationConfig.create()
79-
.moduleConfigurations(
80-
ModuleConfigs.create()
81-
.llmModuleConfig(LLM_CONFIG)
82-
.templatingModuleConfig(templatingModuleConfig)));
83-
8475
@BeforeEach
8576
void setup(WireMockRuntimeInfo server) {
8677
stubFor(
@@ -109,27 +100,43 @@ void setup(WireMockRuntimeInfo server) {
109100
.forDeploymentByScenario("orchestration")
110101
.withResourceGroup("my-resource-group");
111102
client = new OrchestrationClient(deployment);
103+
config =
104+
ModuleConfigs.create()
105+
.llmModuleConfig(LLM_CONFIG)
106+
.templatingModuleConfig(TemplatingModuleConfig.create().template());
107+
}
108+
109+
@Test
110+
void testCompletion() {
111+
stubFor(
112+
post(urlPathEqualTo("/v2/inference/deployments/abcdef0123456789/completion"))
113+
.willReturn(
114+
aResponse()
115+
.withBodyFile("templatingResponse.json")
116+
.withHeader("Content-Type", "application/json")));
117+
final var result =
118+
client.chatCompletion(new OrchestrationPrompt("What is the capital of France?"), config);
119+
120+
assertThat(result).isNotNull();
121+
assertThat(result.getOrchestrationResult().getChoices().get(0).getMessage().getContent())
122+
.isNotEmpty();
112123
}
113124

114125
@Test
115-
void templating() throws IOException {
126+
void testTemplating() throws IOException {
116127
stubFor(
117128
post(urlPathEqualTo("/v2/inference/deployments/abcdef0123456789/completion"))
118129
.willReturn(
119130
aResponse()
120131
.withBodyFile("templatingResponse.json")
121132
.withHeader("Content-Type", "application/json")));
122133

123-
final var template = ChatMessage.create().role("user").content("{{?input}}");
134+
final var template = List.of(ChatMessage.create().role("user").content("{{?input}}"));
124135
final var inputParams =
125136
Map.of("input", "Reply with 'Orchestration Service is working!' in German");
126137

127-
final var config =
128-
TEMPLATE_CONFIG
129-
.apply(TemplatingModuleConfig.create().template(template))
130-
.inputParams(inputParams);
131-
132-
final var result = client.chatCompletion(config);
138+
final var result =
139+
client.chatCompletion(new OrchestrationPrompt(template, inputParams), config);
133140

134141
assertThat(result.getRequestId()).isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91");
135142
assertThat(result.getModuleResults().getTemplating().get(0).getContent())
@@ -176,7 +183,7 @@ void templating() throws IOException {
176183
}
177184

178185
@Test
179-
void templatingBadRequest() {
186+
void testBadRequest() {
180187
stubFor(
181188
post(urlPathEqualTo("/v2/inference/deployments/abcdef0123456789/completion"))
182189
.willReturn(
@@ -191,17 +198,10 @@ void templatingBadRequest() {
191198
}
192199
""",
193200
SC_BAD_REQUEST)));
201+
var message = ChatMessage.create().role("user").content("What is the capital of {{?input}}?");
202+
final var prompt = new OrchestrationPrompt(List.of(message), Map.of());
194203

195-
final var template = ChatMessage.create().role("user").content("{{?input}}");
196-
// input params are omitted on purpose to trigger an error
197-
Map<String, String> inputParams = Map.of();
198-
199-
final var config =
200-
TEMPLATE_CONFIG
201-
.apply(TemplatingModuleConfig.create().template(template))
202-
.inputParams(inputParams);
203-
204-
assertThatThrownBy(() -> client.chatCompletion(config))
204+
assertThatThrownBy(() -> client.chatCompletion(prompt, config))
205205
.isInstanceOf(OrchestrationClientException.class)
206206
.hasMessage(
207207
"Request to orchestration service failed with status 400 Bad Request and error message: 'Missing required parameters: ['input']'");
@@ -305,21 +305,20 @@ void messagesHistory() throws IOException {
305305
final var message =
306306
ChatMessage.create().role("user").content("What is the typical food there?");
307307

308-
final var config =
309-
TEMPLATE_CONFIG
310-
.apply(TemplatingModuleConfig.create().template(message))
311-
.messagesHistory(messagesHistory);
308+
final var prompt = new OrchestrationPrompt(message);
309+
final var request = OrchestrationClient.toCompletionPostRequestDto(prompt, config);
310+
request.setMessagesHistory(messagesHistory);
312311

313-
final var result = client.chatCompletion(config);
312+
final var result = client.chatCompletion(request);
314313

315314
assertThat(result.getRequestId()).isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91");
316315

317316
// verify that the history is sent correctly
318317
try (var requestInputStream = fileLoader.apply("messagesHistoryRequest.json")) {
319-
final String request = new String(requestInputStream.readAllBytes());
318+
final String requestBody = new String(requestInputStream.readAllBytes());
320319
verify(
321320
postRequestedFor(urlPathEqualTo("/v2/inference/deployments/abcdef0123456789/completion"))
322-
.withRequestBody(equalToJson(request)));
321+
.withRequestBody(equalToJson(requestBody)));
323322
}
324323
}
325324

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@
7373
<enforcer.skipBanGeneratedModulesReference>false</enforcer.skipBanGeneratedModulesReference>
7474
<!-- Test coverage -->
7575
<coverage.instruction>74%</coverage.instruction>
76-
<coverage.branch>68%</coverage.branch>
76+
<coverage.branch>60%</coverage.branch>
7777
<coverage.complexity>67%</coverage.complexity>
7878
<coverage.line>75%</coverage.line>
7979
<coverage.method>80%</coverage.method>

0 commit comments

Comments
 (0)