Skip to content

Commit f5fb9fd

Browse files
committed
Renaming and adding tests.
1 parent b99270b commit f5fb9fd

File tree

7 files changed

+79
-21
lines changed

7 files changed

+79
-21
lines changed

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiChatCompletionOutput.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ public class OpenAiChatCompletionOutput extends OpenAiCompletionOutput
4040
*/
4141
@Nonnull
4242
public String getContent() throws OpenAiClientException {
43+
if (getChoices().isEmpty()) {
44+
return "";
45+
}
4346
if ("content_filter".equals(getChoices().get(0).getFinishReason())) {
4447
throw new OpenAiClientException("Content filter filtered the output.");
4548
}

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationResponse.java renamed to orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatResponse.java

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
/** Orchestration chat completion output. */
1212
@Value
1313
@RequiredArgsConstructor(access = PACKAGE)
14-
public class OrchestrationResponse {
15-
CompletionPostResponse data;
14+
public class OrchestrationChatResponse {
15+
CompletionPostResponse originalResponse;
1616

1717
/**
1818
* Get the message content from the output.
@@ -24,8 +24,14 @@ public class OrchestrationResponse {
2424
*/
2525
@Nonnull
2626
public String getContent() throws OrchestrationClientException {
27-
final var choice =
28-
((LLMModuleResultSynchronous) data.getOrchestrationResult()).getChoices().get(0);
27+
final var choices =
28+
((LLMModuleResultSynchronous) originalResponse.getOrchestrationResult()).getChoices();
29+
30+
if (choices.isEmpty()) {
31+
return "";
32+
}
33+
34+
final var choice = choices.get(0);
2935

3036
if ("content_filter".equals(choice.getFinishReason())) {
3137
throw new OrchestrationClientException("Content filter filtered the output.");

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,12 @@ public static CompletionPostRequest toCompletionPostRequest(
101101
* @throws OrchestrationClientException if the request fails.
102102
*/
103103
@Nonnull
104-
public OrchestrationResponse chatCompletion(
104+
public OrchestrationChatResponse chatCompletion(
105105
@Nonnull final OrchestrationPrompt prompt, @Nonnull final OrchestrationModuleConfig config)
106106
throws OrchestrationClientException {
107107

108108
val request = toCompletionPostRequest(prompt, config);
109-
return new OrchestrationResponse(executeRequest(request));
109+
return new OrchestrationChatResponse(executeRequest(request));
110110
}
111111

112112
/**

orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ void testTemplating() throws IOException {
140140
final var result =
141141
client.chatCompletion(new OrchestrationPrompt(inputParams, template), config);
142142

143-
final var response = result.getData();
143+
final var response = result.getOriginalResponse();
144144
assertThat(response.getRequestId()).isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91");
145145
assertThat(response.getModuleResults().getTemplating().get(0).getContent())
146146
.isEqualTo("Reply with 'Orchestration Service is working!' in German");
@@ -286,7 +286,8 @@ void messagesHistory() throws IOException {
286286

287287
final var result = client.chatCompletion(prompt, config);
288288

289-
assertThat(result.getData().getRequestId()).isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91");
289+
assertThat(result.getOriginalResponse().getRequestId())
290+
.isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91");
290291

291292
// verify that the history is sent correctly
292293
try (var requestInputStream = fileLoader.apply("messagesHistoryRequest.json")) {
@@ -310,7 +311,7 @@ void maskingPseudonymization() throws IOException {
310311
createMaskingConfig(DPIConfig.MethodEnum.PSEUDONYMIZATION, DPIEntities.PHONE);
311312

312313
final var result = client.chatCompletion(prompt, config.withMaskingConfig(maskingConfig));
313-
final var response = result.getData();
314+
final var response = result.getOriginalResponse();
314315

315316
assertThat(response).isNotNull();
316317
GenericModuleResult inputMasking = response.getModuleResults().getInputMasking();
@@ -412,4 +413,17 @@ void testErrorHandling() {
412413

413414
softly.assertAll();
414415
}
416+
417+
@Test
418+
void testEmptyChoicesResponse() {
419+
stubFor(
420+
post(urlPathEqualTo("/v2/inference/deployments/abcdef0123456789/completion"))
421+
.willReturn(
422+
aResponse()
423+
.withBodyFile("emptyChoicesResponse.json")
424+
.withHeader("Content-Type", "application/json")));
425+
final var result = client.chatCompletion(prompt, config);
426+
427+
assertThat(result.getContent()).isEmpty();
428+
}
415429
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
{
2+
"request_id": "26ea36b5-c196-4806-a9a6-a686f0c6ad91",
3+
"module_results": {
4+
"templating": [
5+
{
6+
"role": "user",
7+
"content": "Reply with 'Orchestration Service is working!' in German"
8+
}
9+
],
10+
"llm": {
11+
"id": "chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj",
12+
"object": "chat.completion",
13+
"created": 1721224505,
14+
"model": "gpt-35-turbo-16k",
15+
"choices": [],
16+
"usage": {
17+
"completion_tokens": 7,
18+
"prompt_tokens": 19,
19+
"total_tokens": 26
20+
}
21+
}
22+
},
23+
"orchestration_result": {
24+
"id": "chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj",
25+
"object": "chat.completion",
26+
"created": 1721224505,
27+
"model": "gpt-35-turbo-16k",
28+
"choices": [],
29+
"usage": {
30+
"completion_tokens": 7,
31+
"prompt_tokens": 19,
32+
"total_tokens": 26
33+
}
34+
}
35+
}

sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
package com.sap.ai.sdk.app.controllers;
22

3+
import com.sap.ai.sdk.orchestration.OrchestrationChatResponse;
34
import com.sap.ai.sdk.orchestration.OrchestrationClient;
45
import com.sap.ai.sdk.orchestration.OrchestrationModuleConfig;
56
import com.sap.ai.sdk.orchestration.OrchestrationPrompt;
6-
import com.sap.ai.sdk.orchestration.OrchestrationResponse;
77
import com.sap.ai.sdk.orchestration.client.model.AzureContentSafety;
88
import com.sap.ai.sdk.orchestration.client.model.AzureContentSafetyFilterConfig;
99
import com.sap.ai.sdk.orchestration.client.model.AzureThreshold;
@@ -44,7 +44,7 @@ class OrchestrationController {
4444
*/
4545
@GetMapping("/completion")
4646
@Nonnull
47-
public OrchestrationResponse completion() {
47+
public OrchestrationChatResponse completion() {
4848
final var prompt = new OrchestrationPrompt("Hello world! Why is this phrase so famous?");
4949

5050
return client.chatCompletion(prompt, config);
@@ -57,7 +57,7 @@ public OrchestrationResponse completion() {
5757
*/
5858
@GetMapping("/template")
5959
@Nonnull
60-
public OrchestrationResponse template() {
60+
public OrchestrationChatResponse template() {
6161
final var template =
6262
new ChatMessage()
6363
.role("user")
@@ -78,7 +78,7 @@ public OrchestrationResponse template() {
7878
*/
7979
@GetMapping("/messagesHistory")
8080
@Nonnull
81-
public OrchestrationResponse messagesHistory() {
81+
public OrchestrationChatResponse messagesHistory() {
8282
final List<ChatMessage> messagesHistory =
8383
List.of(
8484
new ChatMessage().role("user").content("What is the capital of France?"),
@@ -98,7 +98,7 @@ public OrchestrationResponse messagesHistory() {
9898
*/
9999
@GetMapping("/filter/{threshold}")
100100
@Nonnull
101-
public OrchestrationResponse filter(
101+
public OrchestrationChatResponse filter(
102102
@Nonnull @PathVariable("threshold") final AzureThreshold threshold) {
103103
final var prompt =
104104
new OrchestrationPrompt(
@@ -145,7 +145,7 @@ private static FilteringModuleConfig createAzureContentFilter(
145145
*/
146146
@GetMapping("/maskingAnonymization")
147147
@Nonnull
148-
public OrchestrationResponse maskingAnonymization() {
148+
public OrchestrationChatResponse maskingAnonymization() {
149149
final var systemMessage =
150150
new ChatMessage()
151151
.role("system")
@@ -176,7 +176,7 @@ public OrchestrationResponse maskingAnonymization() {
176176
*/
177177
@GetMapping("/maskingPseudonymization")
178178
@Nonnull
179-
public OrchestrationResponse maskingPseudonymization() {
179+
public OrchestrationChatResponse maskingPseudonymization() {
180180
final var systemMessage =
181181
new ChatMessage()
182182
.role("system")

sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ void testCompletion() {
3333
@Test
3434
void testTemplate() {
3535
final var response = controller.template();
36-
final var result = response.getData();
36+
final var result = response.getOriginalResponse();
3737

3838
assertThat(result.getRequestId()).isNotEmpty();
3939
assertThat(result.getModuleResults().getTemplating().get(0).getContent())
@@ -73,7 +73,7 @@ void testTemplate() {
7373
@Test
7474
void testLenientContentFilter() {
7575
var response = controller.filter(AzureThreshold.NUMBER_4);
76-
var result = response.getData();
76+
var result = response.getOriginalResponse();
7777
var llmChoice =
7878
((LLMModuleResultSynchronous) result.getOrchestrationResult()).getChoices().get(0);
7979
assertThat(llmChoice.getFinishReason()).isEqualTo("stop");
@@ -93,7 +93,7 @@ void testStrictContentFilter() {
9393

9494
@Test
9595
void testMessagesHistory() {
96-
CompletionPostResponse result = controller.messagesHistory().getData();
96+
CompletionPostResponse result = controller.messagesHistory().getOriginalResponse();
9797
final var choices = ((LLMModuleResultSynchronous) result.getOrchestrationResult()).getChoices();
9898
assertThat(choices.get(0).getMessage().getContent()).isNotEmpty();
9999
}
@@ -102,7 +102,7 @@ void testMessagesHistory() {
102102
@Test
103103
void testMaskingAnonymization() {
104104
var response = controller.maskingAnonymization();
105-
var result = response.getData();
105+
var result = response.getOriginalResponse();
106106
var llmChoice =
107107
((LLMModuleResultSynchronous) result.getOrchestrationResult()).getChoices().get(0);
108108
assertThat(llmChoice.getFinishReason()).isEqualTo("stop");
@@ -122,7 +122,7 @@ void testMaskingAnonymization() {
122122
@Test
123123
void testMaskingPseudonymization() {
124124
var response = controller.maskingPseudonymization();
125-
var result = response.getData();
125+
var result = response.getOriginalResponse();
126126
var llmChoice =
127127
((LLMModuleResultSynchronous) result.getOrchestrationResult()).getChoices().get(0);
128128
assertThat(llmChoice.getFinishReason()).isEqualTo("stop");

0 commit comments

Comments
 (0)