3737import org .junit .jupiter .api .BeforeEach ;
3838import org .junit .jupiter .params .ParameterizedTest ;
3939import org .junit .jupiter .params .provider .ValueSource ;
40- import reactor .core .publisher .Mono ;
4140import reactor .netty .DisposableServer ;
4241import reactor .netty .http .server .HttpServer ;
43- import reactor .test .StepVerifier ;
4442
4543import org .springframework .http .server .reactive .HttpHandler ;
4644import org .springframework .http .server .reactive .ReactorHttpHandlerAdapter ;
5048
5149import static org .assertj .core .api .Assertions .assertThat ;
5250import static org .assertj .core .api .Assertions .assertThatExceptionOfType ;
51+ import static org .assertj .core .api .Assertions .assertWith ;
5352import static org .awaitility .Awaitility .await ;
5453import static org .mockito .Mockito .mock ;
5554
@@ -109,12 +108,9 @@ void testCreateMessageWithoutSamplingCapabilities(String clientType) {
109108 var clientBuilder = clientBuilders .get (clientType );
110109
111110 McpServerFeatures .AsyncToolSpecification tool = new McpServerFeatures .AsyncToolSpecification (
112- new McpSchema .Tool ("tool1" , "tool1 description" , emptyJsonSchema ), (exchange , request ) -> {
113-
114- exchange .createMessage (mock (McpSchema .CreateMessageRequest .class )).block ();
115-
116- return Mono .just (mock (CallToolResult .class ));
117- });
111+ new McpSchema .Tool ("tool1" , "tool1 description" , emptyJsonSchema ),
112+ (exchange , request ) -> exchange .createMessage (mock (CreateMessageRequest .class ))
113+ .thenReturn (mock (CallToolResult .class )));
118114
119115 var server = McpServer .async (mcpServerTransportProvider ).serverInfo ("test-server" , "1.0.0" ).tools (tool ).build ();
120116
@@ -151,6 +147,8 @@ void testCreateMessageSuccess(String clientType) {
151147 CallToolResult callResponse = new McpSchema .CallToolResult (List .of (new McpSchema .TextContent ("CALL RESPONSE" )),
152148 null );
153149
150+ AtomicReference <CreateMessageResult > samplingResult = new AtomicReference <>();
151+
154152 McpServerFeatures .AsyncToolSpecification tool = new McpServerFeatures .AsyncToolSpecification (
155153 new McpSchema .Tool ("tool1" , "tool1 description" , emptyJsonSchema ), (exchange , request ) -> {
156154
@@ -165,16 +163,9 @@ void testCreateMessageSuccess(String clientType) {
165163 .build ())
166164 .build ();
167165
168- StepVerifier .create (exchange .createMessage (craeteMessageRequest )).consumeNextWith (result -> {
169- assertThat (result ).isNotNull ();
170- assertThat (result .role ()).isEqualTo (Role .USER );
171- assertThat (result .content ()).isInstanceOf (McpSchema .TextContent .class );
172- assertThat (((McpSchema .TextContent ) result .content ()).text ()).isEqualTo ("Test message" );
173- assertThat (result .model ()).isEqualTo ("MockModelName" );
174- assertThat (result .stopReason ()).isEqualTo (CreateMessageResult .StopReason .STOP_SEQUENCE );
175- }).verifyComplete ();
176-
177- return Mono .just (callResponse );
166+ return exchange .createMessage (craeteMessageRequest )
167+ .doOnNext (samplingResult ::set )
168+ .thenReturn (callResponse );
178169 });
179170
180171 var mcpServer = McpServer .async (mcpServerTransportProvider )
@@ -194,8 +185,17 @@ void testCreateMessageSuccess(String clientType) {
194185
195186 assertThat (response ).isNotNull ();
196187 assertThat (response ).isEqualTo (callResponse );
188+
189+ assertWith (samplingResult .get (), result -> {
190+ assertThat (result ).isNotNull ();
191+ assertThat (result .role ()).isEqualTo (Role .USER );
192+ assertThat (result .content ()).isInstanceOf (McpSchema .TextContent .class );
193+ assertThat (((McpSchema .TextContent ) result .content ()).text ()).isEqualTo ("Test message" );
194+ assertThat (result .model ()).isEqualTo ("MockModelName" );
195+ assertThat (result .stopReason ()).isEqualTo (CreateMessageResult .StopReason .STOP_SEQUENCE );
196+ });
197197 }
198- mcpServer .close ();
198+ mcpServer .closeGracefully (). block ();
199199 }
200200
201201 @ ParameterizedTest (name = "{0} : {displayName} " )
@@ -218,16 +218,13 @@ void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws Interr
218218 CreateMessageResult .StopReason .STOP_SEQUENCE );
219219 };
220220
221- var mcpClient = clientBuilder .clientInfo (new McpSchema .Implementation ("Sample client" , "0.0.0" ))
222- .capabilities (ClientCapabilities .builder ().sampling ().build ())
223- .sampling (samplingHandler )
224- .build ();
225-
226221 // Server
227222
228223 CallToolResult callResponse = new McpSchema .CallToolResult (List .of (new McpSchema .TextContent ("CALL RESPONSE" )),
229224 null );
230225
226+ AtomicReference <CreateMessageResult > samplingResult = new AtomicReference <>();
227+
231228 McpServerFeatures .AsyncToolSpecification tool = new McpServerFeatures .AsyncToolSpecification (
232229 new McpSchema .Tool ("tool1" , "tool1 description" , emptyJsonSchema ), (exchange , request ) -> {
233230
@@ -242,16 +239,9 @@ void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws Interr
242239 .build ())
243240 .build ();
244241
245- StepVerifier .create (exchange .createMessage (craeteMessageRequest )).consumeNextWith (result -> {
246- assertThat (result ).isNotNull ();
247- assertThat (result .role ()).isEqualTo (Role .USER );
248- assertThat (result .content ()).isInstanceOf (McpSchema .TextContent .class );
249- assertThat (((McpSchema .TextContent ) result .content ()).text ()).isEqualTo ("Test message" );
250- assertThat (result .model ()).isEqualTo ("MockModelName" );
251- assertThat (result .stopReason ()).isEqualTo (CreateMessageResult .StopReason .STOP_SEQUENCE );
252- }).verifyComplete ();
253-
254- return Mono .just (callResponse );
242+ return exchange .createMessage (craeteMessageRequest )
243+ .doOnNext (samplingResult ::set )
244+ .thenReturn (callResponse );
255245 });
256246
257247 var mcpServer = McpServer .async (mcpServerTransportProvider )
@@ -260,16 +250,30 @@ void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws Interr
260250 .tools (tool )
261251 .build ();
262252
263- InitializeResult initResult = mcpClient .initialize ();
264- assertThat (initResult ).isNotNull ();
253+ try (var mcpClient = clientBuilder .clientInfo (new McpSchema .Implementation ("Sample client" , "0.0.0" ))
254+ .capabilities (ClientCapabilities .builder ().sampling ().build ())
255+ .sampling (samplingHandler )
256+ .build ()) {
265257
266- CallToolResult response = mcpClient .callTool (new McpSchema .CallToolRequest ("tool1" , Map .of ()));
258+ InitializeResult initResult = mcpClient .initialize ();
259+ assertThat (initResult ).isNotNull ();
267260
268- assertThat (response ).isNotNull ();
269- assertThat (response ).isEqualTo (callResponse );
261+ CallToolResult response = mcpClient .callTool (new McpSchema .CallToolRequest ("tool1" , Map .of ()));
270262
271- mcpClient .close ();
272- mcpServer .close ();
263+ assertThat (response ).isNotNull ();
264+ assertThat (response ).isEqualTo (callResponse );
265+
266+ assertWith (samplingResult .get (), result -> {
267+ assertThat (result ).isNotNull ();
268+ assertThat (result .role ()).isEqualTo (Role .USER );
269+ assertThat (result .content ()).isInstanceOf (McpSchema .TextContent .class );
270+ assertThat (((McpSchema .TextContent ) result .content ()).text ()).isEqualTo ("Test message" );
271+ assertThat (result .model ()).isEqualTo ("MockModelName" );
272+ assertThat (result .stopReason ()).isEqualTo (CreateMessageResult .StopReason .STOP_SEQUENCE );
273+ });
274+ }
275+
276+ mcpServer .closeGracefully ().block ();
273277 }
274278
275279 @ ParameterizedTest (name = "{0} : {displayName} " )
@@ -283,7 +287,7 @@ void testCreateMessageWithRequestTimeoutFail(String clientType) throws Interrupt
283287 assertThat (request .messages ()).hasSize (1 );
284288 assertThat (request .messages ().get (0 ).content ()).isInstanceOf (McpSchema .TextContent .class );
285289 try {
286- TimeUnit .SECONDS .sleep (3 );
290+ TimeUnit .SECONDS .sleep (2 );
287291 }
288292 catch (InterruptedException e ) {
289293 throw new RuntimeException (e );
@@ -292,11 +296,6 @@ void testCreateMessageWithRequestTimeoutFail(String clientType) throws Interrupt
292296 CreateMessageResult .StopReason .STOP_SEQUENCE );
293297 };
294298
295- var mcpClient = clientBuilder .clientInfo (new McpSchema .Implementation ("Sample client" , "0.0.0" ))
296- .capabilities (ClientCapabilities .builder ().sampling ().build ())
297- .sampling (samplingHandler )
298- .build ();
299-
300299 // Server
301300
302301 CallToolResult callResponse = new McpSchema .CallToolResult (List .of (new McpSchema .TextContent ("CALL RESPONSE" )),
@@ -308,24 +307,9 @@ void testCreateMessageWithRequestTimeoutFail(String clientType) throws Interrupt
308307 var craeteMessageRequest = McpSchema .CreateMessageRequest .builder ()
309308 .messages (List .of (new McpSchema .SamplingMessage (McpSchema .Role .USER ,
310309 new McpSchema .TextContent ("Test message" ))))
311- .modelPreferences (ModelPreferences .builder ()
312- .hints (List .of ())
313- .costPriority (1.0 )
314- .speedPriority (1.0 )
315- .intelligencePriority (1.0 )
316- .build ())
317310 .build ();
318311
319- StepVerifier .create (exchange .createMessage (craeteMessageRequest )).consumeNextWith (result -> {
320- assertThat (result ).isNotNull ();
321- assertThat (result .role ()).isEqualTo (Role .USER );
322- assertThat (result .content ()).isInstanceOf (McpSchema .TextContent .class );
323- assertThat (((McpSchema .TextContent ) result .content ()).text ()).isEqualTo ("Test message" );
324- assertThat (result .model ()).isEqualTo ("MockModelName" );
325- assertThat (result .stopReason ()).isEqualTo (CreateMessageResult .StopReason .STOP_SEQUENCE );
326- }).verifyComplete ();
327-
328- return Mono .just (callResponse );
312+ return exchange .createMessage (craeteMessageRequest ).thenReturn (callResponse );
329313 });
330314
331315 var mcpServer = McpServer .async (mcpServerTransportProvider )
@@ -334,15 +318,21 @@ void testCreateMessageWithRequestTimeoutFail(String clientType) throws Interrupt
334318 .tools (tool )
335319 .build ();
336320
337- InitializeResult initResult = mcpClient .initialize ();
338- assertThat (initResult ).isNotNull ();
321+ try (var mcpClient = clientBuilder .clientInfo (new McpSchema .Implementation ("Sample client" , "0.0.0" ))
322+ .capabilities (ClientCapabilities .builder ().sampling ().build ())
323+ .sampling (samplingHandler )
324+ .build ()) {
339325
340- assertThatExceptionOfType (McpError .class ).isThrownBy (() -> {
341- mcpClient .callTool (new McpSchema .CallToolRequest ("tool1" , Map .of ()));
342- }).withMessageContaining ("Timeout" );
326+ InitializeResult initResult = mcpClient .initialize ();
327+ assertThat (initResult ).isNotNull ();
343328
344- mcpClient .close ();
345- mcpServer .close ();
329+ assertThatExceptionOfType (McpError .class ).isThrownBy (() -> {
330+ mcpClient .callTool (new McpSchema .CallToolRequest ("tool1" , Map .of ()));
331+ }).withMessageContaining ("within 1000ms" );
332+
333+ }
334+
335+ mcpServer .closeGracefully ().block ();
346336 }
347337
348338 // ---------------------------------------
@@ -412,9 +402,8 @@ void testRootsWithoutCapability(String clientType) {
412402 var mcpServer = McpServer .sync (mcpServerTransportProvider ).rootsChangeHandler ((exchange , rootsUpdate ) -> {
413403 }).tools (tool ).build ();
414404
415- try (
416- // Create client without roots capability
417- var mcpClient = clientBuilder .capabilities (ClientCapabilities .builder ().build ()).build ()) {
405+ // Create client without roots capability
406+ try (var mcpClient = clientBuilder .capabilities (ClientCapabilities .builder ().build ()).build ()) {
418407
419408 assertThat (mcpClient .initialize ()).isNotNull ();
420409
0 commit comments