3333import com .github .tomakehurst .wiremock .junit5 .WireMockTest ;
3434import com .github .tomakehurst .wiremock .stubbing .Scenario ;
3535import com .sap .ai .sdk .orchestration .model .ChatMessage ;
36- import com .sap .ai .sdk .orchestration .model .CompletionPostRequest ;
3736import com .sap .ai .sdk .orchestration .model .DPIEntities ;
3837import com .sap .ai .sdk .orchestration .model .GenericModuleResult ;
3938import com .sap .ai .sdk .orchestration .model .LLMModuleResultSynchronous ;
4039import com .sap .cloud .sdk .cloudplatform .connectivity .ApacheHttpClient5Accessor ;
40+ import com .sap .cloud .sdk .cloudplatform .connectivity .ApacheHttpClient5Cache ;
4141import com .sap .cloud .sdk .cloudplatform .connectivity .DefaultHttpDestination ;
4242import java .io .IOException ;
4343import java .io .InputStream ;
4646import java .util .Objects ;
4747import java .util .function .Function ;
4848import java .util .stream .Stream ;
49+ import javax .annotation .Nonnull ;
4950import org .apache .hc .client5 .http .classic .HttpClient ;
5051import org .apache .hc .core5 .http .ContentType ;
5152import org .apache .hc .core5 .http .io .entity .InputStreamEntity ;
5253import org .apache .hc .core5 .http .message .BasicClassicHttpResponse ;
5354import org .assertj .core .api .SoftAssertions ;
55+ import org .junit .jupiter .api .AfterEach ;
5456import org .junit .jupiter .api .BeforeEach ;
5557import org .junit .jupiter .api .Test ;
58+ import org .junit .jupiter .params .ParameterizedTest ;
59+ import org .junit .jupiter .params .provider .MethodSource ;
5660import org .mockito .Mockito ;
5761
5862/**
@@ -71,9 +75,9 @@ class OrchestrationUnitTest {
7175 private final Function <String , InputStream > fileLoader =
7276 filename -> Objects .requireNonNull (getClass ().getClassLoader ().getResourceAsStream (filename ));
7377
74- private OrchestrationClient client ;
75- private OrchestrationModuleConfig config ;
76- private OrchestrationPrompt prompt ;
78+ private static OrchestrationClient client ;
79+ private static OrchestrationModuleConfig config ;
80+ private static OrchestrationPrompt prompt ;
7781
7882 @ BeforeEach
7983 void setup (WireMockRuntimeInfo server ) {
@@ -82,6 +86,13 @@ void setup(WireMockRuntimeInfo server) {
8286 client = new OrchestrationClient (destination );
8387 config = new OrchestrationModuleConfig ().withLlmConfig (CUSTOM_GPT_35 );
8488 prompt = new OrchestrationPrompt ("Hello World! Why is this phrase so famous?" );
89+ ApacheHttpClient5Accessor .setHttpClientCache (ApacheHttpClient5Cache .DISABLED );
90+ }
91+
92+ @ AfterEach
93+ void reset () {
94+ ApacheHttpClient5Accessor .setHttpClientCache (null );
95+ ApacheHttpClient5Accessor .setHttpClientFactory (null );
8596 }
8697
8798 @ Test
@@ -286,8 +297,20 @@ void maskingPseudonymization() throws IOException {
286297 }
287298 }
288299
289- @ Test
290- void testErrorHandling () {
300+ private static Runnable [] errorHandlingCalls () {
301+ return new Runnable [] {
302+ () -> client .chatCompletion (new OrchestrationPrompt ("" ), config ),
303+ () ->
304+ client
305+ .streamChatCompletion (new OrchestrationPrompt ("" ), config )
306+ // the stream needs to be consumed to parse the response
307+ .forEach (System .out ::println )
308+ };
309+ }
310+
311+ @ ParameterizedTest
312+ @ MethodSource ("errorHandlingCalls" )
313+ void testErrorHandling (@ Nonnull final Runnable request ) {
291314 stubFor (
292315 post (anyUrl ())
293316 .inScenario ("Errors" )
@@ -321,7 +344,6 @@ void testErrorHandling() {
321344 stubFor (post (anyUrl ()).inScenario ("Errors" ).whenScenarioStateIs ("4" ).willReturn (noContent ()));
322345
323346 final var softly = new SoftAssertions ();
324- final Runnable request = () -> client .executeRequest (mock (CompletionPostRequest .class ));
325347
326348 softly
327349 .assertThatThrownBy (request ::run )
@@ -432,6 +454,32 @@ void testThrowsOnContentFilter() {
432454 .hasMessageContaining ("Content filter" );
433455 }
434456
457+ @ Test
458+ void streamChatCompletionOutputFilterErrorHandling () throws IOException {
459+ try (var inputStream = spy (fileLoader .apply ("streamChatCompletionOutputFilter.txt" ))) {
460+
461+ final var httpClient = mock (HttpClient .class );
462+ ApacheHttpClient5Accessor .setHttpClientFactory (destination -> httpClient );
463+
464+ // Create a mock response
465+ final var mockResponse = new BasicClassicHttpResponse (200 , "OK" );
466+ final var inputStreamEntity = new InputStreamEntity (inputStream , ContentType .TEXT_PLAIN );
467+ mockResponse .setEntity (inputStreamEntity );
468+ mockResponse .setHeader ("Content-Type" , "text/event-stream" );
469+
470+ // Configure the HttpClient mock to return the mock response
471+ doReturn (mockResponse ).when (httpClient ).executeOpen (any (), any (), any ());
472+
473+ try (Stream <String > stream = client .streamChatCompletion (prompt , config )) {
474+ assertThatThrownBy (() -> stream .forEach (System .out ::println ))
475+ .isInstanceOf (OrchestrationClientException .class )
476+ .hasMessage ("Content filter filtered the output." );
477+ }
478+
479+ Mockito .verify (inputStream , times (1 )).close ();
480+ }
481+ }
482+
435483 @ Test
436484 void streamChatCompletionDeltas () throws IOException {
437485 try (var inputStream = spy (fileLoader .apply ("streamChatCompletion.txt" ))) {
@@ -470,6 +518,10 @@ void streamChatCompletionDeltas() throws IOException {
470518 assertThat (deltaList .get (2 ).getRequestId ())
471519 .isEqualTo ("5bd87b41-6368-4c18-aaae-47ab82e9475b" );
472520
521+ assertThat (deltaList .get (0 ).getFinishReason ()).isEqualTo ("" );
522+ assertThat (deltaList .get (1 ).getFinishReason ()).isEqualTo ("" );
523+ assertThat (deltaList .get (2 ).getFinishReason ()).isEqualTo ("stop" );
524+
473525 // should be of type LLMModuleResultStreaming, will be fixed with a discriminator
474526 var result0 = (LLMModuleResultSynchronous ) deltaList .get (0 ).getOrchestrationResult ();
475527 var result1 = (LLMModuleResultSynchronous ) deltaList .get (1 ).getOrchestrationResult ();
@@ -486,6 +538,8 @@ void streamChatCompletionDeltas() throws IOException {
486538 final var choices0 = result0 .getChoices ().get (0 );
487539 assertThat (choices0 .getIndex ()).isEqualTo (0 );
488540 assertThat (choices0 .getFinishReason ()).isEmpty ();
541+ assertThat (choices0 .getCustomField ("delta" )).isNotNull ();
542+ // this should be getDelta(), only when the result is of type LLMModuleResultStreaming
489543 final var message0 = (Map <String , Object >) choices0 .getCustomField ("delta" );
490544 assertThat (message0 .get ("role" )).isEqualTo ("" );
491545 assertThat (message0 .get ("content" )).isEqualTo ("" );
@@ -501,10 +555,10 @@ void streamChatCompletionDeltas() throws IOException {
501555 assertThat (result1 .getModel ()).isEqualTo ("gpt-35-turbo" );
502556 assertThat (result1 .getObject ()).isEqualTo ("chat.completion.chunk" );
503557 assertThat (result1 .getUsage ()).isNull ();
558+ assertThat (result1 .getChoices ()).hasSize (1 );
504559 final var choices1 = result1 .getChoices ().get (0 );
505560 assertThat (choices1 .getIndex ()).isEqualTo (0 );
506561 assertThat (choices1 .getFinishReason ()).isEmpty ();
507- // this should be getDelta(), only when the result is of type LLMModuleResultStreaming
508562 assertThat (choices1 .getCustomField ("delta" )).isNotNull ();
509563 final var message1 = (Map <String , Object >) choices1 .getCustomField ("delta" );
510564 assertThat (message1 .get ("role" )).isEqualTo ("assistant" );
@@ -516,6 +570,7 @@ void streamChatCompletionDeltas() throws IOException {
516570 assertThat (result2 .getModel ()).isEqualTo ("gpt-35-turbo" );
517571 assertThat (result2 .getObject ()).isEqualTo ("chat.completion.chunk" );
518572 assertThat (result2 .getUsage ()).isNull ();
573+ assertThat (result2 .getChoices ()).hasSize (1 );
519574 final var choices2 = result2 .getChoices ().get (0 );
520575 assertThat (choices2 .getIndex ()).isEqualTo (0 );
521576 assertThat (choices2 .getFinishReason ()).isEqualTo ("stop" );
0 commit comments