4
4
5
5
package io .modelcontextprotocol .server .transport ;
6
6
7
- import java .io .BufferedReader ;
8
- import java .io .IOException ;
9
- import java .io .PrintWriter ;
10
-
11
- import org .slf4j .Logger ;
12
- import org .slf4j .LoggerFactory ;
13
-
14
7
import com .fasterxml .jackson .databind .ObjectMapper ;
15
-
16
- import io .modelcontextprotocol .server .DefaultMcpTransportContext ;
17
8
import io .modelcontextprotocol .server .McpStatelessServerHandler ;
18
9
import io .modelcontextprotocol .server .McpTransportContext ;
19
10
import io .modelcontextprotocol .server .McpTransportContextExtractor ;
11
+ import io .modelcontextprotocol .server .StatelessMcpTransportContext ;
20
12
import io .modelcontextprotocol .spec .McpError ;
21
13
import io .modelcontextprotocol .spec .McpSchema ;
22
14
import io .modelcontextprotocol .spec .McpStatelessServerTransport ;
26
18
import jakarta .servlet .http .HttpServlet ;
27
19
import jakarta .servlet .http .HttpServletRequest ;
28
20
import jakarta .servlet .http .HttpServletResponse ;
21
+ import org .slf4j .Logger ;
22
+ import org .slf4j .LoggerFactory ;
29
23
import reactor .core .publisher .Mono ;
30
24
25
+ import java .io .BufferedReader ;
26
+ import java .io .IOException ;
27
+ import java .io .PrintWriter ;
28
+ import java .util .concurrent .atomic .AtomicBoolean ;
29
+ import java .util .concurrent .atomic .AtomicInteger ;
30
+ import java .util .function .BiConsumer ;
31
+
31
32
/**
32
33
* Implementation of an HttpServlet based {@link McpStatelessServerTransport}.
33
34
*
@@ -123,11 +124,16 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
123
124
return ;
124
125
}
125
126
126
- McpTransportContext transportContext = this .contextExtractor .extract (request , new DefaultMcpTransportContext ());
127
+ AtomicInteger nextId = new AtomicInteger (0 );
128
+ AtomicBoolean upgradedToSse = new AtomicBoolean (false );
129
+ BiConsumer <String , Object > notificationHandler = buildNotificationHandler (response , upgradedToSse , nextId );
130
+ McpTransportContext transportContext = this .contextExtractor .extract (request ,
131
+ new StatelessMcpTransportContext (notificationHandler ));
127
132
128
133
String accept = request .getHeader (ACCEPT );
129
134
if (accept == null || !(accept .contains (APPLICATION_JSON ) && accept .contains (TEXT_EVENT_STREAM ))) {
130
- this .responseError (response , HttpServletResponse .SC_BAD_REQUEST ,
135
+ this .responseError (response , HttpServletResponse .SC_BAD_REQUEST , null , upgradedToSse .get (),
136
+ nextId .getAndIncrement (),
131
137
new McpError ("Both application/json and text/event-stream required in Accept header" ));
132
138
return ;
133
139
}
@@ -149,18 +155,24 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
149
155
.contextWrite (ctx -> ctx .put (McpTransportContext .KEY , transportContext ))
150
156
.block ();
151
157
152
- response .setContentType (APPLICATION_JSON );
153
- response .setCharacterEncoding (UTF_8 );
154
- response .setStatus (HttpServletResponse .SC_OK );
155
-
156
158
String jsonResponseText = objectMapper .writeValueAsString (jsonrpcResponse );
157
- PrintWriter writer = response .getWriter ();
158
- writer .write (jsonResponseText );
159
- writer .flush ();
159
+ if (upgradedToSse .get ()) {
160
+ sendEvent (response .getWriter (), jsonResponseText , nextId .getAndIncrement ());
161
+ }
162
+ else {
163
+ response .setContentType (APPLICATION_JSON );
164
+ response .setCharacterEncoding (UTF_8 );
165
+ response .setStatus (HttpServletResponse .SC_OK );
166
+
167
+ PrintWriter writer = response .getWriter ();
168
+ writer .write (jsonResponseText );
169
+ writer .flush ();
170
+ }
160
171
}
161
172
catch (Exception e ) {
162
173
logger .error ("Failed to handle request: {}" , e .getMessage ());
163
- this .responseError (response , HttpServletResponse .SC_INTERNAL_SERVER_ERROR ,
174
+ this .responseError (response , HttpServletResponse .SC_INTERNAL_SERVER_ERROR , jsonrpcRequest .id (),
175
+ upgradedToSse .get (), nextId .getAndIncrement (),
164
176
new McpError ("Failed to handle request: " + e .getMessage ()));
165
177
}
166
178
}
@@ -173,41 +185,53 @@ else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) {
173
185
}
174
186
catch (Exception e ) {
175
187
logger .error ("Failed to handle notification: {}" , e .getMessage ());
176
- this .responseError (response , HttpServletResponse .SC_INTERNAL_SERVER_ERROR ,
188
+ this .responseError (response , HttpServletResponse .SC_INTERNAL_SERVER_ERROR , null ,
189
+ upgradedToSse .get (), nextId .getAndIncrement (),
177
190
new McpError ("Failed to handle notification: " + e .getMessage ()));
178
191
}
179
192
}
180
193
else {
181
- this .responseError (response , HttpServletResponse .SC_BAD_REQUEST ,
182
- new McpError ("The server accepts either requests or notifications" ));
194
+ this .responseError (response , HttpServletResponse .SC_BAD_REQUEST , null , upgradedToSse . get (),
195
+ nextId . getAndIncrement (), new McpError ("The server accepts either requests or notifications" ));
183
196
}
184
197
}
185
198
catch (IllegalArgumentException | IOException e ) {
186
199
logger .error ("Failed to deserialize message: {}" , e .getMessage ());
187
- this .responseError (response , HttpServletResponse .SC_BAD_REQUEST , new McpError ("Invalid message format" ));
200
+ this .responseError (response , HttpServletResponse .SC_BAD_REQUEST , null , upgradedToSse .get (),
201
+ nextId .getAndIncrement (), new McpError ("Invalid message format" ));
188
202
}
189
203
catch (Exception e ) {
190
204
logger .error ("Unexpected error handling message: {}" , e .getMessage ());
191
- this .responseError (response , HttpServletResponse .SC_INTERNAL_SERVER_ERROR ,
192
- new McpError ("Unexpected error: " + e .getMessage ()));
205
+ this .responseError (response , HttpServletResponse .SC_INTERNAL_SERVER_ERROR , null , upgradedToSse . get (),
206
+ nextId . getAndIncrement (), new McpError ("Unexpected error: " + e .getMessage ()));
193
207
}
194
208
}
195
209
196
210
/**
197
211
* Sends an error response to the client.
198
212
* @param response The HTTP servlet response
199
213
* @param httpCode The HTTP status code
214
+ * @param upgradedToSse true if the response is upgraded to SSE, false otherwise
215
+ * @param eventIdIfNeeded if upgradedToSse, the event ID to use, otherwise ignored
200
216
* @param mcpError The MCP error to send
201
217
* @throws IOException If an I/O error occurs
202
218
*/
203
- private void responseError (HttpServletResponse response , int httpCode , McpError mcpError ) throws IOException {
204
- response .setContentType (APPLICATION_JSON );
205
- response .setCharacterEncoding (UTF_8 );
206
- response .setStatus (httpCode );
207
- String jsonError = objectMapper .writeValueAsString (mcpError );
208
- PrintWriter writer = response .getWriter ();
209
- writer .write (jsonError );
210
- writer .flush ();
219
+ private void responseError (HttpServletResponse response , int httpCode , Object requestId , boolean upgradedToSse ,
220
+ int eventIdIfNeeded , McpError mcpError ) throws IOException {
221
+ if (upgradedToSse ) {
222
+ String jsonError = objectMapper .writeValueAsString (new McpSchema .JSONRPCResponse (McpSchema .JSONRPC_VERSION ,
223
+ requestId , null , mcpError .getJsonRpcError ()));
224
+ sendEvent (response .getWriter (), jsonError , eventIdIfNeeded );
225
+ }
226
+ else {
227
+ response .setContentType (APPLICATION_JSON );
228
+ response .setCharacterEncoding (UTF_8 );
229
+ response .setStatus (httpCode );
230
+ PrintWriter writer = response .getWriter ();
231
+ String jsonError = objectMapper .writeValueAsString (mcpError );
232
+ writer .write (jsonError );
233
+ writer .flush ();
234
+ }
211
235
}
212
236
213
237
/**
@@ -303,4 +327,43 @@ public HttpServletStatelessServerTransport build() {
303
327
304
328
}
305
329
330
+ private BiConsumer <String , Object > buildNotificationHandler (HttpServletResponse response ,
331
+ AtomicBoolean upgradedToSse , AtomicInteger nextId ) {
332
+ AtomicBoolean responseInitialized = new AtomicBoolean (false );
333
+
334
+ return (notificationMethod , params ) -> {
335
+ if (responseInitialized .compareAndSet (false , true )) {
336
+ response .setContentType (TEXT_EVENT_STREAM );
337
+ response .setCharacterEncoding (UTF_8 );
338
+ response .setStatus (HttpServletResponse .SC_OK );
339
+ }
340
+
341
+ upgradedToSse .set (true );
342
+
343
+ McpSchema .JSONRPCNotification notification = new McpSchema .JSONRPCNotification (McpSchema .JSONRPC_VERSION ,
344
+ notificationMethod , params );
345
+ try {
346
+ sendEvent (response .getWriter (), objectMapper .writeValueAsString (notification ),
347
+ nextId .getAndIncrement ());
348
+ }
349
+ catch (IOException e ) {
350
+ logger .error ("Failed to handle notification: {}" , e .getMessage ());
351
+ throw new McpError (new McpSchema .JSONRPCResponse .JSONRPCError (McpSchema .ErrorCodes .INTERNAL_ERROR ,
352
+ e .getMessage (), null ));
353
+ }
354
+ };
355
+ }
356
+
357
+ private void sendEvent (PrintWriter writer , String data , int id ) throws IOException {
358
+ // tested with MCP inspector. Event must consist of these two fields and only
359
+ // these two fields
360
+ writer .write ("id: " + id + "\n " );
361
+ writer .write ("data: " + data + "\n \n " );
362
+ writer .flush ();
363
+
364
+ if (writer .checkError ()) {
365
+ throw new IOException ("Client disconnected" );
366
+ }
367
+ }
368
+
306
369
}
0 commit comments