Skip to content

Commit cdac75a

Browse files
committed
Introduce role based message classes
- Better separate convenience and generated API
1 parent 5b69e3a commit cdac75a

File tree

10 files changed

+154
-57
lines changed

10 files changed

+154
-57
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import com.sap.ai.sdk.orchestration.client.model.ChatMessage;
4+
import javax.annotation.Nonnull;
5+
import lombok.Value;
6+
import lombok.experimental.Accessors;
7+
8+
/** Represents a chat message as 'assistant' to the orchestration service. */
9+
@Value
10+
@Accessors(fluent = true)
11+
public class AssistantMessage implements Message {
12+
13+
/** The role of the assistant. */
14+
@Nonnull public static final String ROLE = "assistant";
15+
16+
@Nonnull String content;
17+
18+
/**
19+
* Converts the message to a serializable ChatMessage object.
20+
*
21+
* @return the corresponding {@code ChatMessage} object.
22+
*/
23+
@Nonnull
24+
public ChatMessage toChatMessage() {
25+
return ChatMessage.create().role(ROLE).content(content);
26+
}
27+
}

orchestration/src/main/java/com/sap/ai/sdk/orchestration/ConfigToRequestTransformer.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ static CompletionPostRequest toCompletionPostRequest(
3131
.orchestrationConfig(
3232
OrchestrationConfig.create().moduleConfigurations(toModuleConfigs(configCopy)))
3333
.inputParams(prompt.getTemplateParameters())
34-
.messagesHistory(prompt.getMessagesHistory());
34+
.messagesHistory(prompt.getMessagesHistory().stream().map(Message::toChatMessage).toList());
3535
}
3636

