Skip to content

Commit 73188ef

Browse files
authored
feat: allow failable service creation in streamable HTTP tower service (#244)
* feat: allow failable service creation in streamable HTTP tower service In our setup we need to lookup some info that may fail before we serve the request. * Fix examples build
1 parent c4f252a commit 73188ef

File tree

4 files changed

+12
-8
lines changed

4 files changed

+12
-8
lines changed

crates/rmcp/src/transport/streamable_http_server/tower.rs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ impl Default for StreamableHttpServerConfig {
4747
pub struct StreamableHttpService<S, M = super::session::local::LocalSessionManager> {
4848
pub config: StreamableHttpServerConfig,
4949
session_manager: Arc<M>,
50-
service_factory: Arc<dyn Fn() -> S + Send + Sync>,
50+
service_factory: Arc<dyn Fn() -> Result<S, std::io::Error> + Send + Sync>,
5151
}
5252

5353
impl<S, M> Clone for StreamableHttpService<S, M> {
@@ -92,7 +92,7 @@ where
9292
M: SessionManager,
9393
{
9494
pub fn new(
95-
service_factory: impl Fn() -> S + Send + Sync + 'static,
95+
service_factory: impl Fn() -> Result<S, std::io::Error> + Send + Sync + 'static,
9696
session_manager: Arc<M>,
9797
config: StreamableHttpServerConfig,
9898
) -> Self {
@@ -102,7 +102,7 @@ where
102102
service_factory: Arc::new(service_factory),
103103
}
104104
}
105-
fn get_service(&self) -> S {
105+
fn get_service(&self) -> Result<S, std::io::Error> {
106106
(self.service_factory)()
107107
}
108108
pub async fn handle<B>(&self, request: Request<B>) -> Response<UnsyncBoxBody<Bytes, Infallible>>
@@ -318,7 +318,9 @@ where
318318
.create_session()
319319
.await
320320
.map_err(internal_error_response("create session"))?;
321-
let service = self.get_service();
321+
let service = self
322+
.get_service()
323+
.map_err(internal_error_response("get service"))?;
322324
// spawn a task to serve the session
323325
tokio::spawn({
324326
let session_manager = self.session_manager.clone();
@@ -372,7 +374,9 @@ where
372374
Ok(response)
373375
}
374376
} else {
375-
let service = self.get_service();
377+
let service = self
378+
.get_service()
379+
.map_err(internal_error_response("get service"))?;
376380
match message {
377381
ClientJsonRpcMessage::Request(request) => {
378382
let (transport, receiver) =

crates/rmcp/tests/test_with_js.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ async fn test_with_js_streamable_http_client() -> anyhow::Result<()> {
9696

9797
let service: StreamableHttpService<Calculator, LocalSessionManager> =
9898
StreamableHttpService::new(
99-
Calculator::default,
99+
|| Ok(Calculator),
100100
Default::default(),
101101
StreamableHttpServerConfig {
102102
stateful_mode: true,

examples/servers/src/counter_hyper_streamable_http.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use rmcp::transport::streamable_http_server::{
1212
#[tokio::main]
1313
async fn main() -> anyhow::Result<()> {
1414
let service = TowerToHyperService::new(StreamableHttpService::new(
15-
Counter::new,
15+
|| Ok(Counter::new()),
1616
LocalSessionManager::default().into(),
1717
Default::default(),
1818
));

examples/servers/src/counter_streamhttp.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ async fn main() -> anyhow::Result<()> {
2222
.init();
2323

2424
let service = StreamableHttpService::new(
25-
Counter::new,
25+
|| Ok(Counter::new()),
2626
LocalSessionManager::default().into(),
2727
Default::default(),
2828
);

0 commit comments

Comments
 (0)