Skip to content

Commit e0faf1e

Browse files
authored
feat: add support for custom requests (#590)
#580 and #556 introduced support for custom notifications, so this PR takes the next logical step and adds support for custom requests: - Introduces `CustomRequest` and `CustomResult` model types, wires them into the client/server request and result unions, and allows `ClientRequest::method()` to return the dynamic method name. - Implements serde and meta handling for `CustomRequest` so `_meta` is carried through extensions; adds default `on_custom_request` handlers that return `METHOD_NOT_FOUND` unless overridden. - Updates JSON schema fixtures to include the new request/result shapes and `EmptyObject` strictness. - Adds tests for custom request roundtrips and end-to-end client↔server handling. - Focused integration test in `crates/rmcp/tests/test_custom_request.rs`. For additional testing, I used this locally to update Codex to use a custom request instead of a custom notification so that it gets an "ack" from the MCP server to ensure it has processed the update before sending more messages: openai/codex#8142.
1 parent 2e3cc4a commit e0faf1e

File tree

10 files changed

+524
-38
lines changed

10 files changed

+524
-38
lines changed

crates/rmcp/src/handler/client.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ impl<H: ClientHandler> Service<RoleClient> for H {
2626
.create_elicitation(request.params, context)
2727
.await
2828
.map(ClientResult::CreateElicitationResult),
29+
ServerRequest::CustomRequest(request) => self
30+
.on_custom_request(request, context)
31+
.await
32+
.map(ClientResult::CustomResult),
2933
}
3034
}
3135

@@ -123,6 +127,20 @@ pub trait ClientHandler: Sized + Send + Sync + 'static {
123127
}))
124128
}
125129

130+
fn on_custom_request(
131+
&self,
132+
request: CustomRequest,
133+
context: RequestContext<RoleClient>,
134+
) -> impl Future<Output = Result<CustomResult, McpError>> + Send + '_ {
135+
let CustomRequest { method, .. } = request;
136+
let _ = context;
137+
std::future::ready(Err(McpError::new(
138+
ErrorCode::METHOD_NOT_FOUND,
139+
method,
140+
None,
141+
)))
142+
}
143+
126144
fn on_cancelled(
127145
&self,
128146
params: CancelledNotificationParam,

crates/rmcp/src/handler/server.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ impl<H: ServerHandler> Service<RoleServer> for H {
6969
.list_tools(request.params, context)
7070
.await
7171
.map(ServerResult::ListToolsResult),
72+
ClientRequest::CustomRequest(request) => self
73+
.on_custom_request(request, context)
74+
.await
75+
.map(ServerResult::CustomResult),
7276
}
7377
}
7478

@@ -200,6 +204,19 @@ pub trait ServerHandler: Sized + Send + Sync + 'static {
200204
) -> impl Future<Output = Result<ListToolsResult, McpError>> + Send + '_ {
201205
std::future::ready(Ok(ListToolsResult::default()))
202206
}
207+
fn on_custom_request(
208+
&self,
209+
request: CustomRequest,
210+
context: RequestContext<RoleServer>,
211+
) -> impl Future<Output = Result<CustomResult, McpError>> + Send + '_ {
212+
let CustomRequest { method, .. } = request;
213+
let _ = context;
214+
std::future::ready(Err(McpError::new(
215+
ErrorCode::METHOD_NOT_FOUND,
216+
method,
217+
None,
218+
)))
219+
}
203220

204221
fn on_cancelled(
205222
&self,

crates/rmcp/src/model.rs

Lines changed: 99 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ macro_rules! object {
5555
///
5656
/// without returning any specific data.
5757
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Copy, Eq)]
58+
#[serde(deny_unknown_fields)]
5859
#[cfg_attr(feature = "server", derive(schemars::JsonSchema))]
5960
pub struct EmptyObject {}
6061

@@ -606,6 +607,23 @@ impl From<EmptyResult> for () {
606607
fn from(_value: EmptyResult) {}
607608
}
608609

610+
/// A catch-all response either side can use for custom requests.
611+
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
612+
#[serde(transparent)]
613+
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
614+
pub struct CustomResult(pub Value);
615+
616+
impl CustomResult {
617+
pub fn new(result: Value) -> Self {
618+
Self(result)
619+
}
620+
621+
/// Deserialize the result into a strongly-typed structure.
622+
pub fn result_as<T: DeserializeOwned>(&self) -> Result<T, serde_json::Error> {
623+
serde_json::from_value(self.0.clone())
624+
}
625+
}
626+
609627
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
610628
#[serde(rename_all = "camelCase")]
611629
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
@@ -661,6 +679,40 @@ impl CustomNotification {
661679
}
662680
}
663681

