88import java .util .List ;
99import java .util .Map ;
1010import java .util .concurrent .ConcurrentHashMap ;
11+ import java .util .concurrent .TimeUnit ;
1112import java .util .concurrent .atomic .AtomicReference ;
13+ import java .util .function .BiFunction ;
1214import java .util .function .Function ;
1315import java .util .stream .Collectors ;
1416
1820import io .modelcontextprotocol .client .transport .WebFluxSseClientTransport ;
1921import io .modelcontextprotocol .server .McpServer ;
2022import io .modelcontextprotocol .server .McpServerFeatures ;
23+ import io .modelcontextprotocol .server .TestUtil ;
24+ import io .modelcontextprotocol .server .McpSyncServerExchange ;
2125import io .modelcontextprotocol .server .transport .WebFluxSseServerTransportProvider ;
2226import io .modelcontextprotocol .spec .McpError ;
2327import io .modelcontextprotocol .spec .McpSchema ;
24- import io .modelcontextprotocol .spec .McpSchema .CallToolResult ;
25- import io .modelcontextprotocol .spec .McpSchema .ClientCapabilities ;
26- import io .modelcontextprotocol .spec .McpSchema .CreateMessageRequest ;
27- import io .modelcontextprotocol .spec .McpSchema .CreateMessageResult ;
28- import io .modelcontextprotocol .spec .McpSchema .InitializeResult ;
29- import io .modelcontextprotocol .spec .McpSchema .ModelPreferences ;
30- import io .modelcontextprotocol .spec .McpSchema .Role ;
31- import io .modelcontextprotocol .spec .McpSchema .Root ;
32- import io .modelcontextprotocol .spec .McpSchema .ServerCapabilities ;
33- import io .modelcontextprotocol .spec .McpSchema .Tool ;
28+ import io .modelcontextprotocol .spec .McpSchema .*;
29+ import io .modelcontextprotocol .spec .McpSchema .ServerCapabilities .CompletionCapabilities ;
3430import org .junit .jupiter .api .AfterEach ;
3531import org .junit .jupiter .api .BeforeEach ;
3632import org .junit .jupiter .params .ParameterizedTest ;
3733import org .junit .jupiter .params .provider .ValueSource ;
38- import reactor .core .publisher .Mono ;
3934import reactor .netty .DisposableServer ;
4035import reactor .netty .http .server .HttpServer ;
41- import reactor .test .StepVerifier ;
4236
4337import org .springframework .http .server .reactive .HttpHandler ;
4438import org .springframework .http .server .reactive .ReactorHttpHandlerAdapter ;
4741import org .springframework .web .reactive .function .server .RouterFunctions ;
4842
4943import static org .assertj .core .api .Assertions .assertThat ;
44+ import static org .assertj .core .api .Assertions .assertThatExceptionOfType ;
45+ import static org .assertj .core .api .Assertions .assertWith ;
5046import static org .awaitility .Awaitility .await ;
5147import static org .mockito .Mockito .mock ;
5248
53- public class WebFluxSseIntegrationTests {
49+ class WebFluxSseIntegrationTests {
5450
55- private static final int PORT = 8182 ;
51+ private static final int PORT = TestUtil . findAvailablePort () ;
5652
5753 private static final String CUSTOM_SSE_ENDPOINT = "/somePath/sse" ;
5854
@@ -106,12 +102,9 @@ void testCreateMessageWithoutSamplingCapabilities(String clientType) {
106102 var clientBuilder = clientBuilders .get (clientType );
107103
108104 McpServerFeatures .AsyncToolSpecification tool = new McpServerFeatures .AsyncToolSpecification (
109- new McpSchema .Tool ("tool1" , "tool1 description" , emptyJsonSchema ), (exchange , request ) -> {
110-
111- exchange .createMessage (mock (McpSchema .CreateMessageRequest .class )).block ();
112-
113- return Mono .just (mock (CallToolResult .class ));
114- });
105+ new McpSchema .Tool ("tool1" , "tool1 description" , emptyJsonSchema ),
106+ (exchange , request ) -> exchange .createMessage (mock (CreateMessageRequest .class ))
107+ .thenReturn (mock (CallToolResult .class )));
115108
116109 var server = McpServer .async (mcpServerTransportProvider ).serverInfo ("test-server" , "1.0.0" ).tools (tool ).build ();
117110
@@ -133,7 +126,7 @@ void testCreateMessageWithoutSamplingCapabilities(String clientType) {
133126
134127 @ ParameterizedTest (name = "{0} : {displayName} " )
135128 @ ValueSource (strings = { "httpclient" , "webflux" })
136- void testCreateMessageSuccess (String clientType ) throws InterruptedException {
129+ void testCreateMessageSuccess (String clientType ) {
137130
138131 var clientBuilder = clientBuilders .get (clientType );
139132
@@ -148,10 +141,12 @@ void testCreateMessageSuccess(String clientType) throws InterruptedException {
148141 CallToolResult callResponse = new McpSchema .CallToolResult (List .of (new McpSchema .TextContent ("CALL RESPONSE" )),
149142 null );
150143
144+ AtomicReference <CreateMessageResult > samplingResult = new AtomicReference <>();
145+
151146 McpServerFeatures .AsyncToolSpecification tool = new McpServerFeatures .AsyncToolSpecification (
152147 new McpSchema .Tool ("tool1" , "tool1 description" , emptyJsonSchema ), (exchange , request ) -> {
153148
154- var craeteMessageRequest = McpSchema .CreateMessageRequest .builder ()
149+ var createMessageRequest = McpSchema .CreateMessageRequest .builder ()
155150 .messages (List .of (new McpSchema .SamplingMessage (McpSchema .Role .USER ,
156151 new McpSchema .TextContent ("Test message" ))))
157152 .modelPreferences (ModelPreferences .builder ()
@@ -162,19 +157,89 @@ void testCreateMessageSuccess(String clientType) throws InterruptedException {
162157 .build ())
163158 .build ();
164159
165- StepVerifier .create (exchange .createMessage (craeteMessageRequest )).consumeNextWith (result -> {
166- assertThat (result ).isNotNull ();
167- assertThat (result .role ()).isEqualTo (Role .USER );
168- assertThat (result .content ()).isInstanceOf (McpSchema .TextContent .class );
169- assertThat (((McpSchema .TextContent ) result .content ()).text ()).isEqualTo ("Test message" );
170- assertThat (result .model ()).isEqualTo ("MockModelName" );
171- assertThat (result .stopReason ()).isEqualTo (CreateMessageResult .StopReason .STOP_SEQUENCE );
172- }).verifyComplete ();
160+ return exchange .createMessage (createMessageRequest )
161+ .doOnNext (samplingResult ::set )
162+ .thenReturn (callResponse );
163+ });
164+
165+ var mcpServer = McpServer .async (mcpServerTransportProvider )
166+ .serverInfo ("test-server" , "1.0.0" )
167+ .tools (tool )
168+ .build ();
169+
170+ try (var mcpClient = clientBuilder .clientInfo (new McpSchema .Implementation ("Sample client" , "0.0.0" ))
171+ .capabilities (ClientCapabilities .builder ().sampling ().build ())
172+ .sampling (samplingHandler )
173+ .build ()) {
174+
175+ InitializeResult initResult = mcpClient .initialize ();
176+ assertThat (initResult ).isNotNull ();
177+
178+ CallToolResult response = mcpClient .callTool (new McpSchema .CallToolRequest ("tool1" , Map .of ()));
179+
180+ assertThat (response ).isNotNull ();
181+ assertThat (response ).isEqualTo (callResponse );
182+
183+ assertWith (samplingResult .get (), result -> {
184+ assertThat (result ).isNotNull ();
185+ assertThat (result .role ()).isEqualTo (Role .USER );
186+ assertThat (result .content ()).isInstanceOf (McpSchema .TextContent .class );
187+ assertThat (((McpSchema .TextContent ) result .content ()).text ()).isEqualTo ("Test message" );
188+ assertThat (result .model ()).isEqualTo ("MockModelName" );
189+ assertThat (result .stopReason ()).isEqualTo (CreateMessageResult .StopReason .STOP_SEQUENCE );
190+ });
191+ }
192+ mcpServer .closeGracefully ().block ();
193+ }
194+
195+ @ ParameterizedTest (name = "{0} : {displayName} " )
196+ @ ValueSource (strings = { "httpclient" , "webflux" })
197+ void testCreateMessageWithRequestTimeoutSuccess (String clientType ) throws InterruptedException {
173198
174- return Mono .just (callResponse );
199+ // Client
200+ var clientBuilder = clientBuilders .get (clientType );
201+
202+ Function <CreateMessageRequest , CreateMessageResult > samplingHandler = request -> {
203+ assertThat (request .messages ()).hasSize (1 );
204+ assertThat (request .messages ().get (0 ).content ()).isInstanceOf (McpSchema .TextContent .class );
205+ try {
206+ TimeUnit .SECONDS .sleep (2 );
207+ }
208+ catch (InterruptedException e ) {
209+ throw new RuntimeException (e );
210+ }
211+ return new CreateMessageResult (Role .USER , new McpSchema .TextContent ("Test message" ), "MockModelName" ,
212+ CreateMessageResult .StopReason .STOP_SEQUENCE );
213+ };
214+
215+ // Server
216+
217+ CallToolResult callResponse = new McpSchema .CallToolResult (List .of (new McpSchema .TextContent ("CALL RESPONSE" )),
218+ null );
219+
220+ AtomicReference <CreateMessageResult > samplingResult = new AtomicReference <>();
221+
222+ McpServerFeatures .AsyncToolSpecification tool = new McpServerFeatures .AsyncToolSpecification (
223+ new McpSchema .Tool ("tool1" , "tool1 description" , emptyJsonSchema ), (exchange , request ) -> {
224+
225+ var craeteMessageRequest = McpSchema .CreateMessageRequest .builder ()
226+ .messages (List .of (new McpSchema .SamplingMessage (McpSchema .Role .USER ,
227+ new McpSchema .TextContent ("Test message" ))))
228+ .modelPreferences (ModelPreferences .builder ()
229+ .hints (List .of ())
230+ .costPriority (1.0 )
231+ .speedPriority (1.0 )
232+ .intelligencePriority (1.0 )
233+ .build ())
234+ .build ();
235+
236+ return exchange .createMessage (craeteMessageRequest )
237+ .doOnNext (samplingResult ::set )
238+ .thenReturn (callResponse );
175239 });
176240
177241 var mcpServer = McpServer .async (mcpServerTransportProvider )
242+ .requestTimeout (Duration .ofSeconds (4 ))
178243 .serverInfo ("test-server" , "1.0.0" )
179244 .tools (tool )
180245 .build ();
@@ -191,8 +256,77 @@ void testCreateMessageSuccess(String clientType) throws InterruptedException {
191256
192257 assertThat (response ).isNotNull ();
193258 assertThat (response ).isEqualTo (callResponse );
259+
260+ assertWith (samplingResult .get (), result -> {
261+ assertThat (result ).isNotNull ();
262+ assertThat (result .role ()).isEqualTo (Role .USER );
263+ assertThat (result .content ()).isInstanceOf (McpSchema .TextContent .class );
264+ assertThat (((McpSchema .TextContent ) result .content ()).text ()).isEqualTo ("Test message" );
265+ assertThat (result .model ()).isEqualTo ("MockModelName" );
266+ assertThat (result .stopReason ()).isEqualTo (CreateMessageResult .StopReason .STOP_SEQUENCE );
267+ });
194268 }
195- mcpServer .close ();
269+
270+ mcpServer .closeGracefully ().block ();
271+ }
272+
273+ @ ParameterizedTest (name = "{0} : {displayName} " )
274+ @ ValueSource (strings = { "httpclient" , "webflux" })
275+ void testCreateMessageWithRequestTimeoutFail (String clientType ) throws InterruptedException {
276+
277+ // Client
278+ var clientBuilder = clientBuilders .get (clientType );
279+
280+ Function <CreateMessageRequest , CreateMessageResult > samplingHandler = request -> {
281+ assertThat (request .messages ()).hasSize (1 );
282+ assertThat (request .messages ().get (0 ).content ()).isInstanceOf (McpSchema .TextContent .class );
283+ try {
284+ TimeUnit .SECONDS .sleep (2 );
285+ }
286+ catch (InterruptedException e ) {
287+ throw new RuntimeException (e );
288+ }
289+ return new CreateMessageResult (Role .USER , new McpSchema .TextContent ("Test message" ), "MockModelName" ,
290+ CreateMessageResult .StopReason .STOP_SEQUENCE );
291+ };
292+
293+ // Server
294+
295+ CallToolResult callResponse = new McpSchema .CallToolResult (List .of (new McpSchema .TextContent ("CALL RESPONSE" )),
296+ null );
297+
298+ McpServerFeatures .AsyncToolSpecification tool = new McpServerFeatures .AsyncToolSpecification (
299+ new McpSchema .Tool ("tool1" , "tool1 description" , emptyJsonSchema ), (exchange , request ) -> {
300+
301+ var craeteMessageRequest = McpSchema .CreateMessageRequest .builder ()
302+ .messages (List .of (new McpSchema .SamplingMessage (McpSchema .Role .USER ,
303+ new McpSchema .TextContent ("Test message" ))))
304+ .build ();
305+
306+ return exchange .createMessage (craeteMessageRequest ).thenReturn (callResponse );
307+ });
308+
309+ var mcpServer = McpServer .async (mcpServerTransportProvider )
310+ .requestTimeout (Duration .ofSeconds (1 ))
311+ .serverInfo ("test-server" , "1.0.0" )
312+ .tools (tool )
313+ .build ();
314+
315+ try (var mcpClient = clientBuilder .clientInfo (new McpSchema .Implementation ("Sample client" , "0.0.0" ))
316+ .capabilities (ClientCapabilities .builder ().sampling ().build ())
317+ .sampling (samplingHandler )
318+ .build ()) {
319+
320+ InitializeResult initResult = mcpClient .initialize ();
321+ assertThat (initResult ).isNotNull ();
322+
323+ assertThatExceptionOfType (McpError .class ).isThrownBy (() -> {
324+ mcpClient .callTool (new McpSchema .CallToolRequest ("tool1" , Map .of ()));
325+ }).withMessageContaining ("within 1000ms" );
326+
327+ }
328+
329+ mcpServer .closeGracefully ().block ();
196330 }
197331
198332 // ---------------------------------------
@@ -262,9 +396,8 @@ void testRootsWithoutCapability(String clientType) {
262396 var mcpServer = McpServer .sync (mcpServerTransportProvider ).rootsChangeHandler ((exchange , rootsUpdate ) -> {
263397 }).tools (tool ).build ();
264398
265- try (
266- // Create client without roots capability
267- var mcpClient = clientBuilder .capabilities (ClientCapabilities .builder ().build ()).build ()) {
399+ // Create client without roots capability
400+ try (var mcpClient = clientBuilder .capabilities (ClientCapabilities .builder ().build ()).build ()) {
268401
269402 assertThat (mcpClient .initialize ()).isNotNull ();
270403
@@ -282,7 +415,7 @@ void testRootsWithoutCapability(String clientType) {
282415
283416 @ ParameterizedTest (name = "{0} : {displayName} " )
284417 @ ValueSource (strings = { "httpclient" , "webflux" })
285- void testRootsNotifciationWithEmptyRootsList (String clientType ) {
418+ void testRootsNotificationWithEmptyRootsList (String clientType ) {
286419 var clientBuilder = clientBuilders .get (clientType );
287420
288421 AtomicReference <List <Root >> rootsRef = new AtomicReference <>();
@@ -620,4 +753,54 @@ void testLoggingNotification(String clientType) {
620753 mcpServer .close ();
621754 }
622755
623- }
756+ // ---------------------------------------
757+ // Completion Tests
758+ // ---------------------------------------
759+ @ ParameterizedTest (name = "{0} : Completion call" )
760+ @ ValueSource (strings = { "httpclient" , "webflux" })
761+ void testCompletionShouldReturnExpectedSuggestions (String clientType ) {
762+ var clientBuilder = clientBuilders .get (clientType );
763+
764+ var expectedValues = List .of ("python" , "pytorch" , "pyside" );
765+ var completionResponse = new McpSchema .CompleteResult (new CompleteResult .CompleteCompletion (expectedValues , 10 , // total
766+ true // hasMore
767+ ));
768+
769+ AtomicReference <CompleteRequest > samplingRequest = new AtomicReference <>();
770+ BiFunction <McpSyncServerExchange , CompleteRequest , CompleteResult > completionHandler = (mcpSyncServerExchange ,
771+ request ) -> {
772+ samplingRequest .set (request );
773+ return completionResponse ;
774+ };
775+
776+ var mcpServer = McpServer .sync (mcpServerTransportProvider )
777+ .capabilities (ServerCapabilities .builder ().completions ().build ())
778+ .prompts (new McpServerFeatures .SyncPromptSpecification (
779+ new Prompt ("code_review" , "this is code review prompt" ,
780+ List .of (new PromptArgument ("language" , "string" , false ))),
781+ (mcpSyncServerExchange , getPromptRequest ) -> null ))
782+ .completions (new McpServerFeatures .SyncCompletionSpecification (
783+ new McpSchema .PromptReference ("ref/prompt" , "code_review" ), completionHandler ))
784+ .build ();
785+
786+ try (var mcpClient = clientBuilder .build ()) {
787+
788+ InitializeResult initResult = mcpClient .initialize ();
789+ assertThat (initResult ).isNotNull ();
790+
791+ CompleteRequest request = new CompleteRequest (new PromptReference ("ref/prompt" , "code_review" ),
792+ new CompleteRequest .CompleteArgument ("language" , "py" ));
793+
794+ CompleteResult result = mcpClient .completeCompletion (request );
795+
796+ assertThat (result ).isNotNull ();
797+
798+ assertThat (samplingRequest .get ().argument ().name ()).isEqualTo ("language" );
799+ assertThat (samplingRequest .get ().argument ().value ()).isEqualTo ("py" );
800+ assertThat (samplingRequest .get ().ref ().type ()).isEqualTo ("ref/prompt" );
801+ }
802+
803+ mcpServer .close ();
804+ }
805+
806+ }
0 commit comments