Skip to content

Commit 817e62d

Browse files
authored
feat: throw initialize error detail (#192)
1 parent 4dc0cad commit 817e62d

File tree

6 files changed

+115
-93
lines changed

6 files changed

+115
-93
lines changed

crates/rmcp/src/model.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,13 @@ impl Default for ProtocolVersion {
119119
Self::LATEST
120120
}
121121
}
122+
123+
impl std::fmt::Display for ProtocolVersion {
124+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125+
self.0.fmt(f)
126+
}
127+
}
128+
122129
impl ProtocolVersion {
123130
pub const V_2025_03_26: Self = Self(Cow::Borrowed("2025-03-26"));
124131
pub const V_2024_11_05: Self = Self(Cow::Borrowed("2024-11-05"));

crates/rmcp/src/service.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ pub trait ServiceRole: std::fmt::Debug + Send + Sync + 'static + Copy + Clone {
7777
type PeerNot: TryInto<CancelledNotification, Error = Self::PeerNot>
7878
+ From<CancelledNotification>
7979
+ TransferObject;
80+
type InitializeError<E>;
8081
const IS_CLIENT: bool;
8182
type Info: TransferObject;
8283
type PeerInfo: TransferObject;
@@ -113,7 +114,7 @@ pub trait ServiceExt<R: ServiceRole>: Service<R> + Sized {
113114
fn serve<T, E, A>(
114115
self,
115116
transport: T,
116-
) -> impl Future<Output = Result<RunningService<R, Self>, E>> + Send
117+
) -> impl Future<Output = Result<RunningService<R, Self>, R::InitializeError<E>>> + Send
117118
where
118119
T: IntoTransport<R, E, A>,
119120
E: std::error::Error + From<std::io::Error> + Send + Sync + 'static,
@@ -125,7 +126,7 @@ pub trait ServiceExt<R: ServiceRole>: Service<R> + Sized {
125126
self,
126127
transport: T,
127128
ct: CancellationToken,
128-
) -> impl Future<Output = Result<RunningService<R, Self>, E>> + Send
129+
) -> impl Future<Output = Result<RunningService<R, Self>, R::InitializeError<E>>> + Send
129130
where
130131
T: IntoTransport<R, E, A>,
131132
E: std::error::Error + From<std::io::Error> + Send + Sync + 'static,
@@ -469,7 +470,7 @@ pub async fn serve_directly<R, S, T, E, A>(
469470
service: S,
470471
transport: T,
471472
peer_info: R::PeerInfo,
472-
) -> Result<RunningService<R, S>, E>
473+
) -> RunningService<R, S>
473474
where
474475
R: ServiceRole,
475476
S: Service<R>,
@@ -485,7 +486,7 @@ pub async fn serve_directly_with_ct<R, S, T, E, A>(
485486
transport: T,
486487
peer_info: R::PeerInfo,
487488
ct: CancellationToken,
488-
) -> Result<RunningService<R, S>, E>
489+
) -> RunningService<R, S>
489490
where
490491
R: ServiceRole,
491492
S: Service<R>,
@@ -503,7 +504,7 @@ async fn serve_inner<R, S, T, E, A>(
503504
peer: Peer<R>,
504505
mut peer_rx: tokio::sync::mpsc::Receiver<PeerSinkMessage<R>>,
505506
ct: CancellationToken,
506-
) -> Result<RunningService<R, S>, E>
507+
) -> RunningService<R, S>
507508
where
508509
R: ServiceRole,
509510
S: Service<R>,
@@ -788,10 +789,10 @@ where
788789
tracing::info!(?quit_reason, "serve finished");
789790
quit_reason
790791
});
791-
Ok(RunningService {
792+
RunningService {
792793
service,
793794
peer: peer_return,
794795
handle,
795796
dg: ct.drop_guard(),
796-
})
797+
}
797798
}

crates/rmcp/src/service/client.rs

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::borrow::Cow;
2+
13
use thiserror::Error;
24