3737
@Nonnull
@@ -46,7 +46,7 @@ static TemplatingModuleConfig toTemplateModuleConfig(
4646
*/
4747
val messages = template instanceof Template t ? t.getTemplate() : List.<ChatMessage>of();
4848
val messagesWithPrompt = new ArrayList<>(messages);
49-
messagesWithPrompt.addAll(prompt.getMessages());
49+
messagesWithPrompt.addAll(prompt.getMessages().stream().map(Message::toChatMessage).toList());
5050
if (messagesWithPrompt.isEmpty()) {
5151
throw new IllegalStateException(
5252
"A prompt is required. Pass at least one message or configure a template with messages or a template reference.");
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import com.sap.ai.sdk.orchestration.client.model.ChatMessage;
4+
import javax.annotation.Nonnull;
5+
6+
/** Interface representing convenience wrappers of chat message to the orchestration service. */
7+
public sealed interface Message permits UserMessage, AssistantMessage, SystemMessage {
8+
9+
/**
10+
* Converts the message to a serializable ChatMessage object.
11+
*
12+
* @return the corresponding {@code ChatMessage} object.
13+
*/
14+
@Nonnull
15+
ChatMessage toChatMessage();
16+
}

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,10 @@ public OrchestrationChatResponse executeRequestFromJsonModuleConfig(
169169
}
170170

171171
final ObjectNode requestJson = JACKSON.createObjectNode();
172-
requestJson.set("messages_history", JACKSON.valueToTree(prompt.getMessagesHistory()));
172+
requestJson.set(
173+
"messages_history",
174+
JACKSON.valueToTree(
175+
prompt.getMessagesHistory().stream().map(Message::toChatMessage).toList()));
173176
requestJson.set("input_params", JACKSON.valueToTree(prompt.getTemplateParameters()));
174177

175178
final JsonNode moduleConfigJson;

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationPrompt.java

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
package com.sap.ai.sdk.orchestration;
22

3-
import com.sap.ai.sdk.orchestration.client.model.ChatMessage;
43
import com.sap.ai.sdk.orchestration.client.model.OrchestrationConfig;
54
import java.util.ArrayList;
65
import java.util.Arrays;
@@ -22,17 +21,17 @@
2221
@Value
2322
@Getter(AccessLevel.PACKAGE)
2423
public class OrchestrationPrompt {
25-
@Nonnull List<ChatMessage> messages = new ArrayList<>();
24+
@Nonnull List<Message> messages = new ArrayList<>();
2625
@Nonnull Map<String, String> templateParameters = new TreeMap<>();
27-
@Nonnull List<ChatMessage> messagesHistory = new ArrayList<>();
26+
@Nonnull List<Message> messagesHistory = new ArrayList<>();
2827

2928
/**
3029
* Initialize a prompt with the given user message.
3130
*
3231
* @param message A user message.
3332
*/
3433
public OrchestrationPrompt(@Nonnull final String message) {
35-
messages.add(ChatMessage.create().role("user").content(message));
34+
messages.add(new UserMessage(message));
3635
}
3736

3837
/**
@@ -41,8 +40,7 @@ public OrchestrationPrompt(@Nonnull final String message) {
4140
* @param message The first message.
4241
* @param messages Optionally, more messages.
4342
*/
44-
public OrchestrationPrompt(
45-
@Nonnull final ChatMessage message, @Nonnull final ChatMessage... messages) {
43+
public OrchestrationPrompt(@Nonnull final Message message, @Nonnull final Message... messages) {
4644
this.messages.add(message);
4745
this.messages.addAll(Arrays.asList(messages));
4846
}
@@ -53,7 +51,7 @@ public OrchestrationPrompt(
5351
* @param inputParams The input parameters as entries of template variables and their contents.
5452
*/
5553
public OrchestrationPrompt(
56-
@Nonnull final Map<String, String> inputParams, @Nonnull final ChatMessage... messages) {
54+
@Nonnull final Map<String, String> inputParams, @Nonnull final Message... messages) {
5755
this.templateParameters.putAll(inputParams);
5856
this.messages.addAll(Arrays.asList(messages));
5957
}
@@ -64,7 +62,7 @@ public OrchestrationPrompt(
6462
* @param messagesHistory The chat history to add.
6563
*/
6664
@Nonnull
67-
public OrchestrationPrompt messageHistory(@Nonnull final List<ChatMessage> messagesHistory) {
65+
public OrchestrationPrompt messageHistory(@Nonnull final List<Message> messagesHistory) {
6866
this.messagesHistory.clear();
6967
this.messagesHistory.addAll(messagesHistory);
7068
return this;
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import com.sap.ai.sdk.orchestration.client.model.ChatMessage;
4+
import javax.annotation.Nonnull;
5+
import lombok.Value;
6+
import lombok.experimental.Accessors;
7+
8+
/** Represents a chat message as 'system' to the orchestration service. */
9+
@Value
10+
@Accessors(fluent = true)
11+
public class SystemMessage implements Message {
12+
13+
/** The role of the assistant. */
14+
@Nonnull public static final String ROLE = "system";
15+
16+
@Nonnull String content;
17+
18+
/**
19+
* Converts the message to a serializable ChatMessage object.
20+
*
21+
* @return the corresponding {@code ChatMessage} object.
22+
*/
23+
@Nonnull
24+
public ChatMessage toChatMessage() {
25+
return ChatMessage.create().role(ROLE).content(content);
26+
}
27+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import com.sap.ai.sdk.orchestration.client.model.ChatMessage;
4+
import javax.annotation.Nonnull;
5+
import lombok.Value;
6+
import lombok.experimental.Accessors;
7+
8+
/** Represents a chat message as 'user' to the orchestration service. */
9+
@Value
10+
@Accessors(fluent = true)
11+
public class UserMessage implements Message {
12+
13+
/** The role of the assistant. */
14+
@Nonnull public static final String ROLE = "user";
15+
16+
@Nonnull String content;
17+
18+
/**
19+
* Converts the message to a serializable ChatMessage object.
20+
*
21+
* @return the corresponding {@code ChatMessage} object.
22+
*/
23+
@Nonnull
24+
public ChatMessage toChatMessage() {
25+
return ChatMessage.create().role(ROLE).content(content);
26+
}
27+
}

orchestration/src/test/java/com/sap/ai/sdk/orchestration/ConfigToRequestTransformerTest.java

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import static org.assertj.core.api.Assertions.assertThat;
55
import static org.assertj.core.api.Assertions.assertThatThrownBy;
66

7-
import com.sap.ai.sdk.orchestration.client.model.ChatMessage;
87
import com.sap.ai.sdk.orchestration.client.model.Template;
98
import java.util.List;
109
import java.util.Map;
@@ -31,10 +30,12 @@ void testThrowsOnMissingMessages() {
3130

3231
@Test
3332
void testEmptyTemplateConfig() {
34-
var systemMessage = ChatMessage.create().role("system").content("foo");
35-
var userMessage = ChatMessage.create().role("user").content("Hello");
33+
var systemMessage = new SystemMessage("foo");
34+
var userMessage = new UserMessage("Hello");
3635

37-
var expected = Template.create().template(List.of(systemMessage, userMessage));
36+
var expected =
37+
Template.create()
38+
.template(List.of(systemMessage.toChatMessage(), userMessage.toChatMessage()));
3839

3940
var prompt = new OrchestrationPrompt(systemMessage, userMessage);
4041
var actual =
@@ -51,28 +52,36 @@ void testEmptyTemplateConfig() {
5152

5253
@Test
5354
void testMergingTemplateConfig() {
54-
var systemMessage = ChatMessage.create().role("system").content("foo");
55-
var userMessage = ChatMessage.create().role("user").content("Hello ");
56-
var userMessage2 = ChatMessage.create().role("user").content("World");
57-
58-
var expected = Template.create().template(List.of(systemMessage, userMessage, userMessage2));
55+
var systemMessage = new SystemMessage("foo");
56+
var userMessage = new UserMessage("Hello ");
57+
var userMessage2 = new UserMessage("World");
58+
59+
var expected =
60+
Template.create()
61+
.template(
62+
List.of(
63+
systemMessage.toChatMessage(),
64+
userMessage.toChatMessage(),
65+
userMessage2.toChatMessage()));
5966

6067
var prompt = new OrchestrationPrompt(userMessage2);
61-
var templateConfig = Template.create().template(List.of(systemMessage, userMessage));
68+
var templateConfig =
69+
Template.create()
70+
.template(List.of(systemMessage.toChatMessage(), userMessage.toChatMessage()));
6271
var actual = ConfigToRequestTransformer.toTemplateModuleConfig(prompt, templateConfig);
6372

6473
assertThat(actual).isEqualTo(expected);
6574
}
6675

6776
@Test
6877
void testMessagesHistory() {
69-
var systemMessage = ChatMessage.create().role("system").content("foo");
78+
var systemMessage = new SystemMessage("foo");
7079

7180
var prompt = new OrchestrationPrompt("bar").messageHistory(List.of(systemMessage));
7281
var actual =
7382
ConfigToRequestTransformer.toCompletionPostRequest(
7483
prompt, new OrchestrationModuleConfig().withLlmConfig(CUSTOM_GPT_35));
7584

76-
assertThat(actual.getMessagesHistory()).containsExactly(systemMessage);
85+
assertThat(actual.getMessagesHistory()).containsExactly(systemMessage.toChatMessage());
7786
}
7887
}

orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
import com.github.tomakehurst.wiremock.junit5.WireMockTest;
3030
import com.github.tomakehurst.wiremock.stubbing.Scenario;
3131
import com.sap.ai.sdk.core.AiCoreService;
32-
import com.sap.ai.sdk.orchestration.client.model.ChatMessage;
3332
import com.sap.ai.sdk.orchestration.client.model.CompletionPostRequest;
3433
import com.sap.ai.sdk.orchestration.client.model.DPIEntities;
3534
import com.sap.ai.sdk.orchestration.client.model.GenericModuleResult;
@@ -120,7 +119,7 @@ void testTemplating() throws IOException {
120119
.withBodyFile("templatingResponse.json")
121120
.withHeader("Content-Type", "application/json")));
122121

123-
final var template = ChatMessage.create().role("user").content("{{?input}}");
122+
final var template = new UserMessage("{{?input}}");
124123
final var inputParams =
125124
Map.of("input", "Reply with 'Orchestration Service is working!' in German");
126125

@@ -257,12 +256,11 @@ void messagesHistory() throws IOException {
257256
.withBodyFile("templatingResponse.json")
258257
.withHeader("Content-Type", "application/json")));
259258

260-
final List<ChatMessage> messagesHistory =
259+
final List<Message> messagesHistory =
261260
List.of(
262-
ChatMessage.create().role("user").content("What is the capital of France?"),
263-
ChatMessage.create().role("assistant").content("The capital of France is Paris."));
264-
final var message =
265-
ChatMessage.create().role("user").content("What is the typical food there?");
261+
new UserMessage("What is the capital of France?"),
262+
new AssistantMessage("The capital of France is Paris."));
263+
final var message = new UserMessage("What is the typical food there?");
266264

267265
prompt = new OrchestrationPrompt(message).messageHistory(messagesHistory);
268266

@@ -388,7 +386,7 @@ void testExecuteRequestFromJson() {
388386

389387
prompt =
390388
new OrchestrationPrompt(Map.of("foo", "bar"))
391-
.messageHistory(List.of(ChatMessage.create().role("user").content("Hello World!")));
389+
.messageHistory(List.of(new UserMessage("Hello World!")));
392390
final var configJson =
393391
"""
394392
{

sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@
22

33
import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.GPT_35_TURBO;
44

5+
import com.sap.ai.sdk.orchestration.AssistantMessage;
56
import com.sap.ai.sdk.orchestration.AzureContentFilter;
67
import com.sap.ai.sdk.orchestration.AzureFilterThreshold;
78
import com.sap.ai.sdk.orchestration.DpiMasking;
9+
import com.sap.ai.sdk.orchestration.Message;
810
import com.sap.ai.sdk.orchestration.OrchestrationChatResponse;
911
import com.sap.ai.sdk.orchestration.OrchestrationClient;
1012
import com.sap.ai.sdk.orchestration.OrchestrationModuleConfig;
1113
import com.sap.ai.sdk.orchestration.OrchestrationPrompt;
12-
import com.sap.ai.sdk.orchestration.client.model.ChatMessage;
14+
import com.sap.ai.sdk.orchestration.SystemMessage;
15+
import com.sap.ai.sdk.orchestration.UserMessage;
1316
import com.sap.ai.sdk.orchestration.client.model.DPIEntities;
1417
import com.sap.ai.sdk.orchestration.client.model.Template;
1518
import java.util.List;
@@ -51,10 +54,8 @@ public OrchestrationChatResponse completion() {
5154
@Nonnull
5255
public OrchestrationChatResponse template() {
5356
final var template =
54-
ChatMessage.create()
55-
.role("user")
56-
.content("Reply with 'Orchestration Service is working!' in {{?language}}");
57-
final var templatingConfig = Template.create().template(List.of(template));
57+
new UserMessage("Reply with 'Orchestration Service is working!' in {{?language}}");
58+
final var templatingConfig = Template.create().template(List.of(template.toChatMessage()));
5859
final var configWithTemplate = config.withTemplateConfig(templatingConfig);
5960

6061
final var inputParams = Map.of("language", "German");
@@ -71,12 +72,11 @@ public OrchestrationChatResponse template() {
7172
@GetMapping("/messagesHistory")
7273
@Nonnull
7374
public OrchestrationChatResponse messagesHistory() {
74-
final List<ChatMessage> messagesHistory =
75+
final List<Message> messagesHistory =
7576
List.of(
76-
ChatMessage.create().role("user").content("What is the capital of France?"),
77-
ChatMessage.create().role("assistant").content("The capital of France is Paris."));
78-
final var message =
79-
ChatMessage.create().role("user").content("What is the typical food there?");
77+
new UserMessage("What is the capital of France?"),
78+
new AssistantMessage("The capital of France is Paris."));
79+
final var message = new UserMessage("What is the typical food there?");
8080

8181
final var prompt = new OrchestrationPrompt(message).messageHistory(messagesHistory);
8282

@@ -120,15 +120,11 @@ public OrchestrationChatResponse filter(
120120
@Nonnull
121121
public OrchestrationChatResponse maskingAnonymization() {
122122
final var systemMessage =
123-
ChatMessage.create()
124-
.role("system")
125-
.content(
126-
"Please evaluate the following user feedback and judge if the sentiment is positive or negative.");
123+
new SystemMessage(
124+
"Please evaluate the following user feedback and judge if the sentiment is positive or negative.");
127125
final var userMessage =
128-
ChatMessage.create()
129-
.role("user")
130-
.content(
131-
"""
126+
new UserMessage(
127+
"""
132128
I think the SDK is good, but could use some further enhancements.
133129
My architect Alice and manager Bob pointed out that we need the grounding capabilities, which aren't supported yet.
134130
""");
@@ -150,18 +146,14 @@ public OrchestrationChatResponse maskingAnonymization() {
150146
@Nonnull
151147
public OrchestrationChatResponse maskingPseudonymization() {
152148
final var systemMessage =
153-
ChatMessage.create()
154-
.role("system")
155-
.content(
156-
"""
149+
new SystemMessage(
150+
"""
157151
Please write an initial response to the below user feedback, stating that we are working on the feedback and will get back to them soon.
158152
Please make sure to address the user in person and end with "Best regards, the AI SDK team".
159153
""");
160154
final var userMessage =
161-
ChatMessage.create()
162-
.role("user")
163-
.content(
164-
"""
155+
new UserMessage(
156+
"""
165157
Username: Mallory
166158
userEmail: [email protected]
167159
Date: 2022-01-01

0 commit comments

Comments
 (0)