Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions poem-mcpserver/src/protocol/resources.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ pub struct ResourcesListRequest {
pub cursor: Option<String>,
}

/// A request to list resource templates.
#[derive(Debug, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct ResourcesTemplatesListRequest {
/// The cursor to continue listing resource templates.
pub cursor: Option<String>,
}

/// Resource information.
#[derive(Debug, Serialize, Clone)]
#[serde(rename_all = "camelCase")]
Expand Down Expand Up @@ -56,6 +64,28 @@ pub struct ResourcesListResponse {
pub resources: Vec<Resource>,
}

/// Resource template information.
#[derive(Debug, Serialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct ResourceTemplate {
/// The URI template of the resource.
pub uri_template: String,
/// The display name of the resource template.
pub name: String,
/// A short description of the template.
pub description: String,
/// The mime type of the resource represented by the template.
pub mime_type: String,
}

/// A response to a resources/templates/list request.
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourcesTemplatesListResponse {
/// Resource templates list.
pub resource_templates: Vec<ResourceTemplate>,
}

/// A response to a resources/read request.
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
Expand Down
159 changes: 155 additions & 4 deletions poem-mcpserver/src/protocol/rpc.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
//! JSON-RPC protocol types.

use itertools::Either;
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Deserializer, Serialize, de::Error as _};
use serde_json::Value;

use crate::protocol::{
initialize::InitializeRequest,
prompts::{PromptsGetRequest, PromptsListRequest},
resources::{ResourcesListRequest, ResourcesReadRequest},
resources::{ResourcesListRequest, ResourcesReadRequest, ResourcesTemplatesListRequest},
tool::{ToolsCallRequest, ToolsListRequest},
};

