Skip to content

Commit 7e3e693

Browse files
committed
Merge remote-tracking branch 'origin/main' into feat/orch-convenient-filtering
2 parents 55ca0a4 + cf661ad commit 7e3e693

File tree

9 files changed

+148
-42
lines changed

9 files changed

+148
-42
lines changed

.github/workflows/perform-release.yml

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,6 @@ jobs:
8888
with:
8989
distribution: "sapmachine"
9090
java-version: ${{ env.JAVA_VERSION }}
91-
server-id: ossrh
92-
server-username: MAVEN_CENTRAL_USER # env variable for username in deploy
93-
server-password: MAVEN_CENTRAL_PASSWORD # env variable for token in deploy
9491

9592
- name: "Download Release Asset"
9693
id: download-asset
@@ -113,8 +110,8 @@ jobs:
113110
114111
- name: "Deploy"
115112
run: |
116-
MVN_ARGS="${{ env.MVN_CLI_ARGS }} deploy -Drelease -s settings.xml"
117-
mvn $MVN_ARGS
113+
MVN_ARGS="${{ env.MVN_CLI_ARGS }} -Drelease -s settings.xml"
114+
mvn deploy $MVN_ARGS
118115
env:
119116
MAVEN_GPG_PASSPHRASE: ${{ secrets.PGP_PASSPHRASE }}
120117

