Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion crates/axum-utils/src/cookies.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,15 @@ impl CookieOption {
cookie.set_http_only(true);
cookie.set_secure(self.secure());
cookie.set_path(self.path().to_owned());
cookie.set_same_site(SameSite::Lax);

// The `form_post` callback requires that, as it means a 3rd party origin will
// POST to MAS. This is presumably fine, as our forms are protected with a CSRF
// token
cookie.set_same_site(if self.secure() {
SameSite::None
} else {
SameSite::Lax
});
cookie
}
}
Expand Down
3 changes: 2 additions & 1 deletion crates/handlers/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,8 @@ where
)
.route(
mas_router::UpstreamOAuth2Callback::route(),
get(self::upstream_oauth2::callback::get),
get(self::upstream_oauth2::callback::handler)
.post(self::upstream_oauth2::callback::handler),
)
.route(
mas_router::UpstreamOAuth2Link::route(),
Expand Down
41 changes: 36 additions & 5 deletions crates/handlers/src/upstream_oauth2/callback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
use axum::{
extract::{Path, Query, State},
response::IntoResponse,
Form,
};
use hyper::StatusCode;
use mas_axum_utils::{cookies::CookieJar, sentry::SentryEventID};
use mas_data_model::UpstreamOAuthProvider;
use mas_data_model::{UpstreamOAuthProvider, UpstreamOAuthProviderResponseMode};
use mas_keystore::{Encrypter, Keystore};
use mas_oidc_client::requests::{
authorization_code::AuthorizationValidationData, jose::JwtVerificationData,
Expand All @@ -35,7 +36,7 @@ use super::{
use crate::{impl_from_error_for_route, upstream_oauth2::cache::MetadataCache};

#[derive(Deserialize)]
pub struct QueryParams {
pub struct Params {
state: String,

#[serde(flatten)]
Expand Down Expand Up @@ -91,6 +92,20 @@ pub(crate) enum RouteError {
#[error("Missing session cookie")]
MissingCookie,

#[error("Missing query parameters")]
MissingQueryParams,

#[error("Missing form parameters")]
MissingFormParams,

#[error("Ambiguous parameters: got both query and form parameters")]
AmbiguousParams,

#[error("Invalid response mode, expected '{expected}'")]
InvalidParamsMode {
expected: UpstreamOAuthProviderResponseMode,
},

#[error(transparent)]
Internal(Box<dyn std::error::Error>),
}
Expand All @@ -117,13 +132,13 @@ impl IntoResponse for RouteError {
}

#[tracing::instrument(
name = "handlers.upstream_oauth2.callback.get",
name = "handlers.upstream_oauth2.callback.handler",
fields(upstream_oauth_provider.id = %provider_id),
skip_all,
err,
)]
#[allow(clippy::too_many_lines, clippy::too_many_arguments)]
pub(crate) async fn get(
pub(crate) async fn handler(
mut rng: BoxRng,
clock: BoxClock,
State(metadata_cache): State<MetadataCache>,
Expand All @@ -134,7 +149,8 @@ pub(crate) async fn get(
State(client): State<reqwest::Client>,
cookie_jar: CookieJar,
Path(provider_id): Path<Ulid>,
Query(params): Query<QueryParams>,
query_params: Option<Query<Params>>,
form_params: Option<Form<Params>>,
) -> Result<impl IntoResponse, RouteError> {
let provider = repo
.upstream_oauth_provider()
Expand All @@ -143,6 +159,21 @@ pub(crate) async fn get(
.filter(UpstreamOAuthProvider::enabled)
.ok_or(RouteError::ProviderNotFound)?;

// Read the parameters from the query or the form, depending on what
// response_mode the provider uses
let params = match (provider.response_mode, query_params, form_params) {
(UpstreamOAuthProviderResponseMode::Query, Some(query_params), None) => query_params.0,
(UpstreamOAuthProviderResponseMode::FormPost, None, Some(form_params)) => form_params.0,
(UpstreamOAuthProviderResponseMode::Query, None, None) => {
return Err(RouteError::MissingQueryParams)
}
(UpstreamOAuthProviderResponseMode::FormPost, None, None) => {
return Err(RouteError::MissingFormParams)
}
(_, Some(_), Some(_)) => return Err(RouteError::AmbiguousParams),
(expected, _, _) => return Err(RouteError::InvalidParamsMode { expected }),
};

let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
let (session_id, _post_auth_action) = sessions_cookie
.find_session(provider_id, &params.state)
Expand Down