Skip to content

Commit cd821c6

Browse files
authored
Chat Message Convenience [Redone] (#199)
* Introduce role based message classes - Better separate convenience and generated API * Clean up - Rename toChat.. to createChat.. - extract variables * Make Message interface unsealed * Adapt for package name changes * Introduce convenience static methods --------- Co-authored-by: Roshin Rajan Panackal <[email protected]>
1 parent 2b4e484 commit cd821c6

File tree

10 files changed

+176
-57
lines changed

10 files changed

+176
-57
lines changed
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import javax.annotation.Nonnull;
4+
import lombok.Value;
5+
import lombok.experimental.Accessors;
6+
7+
/** Represents a chat message as 'assistant' to the orchestration service. */
8+
@Value
9+
@Accessors(fluent = true)
10+
public class AssistantMessage implements Message {
11+
12+
/** The role of the assistant. */
13+
@Nonnull String role = "assistant";
14+
15+
/** The content of the message. */
16+
@Nonnull String content;
17+
}

orchestration/src/main/java/com/sap/ai/sdk/orchestration/ConfigToRequestTransformer.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ static CompletionPostRequest toCompletionPostRequest(
3131
.orchestrationConfig(
3232
OrchestrationConfig.create().moduleConfigurations(toModuleConfigs(configCopy)))
3333
.inputParams(prompt.getTemplateParameters())
34-
.messagesHistory(prompt.getMessagesHistory());
34+
.messagesHistory(
35+
prompt.getMessagesHistory().stream().map(Message::createChatMessage).toList());
3536
}
3637

3738
@Nonnull
@@ -46,7 +47,8 @@ static TemplatingModuleConfig toTemplateModuleConfig(
4647
*/
4748
val messages = template instanceof Template t ? t.getTemplate() : List.<ChatMessage>of();
4849
val messagesWithPrompt = new ArrayList<>(messages);
49-
messagesWithPrompt.addAll(prompt.getMessages());
50+
messagesWithPrompt.addAll(
51+
prompt.getMessages().stream().map(Message::createChatMessage).toList());
5052
if (messagesWithPrompt.isEmpty()) {
5153
throw new IllegalStateException(
5254
"A prompt is required. Pass at least one message or configure a template with messages or a template reference.");
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import com.sap.ai.sdk.orchestration.model.ChatMessage;
4+
import javax.annotation.Nonnull;
5+
6+
/** Interface representing convenience wrappers of chat message to the orchestration service. */
7+
public sealed interface Message permits UserMessage, AssistantMessage, SystemMessage {
8+
9+
/**
10+
* A convenience method to create a user message.
11+
*
12+
* @param msg the message content.
13+
* @return the user message.
14+
*/
15+
@Nonnull
16+
static UserMessage user(@Nonnull final String msg) {
17+
return new UserMessage(msg);
18+
}
19+
20+
/**
21+
* A convenience method to create an assistant message.
22+
*
23+
* @param msg the message content.
24+
* @return the assistant message.
25+
*/
26+
@Nonnull
27+
static AssistantMessage assistant(@Nonnull final String msg) {
28+
return new AssistantMessage(msg);
29+
}
30+
31+
/**
32+
* A convenience method to create a system message.
33+
*
34+
* @param msg the message content.
35+
* @return the system message.
36+
*/
37+
@Nonnull
38+
static SystemMessage system(@Nonnull final String msg) {
39+
return new SystemMessage(msg);
40+
}
41+
42+
/**
43+
* Converts the message to a serializable ChatMessage object.
44+
*
45+
* @return the corresponding {@code ChatMessage} object.
46+
*/
47+
@Nonnull
48+
default ChatMessage createChatMessage() {
49+
return ChatMessage.create().role(role()).content(content());
50+
}
51+
52+
/**
53+
* Returns the role of the assistant.
54+
*
55+
* @return the role.
56+
*/
57+
@Nonnull
58+
String role();
59+
60+
/**
61+
* Returns the content of the message.
62+
*
63+
* @return the content.
64+
*/
65+
@Nonnull
66+
String content();
67+
}

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,9 @@ public OrchestrationChatResponse executeRequestFromJsonModuleConfig(
169169
}
170170

171171
final ObjectNode requestJson = JACKSON.createObjectNode();
172-
requestJson.set("messages_history", JACKSON.valueToTree(prompt.getMessagesHistory()));
172+
final var chatMessageHistory =
173+
prompt.getMessagesHistory().stream().map(Message::createChatMessage).toList();
174+
requestJson.set("messages_history", JACKSON.valueToTree(chatMessageHistory));
173175
requestJson.set("input_params", JACKSON.valueToTree(prompt.getTemplateParameters()));
174176

175177
final JsonNode moduleConfigJson;

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationPrompt.java

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
package com.sap.ai.sdk.orchestration;
22

3-
import com.sap.ai.sdk.orchestration.model.ChatMessage;
43
import com.sap.ai.sdk.orchestration.model.OrchestrationConfig;
54
import java.util.ArrayList;
65
import java.util.Arrays;
@@ -22,17 +21,17 @@
2221
@Value
2322
@Getter(AccessLevel.PACKAGE)
2423
public class OrchestrationPrompt {
25-
@Nonnull List<ChatMessage> messages = new ArrayList<>();
24+
@Nonnull List<Message> messages = new ArrayList<>();
2625
@Nonnull Map<String, String> templateParameters = new TreeMap<>();
27-
@Nonnull List<ChatMessage> messagesHistory = new ArrayList<>();
26+
@Nonnull List<Message> messagesHistory = new ArrayList<>();
2827

2928
/**
3029
* Initialize a prompt with the given user message.
3130
*
3231
* @param message A user message.
3332
*/
3433
public OrchestrationPrompt(@Nonnull final String message) {
35-
messages.add(ChatMessage.create().role("user").content(message));
34+
messages.add(new UserMessage(message));
3635
}
3736

3837
/**
@@ -41,8 +40,7 @@ public OrchestrationPrompt(@Nonnull final String message) {
4140
* @param message The first message.
4241
* @param messages Optionally, more messages.
4342
*/
44-
public OrchestrationPrompt(
45-
@Nonnull final ChatMessage message, @Nonnull final ChatMessage... messages) {
43+
public OrchestrationPrompt(@Nonnull final Message message, @Nonnull final Message... messages) {
4644
this.messages.add(message);
4745
this.messages.addAll(Arrays.asList(messages));
4846
}
@@ -53,7 +51,7 @@ public OrchestrationPrompt(
5351
* @param inputParams The input parameters as entries of template variables and their contents.
5452
*/
5553
public OrchestrationPrompt(
56-
@Nonnull final Map<String, String> inputParams, @Nonnull final ChatMessage... messages) {
54+
@Nonnull final Map<String, String> inputParams, @Nonnull final Message... messages) {
5755
this.templateParameters.putAll(inputParams);
5856
this.messages.addAll(Arrays.asList(messages));
5957
}
@@ -64,7 +62,7 @@ public OrchestrationPrompt(
6462
* @param messagesHistory The chat history to add.
6563
*/
6664
@Nonnull
67-
public OrchestrationPrompt messageHistory(@Nonnull final List<ChatMessage> messagesHistory) {
65+
public OrchestrationPrompt messageHistory(@Nonnull final List<Message> messagesHistory) {
6866
this.messagesHistory.clear();
6967
this.messagesHistory.addAll(messagesHistory);
7068
return this;
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import javax.annotation.Nonnull;
4+
import lombok.Value;
5+
import lombok.experimental.Accessors;
6+
7+
/** Represents a chat message as 'system' to the orchestration service. */
8+
@Value
9+
@Accessors(fluent = true)
10+
public class SystemMessage implements Message {
11+
12+
/** The role of the assistant. */
13+
@Nonnull String role = "system";
14+
15+
/** The content of the message. */
16+
@Nonnull String content;
17+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import javax.annotation.Nonnull;
4+
import lombok.Value;
5+
import lombok.experimental.Accessors;
6+
7+
/** Represents a chat message as 'user' to the orchestration service. */
8+
@Value
9+
@Accessors(fluent = true)
10+
public class UserMessage implements Message {
11+
12+
/** The role of the assistant. */
13+
@Nonnull String role = "user";
14+
15+
/** The content of the message. */
16+
@Nonnull String content;
17+
}

orchestration/src/test/java/com/sap/ai/sdk/orchestration/ConfigToRequestTransformerTest.java

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import static org.assertj.core.api.Assertions.assertThat;
55
import static org.assertj.core.api.Assertions.assertThatThrownBy;
66

7-
import com.sap.ai.sdk.orchestration.model.ChatMessage;
87
import com.sap.ai.sdk.orchestration.model.Template;
98
import java.util.List;
109
import java.util.Map;
@@ -31,10 +30,12 @@ void testThrowsOnMissingMessages() {
3130

3231
@Test
3332
void testEmptyTemplateConfig() {
34-
var systemMessage = ChatMessage.create().role("system").content("foo");
35-
var userMessage = ChatMessage.create().role("user").content("Hello");
33+
var systemMessage = new SystemMessage("foo");
34+
var userMessage = new UserMessage("Hello");
3635

37-
var expected = Template.create().template(List.of(systemMessage, userMessage));
36+
var expected =
37+
Template.create()
38+
.template(List.of(systemMessage.createChatMessage(), userMessage.createChatMessage()));
3839

3940
var prompt = new OrchestrationPrompt(systemMessage, userMessage);
4041
var actual =
@@ -51,28 +52,36 @@ void testEmptyTemplateConfig() {
5152

5253
@Test
5354
void testMergingTemplateConfig() {
54-
var systemMessage = ChatMessage.create().role("system").content("foo");
55-
var userMessage = ChatMessage.create().role("user").content("Hello ");
56-
var userMessage2 = ChatMessage.create().role("user").content("World");
57-
58-
var expected = Template.create().template(List.of(systemMessage, userMessage, userMessage2));
55+
var systemMessage = new SystemMessage("foo");
56+
var userMessage = new UserMessage("Hello ");
57+
var userMessage2 = new UserMessage("World");
58+
59+
var expected =
60+
Template.create()
61+
.template(
62+
List.of(
63+
systemMessage.createChatMessage(),
64+
userMessage.createChatMessage(),
65+
userMessage2.createChatMessage()));
5966

6067
var prompt = new OrchestrationPrompt(userMessage2);
61-
var templateConfig = Template.create().template(List.of(systemMessage, userMessage));
68+
var templateConfig =
69+
Template.create()
70+
.template(List.of(systemMessage.createChatMessage(), userMessage.createChatMessage()));
6271
var actual = ConfigToRequestTransformer.toTemplateModuleConfig(prompt, templateConfig);
6372

6473
assertThat(actual).isEqualTo(expected);
6574
}
6675

6776
@Test
6877
void testMessagesHistory() {
69-
var systemMessage = ChatMessage.create().role("system").content("foo");
78+
var systemMessage = new SystemMessage("foo");
7079

7180
var prompt = new OrchestrationPrompt("bar").messageHistory(List.of(systemMessage));
7281
var actual =
7382
ConfigToRequestTransformer.toCompletionPostRequest(
7483
prompt, new OrchestrationModuleConfig().withLlmConfig(CUSTOM_GPT_35));
7584

76-
assertThat(actual.getMessagesHistory()).containsExactly(systemMessage);
85+
assertThat(actual.getMessagesHistory()).containsExactly(systemMessage.createChatMessage());
7786
}
7887
}

orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
import com.github.tomakehurst.wiremock.junit5.WireMockTest;
3131
import com.github.tomakehurst.wiremock.stubbing.Scenario;
3232
import com.sap.ai.sdk.core.AiCoreService;
33-
import com.sap.ai.sdk.orchestration.model.ChatMessage;
3433
import com.sap.ai.sdk.orchestration.model.CompletionPostRequest;
3534
import com.sap.ai.sdk.orchestration.model.DPIEntities;
3635
import com.sap.ai.sdk.orchestration.model.GenericModuleResult;
@@ -121,7 +120,7 @@ void testTemplating() throws IOException {
121120
.withBodyFile("templatingResponse.json")
122121
.withHeader("Content-Type", "application/json")));
123122

124-
final var template = ChatMessage.create().role("user").content("{{?input}}");
123+
final var template = new UserMessage("{{?input}}");
125124
final var inputParams =
126125
Map.of("input", "Reply with 'Orchestration Service is working!' in German");
127126

@@ -258,12 +257,11 @@ void messagesHistory() throws IOException {
258257
.withBodyFile("templatingResponse.json")
259258
.withHeader("Content-Type", "application/json")));
260259

261-
final List<ChatMessage> messagesHistory =
260+
final List<Message> messagesHistory =
262261
List.of(
263-
ChatMessage.create().role("user").content("What is the capital of France?"),
264-
ChatMessage.create().role("assistant").content("The capital of France is Paris."));
265-
final var message =
266-
ChatMessage.create().role("user").content("What is the typical food there?");
262+
new UserMessage("What is the capital of France?"),
263+
new AssistantMessage("The capital of France is Paris."));
264+
final var message = new UserMessage("What is the typical food there?");
267265

268266
prompt = new OrchestrationPrompt(message).messageHistory(messagesHistory);
269267

@@ -389,7 +387,7 @@ void testExecuteRequestFromJson() {
389387

390388
prompt =
391389
new OrchestrationPrompt(Map.of("foo", "bar"))
392-
.messageHistory(List.of(ChatMessage.create().role("user").content("Hello World!")));
390+
.messageHistory(List.of(new UserMessage("Hello World!")));
393391
final var configJson =
394392
"""
395393
{

0 commit comments

Comments
 (0)