@@ -672,136 +672,116 @@ pub(crate) async fn post(
672672 ctx
673673 } ;
674674
675- let forced_username = if provider. claims_imports . localpart . is_forced ( ) {
675+ let username = if provider. claims_imports . localpart . is_forced ( ) {
676676 let template = provider
677677 . claims_imports
678678 . localpart
679679 . template
680680 . as_deref ( )
681681 . unwrap_or ( DEFAULT_LOCALPART_TEMPLATE ) ;
682682
683- render_attribute_template (
684- & env,
685- template,
686- & context,
687- provider. claims_imports . email . is_required ( ) ,
688- ) ?
683+ render_attribute_template ( & env, template, & context, true ) ?
689684 } else {
690- None
691- } ;
692-
693- // If there is no forced username, we can use the one the user entered
694- let username = forced_username
695- . or ( username)
696- . filter ( |username| !username. is_empty ( ) ) ;
697-
698- let Some ( username) = username else {
699- // We're missing a username, let's re-render the form with an error
700- let form_state = form_state. with_error_on_field (
701- mas_templates:: UpstreamRegisterFormField :: Username ,
702- FieldError :: Required ,
703- ) ;
704-
705- let ctx = ctx
706- . with_form_state ( form_state)
707- . with_csrf ( csrf_token. form_value ( ) )
708- . with_language ( locale) ;
709- return Ok ( (
710- cookie_jar,
711- Html ( templates. render_upstream_oauth2_do_register ( & ctx) ?) ,
712- )
713- . into_response ( ) ) ;
714- } ;
685+ // If there is no forced username, we can use the one the user entered
686+ username
687+ }
688+ . unwrap_or_default ( ) ;
715689
716690 let ctx = ctx. with_localpart (
717691 username. clone ( ) ,
718692 provider. claims_imports . localpart . is_forced ( ) ,
719693 ) ;
720694
721- // Check if there is an existing user
722- let existing_user = repo. user ( ) . find_by_username ( & username) . await ?;
695+ // Validate the form
696+ let form_state = {
697+ let mut form_state = form_state;
698+ let mut homeserver_denied_username = false ;
699+ if username. is_empty ( ) {
700+ form_state. add_error_on_field (
701+ mas_templates:: UpstreamRegisterFormField :: Username ,
702+ FieldError :: Required ,
703+ ) ;
704+ } else if repo. user ( ) . exists ( & username) . await ? {
705+ form_state. add_error_on_field (
706+ mas_templates:: UpstreamRegisterFormField :: Username ,
707+ FieldError :: Exists ,
708+ ) ;
709+ } else if !homeserver
710+ . is_localpart_available ( & username)
711+ . await
712+ . map_err ( RouteError :: HomeserverConnection ) ?
713+ {
714+ // The user already exists on the homeserver
715+ tracing:: warn!(
716+ %username,
717+ "Homeserver denied username provided by user"
718+ ) ;
719+
720+ // We defer adding the error on the field, until we know whether we had another
721+ // error from the policy, to avoid showing both
722+ homeserver_denied_username = true ;
723+ }
723724
724- // Ask the homeserver to make sure the username is valid
725- let is_available = homeserver
726- . is_localpart_available ( & username)
727- . await
728- . map_err ( RouteError :: HomeserverConnection ) ?;
725+ // If we have a TOS in the config, make sure the user has accepted it
726+ if site_config. tos_uri . is_some ( ) && !accept_terms {
727+ form_state. add_error_on_field (
728+ mas_templates:: UpstreamRegisterFormField :: AcceptTerms ,
729+ FieldError :: Required ,
730+ ) ;
731+ }
729732
730- if existing_user. is_some ( ) || !is_available {
731- // If there is an existing user, we can't create a new one
732- // with the same username, show an error
733+ // Policy check
734+ let res = policy
735+ . evaluate_register ( mas_policy:: RegisterInput {
736+ registration_method : mas_policy:: RegistrationMethod :: UpstreamOAuth2 ,
737+ username : & username,
738+ email : email. as_deref ( ) ,
739+ requester : mas_policy:: Requester {
740+ ip_address : activity_tracker. ip ( ) ,
741+ user_agent : user_agent. clone ( ) . map ( |ua| ua. raw ) ,
742+ } ,
743+ } )
744+ . await ?;
733745
734- let form_state = form_state. with_error_on_field (
735- mas_templates:: UpstreamRegisterFormField :: Username ,
736- FieldError :: Exists ,
737- ) ;
746+ for violation in res. violations {
747+ match violation. field . as_deref ( ) {
748+ Some ( "username" ) => {
749+ // If the homeserver denied the username, but we also had an error on
750+ // the policy side, we don't want to show
751+ // both, so we reset the state here
752+ homeserver_denied_username = false ;
753+ form_state. add_error_on_field (
754+ mas_templates:: UpstreamRegisterFormField :: Username ,
755+ FieldError :: Policy {
756+ code : violation. code . map ( |c| c. as_str ( ) ) ,
757+ message : violation. msg ,
758+ } ,
759+ ) ;
760+ }
761+ _ => form_state. add_error_on_form ( FormError :: Policy {
762+ code : violation. code . map ( |c| c. as_str ( ) ) ,
763+ message : violation. msg ,
764+ } ) ,
765+ }
766+ }
738767
739- let ctx = ctx
740- . with_form_state ( form_state)
741- . with_csrf ( csrf_token. form_value ( ) )
742- . with_language ( locale) ;
743- return Ok ( (
744- cookie_jar,
745- Html ( templates. render_upstream_oauth2_do_register ( & ctx) ?) ,
746- )
747- . into_response ( ) ) ;
748- }
768+ if homeserver_denied_username {
769+ // XXX: we may want to return different errors like "this username is reserved"
770+ form_state. add_error_on_field (
771+ mas_templates:: UpstreamRegisterFormField :: Username ,
772+ FieldError :: Exists ,
773+ ) ;
774+ }
749775
750- // If we need have a TOS in the config, make sure the user has accepted it
751- if site_config. tos_uri . is_some ( ) && !accept_terms {
752- let form_state = form_state. with_error_on_field (
753- mas_templates:: UpstreamRegisterFormField :: AcceptTerms ,
754- FieldError :: Required ,
755- ) ;
776+ form_state
777+ } ;
756778
779+ if !form_state. is_valid ( ) {
757780 let ctx = ctx
758781 . with_form_state ( form_state)
759782 . with_csrf ( csrf_token. form_value ( ) )
760783 . with_language ( locale) ;
761- return Ok ( (
762- cookie_jar,
763- Html ( templates. render_upstream_oauth2_do_register ( & ctx) ?) ,
764- )
765- . into_response ( ) ) ;
766- }
767-
768- // Policy check
769- let res = policy
770- . evaluate_register ( mas_policy:: RegisterInput {
771- registration_method : mas_policy:: RegistrationMethod :: UpstreamOAuth2 ,
772- username : & username,
773- email : email. as_deref ( ) ,
774- requester : mas_policy:: Requester {
775- ip_address : activity_tracker. ip ( ) ,
776- user_agent : user_agent. clone ( ) . map ( |ua| ua. raw ) ,
777- } ,
778- } )
779- . await ?;
780-
781- if !res. valid ( ) {
782- let form_state =
783- res. violations
784- . into_iter ( )
785- . fold ( form_state, |form_state, violation| {
786- match violation. field . as_deref ( ) {
787- Some ( "username" ) => form_state. with_error_on_field (
788- mas_templates:: UpstreamRegisterFormField :: Username ,
789- FieldError :: Policy {
790- code : violation. code . map ( |c| c. as_str ( ) ) ,
791- message : violation. msg ,
792- } ,
793- ) ,
794- _ => form_state. with_error_on_form ( FormError :: Policy {
795- code : violation. code . map ( |c| c. as_str ( ) ) ,
796- message : violation. msg ,
797- } ) ,
798- }
799- } ) ;
800784
801- let ctx = ctx
802- . with_form_state ( form_state)
803- . with_csrf ( csrf_token. form_value ( ) )
804- . with_language ( locale) ;
805785 return Ok ( (
806786 cookie_jar,
807787 Html ( templates. render_upstream_oauth2_do_register ( & ctx) ?) ,
0 commit comments