Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 86 additions & 26 deletions crates/axum-utils/src/client_authorization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@ use std::collections::HashMap;
use axum::{
BoxError, Json,
extract::{
Form, FromRequest, FromRequestParts,
Form, FromRequest,
rejection::{FailedToDeserializeForm, FormRejection},
},
response::IntoResponse,
};
use axum_extra::typed_header::{TypedHeader, TypedHeaderRejectionReason};
use headers::{Authorization, authorization::Basic};
use headers::authorization::{Basic, Bearer, Credentials as _};
use http::{Request, StatusCode};
use mas_data_model::{Client, JwksOrJwksUri};
use mas_http::RequestBuilderExt;
Expand Down Expand Up @@ -60,17 +59,30 @@ pub enum Credentials {
client_id: String,
jwt: Box<Jwt<'static, HashMap<String, serde_json::Value>>>,
},
BearerToken {
token: String,
},
}

impl Credentials {
/// Get the `client_id` of the credentials
#[must_use]
pub fn client_id(&self) -> &str {
pub fn client_id(&self) -> Option<&str> {
match self {
Credentials::None { client_id }
| Credentials::ClientSecretBasic { client_id, .. }
| Credentials::ClientSecretPost { client_id, .. }
| Credentials::ClientAssertionJwtBearer { client_id, .. } => client_id,
| Credentials::ClientAssertionJwtBearer { client_id, .. } => Some(client_id),
Credentials::BearerToken { .. } => None,
}
}

/// Get the bearer token from the credentials.
#[must_use]
pub fn bearer_token(&self) -> Option<&str> {
match self {
Credentials::BearerToken { token } => Some(token),
_ => None,
}
}

Expand All @@ -89,6 +101,7 @@ impl Credentials {
| Credentials::ClientSecretBasic { client_id, .. }
| Credentials::ClientSecretPost { client_id, .. }
| Credentials::ClientAssertionJwtBearer { client_id, .. } => client_id,
Credentials::BearerToken { .. } => return Ok(None),
};

repo.oauth2_client().find_by_client_id(client_id).await
Expand Down Expand Up @@ -239,7 +252,7 @@ pub struct ClientAuthorization<F = ()> {
impl<F> ClientAuthorization<F> {
/// Get the `client_id` from the credentials.
#[must_use]
pub fn client_id(&self) -> &str {
pub fn client_id(&self) -> Option<&str> {
self.credentials.client_id()
}
}
Expand Down Expand Up @@ -360,25 +373,36 @@ where
req: Request<axum::body::Body>,
state: &S,
) -> Result<Self, Self::Rejection> {
// Split the request into parts so we can extract some headers
let (mut parts, body) = req.into_parts();

let header =
TypedHeader::<Authorization<Basic>>::from_request_parts(&mut parts, state).await;

// Take the Authorization header
let credentials_from_header = match header {
Ok(header) => Some((header.username().to_owned(), header.password().to_owned())),
Err(err) => match err.reason() {
// If it's missing it is fine
TypedHeaderRejectionReason::Missing => None,
// If the header could not be parsed, return the error
_ => return Err(ClientAuthorizationError::InvalidHeader),
},
};
enum Authorization {
Basic(String, String),
Bearer(String),
}

// Sadly, the typed-header 'Authorization' doesn't let us check for both
// Basic and Bearer at the same time, so we need to parse them manually
let authorization = if let Some(header) = req.headers().get(http::header::AUTHORIZATION) {
let bytes = header.as_bytes();
if bytes.len() >= 6 && bytes[..6].eq_ignore_ascii_case(b"Basic ") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if bytes.len() >= 6 && bytes[..6].eq_ignore_ascii_case(b"Basic ") {
if bytes.starts_with(b"Basic ") {

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to make a case-insensitive comparison, hence the eq_ignore_ascii_case

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah derp, fine :)

let Some(decoded) = Basic::decode(header) else {
return Err(ClientAuthorizationError::InvalidHeader);
};

// Reconstruct the request from the parts
let req = Request::from_parts(parts, body);
Some(Authorization::Basic(
decoded.username().to_owned(),
decoded.password().to_owned(),
))
} else if bytes.len() >= 7 && bytes[..7].eq_ignore_ascii_case(b"Bearer ") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
} else if bytes.len() >= 7 && bytes[..7].eq_ignore_ascii_case(b"Bearer ") {
} else if bytes.starts_with(b"Bearer ") {

let Some(decoded) = Bearer::decode(header) else {
return Err(ClientAuthorizationError::InvalidHeader);
};

Some(Authorization::Bearer(decoded.token().to_owned()))
} else {
return Err(ClientAuthorizationError::InvalidHeader);
}
} else {
None
};

// Take the form value
let (
Expand Down Expand Up @@ -407,13 +431,19 @@ where

// And now, figure out the actual auth method
let credentials = match (
credentials_from_header,
authorization,
client_id_from_form,
client_secret_from_form,
client_assertion_type,
client_assertion,
) {
(Some((client_id, client_secret)), client_id_from_form, None, None, None) => {
(
Some(Authorization::Basic(client_id, client_secret)),
client_id_from_form,
None,
None,
None,
) => {
if let Some(client_id_from_form) = client_id_from_form {
// If the client_id was in the body, verify it matches with the header
if client_id != client_id_from_form {
Expand Down Expand Up @@ -483,6 +513,11 @@ where
});
}

(Some(Authorization::Bearer(token)), None, None, None, None) => {
// Got a bearer token
Credentials::BearerToken { token }
}

(None, None, None, None, None) => {
// Special case when there are no credentials anywhere
return Err(ClientAuthorizationError::MissingCredentials);
Expand Down Expand Up @@ -677,4 +712,29 @@ mod tests {
jwt.verify_with_shared_secret(b"client-secret".to_vec())
.unwrap();
}

#[tokio::test]
async fn bearer_token_test() {
let req = Request::builder()
.method(Method::POST)
.header(
http::header::CONTENT_TYPE,
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
)
.header(http::header::AUTHORIZATION, "Bearer token")
.body(Body::new("foo=bar".to_owned()))
.unwrap();

assert_eq!(
ClientAuthorization::<serde_json::Value>::from_request(req, &())
.await
.unwrap(),
ClientAuthorization {
credentials: Credentials::BearerToken {
token: "token".to_owned(),
},
form: Some(serde_json::json!({"foo": "bar"})),
}
);
}
}
98 changes: 77 additions & 21 deletions crates/handlers/src/oauth2/introspection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.

use std::sync::LazyLock;
use std::sync::{Arc, LazyLock};

use axum::{Json, extract::State, http::HeaderValue, response::IntoResponse};
use hyper::{HeaderMap, StatusCode};
Expand All @@ -15,6 +15,7 @@ use mas_axum_utils::{
use mas_data_model::{Device, TokenFormatError, TokenType};
use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint};
use mas_keystore::Encrypter;
use mas_matrix::HomeserverConnection;
use mas_storage::{
BoxClock, BoxRepository, Clock,
compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository},
Expand Down Expand Up @@ -102,8 +103,14 @@ pub enum RouteError {
#[error("bad request")]
BadRequest,

#[error("failed to verify token")]
FailedToVerifyToken(#[source] anyhow::Error),

#[error(transparent)]
ClientCredentialsVerification(#[from] CredentialsVerificationError),

#[error("bearer token presented is invalid")]
InvalidBearerToken,
}

impl IntoResponse for RouteError {
Expand All @@ -114,13 +121,15 @@ impl IntoResponse for RouteError {
| Self::CantLoadCompatSession(_)
| Self::CantLoadOAuthSession(_)
| Self::CantLoadUser(_)
| Self::FailedToVerifyToken(_)
);

let response = match self {
e @ (Self::Internal(_)
| Self::CantLoadCompatSession(_)
| Self::CantLoadOAuthSession(_)
| Self::CantLoadUser(_)) => (
| Self::CantLoadUser(_)
| Self::FailedToVerifyToken(_)) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(
ClientError::from(ClientErrorCode::ServerError).with_description(e.to_string()),
Expand All @@ -140,6 +149,14 @@ impl IntoResponse for RouteError {
),
)
.into_response(),
e @ Self::InvalidBearerToken => (
StatusCode::UNAUTHORIZED,
Json(
ClientError::from(ClientErrorCode::AccessDenied)
.with_description(e.to_string()),
),
)
.into_response(),

Self::UnknownToken(_)
| Self::UnexpectedTokenType
Expand Down Expand Up @@ -195,7 +212,7 @@ const SYNAPSE_ADMIN_SCOPE: ScopeToken = ScopeToken::from_static("urn:synapse:adm

#[tracing::instrument(
name = "handlers.oauth2.introspection.post",
fields(client.id = client_authorization.client_id()),
fields(client.id = credentials.client_id()),
skip_all,
)]
#[allow(clippy::too_many_lines)]
Expand All @@ -205,28 +222,41 @@ pub(crate) async fn post(
mut repo: BoxRepository,
activity_tracker: ActivityTracker,
State(encrypter): State<Encrypter>,
State(homeserver): State<Arc<dyn HomeserverConnection>>,
headers: HeaderMap,
client_authorization: ClientAuthorization<IntrospectionRequest>,
ClientAuthorization { credentials, form }: ClientAuthorization<IntrospectionRequest>,
) -> Result<impl IntoResponse, RouteError> {
let client = client_authorization
.credentials
.fetch(&mut repo)
.await?
.ok_or(RouteError::ClientNotFound)?;

let method = match &client.token_endpoint_auth_method {
None | Some(OAuthClientAuthenticationMethod::None) => {
return Err(RouteError::NotAllowed(client.id));
if let Some(token) = credentials.bearer_token() {
// If the client presented a bearer token, we check with the homeserver
// connection if it is allowed to use the introspection endpoint
if !homeserver
.verify_token(token)
.await
.map_err(RouteError::FailedToVerifyToken)?
{
return Err(RouteError::InvalidBearerToken);
}
Some(c) => c,
};
} else {
// Otherwise, it presented regular client credentials, so we verify them
let client = credentials
.fetch(&mut repo)
.await?
.ok_or(RouteError::ClientNotFound)?;

// Only confidential clients are allowed to introspect
let method = match &client.token_endpoint_auth_method {
None | Some(OAuthClientAuthenticationMethod::None) => {
return Err(RouteError::NotAllowed(client.id));
}
Some(c) => c,
};

client_authorization
.credentials
.verify(&http_client, &encrypter, method, &client)
.await?;
credentials
.verify(&http_client, &encrypter, method, &client)
.await?;
}

let Some(form) = client_authorization.form else {
let Some(form) = form else {
return Err(RouteError::BadRequest);
};

Expand Down Expand Up @@ -578,10 +608,11 @@ mod tests {
use hyper::{Request, StatusCode};
use mas_data_model::{AccessToken, RefreshToken};
use mas_iana::oauth::OAuthTokenTypeHint;
use mas_matrix::{HomeserverConnection, ProvisionRequest};
use mas_matrix::{HomeserverConnection, MockHomeserverConnection, ProvisionRequest};
use mas_router::{OAuth2Introspection, OAuth2RegistrationEndpoint, SimpleRoute};
use mas_storage::Clock;
use oauth2_types::{
errors::{ClientError, ClientErrorCode},
registration::ClientRegistrationResponse,
requests::IntrospectionResponse,
scope::{OPENID, Scope},
Expand Down Expand Up @@ -984,4 +1015,29 @@ mod tests {
let response: IntrospectionResponse = response.json();
assert!(response.active);
}

#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_introspect_with_bearer_token(pool: PgPool) {
setup();
let state = TestState::from_pool(pool).await.unwrap();

// Check that talking to the introspection endpoint with the bearer token from
// the MockHomeserverConnection doens't error out
let request = Request::post(OAuth2Introspection::PATH)
.bearer(MockHomeserverConnection::VALID_BEARER_TOKEN)
.form(json!({ "token": "some_token" }));
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let response: IntrospectionResponse = response.json();
assert!(!response.active);

// Check with another token, we should get a 401
let request = Request::post(OAuth2Introspection::PATH)
.bearer("another_token")
.form(json!({ "token": "some_token" }));
let response = state.request(request).await;
response.assert_status(StatusCode::UNAUTHORIZED);
let response: ClientError = response.json();
assert_eq!(response.error, ClientErrorCode::AccessDenied);
}
}
5 changes: 5 additions & 0 deletions crates/matrix-synapse/src/legacy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,11 @@ impl HomeserverConnection for SynapseConnection {
&self.homeserver
}

#[tracing::instrument(name = "homeserver.verify_token", skip_all, err(Debug))]
async fn verify_token(&self, token: &str) -> Result<bool, anyhow::Error> {
Ok(self.access_token == token)
}

#[tracing::instrument(
name = "homeserver.query_user",
skip_all,
Expand Down
5 changes: 5 additions & 0 deletions crates/matrix-synapse/src/modern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ impl HomeserverConnection for SynapseConnection {
&self.homeserver
}

#[tracing::instrument(name = "homeserver.verify_token", skip_all, err(Debug))]
async fn verify_token(&self, token: &str) -> Result<bool, anyhow::Error> {
Ok(self.access_token == token)
}

#[tracing::instrument(
name = "homeserver.query_user",
skip_all,
Expand Down
Loading
Loading