Skip to content

Commit c657c06

Browse files
committed
Applied review comments
1 parent 45ae6fa commit c657c06

File tree

6 files changed

+40
-31
lines changed

6 files changed

+40
-31
lines changed

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiChatCompletionOutput.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,6 @@ public class OpenAiChatCompletionOutput extends OpenAiCompletionOutput
4040
*/
4141
@Nonnull
4242
public String getContent() throws OpenAiClientException {
43-
if (getChoices().isEmpty()) {
44-
return "";
45-
}
4643
if ("content_filter".equals(getChoices().get(0).getFinishReason())) {
4744
throw new OpenAiClientException("Content filter filtered the output.");
4845
}

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ public OrchestrationResponse chatCompletion(
104104
throws OrchestrationClientException {
105105

106106
val request = toCompletionPostRequest(prompt, config);
107-
return executeRequest(request);
107+
return new OrchestrationResponse(executeRequest(request));
108108
}
109109

110110
/**
@@ -128,7 +128,7 @@ public OrchestrationResponse chatCompletion(
128128
* @throws OrchestrationClientException If the request fails.
129129
*/
130130
@Nonnull
131-
public OrchestrationResponse executeRequest(@Nonnull final CompletionPostRequest request)
131+
public CompletionPostResponse executeRequest(@Nonnull final CompletionPostRequest request)
132132
throws OrchestrationClientException {
133133
final BasicClassicHttpRequest postRequest = new HttpPost("/completion");
134134
try {
@@ -143,13 +143,13 @@ public OrchestrationResponse executeRequest(@Nonnull final CompletionPostRequest
143143
}
144144

145145
@Nonnull
146-
OrchestrationResponse executeRequest(@Nonnull final BasicClassicHttpRequest request) {
146+
CompletionPostResponse executeRequest(@Nonnull final BasicClassicHttpRequest request) {
147147
try {
148148
val destination = deployment.get().destination();
149149
log.debug("Using destination {} to connect to orchestration service", destination);
150150
val client = ApacheHttpClient5Accessor.getHttpClient(destination);
151151
return client.execute(
152-
request, new OrchestrationResponseHandler<>(OrchestrationResponse.class));
152+
request, new OrchestrationResponseHandler<>(CompletionPostResponse.class));
153153
} catch (NoSuchElementException
154154
| DestinationAccessException
155155
| DestinationNotFoundException
Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
package com.sap.ai.sdk.orchestration;
22

33
import com.sap.ai.sdk.orchestration.client.model.CompletionPostResponse;
4+
import com.sap.ai.sdk.orchestration.client.model.LLMModuleResultSynchronous;
45
import javax.annotation.Nonnull;
6+
import lombok.Getter;
7+
import lombok.RequiredArgsConstructor;
58

69
/** Orchestration chat completion output. */
7-
public class OrchestrationResponse extends CompletionPostResponse {
10+
@RequiredArgsConstructor
11+
@Getter
12+
public class OrchestrationResponse {
13+
private final CompletionPostResponse data;
14+
815
/**
916
* Get the message content from the output.
1017
*
@@ -15,12 +22,12 @@ public class OrchestrationResponse extends CompletionPostResponse {
1522
*/
1623
@Nonnull
1724
public String getContent() throws OrchestrationClientException {
18-
if (getOrchestrationResult().getChoices().isEmpty()) {
19-
return "";
20-
}
21-
if ("content_filter".equals(getOrchestrationResult().getChoices().get(0).getFinishReason())) {
25+
final var choice =
26+
((LLMModuleResultSynchronous) data.getOrchestrationResult()).getChoices().get(0);
27+
28+
if ("content_filter".equals(choice.getFinishReason())) {
2229
throw new OrchestrationClientException("Content filter filtered the output.");
2330
}
24-
return getOrchestrationResult().getChoices().get(0).getMessage().getContent();
31+
return choice.getMessage().getContent();
2532
}
2633
}

orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,12 @@ void testTemplating() throws IOException {
140140
final var result =
141141
client.chatCompletion(new OrchestrationPrompt(inputParams, template), config);
142142

143-
assertThat(result.getRequestId()).isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91");
144-
assertThat(result.getModuleResults().getTemplating().get(0).getContent())
143+
final var response = result.getData();
144+
assertThat(response.getRequestId()).isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91");
145+
assertThat(response.getModuleResults().getTemplating().get(0).getContent())
145146
.isEqualTo("Reply with 'Orchestration Service is working!' in German");
146-
assertThat(result.getModuleResults().getTemplating().get(0).getRole()).isEqualTo("user");
147-
var llm = (LLMModuleResultSynchronous) result.getModuleResults().getLlm();
147+
assertThat(response.getModuleResults().getTemplating().get(0).getRole()).isEqualTo("user");
148+
var llm = (LLMModuleResultSynchronous) response.getModuleResults().getLlm();
148149
assertThat(llm.getId()).isEqualTo("chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj");
149150
assertThat(llm.getObject()).isEqualTo("chat.completion");
150151
assertThat(llm.getCreated()).isEqualTo(1721224505);
@@ -159,7 +160,7 @@ void testTemplating() throws IOException {
159160
assertThat(usage.getCompletionTokens()).isEqualTo(7);
160161
assertThat(usage.getPromptTokens()).isEqualTo(19);
161162
assertThat(usage.getTotalTokens()).isEqualTo(26);
162-
var orchestrationResult = (LLMModuleResultSynchronous) result.getOrchestrationResult();
163+
var orchestrationResult = (LLMModuleResultSynchronous) response.getOrchestrationResult();
163164
assertThat(orchestrationResult.getId()).isEqualTo("chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj");
164165
assertThat(orchestrationResult.getObject()).isEqualTo("chat.completion");
165166
assertThat(orchestrationResult.getCreated()).isEqualTo(1721224505);
@@ -285,7 +286,7 @@ void messagesHistory() throws IOException {
285286

286287
final var result = client.chatCompletion(prompt, config);
287288

288-
assertThat(result.getRequestId()).isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91");
289+
assertThat(result.getData().getRequestId()).isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91");
289290

290291
// verify that the history is sent correctly
291292
try (var requestInputStream = fileLoader.apply("messagesHistoryRequest.json")) {
@@ -309,13 +310,13 @@ void maskingAnonymization() throws IOException {
309310
createMaskingConfig(DPIConfig.MethodEnum.ANONYMIZATION, DPIEntities.PHONE);
310311

311312
final var result = client.chatCompletion(prompt, config.withMaskingConfig(maskingConfig));
313+
final var response = result.getData();
312314

313-
assertThat(result).isNotNull();
314-
GenericModuleResult inputMasking = result.getModuleResults().getInputMasking();
315+
assertThat(response).isNotNull();
316+
GenericModuleResult inputMasking = response.getModuleResults().getInputMasking();
315317
assertThat(inputMasking.getMessage()).isEqualTo("Input to LLM is masked successfully.");
316318
assertThat(inputMasking.getData()).isNotNull();
317-
final var choices = ((LLMModuleResultSynchronous) result.getOrchestrationResult()).getChoices();
318-
assertThat(choices.get(0).getMessage().getContent())
319+
assertThat(result.getContent())
319320
.isEqualTo(
320321
"I'm sorry, I cannot provide information about specific individuals, including their nationality.");
321322

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
<enforcer.skipBanGeneratedModulesReference>false</enforcer.skipBanGeneratedModulesReference>
7575
<!-- Test coverage -->
7676
<coverage.instruction>74%</coverage.instruction>
77-
<coverage.branch>66%</coverage.branch>
77+
<coverage.branch>67%</coverage.branch>
7878
<coverage.complexity>67%</coverage.complexity>
7979
<coverage.line>75%</coverage.line>
8080
<coverage.method>80%</coverage.method>

sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ void testCompletion() {
3232

3333
@Test
3434
void testTemplate() {
35-
final var result = controller.template();
35+
final var response = controller.template();
36+
final var result = response.getData();
3637

3738
assertThat(result.getRequestId()).isNotEmpty();
3839
assertThat(result.getModuleResults().getTemplating().get(0).getContent())
@@ -58,20 +59,21 @@ void testTemplate() {
5859
assertThat(orchestrationResult.getCreated()).isGreaterThan(1);
5960
assertThat(orchestrationResult.getModel())
6061
.isEqualTo(OrchestrationController.LLM_CONFIG.getModelName());
61-
choices = ((LLMModuleResultSynchronous) orchestrationResult).getChoices();
62+
choices = orchestrationResult.getChoices();
6263
assertThat(choices.get(0).getIndex()).isZero();
6364
assertThat(choices.get(0).getMessage().getContent()).isNotEmpty();
6465
assertThat(choices.get(0).getMessage().getRole()).isEqualTo("assistant");
6566
assertThat(choices.get(0).getFinishReason()).isEqualTo("stop");
66-
usage = ((LLMModuleResultSynchronous) orchestrationResult).getUsage();
67+
usage = orchestrationResult.getUsage();
6768
assertThat(usage.getCompletionTokens()).isGreaterThan(1);
6869
assertThat(usage.getPromptTokens()).isGreaterThan(1);
6970
assertThat(usage.getTotalTokens()).isGreaterThan(1);
7071
}
7172

7273
@Test
7374
void testLenientContentFilter() {
74-
var result = controller.filter(AzureThreshold.NUMBER_4);
75+
var response = controller.filter(AzureThreshold.NUMBER_4);
76+
var result = response.getData();
7577
var llmChoice =
7678
((LLMModuleResultSynchronous) result.getOrchestrationResult()).getChoices().get(0);
7779
assertThat(llmChoice.getFinishReason()).isEqualTo("stop");
@@ -91,15 +93,16 @@ void testStrictContentFilter() {
9193

9294
@Test
9395
void testMessagesHistory() {
94-
CompletionPostResponse result = controller.messagesHistory();
96+
CompletionPostResponse result = controller.messagesHistory().getData();
9597
final var choices = ((LLMModuleResultSynchronous) result.getOrchestrationResult()).getChoices();
9698
assertThat(choices.get(0).getMessage().getContent()).isNotEmpty();
9799
}
98100

99101
@SuppressWarnings("unchecked")
100102
@Test
101103
void testMaskingAnonymization() {
102-
var result = controller.maskingAnonymization();
104+
var response = controller.maskingAnonymization();
105+
var result = response.getData();
103106
var llmChoice =
104107
((LLMModuleResultSynchronous) result.getOrchestrationResult()).getChoices().get(0);
105108
assertThat(llmChoice.getFinishReason()).isEqualTo("stop");
@@ -118,7 +121,8 @@ void testMaskingAnonymization() {
118121
@SuppressWarnings("unchecked")
119122
@Test
120123
void testMaskingPseudonymization() {
121-
var result = controller.maskingPseudonymization();
124+
var response = controller.maskingPseudonymization();
125+
var result = response.getData();
122126
var llmChoice =
123127
((LLMModuleResultSynchronous) result.getOrchestrationResult()).getChoices().get(0);
124128
assertThat(llmChoice.getFinishReason()).isEqualTo("stop");

0 commit comments

Comments
 (0)