-
Notifications
You must be signed in to change notification settings - Fork 299
feat: Add MCP Elicitation support #332
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
4acf13c
ce4a03a
e1a995c
174af7e
f4819a5
c5fb31c
242a8ab
c080d40
6867d5c
01dd3e8
c649de1
0be52a1
3b18efc
8296242
3d6a9d5
5be0f43
edc5bed
a8c9b5e
3c98177
fd54781
a7211f3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,34 @@ use std::borrow::Cow; | |
use thiserror::Error; | ||
|
||
use super::*; | ||
|
||
/// Validates timeout values to prevent DoS attacks and ensure reasonable limits | ||
fn validate_timeout(timeout: Option<std::time::Duration>) -> Result<(), ServiceError> { | ||
if let Some(duration) = timeout { | ||
const MAX_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300); // 5 minutes max | ||
const MIN_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(1); // 1ms min | ||
|
||
if duration > MAX_TIMEOUT { | ||
return Err(ServiceError::InvalidTimeout { | ||
timeout: duration, | ||
reason: "Timeout exceeds maximum allowed duration (300 seconds)".to_string(), | ||
}); | ||
} | ||
if duration < MIN_TIMEOUT { | ||
return Err(ServiceError::InvalidTimeout { | ||
timeout: duration, | ||
reason: "Timeout must be at least 1 millisecond".to_string(), | ||
}); | ||
} | ||
if duration.is_zero() { | ||
return Err(ServiceError::InvalidTimeout { | ||
timeout: duration, | ||
reason: "Timeout cannot be zero".to_string(), | ||
}); | ||
} | ||
} | ||
Ok(()) | ||
} | ||
#[cfg(feature = "elicitation")] | ||
use crate::model::{ | ||
CreateElicitationRequest, CreateElicitationRequestParam, CreateElicitationResult, | ||
|
@@ -335,6 +363,9 @@ macro_rules! method { | |
&self, | ||
timeout: Option<std::time::Duration>, | ||
) -> Result<$Resp, ServiceError> { | ||
// Validate timeout to prevent DoS attacks | ||
validate_timeout(timeout)?; | ||
|
||
let request = ServerRequest::$Req($Req { | ||
method: Default::default(), | ||
extensions: Default::default(), | ||
|
@@ -361,6 +392,9 @@ macro_rules! method { | |
params: $Param, | ||
timeout: Option<std::time::Duration>, | ||
) -> Result<$Resp, ServiceError> { | ||
// Validate timeout to prevent DoS attacks | ||
validate_timeout(timeout)?; | ||
|
||
let request = ServerRequest::$Req($Req { | ||
method: Default::default(), | ||
params, | ||
|
@@ -413,10 +447,18 @@ pub enum ElicitationError { | |
#[error("Service error: {0}")] | ||
Service(#[from] ServiceError), | ||
|
||
/// User declined to provide input or cancelled the request | ||
#[error("User declined or cancelled the request")] | ||
/// User explicitly declined to provide the requested information | ||
/// This indicates a conscious decision by the user to reject the request | ||
/// (e.g., clicked "Reject", "Decline", "No", etc.) | ||
#[error("User explicitly declined the request")] | ||
UserDeclined, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think declined and canceled are materially, if subtly, different. The spec says:
so probably they should be separated here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds reasonable There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
/// User dismissed the request without making an explicit choice | ||
/// This indicates the user cancelled without explicitly declining | ||
/// (e.g., closed dialog, clicked outside, pressed Escape, etc.) | ||
#[error("User cancelled/dismissed the request")] | ||
UserCancelled, | ||
|
||
/// The response data could not be parsed into the requested type | ||
#[error("Failed to parse response data: {error}\nReceived data: {data}")] | ||
ParseError { | ||
|
@@ -433,6 +475,59 @@ pub enum ElicitationError { | |
CapabilityNotSupported, | ||
} | ||
|
||
/// Marker trait to ensure that elicitation types generate object-type JSON schemas. | ||
/// | ||
/// This trait provides compile-time safety to ensure that types used with | ||
/// `elicit<T>()` methods will generate JSON schemas of type "object", which | ||
/// aligns with MCP client expectations for structured data input. | ||
/// | ||
/// # Type Safety Rationale | ||
/// | ||
/// MCP clients typically expect JSON objects for elicitation schemas to | ||
/// provide structured forms and validation. This trait prevents common | ||
/// mistakes like: | ||
/// | ||
/// ```compile_fail | ||
/// // These would not compile due to missing ElicitationSafe bound: | ||
/// let name: String = server.elicit("Enter name").await?; // Primitive | ||
/// let items: Vec<i32> = server.elicit("Enter items").await?; // Array | ||
/// ``` | ||
#[cfg(feature = "elicitation")] | ||
pub trait ElicitationSafe: schemars::JsonSchema {} | ||
|
||
/// Macro to mark types as safe for elicitation by verifying they generate object schemas. | ||
/// | ||
/// This macro automatically implements the `ElicitationSafe` trait for struct types | ||
/// that should be used with `elicit<T>()` methods. | ||
/// | ||
/// # Example | ||
/// | ||
/// ```rust | ||
/// use rmcp::elicit_safe; | ||
/// use schemars::JsonSchema; | ||
/// use serde::{Deserialize, Serialize}; | ||
/// | ||
/// #[derive(Serialize, Deserialize, JsonSchema)] | ||
/// struct UserProfile { | ||
/// name: String, | ||
/// email: String, | ||
/// } | ||
/// | ||
/// elicit_safe!(UserProfile); | ||
/// | ||
/// // Now safe to use: | ||
/// let profile: UserProfile = server.elicit("Enter profile").await?; | ||
/// ``` | ||
#[cfg(feature = "elicitation")] | ||
#[macro_export] | ||
macro_rules! elicit_safe { | ||
($($t:ty),* $(,)?) => { | ||
$( | ||
impl $crate::service::ElicitationSafe for $t {} | ||
)* | ||
}; | ||
} | ||
|
||
#[cfg(feature = "elicitation")] | ||
impl Peer<RoleServer> { | ||
/// Check if the client supports elicitation capability | ||
|
@@ -466,7 +561,8 @@ impl Peer<RoleServer> { | |
/// | ||
/// # Returns | ||
/// * `Ok(Some(data))` if user provided valid data that matches type T | ||
/// * `Err(ElicitationError::UserDeclined)` if user declined or cancelled the request | ||
/// * `Err(ElicitationError::UserDeclined)` if user explicitly declined the request | ||
/// * `Err(ElicitationError::UserCancelled)` if user cancelled/dismissed the request | ||
/// * `Err(ElicitationError::ParseError { .. })` if response data couldn't be parsed into type T | ||
/// * `Err(ElicitationError::NoContent)` if no response content was provided | ||
/// * `Err(ElicitationError::Service(_))` if the underlying service call failed | ||
|
@@ -497,16 +593,24 @@ impl Peer<RoleServer> { | |
/// age: u8, | ||
/// } | ||
/// | ||
/// // Mark as safe for elicitation (generates object schema) | ||
/// rmcp::elicit_safe!(UserProfile); | ||
/// | ||
/// # async fn example(peer: Peer<RoleServer>) -> Result<(), Box<dyn std::error::Error>> { | ||
/// match peer.elicit::<UserProfile>("Please enter your profile information").await { | ||
/// Ok(Some(profile)) => { | ||
/// println!("Name: {}, Email: {}, Age: {}", profile.name, profile.email, profile.age); | ||
/// } | ||
/// Ok(None) => { | ||
/// println!("User declined to provide information"); | ||
/// println!("User provided no content"); | ||
/// } | ||
/// Err(ElicitationError::UserDeclined) => { | ||
/// println!("User declined to provide information"); | ||
/// println!("User explicitly declined to provide information"); | ||
/// // Handle explicit decline - perhaps offer alternatives | ||
/// } | ||
/// Err(ElicitationError::UserCancelled) => { | ||
/// println!("User cancelled the request"); | ||
/// // Handle cancellation - perhaps prompt again later | ||
/// } | ||
/// Err(ElicitationError::ParseError { error, data }) => { | ||
/// println!("Failed to parse response: {}\nData: {}", error, data); | ||
|
@@ -519,7 +623,7 @@ impl Peer<RoleServer> { | |
#[cfg(all(feature = "schemars", feature = "elicitation"))] | ||
pub async fn elicit<T>(&self, message: impl Into<String>) -> Result<Option<T>, ElicitationError> | ||
where | ||
T: schemars::JsonSchema + for<'de> serde::Deserialize<'de>, | ||
T: ElicitationSafe + for<'de> serde::Deserialize<'de>, | ||
{ | ||
self.elicit_with_timeout(message, None).await | ||
} | ||
|
@@ -549,6 +653,9 @@ impl Peer<RoleServer> { | |
/// answer: String, | ||
/// } | ||
/// | ||
/// // Mark as safe for elicitation | ||
/// rmcp::elicit_safe!(QuickResponse); | ||
/// | ||
/// # async fn example(peer: Peer<RoleServer>) -> Result<(), Box<dyn std::error::Error>> { | ||
/// // Give user 30 seconds to respond | ||
/// let timeout = Some(Duration::from_secs(30)); | ||
|
@@ -557,7 +664,15 @@ impl Peer<RoleServer> { | |
/// timeout | ||
/// ).await { | ||
/// Ok(Some(response)) => println!("Got answer: {}", response.answer), | ||
/// Ok(None) => println!("User declined"), | ||
/// Ok(None) => println!("User provided no content"), | ||
/// Err(ElicitationError::UserDeclined) => { | ||
/// println!("User explicitly declined"); | ||
/// // Handle explicit decline | ||
/// } | ||
/// Err(ElicitationError::UserCancelled) => { | ||
/// println!("User cancelled/dismissed"); | ||
/// // Handle cancellation | ||
/// } | ||
/// Err(ElicitationError::Service(ServiceError::Timeout { .. })) => { | ||
/// println!("User didn't respond in time"); | ||
/// } | ||
|
@@ -573,7 +688,7 @@ impl Peer<RoleServer> { | |
timeout: Option<std::time::Duration>, | ||
) -> Result<Option<T>, ElicitationError> | ||
where | ||
T: schemars::JsonSchema + for<'de> serde::Deserialize<'de>, | ||
T: ElicitationSafe + for<'de> serde::Deserialize<'de>, | ||
{ | ||
// Check if client supports elicitation capability | ||
if !self.supports_elicitation() { | ||
|
@@ -604,7 +719,8 @@ impl Peer<RoleServer> { | |
Err(ElicitationError::NoContent) | ||
} | ||
} | ||
_ => Err(ElicitationError::UserDeclined), | ||
crate::model::ElicitationAction::Decline => Err(ElicitationError::UserDeclined), | ||
crate::model::ElicitationAction::Cancel => Err(ElicitationError::UserCancelled), | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm is this necessary? Why does it need to be limited to 1ms-5min? Values outside of that might be very uncommon but I'd think that should be left to the server developer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jamadeo You're absolutely right. Hard-coding these limits was an oversight - server developers should have control over timeout policies for their specific use cases.
I'll remove the validation and let developers implement their own timeout constraints as needed. The library shouldn't enforce arbitrary business logic decisions.
Thanks for catching this!