Skip to content

Commit 3dee024

Browse files
authored
fix(streamable-http): gracefully shutdown while client connected (#494)
* fix(streamable-http): gracefully shutdown while client connected * fix: adviced comments * fix: windows test build
1 parent 57d1ac9 commit 3dee024

File tree

5 files changed

+52
-15
lines changed

5 files changed

+52
-15
lines changed

crates/rmcp/src/transport/common/server_side_http.rs

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use http::Response;
66
use http_body::Body;
77
use http_body_util::{BodyExt, Empty, Full, combinators::BoxBody};
88
use sse_stream::{KeepAlive, Sse, SseBody};
9+
use tokio_util::sync::CancellationToken;
910

1011
use super::http_header::EVENT_STREAM_MIME_TYPE;
1112
use crate::model::{ClientJsonRpcMessage, ServerJsonRpcMessage};
@@ -65,20 +66,26 @@ pub struct ServerSseMessage {
6566
pub(crate) fn sse_stream_response(
6667
stream: impl futures::Stream<Item = ServerSseMessage> + Send + Sync + 'static,
6768
keep_alive: Option<Duration>,
69+
ct: CancellationToken,
6870
) -> Response<BoxBody<Bytes, Infallible>> {
6971
use futures::StreamExt;
70-
let stream = SseBody::new(stream.map(|message| {
71-
let data = serde_json::to_string(&message.message).expect("valid message");
72-
let mut sse = Sse::default().data(data);
73-
sse.id = message.event_id;
74-
Result::<Sse, Infallible>::Ok(sse)
75-
}));
72+
let stream = stream
73+
.map(|message| {
74+
let data = serde_json::to_string(&message.message).expect("valid message");
75+
let mut sse = Sse::default().data(data);
76+
sse.id = message.event_id;
77+
Result::<Sse, Infallible>::Ok(sse)
78+
})
79+
.take_until(async move { ct.cancelled().await });
80+
let stream = SseBody::new(stream);
81+
7682
let stream = match keep_alive {
7783
Some(duration) => stream
7884
.with_keep_alive::<TokioTimer>(KeepAlive::new().interval(duration))
7985
.boxed(),
8086
None => stream.boxed(),
8187
};
88+
8289
Response::builder()
8390
.status(http::StatusCode::OK)
8491
.header(http::header::CONTENT_TYPE, EVENT_STREAM_MIME_TYPE)

crates/rmcp/src/transport/streamable_http_server/tower.rs

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use http::{Method, Request, Response, header::ALLOW};
66
use http_body::Body;
77
use http_body_util::{BodyExt, Full, combinators::BoxBody};
88
use tokio_stream::wrappers::ReceiverStream;
9+
use tokio_util::sync::CancellationToken;
910

1011
use super::session::SessionManager;
1112
use crate::{
@@ -33,13 +34,19 @@ pub struct StreamableHttpServerConfig {
3334
pub sse_keep_alive: Option<Duration>,
3435
/// If true, the server will create a session for each request and keep it alive.
3536
pub stateful_mode: bool,
37+
/// Cancellation token for the Streamable HTTP server.
38+
///
39+
/// When this token is cancelled, all active sessions are terminated and
40+
/// the server stops accepting new requests.
41+
pub cancellation_token: CancellationToken,
3642
}
3743

3844
impl Default for StreamableHttpServerConfig {
3945
fn default() -> Self {
4046
Self {
4147
sse_keep_alive: Some(Duration::from_secs(15)),
4248
stateful_mode: true,
49+
cancellation_token: CancellationToken::new(),
4350
}
4451
}
4552
}
@@ -209,15 +216,23 @@ where
209216
.resume(&session_id, last_event_id)
210217
.await
211218
.map_err(internal_error_response("resume session"))?;
212-
Ok(sse_stream_response(stream, self.config.sse_keep_alive))
219+
Ok(sse_stream_response(
220+
stream,
221+
self.config.sse_keep_alive,
222+
self.config.cancellation_token.child_token(),
223+
))
213224
} else {
214225
// create standalone stream
215226
let stream = self
216227
.session_manager
217228
.create_standalone_stream(&session_id)
218229
.await
219230
.map_err(internal_error_response("create standalone stream"))?;
220-
Ok(sse_stream_response(stream, self.config.sse_keep_alive))
231+
Ok(sse_stream_response(
232+
stream,
233+
self.config.sse_keep_alive,
234+
self.config.cancellation_token.child_token(),
235+
))
221236
}
222237
}
223238

@@ -307,7 +322,11 @@ where
307322
.create_stream(&session_id, message)
308323
.await
309324
.map_err(internal_error_response("get session"))?;
310-
Ok(sse_stream_response(stream, self.config.sse_keep_alive))
325+
Ok(sse_stream_response(
326+
stream,
327+
self.config.sse_keep_alive,
328+
self.config.cancellation_token.child_token(),
329+
))
311330
}
312331
ClientJsonRpcMessage::Notification(_)
313332
| ClientJsonRpcMessage::Response(_)
@@ -380,6 +399,7 @@ where
380399
}
381400
}),
382401
self.config.sse_keep_alive,
402+
self.config.cancellation_token.child_token(),
383403
);
384404

