diff --git a/crates/rmcp/src/transport/auth.rs b/crates/rmcp/src/transport/auth.rs index ac55e5e4..fc7015e0 100644 --- a/crates/rmcp/src/transport/auth.rs +++ b/crates/rmcp/src/transport/auth.rs @@ -6,11 +6,14 @@ use std::{ use oauth2::{ AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, EmptyExtraTokenFields, - PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, Scope, StandardTokenResponse, - TokenResponse, TokenUrl, + PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, RequestTokenError, Scope, + StandardTokenResponse, TokenResponse, TokenUrl, basic::{BasicClient, BasicTokenType}, }; -use reqwest::{Client as HttpClient, IntoUrl, StatusCode, Url, header::AUTHORIZATION}; +use reqwest::{ + Client as HttpClient, IntoUrl, StatusCode, Url, + header::{AUTHORIZATION, WWW_AUTHENTICATE}, +}; use serde::{Deserialize, Serialize}; use thiserror::Error; use tokio::sync::{Mutex, RwLock}; @@ -111,6 +114,12 @@ pub struct AuthorizationMetadata { pub additional_fields: HashMap, } +#[derive(Debug, Clone, Deserialize)] +struct ResourceServerMetadata { + authorization_server: Option, + authorization_servers: Option>, +} + /// oauth2 client config #[derive(Debug, Clone)] pub struct OAuthClientConfig { @@ -165,7 +174,7 @@ pub struct ClientRegistrationRequest { pub struct ClientRegistrationResponse { pub client_id: String, pub client_secret: Option, - pub client_name: String, + pub client_name: Option, pub redirect_uris: Vec, // allow additional fields #[serde(flatten)] @@ -234,46 +243,15 @@ impl AuthorizationManager { /// discover oauth2 metadata pub async fn discover_metadata(&self) -> Result { - for candidate_path in - Self::well_known_paths(self.base_url.path(), "oauth-authorization-server") - { - let mut discovery_url = self.base_url.clone(); - discovery_url.set_path(&candidate_path); - debug!("discovery url: {:?}", discovery_url); - - let response = match self - .http_client - .get(discovery_url) - .header("MCP-Protocol-Version", "2024-11-05") - .send() - .await - { - Ok(r) => r, - Err(e) => { - debug!("discovery request failed: {}", e); - continue; // try next candidate if request fails - } - }; + if let Some(metadata) = self.try_discover_oauth_server(&self.base_url).await? { + return Ok(metadata); + } - if response.status() != StatusCode::OK { - debug!("discovery returned non-200: {}", response.status()); - continue; // try next candidate if response is not OK - } - - // parse metadata - let metadata = response - .json::() - .await - .map_err(|e| { - // Fail the discovery if we get a 200 but cannot parse the response - // This indicates a misconfiguration on the server side - AuthError::MetadataError(format!("Failed to parse metadata: {}", e)) - })?; - debug!("metadata: {:?}", metadata); + if let Some(metadata) = self.discover_oauth_server_via_resource_metadata().await? { return Ok(metadata); } - warn!("No valid .well-known endpoint found, falling back to default endpoints"); + warn!("No valid authorization metadata found, falling back to default endpoints"); // fallback to default endpoints let mut auth_base = self.base_url.clone(); @@ -501,12 +479,30 @@ impl AuthorizationManager { debug!("client_id: {:?}", oauth_client.client_id()); // exchange token - let token_result = oauth_client + let token_result = match oauth_client .exchange_code(AuthorizationCode::new(code.to_string())) .set_pkce_verifier(pkce_verifier) .request_async(&http_client) .await - .map_err(|e| AuthError::TokenExchangeFailed(e.to_string()))?; + { + Ok(token) => token, + Err(RequestTokenError::Parse(_, body)) => { + match serde_json::from_slice::(&body) { + Ok(parsed) => { + warn!( + "token exchange failed to parse completely but included a valid token response. Accepting it." + ); + parsed + } + Err(parse_err) => { + return Err(AuthError::TokenExchangeFailed(parse_err.to_string())); + } + } + } + Err(e) => { + return Err(AuthError::TokenExchangeFailed(e.to_string())); + } + }; // get expires_in from token response let expires_in = token_result.expires_in(); @@ -602,6 +598,254 @@ impl AuthorizationManager { Ok(response) } } + + async fn try_discover_oauth_server( + &self, + base_url: &Url, + ) -> Result, AuthError> { + for candidate_path in Self::well_known_paths(base_url.path(), "oauth-authorization-server") + { + let mut discovery_url = base_url.clone(); + discovery_url.set_query(None); + discovery_url.set_fragment(None); + discovery_url.set_path(&candidate_path); + if let Some(metadata) = self.fetch_authorization_metadata(&discovery_url).await? { + return Ok(Some(metadata)); + } + } + Ok(None) + } + + async fn fetch_authorization_metadata( + &self, + discovery_url: &Url, + ) -> Result, AuthError> { + debug!("discovery url: {:?}", discovery_url); + let response = match self + .http_client + .get(discovery_url.clone()) + .header("MCP-Protocol-Version", "2024-11-05") + .send() + .await + { + Ok(r) => r, + Err(e) => { + debug!("discovery request failed: {}", e); + return Ok(None); + } + }; + + if response.status() != StatusCode::OK { + debug!("discovery returned non-200: {}", response.status()); + return Ok(None); + } + + let metadata = response + .json::() + .await + .map_err(|e| AuthError::MetadataError(format!("Failed to parse metadata: {}", e)))?; + debug!("metadata: {:?}", metadata); + Ok(Some(metadata)) + } + + async fn discover_oauth_server_via_resource_metadata( + &self, + ) -> Result, AuthError> { + let Some(resource_metadata_url) = self.fetch_resource_metadata_url().await? else { + return Ok(None); + }; + + let Some(resource_metadata) = self + .fetch_resource_metadata_from_url(&resource_metadata_url) + .await? + else { + return Ok(None); + }; + + let mut candidates = Vec::new(); + + if let Some(single) = resource_metadata.authorization_server { + candidates.push(single); + } + if let Some(list) = resource_metadata.authorization_servers { + candidates.extend(list); + } + + for candidate in candidates { + let candidate = candidate.trim(); + if candidate.is_empty() { + continue; + } + + let candidate_url = match Url::parse(candidate) { + Ok(url) => url, + Err(_) => match resource_metadata_url.join(candidate) { + Ok(url) => url, + Err(e) => { + debug!("Failed to resolve authorization server URL `{candidate}`: {e}"); + continue; + } + }, + }; + + if candidate_url.path().contains("/.well-known/") { + if let Some(metadata) = self.fetch_authorization_metadata(&candidate_url).await? { + return Ok(Some(metadata)); + } + continue; + } + + if let Some(metadata) = self.try_discover_oauth_server(&candidate_url).await? { + return Ok(Some(metadata)); + } + } + + Ok(None) + } + + /// Extract the resource metadata url from the WWW-Authenticate header value. + /// https://www.rfc-editor.org/rfc/rfc9728.html#name-use-of-www-authenticate-for + async fn fetch_resource_metadata_url(&self) -> Result, AuthError> { + let response = match self + .http_client + .get(self.base_url.clone()) + .header("MCP-Protocol-Version", "2024-11-05") + .send() + .await + { + Ok(r) => r, + Err(e) => { + debug!("resource metadata probe failed: {}", e); + return Ok(None); + } + }; + + if response.status() != StatusCode::UNAUTHORIZED { + debug!( + "resource metadata probe returned unexpected status: {}", + response.status() + ); + return Ok(None); + } + + let mut parsed_url = None; + for value in response.headers().get_all(WWW_AUTHENTICATE).iter() { + let Ok(value_str) = value.to_str() else { + continue; + }; + if let Some(url) = + Self::extract_resource_metadata_url_from_header(value_str, &self.base_url) + { + parsed_url = Some(url); + break; + } + } + + Ok(parsed_url) + } + + async fn fetch_resource_metadata_from_url( + &self, + resource_metadata_url: &Url, + ) -> Result, AuthError> { + debug!( + "resource metadata discovery url: {:?}", + resource_metadata_url + ); + let response = match self + .http_client + .get(resource_metadata_url.clone()) + .header("MCP-Protocol-Version", "2024-11-05") + .send() + .await + { + Ok(r) => r, + Err(e) => { + debug!("resource metadata request failed: {}", e); + return Ok(None); + } + }; + + if response.status() != StatusCode::OK { + debug!( + "resource metadata request returned non-200: {}", + response.status() + ); + return Ok(None); + } + + let metadata = response + .json::() + .await + .map_err(|e| { + AuthError::MetadataError(format!("Failed to parse resource metadata: {}", e)) + })?; + Ok(Some(metadata)) + } + + /// Extracts a url following `resource_metadata=` in a header value + fn extract_resource_metadata_url_from_header(header: &str, base_url: &Url) -> Option { + let header_lowercase = header.to_ascii_lowercase(); + let fragment_key = "resource_metadata="; + let mut search_offset = 0; + + while let Some(pos) = header_lowercase[search_offset..].find(fragment_key) { + let global_pos = search_offset + pos + fragment_key.len(); + let value_slice = &header[global_pos..]; + if let Some((value, consumed)) = Self::parse_next_header_value(value_slice) { + if let Ok(url) = Url::parse(&value) { + return Some(url); + } + if let Ok(url) = base_url.join(&value) { + return Some(url); + } + debug!("failed to parse resource metadata value `{value}` as URL"); + search_offset = global_pos + consumed; + continue; + } else { + break; + } + } + + None + } + + /// Parses an authentication parameter value from a `WWW-Authenticate` header fragment. + /// The header fragment should start with the header value after the `=` character and then + /// reads until the value ends. + /// + /// Returns the extracted value together with the number of bytes consumed from the provided + /// fragment. Quoted values support escaped characters (e.g. `\"`). The parser skips leading + /// whitespace before reading either a quoted or token value. If no well-formed value is found, + /// `None` is returned. + fn parse_next_header_value(header_fragment: &str) -> Option<(String, usize)> { + let trimmed = header_fragment.trim_start(); + let leading_ws = header_fragment.len() - trimmed.len(); + + if let Some(stripped) = trimmed.strip_prefix('"') { + let mut escaped = false; + let mut result = String::new(); + #[allow(clippy::manual_strip)] + for (idx, ch) in stripped.char_indices() { + if escaped { + result.push(ch); + escaped = false; + continue; + } + match ch { + '\\' => escaped = true, + '"' => return Some((result, leading_ws + idx + 2)), + _ => result.push(ch), + } + } + None + } else { + let end = trimmed + .find(|c: char| c == ',' || c == ';' || c.is_whitespace()) + .unwrap_or(trimmed.len()); + Some((trimmed[..end].to_string(), leading_ws + end)) + } + } } /// oauth2 authorization session, for guiding user to complete the authorization process @@ -921,8 +1165,78 @@ impl OAuthState { #[cfg(test)] mod tests { + use url::Url; + use super::AuthorizationManager; + #[test] + fn parses_resource_metadata_parameter() { + let header = r#"Bearer error="invalid_request", error_description="missing token", resource_metadata="https://example.com/.well-known/oauth-protected-resource/api""#; + let base = Url::parse("https://example.com/api").unwrap(); + let parsed = AuthorizationManager::extract_resource_metadata_url_from_header(header, &base); + assert_eq!( + parsed.unwrap().as_str(), + "https://example.com/.well-known/oauth-protected-resource/api" + ); + } + + #[test] + fn parses_relative_resource_metadata_parameter() { + let header = r#"Bearer error="invalid_request", resource_metadata="/.well-known/oauth-protected-resource/api""#; + let base = Url::parse("https://example.com/api").unwrap(); + let parsed = AuthorizationManager::extract_resource_metadata_url_from_header(header, &base); + assert_eq!( + parsed.unwrap().as_str(), + "https://example.com/.well-known/oauth-protected-resource/api" + ); + } + + #[test] + fn parse_auth_param_value_handles_quoted_string() { + let fragment = r#""example", realm="foo""#; + let parsed = AuthorizationManager::parse_next_header_value(fragment).unwrap(); + assert_eq!(parsed.0, "example"); + assert_eq!(parsed.1, 9); + } + + #[test] + fn parse_auth_param_value_handles_escaped_quotes_and_whitespace() { + let fragment = r#" "a\"b\\c" ,next=value"#; + let parsed = AuthorizationManager::parse_next_header_value(fragment).unwrap(); + assert_eq!(parsed.0, r#"a"b\c"#); + assert_eq!(parsed.1, 12); + } + + #[test] + fn parse_auth_param_value_handles_token_values() { + let fragment = " token,next"; + let parsed = AuthorizationManager::parse_next_header_value(fragment).unwrap(); + assert_eq!(parsed.0, "token"); + assert_eq!(parsed.1, 7); + } + + #[test] + fn parse_auth_param_value_handles_semicolon_separated_tokens() { + let fragment = r#" https://example.com/meta; error="invalid_token""#; + let parsed = AuthorizationManager::parse_next_header_value(fragment).unwrap(); + assert_eq!(parsed.0, "https://example.com/meta"); + assert_eq!(&fragment[..parsed.1], " https://example.com/meta"); + } + + #[test] + fn parse_auth_param_value_handles_semicolon_after_quoted_value() { + let fragment = r#" "https://example.com/meta"; error="invalid_token""#; + let parsed = AuthorizationManager::parse_next_header_value(fragment).unwrap(); + assert_eq!(parsed.0, "https://example.com/meta"); + assert_eq!(&fragment[..parsed.1], r#" "https://example.com/meta""#); + } + + #[test] + fn parse_auth_param_value_returns_none_for_unterminated_quotes() { + let fragment = r#""unterminated,value"#; + assert!(AuthorizationManager::parse_next_header_value(fragment).is_none()); + } + #[test] fn well_known_paths_root() { let paths = AuthorizationManager::well_known_paths("/", "oauth-authorization-server"); diff --git a/examples/servers/src/complex_auth_sse.rs b/examples/servers/src/complex_auth_sse.rs index b3919aa9..c3f9fc36 100644 --- a/examples/servers/src/complex_auth_sse.rs +++ b/examples/servers/src/complex_auth_sse.rs @@ -572,7 +572,7 @@ async fn oauth_register( let response = ClientRegistrationResponse { client_id, client_secret: Some(client_secret), - client_name: req.client_name, + client_name: Some(req.client_name), redirect_uris: req.redirect_uris, additional_fields: HashMap::new(), };