Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/cli/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ pub fn build_router(
}
mas_config::HttpResource::OAuth => router.merge(mas_handlers::api_router::<AppState>()),
mas_config::HttpResource::Compat => {
router.merge(mas_handlers::compat_router::<AppState>())
router.merge(mas_handlers::compat_router::<AppState>(templates.clone()))
}
mas_config::HttpResource::AdminApi => {
let (_, api_router) = mas_handlers::admin_api_router::<AppState>();
Expand Down
14 changes: 7 additions & 7 deletions crates/handlers/src/compat/login_sso_redirect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use axum::{
response::IntoResponse,
};
use hyper::StatusCode;
use mas_axum_utils::record_error;
use mas_axum_utils::{GenericError, InternalError};
use mas_router::{CompatLoginSsoAction, CompatLoginSsoComplete, UrlBuilder};
use mas_storage::{BoxClock, BoxRepository, BoxRng, compat::CompatSsoLoginRepository};
use rand::distributions::{Alphanumeric, DistString};
Expand Down Expand Up @@ -43,12 +43,12 @@ impl_from_error_for_route!(mas_storage::RepositoryError);

impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let sentry_event_id = record_error!(self, Self::Internal(_));
let status_code = match &self {
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::MissingRedirectUrl | Self::InvalidRedirectUrl => StatusCode::BAD_REQUEST,
};
(status_code, sentry_event_id, format!("{self}")).into_response()
match self {
Self::Internal(e) => InternalError::new(e).into_response(),
Self::MissingRedirectUrl | Self::InvalidRedirectUrl => {
GenericError::new(StatusCode::BAD_REQUEST, self).into_response()
}
}
}
}

Expand Down
72 changes: 45 additions & 27 deletions crates/handlers/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ where
}

#[allow(clippy::trait_duplication_in_bounds)]
pub fn compat_router<S>() -> Router<S>
pub fn compat_router<S>(templates: Templates) -> Router<S>
where
S: Clone + Send + Sync + 'static,
UrlBuilder: FromRef<S>,
Expand All @@ -272,7 +272,28 @@ where
BoxClock: FromRequestParts<S>,
BoxRng: FromRequestParts<S>,
{
Router::new()
// A sub-router for human-facing routes with error handling
let human_router = Router::new()
.route(
mas_router::CompatLoginSsoRedirect::route(),
get(self::compat::login_sso_redirect::get),
)
.route(
mas_router::CompatLoginSsoRedirectIdp::route(),
get(self::compat::login_sso_redirect::get),
)
.route(
mas_router::CompatLoginSsoRedirectSlash::route(),
get(self::compat::login_sso_redirect::get),
)
.layer(AndThenLayer::new(
async move |response: axum::response::Response| {
Ok::<_, Infallible>(recover_error(&templates, response))
},
));

// A sub-router for API-facing routes with CORS
let api_router = Router::new()
.route(
mas_router::CompatLogin::route(),
get(self::compat::login::get).post(self::compat::login::post),
Expand All @@ -289,18 +310,6 @@ where
mas_router::CompatRefresh::route(),
post(self::compat::refresh::post),
)
.route(
mas_router::CompatLoginSsoRedirect::route(),
get(self::compat::login_sso_redirect::get),
)
.route(
mas_router::CompatLoginSsoRedirectIdp::route(),
get(self::compat::login_sso_redirect::get),
)
.route(
mas_router::CompatLoginSsoRedirectSlash::route(),
get(self::compat::login_sso_redirect::get),
)
.layer(
CorsLayer::new()
.allow_origin(Any)
Expand All @@ -314,7 +323,9 @@ where
HeaderName::from_static("x-requested-with"),
])
.max_age(Duration::from_secs(60 * 60)),
)
);

Router::new().merge(human_router).merge(api_router)
}

#[allow(clippy::too_many_lines)]
Expand Down Expand Up @@ -454,22 +465,29 @@ where
)
.layer(AndThenLayer::new(
async move |response: axum::response::Response| {
// Error responses should have an ErrorContext attached to them
let ext = response.extensions().get::<ErrorContext>();
if let Some(ctx) = ext {
if let Ok(res) = templates.render_error(ctx) {
let (mut parts, _original_body) = response.into_parts();
parts.headers.remove(CONTENT_TYPE);
parts.headers.remove(CONTENT_LENGTH);
return Ok((parts, Html(res)).into_response());
}
}

Ok::<_, Infallible>(response)
Ok::<_, Infallible>(recover_error(&templates, response))
},
))
}

