diff --git a/crates/axum-utils/src/client_authorization.rs b/crates/axum-utils/src/client_authorization.rs index edc660431..ed842590d 100644 --- a/crates/axum-utils/src/client_authorization.rs +++ b/crates/axum-utils/src/client_authorization.rs @@ -9,13 +9,12 @@ use std::collections::HashMap; use axum::{ BoxError, Json, extract::{ - Form, FromRequest, FromRequestParts, + Form, FromRequest, rejection::{FailedToDeserializeForm, FormRejection}, }, response::IntoResponse, }; -use axum_extra::typed_header::{TypedHeader, TypedHeaderRejectionReason}; -use headers::{Authorization, authorization::Basic}; +use headers::authorization::{Basic, Bearer, Credentials as _}; use http::{Request, StatusCode}; use mas_data_model::{Client, JwksOrJwksUri}; use mas_http::RequestBuilderExt; @@ -60,17 +59,30 @@ pub enum Credentials { client_id: String, jwt: Box>>, }, + BearerToken { + token: String, + }, } impl Credentials { /// Get the `client_id` of the credentials #[must_use] - pub fn client_id(&self) -> &str { + pub fn client_id(&self) -> Option<&str> { match self { Credentials::None { client_id } | Credentials::ClientSecretBasic { client_id, .. } | Credentials::ClientSecretPost { client_id, .. } - | Credentials::ClientAssertionJwtBearer { client_id, .. } => client_id, + | Credentials::ClientAssertionJwtBearer { client_id, .. } => Some(client_id), + Credentials::BearerToken { .. } => None, + } + } + + /// Get the bearer token from the credentials. + #[must_use] + pub fn bearer_token(&self) -> Option<&str> { + match self { + Credentials::BearerToken { token } => Some(token), + _ => None, } } @@ -89,6 +101,7 @@ impl Credentials { | Credentials::ClientSecretBasic { client_id, .. } | Credentials::ClientSecretPost { client_id, .. } | Credentials::ClientAssertionJwtBearer { client_id, .. } => client_id, + Credentials::BearerToken { .. } => return Ok(None), }; repo.oauth2_client().find_by_client_id(client_id).await @@ -239,7 +252,7 @@ pub struct ClientAuthorization { impl ClientAuthorization { /// Get the `client_id` from the credentials. #[must_use] - pub fn client_id(&self) -> &str { + pub fn client_id(&self) -> Option<&str> { self.credentials.client_id() } } @@ -360,25 +373,36 @@ where req: Request, state: &S, ) -> Result { - // Split the request into parts so we can extract some headers - let (mut parts, body) = req.into_parts(); - - let header = - TypedHeader::>::from_request_parts(&mut parts, state).await; - - // Take the Authorization header - let credentials_from_header = match header { - Ok(header) => Some((header.username().to_owned(), header.password().to_owned())), - Err(err) => match err.reason() { - // If it's missing it is fine - TypedHeaderRejectionReason::Missing => None, - // If the header could not be parsed, return the error - _ => return Err(ClientAuthorizationError::InvalidHeader), - }, - }; + enum Authorization { + Basic(String, String), + Bearer(String), + } + + // Sadly, the typed-header 'Authorization' doesn't let us check for both + // Basic and Bearer at the same time, so we need to parse them manually + let authorization = if let Some(header) = req.headers().get(http::header::AUTHORIZATION) { + let bytes = header.as_bytes(); + if bytes.len() >= 6 && bytes[..6].eq_ignore_ascii_case(b"Basic ") { + let Some(decoded) = Basic::decode(header) else { + return Err(ClientAuthorizationError::InvalidHeader); + }; - // Reconstruct the request from the parts - let req = Request::from_parts(parts, body); + Some(Authorization::Basic( + decoded.username().to_owned(), + decoded.password().to_owned(), + )) + } else if bytes.len() >= 7 && bytes[..7].eq_ignore_ascii_case(b"Bearer ") { + let Some(decoded) = Bearer::decode(header) else { + return Err(ClientAuthorizationError::InvalidHeader); + }; + + Some(Authorization::Bearer(decoded.token().to_owned())) + } else { + return Err(ClientAuthorizationError::InvalidHeader); + } + } else { + None + }; // Take the form value let ( @@ -407,13 +431,19 @@ where // And now, figure out the actual auth method let credentials = match ( - credentials_from_header, + authorization, client_id_from_form, client_secret_from_form, client_assertion_type, client_assertion, ) { - (Some((client_id, client_secret)), client_id_from_form, None, None, None) => { + ( + Some(Authorization::Basic(client_id, client_secret)), + client_id_from_form, + None, + None, + None, + ) => { if let Some(client_id_from_form) = client_id_from_form { // If the client_id was in the body, verify it matches with the header if client_id != client_id_from_form { @@ -483,6 +513,11 @@ where }); } + (Some(Authorization::Bearer(token)), None, None, None, None) => { + // Got a bearer token + Credentials::BearerToken { token } + } + (None, None, None, None, None) => { // Special case when there are no credentials anywhere return Err(ClientAuthorizationError::MissingCredentials); @@ -677,4 +712,29 @@ mod tests { jwt.verify_with_shared_secret(b"client-secret".to_vec()) .unwrap(); } + + #[tokio::test] + async fn bearer_token_test() { + let req = Request::builder() + .method(Method::POST) + .header( + http::header::CONTENT_TYPE, + mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(), + ) + .header(http::header::AUTHORIZATION, "Bearer token") + .body(Body::new("foo=bar".to_owned())) + .unwrap(); + + assert_eq!( + ClientAuthorization::::from_request(req, &()) + .await + .unwrap(), + ClientAuthorization { + credentials: Credentials::BearerToken { + token: "token".to_owned(), + }, + form: Some(serde_json::json!({"foo": "bar"})), + } + ); + } } diff --git a/crates/handlers/src/oauth2/introspection.rs b/crates/handlers/src/oauth2/introspection.rs index 50c043b04..f2f225475 100644 --- a/crates/handlers/src/oauth2/introspection.rs +++ b/crates/handlers/src/oauth2/introspection.rs @@ -4,7 +4,7 @@ // SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial // Please see LICENSE files in the repository root for full details. -use std::sync::LazyLock; +use std::sync::{Arc, LazyLock}; use axum::{Json, extract::State, http::HeaderValue, response::IntoResponse}; use hyper::{HeaderMap, StatusCode}; @@ -15,6 +15,7 @@ use mas_axum_utils::{ use mas_data_model::{Device, TokenFormatError, TokenType}; use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint}; use mas_keystore::Encrypter; +use mas_matrix::HomeserverConnection; use mas_storage::{ BoxClock, BoxRepository, Clock, compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository}, @@ -102,8 +103,14 @@ pub enum RouteError { #[error("bad request")] BadRequest, + #[error("failed to verify token")] + FailedToVerifyToken(#[source] anyhow::Error), + #[error(transparent)] ClientCredentialsVerification(#[from] CredentialsVerificationError), + + #[error("bearer token presented is invalid")] + InvalidBearerToken, } impl IntoResponse for RouteError { @@ -114,13 +121,15 @@ impl IntoResponse for RouteError { | Self::CantLoadCompatSession(_) | Self::CantLoadOAuthSession(_) | Self::CantLoadUser(_) + | Self::FailedToVerifyToken(_) ); let response = match self { e @ (Self::Internal(_) | Self::CantLoadCompatSession(_) | Self::CantLoadOAuthSession(_) - | Self::CantLoadUser(_)) => ( + | Self::CantLoadUser(_) + | Self::FailedToVerifyToken(_)) => ( StatusCode::INTERNAL_SERVER_ERROR, Json( ClientError::from(ClientErrorCode::ServerError).with_description(e.to_string()), @@ -140,6 +149,14 @@ impl IntoResponse for RouteError { ), ) .into_response(), + e @ Self::InvalidBearerToken => ( + StatusCode::UNAUTHORIZED, + Json( + ClientError::from(ClientErrorCode::AccessDenied) + .with_description(e.to_string()), + ), + ) + .into_response(), Self::UnknownToken(_) | Self::UnexpectedTokenType @@ -195,7 +212,7 @@ const SYNAPSE_ADMIN_SCOPE: ScopeToken = ScopeToken::from_static("urn:synapse:adm #[tracing::instrument( name = "handlers.oauth2.introspection.post", - fields(client.id = client_authorization.client_id()), + fields(client.id = credentials.client_id()), skip_all, )] #[allow(clippy::too_many_lines)] @@ -205,28 +222,41 @@ pub(crate) async fn post( mut repo: BoxRepository, activity_tracker: ActivityTracker, State(encrypter): State, + State(homeserver): State>, headers: HeaderMap, - client_authorization: ClientAuthorization, + ClientAuthorization { credentials, form }: ClientAuthorization, ) -> Result { - let client = client_authorization - .credentials - .fetch(&mut repo) - .await? - .ok_or(RouteError::ClientNotFound)?; - - let method = match &client.token_endpoint_auth_method { - None | Some(OAuthClientAuthenticationMethod::None) => { - return Err(RouteError::NotAllowed(client.id)); + if let Some(token) = credentials.bearer_token() { + // If the client presented a bearer token, we check with the homeserver + // configuration if it is allowed to use the introspection endpoint + if !homeserver + .verify_token(token) + .await + .map_err(RouteError::FailedToVerifyToken)? + { + return Err(RouteError::InvalidBearerToken); } - Some(c) => c, - }; + } else { + // Otherwise, it presented regular client credentials, so we verify them + let client = credentials + .fetch(&mut repo) + .await? + .ok_or(RouteError::ClientNotFound)?; + + // Only confidential clients are allowed to introspect + let method = match &client.token_endpoint_auth_method { + None | Some(OAuthClientAuthenticationMethod::None) => { + return Err(RouteError::NotAllowed(client.id)); + } + Some(c) => c, + }; - client_authorization - .credentials - .verify(&http_client, &encrypter, method, &client) - .await?; + credentials + .verify(&http_client, &encrypter, method, &client) + .await?; + } - let Some(form) = client_authorization.form else { + let Some(form) = form else { return Err(RouteError::BadRequest); }; @@ -578,10 +608,11 @@ mod tests { use hyper::{Request, StatusCode}; use mas_data_model::{AccessToken, RefreshToken}; use mas_iana::oauth::OAuthTokenTypeHint; - use mas_matrix::{HomeserverConnection, ProvisionRequest}; + use mas_matrix::{HomeserverConnection, MockHomeserverConnection, ProvisionRequest}; use mas_router::{OAuth2Introspection, OAuth2RegistrationEndpoint, SimpleRoute}; use mas_storage::Clock; use oauth2_types::{ + errors::{ClientError, ClientErrorCode}, registration::ClientRegistrationResponse, requests::IntrospectionResponse, scope::{OPENID, Scope}, @@ -984,4 +1015,29 @@ mod tests { let response: IntrospectionResponse = response.json(); assert!(response.active); } + + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_introspect_with_bearer_token(pool: PgPool) { + setup(); + let state = TestState::from_pool(pool).await.unwrap(); + + // Check that talking to the introspection endpoint with the bearer token from + // the MockHomeserverConnection doens't error out + let request = Request::post(OAuth2Introspection::PATH) + .bearer(MockHomeserverConnection::VALID_BEARER_TOKEN) + .form(json!({ "token": "some_token" })); + let response = state.request(request).await; + response.assert_status(StatusCode::OK); + let response: IntrospectionResponse = response.json(); + assert!(!response.active); + + // Check with another token, we should get a 401 + let request = Request::post(OAuth2Introspection::PATH) + .bearer("another_token") + .form(json!({ "token": "some_token" })); + let response = state.request(request).await; + response.assert_status(StatusCode::UNAUTHORIZED); + let response: ClientError = response.json(); + assert_eq!(response.error, ClientErrorCode::AccessDenied); + } } diff --git a/crates/matrix-synapse/src/legacy.rs b/crates/matrix-synapse/src/legacy.rs index d07e6b5d5..b93298ceb 100644 --- a/crates/matrix-synapse/src/legacy.rs +++ b/crates/matrix-synapse/src/legacy.rs @@ -160,6 +160,11 @@ impl HomeserverConnection for SynapseConnection { &self.homeserver } + #[tracing::instrument(name = "homeserver.verify_token", skip_all, err(Debug))] + async fn verify_token(&self, token: &str) -> Result { + Ok(self.access_token == token) + } + #[tracing::instrument( name = "homeserver.query_user", skip_all, diff --git a/crates/matrix-synapse/src/modern.rs b/crates/matrix-synapse/src/modern.rs index 26c8e21a1..3d70f52de 100644 --- a/crates/matrix-synapse/src/modern.rs +++ b/crates/matrix-synapse/src/modern.rs @@ -66,6 +66,11 @@ impl HomeserverConnection for SynapseConnection { &self.homeserver } + #[tracing::instrument(name = "homeserver.verify_token", skip_all, err(Debug))] + async fn verify_token(&self, token: &str) -> Result { + Ok(self.access_token == token) + } + #[tracing::instrument( name = "homeserver.query_user", skip_all, diff --git a/crates/matrix/src/lib.rs b/crates/matrix/src/lib.rs index 34c502f81..f1fbe9c83 100644 --- a/crates/matrix/src/lib.rs +++ b/crates/matrix/src/lib.rs @@ -207,6 +207,20 @@ pub trait HomeserverConnection: Send + Sync { Some(mxid.localpart()) } + /// Verify a bearer token coming from the homeserver for homeserver to MAS + /// interactions + /// + /// Returns `true` if the token is valid, `false` otherwise. + /// + /// # Parameters + /// + /// * `token` - The token to verify. + /// + /// # Errors + /// + /// Returns an error if the token failed to verify. + async fn verify_token(&self, token: &str) -> Result; + /// Query the state of a user on the homeserver. /// /// # Parameters @@ -384,6 +398,10 @@ impl HomeserverConnection for &T (**self).homeserver() } + async fn verify_token(&self, token: &str) -> Result { + (**self).verify_token(token).await + } + async fn query_user(&self, localpart: &str) -> Result { (**self).query_user(localpart).await } @@ -462,6 +480,10 @@ impl HomeserverConnection for Arc { (**self).homeserver() } + async fn verify_token(&self, token: &str) -> Result { + (**self).verify_token(token).await + } + async fn query_user(&self, localpart: &str) -> Result { (**self).query_user(localpart).await } diff --git a/crates/matrix/src/mock.rs b/crates/matrix/src/mock.rs index 0c315ff97..4180597e2 100644 --- a/crates/matrix/src/mock.rs +++ b/crates/matrix/src/mock.rs @@ -31,6 +31,10 @@ pub struct HomeserverConnection { } impl HomeserverConnection { + /// A valid bearer token that will be accepted by + /// [`crate::HomeserverConnection::verify_token`]. + pub const VALID_BEARER_TOKEN: &str = "mock_homeserver_bearer_token"; + /// Create a new mock connection. pub fn new(homeserver: H) -> Self where @@ -54,6 +58,10 @@ impl crate::HomeserverConnection for HomeserverConnection { &self.homeserver } + async fn verify_token(&self, token: &str) -> Result { + Ok(token == Self::VALID_BEARER_TOKEN) + } + async fn query_user(&self, localpart: &str) -> Result { let mxid = self.mxid(localpart); let users = self.users.read().await; diff --git a/crates/matrix/src/readonly.rs b/crates/matrix/src/readonly.rs index 2efa935f9..590583bf8 100644 --- a/crates/matrix/src/readonly.rs +++ b/crates/matrix/src/readonly.rs @@ -28,6 +28,10 @@ impl HomeserverConnection for ReadOnlyHomeserverConnect self.inner.homeserver() } + async fn verify_token(&self, token: &str) -> Result { + self.inner.verify_token(token).await + } + async fn query_user(&self, localpart: &str) -> Result { self.inner.query_user(localpart).await }