Skip to content

Commit 733053a

Browse files
committed
Custom interface base message classes
1 parent 0c609a3 commit 733053a

File tree

8 files changed

+92
-25
lines changed

8 files changed

+92
-25
lines changed

orchestration/src/main/java/com/sap/ai/sdk/orchestration/AssistantMessage.java

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,23 @@
33
import com.sap.ai.sdk.orchestration.client.model.ChatMessage;
44
import javax.annotation.Nonnull;
55

6-
public class AssistantMessage extends ChatMessage {
6+
/**
7+
* Represents a chat message as 'assistant' to the orchestration service.
8+
*
9+
* @param content
10+
*/
11+
public record AssistantMessage(@Nonnull String content) implements Message {
712

8-
public AssistantMessage(@Nonnull String content) {
9-
this.role("assistant").content(content);
13+
/** The role of the assistant. */
14+
public static final String ROLE = "assistant";
15+
16+
/**
17+
* Converts the message to a serializable ChatMessage object.
18+
*
19+
* @return the corresponding {@code ChatMessage} object.
20+
*/
21+
@Nonnull
22+
public ChatMessage toChatMessage() {
23+
return new ChatMessage().role(ROLE).content(content);
1024
}
1125
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import com.sap.ai.sdk.orchestration.client.model.ChatMessage;
4+
import javax.annotation.Nonnull;
5+
6+
/** Interface representing convenience wrappers of chat message to the orchestration service. */
7+
public sealed interface Message permits UserMessage, AssistantMessage, SystemMessage {
8+
9+
/**
10+
* Converts the message to a serializable ChatMessage object.
11+
*
12+
* @return the corresponding {@code ChatMessage} object.
13+
*/
14+
@Nonnull
15+
ChatMessage toChatMessage();
16+
}

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationPrompt.java

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ public class OrchestrationPrompt {
3232
* @param message A user message.
3333
*/
3434
public OrchestrationPrompt(@Nonnull final String message) {
35-
messages.add(new UserMessage(message));
35+
messages.add(new UserMessage(message).toChatMessage());
3636
}
3737

3838
/**
@@ -41,10 +41,9 @@ public OrchestrationPrompt(@Nonnull final String message) {
4141
* @param message The first message.
4242
* @param messages Optionally, more messages.
4343
*/
44-
public OrchestrationPrompt(
45-
@Nonnull final ChatMessage message, @Nonnull final ChatMessage... messages) {
46-
this.messages.add(message);
47-
this.messages.addAll(Arrays.asList(messages));
44+
public OrchestrationPrompt(@Nonnull final Message message, @Nonnull final Message... messages) {
45+
this.messages.add(message.toChatMessage());
46+
this.messages.addAll(Arrays.stream(messages).map(Message::toChatMessage).toList());
4847
}
4948

5049
/**
@@ -53,9 +52,9 @@ public OrchestrationPrompt(
5352
* @param inputParams The input parameters as entries of template variables and their contents.
5453
*/
5554
public OrchestrationPrompt(
56-
@Nonnull final Map<String, String> inputParams, @Nonnull final ChatMessage... messages) {
55+
@Nonnull final Map<String, String> inputParams, @Nonnull final Message... messages) {
5756
this.templateParameters.putAll(inputParams);
58-
this.messages.addAll(Arrays.asList(messages));
57+
this.messages.addAll(Arrays.stream(messages).map(Message::toChatMessage).toList());
5958
}
6059

6160
/**
@@ -64,9 +63,9 @@ public OrchestrationPrompt(
6463
* @param messagesHistory The chat history to add.
6564
*/
6665
@Nonnull
67-
public OrchestrationPrompt messageHistory(@Nonnull final List<ChatMessage> messagesHistory) {
66+
public OrchestrationPrompt messageHistory(@Nonnull final List<Message> messagesHistory) {
6867
this.messagesHistory.clear();
69-
this.messagesHistory.addAll(messagesHistory);
68+
this.messagesHistory.addAll(messagesHistory.stream().map(Message::toChatMessage).toList());
7069
return this;
7170
}
7271
}

orchestration/src/main/java/com/sap/ai/sdk/orchestration/SystemMessage.java

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,23 @@
33
import com.sap.ai.sdk.orchestration.client.model.ChatMessage;
44
import javax.annotation.Nonnull;
55

6-
public class SystemMessage extends ChatMessage {
6+
/**
7+
* Represents a chat message as 'system' to the orchestration service.
8+
*
9+
* @param content
10+
*/
11+
public record SystemMessage(@Nonnull String content) implements Message {
712

8-
public SystemMessage(@Nonnull String content) {
9-
this.role("system").content(content);
13+
/** The role of the assistant. */
14+
public static final String ROLE = "system";
15+
16+
/**
17+
* Converts the message to a serializable ChatMessage object.
18+
*
19+
* @return the corresponding {@code ChatMessage} object.
20+
*/
21+
@Nonnull
22+
public ChatMessage toChatMessage() {
23+
return new ChatMessage().role(ROLE).content(content);
1024
}
1125
}

orchestration/src/main/java/com/sap/ai/sdk/orchestration/UserMessage.java

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,23 @@
33
import com.sap.ai.sdk.orchestration.client.model.ChatMessage;
44
import javax.annotation.Nonnull;
55

6-
public class UserMessage extends ChatMessage {
6+
/**
7+
* Represents a chat message as 'user' to the orchestration service.
8+
*
9+
* @param content
10+
*/
11+
public record UserMessage(@Nonnull String content) implements Message {
712

8-
public UserMessage(@Nonnull String content) {
9-
this.role("user").content(content);
13+
/** The role of the assistant. */
14+
public static final String ROLE = "user";
15+
16+
/**
17+
* Converts the message to a serializable ChatMessage object.
18+
*
19+
* @return the corresponding {@code ChatMessage} object.
20+
*/
21+
@Nonnull
22+
public ChatMessage toChatMessage() {
23+
return new ChatMessage().role(ROLE).content(content);
1024
}
1125
}

orchestration/src/test/java/com/sap/ai/sdk/orchestration/ConfigToRequestTransformerTest.java

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ void testEmptyTemplateConfig() {
3333
var systemMessage = new SystemMessage("foo");
3434
var userMessage = new UserMessage("Hello");
3535

36-
var expected = new Template().template(List.of(systemMessage, userMessage));
36+
var expected =
37+
new Template()
38+
.template(List.of(systemMessage.toChatMessage(), userMessage.toChatMessage()));
3739

3840
var prompt = new OrchestrationPrompt(systemMessage, userMessage);
3941
var actual =
@@ -54,10 +56,18 @@ void testMergingTemplateConfig() {
5456
var userMessage = new UserMessage("Hello ");
5557
var userMessage2 = new UserMessage("World");
5658

57-
var expected = new Template().template(List.of(systemMessage, userMessage, userMessage2));
59+
var expected =
60+
new Template()
61+
.template(
62+
List.of(
63+
systemMessage.toChatMessage(),
64+
userMessage.toChatMessage(),
65+
userMessage2.toChatMessage()));
5866

5967
var prompt = new OrchestrationPrompt(userMessage2);
60-
var templateConfig = new Template().template(List.of(systemMessage, userMessage));
68+
var templateConfig =
69+
new Template()
70+
.template(List.of(systemMessage.toChatMessage(), userMessage.toChatMessage()));
6171
var actual = ConfigToRequestTransformer.toTemplateModuleConfig(prompt, templateConfig);
6272

6373
assertThat(actual).isEqualTo(expected);
@@ -72,6 +82,6 @@ void testMessagesHistory() {
7282
ConfigToRequestTransformer.toCompletionPostRequest(
7383
prompt, new OrchestrationModuleConfig().withLlmConfig(CUSTOM_GPT_35));
7484

75-
assertThat(actual.getMessagesHistory()).containsExactly(systemMessage);
85+
assertThat(actual.getMessagesHistory()).containsExactly(systemMessage.toChatMessage());
7686
}
7787
}

orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
import com.github.tomakehurst.wiremock.junit5.WireMockTest;
3030
import com.github.tomakehurst.wiremock.stubbing.Scenario;
3131
import com.sap.ai.sdk.core.AiCoreService;
32-
import com.sap.ai.sdk.orchestration.client.model.ChatMessage;
3332
import com.sap.ai.sdk.orchestration.client.model.CompletionPostRequest;
3433
import com.sap.ai.sdk.orchestration.client.model.DPIEntities;
3534
import com.sap.ai.sdk.orchestration.client.model.GenericModuleResult;
@@ -257,7 +256,7 @@ void messagesHistory() throws IOException {
257256
.withBodyFile("templatingResponse.json")
258257
.withHeader("Content-Type", "application/json")));
259258

260-
final List<ChatMessage> messagesHistory =
259+
final List<Message> messagesHistory =
261260
List.of(
262261
new UserMessage("What is the capital of France?"),
263262
new AssistantMessage("The capital of France is Paris."));

sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import com.sap.ai.sdk.orchestration.AzureContentFilter;
77
import com.sap.ai.sdk.orchestration.AzureFilterThreshold;
88
import com.sap.ai.sdk.orchestration.DpiMasking;
9+
import com.sap.ai.sdk.orchestration.Message;
910
import com.sap.ai.sdk.orchestration.OrchestrationChatResponse;
1011
import com.sap.ai.sdk.orchestration.OrchestrationClient;
1112
import com.sap.ai.sdk.orchestration.OrchestrationModuleConfig;
@@ -72,7 +73,7 @@ public OrchestrationChatResponse template() {
7273
@GetMapping("/messagesHistory")
7374
@Nonnull
7475
public OrchestrationChatResponse messagesHistory() {
75-
final List<ChatMessage> messagesHistory =
76+
final List<Message> messagesHistory =
7677
List.of(
7778
new UserMessage("What is the capital of France?"),
7879
new AssistantMessage("The capital of France is Paris."));

0 commit comments

Comments
 (0)