Skip to content

Commit bc899a1

Browse files
Added unit tests
1 parent 40e56b3 commit bc899a1

File tree

4 files changed

+138
-3
lines changed

4 files changed

+138
-3
lines changed

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,12 +122,21 @@ public OrchestrationChatResponse chatCompletion(
122122
* content_filter
123123
*/
124124
@Nonnull
125-
public Stream<OrchestrationChatCompletionDelta> streamChatCompletion(
125+
public Stream<String> streamChatCompletion(
126126
@Nonnull final OrchestrationPrompt prompt, @Nonnull final OrchestrationModuleConfig config)
127127
throws OrchestrationClientException {
128128

129129
val request = toCompletionPostRequest(prompt, config);
130-
return streamChatCompletionDeltas(request);
130+
return streamChatCompletionDeltas(request)
131+
.peek(OrchestrationClient::throwOnContentFilter)
132+
.map(OrchestrationChatCompletionDelta::getDeltaContent);
133+
}
134+
135+
private static void throwOnContentFilter(@Nonnull final OrchestrationChatCompletionDelta delta) {
136+
final String finishReason = delta.getFinishReason();
137+
if (finishReason != null && finishReason.equals("content_filter")) {
138+
throw new OrchestrationClientException("Content filter filtered the output.");
139+
}
131140
}
132141

133142
/**

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationStreamingHandler.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ Stream<D> handleResponse(@Nonnull final ClassicHttpResponse response)
3535
.peek(
3636
line -> {
3737
if (!line.startsWith("data: ")) {
38-
final String msg = "Failed to parse response from OpenAI model";
38+
final String msg = "Failed to parse response from the Orchestration service";
3939
parseErrorAndThrow(line, new OrchestrationClientException(msg));
4040
}
4141
})

orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,26 +21,39 @@
2121
import static org.apache.hc.core5.http.HttpStatus.SC_BAD_REQUEST;
2222
import static org.assertj.core.api.Assertions.assertThat;
2323
import static org.assertj.core.api.Assertions.assertThatThrownBy;
24+
import static org.mockito.ArgumentMatchers.any;
25+
import static org.mockito.Mockito.doReturn;
2426
import static org.mockito.Mockito.mock;
27+
import static org.mockito.Mockito.spy;
28+
import static org.mockito.Mockito.times;
29+
import static org.mockito.Mockito.when;
2530

2631
import com.fasterxml.jackson.core.JsonParseException;
2732
import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo;
2833
import com.github.tomakehurst.wiremock.junit5.WireMockTest;
2934
import com.github.tomakehurst.wiremock.stubbing.Scenario;
35+
import com.sap.ai.sdk.orchestration.model.ChatMessage;
3036
import com.sap.ai.sdk.orchestration.model.CompletionPostRequest;
3137
import com.sap.ai.sdk.orchestration.model.DPIEntities;
3238
import com.sap.ai.sdk.orchestration.model.GenericModuleResult;
3339
import com.sap.ai.sdk.orchestration.model.LLMModuleResultSynchronous;
40+
import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor;
3441
import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination;
3542
import java.io.IOException;
3643
import java.io.InputStream;
3744
import java.util.List;
3845
import java.util.Map;
3946
import java.util.Objects;
4047
import java.util.function.Function;
48+
import java.util.stream.Stream;
49+
import org.apache.hc.client5.http.classic.HttpClient;
50+
import org.apache.hc.core5.http.ContentType;
51+
import org.apache.hc.core5.http.io.entity.InputStreamEntity;
52+
import org.apache.hc.core5.http.message.BasicClassicHttpResponse;
4153
import org.assertj.core.api.SoftAssertions;
4254
import org.junit.jupiter.api.BeforeEach;
4355
import org.junit.jupiter.api.Test;
56+
import org.mockito.Mockito;
4457

4558
/**
4659
* Test that queries are on the right URL, with the right headers. Also check that the received
@@ -402,4 +415,113 @@ void testExecuteRequestFromJsonThrows() {
402415
.isInstanceOf(IllegalArgumentException.class)
403416
.hasMessageContaining("not valid JSON");
404417
}
418+
419+
@Test
420+
void testThrowsOnContentFilter() {
421+
var mock = mock(OrchestrationClient.class);
422+
when(mock.streamChatCompletion(any(), any())).thenCallRealMethod();
423+
424+
var deltaWithContentFilter = mock(OrchestrationChatCompletionDelta.class);
425+
when(deltaWithContentFilter.getFinishReason()).thenReturn("content_filter");
426+
when(mock.streamChatCompletionDeltas(any())).thenReturn(Stream.of(deltaWithContentFilter));
427+
428+
// this must not throw, since the stream is lazily evaluated
429+
var stream = mock.streamChatCompletion(new OrchestrationPrompt(""), config);
430+
assertThatThrownBy(stream::toList)
431+
.isInstanceOf(OrchestrationClientException.class)
432+
.hasMessageContaining("Content filter");
433+
}
434+
435+
@Test
436+
void streamChatCompletionDeltas() throws IOException {
437+
try (var inputStream = spy(fileLoader.apply("streamChatCompletion.txt"))) {
438+
439+
final var httpClient = mock(HttpClient.class);
440+
ApacheHttpClient5Accessor.setHttpClientFactory(destination -> httpClient);
441+
442+
// Create a mock response
443+
final var mockResponse = new BasicClassicHttpResponse(200, "OK");
444+
final var inputStreamEntity = new InputStreamEntity(inputStream, ContentType.TEXT_PLAIN);
445+
mockResponse.setEntity(inputStreamEntity);
446+
mockResponse.setHeader("Content-Type", "text/event-stream");
447+
448+
// Configure the HttpClient mock to return the mock response
449+
doReturn(mockResponse).when(httpClient).executeOpen(any(), any(), any());
450+
451+
var prompt =
452+
new OrchestrationPrompt(
453+
"Can you give me the first 100 numbers of the Fibonacci sequence?");
454+
var request = OrchestrationClient.toCompletionPostRequest(prompt, config);
455+
456+
try (Stream<OrchestrationChatCompletionDelta> stream =
457+
client.streamChatCompletionDeltas(request)) {
458+
var deltaList = stream.toList();
459+
460+
assertThat(deltaList).hasSize(3);
461+
// the first delta doesn't have any content
462+
assertThat(deltaList.get(0).getDeltaContent()).isEqualTo("");
463+
assertThat(deltaList.get(1).getDeltaContent()).isEqualTo("Sure");
464+
assertThat(deltaList.get(2).getDeltaContent()).isEqualTo("!");
465+
466+
assertThat(deltaList.get(0).getRequestId()).isEqualTo("5bd87b41-6368-4c18-aaae-47ab82e9475b");
467+
assertThat(deltaList.get(1).getRequestId()).isEqualTo("5bd87b41-6368-4c18-aaae-47ab82e9475b");
468+
assertThat(deltaList.get(2).getRequestId()).isEqualTo("5bd87b41-6368-4c18-aaae-47ab82e9475b");
469+
470+
// should be of type LLMModuleResultStreaming, will be fixed with a discriminator
471+
var result0 = (LLMModuleResultSynchronous) deltaList.get(0).getOrchestrationResult();
472+
var result1 = (LLMModuleResultSynchronous) deltaList.get(1).getOrchestrationResult();
473+
var result2 = (LLMModuleResultSynchronous) deltaList.get(2).getOrchestrationResult();
474+
475+
assertThat(result0.getSystemFingerprint()).isEmpty();
476+
assertThat(result0.getId()).isEmpty();
477+
assertThat(result0.getCreated()).isEqualTo(0);
478+
assertThat(result0.getModel()).isEmpty();
479+
assertThat(result0.getObject()).isEmpty();
480+
// BUG: usage is absent from the request
481+
assertThat(result0.getUsage()).isNull();
482+
assertThat(result0.getChoices()).hasSize(1);
483+
final var choices0 = result0.getChoices().get(0);
484+
assertThat(choices0.getIndex()).isEqualTo(0);
485+
assertThat(choices0.getFinishReason()).isEmpty();
486+
final var message0 = (Map<String, Object>) choices0.getCustomField("delta");
487+
assertThat(message0.get("role")).isEqualTo("");
488+
assertThat(message0.get("content")).isEqualTo("");
489+
List<ChatMessage> templating = deltaList.get(0).getModuleResults().getTemplating();
490+
assertThat(templating).hasSize(1);
491+
assertThat(templating.get(0).getRole()).isEqualTo("user");
492+
assertThat(templating.get(0).getContent()).isEqualTo("Hello world! Why is this phrase so famous?");
493+
494+
assertThat(result1.getSystemFingerprint()).isEqualTo("fp_808245b034");
495+
assertThat(result1.getId()).isEqualTo("chatcmpl-AYZSQQwWv7ajJsyDBpMG4X01BBJxq");
496+
assertThat(result1.getCreated()).isEqualTo(1732802814);
497+
assertThat(result1.getModel()).isEqualTo("gpt-35-turbo");
498+
assertThat(result1.getObject()).isEqualTo("chat.completion.chunk");
499+
assertThat(result1.getUsage()).isNull();
500+
final var choices1 = result1.getChoices().get(0);
501+
assertThat(choices1.getIndex()).isEqualTo(0);
502+
assertThat(choices1.getFinishReason()).isEmpty();
503+
// this should be getDelta(), only when the result is of type LLMModuleResultStreaming
504+
assertThat(choices1.getCustomField("delta")).isNotNull();
505+
final var message1 = (Map<String, Object>) choices1.getCustomField("delta");
506+
assertThat(message1.get("role")).isEqualTo("assistant");
507+
assertThat(message1.get("content")).isEqualTo("Sure");
508+
509+
assertThat(result2.getSystemFingerprint()).isEqualTo("fp_808245b034");
510+
assertThat(result2.getId()).isEqualTo("chatcmpl-AYZSQQwWv7ajJsyDBpMG4X01BBJxq");
511+
assertThat(result2.getCreated()).isEqualTo(1732802814);
512+
assertThat(result2.getModel()).isEqualTo("gpt-35-turbo");
513+
assertThat(result2.getObject()).isEqualTo("chat.completion.chunk");
514+
assertThat(result2.getUsage()).isNull();
515+
final var choices2 = result2.getChoices().get(0);
516+
assertThat(choices2.getIndex()).isEqualTo(0);
517+
assertThat(choices2.getFinishReason()).isEqualTo("stop");
518+
// this should be getDelta(), only when the result is of type LLMModuleResultStreaming
519+
assertThat(choices2.getCustomField("delta")).isNotNull();
520+
final var message2 = (Map<String, Object>) choices2.getCustomField("delta");
521+
assertThat(message2.get("role")).isEqualTo("assistant");
522+
assertThat(message2.get("content")).isEqualTo("!");
523+
}
524+
Mockito.verify(inputStream, times(1)).close();
525+
}
526+
}
405527
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
data: {"request_id": "5bd87b41-6368-4c18-aaae-47ab82e9475b", "module_results": {"templating": [{"role": "user", "content": "Hello world! Why is this phrase so famous?"}]}, "orchestration_result": {"id": "", "object": "", "created": 0, "model": "", "system_fingerprint": "", "choices": [{"index": 0, "delta": {"role": "", "content": ""}, "finish_reason": ""}]}}
2+
data: {"request_id": "5bd87b41-6368-4c18-aaae-47ab82e9475b", "module_results": {"llm": {"id": "chatcmpl-AYZSQQwWv7ajJsyDBpMG4X01BBJxq", "object": "chat.completion.chunk", "created": 1732802814, "model": "gpt-35-turbo", "system_fingerprint": "fp_808245b034", "choices": [{"index": 0, "delta": {"role": "assistant", "content": "Sure"}, "finish_reason": ""}]}}, "orchestration_result": {"id": "chatcmpl-AYZSQQwWv7ajJsyDBpMG4X01BBJxq", "object": "chat.completion.chunk", "created": 1732802814, "model": "gpt-35-turbo", "system_fingerprint": "fp_808245b034", "choices": [{"index": 0, "delta": {"role": "assistant", "content": "Sure"}, "finish_reason": ""}]}}
3+
data: {"request_id": "5bd87b41-6368-4c18-aaae-47ab82e9475b", "module_results": {"llm": {"id": "chatcmpl-AYZSQQwWv7ajJsyDBpMG4X01BBJxq", "object": "chat.completion.chunk", "created": 1732802814, "model": "gpt-35-turbo", "system_fingerprint": "fp_808245b034", "choices": [{"index": 0, "delta": {"role": "assistant", "content": "!"}, "finish_reason": "stop"}]}}, "orchestration_result": {"id": "chatcmpl-AYZSQQwWv7ajJsyDBpMG4X01BBJxq", "object": "chat.completion.chunk", "created": 1732802814, "model": "gpt-35-turbo", "system_fingerprint": "fp_808245b034", "choices": [{"index": 0, "delta": {"role": "assistant", "content": "!"}, "finish_reason": "stop"}]}}
4+
data: [DONE]

0 commit comments

Comments
 (0)