Skip to content

Commit 36834f3

Browse files
authored
fix(macro): add generics marco types support (#98)
add generic marco support Signed-off-by: jokemanfire <[email protected]>
1 parent 6837546 commit 36834f3

File tree

7 files changed

+245
-17
lines changed

7 files changed

+245
-17
lines changed

crates/rmcp-macros/src/tool.rs

Lines changed: 100 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -177,30 +177,114 @@ pub(crate) fn tool(attr: TokenStream, input: TokenStream) -> syn::Result<TokenSt
177177
pub(crate) fn tool_impl_item(attr: TokenStream, mut input: ItemImpl) -> syn::Result<TokenStream> {
178178
let tool_impl_attr: ToolImplItemAttrs = syn::parse2(attr)?;
179179
let tool_box_ident = tool_impl_attr.tool_box;
180+
181+
// get all tool function ident
182+
let mut tool_fn_idents = Vec::new();
183+
for item in &input.items {
184+
if let syn::ImplItem::Fn(method) = item {
185+
for attr in &method.attrs {
186+
if attr.path().is_ident(TOOL_IDENT) {
187+
tool_fn_idents.push(method.sig.ident.clone());
188+
}
189+
}
190+
}
191+
}
192+
193+
// handle different cases
180194
if input.trait_.is_some() {
181195
if let Some(ident) = tool_box_ident {
182-
input.items.push(parse_quote!(
183-
rmcp::tool_box!(@derive #ident);
184-
));
196+
// check if there are generic parameters
197+
if !input.generics.params.is_empty() {
198+
// for trait implementation with generic parameters, directly use the already generated *_inner method
199+
200+
// generate call_tool method
201+
input.items.push(parse_quote! {
202+
async fn call_tool(
203+
&self,
204+
request: rmcp::model::CallToolRequestParam,
205+
context: rmcp::service::RequestContext<rmcp::RoleServer>,
206+
) -> Result<rmcp::model::CallToolResult, rmcp::Error> {
207+
self.call_tool_inner(request, context).await
208+
}
209+
});
210+
211+
// generate list_tools method
212+
input.items.push(parse_quote! {
213+
async fn list_tools(
214+
&self,
215+
request: rmcp::model::PaginatedRequestParam,
216+
context: rmcp::service::RequestContext<rmcp::RoleServer>,
217+
) -> Result<rmcp::model::ListToolsResult, rmcp::Error> {
218+
self.list_tools_inner(request, context).await
219+
}
220+
});
221+
} else {
222+
// if there are no generic parameters, add tool box derive
223+
input.items.push(parse_quote!(
224+
rmcp::tool_box!(@derive #ident);
225+
));
226+
}
185227
}
186228
} else if let Some(ident) = tool_box_ident {
187-
let mut tool_fn_idents = Vec::new();
188-
for item in &input.items {
189-
if let syn::ImplItem::Fn(method) = item {
190-
for attr in &method.attrs {
191-
if attr.path().is_ident(TOOL_IDENT) {
192-
tool_fn_idents.push(method.sig.ident.clone());
229+
// if it is a normal impl block
230+
if !input.generics.params.is_empty() {
231+
// if there are generic parameters, not use tool_box! macro, but generate code directly
232+
233+
// create call code for each tool function
234+
let match_arms = tool_fn_idents.iter().map(|ident| {
235+
let attr_fn = Ident::new(&format!("{}_tool_attr", ident), ident.span());
236+
let call_fn = Ident::new(&format!("{}_tool_call", ident), ident.span());
237+
quote! {
238+
name if name == Self::#attr_fn().name => {
239+
Self::#call_fn(tcc).await
193240
}
194241
}
195-
}
242+
});
243+
244+
let tool_attrs = tool_fn_idents.iter().map(|ident| {
245+
let attr_fn = Ident::new(&format!("{}_tool_attr", ident), ident.span());
246+
quote! { Self::#attr_fn() }
247+
});
248+
249+
// implement call_tool method
250+
input.items.push(parse_quote! {
251+
async fn call_tool_inner(
252+
&self,
253+
request: rmcp::model::CallToolRequestParam,
254+
context: rmcp::service::RequestContext<rmcp::RoleServer>,
255+
) -> Result<rmcp::model::CallToolResult, rmcp::Error> {
256+
let tcc = rmcp::handler::server::tool::ToolCallContext::new(self, request, context);
257+
match tcc.name() {
258+
#(#match_arms,)*
259+
_ => Err(rmcp::Error::invalid_params("tool not found", None)),
260+
}
261+
}
262+
});
263+
264+
// implement list_tools method
265+
input.items.push(parse_quote! {
266+
async fn list_tools_inner(
267+
&self,
268+
_: rmcp::model::PaginatedRequestParam,
269+
_: rmcp::service::RequestContext<rmcp::RoleServer>,
270+
) -> Result<rmcp::model::ListToolsResult, rmcp::Error> {
271+
Ok(rmcp::model::ListToolsResult {
272+
next_cursor: None,
273+
tools: vec![#(#tool_attrs),*],
274+
})
275+
}
276+
});
277+
} else {
278+
// if there are no generic parameters, use the original tool_box! macro
279+
let this_type_ident = &input.self_ty;
280+
input.items.push(parse_quote!(
281+
rmcp::tool_box!(#this_type_ident {
282+
#(#tool_fn_idents),*
283+
} #ident);
284+
));
196285
}
197-
let this_type_ident = &input.self_ty;
198-
input.items.push(parse_quote!(
199-
rmcp::tool_box!(#this_type_ident {
200-
#(#tool_fn_idents),*
201-
} #ident);
202-
));
203286
}
287+
204288
Ok(quote! {
205289
#input
206290
})

crates/rmcp/tests/test_tool_macros.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::sync::Arc;
2+
13
use rmcp::{ServerHandler, handler::server::tool::ToolCallContext, tool};
24
use schemars::JsonSchema;
35
use serde::{Deserialize, Serialize};
@@ -21,6 +23,7 @@ impl ServerHandler for Server {
2123
}
2224
}
2325
}
26+
2427
#[derive(Debug, Clone, Default)]
2528
pub struct Server {}
2629

@@ -35,6 +38,40 @@ impl Server {
3538
async fn empty_param(&self) {}
3639
}
3740

41+
// define generic service trait
42+
pub trait DataService: Send + Sync + 'static {
43+
fn get_data(&self) -> String;
44+
}
45+
46+
// mock service for test
47+
#[derive(Clone)]
48+
struct MockDataService;
49+
impl DataService for MockDataService {
50+
fn get_data(&self) -> String {
51+
"mock data".to_string()
52+
}
53+
}
54+
55+
// define generic server
56+
#[derive(Debug, Clone)]
57+
pub struct GenericServer<DS: DataService> {
58+
data_service: Arc<DS>,
59+
}
60+
61+
#[tool(tool_box)]
62+
impl<DS: DataService> GenericServer<DS> {
63+
pub fn new(data_service: DS) -> Self {
64+
Self {
65+
data_service: Arc::new(data_service),
66+
}
67+
}
68+
69+
#[tool(description = "Get data from the service")]
70+
async fn get_data(&self) -> String {
71+
self.data_service.get_data()
72+
}
73+
}
74+
3875
#[tokio::test]
3976
async fn test_tool_macros() {
4077
let server = Server::default();
@@ -52,4 +89,14 @@ async fn test_tool_macros_with_empty_param() {
5289
assert!(_attr.input_schema.get("properties").is_none());
5390
}
5491

92+
#[tokio::test]
93+
async fn test_tool_macros_with_generics() {
94+
let mock_service = MockDataService;
95+
let server = GenericServer::new(mock_service);
96+
let _attr = GenericServer::<MockDataService>::get_data_tool_attr();
97+
let _get_data_call_fn = GenericServer::<MockDataService>::get_data_tool_call;
98+
let _get_data_fn = GenericServer::<MockDataService>::get_data;
99+
assert_eq!(server.get_data().await, "mock data");
100+
}
101+
55102
impl GetWeatherRequest {}

examples/servers/Cargo.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,8 @@ path = "src/axum.rs"
3939

4040
[[example]]
4141
name = "servers_axum_router"
42-
path = "src/axum_router.rs"
42+
path = "src/axum_router.rs"
43+
44+
[[example]]
45+
name = "servers_generic_server"
46+
path = "src/generic_service.rs"

examples/servers/src/common/counter.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ pub struct Counter {
1919
}
2020
#[tool(tool_box)]
2121
impl Counter {
22+
#[allow(dead_code)]
2223
pub fn new() -> Self {
2324
Self {
2425
counter: Arc::new(Mutex::new(0)),
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
use std::sync::Arc;
2+
3+
use rmcp::{
4+
ServerHandler,
5+
model::{ServerCapabilities, ServerInfo},
6+
schemars, tool,
7+
};
8+
9+
#[allow(dead_code)]
10+
pub trait DataService: Send + Sync + 'static {
11+
fn get_data(&self) -> String;
12+
fn set_data(&mut self, data: String);
13+
}
14+
15+
#[derive(Debug, Clone)]
16+
pub struct MemoryDataService {
17+
data: String,
18+
}
19+
20+
impl MemoryDataService {
21+
#[allow(dead_code)]
22+
pub fn new(initial_data: impl Into<String>) -> Self {
23+
Self {
24+
data: initial_data.into(),
25+
}
26+
}
27+
}
28+
29+
impl DataService for MemoryDataService {
30+
fn get_data(&self) -> String {
31+
self.data.clone()
32+
}
33+
34+
fn set_data(&mut self, data: String) {
35+
self.data = data;
36+
}
37+
}
38+
39+
#[derive(Debug, Clone)]
40+
pub struct GenericService<DS: DataService> {
41+
#[allow(dead_code)]
42+
data_service: Arc<DS>,
43+
}
44+
45+
#[tool(tool_box)]
46+
impl<DS: DataService> GenericService<DS> {
47+
pub fn new(data_service: DS) -> Self {
48+
Self {
49+
data_service: Arc::new(data_service),
50+
}
51+
}
52+
53+
#[tool(description = "get memory from service")]
54+
pub async fn get_data(&self) -> String {
55+
self.data_service.get_data()
56+
}
57+
58+
#[tool(description = "set memory to service")]
59+
pub async fn set_data(&self, #[tool(param)] data: String) -> String {
60+
let new_data = data.clone();
61+
format!("Current memory: {}", new_data)
62+
}
63+
}
64+
65+
impl<DS: DataService> ServerHandler for GenericService<DS> {
66+
fn get_info(&self) -> ServerInfo {
67+
ServerInfo {
68+
instructions: Some("generic data service".into()),
69+
capabilities: ServerCapabilities::builder().enable_tools().build(),
70+
..Default::default()
71+
}
72+
}
73+
}

examples/servers/src/common/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
pub mod calculator;
22
pub mod counter;
3+
pub mod generic_service;
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
use std::error::Error;
2+
mod common;
3+
use common::generic_service::{GenericService, MemoryDataService};
4+
use rmcp::serve_server;
5+
6+
#[tokio::main]
7+
async fn main() -> Result<(), Box<dyn Error>> {
8+
let memory_service = MemoryDataService::new("initial data");
9+
10+
let generic_service = GenericService::new(memory_service);
11+
12+
println!("start server, connect to standard input/output");
13+
14+
let io = (tokio::io::stdin(), tokio::io::stdout());
15+
16+
serve_server(generic_service, io).await?;
17+
Ok(())
18+
}

0 commit comments

Comments
 (0)