Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions crates/rmcp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,8 @@ path = "tests/test_progress_subscriber.rs"
name = "test_elicitation"
required-features = ["elicitation", "client", "server"]
path = "tests/test_elicitation.rs"

[[test]]
name = "test_task"
required-features = ["server", "client", "macros"]
path = "tests/test_task.rs"
4 changes: 4 additions & 0 deletions crates/rmcp/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ pub enum RmcpError {
error: Box<dyn std::error::Error + Send + Sync>,
},
// and cancellation shouldn't be an error?

// TODO: add more error variants as needed
#[error("Task error: {0}")]
TaskError(String),
}

impl RmcpError {
Expand Down
49 changes: 45 additions & 4 deletions crates/rmcp/src/handler/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub mod router;
pub mod tool;
pub mod tool_name_validation;
pub mod wrapper;

impl<H: ServerHandler> Service<RoleServer> for H {
async fn handle_request(
&self,
Expand Down Expand Up @@ -61,14 +62,29 @@ impl<H: ServerHandler> Service<RoleServer> for H {
.unsubscribe(request.params, context)
.await
.map(ServerResult::empty),
ClientRequest::CallToolRequest(request) => self
.call_tool(request.params, context)
.await
.map(ServerResult::CallToolResult),
ClientRequest::CallToolRequest(request) => {
if request.params.task.is_some() {
tracing::info!("Enqueueing task for tool call: {}", request.params.name);
self.enqueue_task(request.params, context.clone()).await.map(ServerResult::GetTaskInfoResult)
}else{
self
.call_tool(request.params, context)
.await
.map(ServerResult::CallToolResult)
}
},
ClientRequest::ListToolsRequest(request) => self
.list_tools(request.params, context)
.await
.map(ServerResult::ListToolsResult),
ClientRequest::ListTasksRequest(request) => self
.list_tasks(request.params, context)
.await
.map(ServerResult::ListTasksResult),
ClientRequest::GetTaskInfoRequest(request) => self
.get_task_info(request.params, context)
.await
.map(ServerResult::GetTaskInfoResult),
}
}

Expand Down Expand Up @@ -104,6 +120,15 @@ impl<H: ServerHandler> Service<RoleServer> for H {

#[allow(unused_variables)]
pub trait ServerHandler: Sized + Send + Sync + 'static {
fn enqueue_task(
&self,
_request: CallToolRequestParam,
_context: RequestContext<RoleServer>,
) -> impl Future<Output = Result<GetTaskInfoResult, McpError>> + Send + '_ {
std::future::ready(Err(McpError::internal_error(
"Task processing not implemented".to_string(),
None)))
}
fn ping(
&self,
context: RequestContext<RoleServer>,
Expand Down Expand Up @@ -240,4 +265,20 @@ pub trait ServerHandler: Sized + Send + Sync + 'static {
fn get_info(&self) -> ServerInfo {
ServerInfo::default()
}

fn list_tasks(
&self,
request: Option<PaginatedRequestParam>,
context: RequestContext<RoleServer>,
) -> impl Future<Output = Result<ListTasksResult, McpError>> + Send + '_ {
std::future::ready(Err(McpError::method_not_found::<ListTasksMethod>()))
}

fn get_task_info(
&self,
request: GetTaskInfoParam,
context: RequestContext<RoleServer>,
) -> impl Future<Output = Result<GetTaskInfoResult, McpError>> + Send + '_ {
std::future::ready(Err(McpError::method_not_found::<GetTaskInfoMethod>()))
}
}
4 changes: 3 additions & 1 deletion crates/rmcp/src/handler/server/tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,21 @@ pub struct ToolCallContext<'s, S> {
pub service: &'s S,
pub name: Cow<'static, str>,
pub arguments: Option<JsonObject>,
pub task:Option<JsonObject>,
}

