Skip to content

Commit 7a9150f

Browse files
committed
check onconflict value in post request
1 parent 21dae3c commit 7a9150f

File tree

1 file changed

+67
-43
lines changed
  • crates/handlers/src/upstream_oauth2

1 file changed

+67
-43
lines changed

crates/handlers/src/upstream_oauth2/link.rs

Lines changed: 67 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ const PROVIDER: Key = Key::from_static_str("provider");
6868
const DEFAULT_LOCALPART_TEMPLATE: &str = "{{ user.preferred_username }}";
6969
const DEFAULT_DISPLAYNAME_TEMPLATE: &str = "{{ user.name }}";
7070
const DEFAULT_EMAIL_TEMPLATE: &str = "{{ user.email }}";
71-
const DEFAULT_ON_CONFLICT: UpstreamOAuthProviderOnConflict = UpstreamOAuthProviderOnConflict::Fail;
7271

7372
#[derive(Debug, Error)]
7473
pub(crate) enum RouteError {
@@ -481,34 +480,38 @@ pub(crate) async fn get(
481480
.claims_imports
482481
.localpart
483482
.on_conflict
484-
.unwrap_or(DEFAULT_ON_CONFLICT);
485-
486-
if on_conflict.is_add() {
487-
// new oauth link is allowed
488-
let ctx = UpstreamExistingLinkContext::new(existing_user)
489-
.with_csrf(csrf_token.form_value())
490-
.with_language(locale);
491-
492-
return Ok((
493-
cookie_jar,
494-
Html(templates.render_upstream_oauth2_login_link(&ctx)?)
495-
.into_response(),
496-
));
483+
.unwrap_or_default();
484+
485+
match on_conflict {
486+
UpstreamOAuthProviderOnConflict::Fail => {
487+
// TODO: translate
488+
let ctx = ErrorContext::new()
489+
.with_code("User exists")
490+
.with_description(format!(
491+
r"Upstream account provider returned {localpart:?} as username,
492+
which is not linked to that upstream account. Your homeserver does not allow
493+
linking an upstream account to an existing account"
494+
))
495+
.with_language(&locale);
496+
497+
return Ok((
498+
cookie_jar,
499+
Html(templates.render_error(&ctx)?).into_response(),
500+
));
501+
}
502+
UpstreamOAuthProviderOnConflict::Add => {
503+
// new oauth link is allowed
504+
let ctx = UpstreamExistingLinkContext::new(existing_user)
505+
.with_csrf(csrf_token.form_value())
506+
.with_language(locale);
507+
508+
return Ok((
509+
cookie_jar,
510+
Html(templates.render_upstream_oauth2_login_link(&ctx)?)
511+
.into_response(),
512+
));
513+
}
497514
}
498-
499-
// TODO: translate
500-
let ctx = ErrorContext::new()
501-
.with_code("User exists")
502-
.with_description(format!(
503-
r"Upstream account provider returned {localpart:?} as username,
504-
which is not linked to that upstream account"
505-
))
506-
.with_language(&locale);
507-
508-
return Ok((
509-
cookie_jar,
510-
Html(templates.render_error(&ctx)?).into_response(),
511-
));
512515
}
513516

514517
if !is_available {
@@ -651,7 +654,7 @@ pub(crate) async fn post(
651654

652655
(None, None, FormData::Link) => {
653656
// User already exists, but it is not linked, neither logged in
654-
// Proceed by associating the link to the user and log in the user
657+
// Proceed by associating the link and log in the user
655658
// Upstream_session is used to re-render the username as it is the only source
656659
// of truth
657660

@@ -663,7 +666,6 @@ pub(crate) async fn post(
663666
.await?
664667
.ok_or(RouteError::ProviderNotFound(link.provider_id))?;
665668

666-
// Let's import the username from the localpart claim
667669
let env = environment();
668670

669671
let mut context = AttributeMappingContext::new();
@@ -679,30 +681,52 @@ pub(crate) async fn post(
679681
}
680682
let context = context.build();
681683

682-
//Claims import must be `require` or `force` at this stage
684+
if !provider.claims_imports.localpart.is_forced()
685+
|| !provider.claims_imports.localpart.is_required()
686+
{
687+
//Claims import for `localpart` should be `require` or `force` at this stage
688+
return Err(RouteError::InvalidFormAction);
689+
}
690+
683691
let template = provider
684692
.claims_imports
685693
.localpart
686694
.template
687695
.as_deref()
688696
.unwrap_or(DEFAULT_LOCALPART_TEMPLATE);
689697

690-
let username = render_attribute_template(&env, template, &context, true)?;
698+
let localpart = render_attribute_template(&env, template, &context, true)?;
691699

692-
let maybe_user = repo.user().find_by_username(&username.unwrap()).await?;
700+
let maybe_user = repo.user().find_by_username(&localpart.unwrap()).await?;
693701

694-
if maybe_user.is_some() {
695-
let user = maybe_user.unwrap();
702+
if maybe_user.is_none() {
703+
//user can not be None at this stage
704+
return Err(RouteError::InvalidFormAction);
705+
}
696706

697-
repo.upstream_oauth_link()
698-
.associate_to_user(&link, &user)
699-
.await?;
707+
let user = maybe_user.unwrap();
700708

701-
repo.browser_session()
702-
.add(&mut rng, &clock, &user, user_agent)
703-
.await?
704-
} else {
705-
return Err(RouteError::InvalidFormAction);
709+
let on_conflict = provider
710+
.claims_imports
711+
.localpart
712+
.on_conflict
713+
.unwrap_or_default();
714+
715+
match on_conflict {
716+
UpstreamOAuthProviderOnConflict::Fail => {
717+
//OnConflict can not be equals to Fail at this stage
718+
return Err(RouteError::InvalidFormAction);
719+
}
720+
UpstreamOAuthProviderOnConflict::Add => {
721+
//add link to the user
722+
repo.upstream_oauth_link()
723+
.associate_to_user(&link, &user)
724+
.await?;
725+
726+
repo.browser_session()
727+
.add(&mut rng, &clock, &user, user_agent)
728+
.await?
729+
}
706730
}
707731
}
708732

0 commit comments

Comments
 (0)