Expand Down Expand Up @@ -39,6 +39,7 @@ pub enum Requests {
#[serde(rename = "notifications/cancelled")]
Cancelled {
/// The ID of the request to cancel
#[serde(alias = "requestId")]
request_id: RequestId,
/// An optional reason string that can be logged or displayed
reason: Option<String>,
Expand Down Expand Up @@ -76,6 +77,13 @@ pub enum Requests {
#[serde(default)]
params: ResourcesListRequest,
},
/// Resource templates list.
#[serde(rename = "resources/templates/list")]
ResourcesTemplatesList {
/// Resource templates list request parameters.
#[serde(default)]
params: ResourcesTemplatesListRequest,
},
/// Read a resource.
#[serde(rename = "resources/read")]
ResourcesRead {
Expand All @@ -85,15 +93,79 @@ pub enum Requests {
}

/// A JSON-RPC batch request.
#[derive(Debug, Deserialize)]
#[serde(untagged)]
#[derive(Debug)]
pub enum BatchRequest {
/// A single request.
Single(Request),
/// A batch of requests.
Batch(Vec<Request>),
}

impl<'de> Deserialize<'de> for BatchRequest {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
fn normalize_request(value: &mut Value) {
let Some(obj) = value.as_object_mut() else {
return;
};

let method = obj
.get("method")
.and_then(Value::as_str)
.unwrap_or_default();

match method {
"notifications/initialized" => {
if obj.get("params").is_some_and(|value| {
value.as_object().is_some_and(serde_json::Map::is_empty)
}) {
obj.remove("params");
}
}
"notifications/cancelled" => {
let Some(params) = obj.get("params").and_then(Value::as_object).cloned() else {
return;
};
if !obj.contains_key("request_id") && !obj.contains_key("requestId") {
if let Some(request_id) =
params.get("request_id").or_else(|| params.get("requestId"))
{
obj.insert("request_id".to_string(), request_id.clone());
}
}
if !obj.contains_key("reason") {
if let Some(reason) = params.get("reason") {
obj.insert("reason".to_string(), reason.clone());
}
}
}
_ => {}
}
}

let mut value = Value::deserialize(deserializer)?;
match &mut value {
Value::Object(_) => {
normalize_request(&mut value);
let request = serde_json::from_value(value).map_err(D::Error::custom)?;
Ok(BatchRequest::Single(request))
}
Value::Array(values) => {
for request in values {
normalize_request(request);
}
let requests = serde_json::from_value(value).map_err(D::Error::custom)?;
Ok(BatchRequest::Batch(requests))
}
_ => Err(D::Error::custom(
"data didnot match any variant of untagged enum BatchRequest",
)),
}
}
}

impl IntoIterator for BatchRequest {
type Item = Request;
type IntoIter = Either<std::iter::Once<Self::Item>, std::vec::IntoIter<Self::Item>>;
Expand Down Expand Up @@ -278,3 +350,82 @@ impl<E> RpcError<E> {
RpcError::new(INTERNAL_ERROR, message)
}
}

#[cfg(test)]
mod tests {
use serde_json::json;

use super::{BatchRequest, RequestId, Requests};

#[test]
fn parse_initialized_with_empty_params() {
let request: BatchRequest = serde_json::from_value(json!({
"jsonrpc": "2.0",
"method": "notifications/initialized",
"params": {}
}))
.expect("parse initialized notification");

let request = request.requests().first().expect("single request");
assert!(matches!(request.body, Requests::Initialized));
}

#[test]
fn parse_cancelled_with_params_object() {
let request: BatchRequest = serde_json::from_value(json!({
"jsonrpc": "2.0",
"method": "notifications/cancelled",
"params": {
"requestId": 42,
"reason": "user cancelled"
}
}))
.expect("parse cancelled notification");

let request = request.requests().first().expect("single request");
assert!(matches!(
request.body,
Requests::Cancelled {
request_id: RequestId::Int(42),
reason: Some(ref reason),
} if reason == "user cancelled"
));
}

#[test]
fn parse_cancelled_top_level_fields() {
let request: BatchRequest = serde_json::from_value(json!({
"jsonrpc": "2.0",
"method": "notifications/cancelled",
"requestId": "abc",
"reason": "timeout"
}))
.expect("parse cancelled notification");

let request = request.requests().first().expect("single request");
assert!(matches!(
request.body,
Requests::Cancelled {
request_id: RequestId::String(ref request_id),
reason: Some(ref reason),
} if request_id == "abc" && reason == "timeout"
));
}

#[test]
fn parse_resources_templates_list() {
let request: BatchRequest = serde_json::from_value(json!({
"jsonrpc": "2.0",
"id": 1,
"method": "resources/templates/list",
"params": {}
}))
.expect("parse resources/templates/list");

let request = request.requests().first().expect("single request");
assert!(matches!(
request.body,
Requests::ResourcesTemplatesList { .. }
));
}
}
17 changes: 16 additions & 1 deletion poem-mcpserver/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::{
prompts::{PromptsGetRequest, PromptsListResponse},
resources::{
Resource, ResourceContent, ResourcesListResponse, ResourcesReadRequest,
ResourcesReadResponse,
ResourcesReadResponse, ResourcesTemplatesListResponse,
},
rpc::{Request, RequestId, Requests, Response},
tool::{ToolsCallRequest, ToolsListResponse},
Expand Down Expand Up @@ -276,6 +276,18 @@ where
.map_result_to_value()
}

fn handle_resources_templates_list(&self, id: Option<RequestId>) -> Response<Value> {
Response {
jsonrpc: JSON_RPC_VERSION.to_string(),
id,
result: Some(ResourcesTemplatesListResponse {
resource_templates: vec![],
}),
error: None,
}
.map_result_to_value()
}

fn handle_resources_read(
&self,
request: ResourcesReadRequest,
Expand Down Expand Up @@ -320,6 +332,9 @@ where
Some(self.handle_prompts_get(params, request.id).await)
}
Requests::ResourcesList { .. } => Some(self.handle_resources_list(request.id)),
Requests::ResourcesTemplatesList { .. } => {
Some(self.handle_resources_templates_list(request.id))
}
Requests::ResourcesRead { params } => {
Some(self.handle_resources_read(params, request.id))
}
Expand Down
26 changes: 26 additions & 0 deletions poem-mcpserver/tests/resources.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,29 @@ async fn resources_list_and_read() {
})
);
}

#[tokio::test]
async fn resources_templates_list() {
let mut server = McpServer::new();

let resp = server
.handle_request(Request {
jsonrpc: JSON_RPC_VERSION.to_string(),
id: Some(RequestId::Int(1)),
body: Requests::ResourcesTemplatesList {
params: Default::default(),
},
})
.await;

assert_eq!(
serde_json::to_value(&resp).unwrap(),
serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"result": {
"resourceTemplates": []
}
})
);
}