fn recover_error(
templates: &Templates,
response: axum::response::Response,
) -> axum::response::Response {
// Error responses should have an ErrorContext attached to them
let ext = response.extensions().get::<ErrorContext>();
if let Some(ctx) = ext {
if let Ok(res) = templates.render_error(ctx) {
let (mut parts, _original_body) = response.into_parts();
parts.headers.remove(CONTENT_TYPE);
parts.headers.remove(CONTENT_LENGTH);
return (parts, Html(res)).into_response();
}
}

response
}

/// The fallback handler for all routes that don't match anything else.
///
/// # Errors
Expand Down
18 changes: 10 additions & 8 deletions crates/handlers/src/oauth2/authorization/consent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ use axum::{
use axum_extra::TypedHeader;
use hyper::StatusCode;
use mas_axum_utils::{
GenericError, InternalError,
cookies::CookieJar,
csrf::{CsrfExt, ProtectedForm},
record_error,
};
use mas_data_model::AuthorizationGrantStage;
use mas_keystore::Keystore;
Expand Down Expand Up @@ -64,13 +64,15 @@ impl_from_error_for_route!(super::callback::CallbackDestinationError);

impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let sentry_event_id = record_error!(self, Self::Internal(_) | Self::NoSuchClient(_));
(
StatusCode::INTERNAL_SERVER_ERROR,
sentry_event_id,
self.to_string(),
)
.into_response()
match self {
Self::Internal(e) => InternalError::new(e).into_response(),
e @ Self::NoSuchClient(_) => InternalError::new(Box::new(e)).into_response(),
e @ Self::GrantNotFound => GenericError::new(StatusCode::NOT_FOUND, e).into_response(),
e @ Self::GrantNotPending(_) => {
GenericError::new(StatusCode::CONFLICT, e).into_response()
}
e @ Self::Csrf(_) => GenericError::new(StatusCode::BAD_REQUEST, e).into_response(),
}
}
}

