Skip to content

Commit c971173

Browse files
Jonas-IsrMatKuhr
andauthored
Add Response Convenience (#173)
* Added Response Convenience * Applied review comments * Use Lombok's Value * Renaming and adding tests. * Add and test getTokenUsage * Add and test getAllMessages * Apply requested changes * Apply requested changes * Apply requested changes --------- Co-authored-by: Jonas Israel <[email protected]> Co-authored-by: Matthias Kuhr <[email protected]>
1 parent e495580 commit c971173

File tree

4 files changed

+56
-69
lines changed

4 files changed

+56
-69
lines changed

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,14 @@
22

33
import static lombok.AccessLevel.PACKAGE;
44

5+
import com.sap.ai.sdk.orchestration.client.model.ChatMessage;
56
import com.sap.ai.sdk.orchestration.client.model.CompletionPostResponse;
7+
import com.sap.ai.sdk.orchestration.client.model.LLMChoice;
68
import com.sap.ai.sdk.orchestration.client.model.LLMModuleResultSynchronous;
9+
import com.sap.ai.sdk.orchestration.client.model.TokenUsage;
10+
import java.util.ArrayList;
11+
import java.util.List;
12+
import java.util.Objects;
713
import javax.annotation.Nonnull;
814
import lombok.RequiredArgsConstructor;
915
import lombok.Value;
@@ -24,18 +30,47 @@ public class OrchestrationChatResponse {
2430
*/
2531
@Nonnull
2632
public String getContent() throws OrchestrationClientException {
27-
final var choices =
28-
((LLMModuleResultSynchronous) originalResponse.getOrchestrationResult()).getChoices();
29-
30-
if (choices.isEmpty()) {
31-
return "";
32-
}
33-
34-
final var choice = choices.get(0);
33+
final var choice = getCurrentChoice();
3534

3635
if ("content_filter".equals(choice.getFinishReason())) {
3736
throw new OrchestrationClientException("Content filter filtered the output.");
3837
}
3938
return choice.getMessage().getContent();
4039
}
40+
41+
/**
42+
* Get the token usage.
43+
*
44+
* @return The token usage.
45+
*/
46+
@Nonnull
47+
public TokenUsage getTokenUsage() {
48+
return ((LLMModuleResultSynchronous) originalResponse.getOrchestrationResult()).getUsage();
49+
}
50+
51+
/**
52+
* Get all messages. This can be used for subsequent prompts as a message history.
53+
*
54+
* @return A list of all messages.
55+
*/
56+
@Nonnull
57+
public List<ChatMessage> getAllMessages() {
58+
final var items = Objects.requireNonNull(originalResponse.getModuleResults().getTemplating());
59+
final var messages = new ArrayList<>(items);
60+
messages.add(getCurrentChoice().getMessage());
61+
return messages;
62+
}
63+
64+
/**
65+
* Get current choice.
66+
*
67+
* @return The current choice.
68+
*/
69+
@Nonnull
70+
private LLMChoice getCurrentChoice() {
71+
// We expect choices to be defined and never empty.
72+
return ((LLMModuleResultSynchronous) originalResponse.getOrchestrationResult())
73+
.getChoices()
74+
.get(0);
75+
}
4176
}

orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,9 @@ void testTemplating() throws IOException {
129129

130130
final var response = result.getOriginalResponse();
131131
assertThat(response.getRequestId()).isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91");
132-
assertThat(response.getModuleResults().getTemplating().get(0).getContent())
132+
assertThat(result.getAllMessages().get(0).getContent())
133133
.isEqualTo("Reply with 'Orchestration Service is working!' in German");
134-
assertThat(response.getModuleResults().getTemplating().get(0).getRole()).isEqualTo("user");
134+
assertThat(result.getAllMessages().get(0).getRole()).isEqualTo("user");
135135
var llm = (LLMModuleResultSynchronous) response.getModuleResults().getLlm();
136136
assertThat(llm).isNotNull();
137137
assertThat(llm.getId()).isEqualTo("chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj");
@@ -144,7 +144,7 @@ void testTemplating() throws IOException {
144144
.isEqualTo("Orchestration Service funktioniert!");
145145
assertThat(choices.get(0).getMessage().getRole()).isEqualTo("assistant");
146146
assertThat(choices.get(0).getFinishReason()).isEqualTo("stop");
147-
var usage = llm.getUsage();
147+
var usage = result.getTokenUsage();
148148
assertThat(usage.getCompletionTokens()).isEqualTo(7);
149149
assertThat(usage.getPromptTokens()).isEqualTo(19);
150150
assertThat(usage.getTotalTokens()).isEqualTo(26);
@@ -159,7 +159,7 @@ void testTemplating() throws IOException {
159159
.isEqualTo("Orchestration Service funktioniert!");
160160
assertThat(choices.get(0).getMessage().getRole()).isEqualTo("assistant");
161161
assertThat(choices.get(0).getFinishReason()).isEqualTo("stop");
162-
usage = orchestrationResult.getUsage();
162+
usage = result.getTokenUsage();
163163
assertThat(usage.getCompletionTokens()).isEqualTo(7);
164164
assertThat(usage.getPromptTokens()).isEqualTo(19);
165165
assertThat(usage.getTotalTokens()).isEqualTo(26);
@@ -380,17 +380,4 @@ void testErrorHandling() {
380380

381381
softly.assertAll();
382382
}
383-
384-
@Test
385-
void testEmptyChoicesResponse() {
386-
stubFor(
387-
post(urlPathEqualTo("/v2/inference/deployments/abcdef0123456789/completion"))
388-
.willReturn(
389-
aResponse()
390-
.withBodyFile("emptyChoicesResponse.json")
391-
.withHeader("Content-Type", "application/json")));
392-
final var result = client.chatCompletion(prompt, config);
393-
394-
assertThat(result.getContent()).isEmpty();
395-
}
396383
}

orchestration/src/test/resources/__files/emptyChoicesResponse.json

Lines changed: 0 additions & 35 deletions
This file was deleted.

sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,14 @@ void testTemplate() {
3737
assertThat(controller.config.getLlmConfig()).isNotNull();
3838
final var modelName = controller.config.getLlmConfig().getModelName();
3939

40-
final var response = controller.template();
41-
final var result = response.getOriginalResponse();
40+
final var result = controller.template();
41+
final var response = result.getOriginalResponse();
4242

43-
assertThat(result.getRequestId()).isNotEmpty();
44-
assertThat(result.getModuleResults().getTemplating().get(0).getContent())
43+
assertThat(response.getRequestId()).isNotEmpty();
44+
assertThat(result.getAllMessages().get(0).getContent())
4545
.isEqualTo("Reply with 'Orchestration Service is working!' in German");
46-
assertThat(result.getModuleResults().getTemplating().get(0).getRole()).isEqualTo("user");
47-
var llm = (LLMModuleResultSynchronous) result.getModuleResults().getLlm();
46+
assertThat(result.getAllMessages().get(0).getRole()).isEqualTo("user");
47+
var llm = (LLMModuleResultSynchronous) response.getModuleResults().getLlm();
4848
assertThat(llm.getId()).isNotEmpty();
4949
assertThat(llm.getObject()).isEqualTo("chat.completion");
5050
assertThat(llm.getCreated()).isGreaterThan(1);
@@ -54,12 +54,12 @@ void testTemplate() {
5454
assertThat(choices.get(0).getMessage().getContent()).isNotEmpty();
5555
assertThat(choices.get(0).getMessage().getRole()).isEqualTo("assistant");
5656
assertThat(choices.get(0).getFinishReason()).isEqualTo("stop");
57-
var usage = llm.getUsage();
57+
var usage = result.getTokenUsage();
5858
assertThat(usage.getCompletionTokens()).isGreaterThan(1);
5959
assertThat(usage.getPromptTokens()).isGreaterThan(1);
6060
assertThat(usage.getTotalTokens()).isGreaterThan(1);
6161

62-
var orchestrationResult = ((LLMModuleResultSynchronous) result.getOrchestrationResult());
62+
var orchestrationResult = ((LLMModuleResultSynchronous) response.getOrchestrationResult());
6363
assertThat(orchestrationResult.getObject()).isEqualTo("chat.completion");
6464
assertThat(orchestrationResult.getCreated()).isGreaterThan(1);
6565
assertThat(orchestrationResult.getModel()).isEqualTo(modelName);
@@ -68,7 +68,7 @@ void testTemplate() {
6868
assertThat(choices.get(0).getMessage().getContent()).isNotEmpty();
6969
assertThat(choices.get(0).getMessage().getRole()).isEqualTo("assistant");
7070
assertThat(choices.get(0).getFinishReason()).isEqualTo("stop");
71-
usage = orchestrationResult.getUsage();
71+
usage = result.getTokenUsage();
7272
assertThat(usage.getCompletionTokens()).isGreaterThan(1);
7373
assertThat(usage.getPromptTokens()).isGreaterThan(1);
7474
assertThat(usage.getTotalTokens()).isGreaterThan(1);

0 commit comments

Comments
 (0)