Skip to content

Commit dc2e989

Browse files
committed
fix(streamable-http): gracefully shutdown while client connected
1 parent 44129e4 commit dc2e989

File tree

3 files changed

+45
-13
lines changed

3 files changed

+45
-13
lines changed

crates/rmcp/src/transport/common/server_side_http.rs

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use http::Response;
66
use http_body::Body;
77
use http_body_util::{BodyExt, Empty, Full, combinators::BoxBody};
88
use sse_stream::{KeepAlive, Sse, SseBody};
9+
use tokio_util::sync::CancellationToken;
910

1011
use super::http_header::EVENT_STREAM_MIME_TYPE;
1112
use crate::model::{ClientJsonRpcMessage, ServerJsonRpcMessage};
@@ -65,20 +66,26 @@ pub struct ServerSseMessage {
6566
pub(crate) fn sse_stream_response(
6667
stream: impl futures::Stream<Item = ServerSseMessage> + Send + Sync + 'static,
6768
keep_alive: Option<Duration>,
69+
ct: CancellationToken,
6870
) -> Response<BoxBody<Bytes, Infallible>> {
6971
use futures::StreamExt;
70-
let stream = SseBody::new(stream.map(|message| {
71-
let data = serde_json::to_string(&message.message).expect("valid message");
72-
let mut sse = Sse::default().data(data);
73-
sse.id = message.event_id;
74-
Result::<Sse, Infallible>::Ok(sse)
75-
}));
72+
let stream = stream
73+
.map(|message| {
74+
let data = serde_json::to_string(&message.message).expect("valid message");
75+
let mut sse = Sse::default().data(data);
76+
sse.id = message.event_id;
77+
Result::<Sse, Infallible>::Ok(sse)
78+
})
79+
.take_until(async move { ct.cancelled().await });
80+
let stream = SseBody::new(stream);
81+
7682
let stream = match keep_alive {
7783
Some(duration) => stream
7884
.with_keep_alive::<TokioTimer>(KeepAlive::new().interval(duration))
7985
.boxed(),
8086
None => stream.boxed(),
8187
};
88+
8289
Response::builder()
8390
.status(http::StatusCode::OK)
8491
.header(http::header::CONTENT_TYPE, EVENT_STREAM_MIME_TYPE)

crates/rmcp/src/transport/streamable_http_server/tower.rs

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use http::{Method, Request, Response, header::ALLOW};
66
use http_body::Body;
77
use http_body_util::{BodyExt, Full, combinators::BoxBody};
88
use tokio_stream::wrappers::ReceiverStream;
9+
use tokio_util::sync::CancellationToken;
910

1011
use super::session::SessionManager;
1112
use crate::{
@@ -33,13 +34,15 @@ pub struct StreamableHttpServerConfig {
3334
pub sse_keep_alive: Option<Duration>,
3435
/// If true, the server will create a session for each request and keep it alive.
3536
pub stateful_mode: bool,
37+
pub cancellation_token: CancellationToken,
3638
}
3739

3840
impl Default for StreamableHttpServerConfig {
3941
fn default() -> Self {
4042
Self {
4143
sse_keep_alive: Some(Duration::from_secs(15)),
4244
stateful_mode: true,
45+
cancellation_token: CancellationToken::new(),
4346
}
4447
}
4548
}
@@ -209,15 +212,23 @@ where
209212
.resume(&session_id, last_event_id)
210213
.await
211214
.map_err(internal_error_response("resume session"))?;
212-
Ok(sse_stream_response(stream, self.config.sse_keep_alive))
215+
Ok(sse_stream_response(
216+
stream,
217+
self.config.sse_keep_alive,
218+
self.config.cancellation_token.child_token(),
219+
))
213220
} else {
214221
// create standalone stream
215222
let stream = self
216223
.session_manager
217224
.create_standalone_stream(&session_id)
218225
.await
219226
.map_err(internal_error_response("create standalone stream"))?;
220-
Ok(sse_stream_response(stream, self.config.sse_keep_alive))
227+
Ok(sse_stream_response(
228+
stream,
229+
self.config.sse_keep_alive,
230+
self.config.cancellation_token.child_token(),
231+
))
221232
}
222233
}
223234

@@ -307,7 +318,11 @@ where
307318
.create_stream(&session_id, message)
308319
.await
309320
.map_err(internal_error_response("get session"))?;
310-
Ok(sse_stream_response(stream, self.config.sse_keep_alive))
321+
Ok(sse_stream_response(
322+
stream,
323+
self.config.sse_keep_alive,
324+
self.config.cancellation_token.child_token(),
325+
))
311326
}
312327
ClientJsonRpcMessage::Notification(_)
313328
| ClientJsonRpcMessage::Response(_)
@@ -380,6 +395,7 @@ where
380395
}
381396
}),
382397
self.config.sse_keep_alive,
398+
self.config.cancellation_token.child_token(),
383399
);
384400

385401
response.headers_mut().insert(
@@ -413,6 +429,7 @@ where
413429
}
414430
}),
415431
self.config.sse_keep_alive,
432+
self.config.cancellation_token.child_token(),
416433
))
417434
}
418435
ClientJsonRpcMessage::Notification(_notification) => {

examples/servers/src/counter_streamhttp.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
use rmcp::transport::streamable_http_server::{
2-
StreamableHttpService, session::local::LocalSessionManager,
1+
use rmcp::transport::{
2+
StreamableHttpServerConfig,
3+
streamable_http_server::{StreamableHttpService, session::local::LocalSessionManager},
34
};
45
use tracing_subscriber::{
56
layer::SubscriberExt,
@@ -20,17 +21,24 @@ async fn main() -> anyhow::Result<()> {
2021
)
2122
.with(tracing_subscriber::fmt::layer())
2223
.init();
24+
let ct = tokio_util::sync::CancellationToken::new();
2325

2426
let service = StreamableHttpService::new(
2527
|| Ok(Counter::new()),
2628
LocalSessionManager::default().into(),
27-
Default::default(),
29+
StreamableHttpServerConfig {
30+
cancellation_token: ct.child_token(),
31+
..Default::default()
32+
},
2833
);
2934

3035
let router = axum::Router::new().nest_service("/mcp", service);
3136
let tcp_listener = tokio::net::TcpListener::bind(BIND_ADDRESS).await?;
3237
let _ = axum::serve(tcp_listener, router)
33-
.with_graceful_shutdown(async { tokio::signal::ctrl_c().await.unwrap() })
38+
.with_graceful_shutdown(async move {
39+
tokio::signal::ctrl_c().await.unwrap();
40+
ct.cancel();
41+
})
3442
.await;
3543
Ok(())
3644
}

0 commit comments

Comments
 (0)