Skip to content

Commit ea5f239

Browse files
committed
Fix type of Orchestration Response Messages
1 parent 8a92c34 commit ea5f239

File tree

4 files changed

+50
-16
lines changed

4 files changed

+50
-16
lines changed

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ public String chatCompletion(@Nonnull final String userPrompt)
8282
response.originalResponseDto());
8383
throw new OrchestrationClientException("Output content filter triggered");
8484
}
85-
return response.assistantMessage().getContent();
85+
return response.assistantMessage().content();
8686
}
8787

8888
/**

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationResponse.java

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import java.util.Arrays;
88
import java.util.List;
99
import javax.annotation.Nonnull;
10+
import javax.annotation.Nullable;
1011
import lombok.RequiredArgsConstructor;
1112
import lombok.val;
1213

@@ -21,8 +22,8 @@
2122
* result details.
2223
*/
2324
public record OrchestrationResponse(
24-
@Nonnull ChatMessage assistantMessage,
25-
@Nonnull List<ChatMessage> allMessages,
25+
@Nonnull AssistantMessage assistantMessage,
26+
@Nonnull List<Message> allMessages,
2627
@Nonnull FinishReason finishReason,
2728
@Nonnull TokenUsage tokenUsage,
2829
@Nonnull CompletionPostResponse originalResponseDto) {
@@ -55,11 +56,37 @@ static FinishReason fromValue(@Nonnull final String value) {
5556
static OrchestrationResponse fromCompletionPostResponseDTO(
5657
@Nonnull final CompletionPostResponse response) {
5758
val choice = response.getOrchestrationResult().getChoices().get(0);
58-
val message = choice.getMessage();
59+
val message = new AssistantMessage(choice.getMessage().getContent());
5960
val finishReason = FinishReason.fromValue(choice.getFinishReason());
6061
val tokenUsage = response.getOrchestrationResult().getUsage();
61-
val allMessages = new ArrayList<>(response.getModuleResults().getTemplating());
62+
val allMessages = new ArrayList<Message>();
63+
response.getModuleResults().getTemplating().stream()
64+
.map(OrchestrationResponse::fromChatMessage)
65+
.forEach(allMessages::add);
6266
allMessages.add(message);
6367
return new OrchestrationResponse(message, allMessages, finishReason, tokenUsage, response);
6468
}
69+
70+
@Nonnull
71+
static Message fromChatMessage(@Nonnull final ChatMessage chatMessage) {
72+
return switch (chatMessage.getRole()) {
73+
case "system" -> new SystemMessage(chatMessage.getContent());
74+
case "user" -> new UserMessage(chatMessage.getContent());
75+
case "assistant" -> new AssistantMessage(chatMessage.getContent());
76+
default ->
77+
new Message() {
78+
@Nonnull
79+
@Override
80+
public String type() {
81+
return chatMessage.getRole();
82+
}
83+
84+
@Nullable
85+
@Override
86+
public String content() {
87+
return chatMessage.getContent();
88+
}
89+
};
90+
};
91+
}
6592
}

orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationResponseTest.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
class OrchestrationResponseTest {
1616
@Test
1717
void testFromCompletionPostResponseDTO() {
18-
var message1 = mock(ChatMessage.class);
19-
var message2 = mock(ChatMessage.class);
20-
var message3 = mock(ChatMessage.class);
18+
var message1 = ChatMessage.create().role("system").content("foo");
19+
var message2 = ChatMessage.create().role("user").content("bar");
20+
var message3 = ChatMessage.create().role("assistant").content("baz");
2121
var moduleResults = ModuleResults.create().templating(List.of(message1, message2));
2222

2323
var orchestrationResult =
@@ -37,8 +37,10 @@ void testFromCompletionPostResponseDTO() {
3737

3838
var result = OrchestrationResponse.fromCompletionPostResponseDTO(postResponse);
3939

40-
assertThat(result.assistantMessage()).isSameAs(message3);
41-
assertThat(result.allMessages()).containsExactly(message1, message2, message3);
40+
assertThat(result.assistantMessage()).isEqualTo(new AssistantMessage("baz"));
41+
assertThat(result.allMessages())
42+
.containsExactly(
43+
new SystemMessage("foo"), new UserMessage("bar"), new AssistantMessage("baz"));
4244
assertThat(result.finishReason()).isEqualTo(OrchestrationResponse.FinishReason.STOP);
4345
assertThat(result.originalResponseDto()).isSameAs(postResponse);
4446
}

sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
import static org.assertj.core.api.Assertions.assertThatThrownBy;
66

77
import com.sap.ai.sdk.foundationmodels.openai.OpenAiModel;
8+
import com.sap.ai.sdk.orchestration.AssistantMessage;
89
import com.sap.ai.sdk.orchestration.AzureContentFilter;
910
import com.sap.ai.sdk.orchestration.OrchestrationClientException;
1011
import com.sap.ai.sdk.orchestration.OrchestrationResponse;
12+
import com.sap.ai.sdk.orchestration.UserMessage;
1113
import java.util.List;
1214
import java.util.Map;
1315
import org.assertj.core.api.InstanceOfAssertFactories;
@@ -27,8 +29,11 @@ void testCompletion() {
2729
final var response = controller.completion();
2830

2931
assertThat(response.finishReason()).isEqualTo(OrchestrationResponse.FinishReason.STOP);
30-
assertThat(response.assistantMessage().getContent()).isNotEmpty();
31-
assertThat(response.assistantMessage().getRole()).isEqualTo("assistant");
32+
assertThat(response.assistantMessage())
33+
.isInstanceOf(AssistantMessage.class)
34+
.extracting(AssistantMessage::content)
35+
.asString()
36+
.isNotEmpty();
3237
assertThat(response.tokenUsage().getPromptTokens()).isPositive();
3338
assertThat(response.tokenUsage().getCompletionTokens()).isPositive();
3439
assertThat(response.tokenUsage().getTotalTokens()).isPositive();
@@ -70,16 +75,16 @@ void testTemplate() {
7075
var result = controller.template();
7176

7277
var templateResult = result.allMessages().get(0);
73-
assertThat(templateResult.getContent())
78+
assertThat(templateResult.content())
7479
.isEqualTo("Reply with 'The Orchestration Service is working!' in german");
75-
assertThat(templateResult.getRole()).isEqualTo("user");
80+
assertThat(templateResult).isInstanceOf(UserMessage.class);
7681
}
7782

7883
@Test
7984
void testLenientContentFilter() {
8085
var result = controller.filter(AzureContentFilter.Sensitivity.LOW);
8186
assertThat(result.finishReason()).isEqualTo(OrchestrationResponse.FinishReason.STOP);
82-
assertThat(result.assistantMessage().getContent()).isNotEmpty();
87+
assertThat(result.assistantMessage().content()).isNotEmpty();
8388

8489
var filterResult = result.originalResponseDto().getModuleResults().getInputFiltering();
8590
assertThat(filterResult.getMessage()).contains("passed");
@@ -98,7 +103,7 @@ void testStrictContentFilter() {
98103
void testMasking() {
99104
var result = controller.masking();
100105

101-
assertThat(result.assistantMessage().getContent()).contains("[email protected]");
106+
assertThat(result.assistantMessage().content()).contains("[email protected]");
102107
var maskingResult = result.originalResponseDto().getModuleResults().getInputMasking();
103108
var data = (Map<String, Object>) maskingResult.getData();
104109
var maskedMessage = ((List<Map<String, Object>>) data.get("masked_template")).get(0);

0 commit comments

Comments
 (0)