Skip to content

Commit 494a2b8

Browse files
committed
feat: implement output schema validation
1 parent 94428d5 commit 494a2b8

File tree

4 files changed

+92
-4
lines changed

4 files changed

+92
-4
lines changed

crates/rmcp-macros/src/tool.rs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,14 @@ fn extract_schema_from_return_type(ret_type: &syn::Type) -> Option<Expr> {
2727
// First, try direct Json<T>
2828
if let Some(inner_type) = extract_json_inner_type(ret_type) {
2929
return syn::parse2::<Expr>(quote! {
30-
rmcp::handler::server::tool::cached_schema_for_type::<#inner_type>()
30+
rmcp::handler::server::tool::cached_schema_for_output::<#inner_type>()
31+
.unwrap_or_else(|e| {
32+
panic!(
33+
"Invalid output schema for Json<{}>: {}",
34+
std::any::type_name::<#inner_type>(),
35+
e
36+
)
37+
})
3138
})
3239
.ok();
3340
}
@@ -57,7 +64,14 @@ fn extract_schema_from_return_type(ret_type: &syn::Type) -> Option<Expr> {
5764
let inner_type = extract_json_inner_type(ok_type)?;
5865

5966
syn::parse2::<Expr>(quote! {
60-
rmcp::handler::server::tool::cached_schema_for_type::<#inner_type>()
67+
rmcp::handler::server::tool::cached_schema_for_output::<#inner_type>()
68+
.unwrap_or_else(|e| {
69+
panic!(
70+
"Invalid output schema for Result<Json<{}>, E>: {}",
71+
std::any::type_name::<#inner_type>(),
72+
e
73+
)
74+
})
6175
})
6276
.ok()
6377
}

crates/rmcp/src/handler/server/common.rs

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,50 @@ pub fn cached_schema_for_type<T: JsonSchema + std::any::Any>() -> Arc<JsonObject
5050
})
5151
}
5252

53+
/// Generate and validate a JSON schema for outputSchema (must have root type "object").
54+
pub fn schema_for_output<T: JsonSchema>() -> Result<JsonObject, String> {
55+
let schema = schema_for_type::<T>();
56+
57+
match schema.get("type") {
58+
Some(serde_json::Value::String(t)) if t == "object" => Ok(schema),
59+
Some(serde_json::Value::String(t)) => Err(format!(
60+
"MCP specification requires tool outputSchema to have root type 'object', but found '{}'.",
61+
t
62+
)),
63+
None => Err(
64+
"Schema is missing 'type' field. MCP specification requires outputSchema to have root type 'object'.".to_string()
65+
),
66+
Some(other) => Err(format!(
67+
"Schema 'type' field has unexpected format: {:?}. Expected \"object\".",
68+
other
69+
)),
70+
}
71+
}
72+
73+
/// Call [`schema_for_output`] with a cache.
74+
pub fn cached_schema_for_output<T: JsonSchema + std::any::Any>() -> Result<Arc<JsonObject>, String>
75+
{
76+
thread_local! {
77+
static CACHE_FOR_OUTPUT: std::sync::RwLock<HashMap<TypeId, Result<Arc<JsonObject>, String>>> = Default::default();
78+
};
79+
CACHE_FOR_OUTPUT.with(|cache| {
80+
if let Some(result) = cache
81+
.read()
82+
.expect("output schema cache lock poisoned")
83+
.get(&TypeId::of::<T>())
84+
{
85+
result.clone()
86+
} else {
87+
let result = schema_for_output::<T>().map(Arc::new);
88+
cache
89+
.write()
90+
.expect("output schema cache lock poisoned")
91+
.insert(TypeId::of::<T>(), result.clone());
92+
result
93+
}
94+
})
95+
}
96+
5397
/// Trait for extracting parts from a context, unifying tool and prompt extraction
5498
pub trait FromContextPart<C>: Sized {
5599
fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData>;
@@ -143,3 +187,25 @@ pub trait AsRequestContext {
143187
fn as_request_context(&self) -> &RequestContext<RoleServer>;
144188
fn as_request_context_mut(&mut self) -> &mut RequestContext<RoleServer>;
145189
}
190+
191+
#[cfg(test)]
192+
mod tests {
193+
use super::*;
194+
195+
#[derive(serde::Serialize, serde::Deserialize, JsonSchema)]
196+
struct TestObject {
197+
value: i32,
198+
}
199+
200+
#[test]
201+
fn test_schema_for_output_rejects_primitive() {
202+
let result = schema_for_output::<i32>();
203+
assert!(result.is_err(),);
204+
}
205+
206+
#[test]
207+
fn test_schema_for_output_accepts_object() {
208+
let result = schema_for_output::<TestObject>();
209+
assert!(result.is_ok(),);
210+
}
211+
}

crates/rmcp/src/handler/server/tool.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ use serde::de::DeserializeOwned;
99

1010
use super::common::{AsRequestContext, FromContextPart};
1111
pub use super::{
12-
common::{Extension, RequestId, cached_schema_for_type, schema_for_type},
12+
common::{
13+
Extension, RequestId, cached_schema_for_output, cached_schema_for_type, schema_for_type,
14+
},
1315
router::tool::{ToolRoute, ToolRouter},
1416
};
1517
use crate::{

crates/rmcp/src/model/tool.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,14 @@ impl Tool {
165165
}
166166

167167
/// Set the output schema using a type that implements JsonSchema
168+
///
169+
/// # Panics
170+
///
171+
/// Panics if the generated schema does not have root type "object" as required by MCP specification.
168172
pub fn with_output_schema<T: JsonSchema + 'static>(mut self) -> Self {
169-
self.output_schema = Some(crate::handler::server::tool::cached_schema_for_type::<T>());
173+
let schema = crate::handler::server::tool::cached_schema_for_output::<T>()
174+
.unwrap_or_else(|e| panic!("Invalid output schema for tool '{}': {}", self.name, e));
175+
self.output_schema = Some(schema);
170176
self
171177
}
172178

0 commit comments

Comments
 (0)