Skip to content

Commit 90e2dba

Browse files
committed
feat(task): add task support (SEP-1686)
Signed-off-by: jokemanfire <[email protected]>
1 parent 3c62ee8 commit 90e2dba

File tree

22 files changed

+857
-35
lines changed

22 files changed

+857
-35
lines changed

crates/rmcp/Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,3 +197,8 @@ path = "tests/test_progress_subscriber.rs"
197197
name = "test_elicitation"
198198
required-features = ["elicitation", "client", "server"]
199199
path = "tests/test_elicitation.rs"
200+
201+
[[test]]
202+
name = "test_task"
203+
required-features = ["server", "client", "macros"]
204+
path = "tests/test_task.rs"

crates/rmcp/src/error.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ pub enum RmcpError {
4141
error: Box<dyn std::error::Error + Send + Sync>,
4242
},
4343
// and cancellation shouldn't be an error?
44+
45+
// TODO: add more error variants as needed
46+
#[error("Task error: {0}")]
47+
TaskError(String),
4448
}
4549

4650
impl RmcpError {

crates/rmcp/src/handler/server.rs

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ pub mod router;
1111
pub mod tool;
1212
pub mod tool_name_validation;
1313
pub mod wrapper;
14+
1415
impl<H: ServerHandler> Service<RoleServer> for H {
1516
async fn handle_request(
1617
&self,
@@ -61,14 +62,29 @@ impl<H: ServerHandler> Service<RoleServer> for H {
6162
.unsubscribe(request.params, context)
6263
.await
6364
.map(ServerResult::empty),
64-
ClientRequest::CallToolRequest(request) => self
65-
.call_tool(request.params, context)
66-
.await
67-
.map(ServerResult::CallToolResult),
65+
ClientRequest::CallToolRequest(request) => {
66+
if request.params.task.is_some() {
67+
tracing::info!("Enqueueing task for tool call: {}", request.params.name);
68+
self.enqueue_task(request.params, context.clone()).await.map(ServerResult::GetTaskInfoResult)
69+
}else{
70+
self
71+
.call_tool(request.params, context)
72+
.await
73+
.map(ServerResult::CallToolResult)
74+
}
75+
},
6876
ClientRequest::ListToolsRequest(request) => self
6977
.list_tools(request.params, context)
7078
.await
7179
.map(ServerResult::ListToolsResult),
80+
ClientRequest::ListTasksRequest(request) => self
81+
.list_tasks(request.params, context)
82+
.await
83+
.map(ServerResult::ListTasksResult),
84+
ClientRequest::GetTaskInfoRequest(request) => self
85+
.get_task_info(request.params, context)
86+
.await
87+
.map(ServerResult::GetTaskInfoResult),
7288
}
7389
}
7490

@@ -104,6 +120,15 @@ impl<H: ServerHandler> Service<RoleServer> for H {
104120

105121
#[allow(unused_variables)]
106122
pub trait ServerHandler: Sized + Send + Sync + 'static {
123+
fn enqueue_task(
124+
&self,
125+
_request: CallToolRequestParam,
126+
_context: RequestContext<RoleServer>,
127+
) -> impl Future<Output = Result<GetTaskInfoResult, McpError>> + Send + '_ {
128+
std::future::ready(Err(McpError::internal_error(
129+
"Task processing not implemented".to_string(),
130+
None)))
131+
}
107132
fn ping(
108133
&self,
109134
context: RequestContext<RoleServer>,
@@ -240,4 +265,20 @@ pub trait ServerHandler: Sized + Send + Sync + 'static {
240265
fn get_info(&self) -> ServerInfo {
241266
ServerInfo::default()
242267
}
268+
269+
fn list_tasks(
270+
&self,
271+
request: Option<PaginatedRequestParam>,
272+
context: RequestContext<RoleServer>,
273+
) -> impl Future<Output = Result<ListTasksResult, McpError>> + Send + '_ {
274+
std::future::ready(Err(McpError::method_not_found::<ListTasksMethod>()))
275+
}
276+
277+
fn get_task_info(
278+
&self,
279+
request: GetTaskInfoParam,
280+
context: RequestContext<RoleServer>,
281+
) -> impl Future<Output = Result<GetTaskInfoResult, McpError>> + Send + '_ {
282+
std::future::ready(Err(McpError::method_not_found::<GetTaskInfoMethod>()))
283+
}
243284
}

crates/rmcp/src/handler/server/tool.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,21 @@ pub struct ToolCallContext<'s, S> {
3333
pub service: &'s S,
3434
pub name: Cow<'static, str>,
3535
pub arguments: Option<JsonObject>,
36+
pub task:Option<JsonObject>,
3637
}
3738

