22
33import com .fasterxml .jackson .core .type .TypeReference ;
44import com .fasterxml .jackson .databind .ObjectMapper ;
5- import io .modelcontextprotocol .spec .McpClientTransport ;
6- import io .modelcontextprotocol .spec .McpError ;
7- import io .modelcontextprotocol .spec .McpSchema ;
8- import io .modelcontextprotocol .spec .McpSessionNotFoundException ;
5+ import io .modelcontextprotocol .spec .*;
96import org .reactivestreams .Publisher ;
107import org .slf4j .Logger ;
118import org .slf4j .LoggerFactory ;
2825import java .util .concurrent .atomic .AtomicBoolean ;
2926import java .util .concurrent .atomic .AtomicLong ;
3027import java .util .concurrent .atomic .AtomicReference ;
28+ import java .util .function .Consumer ;
3129import java .util .function .Function ;
3230
3331public class WebClientStreamableHttpTransport implements McpClientTransport {
@@ -52,11 +50,9 @@ public class WebClientStreamableHttpTransport implements McpClientTransport {
5250
5351 private AtomicReference <Function <Mono <McpSchema .JSONRPCMessage >, Mono <McpSchema .JSONRPCMessage >>> handler = new AtomicReference <>();
5452
55- private final Disposable . Composite openConnections = Disposables . composite ();
53+ private final AtomicReference < McpTransportSession > activeSession = new AtomicReference <> ();
5654
57- private final AtomicBoolean initialized = new AtomicBoolean ();
58-
59- private final AtomicReference <String > sessionId = new AtomicReference <>();
55+ private final AtomicReference <Consumer <Throwable >> exceptionHandler = new AtomicReference <>();
6056
6157 public WebClientStreamableHttpTransport (ObjectMapper objectMapper , WebClient .Builder webClientBuilder ,
6258 String endpoint , boolean resumableStreams , boolean openConnectionOnStartup ) {
@@ -65,14 +61,12 @@ public WebClientStreamableHttpTransport(ObjectMapper objectMapper, WebClient.Bui
6561 this .endpoint = endpoint ;
6662 this .resumableStreams = resumableStreams ;
6763 this .openConnectionOnStartup = openConnectionOnStartup ;
64+ this .activeSession .set (new McpTransportSession ());
6865 }
6966
7067 @ Override
7168 public Mono <Void > connect (Function <Mono <McpSchema .JSONRPCMessage >, Mono <McpSchema .JSONRPCMessage >> handler ) {
7269 return Mono .deferContextual (ctx -> {
73- if (this .openConnections .isDisposed ()) {
74- return Mono .error (new RuntimeException ("Transport already disposed" ));
75- }
7670 this .handler .set (handler );
7771 if (openConnectionOnStartup ) {
7872 this .reconnect (null , ctx );
@@ -81,9 +75,20 @@ public Mono<Void> connect(Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchem
8175 });
8276 }
8377
78+ @ Override
79+ public void handleException (Consumer <Throwable > handler ) {
80+ this .exceptionHandler .set (handler );
81+ }
82+
8483 @ Override
8584 public Mono <Void > closeGracefully () {
86- return Mono .fromRunnable (this .openConnections ::dispose );
85+ return Mono .defer (() -> {
86+ McpTransportSession currentSession = this .activeSession .get ();
87+ if (currentSession != null ) {
88+ return currentSession .closeGracefully ();
89+ }
90+ return Mono .empty ();
91+ });
8792 }
8893
8994 private void reconnect (McpStream stream , ContextView ctx ) {
@@ -93,12 +98,13 @@ private void reconnect(McpStream stream, ContextView ctx) {
9398 // listen for messages.
9499 // If it doesn't, nothing actually happens here, that's just the way it is...
95100 final AtomicReference <Disposable > disposableRef = new AtomicReference <>();
101+ final McpTransportSession transportSession = this .activeSession .get ();
96102 Disposable connection = webClient .get ()
97103 .uri (this .endpoint )
98104 .accept (MediaType .TEXT_EVENT_STREAM )
99105 .headers (httpHeaders -> {
100- if (sessionId . get () != null ) {
101- httpHeaders .add ("mcp-session-id" , sessionId . get ());
106+ if (transportSession . sessionId () != null ) {
107+ httpHeaders .add ("mcp-session-id" , transportSession . sessionId ());
102108 }
103109 if (stream != null && stream .lastId () != null ) {
104110 httpHeaders .add ("last-event-id" , stream .lastId ());
@@ -123,22 +129,33 @@ else if (response.statusCode().isSameCodeAs(HttpStatus.METHOD_NOT_ALLOWED)) {
123129 logger .info ("The server does not support SSE streams, using request-response mode." );
124130 return Flux .empty ();
125131 }
132+ else if (response .statusCode ().isSameCodeAs (HttpStatus .NOT_FOUND )) {
133+ logger .info ("Session {} was not found on the MCP server" , transportSession .sessionId ());
134+
135+ McpSessionNotFoundException notFoundException = new McpSessionNotFoundException (
136+ "Session " + transportSession .sessionId () + " not found" );
137+ // inform the stream/connection subscriber
138+ return Flux .error (notFoundException );
139+ }
126140 else {
127141 return response .<McpSchema .JSONRPCMessage >createError ().doOnError (e -> {
128142 logger .info ("Opening an SSE stream failed. This can be safely ignored." , e );
129143 }).flux ();
130144 }
131145 })
146+ .doOnError (e -> {
147+ this .exceptionHandler .get ().accept (e );
148+ })
132149 .doFinally (s -> {
133150 Disposable ref = disposableRef .getAndSet (null );
134151 if (ref != null ) {
135- this . openConnections . remove (ref );
152+ transportSession . removeConnection (ref );
136153 }
137154 })
138155 .contextWrite (ctx )
139156 .subscribe ();
140157 disposableRef .set (connection );
141- this . openConnections . add (connection );
158+ transportSession . addConnection (connection );
142159 }
143160
144161 @ Override
@@ -151,20 +168,22 @@ public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
151168 // listen for messages.
152169 // If it doesn't, nothing actually happens here, that's just the way it is...
153170 final AtomicReference <Disposable > disposableRef = new AtomicReference <>();
171+ final McpTransportSession transportSession = this .activeSession .get ();
172+
154173 Disposable connection = webClient .post ()
155174 .uri (this .endpoint )
156175 .accept (MediaType .TEXT_EVENT_STREAM , MediaType .APPLICATION_JSON )
157176 .headers (httpHeaders -> {
158- if (sessionId . get () != null ) {
159- httpHeaders .add ("mcp-session-id" , sessionId . get ());
177+ if (transportSession . sessionId () != null ) {
178+ httpHeaders .add ("mcp-session-id" , transportSession . sessionId ());
160179 }
161180 })
162181 .bodyValue (message )
163182 .exchangeToFlux (response -> {
164- // TODO: this goes into the request phase
165- if (!initialized .compareAndExchange (false , true )) {
183+ if (transportSession .markInitialized ()) {
166184 if (!response .headers ().header ("mcp-session-id" ).isEmpty ()) {
167- sessionId .set (response .headers ().asHttpHeaders ().getFirst ("mcp-session-id" ));
185+ transportSession
186+ .setSessionId (response .headers ().asHttpHeaders ().getFirst ("mcp-session-id" ));
168187 // Once we have a session, we try to open an async stream for
169188 // the server to send notifications and requests out-of-band.
170189 reconnect (null , sink .contextView ());
@@ -176,10 +195,10 @@ public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
176195 // if (!response.statusCode().isSameCodeAs(HttpStatus.ACCEPTED)) {
177196 if (!response .statusCode ().is2xxSuccessful ()) {
178197 if (response .statusCode ().isSameCodeAs (HttpStatus .NOT_FOUND )) {
179- logger .info ("Session {} was not found on the MCP server" , sessionId . get ());
198+ logger .info ("Session {} was not found on the MCP server" , transportSession . sessionId ());
180199
181200 McpSessionNotFoundException notFoundException = new McpSessionNotFoundException (
182- "Session " + sessionId . get () + " not found" );
201+ "Session " + transportSession . sessionId () + " not found" );
183202 // inform the caller of sendMessage
184203 sink .error (notFoundException );
185204 // inform the stream/connection subscriber
@@ -233,8 +252,6 @@ else if (contentType.isCompatibleWith(MediaType.APPLICATION_JSON)) {
233252 }
234253 })
235254 .flatMapIterable (Function .identity ());
236- // .map(Mono::just)
237- // .flatMap(this.handler.get());
238255 }
239256 else {
240257 sink .error (new RuntimeException ("Unknown media type" ));
@@ -246,13 +263,13 @@ else if (contentType.isCompatibleWith(MediaType.APPLICATION_JSON)) {
246263 .doFinally (s -> {
247264 Disposable ref = disposableRef .getAndSet (null );
248265 if (ref != null ) {
249- this . openConnections . remove (ref );
266+ transportSession . removeConnection (ref );
250267 }
251268 })
252269 .contextWrite (sink .contextView ())
253270 .subscribe ();
254271 disposableRef .set (connection );
255- this . openConnections . add (connection );
272+ transportSession . addConnection (connection );
256273 });
257274 }
258275
0 commit comments