From 8dac005678e11bd761181dc719b24e2e91603b89 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 11 Feb 2025 15:13:43 +0100 Subject: [PATCH] Fix the upstream OAuth 2.0 callback form deserialisation --- .../handlers/src/upstream_oauth2/callback.rs | 81 ++++++++++--------- 1 file changed, 43 insertions(+), 38 deletions(-) diff --git a/crates/handlers/src/upstream_oauth2/callback.rs b/crates/handlers/src/upstream_oauth2/callback.rs index 935014448..1042652ac 100644 --- a/crates/handlers/src/upstream_oauth2/callback.rs +++ b/crates/handlers/src/upstream_oauth2/callback.rs @@ -41,32 +41,33 @@ use crate::{impl_from_error_for_route, upstream_oauth2::cache::MetadataCache, Pr #[derive(Serialize, Deserialize)] pub struct Params { - state: String, + state: Option, /// An extra parameter to track whether the POST request was re-made by us /// to the same URL to escape Same-Site cookies restrictions #[serde(default)] did_mas_repost_to_itself: bool, + code: Option, + + error: Option, + error_description: Option, + #[allow(dead_code)] + error_uri: Option, + #[serde(flatten)] - code_or_error: CodeOrError, + extra_callback_parameters: Option, } -#[derive(Serialize, Deserialize)] -#[serde(untagged)] -enum CodeOrError { - Code { - code: String, - - #[serde(flatten)] - extra_callback_parameters: Option, - }, - Error { - error: ClientErrorCode, - error_description: Option, - #[allow(dead_code)] - error_uri: Option, - }, +impl Params { + /// Returns true if none of the fields are set + pub fn is_empty(&self) -> bool { + self.state.is_none() + && self.code.is_none() + && self.error.is_none() + && self.error_description.is_none() + && self.error_uri.is_none() + } } #[derive(Debug, Error)] @@ -86,6 +87,12 @@ pub(crate) enum RouteError { #[error("State parameter mismatch")] StateMismatch, + #[error("Missing state parameter")] + MissingState, + + #[error("Missing code parameter")] + MissingCode, + #[error("Could not extract subject from ID token")] ExtractSubject(#[source] minijinja::Error), @@ -161,7 +168,7 @@ pub(crate) async fn handler( PreferredLanguage(locale): PreferredLanguage, cookie_jar: CookieJar, Path(provider_id): Path, - Form(params): Form>, + Form(params): Form, ) -> Result { let provider = repo .upstream_oauth_provider() @@ -172,7 +179,7 @@ pub(crate) async fn handler( let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar); - let Some(params) = params else { + if params.is_empty() { if let Method::GET = method { return Err(RouteError::MissingQueryParams); } @@ -204,8 +211,19 @@ pub(crate) async fn handler( (Some(expected), _) => return Err(RouteError::InvalidResponseMode { expected }), } + if let Some(error) = params.error { + return Err(RouteError::ClientError { + error, + error_description: params.error_description.clone(), + }); + } + + let Some(state) = params.state else { + return Err(RouteError::MissingState); + }; + let (session_id, _post_auth_action) = sessions_cookie - .find_session(provider_id, ¶ms.state) + .find_session(provider_id, &state) .map_err(|_| RouteError::MissingCookie)?; let session = repo @@ -219,7 +237,7 @@ pub(crate) async fn handler( return Err(RouteError::ProviderMismatch); } - if params.state != session.state_str { + if state != session.state_str { // The state in the session cookie should match the one from the params return Err(RouteError::StateMismatch); } @@ -230,21 +248,8 @@ pub(crate) async fn handler( } // Let's extract the code from the params, and return if there was an error - let (code, extra_callback_parameters) = match params.code_or_error { - CodeOrError::Error { - error, - error_description, - .. - } => { - return Err(RouteError::ClientError { - error, - error_description, - }) - } - CodeOrError::Code { - code, - extra_callback_parameters, - } => (code, extra_callback_parameters), + let Some(code) = params.code else { + return Err(RouteError::MissingCode); }; let mut lazy_metadata = LazyProviderInfos::new(&metadata_cache, &provider, &client); @@ -326,7 +331,7 @@ pub(crate) async fn handler( context = context.with_id_token_claims(claims); } - if let Some(extra_callback_parameters) = extra_callback_parameters.clone() { + if let Some(extra_callback_parameters) = params.extra_callback_parameters.clone() { context = context.with_extra_callback_parameters(extra_callback_parameters); } @@ -432,7 +437,7 @@ pub(crate) async fn handler( session, &link, token_response.id_token, - extra_callback_parameters, + params.extra_callback_parameters, userinfo, ) .await?;