3839
impl<'s, S> ToolCallContext<'s, S> {
3940
pub fn new(
4041
service: &'s S,
41-
CallToolRequestParam { name, arguments }: CallToolRequestParam,
42+
CallToolRequestParam { name, arguments, task }: CallToolRequestParam,
4243
request_context: RequestContext<RoleServer>,
4344
) -> Self {
4445
Self {
4546
request_context,
4647
service,
4748
name,
4849
arguments,
50+
task,
4951
}
5052
}
5153
pub fn name(&self) -> &str {

crates/rmcp/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ pub use service::{RoleClient, serve_client};
162162
pub use service::{RoleServer, serve_server};
163163

164164
pub mod handler;
165+
pub mod task_manager;
165166
pub mod transport;
166167

167168
// re-export

crates/rmcp/src/model.rs

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ mod meta;
88
mod prompt;
99
mod resource;
1010
mod serde_impl;
11+
mod task;
1112
mod tool;
1213
pub use annotated::*;
1314
pub use capabilities::*;
@@ -19,6 +20,7 @@ pub use prompt::*;
1920
pub use resource::*;
2021
use serde::{Deserialize, Serialize, de::DeserializeOwned};
2122
use serde_json::Value;
23+
pub use task::*;
2224
pub use tool::*;
2325

2426
/// A JSON object type alias for convenient handling of JSON data.
@@ -1639,6 +1641,7 @@ paginated_result!(
16391641
}
16401642
);
16411643

1644+
16421645
const_string!(CallToolRequestMethod = "tools/call");
16431646
/// Parameters for calling a tool provided by an MCP server.
16441647
///
@@ -1653,6 +1656,8 @@ pub struct CallToolRequestParam {
16531656
/// Arguments to pass to the tool (must match the tool's input schema)
16541657
#[serde(skip_serializing_if = "Option::is_none")]
16551658
pub arguments: Option<JsonObject>,
1659+
#[serde(skip_serializing_if = "Option::is_none")]
1660+
pub task: Option<JsonObject>,
16561661
}
16571662

16581663
/// Request to call a specific tool
@@ -1691,6 +1696,23 @@ pub struct GetPromptResult {
16911696
pub messages: Vec<PromptMessage>,
16921697
}
16931698

1699+
// =============================================================================
1700+
// TASK MANAGEMENT
1701+
// =============================================================================
1702+
1703+
const_string!(GetTaskInfoMethod = "tasks/get");
1704+
pub type GetTaskInfoRequest = Request<GetTaskInfoMethod, GetTaskInfoParam>;
1705+
1706+
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1707+
#[serde(rename_all = "camelCase")]
1708+
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
1709+
pub struct GetTaskInfoParam {
1710+
pub task_id: String,
1711+
}
1712+
1713+
const_string!(ListTasksMethod = "tasks/list");
1714+
pub type ListTasksRequest = RequestOptionalParam<ListTasksMethod, PaginatedRequestParam>;
1715+
16941716
// =============================================================================
16951717
// MESSAGE TYPE UNIONS
16961718
// =============================================================================
@@ -1757,7 +1779,9 @@ ts_union!(
17571779
| SubscribeRequest
17581780
| UnsubscribeRequest
17591781
| CallToolRequest
1760-
| ListToolsRequest;
1782+
| ListToolsRequest
1783+
| GetTaskInfoRequest
1784+
| ListTasksRequest;
17611785
);
17621786

17631787
impl ClientRequest {
@@ -1776,6 +1800,8 @@ impl ClientRequest {
17761800
ClientRequest::UnsubscribeRequest(r) => r.method.as_str(),
17771801
ClientRequest::CallToolRequest(r) => r.method.as_str(),
17781802
ClientRequest::ListToolsRequest(r) => r.method.as_str(),
1803+
ClientRequest::GetTaskInfoRequest(r) => r.method.as_str(),
1804+
ClientRequest::ListTasksRequest(r) => r.method.as_str(),
17791805
}
17801806
}
17811807
}
@@ -1832,6 +1858,8 @@ ts_union!(
18321858
| CallToolResult
18331859
| ListToolsResult
18341860
| CreateElicitationResult
1861+
| GetTaskInfoResult
1862+
| ListTasksResult
18351863
| EmptyResult
18361864
;
18371865
);
@@ -1842,6 +1870,28 @@ impl ServerResult {
18421870
}
18431871
}
18441872

1873+
// =============================================================================
1874+
// TASK RESULT TYPES (Server responses for task queries)
1875+
// =============================================================================
1876+
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1877+
#[serde(rename_all = "camelCase")]
1878+
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
1879+
pub struct GetTaskInfoResult {
1880+
#[serde(skip_serializing_if = "Option::is_none")]
1881+
pub task: Option<crate::model::Task>,
1882+
}
1883+
1884+
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
1885+
#[serde(rename_all = "camelCase")]
1886+
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
1887+
pub struct ListTasksResult {
1888+
pub tasks: Vec<crate::model::Task>,
1889+
#[serde(skip_serializing_if = "Option::is_none")]
1890+
pub next_cursor: Option<String>,
1891+
#[serde(skip_serializing_if = "Option::is_none")]
1892+
pub total: Option<u64>,
1893+
}
1894+
18451895
pub type ServerJsonRpcMessage = JsonRpcMessage<ServerRequest, ServerResult, ServerNotification>;
18461896

18471897
impl TryInto<CancelledNotification> for ServerNotification {

crates/rmcp/src/model/meta.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ variant_extension! {
8686
UnsubscribeRequest
8787
CallToolRequest
8888
ListToolsRequest
89+
GetTaskInfoRequest
90+
ListTasksRequest
8991
}
9092
}
9193

@@ -154,6 +156,7 @@ impl Meta {
154156
})
155157
}
156158

159+
157160
pub fn set_progress_token(&mut self, token: ProgressToken) {
158161
match token.0 {
159162
NumberOrString::String(ref s) => self.0.insert(

0 commit comments

Comments
 (0)