385405
response.headers_mut().insert(
@@ -413,6 +433,7 @@ where
413433
}
414434
}),
415435
self.config.sse_keep_alive,
436+
self.config.cancellation_token.child_token(),
416437
))
417438
}
418439
ClientJsonRpcMessage::Notification(_notification) => {

crates/rmcp/tests/test_with_js.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,18 +94,20 @@ async fn test_with_js_streamable_http_client() -> anyhow::Result<()> {
9494
.wait()
9595
.await?;
9696

97+
let ct = CancellationToken::new();
9798
let service: StreamableHttpService<Calculator, LocalSessionManager> =
9899
StreamableHttpService::new(
99100
|| Ok(Calculator::new()),
100101
Default::default(),
101102
StreamableHttpServerConfig {
102103
stateful_mode: true,
103104
sse_keep_alive: None,
105+
cancellation_token: ct.child_token(),
104106
},
105107
);
106108
let router = axum::Router::new().nest_service("/mcp", service);
107109
let tcp_listener = tokio::net::TcpListener::bind(STREAMABLE_HTTP_BIND_ADDRESS).await?;
108-
let ct = CancellationToken::new();
110+
109111
let handle = tokio::spawn({
110112
let ct = ct.clone();
111113
async move {

examples/servers/src/counter_streamhttp.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use rmcp::transport::streamable_http_server::{
2-
StreamableHttpService, session::local::LocalSessionManager,
2+
StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
33
};
44
use tracing_subscriber::{
55
layer::SubscriberExt,
@@ -20,17 +20,24 @@ async fn main() -> anyhow::Result<()> {
2020
)
2121
.with(tracing_subscriber::fmt::layer())
2222
.init();
23+
let ct = tokio_util::sync::CancellationToken::new();
2324

2425
let service = StreamableHttpService::new(
2526
|| Ok(Counter::new()),
2627
LocalSessionManager::default().into(),
27-
Default::default(),
28+
StreamableHttpServerConfig {
29+
cancellation_token: ct.child_token(),
30+
..Default::default()
31+
},
2832
);
2933

3034
let router = axum::Router::new().nest_service("/mcp", service);
3135
let tcp_listener = tokio::net::TcpListener::bind(BIND_ADDRESS).await?;
3236
let _ = axum::serve(tcp_listener, router)
33-
.with_graceful_shutdown(async { tokio::signal::ctrl_c().await.unwrap() })
37+
.with_graceful_shutdown(async move {
38+
tokio::signal::ctrl_c().await.unwrap();
39+
ct.cancel();
40+
})
3441
.await;
3542
Ok(())
3643
}

examples/transport/src/named-pipe.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@ async fn main() -> anyhow::Result<()> {
1212
let mut server = ServerOptions::new()
1313
.first_pipe_instance(true)
1414
.create(name)?;
15-
while let Ok(_) = server.connect().await {
15+
while server.connect().await.is_ok() {
1616
let stream = server;
1717
server = ServerOptions::new().create(name)?;
1818
tokio::spawn(async move {
19-
match serve_server(Calculator, stream).await {
19+
match serve_server(Calculator::new(), stream).await {
2020
Ok(server) => {
2121
println!("Server initialized successfully");
2222
if let Err(e) = server.waiting().await {

0 commit comments

Comments
 (0)