|
21 | 21 | import static org.apache.hc.core5.http.HttpStatus.SC_BAD_REQUEST; |
22 | 22 | import static org.assertj.core.api.Assertions.assertThat; |
23 | 23 | import static org.assertj.core.api.Assertions.assertThatThrownBy; |
| 24 | +import static org.mockito.ArgumentMatchers.any; |
| 25 | +import static org.mockito.Mockito.doReturn; |
24 | 26 | import static org.mockito.Mockito.mock; |
| 27 | +import static org.mockito.Mockito.spy; |
| 28 | +import static org.mockito.Mockito.times; |
| 29 | +import static org.mockito.Mockito.when; |
25 | 30 |
|
26 | 31 | import com.fasterxml.jackson.core.JsonParseException; |
27 | 32 | import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; |
28 | 33 | import com.github.tomakehurst.wiremock.junit5.WireMockTest; |
29 | 34 | import com.github.tomakehurst.wiremock.stubbing.Scenario; |
| 35 | +import com.sap.ai.sdk.orchestration.model.ChatMessage; |
30 | 36 | import com.sap.ai.sdk.orchestration.model.CompletionPostRequest; |
31 | 37 | import com.sap.ai.sdk.orchestration.model.DPIEntities; |
32 | 38 | import com.sap.ai.sdk.orchestration.model.GenericModuleResult; |
33 | 39 | import com.sap.ai.sdk.orchestration.model.LLMModuleResultSynchronous; |
| 40 | +import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor; |
34 | 41 | import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination; |
35 | 42 | import java.io.IOException; |
36 | 43 | import java.io.InputStream; |
37 | 44 | import java.util.List; |
38 | 45 | import java.util.Map; |
39 | 46 | import java.util.Objects; |
40 | 47 | import java.util.function.Function; |
| 48 | +import java.util.stream.Stream; |
| 49 | +import org.apache.hc.client5.http.classic.HttpClient; |
| 50 | +import org.apache.hc.core5.http.ContentType; |
| 51 | +import org.apache.hc.core5.http.io.entity.InputStreamEntity; |
| 52 | +import org.apache.hc.core5.http.message.BasicClassicHttpResponse; |
41 | 53 | import org.assertj.core.api.SoftAssertions; |
42 | 54 | import org.junit.jupiter.api.BeforeEach; |
43 | 55 | import org.junit.jupiter.api.Test; |
| 56 | +import org.mockito.Mockito; |
44 | 57 |
|
45 | 58 | /** |
46 | 59 | * Test that queries are on the right URL, with the right headers. Also check that the received |
@@ -402,4 +415,113 @@ void testExecuteRequestFromJsonThrows() { |
402 | 415 | .isInstanceOf(IllegalArgumentException.class) |
403 | 416 | .hasMessageContaining("not valid JSON"); |
404 | 417 | } |
| 418 | + |
| 419 | + @Test |
| 420 | + void testThrowsOnContentFilter() { |
| 421 | + var mock = mock(OrchestrationClient.class); |
| 422 | + when(mock.streamChatCompletion(any(), any())).thenCallRealMethod(); |
| 423 | + |
| 424 | + var deltaWithContentFilter = mock(OrchestrationChatCompletionDelta.class); |
| 425 | + when(deltaWithContentFilter.getFinishReason()).thenReturn("content_filter"); |
| 426 | + when(mock.streamChatCompletionDeltas(any())).thenReturn(Stream.of(deltaWithContentFilter)); |
| 427 | + |
| 428 | + // this must not throw, since the stream is lazily evaluated |
| 429 | + var stream = mock.streamChatCompletion(new OrchestrationPrompt(""), config); |
| 430 | + assertThatThrownBy(stream::toList) |
| 431 | + .isInstanceOf(OrchestrationClientException.class) |
| 432 | + .hasMessageContaining("Content filter"); |
| 433 | + } |
| 434 | + |
| 435 | + @Test |
| 436 | + void streamChatCompletionDeltas() throws IOException { |
| 437 | + try (var inputStream = spy(fileLoader.apply("streamChatCompletion.txt"))) { |
| 438 | + |
| 439 | + final var httpClient = mock(HttpClient.class); |
| 440 | + ApacheHttpClient5Accessor.setHttpClientFactory(destination -> httpClient); |
| 441 | + |
| 442 | + // Create a mock response |
| 443 | + final var mockResponse = new BasicClassicHttpResponse(200, "OK"); |
| 444 | + final var inputStreamEntity = new InputStreamEntity(inputStream, ContentType.TEXT_PLAIN); |
| 445 | + mockResponse.setEntity(inputStreamEntity); |
| 446 | + mockResponse.setHeader("Content-Type", "text/event-stream"); |
| 447 | + |
| 448 | + // Configure the HttpClient mock to return the mock response |
| 449 | + doReturn(mockResponse).when(httpClient).executeOpen(any(), any(), any()); |
| 450 | + |
| 451 | + var prompt = |
| 452 | + new OrchestrationPrompt( |
| 453 | + "Can you give me the first 100 numbers of the Fibonacci sequence?"); |
| 454 | + var request = OrchestrationClient.toCompletionPostRequest(prompt, config); |
| 455 | + |
| 456 | + try (Stream<OrchestrationChatCompletionDelta> stream = |
| 457 | + client.streamChatCompletionDeltas(request)) { |
| 458 | + var deltaList = stream.toList(); |
| 459 | + |
| 460 | + assertThat(deltaList).hasSize(3); |
| 461 | + // the first delta doesn't have any content |
| 462 | + assertThat(deltaList.get(0).getDeltaContent()).isEqualTo(""); |
| 463 | + assertThat(deltaList.get(1).getDeltaContent()).isEqualTo("Sure"); |
| 464 | + assertThat(deltaList.get(2).getDeltaContent()).isEqualTo("!"); |
| 465 | + |
| 466 | + assertThat(deltaList.get(0).getRequestId()).isEqualTo("5bd87b41-6368-4c18-aaae-47ab82e9475b"); |
| 467 | + assertThat(deltaList.get(1).getRequestId()).isEqualTo("5bd87b41-6368-4c18-aaae-47ab82e9475b"); |
| 468 | + assertThat(deltaList.get(2).getRequestId()).isEqualTo("5bd87b41-6368-4c18-aaae-47ab82e9475b"); |
| 469 | + |
| 470 | + // should be of type LLMModuleResultStreaming, will be fixed with a discriminator |
| 471 | + var result0 = (LLMModuleResultSynchronous) deltaList.get(0).getOrchestrationResult(); |
| 472 | + var result1 = (LLMModuleResultSynchronous) deltaList.get(1).getOrchestrationResult(); |
| 473 | + var result2 = (LLMModuleResultSynchronous) deltaList.get(2).getOrchestrationResult(); |
| 474 | + |
| 475 | + assertThat(result0.getSystemFingerprint()).isEmpty(); |
| 476 | + assertThat(result0.getId()).isEmpty(); |
| 477 | + assertThat(result0.getCreated()).isEqualTo(0); |
| 478 | + assertThat(result0.getModel()).isEmpty(); |
| 479 | + assertThat(result0.getObject()).isEmpty(); |
| 480 | + // BUG: usage is absent from the request |
| 481 | + assertThat(result0.getUsage()).isNull(); |
| 482 | + assertThat(result0.getChoices()).hasSize(1); |
| 483 | + final var choices0 = result0.getChoices().get(0); |
| 484 | + assertThat(choices0.getIndex()).isEqualTo(0); |
| 485 | + assertThat(choices0.getFinishReason()).isEmpty(); |
| 486 | + final var message0 = (Map<String, Object>) choices0.getCustomField("delta"); |
| 487 | + assertThat(message0.get("role")).isEqualTo(""); |
| 488 | + assertThat(message0.get("content")).isEqualTo(""); |
| 489 | + List<ChatMessage> templating = deltaList.get(0).getModuleResults().getTemplating(); |
| 490 | + assertThat(templating).hasSize(1); |
| 491 | + assertThat(templating.get(0).getRole()).isEqualTo("user"); |
| 492 | + assertThat(templating.get(0).getContent()).isEqualTo("Hello world! Why is this phrase so famous?"); |
| 493 | + |
| 494 | + assertThat(result1.getSystemFingerprint()).isEqualTo("fp_808245b034"); |
| 495 | + assertThat(result1.getId()).isEqualTo("chatcmpl-AYZSQQwWv7ajJsyDBpMG4X01BBJxq"); |
| 496 | + assertThat(result1.getCreated()).isEqualTo(1732802814); |
| 497 | + assertThat(result1.getModel()).isEqualTo("gpt-35-turbo"); |
| 498 | + assertThat(result1.getObject()).isEqualTo("chat.completion.chunk"); |
| 499 | + assertThat(result1.getUsage()).isNull(); |
| 500 | + final var choices1 = result1.getChoices().get(0); |
| 501 | + assertThat(choices1.getIndex()).isEqualTo(0); |
| 502 | + assertThat(choices1.getFinishReason()).isEmpty(); |
| 503 | + // this should be getDelta(), only when the result is of type LLMModuleResultStreaming |
| 504 | + assertThat(choices1.getCustomField("delta")).isNotNull(); |
| 505 | + final var message1 = (Map<String, Object>) choices1.getCustomField("delta"); |
| 506 | + assertThat(message1.get("role")).isEqualTo("assistant"); |
| 507 | + assertThat(message1.get("content")).isEqualTo("Sure"); |
| 508 | + |
| 509 | + assertThat(result2.getSystemFingerprint()).isEqualTo("fp_808245b034"); |
| 510 | + assertThat(result2.getId()).isEqualTo("chatcmpl-AYZSQQwWv7ajJsyDBpMG4X01BBJxq"); |
| 511 | + assertThat(result2.getCreated()).isEqualTo(1732802814); |
| 512 | + assertThat(result2.getModel()).isEqualTo("gpt-35-turbo"); |
| 513 | + assertThat(result2.getObject()).isEqualTo("chat.completion.chunk"); |
| 514 | + assertThat(result2.getUsage()).isNull(); |
| 515 | + final var choices2 = result2.getChoices().get(0); |
| 516 | + assertThat(choices2.getIndex()).isEqualTo(0); |
| 517 | + assertThat(choices2.getFinishReason()).isEqualTo("stop"); |
| 518 | + // this should be getDelta(), only when the result is of type LLMModuleResultStreaming |
| 519 | + assertThat(choices2.getCustomField("delta")).isNotNull(); |
| 520 | + final var message2 = (Map<String, Object>) choices2.getCustomField("delta"); |
| 521 | + assertThat(message2.get("role")).isEqualTo("assistant"); |
| 522 | + assertThat(message2.get("content")).isEqualTo("!"); |
| 523 | + } |
| 524 | + Mockito.verify(inputStream, times(1)).close(); |
| 525 | + } |
| 526 | + } |
405 | 527 | } |
0 commit comments