4
4
5
5
package io .modelcontextprotocol .server ;
6
6
7
- import static net .javacrumbs .jsonunit .assertj .JsonAssertions .assertThatJson ;
8
- import static net .javacrumbs .jsonunit .assertj .JsonAssertions .json ;
9
- import static org .assertj .core .api .Assertions .assertThat ;
10
- import static org .awaitility .Awaitility .await ;
11
-
12
- import java .time .Duration ;
13
- import java .util .List ;
14
- import java .util .Map ;
15
- import java .util .concurrent .ConcurrentHashMap ;
16
- import java .util .concurrent .atomic .AtomicReference ;
17
- import java .util .function .BiFunction ;
18
-
19
- import org .apache .catalina .LifecycleException ;
20
- import org .apache .catalina .LifecycleState ;
21
- import org .apache .catalina .startup .Tomcat ;
22
- import org .junit .jupiter .api .AfterEach ;
23
- import org .junit .jupiter .api .BeforeEach ;
24
- import org .junit .jupiter .params .ParameterizedTest ;
25
- import org .junit .jupiter .params .provider .ValueSource ;
26
- import org .springframework .web .client .RestClient ;
27
-
28
7
import com .fasterxml .jackson .databind .ObjectMapper ;
29
-
30
8
import io .modelcontextprotocol .client .McpClient ;
31
9
import io .modelcontextprotocol .client .transport .HttpClientStreamableHttpTransport ;
32
10
import io .modelcontextprotocol .server .transport .HttpServletStatelessServerTransport ;
42
20
import io .modelcontextprotocol .spec .McpSchema .ServerCapabilities ;
43
21
import io .modelcontextprotocol .spec .McpSchema .Tool ;
44
22
import net .javacrumbs .jsonunit .core .Option ;
23
+ import org .apache .catalina .LifecycleException ;
24
+ import org .apache .catalina .LifecycleState ;
25
+ import org .apache .catalina .startup .Tomcat ;
26
+ import org .junit .jupiter .api .AfterEach ;
27
+ import org .junit .jupiter .api .BeforeEach ;
28
+ import org .junit .jupiter .api .Test ;
29
+ import org .junit .jupiter .params .ParameterizedTest ;
30
+ import org .junit .jupiter .params .provider .ValueSource ;
31
+ import org .springframework .web .client .RestClient ;
32
+
33
+ import java .net .URI ;
34
+ import java .net .http .HttpClient ;
35
+ import java .net .http .HttpRequest ;
36
+ import java .net .http .HttpResponse ;
37
+ import java .time .Duration ;
38
+ import java .util .Iterator ;
39
+ import java .util .List ;
40
+ import java .util .Map ;
41
+ import java .util .UUID ;
42
+ import java .util .concurrent .ConcurrentHashMap ;
43
+ import java .util .concurrent .atomic .AtomicReference ;
44
+ import java .util .function .BiFunction ;
45
+ import java .util .stream .Stream ;
46
+
47
+ import static io .modelcontextprotocol .server .transport .HttpServletStatelessServerTransport .APPLICATION_JSON ;
48
+ import static io .modelcontextprotocol .server .transport .HttpServletStatelessServerTransport .TEXT_EVENT_STREAM ;
49
+ import static net .javacrumbs .jsonunit .assertj .JsonAssertions .assertThatJson ;
50
+ import static net .javacrumbs .jsonunit .assertj .JsonAssertions .json ;
51
+ import static org .assertj .core .api .Assertions .assertThat ;
52
+ import static org .awaitility .Awaitility .await ;
45
53
46
54
class HttpServletStatelessIntegrationTests {
47
55
@@ -55,10 +63,13 @@ class HttpServletStatelessIntegrationTests {
55
63
56
64
private Tomcat tomcat ;
57
65
66
+ private ObjectMapper objectMapper ;
67
+
58
68
@ BeforeEach
59
69
public void before () {
70
+ objectMapper = new ObjectMapper ();
60
71
this .mcpStatelessServerTransport = HttpServletStatelessServerTransport .builder ()
61
- .objectMapper (new ObjectMapper () )
72
+ .objectMapper (objectMapper )
62
73
.messageEndpoint (CUSTOM_MESSAGE_ENDPOINT )
63
74
.build ();
64
75
@@ -213,6 +224,87 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) {
213
224
mcpServer .close ();
214
225
}
215
226
227
+ @ Test
228
+ void testNotifications () throws Exception {
229
+
230
+ Tool tool = Tool .builder ().name ("test" ).build ();
231
+
232
+ final int PROGRESS_QTY = 1000 ;
233
+ final String progressMessage = "We're working on it..." ;
234
+
235
+ var progressToken = UUID .randomUUID ().toString ();
236
+ var callResponse = new CallToolResult (List .of (), null , null , Map .of ("progressToken" , progressToken ));
237
+ McpStatelessServerFeatures .SyncToolSpecification toolSpecification = new McpStatelessServerFeatures .SyncToolSpecification (
238
+ tool , (transportContext , request ) -> {
239
+ // Simulate sending progress notifications - send enough to ensure
240
+ // that cunked transfer encoding is used
241
+ for (int i = 0 ; i < PROGRESS_QTY ; i ++) {
242
+ transportContext .sendNotification (McpSchema .METHOD_NOTIFICATION_PROGRESS ,
243
+ new McpSchema .ProgressNotification (progressToken , i , 5.0 , progressMessage ));
244
+ }
245
+ return callResponse ;
246
+ });
247
+
248
+ var mcpServer = McpServer .sync (mcpStatelessServerTransport )
249
+ .capabilities (ServerCapabilities .builder ().tools (true ).build ())
250
+ .tools (toolSpecification )
251
+ .build ();
252
+
253
+ HttpClient client = HttpClient .newBuilder ().version (HttpClient .Version .HTTP_1_1 ).build ();
254
+ HttpRequest request = HttpRequest .newBuilder ()
255
+ .method ("POST" ,
256
+ HttpRequest .BodyPublishers .ofString (
257
+ objectMapper .writeValueAsString (new McpSchema .JSONRPCRequest (McpSchema .JSONRPC_VERSION ,
258
+ "tools/call" , "1" , new McpSchema .CallToolRequest ("test" , Map .of ())))))
259
+ .header ("Content-Type" , APPLICATION_JSON )
260
+ .header ("Accept" , APPLICATION_JSON + "," + TEXT_EVENT_STREAM )
261
+ .uri (URI .create ("http://localhost:" + PORT + CUSTOM_MESSAGE_ENDPOINT ))
262
+ .build ();
263
+
264
+ HttpResponse <Stream <String >> response = client .send (request , HttpResponse .BodyHandlers .ofLines ());
265
+ assertThat (response .headers ().firstValue ("Transfer-Encoding" )).contains ("chunked" );
266
+
267
+ List <String > responseBody = response .body ().toList ();
268
+
269
+ assertThat (responseBody ).hasSize ((PROGRESS_QTY + 1 ) * 4 ); // 4 lines per progress
270
+ // notification + 4
271
+ // for
272
+ // the call result
273
+
274
+ Iterator <String > iterator = responseBody .iterator ();
275
+ for (int i = 0 ; i < PROGRESS_QTY ; ++i ) {
276
+ String eventLine = iterator .next ();
277
+ String idLine = iterator .next ();
278
+ String dataLine = iterator .next ();
279
+ String blankLine = iterator .next ();
280
+
281
+ McpSchema .ProgressNotification expectedNotification = new McpSchema .ProgressNotification (progressToken , i ,
282
+ 5.0 , progressMessage );
283
+ McpSchema .JSONRPCNotification expectedJsonRpcNotification = new McpSchema .JSONRPCNotification (
284
+ McpSchema .JSONRPC_VERSION , McpSchema .METHOD_NOTIFICATION_PROGRESS , expectedNotification );
285
+
286
+ assertThat (eventLine ).isEqualTo ("event: notification" );
287
+ assertThat (idLine ).isEqualTo ("id: " + i );
288
+ assertThat (dataLine ).isEqualTo ("data: " + objectMapper .writeValueAsString (expectedJsonRpcNotification ));
289
+ assertThat (blankLine ).isBlank ();
290
+ }
291
+
292
+ String eventLine = iterator .next ();
293
+ String idLine = iterator .next ();
294
+ String dataLine = iterator .next ();
295
+ String blankLine = iterator .next ();
296
+
297
+ assertThat (eventLine ).isEqualTo ("event: result" );
298
+ assertThat (idLine ).isEqualTo ("id: " + PROGRESS_QTY );
299
+ assertThat (dataLine ).isEqualTo ("data: " + objectMapper
300
+ .writeValueAsString (new McpSchema .JSONRPCResponse (McpSchema .JSONRPC_VERSION , "1" , callResponse , null )));
301
+ assertThat (blankLine ).isBlank ();
302
+
303
+ assertThat (iterator .hasNext ()).isFalse ();
304
+
305
+ mcpServer .close ();
306
+ }
307
+
216
308
// ---------------------------------------
217
309
// Tool Structured Output Schema Tests
218
310
// ---------------------------------------
0 commit comments