Skip to content

Commit f1981ab

Browse files
authored
Properly accumulate form errors on the upstream register page (#4173)
2 parents d10dd61 + 9668418 commit f1981ab

File tree

1 file changed

+85
-105
lines changed
  • crates/handlers/src/upstream_oauth2

1 file changed

+85
-105
lines changed

crates/handlers/src/upstream_oauth2/link.rs

Lines changed: 85 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -672,136 +672,116 @@ pub(crate) async fn post(
672672
ctx
673673
};
674674

675-
let forced_username = if provider.claims_imports.localpart.is_forced() {
675+
let username = if provider.claims_imports.localpart.is_forced() {
676676
let template = provider
677677
.claims_imports
678678
.localpart
679679
.template
680680
.as_deref()
681681
.unwrap_or(DEFAULT_LOCALPART_TEMPLATE);
682682

683-
render_attribute_template(
684-
&env,
685-
template,
686-
&context,
687-
provider.claims_imports.email.is_required(),
688-
)?
683+
render_attribute_template(&env, template, &context, true)?
689684
} else {
690-
None
691-
};
692-
693-
// If there is no forced username, we can use the one the user entered
694-
let username = forced_username
695-
.or(username)
696-
.filter(|username| !username.is_empty());
697-
698-
let Some(username) = username else {
699-
// We're missing a username, let's re-render the form with an error
700-
let form_state = form_state.with_error_on_field(
701-
mas_templates::UpstreamRegisterFormField::Username,
702-
FieldError::Required,
703-
);
704-
705-
let ctx = ctx
706-
.with_form_state(form_state)
707-
.with_csrf(csrf_token.form_value())
708-
.with_language(locale);
709-
return Ok((
710-
cookie_jar,
711-
Html(templates.render_upstream_oauth2_do_register(&ctx)?),
712-
)
713-
.into_response());
714-
};
685+
// If there is no forced username, we can use the one the user entered
686+
username
687+
}
688+
.unwrap_or_default();
715689

716690
let ctx = ctx.with_localpart(
717691
username.clone(),
718692
provider.claims_imports.localpart.is_forced(),
719693
);
720694

721-
// Check if there is an existing user
722-
let existing_user = repo.user().find_by_username(&username).await?;
695+
// Validate the form
696+
let form_state = {
697+
let mut form_state = form_state;
698+
let mut homeserver_denied_username = false;
699+
if username.is_empty() {
700+
form_state.add_error_on_field(
701+
mas_templates::UpstreamRegisterFormField::Username,
702+
FieldError::Required,
703+
);
704+
} else if repo.user().exists(&username).await? {
705+
form_state.add_error_on_field(
706+
mas_templates::UpstreamRegisterFormField::Username,
707+
FieldError::Exists,
708+
);
709+
} else if !homeserver
710+
.is_localpart_available(&username)
711+
.await
712+
.map_err(RouteError::HomeserverConnection)?
713+
{
714+
// The user already exists on the homeserver
715+
tracing::warn!(
716+
%username,
717+
"Homeserver denied username provided by user"
718+
);
719+
720+
// We defer adding the error on the field, until we know whether we had another
721+
// error from the policy, to avoid showing both
722+
homeserver_denied_username = true;
723+
}
723724

724-
// Ask the homeserver to make sure the username is valid
725-
let is_available = homeserver
726-
.is_localpart_available(&username)
727-
.await
728-
.map_err(RouteError::HomeserverConnection)?;
725+
// If we have a TOS in the config, make sure the user has accepted it
726+
if site_config.tos_uri.is_some() && !accept_terms {
727+
form_state.add_error_on_field(
728+
mas_templates::UpstreamRegisterFormField::AcceptTerms,
729+
FieldError::Required,
730+
);
731+
}
729732

730-
if existing_user.is_some() || !is_available {
731-
// If there is an existing user, we can't create a new one
732-
// with the same username, show an error
733+
// Policy check
734+
let res = policy
735+
.evaluate_register(mas_policy::RegisterInput {
736+
registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2,
737+
username: &username,
738+
email: email.as_deref(),
739+
requester: mas_policy::Requester {
740+
ip_address: activity_tracker.ip(),
741+
user_agent: user_agent.clone().map(|ua| ua.raw),
742+
},
743+
})
744+
.await?;
733745

734-
let form_state = form_state.with_error_on_field(
735-
mas_templates::UpstreamRegisterFormField::Username,
736-
FieldError::Exists,
737-
);
746+
for violation in res.violations {
747+
match violation.field.as_deref() {
748+
Some("username") => {
749+
// If the homeserver denied the username, but we also had an error on
750+
// the policy side, we don't want to show
751+
// both, so we reset the state here
752+
homeserver_denied_username = false;
753+
form_state.add_error_on_field(
754+
mas_templates::UpstreamRegisterFormField::Username,
755+
FieldError::Policy {
756+
code: violation.code.map(|c| c.as_str()),
757+
message: violation.msg,
758+
},
759+
);
760+
}
761+
_ => form_state.add_error_on_form(FormError::Policy {
762+
code: violation.code.map(|c| c.as_str()),
763+
message: violation.msg,
764+
}),
765+
}
766+
}
738767

739-
let ctx = ctx
740-
.with_form_state(form_state)
741-
.with_csrf(csrf_token.form_value())
742-
.with_language(locale);
743-
return Ok((
744-
cookie_jar,
745-
Html(templates.render_upstream_oauth2_do_register(&ctx)?),
746-
)
747-
.into_response());
748-
}
768+
if homeserver_denied_username {
769+
// XXX: we may want to return different errors like "this username is reserved"
770+
form_state.add_error_on_field(
771+
mas_templates::UpstreamRegisterFormField::Username,
772+
FieldError::Exists,
773+
);
774+
}
749775

750-
// If we need have a TOS in the config, make sure the user has accepted it
751-
if site_config.tos_uri.is_some() && !accept_terms {
752-
let form_state = form_state.with_error_on_field(
753-
mas_templates::UpstreamRegisterFormField::AcceptTerms,
754-
FieldError::Required,
755-
);
776+
form_state
777+
};
756778

779+
if !form_state.is_valid() {
757780
let ctx = ctx
758781
.with_form_state(form_state)
759782
.with_csrf(csrf_token.form_value())
760783
.with_language(locale);
761-
return Ok((
762-
cookie_jar,
763-
Html(templates.render_upstream_oauth2_do_register(&ctx)?),
764-
)
765-
.into_response());
766-
}
767-
768-
// Policy check
769-
let res = policy
770-
.evaluate_register(mas_policy::RegisterInput {
771-
registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2,
772-
username: &username,
773-
email: email.as_deref(),
774-
requester: mas_policy::Requester {
775-
ip_address: activity_tracker.ip(),
776-
user_agent: user_agent.clone().map(|ua| ua.raw),
777-
},
778-
})
779-
.await?;
780-
781-
if !res.valid() {
782-
let form_state =
783-
res.violations
784-
.into_iter()
785-
.fold(form_state, |form_state, violation| {
786-
match violation.field.as_deref() {
787-
Some("username") => form_state.with_error_on_field(
788-
mas_templates::UpstreamRegisterFormField::Username,
789-
FieldError::Policy {
790-
code: violation.code.map(|c| c.as_str()),
791-
message: violation.msg,
792-
},
793-
),
794-
_ => form_state.with_error_on_form(FormError::Policy {
795-
code: violation.code.map(|c| c.as_str()),
796-
message: violation.msg,
797-
}),
798-
}
799-
});
800784

801-
let ctx = ctx
802-
.with_form_state(form_state)
803-
.with_csrf(csrf_token.form_value())
804-
.with_language(locale);
805785
return Ok((
806786
cookie_jar,
807787
Html(templates.render_upstream_oauth2_do_register(&ctx)?),

0 commit comments

Comments
 (0)