682+
/// A catch-all request either side can use to send custom messages to its peer.
683+
///
684+
/// This preserves the raw `method` name and `params` payload so handlers can
685+
/// deserialize them into domain-specific types.
686+
#[derive(Debug, Clone)]
687+
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
688+
pub struct CustomRequest {
689+
pub method: String,
690+
pub params: Option<Value>,
691+
/// extensions will carry anything possible in the context, including [`Meta`]
692+
///
693+
/// this is similar with the Extensions in `http` crate
694+
#[cfg_attr(feature = "schemars", schemars(skip))]
695+
pub extensions: Extensions,
696+
}
697+
698+
impl CustomRequest {
699+
pub fn new(method: impl Into<String>, params: Option<Value>) -> Self {
700+
Self {
701+
method: method.into(),
702+
params,
703+
extensions: Extensions::default(),
704+
}
705+
}
706+
707+
/// Deserialize `params` into a strongly-typed structure.
708+
pub fn params_as<T: DeserializeOwned>(&self) -> Result<Option<T>, serde_json::Error> {
709+
self.params
710+
.as_ref()
711+
.map(|params| serde_json::from_value(params.clone()))
712+
.transpose()
713+
}
714+
}
715+
664716
const_string!(InitializeResultMethod = "initialize");
665717
/// # Initialization
666718
/// This request is sent from the client to the server when it first connects, asking it to begin initialization.
@@ -1757,11 +1809,12 @@ ts_union!(
17571809
| SubscribeRequest
17581810
| UnsubscribeRequest
17591811
| CallToolRequest
1760-
| ListToolsRequest;
1812+
| ListToolsRequest
1813+
| CustomRequest;
17611814
);
17621815

