Skip to content

Commit 209be7b

Browse files
authored
refactor: provide http server as tower service (#228)
* refactor: streamable http server as tower service
1 parent 915bc3f commit 209be7b

File tree

22 files changed

+1828
-1180
lines changed

22 files changed

+1828
-1180
lines changed

crates/rmcp/Cargo.toml

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ reqwest = { version = "0.12", default-features = false, features = [
3939
"json",
4040
"stream",
4141
], optional = true }
42-
sse-stream = { version = "0.2.0", optional = true }
42+
43+
sse-stream = { version = "0.2", optional = true }
44+
4345
http = { version = "1", optional = true }
4446
url = { version = "2.4", optional = true }
4547

@@ -57,7 +59,9 @@ axum = { version = "0.8", features = [], optional = true }
5759
rand = { version = "0.9", optional = true }
5860
tokio-stream = { version = "0.1", optional = true }
5961
uuid = { version = "1", features = ["v4"], optional = true }
60-
62+
http-body = { version = "1", optional = true }
63+
http-body-util = { version = "0.1", optional = true }
64+
bytes = { version = "1", optional = true }
6165
# macro
6266
rmcp-macros = { version = "0.1", workspace = true, optional = true }
6367

@@ -74,7 +78,17 @@ reqwest = ["__reqwest", "reqwest?/rustls-tls"]
7478

7579
reqwest-tls-no-provider = ["__reqwest", "reqwest?/rustls-tls-no-provider"]
7680

77-
axum = ["dep:axum"]
81+
server-side-http = [
82+
"uuid",
83+
"dep:rand",
84+
"dep:tokio-stream",
85+
"dep:http",
86+
"dep:http-body",
87+
"dep:http-body-util",
88+
"dep:bytes",
89+
"dep:sse-stream",
90+
"tower",
91+
]
7892
# SSE client
7993
client-side-sse = ["dep:sse-stream", "dep:http"]
8094

@@ -97,15 +111,12 @@ transport-child-process = [
97111
transport-sse-server = [
98112
"transport-async-rw",
99113
"transport-worker",
100-
"axum",
101-
"dep:rand",
102-
"dep:tokio-stream",
103-
"uuid",
114+
"server-side-http",
115+
"dep:axum",
104116
]
105117
transport-streamable-http-server = [
106118
"transport-streamable-http-server-session",
107-
"axum",
108-
"uuid",
119+
"server-side-http",
109120
]
110121
transport-streamable-http-server-session = [
111122
"transport-async-rw",

crates/rmcp/src/service.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@ pub struct RequestContext<R: ServiceRole> {
488488
}
489489

490490
/// Use this function to skip initialization process
491-
pub async fn serve_directly<R, S, T, E, A>(
491+
pub fn serve_directly<R, S, T, E, A>(
492492
service: S,
493493
transport: T,
494494
peer_info: Option<R::PeerInfo>,
@@ -499,11 +499,11 @@ where
499499
T: IntoTransport<R, E, A>,
500500
E: std::error::Error + Send + Sync + 'static,
501501
{
502-
serve_directly_with_ct(service, transport, peer_info, Default::default()).await
502+
serve_directly_with_ct(service, transport, peer_info, Default::default())
503503
}
504504

505505
/// Use this function to skip initialization process
506-
pub async fn serve_directly_with_ct<R, S, T, E, A>(
506+
pub fn serve_directly_with_ct<R, S, T, E, A>(
507507
service: S,
508508
transport: T,
509509
peer_info: Option<R::PeerInfo>,
@@ -516,11 +516,11 @@ where
516516
E: std::error::Error + Send + Sync + 'static,
517517
{
518518
let (peer, peer_rx) = Peer::new(Arc::new(AtomicU32RequestIdProvider::default()), peer_info);
519-
serve_inner(service, transport, peer, peer_rx, ct).await
519+
serve_inner(service, transport, peer, peer_rx, ct)
520520
}
521521

522522
#[instrument(skip_all)]
523-
async fn serve_inner<R, S, T, E, A>(
523+
fn serve_inner<R, S, T, E, A>(
524524
service: S,
525525
transport: T,
526526
peer: Peer<R>,

crates/rmcp/src/service/client.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ pub async fn serve_client<S, T, E, A>(
111111
where
112112
S: Service<RoleClient>,
113113
T: IntoTransport<RoleClient, E, A>,
114-
E: std::error::Error + From<std::io::Error> + Send + Sync + 'static,
114+
E: std::error::Error + Send + Sync + 'static,
115115
{
116116
serve_client_with_ct(service, transport, Default::default()).await
117117
}
@@ -124,7 +124,7 @@ pub async fn serve_client_with_ct<S, T, E, A>(
124124
where
125125
S: Service<RoleClient>,
126126
T: IntoTransport<RoleClient, E, A>,
127-
E: std::error::Error + From<std::io::Error> + Send + Sync + 'static,
127+
E: std::error::Error + Send + Sync + 'static,
128128
{
129129
let mut transport = transport.into_transport();
130130
let id_provider = <Arc<AtomicU32RequestIdProvider>>::default();
@@ -175,7 +175,7 @@ where
175175
context: "send initialized notification".into(),
176176
})?;
177177
let (peer, peer_rx) = Peer::new(id_provider, Some(initialize_result));
178-
Ok(serve_inner(service, transport, peer, peer_rx, ct).await)
178+
Ok(serve_inner(service, transport, peer, peer_rx, ct))
179179
}
180180

181181
macro_rules! method {

crates/rmcp/src/service/server.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ impl<S: Service<RoleServer>> ServiceExt<RoleServer> for S {
7070
) -> impl Future<Output = Result<RunningService<RoleServer, Self>, ServerInitializeError<E>>> + Send
7171
where
7272
T: IntoTransport<RoleServer, E, A>,
73-
E: std::error::Error + From<std::io::Error> + Send + Sync + 'static,
73+
E: std::error::Error + Send + Sync + 'static,
7474
Self: Sized,
7575
{
7676
serve_server_with_ct(self, transport, ct)
@@ -84,7 +84,7 @@ pub async fn serve_server<S, T, E, A>(
8484
where
8585
S: Service<RoleServer>,
8686
T: IntoTransport<RoleServer, E, A>,
87-
E: std::error::Error + From<std::io::Error> + Send + Sync + 'static,
87+
E: std::error::Error + Send + Sync + 'static,
8888
{
8989
serve_server_with_ct(service, transport, CancellationToken::new()).await
9090
}
@@ -143,7 +143,7 @@ pub async fn serve_server_with_ct<S, T, E, A>(
143143
where
144144
S: Service<RoleServer>,
145145
T: IntoTransport<RoleServer, E, A>,
146-
E: std::error::Error + From<std::io::Error> + Send + Sync + 'static,
146+
E: std::error::Error + Send + Sync + 'static,
147147
{
148148
let mut transport = transport.into_transport();
149149
let id_provider = <Arc<AtomicU32RequestIdProvider>>::default();
@@ -212,7 +212,7 @@ where
212212
};
213213
let _ = service.handle_notification(notification).await;
214214
// Continue processing service
215-
Ok(serve_inner(service, transport, peer, peer_rx, ct).await)
215+
Ok(serve_inner(service, transport, peer, peer_rx, ct))
216216
}
217217

218218
macro_rules! method {

crates/rmcp/src/transport.rs

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
//! | transport | client | server |
88
//! |:-: |:-: |:-: |
99
//! | std IO | [`child_process::TokioChildProcess`] | [`io::stdio`] |
10-
//! | streamable http | [`streamable_http_client::StreamableHttpClientTransport`] | [`streamable_http_server::session::create_session`] |
10+
//! | streamable http | [`streamable_http_client::StreamableHttpClientTransport`] | [`streamable_http_server::StreamableHttpService`] |
1111
//! | sse | [`sse_client::SseClientTransport`] | [`sse_server::SseServer`] |
1212
//!
1313
//!## Helper Transport Types
@@ -64,6 +64,8 @@
6464
//! }
6565
//! ```
6666
67+
use std::sync::Arc;
68+
6769
use crate::service::{RxJsonRpcMessage, ServiceRole, TxJsonRpcMessage};
6870

6971
pub mod sink_stream;
@@ -122,7 +124,7 @@ pub use auth::{AuthError, AuthorizationManager, AuthorizationSession, Authorized
122124
pub mod streamable_http_server;
123125
#[cfg(feature = "transport-streamable-http-server")]
124126
#[cfg_attr(docsrs, doc(cfg(feature = "transport-streamable-http-server")))]
125-
pub use streamable_http_server::axum::StreamableHttpServer;
127+
pub use streamable_http_server::tower::{StreamableHttpServerConfig, StreamableHttpService};
126128

127129
#[cfg(feature = "transport-streamable-http-client")]
128130
#[cfg_attr(docsrs, doc(cfg(feature = "transport-streamable-http-client")))]
@@ -138,7 +140,7 @@ pub trait Transport<R>: Send
138140
where
139141
R: ServiceRole,
140142
{
141-
type Error;
143+
type Error: std::error::Error + Send + Sync + 'static;
142144
/// Send a message to the transport
143145
///
144146
/// Notice that the future returned by this function should be `Send` and `'static`.
@@ -169,9 +171,73 @@ impl<R, T, E> IntoTransport<R, E, TransportAdapterIdentity> for T
169171
where
170172
T: Transport<R, Error = E> + Send + 'static,
171173
R: ServiceRole,
172-
E: std::error::Error + Send + 'static,
174+
E: std::error::Error + Send + Sync + 'static,
173175
{
174176
fn into_transport(self) -> impl Transport<R, Error = E> + 'static {
175177
self
176178
}
177179
}
180+
181+
/// A transport that can send a single message and then close itself
182+
pub struct OneshotTransport<R>
183+
where
184+
R: ServiceRole,
185+
{
186+
message: Option<RxJsonRpcMessage<R>>,
187+
sender: tokio::sync::mpsc::Sender<TxJsonRpcMessage<R>>,
188+
finished_signal: Arc<tokio::sync::Notify>,
189+
}
190+
191+
impl<R> OneshotTransport<R>
192+
where
193+
R: ServiceRole,
194+
{
195+
pub fn new(
196+
message: RxJsonRpcMessage<R>,
197+
) -> (Self, tokio::sync::mpsc::Receiver<TxJsonRpcMessage<R>>) {
198+
let (sender, receiver) = tokio::sync::mpsc::channel(16);
199+
(
200+
Self {
201+
message: Some(message),
202+
sender,
203+
finished_signal: Arc::new(tokio::sync::Notify::new()),
204+
},
205+
receiver,
206+
)
207+
}
208+
}
209+
210+
impl<R> Transport<R> for OneshotTransport<R>
211+
where
212+
R: ServiceRole,
213+
{
214+
type Error = tokio::sync::mpsc::error::SendError<TxJsonRpcMessage<R>>;
215+
216+
fn send(
217+
&mut self,
218+
item: TxJsonRpcMessage<R>,
219+
) -> impl Future<Output = Result<(), Self::Error>> + Send + 'static {
220+
let sender = self.sender.clone();
221+
let terminate = matches!(item, TxJsonRpcMessage::<R>::Response(_));
222+
let signal = self.finished_signal.clone();
223+
async move {
224+
sender.send(item).await?;
225+
if terminate {
226+
signal.notify_waiters();
227+
}
228+
Ok(())
229+
}
230+
}
231+
232+
async fn receive(&mut self) -> Option<RxJsonRpcMessage<R>> {
233+
if self.message.is_none() {
234+
self.finished_signal.notified().await;
235+
}
236+
self.message.take()
237+
}
238+
239+
fn close(&mut self) -> impl Future<Output = Result<(), Self::Error>> + Send {
240+
self.message.take();
241+
std::future::ready(Ok(()))
242+
}
243+
}

crates/rmcp/src/transport/common.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
feature = "transport-streamable-http-server",
33
feature = "transport-sse-server"
44
))]
5-
pub mod axum;
5+
pub mod server_side_http;
66

77
pub mod http_header;
88

crates/rmcp/src/transport/common/axum.rs

Lines changed: 0 additions & 9 deletions
This file was deleted.

0 commit comments

Comments
 (0)