Skip to content

Commit 47518b3

Browse files
authored
fix: allow SSE server router to be nested (#240)
* fix: allow SSE server router to be nested
1 parent 97bb479 commit 47518b3

File tree

4 files changed

+77
-27
lines changed

4 files changed

+77
-27
lines changed

crates/rmcp/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ path = "tests/test_tool_macros.rs"
151151
[[test]]
152152
name = "test_with_python"
153153
required-features = [
154+
"reqwest",
154155
"server",
155156
"client",
156157
"transport-sse-server",

crates/rmcp/src/transport/sse_server.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
use std::{collections::HashMap, io, net::SocketAddr, sync::Arc, time::Duration};
22

33
use axum::{
4-
Json, Router,
5-
extract::{Query, State},
4+
Extension, Json, Router,
5+
extract::{NestedPath, Query, State},
66
http::{StatusCode, request::Parts},
77
response::{
88
Response,
@@ -84,6 +84,7 @@ async fn post_event_handler(
8484

8585
async fn sse_handler(
8686
State(app): State<App>,
87+
nested_path: Option<Extension<NestedPath>>,
8788
parts: Parts,
8889
) -> Result<Sse<impl Stream<Item = Result<Event, io::Error>>>, Response<String>> {
8990
let session = session_id();
@@ -115,12 +116,13 @@ async fn sse_handler(
115116
*response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
116117
return Err(response);
117118
}
119+
let nested_path = nested_path.as_deref().map(NestedPath::as_str).unwrap_or("");
118120
let post_path = app.post_path.as_ref();
119121
let ping_interval = app.sse_ping_interval;
120122
let stream = futures::stream::once(futures::future::ok(
121123
Event::default()
122124
.event("endpoint")
123-
.data(format!("{post_path}?sessionId={session}")),
125+
.data(format!("{nested_path}{post_path}?sessionId={session}")),
124126
))
125127
.chain(ReceiverStream::new(to_client_rx).map(|message| {
126128
match serde_json::to_string(&message) {
@@ -257,8 +259,6 @@ impl SseServer {
257259
Ok(sse_server)
258260
}
259261

260-
/// Warning: This function creates a new SseServer instance with the provided configuration.
261-
/// `App.post_path` may be incorrect if using `Router` as an embedded router.
262262
pub fn new(config: SseServerConfig) -> (SseServer, Router) {
263263
let (app, transport_rx) = App::new(
264264
config.post_path.clone(),

crates/rmcp/tests/test_with_python.rs

Lines changed: 67 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1+
use axum::Router;
12
use rmcp::{
23
ServiceExt,
3-
transport::{ConfigureCommandExt, SseServer, TokioChildProcess},
4+
transport::{ConfigureCommandExt, SseServer, TokioChildProcess, sse_server::SseServerConfig},
45
};
6+
use tokio::time::timeout;
7+
use tokio_util::sync::CancellationToken;
58
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
69
mod common;
710
use common::calculator::Calculator;
811

9-
const BIND_ADDRESS: &str = "127.0.0.1:8000";
10-
11-
#[tokio::test]
12-
async fn test_with_python_client() -> anyhow::Result<()> {
12+
async fn init() -> anyhow::Result<()> {
1313
let _ = tracing_subscriber::registry()
1414
.with(
1515
tracing_subscriber::EnvFilter::try_from_default_env()
@@ -23,6 +23,14 @@ async fn test_with_python_client() -> anyhow::Result<()> {
2323
.spawn()?
2424
.wait()
2525
.await?;
26+
Ok(())
27+
}
28+
29+
#[tokio::test]
30+
async fn test_with_python_client() -> anyhow::Result<()> {
31+
init().await?;
32+
33+
const BIND_ADDRESS: &str = "127.0.0.1:8000";
2634

2735
let ct = SseServer::serve(BIND_ADDRESS.parse()?)
2836
.await?
@@ -31,6 +39,7 @@ async fn test_with_python_client() -> anyhow::Result<()> {
3139
let status = tokio::process::Command::new("uv")
3240
.arg("run")
3341
.arg("tests/test_with_python/client.py")
42+
.arg(format!("http://{BIND_ADDRESS}/sse"))
3443
.spawn()?
3544
.wait()
3645
.await?;
@@ -39,21 +48,61 @@ async fn test_with_python_client() -> anyhow::Result<()> {
3948
Ok(())
4049
}
4150

51+
/// Test the SSE server in a nested Axum router.
52+
#[tokio::test]
53+
async fn test_nested_with_python_client() -> anyhow::Result<()> {
54+
init().await?;
55+
56+
const BIND_ADDRESS: &str = "127.0.0.1:8001";
57+
58+
// Create an SSE router
59+
let sse_config = SseServerConfig {
60+
bind: BIND_ADDRESS.parse()?,
61+
sse_path: "/sse".to_string(),
62+
post_path: "/message".to_string(),
63+
ct: CancellationToken::new(),
64+
sse_keep_alive: None,
65+
};
66+
67+
let listener = tokio::net::TcpListener::bind(&sse_config.bind).await?;
68+
69+
let (sse_server, sse_router) = SseServer::new(sse_config);
70+
let ct = sse_server.with_service(Calculator::default);
71+
72+
let main_router = Router::new().nest("/nested", sse_router);
73+
74+
let server_ct = ct.clone();
75+
let server = axum::serve(listener, main_router).with_graceful_shutdown(async move {
76+
server_ct.cancelled().await;
77+
tracing::info!("sse server cancelled");
78+
});
79+
80+
tokio::spawn(async move {
81+
let _ = server.await;
82+
tracing::info!("sse server shutting down");
83+
});
84+
85+
// Spawn the process with timeout, as failure to access the '/message' URL
86+
// causes the client to never exit.
87+
let status = timeout(
88+
tokio::time::Duration::from_secs(5),
89+
tokio::process::Command::new("uv")
90+
.arg("run")
91+
.arg("tests/test_with_python/client.py")
92+
.arg(format!("http://{BIND_ADDRESS}/nested/sse"))
93+
.spawn()?
94+
.wait(),
95+
)
96+
.await?;
97+
assert!(status?.success());
98+
ct.cancel();
99+
Ok(())
100+
}
101+
42102
#[tokio::test]
43103
async fn test_with_python_server() -> anyhow::Result<()> {
44-
let _ = tracing_subscriber::registry()
45-
.with(
46-
tracing_subscriber::EnvFilter::try_from_default_env()
47-
.unwrap_or_else(|_| "debug".to_string().into()),
48-
)
49-
.with(tracing_subscriber::fmt::layer())
50-
.try_init();
51-
tokio::process::Command::new("uv")
52-
.args(["pip", "install", "-r", "pyproject.toml"])
53-
.current_dir("tests/test_with_python")
54-
.spawn()?
55-
.wait()
56-
.await?;
104+
init().await?;
105+
57106
let transport = TokioChildProcess::new(tokio::process::Command::new("uv").configure(|cmd| {
58107
cmd.arg("run").arg("tests/test_with_python/server.py");
59108
}))?;

crates/rmcp/tests/test_with_python/client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from mcp import ClientSession, StdioServerParameters, types
22
from mcp.client.sse import sse_client
3-
4-
3+
import sys
54

65
async def run():
7-
async with sse_client("http://localhost:8000/sse") as (read, write):
6+
url = sys.argv[1]
7+
async with sse_client(url) as (read, write):
88
async with ClientSession(
99
read, write
1010
) as session:
@@ -25,4 +25,4 @@ async def run():
2525
if __name__ == "__main__":
2626
import asyncio
2727

28-
asyncio.run(run())
28+
asyncio.run(run())

0 commit comments

Comments
 (0)