35
use super::*;
@@ -19,7 +21,7 @@ use crate::model::{
1921
///
2022
/// if you want to handle the error, you can use `serve_client_with_ct` or `serve_client` with `Result<RunningService<RoleClient, S>, ClientError>`
2123
#[derive(Error, Debug)]
22-
pub enum ClientError {
24+
pub enum ClientInitializeError<E> {
2325
#[error("expect initialized response, but received: {0:?}")]
2426
ExpectedInitResponse(Option<ServerJsonRpcMessage>),
2527

@@ -32,38 +34,40 @@ pub enum ClientError {
3234
#[error("connection closed: {0}")]
3335
ConnectionClosed(String),
3436

35-
#[error("IO error: {0}")]
36-
Io(#[from] std::io::Error),
37+
#[error("Send message error {error}, when {context}")]
38+
TransportError {
39+
error: E,
40+
context: Cow<'static, str>,
41+
},
3742
}
3843

3944
/// Helper function to get the next message from the stream
40-
async fn expect_next_message<T>(
45+
async fn expect_next_message<T, E>(
4146
transport: &mut T,
4247
context: &str,
43-
) -> Result<ServerJsonRpcMessage, ClientError>
48+
) -> Result<ServerJsonRpcMessage, ClientInitializeError<E>>
4449
where
4550
T: Transport<RoleClient>,
4651
{
4752
transport
4853
.receive()
4954
.await
50-
.ok_or_else(|| ClientError::ConnectionClosed(context.to_string()))
51-
.map_err(|e| ClientError::Io(std::io::Error::new(std::io::ErrorKind::Other, e)))
55+
.ok_or_else(|| ClientInitializeError::ConnectionClosed(context.to_string()))
5256
}
5357

5458
/// Helper function to expect a response from the stream
55-
async fn expect_response<T>(
59+
async fn expect_response<T, E>(
5660
transport: &mut T,
5761
context: &str,
58-
) -> Result<(ServerResult, RequestId), ClientError>
62+
) -> Result<(ServerResult, RequestId), ClientInitializeError<E>>
5963
where
6064
T: Transport<RoleClient>,
6165
{
6266
let msg = expect_next_message(transport, context).await?;
6367

6468
match msg {
6569
ServerJsonRpcMessage::Response(JsonRpcResponse { id, result, .. }) => Ok((result, id)),
66-
_ => Err(ClientError::ExpectedInitResponse(Some(msg))),
70+
_ => Err(ClientInitializeError::ExpectedInitResponse(Some(msg))),
6771
}
6872
}
6973

@@ -79,7 +83,7 @@ impl ServiceRole for RoleClient {
7983
type PeerNot = ServerNotification;
8084
type Info = ClientInfo;
8185
type PeerInfo = ServerInfo;
82-
86+
type InitializeError<E> = ClientInitializeError<E>;
8387
const IS_CLIENT: bool = true;
8488
}
8589

@@ -90,7 +94,7 @@ impl<S: Service<RoleClient>> ServiceExt<RoleClient> for S {
9094
self,
9195
transport: T,
9296
ct: CancellationToken,
93-
) -> impl Future<Output = Result<RunningService<RoleClient, Self>, E>> + Send
97+
) -> impl Future<Output = Result<RunningService<RoleClient, Self>, ClientInitializeError<E>>> + Send
9498
where
9599
T: IntoTransport<RoleClient, E, A>,
96100
E: std::error::Error + From<std::io::Error> + Send + Sync + 'static,
@@ -103,7 +107,7 @@ impl<S: Service<RoleClient>> ServiceExt<RoleClient> for S {
103107
pub async fn serve_client<S, T, E, A>(
104108
service: S,
105109
transport: T,
106-
) -> Result<RunningService<RoleClient, S>, E>
110+
) -> Result<RunningService<RoleClient, S>, ClientInitializeError<E>>
107111
where
108112
S: Service<RoleClient>,
109113
T: IntoTransport<RoleClient, E, A>,
@@ -116,7 +120,7 @@ pub async fn serve_client_with_ct<S, T, E, A>(
116120
service: S,
117121
transport: T,
118122
ct: CancellationToken,
119-
) -> Result<RunningService<RoleClient, S>, E>
123+
) -> Result<RunningService<RoleClient, S>, ClientInitializeError<E>>
120124
where
121125
S: Service<RoleClient>,
122126
T: IntoTransport<RoleClient, E, A>,
@@ -125,14 +129,6 @@ where
125129
let mut transport = transport.into_transport();
126130
let id_provider = <Arc<AtomicU32RequestIdProvider>>::default();
127131

128-
// Convert ClientError to std::io::Error, then to E
129-
let handle_client_error = |e: ClientError| -> E {
130-
match e {
131-
ClientError::Io(io_err) => io_err.into(),
132-
other => std::io::Error::new(std::io::ErrorKind::Other, format!("{}", other)).into(),
133-
}
134-
};
135-
136132
// service
137133
let id = id_provider.next_request_id();
138134
let init_request = InitializeRequest {
@@ -145,23 +141,23 @@ where
145141
ClientRequest::InitializeRequest(init_request),
146142
id.clone(),
147143
))
148-
.await?;
149-
150-
let (response, response_id) = expect_response(&mut transport, "initialize response")
151144
.await
152-
.map_err(handle_client_error)?;
145+
.map_err(|error| ClientInitializeError::TransportError {
146+
error,
147+
context: "send initialize request".into(),
148+
})?;
149+
150+
let (response, response_id) = expect_response(&mut transport, "initialize response").await?;
153151

154152
if id != response_id {
155-
return Err(handle_client_error(ClientError::ConflictInitResponseId(
153+
return Err(ClientInitializeError::ConflictInitResponseId(
156154
id,
157155
response_id,
158-
)));
156+
));
159157
}
160158

161159
let ServerResult::InitializeResult(initialize_result) = response else {
162-
return Err(handle_client_error(ClientError::ExpectedInitResult(Some(
163-
response,
164-
))));
160+
return Err(ClientInitializeError::ExpectedInitResult(Some(response)));
165161
};
166162

167163
// send notification
@@ -171,9 +167,15 @@ where
171167
extensions: Default::default(),
172168
}),
173169
);
174-
transport.send(notification).await?;
170+
transport
171+
.send(notification)
172+
.await
173+
.map_err(|error| ClientInitializeError::TransportError {
174+
error,
175+
context: "send initialized notification".into(),
176+
})?;
175177
let (peer, peer_rx) = Peer::new(id_provider, initialize_result);
176-
serve_inner(service, transport, peer, peer_rx, ct).await
178+
Ok(serve_inner(service, transport, peer, peer_rx, ct).await)
177179
}
178180

179181
macro_rules! method {

0 commit comments

Comments
 (0)