Skip to content

Commit 837b23e

Browse files
Added tests
1 parent 61198e1 commit 837b23e

File tree

8 files changed

+102
-17
lines changed

8 files changed

+102
-17
lines changed

orchestration/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,11 @@
112112
<artifactId>mockito-core</artifactId>
113113
<scope>test</scope>
114114
</dependency>
115+
<dependency>
116+
<groupId>org.junit.jupiter</groupId>
117+
<artifactId>junit-jupiter-params</artifactId>
118+
<scope>test</scope>
119+
</dependency>
115120
</dependencies>
116121

117122
<profiles>

orchestration/src/main/java/com/sap/ai/sdk/orchestration/IterableStreamConverter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ public T next() {
9999
static Stream<String> lines(@Nullable final HttpEntity entity)
100100
throws OrchestrationClientException {
101101
if (entity == null) {
102-
throw new OrchestrationClientException("OpenAI response was empty.");
102+
throw new OrchestrationClientException("Orchestration service response was empty.");
103103
}
104104

105105
final InputStream inputStream;

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationStreamingHandler.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ Stream<D> handleResponse(@Nonnull final ClassicHttpResponse response)
4545
try {
4646
return JACKSON.readValue(data, deltaType);
4747
} catch (final IOException e) { // exception message e gets lost
48-
log.error("Failed to parse the following response from OpenAI model: {}", line);
48+
log.error(
49+
"Failed to parse the following response from the Orchestration service: {}",
50+
line);
4951
throw new OrchestrationClientException("Failed to parse delta message: " + line, e);
5052
}
5153
});

orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java

Lines changed: 63 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@
3333
import com.github.tomakehurst.wiremock.junit5.WireMockTest;
3434
import com.github.tomakehurst.wiremock.stubbing.Scenario;
3535
import com.sap.ai.sdk.orchestration.model.ChatMessage;
36-
import com.sap.ai.sdk.orchestration.model.CompletionPostRequest;
3736
import com.sap.ai.sdk.orchestration.model.DPIEntities;
3837
import com.sap.ai.sdk.orchestration.model.GenericModuleResult;
3938
import com.sap.ai.sdk.orchestration.model.LLMModuleResultSynchronous;
4039
import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor;
40+
import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Cache;
4141
import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination;
4242
import java.io.IOException;
4343
import java.io.InputStream;
@@ -46,13 +46,17 @@
4646
import java.util.Objects;
4747
import java.util.function.Function;
4848
import java.util.stream.Stream;
49+
import javax.annotation.Nonnull;
4950
import org.apache.hc.client5.http.classic.HttpClient;
5051
import org.apache.hc.core5.http.ContentType;
5152
import org.apache.hc.core5.http.io.entity.InputStreamEntity;
5253
import org.apache.hc.core5.http.message.BasicClassicHttpResponse;
5354
import org.assertj.core.api.SoftAssertions;
55+
import org.junit.jupiter.api.AfterEach;
5456
import org.junit.jupiter.api.BeforeEach;
5557
import org.junit.jupiter.api.Test;
58+
import org.junit.jupiter.params.ParameterizedTest;
59+
import org.junit.jupiter.params.provider.MethodSource;
5660
import org.mockito.Mockito;
5761

