diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index f8631112e..e59a7514d 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -672,7 +672,7 @@ pub(crate) async fn post( ctx }; - let forced_username = if provider.claims_imports.localpart.is_forced() { + let username = if provider.claims_imports.localpart.is_forced() { let template = provider .claims_imports .localpart @@ -680,128 +680,108 @@ pub(crate) async fn post( .as_deref() .unwrap_or(DEFAULT_LOCALPART_TEMPLATE); - render_attribute_template( - &env, - template, - &context, - provider.claims_imports.email.is_required(), - )? + render_attribute_template(&env, template, &context, true)? } else { - None - }; - - // If there is no forced username, we can use the one the user entered - let username = forced_username - .or(username) - .filter(|username| !username.is_empty()); - - let Some(username) = username else { - // We're missing a username, let's re-render the form with an error - let form_state = form_state.with_error_on_field( - mas_templates::UpstreamRegisterFormField::Username, - FieldError::Required, - ); - - let ctx = ctx - .with_form_state(form_state) - .with_csrf(csrf_token.form_value()) - .with_language(locale); - return Ok(( - cookie_jar, - Html(templates.render_upstream_oauth2_do_register(&ctx)?), - ) - .into_response()); - }; + // If there is no forced username, we can use the one the user entered + username + } + .unwrap_or_default(); let ctx = ctx.with_localpart( username.clone(), provider.claims_imports.localpart.is_forced(), ); - // Check if there is an existing user - let existing_user = repo.user().find_by_username(&username).await?; + // Validate the form + let form_state = { + let mut form_state = form_state; + let mut homeserver_denied_username = false; + if username.is_empty() { + form_state.add_error_on_field( + mas_templates::UpstreamRegisterFormField::Username, + FieldError::Required, + ); + } else if repo.user().exists(&username).await? { + form_state.add_error_on_field( + mas_templates::UpstreamRegisterFormField::Username, + FieldError::Exists, + ); + } else if !homeserver + .is_localpart_available(&username) + .await + .map_err(RouteError::HomeserverConnection)? + { + // The user already exists on the homeserver + tracing::warn!( + %username, + "Homeserver denied username provided by user" + ); + + // We defer adding the error on the field, until we know whether we had another + // error from the policy, to avoid showing both + homeserver_denied_username = true; + } - // Ask the homeserver to make sure the username is valid - let is_available = homeserver - .is_localpart_available(&username) - .await - .map_err(RouteError::HomeserverConnection)?; + // If we have a TOS in the config, make sure the user has accepted it + if site_config.tos_uri.is_some() && !accept_terms { + form_state.add_error_on_field( + mas_templates::UpstreamRegisterFormField::AcceptTerms, + FieldError::Required, + ); + } - if existing_user.is_some() || !is_available { - // If there is an existing user, we can't create a new one - // with the same username, show an error + // Policy check + let res = policy + .evaluate_register(mas_policy::RegisterInput { + registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2, + username: &username, + email: email.as_deref(), + requester: mas_policy::Requester { + ip_address: activity_tracker.ip(), + user_agent: user_agent.clone().map(|ua| ua.raw), + }, + }) + .await?; - let form_state = form_state.with_error_on_field( - mas_templates::UpstreamRegisterFormField::Username, - FieldError::Exists, - ); + for violation in res.violations { + match violation.field.as_deref() { + Some("username") => { + // If the homeserver denied the username, but we also had an error on + // the policy side, we don't want to show + // both, so we reset the state here + homeserver_denied_username = false; + form_state.add_error_on_field( + mas_templates::UpstreamRegisterFormField::Username, + FieldError::Policy { + code: violation.code.map(|c| c.as_str()), + message: violation.msg, + }, + ); + } + _ => form_state.add_error_on_form(FormError::Policy { + code: violation.code.map(|c| c.as_str()), + message: violation.msg, + }), + } + } - let ctx = ctx - .with_form_state(form_state) - .with_csrf(csrf_token.form_value()) - .with_language(locale); - return Ok(( - cookie_jar, - Html(templates.render_upstream_oauth2_do_register(&ctx)?), - ) - .into_response()); - } + if homeserver_denied_username { + // XXX: we may want to return different errors like "this username is reserved" + form_state.add_error_on_field( + mas_templates::UpstreamRegisterFormField::Username, + FieldError::Exists, + ); + } - // If we need have a TOS in the config, make sure the user has accepted it - if site_config.tos_uri.is_some() && !accept_terms { - let form_state = form_state.with_error_on_field( - mas_templates::UpstreamRegisterFormField::AcceptTerms, - FieldError::Required, - ); + form_state + }; + if !form_state.is_valid() { let ctx = ctx .with_form_state(form_state) .with_csrf(csrf_token.form_value()) .with_language(locale); - return Ok(( - cookie_jar, - Html(templates.render_upstream_oauth2_do_register(&ctx)?), - ) - .into_response()); - } - - // Policy check - let res = policy - .evaluate_register(mas_policy::RegisterInput { - registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2, - username: &username, - email: email.as_deref(), - requester: mas_policy::Requester { - ip_address: activity_tracker.ip(), - user_agent: user_agent.clone().map(|ua| ua.raw), - }, - }) - .await?; - - if !res.valid() { - let form_state = - res.violations - .into_iter() - .fold(form_state, |form_state, violation| { - match violation.field.as_deref() { - Some("username") => form_state.with_error_on_field( - mas_templates::UpstreamRegisterFormField::Username, - FieldError::Policy { - code: violation.code.map(|c| c.as_str()), - message: violation.msg, - }, - ), - _ => form_state.with_error_on_form(FormError::Policy { - code: violation.code.map(|c| c.as_str()), - message: violation.msg, - }), - } - }); - let ctx = ctx - .with_form_state(form_state) - .with_csrf(csrf_token.form_value()) - .with_language(locale); return Ok(( cookie_jar, Html(templates.render_upstream_oauth2_do_register(&ctx)?),