From b4cda2f6d2c696adbed17b8c3aa289d00e18b888 Mon Sep 17 00:00:00 2001 From: Eitan Yarmush Date: Mon, 7 Apr 2025 16:49:50 +0000 Subject: [PATCH] feat: extensions to context --- crates/rmcp/src/model.rs | 17 +++++++++++++++++ crates/rmcp/src/model/meta.rs | 11 ++++++++--- crates/rmcp/src/service.rs | 14 ++++++++------ crates/rmcp/tests/test_message_protocol.rs | 10 ++++++++++ 4 files changed, 43 insertions(+), 9 deletions(-) diff --git a/crates/rmcp/src/model.rs b/crates/rmcp/src/model.rs index 34df47fb..7ff035ee 100644 --- a/crates/rmcp/src/model.rs +++ b/crates/rmcp/src/model.rs @@ -200,6 +200,15 @@ pub struct Request { pub extensions: Extensions, } +impl GetExtensions for Request { + fn extensions(&self) -> &Extensions { + &self.extensions + } + fn extensions_mut(&mut self) -> &mut Extensions { + &mut self.extensions + } +} + #[derive(Debug, Clone)] pub struct RequestOptionalParam { pub method: M, @@ -220,6 +229,14 @@ pub struct RequestNoParam { pub extensions: Extensions, } +impl GetExtensions for RequestNoParam { + fn extensions(&self) -> &Extensions { + &self.extensions + } + fn extensions_mut(&mut self) -> &mut Extensions { + &mut self.extensions + } +} #[derive(Debug, Clone)] pub struct Notification { pub method: M, diff --git a/crates/rmcp/src/model/meta.rs b/crates/rmcp/src/model/meta.rs index 5c978896..010dd8ca 100644 --- a/crates/rmcp/src/model/meta.rs +++ b/crates/rmcp/src/model/meta.rs @@ -13,21 +13,26 @@ pub trait GetMeta { fn get_meta(&self) -> &Meta; } +pub trait GetExtensions { + fn extensions(&self) -> &Extensions; + fn extensions_mut(&mut self) -> &mut Extensions; +} + macro_rules! variant_extension { ( $Enum: ident { $($variant: ident)* } ) => { - impl $Enum { - pub fn extensions(&self) -> &Extensions { + impl GetExtensions for $Enum { + fn extensions(&self) -> &Extensions { match self { $( $Enum::$variant(v) => &v.extensions, )* } } - pub fn extensions_mut(&mut self) -> &mut Extensions { + fn extensions_mut(&mut self) -> &mut Extensions { match self { $( $Enum::$variant(v) => &mut v.extensions, diff --git a/crates/rmcp/src/service.rs b/crates/rmcp/src/service.rs index b23c61fa..696947b6 100644 --- a/crates/rmcp/src/service.rs +++ b/crates/rmcp/src/service.rs @@ -4,10 +4,10 @@ use thiserror::Error; use crate::{ error::Error as McpError, model::{ - CancelledNotification, CancelledNotificationParam, GetMeta, JsonRpcBatchRequestItem, - JsonRpcBatchResponseItem, JsonRpcError, JsonRpcMessage, JsonRpcNotification, - JsonRpcRequest, JsonRpcResponse, Meta, NumberOrString, ProgressToken, RequestId, - ServerJsonRpcMessage, + CancelledNotification, CancelledNotificationParam, Extensions, GetExtensions, GetMeta, + JsonRpcBatchRequestItem, JsonRpcBatchResponseItem, JsonRpcError, JsonRpcMessage, + JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, Meta, NumberOrString, ProgressToken, + RequestId, ServerJsonRpcMessage, }, transport::IntoTransport, }; @@ -59,12 +59,12 @@ impl TransferObject for T where #[allow(private_bounds, reason = "there's no the third implementation")] pub trait ServiceRole: std::fmt::Debug + Send + Sync + 'static + Copy + Clone { - type Req: TransferObject + GetMeta; + type Req: TransferObject + GetMeta + GetExtensions; type Resp: TransferObject; type Not: TryInto + From + TransferObject; - type PeerReq: TransferObject + GetMeta; + type PeerReq: TransferObject + GetMeta + GetExtensions; type PeerResp: TransferObject; type PeerNot: TryInto + From @@ -471,6 +471,7 @@ pub struct RequestContext { pub ct: CancellationToken, pub id: RequestId, pub meta: Meta, + pub extensions: Extensions, /// An interface to fetch the remote client or server pub peer: Peer, } @@ -667,6 +668,7 @@ where id: id.clone(), peer: peer.clone(), meta: request.get_meta().clone(), + extensions: request.extensions().clone(), }; tokio::spawn(async move { let result = service.handle_request(request, context).await; diff --git a/crates/rmcp/tests/test_message_protocol.rs b/crates/rmcp/tests/test_message_protocol.rs index d59a5563..602f93da 100644 --- a/crates/rmcp/tests/test_message_protocol.rs +++ b/crates/rmcp/tests/test_message_protocol.rs @@ -71,6 +71,7 @@ async fn test_context_inclusion_integration() -> anyhow::Result<()> { ct: CancellationToken::new(), id: NumberOrString::Number(1), meta: Default::default(), + extensions: Default::default(), }, ) .await?; @@ -112,6 +113,7 @@ async fn test_context_inclusion_integration() -> anyhow::Result<()> { ct: CancellationToken::new(), id: NumberOrString::Number(2), meta: Default::default(), + extensions: Default::default(), }, ) .await?; @@ -153,6 +155,7 @@ async fn test_context_inclusion_integration() -> anyhow::Result<()> { ct: CancellationToken::new(), id: NumberOrString::Number(3), meta: Default::default(), + extensions: Default::default(), }, ) .await?; @@ -214,6 +217,7 @@ async fn test_context_inclusion_ignored_integration() -> anyhow::Result<()> { ct: CancellationToken::new(), id: NumberOrString::Number(1), meta: Meta::default(), + extensions: Default::default(), }, ) .await?; @@ -280,6 +284,7 @@ async fn test_message_sequence_integration() -> anyhow::Result<()> { ct: CancellationToken::new(), id: NumberOrString::Number(1), meta: Meta::default(), + extensions: Default::default(), }, ) .await?; @@ -354,6 +359,7 @@ async fn test_message_sequence_validation_integration() -> anyhow::Result<()> { ct: CancellationToken::new(), id: NumberOrString::Number(1), meta: Meta::default(), + extensions: Default::default(), }, ) .await?; @@ -387,6 +393,7 @@ async fn test_message_sequence_validation_integration() -> anyhow::Result<()> { ct: CancellationToken::new(), id: NumberOrString::Number(2), meta: Meta::default(), + extensions: Default::default(), }, ) .await; @@ -439,6 +446,7 @@ async fn test_selective_context_handling_integration() -> anyhow::Result<()> { ct: CancellationToken::new(), id: NumberOrString::Number(1), meta: Meta::default(), + extensions: Default::default(), }, ) .await?; @@ -478,6 +486,7 @@ async fn test_selective_context_handling_integration() -> anyhow::Result<()> { ct: CancellationToken::new(), id: NumberOrString::Number(2), meta: Meta::default(), + extensions: Default::default(), }, ) .await?; @@ -534,6 +543,7 @@ async fn test_context_inclusion() -> anyhow::Result<()> { ct: CancellationToken::new(), id: NumberOrString::Number(1), meta: Meta::default(), + extensions: Default::default(), }, ) .await?;