17631816
impl ClientRequest {
1764-
pub fn method(&self) -> &'static str {
1817+
pub fn method(&self) -> &str {
17651818
match &self {
17661819
ClientRequest::PingRequest(r) => r.method.as_str(),
17671820
ClientRequest::InitializeRequest(r) => r.method.as_str(),
@@ -1776,6 +1829,7 @@ impl ClientRequest {
17761829
ClientRequest::UnsubscribeRequest(r) => r.method.as_str(),
17771830
ClientRequest::CallToolRequest(r) => r.method.as_str(),
17781831
ClientRequest::ListToolsRequest(r) => r.method.as_str(),
1832+
ClientRequest::CustomRequest(r) => r.method.as_str(),
17791833
}
17801834
}
17811835
}
@@ -1790,7 +1844,12 @@ ts_union!(
17901844
);
17911845

17921846
ts_union!(
1793-
export type ClientResult = box CreateMessageResult | ListRootsResult | CreateElicitationResult | EmptyResult;
1847+
export type ClientResult =
1848+
box CreateMessageResult
1849+
| ListRootsResult
1850+
| CreateElicitationResult
1851+
| EmptyResult
1852+
| CustomResult;
17941853
);
17951854

17961855
impl ClientResult {
@@ -1806,7 +1865,8 @@ ts_union!(
18061865
| PingRequest
18071866
| CreateMessageRequest
18081867
| ListRootsRequest
1809-
| CreateElicitationRequest;
1868+
| CreateElicitationRequest
1869+
| CustomRequest;
18101870
);
18111871

18121872
ts_union!(
@@ -1834,6 +1894,7 @@ ts_union!(
18341894
| ListToolsResult
18351895
| CreateElicitationResult
18361896
| EmptyResult
1897+
| CustomResult
18371898
;
18381899
);
18391900

@@ -1960,6 +2021,40 @@ mod tests {
19602021
assert_eq!(json, raw);
19612022
}
19622023

2024+
#[test]
2025+
fn test_custom_request_roundtrip() {
2026+
let raw = json!( {
2027+
"jsonrpc": JsonRpcVersion2_0,
2028+
"id": 42,
2029+
"method": "requests/custom",
2030+
"params": {"foo": "bar"},
2031+
});
2032+
2033+
let message: ClientJsonRpcMessage =
2034+
serde_json::from_value(raw.clone()).expect("invalid request");
2035+
match &message {
2036+
ClientJsonRpcMessage::Request(JsonRpcRequest { id, request, .. }) => {
2037+
assert_eq!(id, &RequestId::Number(42));
2038+
match request {
2039+
ClientRequest::CustomRequest(custom) => {
2040+
let expected_request = json!({
2041+
"method": "requests/custom",
2042+
"params": {"foo": "bar"},
2043+
});
2044+
let actual_request =
2045+
serde_json::to_value(custom).expect("serialize custom request");
2046+
assert_eq!(actual_request, expected_request);
2047+
}
2048+
other => panic!("Expected custom request, got: {other:?}"),
2049+
}
2050+
}
2051+
other => panic!("Expected request, got: {other:?}"),
2052+
}
2053+
2054+
let json = serde_json::to_value(message).expect("valid json");
2055+
assert_eq!(json, raw);
2056+
}
2057+
19632058
#[test]
19642059
fn test_request_conversion() {
19652060
let raw = json!( {

crates/rmcp/src/model/meta.rs

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ use serde::{Deserialize, Serialize};
44
use serde_json::Value;
55

66
use super::{
7-
ClientNotification, ClientRequest, CustomNotification, Extensions, JsonObject, JsonRpcMessage,
8-
NumberOrString, ProgressToken, ServerNotification, ServerRequest,
7+
ClientNotification, ClientRequest, CustomNotification, CustomRequest, Extensions, JsonObject,
8+
JsonRpcMessage, NumberOrString, ProgressToken, ServerNotification, ServerRequest,
99
};
1010

1111
pub trait GetMeta {
@@ -38,6 +38,26 @@ impl GetMeta for CustomNotification {
3838
}
3939
}
4040

41+
impl GetExtensions for CustomRequest {
42+
fn extensions(&self) -> &Extensions {
43+
&self.extensions
44+
}
45+
fn extensions_mut(&mut self) -> &mut Extensions {
46+
&mut self.extensions
47+
}
48+
}
49+
50+
impl GetMeta for CustomRequest {
51+
fn get_meta_mut(&mut self) -> &mut Meta {
52+
self.extensions_mut().get_or_insert_default()
53+
}
54+
fn get_meta(&self) -> &Meta {
55+
self.extensions()
56+
.get::<Meta>()
57+
.unwrap_or(Meta::static_empty())
58+
}
59+
}
60+
4161
macro_rules! variant_extension {
4262
(
4363
$Enum: ident {
@@ -86,6 +106,7 @@ variant_extension! {
86106
UnsubscribeRequest
87107
CallToolRequest
88108
ListToolsRequest
109+
CustomRequest
89110
}
90111
}
91112

@@ -95,6 +116,7 @@ variant_extension! {
95116
CreateMessageRequest
96117
ListRootsRequest
97118
CreateElicitationRequest
119+
CustomRequest
98120
}
99121
}
100122

crates/rmcp/src/model/serde_impl.rs

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ use std::borrow::Cow;
33
use serde::{Deserialize, Serialize};
44

55
use super::{
6-
CustomNotification, Extensions, Meta, Notification, NotificationNoParam, Request,
7-
RequestNoParam, RequestOptionalParam,
6+
CustomNotification, CustomRequest, Extensions, Meta, Notification, NotificationNoParam,
7+
Request, RequestNoParam, RequestOptionalParam,
88
};
99
#[derive(Serialize, Deserialize)]
1010
struct WithMeta<'a, P> {
@@ -249,6 +249,59 @@ where
249249
}
250250
}
251251

252+
impl Serialize for CustomRequest {
253+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
254+
where
255+
S: serde::Serializer,
256+
{
257+
let extensions = &self.extensions;
258+
let _meta = extensions.get::<Meta>().map(Cow::Borrowed);
259+
let params = self.params.as_ref();
260+
261+
let params = if _meta.is_some() || params.is_some() {
262+
Some(WithMeta {
263+
_meta,
264+
_rest: &self.params,
265+
})
266+
} else {
267+
None
268+
};
269+
270+
ProxyOptionalParam::serialize(
271+
&ProxyOptionalParam {
272+
method: &self.method,
273+
params,
274+
},
275+
serializer,
276+
)
277+
}
278+
}
279+
280+
impl<'de> Deserialize<'de> for CustomRequest {
281+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
282+
where
283+
D: serde::Deserializer<'de>,
284+
{
285+
let body =
286+
ProxyOptionalParam::<'_, _, Option<serde_json::Value>>::deserialize(deserializer)?;
287+
let mut params = None;
288+
let mut _meta = None;
289+
if let Some(body_params) = body.params {
290+
params = body_params._rest;
291+
_meta = body_params._meta.map(|m| m.into_owned());
292+
}
293+
let mut extensions = Extensions::new();
294+
if let Some(meta) = _meta {
295+
extensions.insert(meta);
296+
}
297+
Ok(CustomRequest {
298+
extensions,
299+
method: body.method,
300+
params,
301+
})
302+
}
303+
}
304+
252305
impl Serialize for CustomNotification {
253306
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
254307
where

0 commit comments

Comments
 (0)