Skip to content

Commit ab3f5e7

Browse files
author
=
committed
trait WithMeta
1 parent 712f3d6 commit ab3f5e7

File tree

9 files changed

+271
-39
lines changed

9 files changed

+271
-39
lines changed

crates/rmcp/src/handler/server/tool.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,11 @@ pub struct ToolCallContext<'service, S> {
6565
impl<'service, S> ToolCallContext<'service, S> {
6666
pub fn new(
6767
service: &'service S,
68-
CallToolRequestParam { name, arguments }: CallToolRequestParam,
68+
CallToolRequestParam {
69+
name,
70+
arguments,
71+
_meta,
72+
}: CallToolRequestParam,
6973
request_context: RequestContext<RoleServer>,
7074
) -> Self {
7175
Self {

crates/rmcp/src/model.rs

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@ use std::{borrow::Cow, sync::Arc};
22
mod annotated;
33
mod capabilities;
44
mod content;
5+
mod meta;
56
mod prompt;
67
mod resource;
78
mod tool;
89

910
pub use annotated::*;
1011
pub use capabilities::*;
1112
pub use content::*;
13+
pub use meta::*;
1214
pub use prompt::*;
1315
pub use resource::*;
1416
use serde::{Deserialize, Serialize};
@@ -190,22 +192,15 @@ impl<'de> Deserialize<'de> for NumberOrString {
190192

191193
pub type RequestId = NumberOrString;
192194
pub type ProgressToken = NumberOrString;
193-
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
194-
pub struct WithMeta<P = JsonObject, M = ()> {
195-
#[serde(skip_serializing_if = "Option::is_none")]
196-
pub _meta: Option<M>,
197-
#[serde(flatten)]
198-
pub inner: P,
199-
}
200195

201196
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
202197
#[serde(rename_all = "camelCase")]
203198
pub struct RequestMeta {
204-
progress_token: ProgressToken,
199+
pub progress_token: ProgressToken,
205200
}
206201

207202
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
208-
pub struct Request<M = String, P = Option<WithMeta<JsonObject, RequestMeta>>> {
203+
pub struct Request<M = String, P = JsonObject> {
209204
pub method: M,
210205
// #[serde(skip_serializing_if = "Option::is_none")]
211206
pub params: P,
@@ -216,7 +211,7 @@ pub struct RequestNoParam<M = String> {
216211
}
217212

218213
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
219-
pub struct Notification<M = String, P = Option<WithMeta<JsonObject, JsonObject>>> {
214+
pub struct Notification<M = String, P = JsonObject> {
220215
pub method: M,
221216
pub params: P,
222217
}
@@ -233,9 +228,9 @@ pub struct JsonRpcRequest<R = Request> {
233228
#[serde(flatten)]
234229
pub request: R,
235230
}
236-
type DefaultResponse = WithMeta<JsonObject, JsonObject>;
231+
type DefaultResponse = JsonObject;
237232
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
238-
pub struct JsonRpcResponse<R = DefaultResponse> {
233+
pub struct JsonRpcResponse<R = JsonObject> {
239234
pub jsonrpc: JsonRpcVersion2_0,
240235
pub id: RequestId,
241236
pub result: R,
@@ -846,6 +841,8 @@ const_string!(CallToolRequestMethod = "tools/call");
846841
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
847842
#[serde(rename_all = "camelCase")]
848843
pub struct CallToolRequestParam {
844+
#[serde(skip_serializing_if = "Option::is_none")]
845+
pub _meta: Option<RequestMeta>,
849846
pub name: Cow<'static, str>,
850847
#[serde(skip_serializing_if = "Option::is_none")]
851848
pub arguments: Option<JsonObject>,
@@ -1041,7 +1038,7 @@ mod tests {
10411038
assert_eq!(r.id, RequestId::Number(1));
10421039
assert_eq!(r.request.method, "request");
10431040
assert_eq!(
1044-
&r.request.params.as_ref().unwrap().inner,
1041+
&r.request.params,
10451042
json!({"key": "value"})
10461043
.as_object()
10471044
.expect("should be an object")

crates/rmcp/src/model/meta.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
use super::{ClientRequest, RequestMeta, ServerRequest};
2+
3+
pub trait WithMeta<M> {
4+
fn set_meta(&mut self, meta: Option<M>);
5+
fn get_meta(&self) -> Option<&M>;
6+
}
7+
8+
impl WithMeta<RequestMeta> for ClientRequest {
9+
fn set_meta(&mut self, meta: Option<RequestMeta>) {
10+
#[allow(clippy::single_match)]
11+
match self {
12+
ClientRequest::CallToolRequest(req) => {
13+
req.params._meta = meta;
14+
}
15+
_ => {}
16+
}
17+
}
18+
19+
fn get_meta(&self) -> Option<&RequestMeta> {
20+
#[allow(clippy::single_match)]
21+
match self {
22+
ClientRequest::CallToolRequest(req) => req.params._meta.as_ref(),
23+
_ => None,
24+
}
25+
}
26+
}
27+
28+
impl WithMeta<RequestMeta> for ServerRequest {
29+
fn set_meta(&mut self, _meta: Option<RequestMeta>) {}
30+
31+
fn get_meta(&self) -> Option<&RequestMeta> {
32+
None
33+
}
34+
}

crates/rmcp/src/service.rs

Lines changed: 59 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use crate::{
66
model::{
77
CancelledNotification, CancelledNotificationParam, JsonRpcBatchRequestItem,
88
JsonRpcBatchResponseItem, JsonRpcError, JsonRpcMessage, JsonRpcNotification,
9-
JsonRpcRequest, JsonRpcResponse, RequestId,
9+
JsonRpcRequest, JsonRpcResponse, ProgressToken, RequestId, RequestMeta, WithMeta,
1010
},
1111
transport::IntoTransport,
1212
};
@@ -58,12 +58,12 @@ impl<T> TransferObject for T where
5858

5959
#[allow(private_bounds, reason = "there's no the third implementation")]
6060
pub trait ServiceRole: std::fmt::Debug + Send + Sync + 'static + Copy + Clone {
61-
type Req: TransferObject;
61+
type Req: TransferObject + WithMeta<RequestMeta>;
6262
type Resp: TransferObject;
6363
type Not: TryInto<CancelledNotification, Error = Self::Not>
6464
+ From<CancelledNotification>
6565
+ TransferObject;
66-
type PeerReq: TransferObject;
66+
type PeerReq: TransferObject + WithMeta<RequestMeta>;
6767
type PeerResp: TransferObject;
6868
type PeerNot: TryInto<CancelledNotification, Error = Self::PeerNot>
6969
+ From<CancelledNotification>
@@ -201,17 +201,30 @@ pub trait RequestIdProvider: Send + Sync + 'static {
201201
fn next_request_id(&self) -> RequestId;
202202
}
203203

204+
pub trait ProgressTokenProvider: Send + Sync + 'static {
205+
fn next_progress_token(&self) -> ProgressToken;
206+
}
207+
208+
pub type AtomicU32RequestIdProvider = AtomicU32Provider;
209+
pub type AtomicU32ProgressTokenProvider = AtomicU32Provider;
210+
204211
#[derive(Debug, Default)]
205-
pub struct AtomicU32RequestIdProvider {
212+
pub struct AtomicU32Provider {
206213
id: AtomicU32,
207214
}
208215

209-
impl RequestIdProvider for AtomicU32RequestIdProvider {
216+
impl RequestIdProvider for AtomicU32Provider {
210217
fn next_request_id(&self) -> RequestId {
211218
RequestId::Number(self.id.fetch_add(1, std::sync::atomic::Ordering::SeqCst))
212219
}
213220
}
214221

222+
impl ProgressTokenProvider for AtomicU32Provider {
223+
fn next_progress_token(&self) -> RequestId {
224+
RequestId::Number(self.id.fetch_add(1, std::sync::atomic::Ordering::SeqCst))
225+
}
226+
}
227+
215228
type Responder<T> = tokio::sync::oneshot::Sender<T>;
216229

217230
/// A handle to a remote request
@@ -225,6 +238,7 @@ pub struct RequestHandle<R: ServiceRole> {
225238
pub options: PeerRequestOptions,
226239
pub peer: Peer<R>,
227240
pub id: RequestId,
241+
pub progress_token: ProgressToken,
228242
}
229243

230244
impl<R: ServiceRole> RequestHandle<R> {
@@ -275,13 +289,16 @@ impl<R: ServiceRole> RequestHandle<R> {
275289
}
276290

277291
#[derive(Debug)]
278-
pub enum PeerSinkMessage<R: ServiceRole> {
279-
Request(
280-
R::Req,
281-
RequestId,
282-
Responder<Result<R::PeerResp, ServiceError>>,
283-
),
284-
Notification(R::Not, Responder<Result<(), ServiceError>>),
292+
pub(crate) enum PeerSinkMessage<R: ServiceRole> {
293+
Request {
294+
request: R::Req,
295+
id: RequestId,
296+
responder: Responder<Result<R::PeerResp, ServiceError>>,
297+
},
298+
Notification {
299+
notification: R::Not,
300+
responder: Responder<Result<(), ServiceError>>,
301+
},
285302
}
286303

287304
/// An interface to fetch the remote client or server
@@ -293,6 +310,7 @@ pub enum PeerSinkMessage<R: ServiceRole> {
293310
pub struct Peer<R: ServiceRole> {
294311
tx: mpsc::Sender<PeerSinkMessage<R>>,
295312
request_id_provider: Arc<dyn RequestIdProvider>,
313+
progress_token_provider: Arc<dyn ProgressTokenProvider>,
296314
info: Arc<R::PeerInfo>,
297315
}
298316

@@ -320,7 +338,7 @@ impl PeerRequestOptions {
320338

321339
impl<R: ServiceRole> Peer<R> {
322340
const CLIENT_CHANNEL_BUFFER_SIZE: usize = 1024;
323-
pub fn new(
341+
pub(crate) fn new(
324342
request_id_provider: Arc<dyn RequestIdProvider>,
325343
peer_info: R::PeerInfo,
326344
) -> (Peer<R>, ProxyOutbound<R>) {
@@ -329,6 +347,7 @@ impl<R: ServiceRole> Peer<R> {
329347
Self {
330348
tx,
331349
request_id_provider,
350+
progress_token_provider: Arc::new(AtomicU32ProgressTokenProvider::default()),
332351
info: peer_info.into(),
333352
},
334353
rx,
@@ -337,7 +356,10 @@ impl<R: ServiceRole> Peer<R> {
337356
pub async fn send_notification(&self, notification: R::Not) -> Result<(), ServiceError> {
338357
let (responder, receiver) = tokio::sync::oneshot::channel();
339358
self.tx
340-
.send(PeerSinkMessage::Notification(notification, responder))
359+
.send(PeerSinkMessage::Notification {
360+
notification,
361+
responder,
362+
})
341363
.await
342364
.map_err(|_m| ServiceError::Transport(std::io::Error::other("disconnected")))?;
343365
receiver
@@ -352,18 +374,27 @@ impl<R: ServiceRole> Peer<R> {
352374
}
353375
pub async fn send_cancellable_request(
354376
&self,
355-
request: R::Req,
377+
mut request: R::Req,
356378
options: PeerRequestOptions,
357379
) -> Result<RequestHandle<R>, ServiceError> {
358380
let id = self.request_id_provider.next_request_id();
381+
let progress_token = self.progress_token_provider.next_progress_token();
382+
request.set_meta(Some(RequestMeta {
383+
progress_token: progress_token.clone(),
384+
}));
359385
let (responder, receiver) = tokio::sync::oneshot::channel();
360386
self.tx
361-
.send(PeerSinkMessage::Request(request, id.clone(), responder))
387+
.send(PeerSinkMessage::Request {
388+
request,
389+
id: id.clone(),
390+
responder,
391+
})
362392
.await
363393
.map_err(|_m| ServiceError::Transport(std::io::Error::other("disconnected")))?;
364394
Ok(RequestHandle {
365395
id,
366396
rx: receiver,
397+
progress_token,
367398
options,
368399
peer: self.clone(),
369400
})
@@ -419,6 +450,7 @@ pub struct RequestContext<R: ServiceRole> {
419450
/// this token will be cancelled when the [`CancelledNotification`] is received.
420451
pub ct: CancellationToken,
421452
pub id: RequestId,
453+
pub meta: Option<RequestMeta>,
422454
/// An interface to fetch the remote client or server
423455
pub peer: Peer<R>,
424456
}
@@ -459,7 +491,7 @@ async fn serve_inner<R, S, T, E, A>(
459491
mut service: S,
460492
transport: T,
461493
peer_info: R::PeerInfo,
462-
id_provider: Arc<AtomicU32RequestIdProvider>,
494+
id_provider: Arc<AtomicU32Provider>,
463495
ct: CancellationToken,
464496
) -> Result<RunningService<R, S>, E>
465497
where
@@ -555,7 +587,11 @@ where
555587
}
556588
}
557589
}
558-
Event::ProxyMessage(PeerSinkMessage::Request(request, id, responder)) => {
590+
Event::ProxyMessage(PeerSinkMessage::Request {
591+
request,
592+
id,
593+
responder,
594+
}) => {
559595
local_responder_pool.insert(id.clone(), responder);
560596
let send_result = sink
561597
.send(JsonRpcMessage::request(request, id.clone()))
@@ -567,7 +603,10 @@ where
567603
}
568604
}
569605
}
570-
Event::ProxyMessage(PeerSinkMessage::Notification(notification, responder)) => {
606+
Event::ProxyMessage(PeerSinkMessage::Notification {
607+
notification,
608+
responder,
609+
}) => {
571610
// catch cancellation notification
572611
let mut cancellation_param = None;
573612
let notification = match notification.try_into() {
@@ -605,6 +644,7 @@ where
605644
ct: context_ct,
606645
id: id.clone(),
607646
peer: peer.clone(),
647+
meta: request.get_meta().cloned(),
608648
};
609649
tokio::spawn(async move {
610650
let result = service.handle_request(request, context).await;

crates/rmcp/src/service/client.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ where
7373
let (sink, stream) = transport.into_transport();
7474
let mut sink = Box::pin(sink);
7575
let mut stream = Box::pin(stream);
76-
let id_provider = <Arc<AtomicU32RequestIdProvider>>::default();
76+
let id_provider = <Arc<AtomicU32Provider>>::default();
7777
// service
7878
let id = id_provider.next_request_id();
7979
let init_request = InitializeRequest {

crates/rmcp/src/service/server.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@ use crate::model::{
66
ClientRequest, ClientResult, CreateMessageRequest, CreateMessageRequestParam,
77
CreateMessageResult, ListRootsRequest, ListRootsResult, LoggingMessageNotification,
88
LoggingMessageNotificationParam, ProgressNotification, ProgressNotificationParam,
9-
PromptListChangedNotification, ProtocolVersion, ResourceListChangedNotification,
10-
ResourceUpdatedNotification, ResourceUpdatedNotificationParam, ServerInfo,
11-
ServerJsonRpcMessage, ServerNotification, ServerRequest, ServerResult,
12-
ToolListChangedNotification,
9+
PromptListChangedNotification, ResourceListChangedNotification, ResourceUpdatedNotification,
10+
ResourceUpdatedNotificationParam, ServerInfo, ServerJsonRpcMessage, ServerNotification,
11+
ServerRequest, ServerResult, ToolListChangedNotification,
1312
};
1413

1514
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
@@ -69,7 +68,7 @@ where
6968
let (sink, stream) = transport.into_transport();
7069
let mut sink = Box::pin(sink);
7170
let mut stream = Box::pin(stream);
72-
let id_provider = <Arc<AtomicU32RequestIdProvider>>::default();
71+
let id_provider = <Arc<AtomicU32Provider>>::default();
7372
// service
7473
let (request, id) = stream
7574
.next()

crates/rmcp/src/transport.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ pub use sse_server::SseServer;
6767
// #[cfg(feature = "transport-ws")]
6868
// pub mod ws;
6969

70+
pub mod streamable_http;
7071
pub trait IntoTransport<R, E, A>: Send + 'static
7172
where
7273
R: ServiceRole,
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pub mod session;

0 commit comments

Comments
 (0)