docs/guides/ORCHESTRATION_CHAT_COMPLETION.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,7 @@ var prompt = new OrchestrationPrompt("Hello world! Why is this phrase so famous?
9191

9292
var result = client.chatCompletion(prompt, config);
9393

94-
String messageResult =
95-
result.getOrchestrationResult().getChoices().get(0).getMessage().getContent();
94+
String messageResult = result.getContent();
9695
```
9796

9897
In this example, the Orchestration service generates a response to the user message "Hello world! Why is this phrase so famous?".
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import static lombok.AccessLevel.PACKAGE;
4+
5+
import com.sap.ai.sdk.orchestration.client.model.CompletionPostResponse;
6+
import com.sap.ai.sdk.orchestration.client.model.LLMModuleResultSynchronous;
7+
import javax.annotation.Nonnull;
8+
import lombok.RequiredArgsConstructor;
9+
import lombok.Value;
10+
11+
/** Orchestration chat completion output. */
12+
@Value
13+
@RequiredArgsConstructor(access = PACKAGE)
14+
public class OrchestrationChatResponse {
15+
CompletionPostResponse originalResponse;
16+
17+
/**
18+
* Get the message content from the output.
19+
*
20+
* <p>Note: If there are multiple choices only the first one is returned
21+
*
22+
* @return the message content or empty string.
23+
* @throws OrchestrationClientException if the content filter filtered the output.
24+
*/
25+
@Nonnull
26+
public String getContent() throws OrchestrationClientException {
27+
final var choices =
28+
((LLMModuleResultSynchronous) originalResponse.getOrchestrationResult()).getChoices();
29+
30+
if (choices.isEmpty()) {
31+
return "";
32+
}
33+
34+
final var choice = choices.get(0);
35+
36+
if ("content_filter".equals(choice.getFinishReason())) {
37+
throw new OrchestrationClientException("Content filter filtered the output.");
38+
}
39+
return choice.getMessage().getContent();
40+
}
41+
}

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,12 @@ public static CompletionPostRequest toCompletionPostRequest(
101101
* @throws OrchestrationClientException if the request fails.
102102
*/
103103
@Nonnull
104-
public CompletionPostResponse chatCompletion(
104+
public OrchestrationChatResponse chatCompletion(
105105
@Nonnull final OrchestrationPrompt prompt, @Nonnull final OrchestrationModuleConfig config)
106106
throws OrchestrationClientException {
107107

108108
val request = toCompletionPostRequest(prompt, config);
109-
return executeRequest(request);
109+
return new OrchestrationChatResponse(executeRequest(request));
110110
}
111111

112112
/**

orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,7 @@ void testCompletion() {
116116
final var result = client.chatCompletion(prompt, config);
117117

118118
assertThat(result).isNotNull();
119-
var orchestrationResult = (LLMModuleResultSynchronous) result.getOrchestrationResult();
120-
assertThat(orchestrationResult.getChoices().get(0).getMessage().getContent()).isNotEmpty();
119+
assertThat(result.getContent()).isNotEmpty();
121120
}
122121

123122
@Test
@@ -136,11 +135,12 @@ void testTemplating() throws IOException {
136135
final var result =
137136
client.chatCompletion(new OrchestrationPrompt(inputParams, template), config);
138137

139-
assertThat(result.getRequestId()).isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91");
140-
assertThat(result.getModuleResults().getTemplating().get(0).getContent())
138+
final var response = result.getOriginalResponse();
139+
assertThat(response.getRequestId()).isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91");
140+
assertThat(response.getModuleResults().getTemplating().get(0).getContent())
141141
.isEqualTo("Reply with 'Orchestration Service is working!' in German");
142-
assertThat(result.getModuleResults().getTemplating().get(0).getRole()).isEqualTo("user");
143-
var llm = (LLMModuleResultSynchronous) result.getModuleResults().getLlm();
142+
assertThat(response.getModuleResults().getTemplating().get(0).getRole()).isEqualTo("user");
143+
var llm = (LLMModuleResultSynchronous) response.getModuleResults().getLlm();
144144
assertThat(llm.getId()).isEqualTo("chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj");
145145
assertThat(llm.getObject()).isEqualTo("chat.completion");
146146
assertThat(llm.getCreated()).isEqualTo(1721224505);
@@ -155,7 +155,7 @@ void testTemplating() throws IOException {
155155
assertThat(usage.getCompletionTokens()).isEqualTo(7);
156156
assertThat(usage.getPromptTokens()).isEqualTo(19);
157157
assertThat(usage.getTotalTokens()).isEqualTo(26);
158-
var orchestrationResult = (LLMModuleResultSynchronous) result.getOrchestrationResult();
158+
var orchestrationResult = (LLMModuleResultSynchronous) response.getOrchestrationResult();
159159
assertThat(orchestrationResult.getId()).isEqualTo("chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj");
160160
assertThat(orchestrationResult.getObject()).isEqualTo("chat.completion");
161161
assertThat(orchestrationResult.getCreated()).isEqualTo(1721224505);
@@ -274,7 +274,8 @@ void messagesHistory() throws IOException {
274274

275275
final var result = client.chatCompletion(prompt, config);
276276

277-
assertThat(result.getRequestId()).isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91");
277+
assertThat(result.getOriginalResponse().getRequestId())
278+
.isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91");
278279

279280
// verify that the history is sent correctly
280281
try (var requestInputStream = fileLoader.apply("messagesHistoryRequest.json")) {
@@ -298,13 +299,13 @@ void maskingPseudonymization() throws IOException {
298299
createMaskingConfig(DPIConfig.MethodEnum.PSEUDONYMIZATION, DPIEntities.PHONE);
299300

300301
final var result = client.chatCompletion(prompt, config.withMaskingConfig(maskingConfig));
302+
final var response = result.getOriginalResponse();
301303

302-
assertThat(result).isNotNull();
303-
GenericModuleResult inputMasking = result.getModuleResults().getInputMasking();
304+
assertThat(response).isNotNull();
305+
GenericModuleResult inputMasking = response.getModuleResults().getInputMasking();
304306
assertThat(inputMasking.getMessage()).isEqualTo("Input to LLM is masked successfully.");
305307
assertThat(inputMasking.getData()).isNotNull();
306-
final var choices = ((LLMModuleResultSynchronous) result.getOrchestrationResult()).getChoices();
307-
assertThat(choices.get(0).getMessage().getContent()).contains("Hi Mallory");
308+
assertThat(result.getContent()).contains("Hi Mallory");
308309

309310
// verify that the request is sent correctly
310311
try (var requestInputStream = fileLoader.apply("maskingRequest.json")) {
@@ -400,4 +401,17 @@ void testErrorHandling() {
400401

401402
softly.assertAll();
402403
}
404+
405+
@Test
406+
void testEmptyChoicesResponse() {
407+
stubFor(
408+
post(urlPathEqualTo("/v2/inference/deployments/abcdef0123456789/completion"))
409+
.willReturn(
410+
aResponse()
411+
.withBodyFile("emptyChoicesResponse.json")
412+
.withHeader("Content-Type", "application/json")));
413+
final var result = client.chatCompletion(prompt, config);
414+
415+
assertThat(result.getContent()).isEmpty();
416+
}
403417
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
{
2+
"request_id": "26ea36b5-c196-4806-a9a6-a686f0c6ad91",
3+
"module_results": {
4+
"templating": [
5+
{
6+
"role": "user",
7+
"content": "Reply with 'Orchestration Service is working!' in German"
8+
}
9+
],
10+
"llm": {
11+
"id": "chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj",
12+
"object": "chat.completion",
13+
"created": 1721224505,
14+
"model": "gpt-35-turbo-16k",
15+
"choices": [],
16+
"usage": {
17+
"completion_tokens": 7,
18+
"prompt_tokens": 19,
19+
"total_tokens": 26
20+
}
21+
}
22+
},
23+
"orchestration_result": {
24+
"id": "chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj",
25+
"object": "chat.completion",
26+
"created": 1721224505,
27+
"model": "gpt-35-turbo-16k",
28+
"choices": [],
29+
"usage": {
30+
"completion_tokens": 7,
31+
"prompt_tokens": 19,
32+
"total_tokens": 26
33+
}
34+
}
35+
}

pom.xml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,28 @@ https://gitbox.apache.org/repos/asf?p=maven-pmd-plugin.git;a=blob_plain;f=src/ma
724724
</plugins>
725725
</pluginManagement>
726726
<plugins>
727+
<!-- The release artifacts don't contain our custom config files for these plugins -->
728+
<plugin>
729+
<groupId>org.apache.maven.plugins</groupId>
730+
<artifactId>maven-checkstyle-plugin</artifactId>
731+
<configuration>
732+
<skip>true</skip>
733+
</configuration>
734+
</plugin>
735+
<plugin>
736+
<groupId>org.apache.maven.plugins</groupId>
737+
<artifactId>maven-pmd-plugin</artifactId>
738+
<configuration>
739+
<skip>true</skip>
740+
</configuration>
741+
</plugin>
742+
<plugin>
743+
<groupId>com.github.spotbugs</groupId>
744+
<artifactId>spotbugs-maven-plugin</artifactId>
745+
<configuration>
746+
<skip>true</skip>
747+
</configuration>
748+
</plugin>
727749
<plugin>
728750
<groupId>org.apache.maven.plugins</groupId>
729751
<artifactId>maven-source-plugin</artifactId>

sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
package com.sap.ai.sdk.app.controllers;
22

3+
import com.sap.ai.sdk.orchestration.OrchestrationChatResponse;
34
import com.sap.ai.sdk.orchestration.OrchestrationClient;
45
import com.sap.ai.sdk.orchestration.OrchestrationModuleConfig;
56
import com.sap.ai.sdk.orchestration.OrchestrationPrompt;
67
import com.sap.ai.sdk.orchestration.client.model.AzureContentSafety;
78
import com.sap.ai.sdk.orchestration.client.model.AzureThreshold;
89
import com.sap.ai.sdk.orchestration.client.model.ChatMessage;
9-
import com.sap.ai.sdk.orchestration.client.model.CompletionPostResponse;
1010
import com.sap.ai.sdk.orchestration.client.model.DPIConfig;
1111
import com.sap.ai.sdk.orchestration.client.model.DPIEntities;
1212
import com.sap.ai.sdk.orchestration.client.model.DPIEntityConfig;
@@ -40,7 +40,7 @@ class OrchestrationController {
4040
*/
4141
@GetMapping("/completion")
4242
@Nonnull
43-
public CompletionPostResponse completion() {
43+
public OrchestrationChatResponse completion() {
4444
final var prompt = new OrchestrationPrompt("Hello world! Why is this phrase so famous?");
4545

4646
return client.chatCompletion(prompt, config);
@@ -53,7 +53,7 @@ public CompletionPostResponse completion() {
5353
*/
5454
@GetMapping("/template")
5555
@Nonnull
56-
public CompletionPostResponse template() {
56+
public OrchestrationChatResponse template() {
5757
final var template =
5858
new ChatMessage()
5959
.role("user")
@@ -74,7 +74,7 @@ public CompletionPostResponse template() {
7474
*/
7575
@GetMapping("/messagesHistory")
7676
@Nonnull
77-
public CompletionPostResponse messagesHistory() {
77+
public OrchestrationChatResponse messagesHistory() {
7878
final List<ChatMessage> messagesHistory =
7979
List.of(
8080
new ChatMessage().role("user").content("What is the capital of France?"),
@@ -94,7 +94,7 @@ public CompletionPostResponse messagesHistory() {
9494
*/
9595
@GetMapping("/filter/{threshold}")
9696
@Nonnull
97-
public CompletionPostResponse filter(
97+
public OrchestrationChatResponse filter(
9898
@Nonnull @PathVariable("threshold") final AzureThreshold threshold) {
9999
final var prompt =
100100
new OrchestrationPrompt(
@@ -125,7 +125,7 @@ public CompletionPostResponse filter(
125125
*/
126126
@GetMapping("/maskingAnonymization")
127127
@Nonnull
128-
public CompletionPostResponse maskingAnonymization() {
128+
public OrchestrationChatResponse maskingAnonymization() {
129129
final var systemMessage =
130130
new ChatMessage()
131131
.role("system")
@@ -156,7 +156,7 @@ public CompletionPostResponse maskingAnonymization() {
156156
*/
157157
@GetMapping("/maskingPseudonymization")
158158
@Nonnull
159-
public CompletionPostResponse maskingPseudonymization() {
159+
public OrchestrationChatResponse maskingPseudonymization() {
160160
final var systemMessage =
161161
new ChatMessage()
162162
.role("system")

sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,13 @@ void testCompletion() {
2727
final var result = controller.completion();
2828

2929
assertThat(result).isNotNull();
30-
assertThat(
31-
((LLMModuleResultSynchronous) result.getOrchestrationResult())
32-
.getChoices()
33-
.get(0)
34-
.getMessage()
35-
.getContent())
36-
.isNotEmpty();
30+
assertThat(result.getContent()).isNotEmpty();
3731
}
3832

3933
@Test
4034
void testTemplate() {
41-
final var result = controller.template();
35+
final var response = controller.template();
36+
final var result = response.getOriginalResponse();
4237

4338
assertThat(result.getRequestId()).isNotEmpty();
4439
assertThat(result.getModuleResults().getTemplating().get(0).getContent())
@@ -64,20 +59,21 @@ void testTemplate() {
6459
assertThat(orchestrationResult.getCreated()).isGreaterThan(1);
6560
assertThat(orchestrationResult.getModel())
6661
.isEqualTo(OrchestrationController.LLM_CONFIG.getModelName());
67-
choices = ((LLMModuleResultSynchronous) orchestrationResult).getChoices();
62+
choices = orchestrationResult.getChoices();
6863
assertThat(choices.get(0).getIndex()).isZero();
6964
assertThat(choices.get(0).getMessage().getContent()).isNotEmpty();
7065
assertThat(choices.get(0).getMessage().getRole()).isEqualTo("assistant");
7166
assertThat(choices.get(0).getFinishReason()).isEqualTo("stop");
72-
usage = ((LLMModuleResultSynchronous) orchestrationResult).getUsage();
67+
usage = orchestrationResult.getUsage();
7368
assertThat(usage.getCompletionTokens()).isGreaterThan(1);
7469
assertThat(usage.getPromptTokens()).isGreaterThan(1);
7570
assertThat(usage.getTotalTokens()).isGreaterThan(1);
7671
}
7772

7873
@Test
7974
void testLenientContentFilter() {
80-
var result = controller.filter(AzureThreshold.NUMBER_4);
75+
var response = controller.filter(AzureThreshold.NUMBER_4);
76+
var result = response.getOriginalResponse();
8177
var llmChoice =
8278
((LLMModuleResultSynchronous) result.getOrchestrationResult()).getChoices().get(0);
8379
assertThat(llmChoice.getFinishReason()).isEqualTo("stop");
@@ -97,15 +93,16 @@ void testStrictContentFilter() {
9793

9894
@Test
9995
void testMessagesHistory() {
100-
CompletionPostResponse result = controller.messagesHistory();
96+
CompletionPostResponse result = controller.messagesHistory().getOriginalResponse();
10197
final var choices = ((LLMModuleResultSynchronous) result.getOrchestrationResult()).getChoices();
10298
assertThat(choices.get(0).getMessage().getContent()).isNotEmpty();
10399
}
104100

105101
@SuppressWarnings("unchecked")
106102
@Test
107103
void testMaskingAnonymization() {
108-
var result = controller.maskingAnonymization();
104+
var response = controller.maskingAnonymization();
105+
var result = response.getOriginalResponse();
109106
var llmChoice =
110107
((LLMModuleResultSynchronous) result.getOrchestrationResult()).getChoices().get(0);
111108
assertThat(llmChoice.getFinishReason()).isEqualTo("stop");
@@ -124,7 +121,8 @@ void testMaskingAnonymization() {
124121
@SuppressWarnings("unchecked")
125122
@Test
126123
void testMaskingPseudonymization() {
127-
var result = controller.maskingPseudonymization();
124+
var response = controller.maskingPseudonymization();
125+
var result = response.getOriginalResponse();
128126
var llmChoice =
129127
((LLMModuleResultSynchronous) result.getOrchestrationResult()).getChoices().get(0);
130128
assertThat(llmChoice.getFinishReason()).isEqualTo("stop");

0 commit comments

Comments
 (0)