Skip to content

Commit 0e2b115

Browse files
authored
fix(server): add error enum while deal server info (#51)
1. wrap the error type for more standardized 2. add more information in error for debug trace 3. wrap helper func for more user-friendly code Signed-off-by: jokemanfire <[email protected]>
1 parent 50fadfb commit 0e2b115

File tree

1 file changed

+93
-43
lines changed

1 file changed

+93
-43
lines changed

crates/rmcp/src/service/server.rs

Lines changed: 93 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
use futures::{SinkExt, StreamExt};
2+
use thiserror::Error;
23

34
use super::*;
45
use crate::model::{
5-
CancelledNotification, CancelledNotificationParam, ClientInfo, ClientNotification,
6-
ClientRequest, ClientResult, CreateMessageRequest, CreateMessageRequestParam,
7-
CreateMessageResult, ListRootsRequest, ListRootsResult, LoggingMessageNotification,
8-
LoggingMessageNotificationParam, ProgressNotification, ProgressNotificationParam,
9-
PromptListChangedNotification, ResourceListChangedNotification, ResourceUpdatedNotification,
10-
ResourceUpdatedNotificationParam, ServerInfo, ServerMessage, ServerNotification, ServerRequest,
11-
ServerResult, ToolListChangedNotification,
6+
CancelledNotification, CancelledNotificationParam, ClientInfo, ClientJsonRpcMessage,
7+
ClientMessage, ClientNotification, ClientRequest, ClientResult, CreateMessageRequest,
8+
CreateMessageRequestParam, CreateMessageResult, ListRootsRequest, ListRootsResult,
9+
LoggingMessageNotification, LoggingMessageNotificationParam, ProgressNotification,
10+
ProgressNotificationParam, PromptListChangedNotification, ResourceListChangedNotification,
11+
ResourceUpdatedNotification, ResourceUpdatedNotificationParam, ServerInfo, ServerMessage,
12+
ServerNotification, ServerRequest, ServerResult, ToolListChangedNotification,
1213
};
1314

1415
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
@@ -26,6 +27,24 @@ impl ServiceRole for RoleServer {
2627
const IS_CLIENT: bool = false;
2728
}
2829

30+
/// It represents the error that may occur when serving the server.
31+
///
32+
/// if you want to handle the error, you can use `serve_server_with_ct` or `serve_server` with `Result<RunningService<RoleServer, S>, ServerError>`
33+
#[derive(Error, Debug)]
34+
pub enum ServerError {
35+
#[error("expect initialized request, but received: {0:?}")]
36+
ExpectedInitRequest(Option<ClientMessage>),
37+
38+
#[error("expect initialized notification, but received: {0:?}")]
39+
ExpectedInitNotification(Option<ClientMessage>),
40+
41+
#[error("connection closed: {0}")]
42+
ConnectionClosed(String),
43+
44+
#[error("IO error: {0}")]
45+
Io(#[from] std::io::Error),
46+
}
47+
2948
pub type ClientSink = Peer<RoleServer>;
3049

3150
impl<S: Service<RoleServer>> ServiceExt<RoleServer> for S {
@@ -55,6 +74,46 @@ where
5574
serve_server_with_ct(service, transport, CancellationToken::new()).await
5675
}
5776

77+
/// Helper function to get the next message from the stream
78+
async fn expect_next_message<S>(stream: &mut S, context: &str) -> Result<ClientMessage, ServerError>
79+
where
80+
S: StreamExt<Item = ClientJsonRpcMessage> + Unpin,
81+
{
82+
Ok(stream
83+
.next()
84+
.await
85+
.ok_or_else(|| ServerError::ConnectionClosed(context.to_string()))?
86+
.into_message())
87+
}
88+
89+
/// Helper function to expect a request from the stream
90+
async fn expect_request<S>(
91+
stream: &mut S,
92+
context: &str,
93+
) -> Result<(ClientRequest, RequestId), ServerError>
94+
where
95+
S: StreamExt<Item = ClientJsonRpcMessage> + Unpin,
96+
{
97+
let msg = expect_next_message(stream, context).await?;
98+
let msg_clone = msg.clone();
99+
msg.into_request()
100+
.ok_or(ServerError::ExpectedInitRequest(Some(msg_clone)))
101+
}
102+
103+
/// Helper function to expect a notification from the stream
104+
async fn expect_notification<S>(
105+
stream: &mut S,
106+
context: &str,
107+
) -> Result<ClientNotification, ServerError>
108+
where
109+
S: StreamExt<Item = ClientJsonRpcMessage> + Unpin,
110+
{
111+
let msg = expect_next_message(stream, context).await?;
112+
let msg_clone = msg.clone();
113+
msg.into_notification()
114+
.ok_or(ServerError::ExpectedInitNotification(Some(msg_clone)))
115+
}
116+
58117
pub async fn serve_server_with_ct<S, T, E, A>(
59118
service: S,
60119
transport: T,
@@ -70,54 +129,45 @@ where
70129
let mut stream = Box::pin(stream);
71130
let id_provider = <Arc<AtomicU32RequestIdProvider>>::default();
72131

73-
// service
74-
let (request, id) = stream
75-
.next()
132+
// Convert ServerError to std::io::Error, then to E
133+
let handle_server_error = |e: ServerError| -> E {
134+
match e {
135+
ServerError::Io(io_err) => io_err.into(),
136+
other => std::io::Error::new(std::io::ErrorKind::Other, format!("{}", other)).into(),
137+
}
138+
};
139+
140+
// Get initialize request
141+
let (request, id) = expect_request(&mut stream, "initialized request")
76142
.await
77-
.ok_or(std::io::Error::new(
78-
std::io::ErrorKind::UnexpectedEof,
79-
"expect initialize request",
80-
))?
81-
.into_message()
82-
.into_request()
83-
.ok_or(std::io::Error::new(
84-
std::io::ErrorKind::InvalidData,
85-
"expect initialize request",
86-
))?;
143+
.map_err(handle_server_error)?;
144+
87145
let ClientRequest::InitializeRequest(peer_info) = request else {
88-
return Err(std::io::Error::new(
89-
std::io::ErrorKind::InvalidData,
90-
"expect initialize request",
91-
)
92-
.into());
146+
return Err(handle_server_error(ServerError::ExpectedInitRequest(Some(
147+
ClientMessage::Request(request, id),
148+
))));
93149
};
150+
151+
// Send initialize response
94152
let init_response = service.get_info();
95153
sink.send(
96154
ServerMessage::Response(ServerResult::InitializeResult(init_response), id)
97155
.into_json_rpc_message(),
98156
)
99157
.await?;
100-
// waiting for notification
101-
let notification = stream
102-
.next()
158+
159+
// Wait for initialize notification
160+
let notification = expect_notification(&mut stream, "initialize notification")
103161
.await
104-
.ok_or(std::io::Error::new(
105-
std::io::ErrorKind::UnexpectedEof,
106-
"expect initialize notification",
107-
))?
108-
.into_message()
109-
.into_notification()
110-
.ok_or(std::io::Error::new(
111-
std::io::ErrorKind::InvalidData,
112-
"expect initialize notification",
113-
))?;
162+
.map_err(handle_server_error)?;
163+
114164
let ClientNotification::InitializedNotification(_) = notification else {
115-
return Err(std::io::Error::new(
116-
std::io::ErrorKind::InvalidData,
117-
"expect initialize notification",
118-
)
119-
.into());
165+
return Err(handle_server_error(ServerError::ExpectedInitNotification(
166+
Some(ClientMessage::Notification(notification)),
167+
)));
120168
};
169+
170+
// Continue processing service
121171
serve_inner(service, (sink, stream), peer_info.params, id_provider, ct).await
122172
}
123173

0 commit comments

Comments
 (0)