Expand Down
32 changes: 9 additions & 23 deletions crates/handlers/src/oauth2/authorization/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use axum::{
response::{IntoResponse, Response},
};
use hyper::StatusCode;
use mas_axum_utils::{SessionInfoExt, cookies::CookieJar, record_error};
use mas_axum_utils::{GenericError, InternalError, SessionInfoExt, cookies::CookieJar};
use mas_data_model::{AuthorizationCode, Pkce};
use mas_router::{PostAuthAction, UrlBuilder};
use mas_storage::{
Expand Down Expand Up @@ -53,29 +53,15 @@ pub enum RouteError {

impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let sentry_event_id = record_error!(self, Self::Internal(_));
// TODO: better error pages
let response = match self {
RouteError::Internal(e) => {
(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response()
match self {
Self::Internal(e) => InternalError::new(e).into_response(),
e @ (Self::ClientNotFound
| Self::InvalidResponseMode
| Self::IntoCallbackDestination(_)
| Self::UnknownRedirectUri(_)) => {
GenericError::new(StatusCode::BAD_REQUEST, e).into_response()
}
RouteError::ClientNotFound => {
(StatusCode::BAD_REQUEST, "could not find client").into_response()
}
RouteError::InvalidResponseMode => {
(StatusCode::BAD_REQUEST, "invalid response mode").into_response()
}
RouteError::IntoCallbackDestination(e) => {
(StatusCode::BAD_REQUEST, e.to_string()).into_response()
}
RouteError::UnknownRedirectUri(e) => (
StatusCode::BAD_REQUEST,
format!("Invalid redirect URI ({e})"),
)
.into_response(),
};

(sentry_event_id, response).into_response()
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/handlers/src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ impl TestState {
let app = crate::healthcheck_router()
.merge(crate::discovery_router())
.merge(crate::api_router())
.merge(crate::compat_router())
.merge(crate::compat_router(self.templates.clone()))
.merge(crate::human_router(self.templates.clone()))
// We enable undocumented_oauth2_access for the tests, as it is easier to query the API
// with it
Expand Down
15 changes: 7 additions & 8 deletions crates/handlers/src/upstream_oauth2/authorize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use axum::{
response::{IntoResponse, Redirect},
};
use hyper::StatusCode;
use mas_axum_utils::{cookies::CookieJar, record_error};
use mas_axum_utils::{GenericError, InternalError, cookies::CookieJar};
use mas_data_model::UpstreamOAuthProvider;
use mas_oidc_client::requests::authorization_code::AuthorizationRequestData;
use mas_router::{PostAuthAction, UrlBuilder};
Expand Down Expand Up @@ -41,13 +41,12 @@ impl_from_error_for_route!(mas_storage::RepositoryError);

impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let sentry_event_id = record_error!(self, Self::Internal(_));
let response = match self {
Self::ProviderNotFound => (StatusCode::NOT_FOUND, "Provider not found").into_response(),
Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
};

(sentry_event_id, response).into_response()
match self {
e @ Self::ProviderNotFound => {
GenericError::new(StatusCode::NOT_FOUND, e).into_response()
}
Self::Internal(e) => InternalError::new(e).into_response(),
}
}
}

Expand Down
18 changes: 8 additions & 10 deletions crates/handlers/src/upstream_oauth2/callback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use axum::{
response::{Html, IntoResponse, Response},
};
use hyper::StatusCode;
use mas_axum_utils::{cookies::CookieJar, record_error};
use mas_axum_utils::{GenericError, InternalError, cookies::CookieJar};
use mas_data_model::{UpstreamOAuthProvider, UpstreamOAuthProviderResponseMode};
use mas_jose::claims::TokenHash;
use mas_keystore::{Encrypter, Keystore};
Expand Down Expand Up @@ -153,15 +153,13 @@ impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);

impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let sentry_event_id = record_error!(self, Self::Internal(_));
let response = match self {
Self::ProviderNotFound => (StatusCode::NOT_FOUND, "Provider not found").into_response(),
Self::SessionNotFound => (StatusCode::NOT_FOUND, "Session not found").into_response(),
Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
e => (StatusCode::BAD_REQUEST, e.to_string()).into_response(),
};

(sentry_event_id, response).into_response()
match self {
Self::Internal(e) => InternalError::new(e).into_response(),
e @ (Self::ProviderNotFound | Self::SessionNotFound) => {
GenericError::new(StatusCode::NOT_FOUND, e).into_response()
}
e => GenericError::new(StatusCode::BAD_REQUEST, e).into_response(),
}
}
}

Expand Down
16 changes: 12 additions & 4 deletions crates/oauth2-types/src/oidc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,10 @@ impl ProviderMetadata {
let metadata = self.insecure_verify_metadata()?;

if metadata.issuer() != issuer {
return Err(ProviderMetadataVerificationError::IssuerUrlsDontMatch);
return Err(ProviderMetadataVerificationError::IssuerUrlsDontMatch {
expected: issuer.to_owned(),
actual: metadata.issuer().to_owned(),
});
}

validate_url(
Expand Down Expand Up @@ -1064,8 +1067,13 @@ pub enum ProviderMetadataVerificationError {
UrlWithFragment(&'static str, Url),

/// The issuer URL doesn't match the one that was discovered.
#[error("issuer URLs don't match")]
IssuerUrlsDontMatch,
#[error("issuer URLs don't match: expected {expected:?}, got {actual:?}")]
IssuerUrlsDontMatch {
/// The expected issuer URL.
expected: String,
/// The issuer URL that was discovered.
actual: String,
},

/// `openid` is missing from the supported scopes.
#[error("missing openid scope")]
Expand Down Expand Up @@ -1314,7 +1322,7 @@ mod tests {
metadata.issuer = Some("https://example.com/".to_owned());
assert_matches!(
metadata.clone().validate(&issuer),
Err(ProviderMetadataVerificationError::IssuerUrlsDontMatch)
Err(ProviderMetadataVerificationError::IssuerUrlsDontMatch { .. })
);

// Err - Not https
Expand Down
Loading