Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 43 additions & 38 deletions crates/handlers/src/upstream_oauth2/callback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,32 +41,33 @@ use crate::{impl_from_error_for_route, upstream_oauth2::cache::MetadataCache, Pr

#[derive(Serialize, Deserialize)]
pub struct Params {
state: String,
state: Option<String>,

/// An extra parameter to track whether the POST request was re-made by us
/// to the same URL to escape Same-Site cookies restrictions
#[serde(default)]
did_mas_repost_to_itself: bool,

code: Option<String>,

error: Option<ClientErrorCode>,
error_description: Option<String>,
#[allow(dead_code)]
error_uri: Option<String>,

#[serde(flatten)]
code_or_error: CodeOrError,
extra_callback_parameters: Option<serde_json::Value>,
}

#[derive(Serialize, Deserialize)]
#[serde(untagged)]
enum CodeOrError {
Code {
code: String,

#[serde(flatten)]
extra_callback_parameters: Option<serde_json::Value>,
},
Error {
error: ClientErrorCode,
error_description: Option<String>,
#[allow(dead_code)]
error_uri: Option<String>,
},
impl Params {
/// Returns true if none of the fields are set
pub fn is_empty(&self) -> bool {
self.state.is_none()
&& self.code.is_none()
&& self.error.is_none()
&& self.error_description.is_none()
&& self.error_uri.is_none()
}
}

#[derive(Debug, Error)]
Expand All @@ -86,6 +87,12 @@ pub(crate) enum RouteError {
#[error("State parameter mismatch")]
StateMismatch,

#[error("Missing state parameter")]
MissingState,

#[error("Missing code parameter")]
MissingCode,

#[error("Could not extract subject from ID token")]
ExtractSubject(#[source] minijinja::Error),

Expand Down Expand Up @@ -161,7 +168,7 @@ pub(crate) async fn handler(
PreferredLanguage(locale): PreferredLanguage,
cookie_jar: CookieJar,
Path(provider_id): Path<Ulid>,
Form(params): Form<Option<Params>>,
Form(params): Form<Params>,
) -> Result<Response, RouteError> {
let provider = repo
.upstream_oauth_provider()
Expand All @@ -172,7 +179,7 @@ pub(crate) async fn handler(

let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);

let Some(params) = params else {
if params.is_empty() {
if let Method::GET = method {
return Err(RouteError::MissingQueryParams);
}
Expand Down Expand Up @@ -204,8 +211,19 @@ pub(crate) async fn handler(
(Some(expected), _) => return Err(RouteError::InvalidResponseMode { expected }),
}

if let Some(error) = params.error {
return Err(RouteError::ClientError {
error,
error_description: params.error_description.clone(),
});
}

let Some(state) = params.state else {
return Err(RouteError::MissingState);
};

let (session_id, _post_auth_action) = sessions_cookie
.find_session(provider_id, &params.state)
.find_session(provider_id, &state)
.map_err(|_| RouteError::MissingCookie)?;

let session = repo
Expand All @@ -219,7 +237,7 @@ pub(crate) async fn handler(
return Err(RouteError::ProviderMismatch);
}

if params.state != session.state_str {
if state != session.state_str {
// The state in the session cookie should match the one from the params
return Err(RouteError::StateMismatch);
}
Expand All @@ -230,21 +248,8 @@ pub(crate) async fn handler(
}

// Let's extract the code from the params, and return if there was an error
let (code, extra_callback_parameters) = match params.code_or_error {
CodeOrError::Error {
error,
error_description,
..
} => {
return Err(RouteError::ClientError {
error,
error_description,
})
}
CodeOrError::Code {
code,
extra_callback_parameters,
} => (code, extra_callback_parameters),
let Some(code) = params.code else {
return Err(RouteError::MissingCode);
};

let mut lazy_metadata = LazyProviderInfos::new(&metadata_cache, &provider, &client);
Expand Down Expand Up @@ -326,7 +331,7 @@ pub(crate) async fn handler(
context = context.with_id_token_claims(claims);
}

if let Some(extra_callback_parameters) = extra_callback_parameters.clone() {
if let Some(extra_callback_parameters) = params.extra_callback_parameters.clone() {
context = context.with_extra_callback_parameters(extra_callback_parameters);
}

Expand Down Expand Up @@ -432,7 +437,7 @@ pub(crate) async fn handler(
session,
&link,
token_response.id_token,
extra_callback_parameters,
params.extra_callback_parameters,
userinfo,
)
.await?;
Expand Down
Loading