Skip to content

Commit 173530d

Browse files
Jonas-IsrMatKuhr
andauthored
Added Response Convenience (#157)
* Added Response Convenience * Applied review comments * Use Lombok's Value * Renaming and adding tests. --------- Co-authored-by: Jonas Israel <[email protected]> Co-authored-by: Matthias Kuhr <[email protected]>
1 parent 50b3c47 commit 173530d

File tree

7 files changed

+124
-37
lines changed

7 files changed

+124
-37
lines changed

docs/guides/ORCHESTRATION_CHAT_COMPLETION.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,7 @@ var prompt = new OrchestrationPrompt("Hello world! Why is this phrase so famous?
9191

9292
var result = client.chatCompletion(prompt, config);
9393

94-
String messageResult =
95-
result.getOrchestrationResult().getChoices().get(0).getMessage().getContent();
94+
String messageResult = result.getContent();
9695
```
9796

9897
In this example, the Orchestration service generates a response to the user message "Hello world! Why is this phrase so famous?".
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import static lombok.AccessLevel.PACKAGE;
4+
5+
import com.sap.ai.sdk.orchestration.client.model.CompletionPostResponse;
6+
import com.sap.ai.sdk.orchestration.client.model.LLMModuleResultSynchronous;
7+
import javax.annotation.Nonnull;
8+
import lombok.RequiredArgsConstructor;
9+
import lombok.Value;
10+
11+
/** Orchestration chat completion output. */
12+
@Value
13+
@RequiredArgsConstructor(access = PACKAGE)
14+
public class OrchestrationChatResponse {
15+
CompletionPostResponse originalResponse;
16+
17+
/**
18+
* Get the message content from the output.
19+
*
20+
* <p>Note: If there are multiple choices only the first one is returned
21+
*
22+
* @return the message content or empty string.
23+
* @throws OrchestrationClientException if the content filter filtered the output.
24+
*/
25+
@Nonnull
26+
public String getContent() throws OrchestrationClientException {
27+
final var choices =
28+
((LLMModuleResultSynchronous) originalResponse.getOrchestrationResult()).getChoices();
29+
30+
if (choices.isEmpty()) {
31+
return "";
32+
}
33+
34+
final var choice = choices.get(0);
35+
36+
if ("content_filter".equals(choice.getFinishReason())) {
37+
throw new OrchestrationClientException("Content filter filtered the output.");
38+
}
39+
return choice.getMessage().getContent();
40+
}
41+
}

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,12 @@ public static CompletionPostRequest toCompletionPostRequest(
101101
* @throws OrchestrationClientException if the request fails.
102102
*/
103103
@Nonnull
104-
public CompletionPostResponse chatCompletion(
104+
public OrchestrationChatResponse chatCompletion(
105105
@Nonnull final OrchestrationPrompt prompt, @Nonnull final OrchestrationModuleConfig config)
106106
throws OrchestrationClientException {
107107

108108
val request = toCompletionPostRequest(prompt, config);
109-
return executeRequest(request);
109+
return new OrchestrationChatResponse(executeRequest(request));
110110
}
111111

112112
/**

orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,7 @@ void testCompletion() {
121121
final var result = client.chatCompletion(prompt, config);
122122

123123
assertThat(result).isNotNull();
124-
var orchestrationResult = (LLMModuleResultSynchronous) result.getOrchestrationResult();
125-
assertThat(orchestrationResult.getChoices().get(0).getMessage().getContent()).isNotEmpty();
124+
assertThat(result.getContent()).isNotEmpty();
126125
}
127126

128127
@Test
@@ -141,11 +140,12 @@ void testTemplating() throws IOException {
141140
final var result =
142141
client.chatCompletion(new OrchestrationPrompt(inputParams, template), config);
143142

144-
assertThat(result.getRequestId()).isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91");
145-
assertThat(result.getModuleResults().getTemplating().get(0).getContent())
143+
final var response = result.getOriginalResponse();
144+
assertThat(response.getRequestId()).isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91");
145+
assertThat(response.getModuleResults().getTemplating().get(0).getContent())
146146
.isEqualTo("Reply with 'Orchestration Service is working!' in German");
147-
assertThat(result.getModuleResults().getTemplating().get(0).getRole()).isEqualTo("user");
148-
var llm = (LLMModuleResultSynchronous) result.getModuleResults().getLlm();
147+
assertThat(response.getModuleResults().getTemplating().get(0).getRole()).isEqualTo("user");
148+
var llm = (LLMModuleResultSynchronous) response.getModuleResults().getLlm();
149149
assertThat(llm.getId()).isEqualTo("chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj");
150150
assertThat(llm.getObject()).isEqualTo("chat.completion");
151151
assertThat(llm.getCreated()).isEqualTo(1721224505);
@@ -160,7 +160,7 @@ void testTemplating() throws IOException {
160160
assertThat(usage.getCompletionTokens()).isEqualTo(7);
161161
assertThat(usage.getPromptTokens()).isEqualTo(19);
162162
assertThat(usage.getTotalTokens()).isEqualTo(26);
163-
var orchestrationResult = (LLMModuleResultSynchronous) result.getOrchestrationResult();
163+
var orchestrationResult = (LLMModuleResultSynchronous) response.getOrchestrationResult();
164164
assertThat(orchestrationResult.getId()).isEqualTo("chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj");
165165
assertThat(orchestrationResult.getObject()).isEqualTo("chat.completion");
166166
assertThat(orchestrationResult.getCreated()).isEqualTo(1721224505);
@@ -286,7 +286,8 @@ void messagesHistory() throws IOException {
286286

287287
final var result = client.chatCompletion(prompt, config);
288288

289-
assertThat(result.getRequestId()).isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91");
289+
assertThat(result.getOriginalResponse().getRequestId())
290+
.isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91");
290291

291292
// verify that the history is sent correctly
292293
try (var requestInputStream = fileLoader.apply("messagesHistoryRequest.json")) {
@@ -310,13 +311,13 @@ void maskingPseudonymization() throws IOException {
310311
createMaskingConfig(DPIConfig.MethodEnum.PSEUDONYMIZATION, DPIEntities.PHONE);
311312

312313
final var result = client.chatCompletion(prompt, config.withMaskingConfig(maskingConfig));
314+
final var response = result.getOriginalResponse();
313315

314-
assertThat(result).isNotNull();
315-
GenericModuleResult inputMasking = result.getModuleResults().getInputMasking();
316+
assertThat(response).isNotNull();
317+
GenericModuleResult inputMasking = response.getModuleResults().getInputMasking();
316318
assertThat(inputMasking.getMessage()).isEqualTo("Input to LLM is masked successfully.");
317319
assertThat(inputMasking.getData()).isNotNull();
318-
final var choices = ((LLMModuleResultSynchronous) result.getOrchestrationResult()).getChoices();
319-
assertThat(choices.get(0).getMessage().getContent()).contains("Hi Mallory");
320+
assertThat(result.getContent()).contains("Hi Mallory");
320321

321322
// verify that the request is sent correctly
322323
try (var requestInputStream = fileLoader.apply("maskingRequest.json")) {
@@ -412,4 +413,17 @@ void testErrorHandling() {
412413

413414
softly.assertAll();
414415
}
416+
417+
@Test
418+
void testEmptyChoicesResponse() {
419+
stubFor(
420+
post(urlPathEqualTo("/v2/inference/deployments/abcdef0123456789/completion"))
421+
.willReturn(
422+
aResponse()
423+
.withBodyFile("emptyChoicesResponse.json")
424+
.withHeader("Content-Type", "application/json")));
425+
final var result = client.chatCompletion(prompt, config);
426+
427+
assertThat(result.getContent()).isEmpty();
428+
}
415429
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
{
2+
"request_id": "26ea36b5-c196-4806-a9a6-a686f0c6ad91",
3+
"module_results": {
4+
"templating": [
5+
{
6+
"role": "user",
7+
"content": "Reply with 'Orchestration Service is working!' in German"
8+
}
9+
],
10+
"llm": {
11+
"id": "chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj",
12+
"object": "chat.completion",
13+
"created": 1721224505,
14+
"model": "gpt-35-turbo-16k",
15+
"choices": [],
16+
"usage": {
17+
"completion_tokens": 7,
18+
"prompt_tokens": 19,
19+
"total_tokens": 26
20+
}
21+
}
22+
},
23+
"orchestration_result": {
24+
"id": "chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj",
25+
"object": "chat.completion",
26+
"created": 1721224505,
27+
"model": "gpt-35-turbo-16k",
28+
"choices": [],
29+
"usage": {
30+
"completion_tokens": 7,
31+
"prompt_tokens": 19,
32+
"total_tokens": 26
33+
}
34+
}
35+
}

sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
package com.sap.ai.sdk.app.controllers;
22

3+
import com.sap.ai.sdk.orchestration.OrchestrationChatResponse;
34
import com.sap.ai.sdk.orchestration.OrchestrationClient;
45
import com.sap.ai.sdk.orchestration.OrchestrationModuleConfig;
56
import com.sap.ai.sdk.orchestration.OrchestrationPrompt;
67
import com.sap.ai.sdk.orchestration.client.model.AzureContentSafety;
78
import com.sap.ai.sdk.orchestration.client.model.AzureContentSafetyFilterConfig;
89
import com.sap.ai.sdk.orchestration.client.model.AzureThreshold;
910
import com.sap.ai.sdk.orchestration.client.model.ChatMessage;
10-
import com.sap.ai.sdk.orchestration.client.model.CompletionPostResponse;
1111
import com.sap.ai.sdk.orchestration.client.model.DPIConfig;
1212
import com.sap.ai.sdk.orchestration.client.model.DPIEntities;
1313
import com.sap.ai.sdk.orchestration.client.model.DPIEntityConfig;
@@ -44,7 +44,7 @@ class OrchestrationController {
4444
*/
4545
@GetMapping("/completion")
4646
@Nonnull
47-
public CompletionPostResponse completion() {
47+
public OrchestrationChatResponse completion() {
4848
final var prompt = new OrchestrationPrompt("Hello world! Why is this phrase so famous?");
4949

5050
return client.chatCompletion(prompt, config);
@@ -57,7 +57,7 @@ public CompletionPostResponse completion() {
5757
*/
5858
@GetMapping("/template")
5959
@Nonnull
60-
public CompletionPostResponse template() {
60+
public OrchestrationChatResponse template() {
6161
final var template =
6262
new ChatMessage()
6363
.role("user")
@@ -78,7 +78,7 @@ public CompletionPostResponse template() {
7878
*/
7979
@GetMapping("/messagesHistory")
8080
@Nonnull
81-
public CompletionPostResponse messagesHistory() {
81+
public OrchestrationChatResponse messagesHistory() {
8282
final List<ChatMessage> messagesHistory =
8383
List.of(
8484
new ChatMessage().role("user").content("What is the capital of France?"),
@@ -98,7 +98,7 @@ public CompletionPostResponse messagesHistory() {
9898
*/
9999
@GetMapping("/filter/{threshold}")
100100
@Nonnull
101-
public CompletionPostResponse filter(
101+
public OrchestrationChatResponse filter(
102102
@Nonnull @PathVariable("threshold") final AzureThreshold threshold) {
103103
final var prompt =
104104
new OrchestrationPrompt(
@@ -145,7 +145,7 @@ private static FilteringModuleConfig createAzureContentFilter(
145145
*/
146146
@GetMapping("/maskingAnonymization")
147147
@Nonnull
148-
public CompletionPostResponse maskingAnonymization() {
148+
public OrchestrationChatResponse maskingAnonymization() {
149149
final var systemMessage =
150150
new ChatMessage()
151151
.role("system")
@@ -176,7 +176,7 @@ public CompletionPostResponse maskingAnonymization() {
176176
*/
177177
@GetMapping("/maskingPseudonymization")
178178
@Nonnull
179-
public CompletionPostResponse maskingPseudonymization() {
179+
public OrchestrationChatResponse maskingPseudonymization() {
180180
final var systemMessage =
181181
new ChatMessage()
182182
.role("system")

sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,13 @@ void testCompletion() {
2727
final var result = controller.completion();
2828

2929
assertThat(result).isNotNull();
30-
assertThat(
31-
((LLMModuleResultSynchronous) result.getOrchestrationResult())
32-
.getChoices()
33-
.get(0)
34-
.getMessage()
35-
.getContent())
36-
.isNotEmpty();
30+
assertThat(result.getContent()).isNotEmpty();
3731
}
3832

3933
@Test
4034
void testTemplate() {
41-
final var result = controller.template();
35+
final var response = controller.template();
36+
final var result = response.getOriginalResponse();
4237

4338
assertThat(result.getRequestId()).isNotEmpty();
4439
assertThat(result.getModuleResults().getTemplating().get(0).getContent())
@@ -64,20 +59,21 @@ void testTemplate() {
6459
assertThat(orchestrationResult.getCreated()).isGreaterThan(1);
6560
assertThat(orchestrationResult.getModel())
6661
.isEqualTo(OrchestrationController.LLM_CONFIG.getModelName());
67-
choices = ((LLMModuleResultSynchronous) orchestrationResult).getChoices();
62+
choices = orchestrationResult.getChoices();
6863
assertThat(choices.get(0).getIndex()).isZero();
6964
assertThat(choices.get(0).getMessage().getContent()).isNotEmpty();
7065
assertThat(choices.get(0).getMessage().getRole()).isEqualTo("assistant");
7166
assertThat(choices.get(0).getFinishReason()).isEqualTo("stop");
72-
usage = ((LLMModuleResultSynchronous) orchestrationResult).getUsage();
67+
usage = orchestrationResult.getUsage();
7368
assertThat(usage.getCompletionTokens()).isGreaterThan(1);
7469
assertThat(usage.getPromptTokens()).isGreaterThan(1);
7570
assertThat(usage.getTotalTokens()).isGreaterThan(1);
7671
}
7772

7873
@Test
7974
void testLenientContentFilter() {
80-
var result = controller.filter(AzureThreshold.NUMBER_4);
75+
var response = controller.filter(AzureThreshold.NUMBER_4);
76+
var result = response.getOriginalResponse();
8177
var llmChoice =
8278
((LLMModuleResultSynchronous) result.getOrchestrationResult()).getChoices().get(0);
8379
assertThat(llmChoice.getFinishReason()).isEqualTo("stop");
@@ -97,15 +93,16 @@ void testStrictContentFilter() {
9793

9894
@Test
9995
void testMessagesHistory() {
100-
CompletionPostResponse result = controller.messagesHistory();
96+
CompletionPostResponse result = controller.messagesHistory().getOriginalResponse();
10197
final var choices = ((LLMModuleResultSynchronous) result.getOrchestrationResult()).getChoices();
10298
assertThat(choices.get(0).getMessage().getContent()).isNotEmpty();
10399
}
104100

105101
@SuppressWarnings("unchecked")
106102
@Test
107103
void testMaskingAnonymization() {
108-
var result = controller.maskingAnonymization();
104+
var response = controller.maskingAnonymization();
105+
var result = response.getOriginalResponse();
109106
var llmChoice =
110107
((LLMModuleResultSynchronous) result.getOrchestrationResult()).getChoices().get(0);
111108
assertThat(llmChoice.getFinishReason()).isEqualTo("stop");
@@ -124,7 +121,8 @@ void testMaskingAnonymization() {
124121
@SuppressWarnings("unchecked")
125122
@Test
126123
void testMaskingPseudonymization() {
127-
var result = controller.maskingPseudonymization();
124+
var response = controller.maskingPseudonymization();
125+
var result = response.getOriginalResponse();
128126
var llmChoice =
129127
((LLMModuleResultSynchronous) result.getOrchestrationResult()).getChoices().get(0);
130128
assertThat(llmChoice.getFinishReason()).isEqualTo("stop");

0 commit comments

Comments
 (0)