Skip to content

Commit 9e803d5

Browse files
author
=
committed
feat: support version negotiation
1 parent 9f21092 commit 9e803d5

File tree

2 files changed

+44
-6
lines changed

2 files changed

+44
-6
lines changed

crates/rmcp/src/model.rs

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ pub fn object(value: serde_json::Value) -> JsonObject {
3030
}
3131
}
3232

33-
3433
/// Use this macro just like [`serde_json::json!`]
3534
#[cfg(feature = "macros")]
3635
#[macro_export]
@@ -103,6 +102,23 @@ impl ProtocolVersion {
103102
pub const LATEST: Self = Self::V_2025_03_26;
104103
}
105104

105+
impl PartialOrd for ProtocolVersion {
106+
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
107+
fn parse(s: &str) -> Option<(u16, u16, u16)> {
108+
let (s_year, rest) = s.split_once('-')?;
109+
let (s_month, rest) = rest.split_once('-')?;
110+
let s_day = rest;
111+
let year = s_year.parse::<u16>().ok()?;
112+
let month = s_month.parse::<u16>().ok()?;
113+
let day = s_day.parse::<u16>().ok()?;
114+
Some((year, month, day))
115+
}
116+
let self_date = parse(self.0.as_ref())?;
117+
let other_date = parse(other.0.as_ref())?;
118+
Some(self_date.cmp(&other_date))
119+
}
120+
}
121+
106122
impl Serialize for ProtocolVersion {
107123
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
108124
where
@@ -1135,4 +1151,11 @@ mod tests {
11351151

11361152
assert_eq!(server_response_json, raw_response_json);
11371153
}
1154+
1155+
#[test]
1156+
fn test_protocol_version_order() {
1157+
let v1 = ProtocolVersion::V_2024_11_05;
1158+
let v2 = ProtocolVersion::V_2025_03_26;
1159+
assert!(v1 < v2);
1160+
}
11381161
}

crates/rmcp/src/service/server.rs

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@ use crate::model::{
66
ClientRequest, ClientResult, CreateMessageRequest, CreateMessageRequestParam,
77
CreateMessageResult, ListRootsRequest, ListRootsResult, LoggingMessageNotification,
88
LoggingMessageNotificationParam, ProgressNotification, ProgressNotificationParam,
9-
PromptListChangedNotification, ResourceListChangedNotification, ResourceUpdatedNotification,
10-
ResourceUpdatedNotificationParam, ServerInfo, ServerJsonRpcMessage, ServerNotification,
11-
ServerRequest, ServerResult, ToolListChangedNotification,
9+
PromptListChangedNotification, ProtocolVersion, ResourceListChangedNotification,
10+
ResourceUpdatedNotification, ResourceUpdatedNotificationParam, ServerInfo,
11+
ServerJsonRpcMessage, ServerNotification, ServerRequest, ServerResult,
12+
ToolListChangedNotification,
1213
};
1314

1415
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
@@ -65,11 +66,11 @@ where
6566
T: IntoTransport<RoleServer, E, A>,
6667
E: std::error::Error + From<std::io::Error> + Send + Sync + 'static,
6768
{
69+
const SUPPORTED_HIGHEST_VERSION: ProtocolVersion = ProtocolVersion::LATEST;
6870
let (sink, stream) = transport.into_transport();
6971
let mut sink = Box::pin(sink);
7072
let mut stream = Box::pin(stream);
7173
let id_provider = <Arc<AtomicU32RequestIdProvider>>::default();
72-
7374
// service
7475
let (request, id) = stream
7576
.next()
@@ -90,7 +91,21 @@ where
9091
)
9192
.into());
9293
};
93-
let init_response = service.get_info();
94+
95+
let protocol_version = match peer_info
96+
.params
97+
.protocol_version
98+
.partial_cmp(&SUPPORTED_HIGHEST_VERSION)
99+
.ok_or(std::io::Error::new(
100+
std::io::ErrorKind::InvalidData,
101+
"unsupported protocol version",
102+
))? {
103+
std::cmp::Ordering::Less => peer_info.params.protocol_version.clone(),
104+
_ => SUPPORTED_HIGHEST_VERSION,
105+
};
106+
107+
let mut init_response = service.get_info();
108+
init_response.protocol_version = protocol_version;
94109
sink.send(ServerJsonRpcMessage::response(
95110
ServerResult::InitializeResult(init_response),
96111
id,

0 commit comments

Comments
 (0)