5862
/**
@@ -71,9 +75,9 @@ class OrchestrationUnitTest {
7175
private final Function<String, InputStream> fileLoader =
7276
filename -> Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream(filename));
7377

74-
private OrchestrationClient client;
75-
private OrchestrationModuleConfig config;
76-
private OrchestrationPrompt prompt;
78+
private static OrchestrationClient client;
79+
private static OrchestrationModuleConfig config;
80+
private static OrchestrationPrompt prompt;
7781

7882
@BeforeEach
7983
void setup(WireMockRuntimeInfo server) {
@@ -82,6 +86,13 @@ void setup(WireMockRuntimeInfo server) {
8286
client = new OrchestrationClient(destination);
8387
config = new OrchestrationModuleConfig().withLlmConfig(CUSTOM_GPT_35);
8488
prompt = new OrchestrationPrompt("Hello World! Why is this phrase so famous?");
89+
ApacheHttpClient5Accessor.setHttpClientCache(ApacheHttpClient5Cache.DISABLED);
90+
}
91+
92+
@AfterEach
93+
void reset() {
94+
ApacheHttpClient5Accessor.setHttpClientCache(null);
95+
ApacheHttpClient5Accessor.setHttpClientFactory(null);
8596
}
8697

8798
@Test
@@ -286,8 +297,20 @@ void maskingPseudonymization() throws IOException {
286297
}
287298
}
288299

289-
@Test
290-
void testErrorHandling() {
300+
private static Runnable[] errorHandlingCalls() {
301+
return new Runnable[] {
302+
() -> client.chatCompletion(new OrchestrationPrompt(""), config),
303+
() ->
304+
client
305+
.streamChatCompletion(new OrchestrationPrompt(""), config)
306+
// the stream needs to be consumed to parse the response
307+
.forEach(System.out::println)
308+
};
309+
}
310+
311+
@ParameterizedTest
312+
@MethodSource("errorHandlingCalls")
313+
void testErrorHandling(@Nonnull final Runnable request) {
291314
stubFor(
292315
post(anyUrl())
293316
.inScenario("Errors")
@@ -321,7 +344,6 @@ void testErrorHandling() {
321344
stubFor(post(anyUrl()).inScenario("Errors").whenScenarioStateIs("4").willReturn(noContent()));
322345

323346
final var softly = new SoftAssertions();
324-
final Runnable request = () -> client.executeRequest(mock(CompletionPostRequest.class));
325347

326348
softly
327349
.assertThatThrownBy(request::run)
@@ -432,6 +454,32 @@ void testThrowsOnContentFilter() {
432454
.hasMessageContaining("Content filter");
433455
}
434456

457+
@Test
458+
void streamChatCompletionOutputFilterErrorHandling() throws IOException {
459+
try (var inputStream = spy(fileLoader.apply("streamChatCompletionOutputFilter.txt"))) {
460+
461+
final var httpClient = mock(HttpClient.class);
462+
ApacheHttpClient5Accessor.setHttpClientFactory(destination -> httpClient);
463+
464+
// Create a mock response
465+
final var mockResponse = new BasicClassicHttpResponse(200, "OK");
466+
final var inputStreamEntity = new InputStreamEntity(inputStream, ContentType.TEXT_PLAIN);
467+
mockResponse.setEntity(inputStreamEntity);
468+
mockResponse.setHeader("Content-Type", "text/event-stream");
469+
470+
// Configure the HttpClient mock to return the mock response
471+
doReturn(mockResponse).when(httpClient).executeOpen(any(), any(), any());
472+
473+
try (Stream<String> stream = client.streamChatCompletion(prompt, config)) {
474+
assertThatThrownBy(() -> stream.forEach(System.out::println))
475+
.isInstanceOf(OrchestrationClientException.class)
476+
.hasMessage("Content filter filtered the output.");
477+
}
478+
479+
Mockito.verify(inputStream, times(1)).close();
480+
}
481+
}
482+
435483
@Test
436484
void streamChatCompletionDeltas() throws IOException {
437485
try (var inputStream = spy(fileLoader.apply("streamChatCompletion.txt"))) {
@@ -470,6 +518,10 @@ void streamChatCompletionDeltas() throws IOException {
470518
assertThat(deltaList.get(2).getRequestId())
471519
.isEqualTo("5bd87b41-6368-4c18-aaae-47ab82e9475b");
472520

521+
assertThat(deltaList.get(0).getFinishReason()).isEqualTo("");
522+
assertThat(deltaList.get(1).getFinishReason()).isEqualTo("");
523+
assertThat(deltaList.get(2).getFinishReason()).isEqualTo("stop");
524+
473525
// should be of type LLMModuleResultStreaming, will be fixed with a discriminator
474526
var result0 = (LLMModuleResultSynchronous) deltaList.get(0).getOrchestrationResult();
475527
var result1 = (LLMModuleResultSynchronous) deltaList.get(1).getOrchestrationResult();
@@ -486,6 +538,8 @@ void streamChatCompletionDeltas() throws IOException {
486538
final var choices0 = result0.getChoices().get(0);
487539
assertThat(choices0.getIndex()).isEqualTo(0);
488540
assertThat(choices0.getFinishReason()).isEmpty();
541+
assertThat(choices0.getCustomField("delta")).isNotNull();
542+
// this should be getDelta(), only when the result is of type LLMModuleResultStreaming
489543
final var message0 = (Map<String, Object>) choices0.getCustomField("delta");
490544
assertThat(message0.get("role")).isEqualTo("");
491545
assertThat(message0.get("content")).isEqualTo("");
@@ -501,10 +555,10 @@ void streamChatCompletionDeltas() throws IOException {
501555
assertThat(result1.getModel()).isEqualTo("gpt-35-turbo");
502556
assertThat(result1.getObject()).isEqualTo("chat.completion.chunk");
503557
assertThat(result1.getUsage()).isNull();
558+
assertThat(result1.getChoices()).hasSize(1);
504559
final var choices1 = result1.getChoices().get(0);
505560
assertThat(choices1.getIndex()).isEqualTo(0);
506561
assertThat(choices1.getFinishReason()).isEmpty();
507-
// this should be getDelta(), only when the result is of type LLMModuleResultStreaming
508562
assertThat(choices1.getCustomField("delta")).isNotNull();
509563
final var message1 = (Map<String, Object>) choices1.getCustomField("delta");
510564
assertThat(message1.get("role")).isEqualTo("assistant");
@@ -516,6 +570,7 @@ void streamChatCompletionDeltas() throws IOException {
516570
assertThat(result2.getModel()).isEqualTo("gpt-35-turbo");
517571
assertThat(result2.getObject()).isEqualTo("chat.completion.chunk");
518572
assertThat(result2.getUsage()).isNull();
573+
assertThat(result2.getChoices()).hasSize(1);
519574
final var choices2 = result2.getChoices().get(0);
520575
assertThat(choices2.getIndex()).isEqualTo(0);
521576
assertThat(choices2.getFinishReason()).isEqualTo("stop");
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
{
2+
"request_id": "b589de57-512e-4e11-9b69-8601453b3296",
3+
"code": 400,
4+
"message": "Content filtered due to safety violations. Please modify the prompt and try again.",
5+
"location": "Filtering Module - Input Filter",
6+
"module_results": {
7+
"templating": [
8+
{
9+
"role": "user",
10+
"content": "Fuck you"
11+
}
12+
],
13+
"input_filtering": {
14+
"message": "Content filtered due to safety violations. Please modify the prompt and try again.",
15+
"data": {
16+
"azure_content_safety": {
17+
"Hate": 2
18+
}
19+
}
20+
}
21+
}
22+
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
data: {"request_id": "eec90bca-a43e-43fa-864e-1d8962341350", "module_results": {"templating": [{"role": "user", "content": "Create 3 paraphrases of the following text: 'I hate you.'"}]}, "orchestration_result": {"id": "", "object": "", "created": 0, "model": "", "system_fingerprint": "", "choices": [{"index": 0, "delta": {"role": "", "content": ""}, "finish_reason": ""}]}}
2+
data: {"request_id": "eec90bca-a43e-43fa-864e-1d8962341350", "module_results": {"llm": {"id": "chatcmpl-Ab4mSDp5DXFu7hfbs2DkCsVJaM4IP", "object": "chat.completion.chunk", "created": 1733399876, "model": "gpt-35-turbo", "system_fingerprint": "fp_808245b034", "choices": [{"index": 0, "delta": {"role": "assistant", "content": "1. I can't stand you.\n2. You are detestable to me.\n3. I have a strong aversion towards you."}, "finish_reason": "stop"}]}, "output_filtering": {"message": "Content filtered due to safety violations. Model returned a result violating the safety threshold. Please modify the prompt and try again.", "data": {"original_service_response": {"azure_content_safety": {"content_allowed": false, "original_service_response": {"Hate": 2}, "checked_text": "1. I can't stand you. 2. You are detestable to me. 3. I have a strong aversion towards you."}}}}}, "orchestration_result": {"id": "chatcmpl-Ab4mSDp5DXFu7hfbs2DkCsVJaM4IP", "object": "chat.completion.chunk", "created": 1733399876, "model": "gpt-35-turbo", "system_fingerprint": "fp_808245b034", "choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": "content_filter"}]}}

pom.xml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,11 @@
7676
<enforcer.skipEnforceScopeLombok>false</enforcer.skipEnforceScopeLombok>
7777
<enforcer.skipBanGeneratedModulesReference>false</enforcer.skipBanGeneratedModulesReference>
7878
<!-- Test coverage -->
79-
<coverage.instruction>75%</coverage.instruction>
80-
<coverage.branch>67%</coverage.branch>
81-
<coverage.complexity>69%</coverage.complexity>
82-
<coverage.line>76%</coverage.line>
83-
<coverage.method>85%</coverage.method>
79+
<coverage.instruction>77%</coverage.instruction>
80+
<coverage.branch>68%</coverage.branch>
81+
<coverage.complexity>71%</coverage.complexity>
82+
<coverage.line>79%</coverage.line>
83+
<coverage.method>100%</coverage.method>
8484
<coverage.class>85%</coverage.class>
8585
</properties>
8686

sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,8 @@ void testStreamChatCompletion() {
4444
// foreach consumes all elements, closing the stream at the end
4545
.forEach(
4646
delta -> {
47-
final String deltaContent = delta.getDeltaContent();
4847
log.info("delta: {}", delta);
49-
if (!deltaContent.isEmpty()) {
48+
if (!delta.isEmpty()) {
5049
filledDeltaCount.incrementAndGet();
5150
}
5251
});

0 commit comments

Comments
 (0)