diff --git a/crates/rmcp/src/transport/sse_server.rs b/crates/rmcp/src/transport/sse_server.rs index 79f0b66b..ef0b0309 100644 --- a/crates/rmcp/src/transport/sse_server.rs +++ b/crates/rmcp/src/transport/sse_server.rs @@ -219,13 +219,9 @@ impl SseServer { .await } pub async fn serve_with_config(config: SseServerConfig) -> io::Result { - let (app, transport_rx) = App::new(config.post_path.clone()); - let listener = tokio::net::TcpListener::bind(config.bind).await?; - let service = Router::new() - .route(&config.sse_path, get(sse_handler)) - .route(&config.post_path, post(post_event_handler)) - .with_state(app); - let ct = config.ct.child_token(); + let (sse_server, service) = Self::new(config); + let listener = tokio::net::TcpListener::bind(sse_server.config.bind).await?; + let ct = sse_server.config.ct.child_token(); let server = axum::serve(listener, service).with_graceful_shutdown(async move { ct.cancelled().await; tracing::info!("sse server cancelled"); @@ -236,13 +232,28 @@ impl SseServer { tracing::error!(error = %e, "sse server shutdown with error"); } } - .instrument(tracing::info_span!("sse-server", bind_address = %config.bind)), + .instrument(tracing::info_span!("sse-server", bind_address = %sse_server.config.bind)), ); - Ok(Self { + Ok(sse_server) + } + + /// Warning: This function creates a new SseServer instance with the provided configuration. + /// `App.post_path` may be incorrect if using `Router` as an embedded router. + pub fn new(config: SseServerConfig) -> (SseServer, Router) { + let (app, transport_rx) = App::new(config.post_path.clone()); + let router = Router::new() + .route(&config.sse_path, get(sse_handler)) + .route(&config.post_path, post(post_event_handler)) + .with_state(app); + + let server = SseServer { transport_rx, config, - }) + }; + + (server, router) } + pub fn with_service(mut self, service_provider: F) -> CancellationToken where S: Service, diff --git a/examples/servers/Cargo.toml b/examples/servers/Cargo.toml index 80407d5c..25008b7b 100644 --- a/examples/servers/Cargo.toml +++ b/examples/servers/Cargo.toml @@ -33,4 +33,8 @@ path = "src/std_io.rs" [[example]] name = "axum" -path = "src/axum.rs" \ No newline at end of file +path = "src/axum.rs" + +[[example]] +name = "axum_router" +path = "src/axum_router.rs" \ No newline at end of file diff --git a/examples/servers/src/axum_router.rs b/examples/servers/src/axum_router.rs new file mode 100644 index 00000000..d4813b85 --- /dev/null +++ b/examples/servers/src/axum_router.rs @@ -0,0 +1,51 @@ +use rmcp::transport::sse_server::{SseServer, SseServerConfig}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +use tracing_subscriber::{self}; +mod common; +use common::counter::Counter; + +const BIND_ADDRESS: &str = "127.0.0.1:8000"; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "debug".to_string().into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let config = SseServerConfig { + bind: BIND_ADDRESS.parse()?, + sse_path: "/sse".to_string(), + post_path: "/message".to_string(), + ct: tokio_util::sync::CancellationToken::new(), + }; + + let (sse_server, router) = SseServer::new(config); + + // Do something with the router, e.g., add routes or middleware + + let listener = tokio::net::TcpListener::bind(sse_server.config.bind).await?; + + let ct = sse_server.config.ct.child_token(); + + let server = axum::serve(listener, router).with_graceful_shutdown(async move { + ct.cancelled().await; + tracing::info!("sse server cancelled"); + }); + + tokio::spawn(async move { + if let Err(e) = server.await { + tracing::error!(error = %e, "sse server shutdown with error"); + } + }); + + let ct = sse_server.with_service(Counter::new); + + tokio::signal::ctrl_c().await?; + ct.cancel(); + Ok(()) +}