impl<'s, S> ToolCallContext<'s, S> {
pub fn new(
service: &'s S,
CallToolRequestParam { name, arguments }: CallToolRequestParam,
CallToolRequestParam { name, arguments, task }: CallToolRequestParam,
request_context: RequestContext<RoleServer>,
) -> Self {
Self {
request_context,
service,
name,
arguments,
task,
}
}
pub fn name(&self) -> &str {
Expand Down
1 change: 1 addition & 0 deletions crates/rmcp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ pub use service::{RoleClient, serve_client};
pub use service::{RoleServer, serve_server};

pub mod handler;
pub mod task_manager;
pub mod transport;

// re-export
Expand Down
52 changes: 51 additions & 1 deletion crates/rmcp/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ mod meta;
mod prompt;
mod resource;
mod serde_impl;
mod task;
mod tool;
pub use annotated::*;
pub use capabilities::*;
Expand All @@ -19,6 +20,7 @@ pub use prompt::*;
pub use resource::*;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use serde_json::Value;
pub use task::*;
pub use tool::*;

/// A JSON object type alias for convenient handling of JSON data.
Expand Down Expand Up @@ -1639,6 +1641,7 @@ paginated_result!(
}
);


const_string!(CallToolRequestMethod = "tools/call");
/// Parameters for calling a tool provided by an MCP server.
///
Expand All @@ -1653,6 +1656,8 @@ pub struct CallToolRequestParam {
/// Arguments to pass to the tool (must match the tool's input schema)
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<JsonObject>,
#[serde(skip_serializing_if = "Option::is_none")]
pub task: Option<JsonObject>,
}

/// Request to call a specific tool
Expand Down Expand Up @@ -1691,6 +1696,23 @@ pub struct GetPromptResult {
pub messages: Vec<PromptMessage>,
}

// =============================================================================
// TASK MANAGEMENT
// =============================================================================

const_string!(GetTaskInfoMethod = "tasks/get");
pub type GetTaskInfoRequest = Request<GetTaskInfoMethod, GetTaskInfoParam>;

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(rename_all = "camelCase")]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub struct GetTaskInfoParam {
pub task_id: String,
}

const_string!(ListTasksMethod = "tasks/list");
pub type ListTasksRequest = RequestOptionalParam<ListTasksMethod, PaginatedRequestParam>;

// =============================================================================
// MESSAGE TYPE UNIONS
// =============================================================================
Expand Down Expand Up @@ -1757,7 +1779,9 @@ ts_union!(
| SubscribeRequest
| UnsubscribeRequest
| CallToolRequest
| ListToolsRequest;
| ListToolsRequest
| GetTaskInfoRequest
| ListTasksRequest;
);

impl ClientRequest {
Expand All @@ -1776,6 +1800,8 @@ impl ClientRequest {
ClientRequest::UnsubscribeRequest(r) => r.method.as_str(),
ClientRequest::CallToolRequest(r) => r.method.as_str(),
ClientRequest::ListToolsRequest(r) => r.method.as_str(),
ClientRequest::GetTaskInfoRequest(r) => r.method.as_str(),
ClientRequest::ListTasksRequest(r) => r.method.as_str(),
}
}
}
Expand Down Expand Up @@ -1832,6 +1858,8 @@ ts_union!(
| CallToolResult
| ListToolsResult
| CreateElicitationResult
| GetTaskInfoResult
| ListTasksResult
| EmptyResult
;
);
Expand All @@ -1842,6 +1870,28 @@ impl ServerResult {
}
}

// =============================================================================
// TASK RESULT TYPES (Server responses for task queries)
// =============================================================================
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(rename_all = "camelCase")]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub struct GetTaskInfoResult {
#[serde(skip_serializing_if = "Option::is_none")]
pub task: Option<crate::model::Task>,
}

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
#[serde(rename_all = "camelCase")]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub struct ListTasksResult {
pub tasks: Vec<crate::model::Task>,
#[serde(skip_serializing_if = "Option::is_none")]
pub next_cursor: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub total: Option<u64>,
}

pub type ServerJsonRpcMessage = JsonRpcMessage<ServerRequest, ServerResult, ServerNotification>;

impl TryInto<CancelledNotification> for ServerNotification {
Expand Down
3 changes: 3 additions & 0 deletions crates/rmcp/src/model/meta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ variant_extension! {
UnsubscribeRequest
CallToolRequest
ListToolsRequest
GetTaskInfoRequest
ListTasksRequest
}
}

Expand Down Expand Up @@ -154,6 +156,7 @@ impl Meta {
})
}


pub fn set_progress_token(&mut self, token: ProgressToken) {
match token.0 {
NumberOrString::String(ref s) => self.0.insert(
Expand Down
Loading
Loading