Skip to content

Commit 5d92061

Browse files
authored
feat(openapi): add OpenAPI v3 compatibility and test for nullable field schema workaround (#135) (#137)
1 parent 030b6f0 commit 5d92061

File tree

2 files changed

+239
-2
lines changed

2 files changed

+239
-2
lines changed

crates/rmcp/src/handler/server/tool.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ use crate::{
1414
};
1515
/// A shortcut for generating a JSON schema for a type.
1616
pub fn schema_for_type<T: JsonSchema>() -> JsonObject {
17-
let schema = schemars::r#gen::SchemaGenerator::default().into_root_schema_for::<T>();
17+
let settings = schemars::r#gen::SchemaSettings::openapi3();
18+
let generator = settings.into_generator();
19+
let schema = generator.into_root_schema_for::<T>();
1820
let object = serde_json::to_value(schema).expect("failed to serialize schema");
1921
match object {
2022
serde_json::Value::Object(object) => object,

crates/rmcp/tests/test_tool_macros.rs

Lines changed: 236 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
1+
//cargo test --test test_tool_macros --features "client server"
2+
13
use std::sync::Arc;
24

3-
use rmcp::{ServerHandler, handler::server::tool::ToolCallContext, tool};
5+
use rmcp::{
6+
ClientHandler, Peer, RoleClient, ServerHandler, ServiceExt,
7+
handler::server::tool::ToolCallContext,
8+
model::{CallToolRequestParam, ClientInfo},
9+
tool,
10+
};
411
use schemars::JsonSchema;
512
use serde::{Deserialize, Serialize};
613

@@ -36,6 +43,11 @@ impl Server {
3643
}
3744
#[tool(description = "Empty Parameter")]
3845
async fn empty_param(&self) {}
46+
47+
#[tool(description = "Optional Parameter")]
48+
async fn optional_param(&self, #[tool(param)] city: Option<String>) -> String {
49+
city.unwrap_or_default()
50+
}
3951
}
4052

4153
// define generic service trait
@@ -99,4 +111,227 @@ async fn test_tool_macros_with_generics() {
99111
assert_eq!(server.get_data().await, "mock data");
100112
}
101113

114+
#[tokio::test]
115+
async fn test_tool_macros_with_optional_param() {
116+
let _attr = Server::optional_param_tool_attr();
117+
// println!("{_attr:?}");
118+
let attr_type = _attr
119+
.input_schema
120+
.get("properties")
121+
.unwrap()
122+
.get("city")
123+
.unwrap()
124+
.get("type")
125+
.unwrap();
126+
println!("_attr.input_schema: {:?}", attr_type);
127+
assert_eq!(attr_type.as_str().unwrap(), "string");
128+
}
129+
102130
impl GetWeatherRequest {}
131+
132+
// Struct defined for testing optional field schema generation
133+
#[derive(Debug, Deserialize, Serialize, JsonSchema)]
134+
pub struct OptionalFieldTestSchema {
135+
#[schemars(description = "An optional description field")]
136+
pub description: Option<String>,
137+
}
138+
139+
// Struct defined for testing optional i64 field schema generation and null handling
140+
#[derive(Debug, Deserialize, Serialize, JsonSchema)]
141+
pub struct OptionalI64TestSchema {
142+
#[schemars(description = "An optional i64 field")]
143+
pub count: Option<i64>,
144+
pub mandatory_field: String, // Added to ensure non-empty object schema
145+
}
146+
147+
// Dummy struct to host the test tool method
148+
#[derive(Debug, Clone, Default)]
149+
pub struct OptionalSchemaTester {}
150+
151+
impl OptionalSchemaTester {
152+
// Dummy tool function using the test schema as an aggregated parameter
153+
#[tool(description = "A tool to test optional schema generation")]
154+
async fn test_optional_aggr(&self, #[tool(aggr)] _req: OptionalFieldTestSchema) {
155+
// Implementation doesn't matter for schema testing
156+
// Return type changed to () to satisfy IntoCallToolResult
157+
}
158+
159+
// Tool function to test optional i64 handling
160+
#[tool(description = "A tool to test optional i64 schema generation")]
161+
async fn test_optional_i64_aggr(&self, #[tool(aggr)] req: OptionalI64TestSchema) -> String {
162+
match req.count {
163+
Some(c) => format!("Received count: {}", c),
164+
None => "Received null count".to_string(),
165+
}
166+
}
167+
}
168+
169+
// Implement ServerHandler to route tool calls for OptionalSchemaTester
170+
impl ServerHandler for OptionalSchemaTester {
171+
async fn call_tool(
172+
&self,
173+
request: rmcp::model::CallToolRequestParam,
174+
context: rmcp::service::RequestContext<rmcp::RoleServer>,
175+
) -> Result<rmcp::model::CallToolResult, rmcp::Error> {
176+
let tcc = ToolCallContext::new(self, request, context);
177+
match tcc.name() {
178+
"test_optional_aggr" => Self::test_optional_aggr_tool_call(tcc).await,
179+
"test_optional_i64_aggr" => Self::test_optional_i64_aggr_tool_call(tcc).await,
180+
_ => Err(rmcp::Error::invalid_params("method not found", None)),
181+
}
182+
}
183+
}
184+
185+
#[test]
186+
fn test_optional_field_schema_generation_via_macro() {
187+
// tests https://github.com/modelcontextprotocol/rust-sdk/issues/135
188+
189+
// Get the attributes generated by the #[tool] macro helper
190+
let tool_attr = OptionalSchemaTester::test_optional_aggr_tool_attr();
191+
192+
// Print the actual generated schema for debugging
193+
println!(
194+
"Actual input schema generated by macro: {:#?}",
195+
tool_attr.input_schema
196+
);
197+
198+
// Verify the schema generated for the aggregated OptionalFieldTestSchema
199+
// by the macro infrastructure (which should now use OpenAPI 3 settings)
200+
let input_schema_map = &*tool_attr.input_schema; // Dereference Arc<JsonObject>
201+
202+
// Check the schema for the 'description' property within the input schema
203+
let properties = input_schema_map
204+
.get("properties")
205+
.expect("Schema should have properties")
206+
.as_object()
207+
.unwrap();
208+
let description_schema = properties
209+
.get("description")
210+
.expect("Properties should include description")
211+
.as_object()
212+
.unwrap();
213+
214+
// Assert that the format is now `type: "string", nullable: true`
215+
assert_eq!(
216+
description_schema.get("type").map(|v| v.as_str().unwrap()),
217+
Some("string"),
218+
"Schema for Option<String> generated by macro should be type: \"string\""
219+
);
220+
assert_eq!(
221+
description_schema
222+
.get("nullable")
223+
.map(|v| v.as_bool().unwrap()),
224+
Some(true),
225+
"Schema for Option<String> generated by macro should have nullable: true"
226+
);
227+
// We still check the description is correct
228+
assert_eq!(
229+
description_schema
230+
.get("description")
231+
.map(|v| v.as_str().unwrap()),
232+
Some("An optional description field")
233+
);
234+
235+
// Ensure the old 'type: [T, null]' format is NOT used
236+
let type_value = description_schema.get("type").unwrap();
237+
assert!(
238+
!type_value.is_array(),
239+
"Schema type should not be an array [T, null]"
240+
);
241+
}
242+
243+
// Define a dummy client handler
244+
#[derive(Debug, Clone, Default)]
245+
struct DummyClientHandler {
246+
peer: Option<Peer<RoleClient>>,
247+
}
248+
249+
impl ClientHandler for DummyClientHandler {
250+
fn get_info(&self) -> ClientInfo {
251+
ClientInfo::default()
252+
}
253+
254+
fn set_peer(&mut self, peer: Peer<RoleClient>) {
255+
self.peer = Some(peer);
256+
}
257+
258+
fn get_peer(&self) -> Option<Peer<RoleClient>> {
259+
self.peer.clone()
260+
}
261+
}
262+
263+
#[tokio::test]
264+
async fn test_optional_i64_field_with_null_input() -> anyhow::Result<()> {
265+
let (server_transport, client_transport) = tokio::io::duplex(4096);
266+
267+
// Server setup
268+
let server = OptionalSchemaTester::default();
269+
let server_handle = tokio::spawn(async move {
270+
server.serve(server_transport).await?.waiting().await?;
271+
anyhow::Ok(())
272+
});
273+
274+
// Create a simple client handler that just forwards tool calls
275+
let client_handler = DummyClientHandler::default();
276+
let client = client_handler.serve(client_transport).await?;
277+
278+
// Test null case
279+
let result = client
280+
.call_tool(CallToolRequestParam {
281+
name: "test_optional_i64_aggr".into(),
282+
arguments: Some(
283+
serde_json::json!({
284+
"count": null,
285+
"mandatory_field": "test_null"
286+
})
287+
.as_object()
288+
.unwrap()
289+
.clone(),
290+
),
291+
})
292+
.await?;
293+
294+
let result_text = result
295+
.content
296+
.first()
297+
.and_then(|content| content.raw.as_text())
298+
.map(|text| text.text.as_str())
299+
.expect("Expected text content");
300+
301+
assert_eq!(
302+
result_text, "Received null count",
303+
"Null case should return expected message"
304+
);
305+
306+
// Test Some case
307+
let some_result = client
308+
.call_tool(CallToolRequestParam {
309+
name: "test_optional_i64_aggr".into(),
310+
arguments: Some(
311+
serde_json::json!({
312+
"count": 42,
313+
"mandatory_field": "test_some"
314+
})
315+
.as_object()
316+
.unwrap()
317+
.clone(),
318+
),
319+
})
320+
.await?;
321+
322+
let some_result_text = some_result
323+
.content
324+
.first()
325+
.and_then(|content| content.raw.as_text())
326+
.map(|text| text.text.as_str())
327+
.expect("Expected text content");
328+
329+
assert_eq!(
330+
some_result_text, "Received count: 42",
331+
"Some case should return expected message"
332+
);
333+
334+
client.cancel().await?;
335+
server_handle.await??;
336+
Ok(())
337+
}

0 commit comments

Comments
 (0)