Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 85 additions & 105 deletions crates/handlers/src/upstream_oauth2/link.rs
Original file line number Diff line number Diff line change
Expand Up @@ -672,136 +672,116 @@ pub(crate) async fn post(
ctx
};

let forced_username = if provider.claims_imports.localpart.is_forced() {
let username = if provider.claims_imports.localpart.is_forced() {
let template = provider
.claims_imports
.localpart
.template
.as_deref()
.unwrap_or(DEFAULT_LOCALPART_TEMPLATE);

render_attribute_template(
&env,
template,
&context,
provider.claims_imports.email.is_required(),
)?
render_attribute_template(&env, template, &context, true)?
} else {
None
};

// If there is no forced username, we can use the one the user entered
let username = forced_username
.or(username)
.filter(|username| !username.is_empty());

let Some(username) = username else {
// We're missing a username, let's re-render the form with an error
let form_state = form_state.with_error_on_field(
mas_templates::UpstreamRegisterFormField::Username,
FieldError::Required,
);

let ctx = ctx
.with_form_state(form_state)
.with_csrf(csrf_token.form_value())
.with_language(locale);
return Ok((
cookie_jar,
Html(templates.render_upstream_oauth2_do_register(&ctx)?),
)
.into_response());
};
// If there is no forced username, we can use the one the user entered
username
}
.unwrap_or_default();

let ctx = ctx.with_localpart(
username.clone(),
provider.claims_imports.localpart.is_forced(),
);

// Check if there is an existing user
let existing_user = repo.user().find_by_username(&username).await?;
// Validate the form
let form_state = {
let mut form_state = form_state;
let mut homeserver_denied_username = false;
if username.is_empty() {
form_state.add_error_on_field(
mas_templates::UpstreamRegisterFormField::Username,
FieldError::Required,
);
} else if repo.user().exists(&username).await? {
form_state.add_error_on_field(
mas_templates::UpstreamRegisterFormField::Username,
FieldError::Exists,
);
} else if !homeserver
.is_localpart_available(&username)
.await
.map_err(RouteError::HomeserverConnection)?
{
// The user already exists on the homeserver
tracing::warn!(
%username,
"Homeserver denied username provided by user"
);

// We defer adding the error on the field, until we know whether we had another
// error from the policy, to avoid showing both
homeserver_denied_username = true;
}

// Ask the homeserver to make sure the username is valid
let is_available = homeserver
.is_localpart_available(&username)
.await
.map_err(RouteError::HomeserverConnection)?;
// If we have a TOS in the config, make sure the user has accepted it
if site_config.tos_uri.is_some() && !accept_terms {
form_state.add_error_on_field(
mas_templates::UpstreamRegisterFormField::AcceptTerms,
FieldError::Required,
);
}

if existing_user.is_some() || !is_available {
// If there is an existing user, we can't create a new one
// with the same username, show an error
// Policy check
let res = policy
.evaluate_register(mas_policy::RegisterInput {
registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2,
username: &username,
email: email.as_deref(),
requester: mas_policy::Requester {
ip_address: activity_tracker.ip(),
user_agent: user_agent.clone().map(|ua| ua.raw),
},
})
.await?;

let form_state = form_state.with_error_on_field(
mas_templates::UpstreamRegisterFormField::Username,
FieldError::Exists,
);
for violation in res.violations {
match violation.field.as_deref() {
Some("username") => {
// If the homeserver denied the username, but we also had an error on
// the policy side, we don't want to show
// both, so we reset the state here
homeserver_denied_username = false;
form_state.add_error_on_field(
mas_templates::UpstreamRegisterFormField::Username,
FieldError::Policy {
code: violation.code.map(|c| c.as_str()),
message: violation.msg,
},
);
}
_ => form_state.add_error_on_form(FormError::Policy {
code: violation.code.map(|c| c.as_str()),
message: violation.msg,
}),
}
}

let ctx = ctx
.with_form_state(form_state)
.with_csrf(csrf_token.form_value())
.with_language(locale);
return Ok((
cookie_jar,
Html(templates.render_upstream_oauth2_do_register(&ctx)?),
)
.into_response());
}
if homeserver_denied_username {
// XXX: we may want to return different errors like "this username is reserved"
form_state.add_error_on_field(
mas_templates::UpstreamRegisterFormField::Username,
FieldError::Exists,
);
}

// If we need have a TOS in the config, make sure the user has accepted it
if site_config.tos_uri.is_some() && !accept_terms {
let form_state = form_state.with_error_on_field(
mas_templates::UpstreamRegisterFormField::AcceptTerms,
FieldError::Required,
);
form_state
};

if !form_state.is_valid() {
let ctx = ctx
.with_form_state(form_state)
.with_csrf(csrf_token.form_value())
.with_language(locale);
return Ok((
cookie_jar,
Html(templates.render_upstream_oauth2_do_register(&ctx)?),
)
.into_response());
}

// Policy check
let res = policy
.evaluate_register(mas_policy::RegisterInput {
registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2,
username: &username,
email: email.as_deref(),
requester: mas_policy::Requester {
ip_address: activity_tracker.ip(),
user_agent: user_agent.clone().map(|ua| ua.raw),
},
})
.await?;

if !res.valid() {
let form_state =
res.violations
.into_iter()
.fold(form_state, |form_state, violation| {
match violation.field.as_deref() {
Some("username") => form_state.with_error_on_field(
mas_templates::UpstreamRegisterFormField::Username,
FieldError::Policy {
code: violation.code.map(|c| c.as_str()),
message: violation.msg,
},
),
_ => form_state.with_error_on_form(FormError::Policy {
code: violation.code.map(|c| c.as_str()),
message: violation.msg,
}),
}
});

let ctx = ctx
.with_form_state(form_state)
.with_csrf(csrf_token.form_value())
.with_language(locale);
return Ok((
cookie_jar,
Html(templates.render_upstream_oauth2_do_register(&ctx)?),
Expand Down
Loading