From 90e2dba58472024d4f2a0ccd9b24b5ab04e8cb83 Mon Sep 17 00:00:00 2001 From: jokemanfire Date: Tue, 11 Nov 2025 10:56:44 +0800 Subject: [PATCH] feat(task): add task support (SEP-1686) Signed-off-by: jokemanfire --- crates/rmcp/Cargo.toml | 5 + crates/rmcp/src/error.rs | 4 + crates/rmcp/src/handler/server.rs | 49 +++- crates/rmcp/src/handler/server/tool.rs | 4 +- crates/rmcp/src/lib.rs | 1 + crates/rmcp/src/model.rs | 52 +++- crates/rmcp/src/model/meta.rs | 3 + crates/rmcp/src/model/task.rs | 270 ++++++++++++++++++ crates/rmcp/src/task_manager.rs | 202 +++++++++++++ crates/rmcp/tests/test_progress_subscriber.rs | 1 + crates/rmcp/tests/test_task.rs | 83 ++++++ crates/rmcp/tests/test_tool_macros.rs | 2 + examples/clients/src/collection.rs | 1 + examples/clients/src/everything_stdio.rs | 2 + examples/clients/src/git_stdio.rs | 1 + examples/clients/src/progress_client.rs | 2 + examples/clients/src/sampling_stdio.rs | 1 + examples/clients/src/streamable_http.rs | 1 + examples/rig-integration/src/mcp_adaptor.rs | 1 + examples/servers/src/common/counter.rs | 205 +++++++++++-- examples/simple-chat-client/src/tool.rs | 1 + examples/transport/src/unix_socket.rs | 1 + 22 files changed, 857 insertions(+), 35 deletions(-) create mode 100644 crates/rmcp/src/model/task.rs create mode 100644 crates/rmcp/src/task_manager.rs create mode 100644 crates/rmcp/tests/test_task.rs diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 098287da..d908f6c3 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -197,3 +197,8 @@ path = "tests/test_progress_subscriber.rs" name = "test_elicitation" required-features = ["elicitation", "client", "server"] path = "tests/test_elicitation.rs" + +[[test]] +name = "test_task" +required-features = ["server", "client", "macros"] +path = "tests/test_task.rs" \ No newline at end of file diff --git a/crates/rmcp/src/error.rs b/crates/rmcp/src/error.rs index e0da2b3d..f51a7158 100644 --- a/crates/rmcp/src/error.rs +++ b/crates/rmcp/src/error.rs @@ -41,6 +41,10 @@ pub enum RmcpError { error: Box, }, // and cancellation shouldn't be an error? + + // TODO: add more error variants as needed + #[error("Task error: {0}")] + TaskError(String), } impl RmcpError { diff --git a/crates/rmcp/src/handler/server.rs b/crates/rmcp/src/handler/server.rs index fd062dbd..774f7165 100644 --- a/crates/rmcp/src/handler/server.rs +++ b/crates/rmcp/src/handler/server.rs @@ -11,6 +11,7 @@ pub mod router; pub mod tool; pub mod tool_name_validation; pub mod wrapper; + impl Service for H { async fn handle_request( &self, @@ -61,14 +62,29 @@ impl Service for H { .unsubscribe(request.params, context) .await .map(ServerResult::empty), - ClientRequest::CallToolRequest(request) => self - .call_tool(request.params, context) - .await - .map(ServerResult::CallToolResult), + ClientRequest::CallToolRequest(request) => { + if request.params.task.is_some() { + tracing::info!("Enqueueing task for tool call: {}", request.params.name); + self.enqueue_task(request.params, context.clone()).await.map(ServerResult::GetTaskInfoResult) + }else{ + self + .call_tool(request.params, context) + .await + .map(ServerResult::CallToolResult) + } + }, ClientRequest::ListToolsRequest(request) => self .list_tools(request.params, context) .await .map(ServerResult::ListToolsResult), + ClientRequest::ListTasksRequest(request) => self + .list_tasks(request.params, context) + .await + .map(ServerResult::ListTasksResult), + ClientRequest::GetTaskInfoRequest(request) => self + .get_task_info(request.params, context) + .await + .map(ServerResult::GetTaskInfoResult), } } @@ -104,6 +120,15 @@ impl Service for H { #[allow(unused_variables)] pub trait ServerHandler: Sized + Send + Sync + 'static { + fn enqueue_task( + &self, + _request: CallToolRequestParam, + _context: RequestContext, + ) -> impl Future> + Send + '_ { + std::future::ready(Err(McpError::internal_error( + "Task processing not implemented".to_string(), + None))) + } fn ping( &self, context: RequestContext, @@ -240,4 +265,20 @@ pub trait ServerHandler: Sized + Send + Sync + 'static { fn get_info(&self) -> ServerInfo { ServerInfo::default() } + + fn list_tasks( + &self, + request: Option, + context: RequestContext, + ) -> impl Future> + Send + '_ { + std::future::ready(Err(McpError::method_not_found::())) + } + + fn get_task_info( + &self, + request: GetTaskInfoParam, + context: RequestContext, + ) -> impl Future> + Send + '_ { + std::future::ready(Err(McpError::method_not_found::())) + } } diff --git a/crates/rmcp/src/handler/server/tool.rs b/crates/rmcp/src/handler/server/tool.rs index cf842679..e669a33e 100644 --- a/crates/rmcp/src/handler/server/tool.rs +++ b/crates/rmcp/src/handler/server/tool.rs @@ -33,12 +33,13 @@ pub struct ToolCallContext<'s, S> { pub service: &'s S, pub name: Cow<'static, str>, pub arguments: Option, + pub task:Option, } impl<'s, S> ToolCallContext<'s, S> { pub fn new( service: &'s S, - CallToolRequestParam { name, arguments }: CallToolRequestParam, + CallToolRequestParam { name, arguments, task }: CallToolRequestParam, request_context: RequestContext, ) -> Self { Self { @@ -46,6 +47,7 @@ impl<'s, S> ToolCallContext<'s, S> { service, name, arguments, + task, } } pub fn name(&self) -> &str { diff --git a/crates/rmcp/src/lib.rs b/crates/rmcp/src/lib.rs index 9f81eabe..cba4eeba 100644 --- a/crates/rmcp/src/lib.rs +++ b/crates/rmcp/src/lib.rs @@ -162,6 +162,7 @@ pub use service::{RoleClient, serve_client}; pub use service::{RoleServer, serve_server}; pub mod handler; +pub mod task_manager; pub mod transport; // re-export diff --git a/crates/rmcp/src/model.rs b/crates/rmcp/src/model.rs index 33b507da..113984a6 100644 --- a/crates/rmcp/src/model.rs +++ b/crates/rmcp/src/model.rs @@ -8,6 +8,7 @@ mod meta; mod prompt; mod resource; mod serde_impl; +mod task; mod tool; pub use annotated::*; pub use capabilities::*; @@ -19,6 +20,7 @@ pub use prompt::*; pub use resource::*; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use serde_json::Value; +pub use task::*; pub use tool::*; /// A JSON object type alias for convenient handling of JSON data. @@ -1639,6 +1641,7 @@ paginated_result!( } ); + const_string!(CallToolRequestMethod = "tools/call"); /// Parameters for calling a tool provided by an MCP server. /// @@ -1653,6 +1656,8 @@ pub struct CallToolRequestParam { /// Arguments to pass to the tool (must match the tool's input schema) #[serde(skip_serializing_if = "Option::is_none")] pub arguments: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub task: Option, } /// Request to call a specific tool @@ -1691,6 +1696,23 @@ pub struct GetPromptResult { pub messages: Vec, } +// ============================================================================= +// TASK MANAGEMENT +// ============================================================================= + +const_string!(GetTaskInfoMethod = "tasks/get"); +pub type GetTaskInfoRequest = Request; + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct GetTaskInfoParam { + pub task_id: String, +} + +const_string!(ListTasksMethod = "tasks/list"); +pub type ListTasksRequest = RequestOptionalParam; + // ============================================================================= // MESSAGE TYPE UNIONS // ============================================================================= @@ -1757,7 +1779,9 @@ ts_union!( | SubscribeRequest | UnsubscribeRequest | CallToolRequest - | ListToolsRequest; + | ListToolsRequest + | GetTaskInfoRequest + | ListTasksRequest; ); impl ClientRequest { @@ -1776,6 +1800,8 @@ impl ClientRequest { ClientRequest::UnsubscribeRequest(r) => r.method.as_str(), ClientRequest::CallToolRequest(r) => r.method.as_str(), ClientRequest::ListToolsRequest(r) => r.method.as_str(), + ClientRequest::GetTaskInfoRequest(r) => r.method.as_str(), + ClientRequest::ListTasksRequest(r) => r.method.as_str(), } } } @@ -1832,6 +1858,8 @@ ts_union!( | CallToolResult | ListToolsResult | CreateElicitationResult + | GetTaskInfoResult + | ListTasksResult | EmptyResult ; ); @@ -1842,6 +1870,28 @@ impl ServerResult { } } +// ============================================================================= +// TASK RESULT TYPES (Server responses for task queries) +// ============================================================================= +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct GetTaskInfoResult { + #[serde(skip_serializing_if = "Option::is_none")] + pub task: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct ListTasksResult { + pub tasks: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub next_cursor: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub total: Option, +} + pub type ServerJsonRpcMessage = JsonRpcMessage; impl TryInto for ServerNotification { diff --git a/crates/rmcp/src/model/meta.rs b/crates/rmcp/src/model/meta.rs index a03fc056..022ac090 100644 --- a/crates/rmcp/src/model/meta.rs +++ b/crates/rmcp/src/model/meta.rs @@ -86,6 +86,8 @@ variant_extension! { UnsubscribeRequest CallToolRequest ListToolsRequest + GetTaskInfoRequest + ListTasksRequest } } @@ -154,6 +156,7 @@ impl Meta { }) } + pub fn set_progress_token(&mut self, token: ProgressToken) { match token.0 { NumberOrString::String(ref s) => self.0.insert( diff --git a/crates/rmcp/src/model/task.rs b/crates/rmcp/src/model/task.rs new file mode 100644 index 00000000..6d3ad111 --- /dev/null +++ b/crates/rmcp/src/model/task.rs @@ -0,0 +1,270 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +use super::{JsonObject, Meta}; + +/// Task lifecycle status +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "snake_case")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub enum TaskStatus { + /// Created but not started yet + #[default] + Pending, + /// Currently running + Running, + /// Waiting for dependencies or external input + Waiting, + /// Cancellation requested and in progress + Cancelling, + /// Completed successfully + Succeeded, + /// Completed with failure + Failed, + /// Cancelled before completion + Cancelled, +} + +/// High-level task kind. Exact set may evolve with the SEP. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "snake_case")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub enum TaskKind { + #[default] + Generation, + Retrieval, + Aggregation, + Orchestration, + ToolCall, + /// Custom kind identifier + Custom(String), +} + +/// Progress information for long-running tasks +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct TaskProgress { + /// Percentage progress in the range [0.0, 100.0] + #[serde(skip_serializing_if = "Option::is_none")] + pub percent: Option, + /// Current stage identifier + #[serde(skip_serializing_if = "Option::is_none")] + pub stage: Option, + /// Human-readable status message + #[serde(skip_serializing_if = "Option::is_none")] + pub message: Option, + /// Arbitrary structured details, protocol-specific + #[serde(skip_serializing_if = "Option::is_none")] + pub details: Option, +} + +/// Error information for failed tasks +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct TaskError { + /// Machine-readable error code + pub code: String, + /// Human-readable error message + pub message: String, + /// Whether the operation can be retried safely + #[serde(skip_serializing_if = "Option::is_none")] + pub retryable: Option, + /// Arbitrary error data for debugging + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, +} + +/// Final result for a succeeded task +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct TaskResult { + /// MIME type or custom content-type identifier + pub content_type: String, + /// The actual result payload + pub value: Value, + /// Optional short summary for UI + #[serde(skip_serializing_if = "Option::is_none")] + pub summary: Option, +} + +/// Primary Task object used across client/server +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct Task { + /// Unique task identifier + pub id: String, + /// Task kind/category + pub kind: TaskKind, + /// Current status + pub status: TaskStatus, + /// ISO8601 creation time + pub created_at: String, + /// ISO8601 last update time + #[serde(skip_serializing_if = "Option::is_none")] + pub updated_at: Option, + /// ISO8601 start time + #[serde(skip_serializing_if = "Option::is_none")] + pub started_at: Option, + /// ISO8601 completion time + #[serde(skip_serializing_if = "Option::is_none")] + pub completed_at: Option, + /// Parent task identifier for hierarchical tasks + #[serde(skip_serializing_if = "Option::is_none")] + pub parent_id: Option, + /// List of prerequisite task ids + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub depends_on: Vec, + /// Optional labels for filtering and grouping + #[serde(skip_serializing_if = "Option::is_none")] + pub labels: Option>, + /// Immutable metadata provided at creation + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option, + /// Mutable runtime state exposed to clients + #[serde(skip_serializing_if = "Option::is_none")] + pub runtime_state: Option, + /// Input parameters for this task + #[serde(skip_serializing_if = "Option::is_none")] + pub input: Option, + /// Progress info when running + #[serde(skip_serializing_if = "Option::is_none")] + pub progress: Option, + /// Final result when succeeded + #[serde(skip_serializing_if = "Option::is_none")] + pub result: Option, + /// Error information when failed + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + /// True if a cancellation has been requested + #[serde(default)] + pub cancellation_requested: bool, + /// Scheduling priority; larger means higher priority (convention) + #[serde(skip_serializing_if = "Option::is_none")] + pub priority: Option, + /// Batch/group identifier for bulk operations + #[serde(skip_serializing_if = "Option::is_none")] + pub batch_group: Option, + /// Trace identifier for observability systems + #[serde(skip_serializing_if = "Option::is_none")] + pub trace_id: Option, + /// Protocol-level metadata for the task object + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option, + /// Reserved for future SEP extensions + #[serde(skip_serializing_if = "Option::is_none")] + pub extensions: Option, +} + +/// Query filter for listing tasks +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct TaskQuery { + #[serde(skip_serializing_if = "Option::is_none")] + pub ids: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub status: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub kind: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub parent_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub labels_any: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub labels_all: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub created_after: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub created_before: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub limit: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub cursor: Option, +} + +/// Paginated list of tasks +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct TaskList { + pub tasks: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub next_cursor: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub total: Option, +} + +/// Request payload to create a new task +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct TaskCreateRequest { + pub kind: TaskKind, + #[serde(skip_serializing_if = "Option::is_none")] + pub input: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub parent_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub depends_on: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub labels: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub priority: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub batch_group: Option, + /// Protocol-level metadata for the request object + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option, + /// Reserved for future SEP extensions + #[serde(skip_serializing_if = "Option::is_none")] + pub extensions: Option, +} + +/// Request payload to update a task's runtime fields +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct TaskUpdateRequest { + pub id: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub runtime_state: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub cancellation_requested: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub labels: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub priority: Option, +} + +/// Incremental progress event for streaming updates +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct TaskProgressEvent { + pub id: String, + pub progress: TaskProgress, + /// ISO8601 timestamp for the event + pub timestamp: String, +} + +/// Terminal event signaling task completion +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct TaskCompletionEvent { + pub id: String, + /// Allowed values: Succeeded, Failed, Cancelled + pub status: TaskStatus, + #[serde(skip_serializing_if = "Option::is_none")] + pub result: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + /// ISO8601 timestamp for the event + pub timestamp: String, +} diff --git a/crates/rmcp/src/task_manager.rs b/crates/rmcp/src/task_manager.rs new file mode 100644 index 00000000..c7a8cab5 --- /dev/null +++ b/crates/rmcp/src/task_manager.rs @@ -0,0 +1,202 @@ +use std::{collections::HashMap, pin::Pin}; + +use futures::Future; +use tokio::sync::mpsc; + +use crate::{error::RmcpError as Error, model::ClientRequest, service::RequestContext, RoleServer}; + +/// Boxed future that represents an asynchronous operation managed by the processor. +pub type OperationFuture = Pin< + Box, Error>> + Send>, +>; + +/// Describes metadata associated with an enqueued task. +#[derive(Debug, Clone)] +pub struct OperationDescriptor { + pub operation_id: String, + pub name: String, + pub client_request: Option, + pub context: Option>, + pub timeout_secs: Option, +} + +impl OperationDescriptor { + pub fn new(operation_id: impl Into, name: impl Into) -> Self { + Self { + operation_id: operation_id.into(), + name: name.into(), + client_request: None, + context: None, + timeout_secs: None, + } + } + + pub fn with_client_request(mut self, request: ClientRequest) -> Self { + self.client_request = Some(request); + self + } + + pub fn with_context(mut self, context: RequestContext) -> Self { + self.context = Some(context); + self + } + + pub fn with_timeout(mut self, timeout_secs: u64) -> Self { + self.timeout_secs = Some(timeout_secs); + self + } +} + +/// Operation message describing a unit of asynchronous work. +pub struct OperationMessage { + pub descriptor: OperationDescriptor, + pub future: OperationFuture, +} + +impl OperationMessage { + pub fn new(descriptor: OperationDescriptor, future: OperationFuture) -> Self { + Self { descriptor, future } + } +} + +/// Trait for operation result transport +pub trait OperationResultTransport: Send + Sync + 'static { + fn operation_id(&self) -> &String; + fn as_any(&self) -> &dyn std::any::Any; +} + +// ===== Operation Processor ===== +pub const DEFAULT_TASK_TIMEOUT_SECS: u64 = 300; // 5 minutes +/// Operation processor that coordinates extractors and handlers +pub struct OperationProcessor { + /// Currently running tasks keyed by id + running_tasks: HashMap, + /// Completed results waiting to be collected + completed_results: Vec, + task_result_receiver: Option>, + task_result_sender: mpsc::UnboundedSender, +} + +struct RunningTask { + task_handle: tokio::task::JoinHandle<()>, + started_at: std::time::Instant, + timeout: Option, + descriptor: OperationDescriptor, +} + +pub struct TaskResult { + pub descriptor: OperationDescriptor, + pub result: Result, Error>, +} + +impl OperationProcessor { + pub fn new() -> Self { + let (task_result_sender, task_result_receiver) = mpsc::unbounded_channel(); + Self { + running_tasks: HashMap::new(), + completed_results: Vec::new(), + task_result_receiver: Some(task_result_receiver), + task_result_sender, + } + } + + /// Submit an operation for asynchronous execution. + pub fn submit_operation(&mut self, message: OperationMessage) -> Result<(), Error> { + if self + .running_tasks + .contains_key(&message.descriptor.operation_id) + { + return Err(Error::TaskError(format!( + "Operation with id {} is already running", + message.descriptor.operation_id + ))); + } + self.spawn_async_task(message); + Ok(()) + } + + fn spawn_async_task(&mut self, message: OperationMessage) { + let OperationMessage { descriptor, future } = message; + let task_id = descriptor.operation_id.clone(); + let timeout = descriptor + .timeout_secs + .or(Some(DEFAULT_TASK_TIMEOUT_SECS)); + let sender = self.task_result_sender.clone(); + let descriptor_for_result = descriptor.clone(); + let handle = tokio::spawn(async move { + let result = future.await; + let task_result = TaskResult { + descriptor: descriptor_for_result, + result, + }; + let _ = sender.send(task_result); + }); + let running_task = RunningTask { + task_handle: handle, + started_at: std::time::Instant::now(), + timeout, + descriptor, + }; + self.running_tasks.insert(task_id, running_task); + } + + /// Collect completed results from running tasks and remove them from the running tasks map. + pub fn collect_completed_results(&mut self) -> Vec { + if let Some(receiver) = &mut self.task_result_receiver { + while let Ok(result) = receiver.try_recv() { + self + .running_tasks + .remove(&result.descriptor.operation_id); + self.completed_results.push(result); + } + } + std::mem::take(&mut self.completed_results) + } + + /// Check for tasks that have exceeded their timeout and handle them appropriately. + pub fn check_timeouts(&mut self) { + let now = std::time::Instant::now(); + let mut timed_out_tasks = Vec::new(); + + for (task_id, task) in &self.running_tasks { + if let Some(timeout_duration) = task.timeout { + if now.duration_since(task.started_at).as_secs() > timeout_duration { + task.task_handle.abort(); + timed_out_tasks.push(task_id.clone()); + } + } + } + + for task_id in timed_out_tasks { + if let Some(task) = self.running_tasks.remove(&task_id) { + let timeout_result = TaskResult { + descriptor: task.descriptor, + result: Err(Error::TaskError("Operation timed out".to_string())), + }; + self.completed_results.push(timeout_result); + } + } + } + + /// Get the number of running tasks. + pub fn running_task_count(&self) -> usize { + self.running_tasks.len() + } + + /// Cancel all running tasks. + pub fn cancel_all_tasks(&mut self) { + for (_, task) in self.running_tasks.drain() { + task.task_handle.abort(); + } + self.completed_results.clear(); + } + /// List running task ids. + pub fn list_running(&self) -> Vec { + self.running_tasks.keys().cloned().collect() + } + + /// Note: collectors should call collect_completed_results; this provides a snapshot of queued results. + pub fn peek_completed(&self) -> &[TaskResult] { + &self.completed_results + } +} diff --git a/crates/rmcp/tests/test_progress_subscriber.rs b/crates/rmcp/tests/test_progress_subscriber.rs index 531b1692..b5d185ab 100644 --- a/crates/rmcp/tests/test_progress_subscriber.rs +++ b/crates/rmcp/tests/test_progress_subscriber.rs @@ -110,6 +110,7 @@ async fn test_progress_subscriber() -> anyhow::Result<()> { ClientRequest::CallToolRequest(Request::new(CallToolRequestParam { name: "some_progress".into(), arguments: None, + task: None, })), PeerRequestOptions::no_options(), ) diff --git a/crates/rmcp/tests/test_task.rs b/crates/rmcp/tests/test_task.rs new file mode 100644 index 00000000..fc18da1d --- /dev/null +++ b/crates/rmcp/tests/test_task.rs @@ -0,0 +1,83 @@ +use std::{any::Any, time::Duration}; + +use rmcp::task_manager::{ + OperationDescriptor, OperationMessage, OperationProcessor, OperationResultTransport, +}; + +struct DummyTransport { + id: String, + value: u32, +} + +impl OperationResultTransport for DummyTransport { + fn operation_id(&self) -> &String { + &self.id + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +#[tokio::test] +async fn executes_enqueued_future() { + let mut processor = OperationProcessor::new(); + let descriptor = OperationDescriptor::new("op1", "dummy"); + let future = Box::pin(async { + tokio::time::sleep(Duration::from_millis(10)).await; + Ok( + Box::new(DummyTransport { + id: "op1".to_string(), + value: 42, + }) as Box, + ) + }); + + processor + .submit_operation(OperationMessage::new(descriptor, future)) + .expect("submit operation"); + + tokio::time::sleep(Duration::from_millis(30)).await; + let results = processor.collect_completed_results(); + assert_eq!(results.len(), 1); + let payload = results[0] + .result + .as_ref() + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(payload.value, 42); +} + +#[tokio::test] +async fn rejects_duplicate_operation_ids() { + let mut processor = OperationProcessor::new(); + let descriptor = OperationDescriptor::new("dup", "dummy"); + let future = Box::pin(async { + Ok( + Box::new(DummyTransport { + id: "dup".to_string(), + value: 1, + }) as Box, + ) + }); + processor + .submit_operation(OperationMessage::new(descriptor, future)) + .expect("first submit"); + + let descriptor_dup = OperationDescriptor::new("dup", "dummy"); + let future_dup = Box::pin(async { + Ok( + Box::new(DummyTransport { + id: "dup".to_string(), + value: 2, + }) as Box, + ) + }); + + let err = processor + .submit_operation(OperationMessage::new(descriptor_dup, future_dup)) + .expect_err("duplicate should fail"); + assert!(format!("{err}").contains("already running")); +} diff --git a/crates/rmcp/tests/test_tool_macros.rs b/crates/rmcp/tests/test_tool_macros.rs index db5242b3..763c4f43 100644 --- a/crates/rmcp/tests/test_tool_macros.rs +++ b/crates/rmcp/tests/test_tool_macros.rs @@ -320,6 +320,7 @@ async fn test_optional_i64_field_with_null_input() -> anyhow::Result<()> { .unwrap() .clone(), ), + task: None, }) .await?; @@ -348,6 +349,7 @@ async fn test_optional_i64_field_with_null_input() -> anyhow::Result<()> { .unwrap() .clone(), ), + task: None, }) .await?; diff --git a/examples/clients/src/collection.rs b/examples/clients/src/collection.rs index 67969ae4..c714da54 100644 --- a/examples/clients/src/collection.rs +++ b/examples/clients/src/collection.rs @@ -49,6 +49,7 @@ async fn main() -> Result<()> { .call_tool(CallToolRequestParam { name: "git_status".into(), arguments: serde_json::json!({ "repo_path": "." }).as_object().cloned(), + task: None, }) .await?; } diff --git a/examples/clients/src/everything_stdio.rs b/examples/clients/src/everything_stdio.rs index 107adc07..f1cbcae5 100644 --- a/examples/clients/src/everything_stdio.rs +++ b/examples/clients/src/everything_stdio.rs @@ -40,6 +40,7 @@ async fn main() -> Result<()> { .call_tool(CallToolRequestParam { name: "echo".into(), arguments: Some(object!({ "message": "hi from rmcp" })), + task: None, }) .await?; tracing::info!("Tool result for echo: {tool_result:#?}"); @@ -49,6 +50,7 @@ async fn main() -> Result<()> { .call_tool(CallToolRequestParam { name: "longRunningOperation".into(), arguments: Some(object!({ "duration": 3, "steps": 1 })), + task: None, }) .await?; tracing::info!("Tool result for longRunningOperation: {tool_result:#?}"); diff --git a/examples/clients/src/git_stdio.rs b/examples/clients/src/git_stdio.rs index d1298b36..7b516f38 100644 --- a/examples/clients/src/git_stdio.rs +++ b/examples/clients/src/git_stdio.rs @@ -42,6 +42,7 @@ async fn main() -> Result<(), RmcpError> { .call_tool(CallToolRequestParam { name: "git_status".into(), arguments: serde_json::json!({ "repo_path": "." }).as_object().cloned(), + task: None, }) .await?; tracing::info!("Tool result: {tool_result:#?}"); diff --git a/examples/clients/src/progress_client.rs b/examples/clients/src/progress_client.rs index ddf18b2f..c795ce22 100644 --- a/examples/clients/src/progress_client.rs +++ b/examples/clients/src/progress_client.rs @@ -184,6 +184,7 @@ async fn test_stdio_transport(records: u32) -> Result<()> { .call_tool(CallToolRequestParam { name: "stream_processor".into(), arguments: None, + task: None, }) .await?; @@ -238,6 +239,7 @@ async fn test_http_transport(http_url: &str, records: u32) -> Result<()> { .call_tool(CallToolRequestParam { name: "stream_processor".into(), arguments: None, + task: None, }) .await?; diff --git a/examples/clients/src/sampling_stdio.rs b/examples/clients/src/sampling_stdio.rs index 8f5aba22..b30a3c26 100644 --- a/examples/clients/src/sampling_stdio.rs +++ b/examples/clients/src/sampling_stdio.rs @@ -106,6 +106,7 @@ async fn main() -> Result<()> { arguments: Some(object!({ "question": "Hello world" })), + task: None, }) .await { diff --git a/examples/clients/src/streamable_http.rs b/examples/clients/src/streamable_http.rs index 2f1f1598..cd4b73c4 100644 --- a/examples/clients/src/streamable_http.rs +++ b/examples/clients/src/streamable_http.rs @@ -44,6 +44,7 @@ async fn main() -> Result<()> { .call_tool(CallToolRequestParam { name: "increment".into(), arguments: serde_json::json!({}).as_object().cloned(), + task: None, }) .await?; tracing::info!("Tool result: {tool_result:#?}"); diff --git a/examples/rig-integration/src/mcp_adaptor.rs b/examples/rig-integration/src/mcp_adaptor.rs index 483c6e02..286e58d5 100644 --- a/examples/rig-integration/src/mcp_adaptor.rs +++ b/examples/rig-integration/src/mcp_adaptor.rs @@ -47,6 +47,7 @@ impl RigTool for McpToolAdaptor { name: self.tool.name.clone(), arguments: serde_json::from_str(&args) .map_err(rig::tool::ToolError::JsonError)?, + task: None, }) .await .inspect(|result| tracing::info!(?result)) diff --git a/examples/servers/src/common/counter.rs b/examples/servers/src/common/counter.rs index dc2472bb..2c218b98 100644 --- a/examples/servers/src/common/counter.rs +++ b/examples/servers/src/common/counter.rs @@ -1,6 +1,6 @@ #![allow(dead_code)] -use std::sync::Arc; - +use std::{any::Any, sync::Arc}; +use rmcp::task_manager::{OperationDescriptor, OperationMessage, OperationProcessor, OperationResultTransport}; use rmcp::{ ErrorData as McpError, RoleServer, ServerHandler, handler::server::{ @@ -13,8 +13,24 @@ use rmcp::{ tool, tool_handler, tool_router, }; use serde_json::json; +use chrono::Utc; use tokio::sync::Mutex; +struct ToolCallOperationResult { + id: String, + result: Result, +} + +impl OperationResultTransport for ToolCallOperationResult { + fn operation_id(&self) -> &String { + &self.id + } + + fn as_any(&self) -> &dyn Any { + self + } +} + #[derive(Debug, serde::Deserialize, schemars::JsonSchema)] pub struct StructRequest { pub a: i32, @@ -41,6 +57,7 @@ pub struct Counter { counter: Arc>, tool_router: ToolRouter, prompt_router: PromptRouter, + processor: Arc>, } #[tool_router] @@ -51,6 +68,7 @@ impl Counter { counter: Arc::new(Mutex::new(0)), tool_router: Self::tool_router(), prompt_router: Self::prompt_router(), + processor: Arc::new(Mutex::new(OperationProcessor::new())), } } @@ -84,6 +102,12 @@ impl Counter { )])) } + #[tool(description = "Long running task example")] + async fn long_task(&self) -> Result { + tokio::time::sleep(std::time::Duration::from_secs(10)).await; + Ok(CallToolResult::success(vec![Content::text("Long task completed")])) + } + #[tool(description = "Say hello to the client")] fn say_hello(&self) -> Result { Ok(CallToolResult::success(vec![Content::text("hello")])) @@ -180,6 +204,49 @@ impl ServerHandler for Counter { } } + async fn enqueue_task( + &self, + request: CallToolRequestParam, + context: RequestContext, + ) -> Result { + let operation_id = context.id.to_string(); + let operation_name = request.name.to_string(); + let task_payload = request.task.clone(); + let future_request = request.clone(); + let future_context = context.clone(); + let server = self.clone(); + let descriptor = OperationDescriptor::new(operation_id.clone(), operation_name) + .with_context(context.clone()) + .with_client_request(ClientRequest::CallToolRequest(Request::new(request.clone()))); + let id_clone = operation_id.clone(); + let future = Box::pin(async move { + let result = server.call_tool(future_request, future_context).await; + Ok( + Box::new(ToolCallOperationResult { + id: id_clone, + result, + }) as Box, + ) + }); + + let message = OperationMessage::new(descriptor, future); + self.processor + .lock() + .await + .submit_operation(message) + .map_err(|err| McpError::internal_error(format!("failed to enqueue task: {err}"), None))?; + + let mut task = Task::default(); + task.id = operation_id; + task.kind = TaskKind::ToolCall; + task.status = TaskStatus::Pending; + task.created_at = Utc::now().to_rfc3339(); + task.metadata = task_payload.map(|payload| serde_json::Value::Object(payload)); + task.input = serde_json::to_value(&request).ok(); + + Ok(GetTaskInfoResult { task: Some(task) }) + } + async fn list_resources( &self, _request: Option, @@ -251,6 +318,13 @@ impl ServerHandler for Counter { #[cfg(test)] mod tests { use super::*; + use rmcp::{ClientHandler, ServiceExt}; + use tokio::time::Duration; + + #[derive(Default, Clone)] + struct TestClient; + + impl ClientHandler for TestClient {} #[tokio::test] async fn test_prompt_attributes_generated() { @@ -289,34 +363,107 @@ mod tests { } #[tokio::test] - async fn test_example_prompt_execution() { + async fn test_client_enqueues_long_task() -> anyhow::Result<()> { let counter = Counter::new(); - let context = rmcp::handler::server::prompt::PromptContext::new( - &counter, - "example_prompt".to_string(), - Some({ - let mut map = serde_json::Map::new(); - map.insert( - "message".to_string(), - serde_json::Value::String("Test message".to_string()), - ); - map - }), - RequestContext { - meta: Default::default(), - ct: tokio_util::sync::CancellationToken::new(), - id: rmcp::model::NumberOrString::String("test-1".to_string()), - peer: Default::default(), - extensions: Default::default(), - }, + let processor = counter.processor.clone(); + let client = TestClient::default(); + + let (server_transport, client_transport) = tokio::io::duplex(4096); + let server_handle = tokio::spawn(async move { + let service = counter.serve(server_transport).await?; + service.waiting().await?; + anyhow::Ok(()) + }); + + let client_service = client.serve(client_transport).await?; + let mut task_meta = serde_json::Map::new(); + task_meta.insert( + "source".into(), + serde_json::Value::String("integration-test".into()), ); - - let router = Counter::prompt_router(); - let result = router.get_prompt(context).await; - assert!(result.is_ok()); - - let prompt_result = result.unwrap(); - assert_eq!(prompt_result.messages.len(), 1); - assert_eq!(prompt_result.messages[0].role, PromptMessageRole::User); + let params = CallToolRequestParam { + name: "long_task".into(), + arguments: None, + task: Some(task_meta), + }; + let response = client_service + .send_request(ClientRequest::CallToolRequest(Request::new(params.clone()))) + .await?; + + let ServerResult::GetTaskInfoResult(info) = response else { + panic!("expected task info result, got {response:?}"); + }; + let task = info.task.expect("task payload missing"); + assert_eq!(task.kind, TaskKind::ToolCall); + assert_eq!(task.status, TaskStatus::Pending); + assert!(task.input.is_some()); + assert!(task.metadata.is_some()); + + tokio::time::sleep(Duration::from_millis(50)).await; + let running = processor.lock().await.running_task_count(); + assert_eq!(running, 1); + + client_service.cancel().await?; + let _ = server_handle.await; + Ok(()) } + + // #[tokio::test] + // async fn test_example_prompt_execution() { + // let counter = Counter::new(); + // let context = rmcp::handler::server::prompt::PromptContext::new( + // &counter, + // "example_prompt".to_string(), + // Some({ + // let mut map = serde_json::Map::new(); + // map.insert( + // "message".to_string(), + // serde_json::Value::String("Test message".to_string()), + // ); + // map + // }), + // RequestContext { + // meta: Default::default(), + // ct: tokio_util::sync::CancellationToken::new(), + // id: rmcp::model::NumberOrString::String("test-1".to_string()), + // peer: Default::default(), + // extensions: Default::default(), + // }, + // ); + + // let router = Counter::prompt_router(); + // let result = router.get_prompt(context).await; + // assert!(result.is_ok()); + + // let prompt_result = result.unwrap(); + // assert_eq!(prompt_result.messages.len(), 1); + // assert_eq!(prompt_result.messages[0].role, PromptMessageRole::User); + // } + + + // #[tokio::test] + // async fn test_long_task_enqueue() { + // let counter = Counter::new(); + // let request = CallToolRequestParam { + // name: "long_task".to_string(), + // task: Some(serde_json::Map::new()), + // arguments: None, + // }; + // let context = RequestContext { + // meta: Default::default(), + // ct: tokio_util::sync::CancellationToken::new(), + // id: rmcp::model::NumberOrString::String("long-task-1".to_string()), + // peer: Default::default(), + // extensions: Default::default(), + // }; + + // let result = counter.enqueue_task(request, context).await; + // assert!(result.is_ok()); + + // let task_info = result.unwrap(); + // assert!(task_info.task.is_some()); + // let task = task_info.task.unwrap(); + // assert_eq!(task.id, "long-task-1"); + // assert_eq!(task.status, TaskStatus::Pending); + // } } diff --git a/examples/simple-chat-client/src/tool.rs b/examples/simple-chat-client/src/tool.rs index 771b4e9e..174f4274 100644 --- a/examples/simple-chat-client/src/tool.rs +++ b/examples/simple-chat-client/src/tool.rs @@ -62,6 +62,7 @@ impl Tool for McpToolAdapter { .call_tool(CallToolRequestParam { name: self.tool.name.clone(), arguments, + task: None, }) .await?; diff --git a/examples/transport/src/unix_socket.rs b/examples/transport/src/unix_socket.rs index feeb2b87..0d91dfee 100644 --- a/examples/transport/src/unix_socket.rs +++ b/examples/transport/src/unix_socket.rs @@ -52,6 +52,7 @@ async fn main() -> anyhow::Result<()> { "a": 10, "b": 20 })), + task: None, }) .await?;