Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/data-model/src/users.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ pub struct UserRegistration {
pub email_authentication_id: Option<Ulid>,
pub user_registration_token_id: Option<Ulid>,
pub password: Option<UserRegistrationPassword>,
pub upstream_oauth_authorization_session_id: Option<Ulid>,
pub post_auth_action: Option<serde_json::Value>,
pub ip_address: Option<IpAddr>,
pub user_agent: Option<String>,
Expand Down
2 changes: 1 addition & 1 deletion crates/handlers/src/graphql/mutations/user_email.rs
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ impl UserEmailMutations {

let authentication = repo
.user_email()
.complete_authentication(&clock, authentication, &code)
.complete_authentication_with_code(&clock, authentication, &code)
.await?;

// Check the email is not already in use by anyone, including the current user
Expand Down
208 changes: 138 additions & 70 deletions crates/handlers/src/upstream_oauth2/link.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ use mas_policy::Policy;
use mas_router::UrlBuilder;
use mas_storage::{
BoxRepository, RepositoryAccess,
queue::{ProvisionUserJob, QueueJobRepositoryExt as _},
upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository},
user::{BrowserSessionRepository, UserEmailRepository, UserRepository},
};
Expand All @@ -46,7 +45,7 @@ use super::{
};
use crate::{
BoundActivityTracker, METER, PreferredLanguage, SiteConfig, impl_from_error_for_route,
views::shared::OptionalPostAuthAction,
views::{register::UserRegistrationSessionsCookie, shared::OptionalPostAuthAction},
};

static LOGIN_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
Expand Down Expand Up @@ -610,10 +609,6 @@ pub(crate) async fn post(
.lookup_link(link_id)
.map_err(|_| RouteError::MissingCookie)?;

let post_auth_action = OptionalPostAuthAction {
post_auth_action: post_auth_action.cloned(),
};

let link = repo
.upstream_oauth_link()
.lookup(link_id)
Expand Down Expand Up @@ -641,15 +636,35 @@ pub(crate) async fn post(
let maybe_user_session = user_session_info.load_active_session(&mut repo).await?;
let form_state = form.to_form_state();

let session = match (maybe_user_session, link.user_id, form) {
match (maybe_user_session, link.user_id, form) {
(Some(session), None, FormData::Link) => {
// The user is already logged in, the link is not linked to any user, and the
// user asked to link their account.
repo.upstream_oauth_link()
.associate_to_user(&link, &session.user)
.await?;

session
let upstream_session = repo
.upstream_oauth_session()
.consume(&clock, upstream_session)
.await?;

repo.browser_session()
.authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
.await?;

let post_auth_action = OptionalPostAuthAction {
post_auth_action: post_auth_action.cloned(),
};

let cookie_jar = sessions_cookie
.consume_link(link_id)?
.save(cookie_jar, &clock);
let cookie_jar = cookie_jar.set_session(&session);

repo.save().await?;

Ok((cookie_jar, post_auth_action.go_next(&url_builder)).into_response())
}

(None, None, FormData::Link) => {
Expand Down Expand Up @@ -714,14 +729,38 @@ pub(crate) async fn post(
return Err(RouteError::InvalidFormAction);
}
UpstreamOAuthProviderOnConflict::Add => {
//add link to the user
// Add link to the user
repo.upstream_oauth_link()
.associate_to_user(&link, &user)
.await?;

repo.browser_session()
// And sign in the user
let session = repo
.browser_session()
.add(&mut rng, &clock, &user, user_agent)
.await?
.await?;

let upstream_session = repo
.upstream_oauth_session()
.consume(&clock, upstream_session)
.await?;

repo.browser_session()
.authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
.await?;

let post_auth_action = OptionalPostAuthAction {
post_auth_action: post_auth_action.cloned(),
};

let cookie_jar = sessions_cookie
.consume_link(link_id)?
.save(cookie_jar, &clock);
let cookie_jar = cookie_jar.set_session(&session);

repo.save().await?;

Ok((cookie_jar, post_auth_action.go_next(&url_builder)).into_response())
}
}
}
Expand Down Expand Up @@ -950,61 +989,84 @@ pub(crate) async fn post(

REGISTRATION_COUNTER.add(1, &[KeyValue::new(PROVIDER, provider.id.to_string())]);

// Now we can create the user
let user = repo.user().add(&mut rng, &clock, username).await?;
let mut registration = repo
.user_registration()
.add(
&mut rng,
&clock,
username,
activity_tracker.ip(),
user_agent,
post_auth_action.map(|action| serde_json::json!(action)),
)
.await?;

if let Some(terms_url) = &site_config.tos_uri {
repo.user_terms()
.accept_terms(&mut rng, &clock, &user, terms_url.clone())
registration = repo
.user_registration()
.set_terms_url(registration, terms_url.clone())
.await?;
}

// And schedule the job to provision it
let mut job = ProvisionUserJob::new(&user);
// If we have an email, add an email authentication and complete it
if let Some(email) = email {
let authentication = repo
.user_email()
.add_authentication_for_registration(&mut rng, &clock, email, &registration)
.await?;
let authentication = repo
.user_email()
.complete_authentication_with_upstream(
&clock,
authentication,
&upstream_session,
)
.await?;

// If we have a display name, set it during provisioning
if let Some(name) = display_name {
job = job.set_display_name(name);
registration = repo
.user_registration()
.set_email_authentication(registration, &authentication)
.await?;
}

repo.queue_job().schedule_job(&mut rng, &clock, job).await?;

// If we have an email, add it to the user
if let Some(email) = email {
repo.user_email()
.add(&mut rng, &clock, &user, email)
// If we have a display name, add it to the registration
if let Some(name) = display_name {
registration = repo
.user_registration()
.set_display_name(registration, name)
.await?;
}

repo.upstream_oauth_link()
.associate_to_user(&link, &user)
let registration = repo
.user_registration()
.set_upstream_oauth_authorization_session(registration, &upstream_session)
.await?;

repo.browser_session()
.add(&mut rng, &clock, &user, user_agent)
.await?
}
repo.upstream_oauth_session()
.consume(&clock, upstream_session)
.await?;

_ => return Err(RouteError::InvalidFormAction),
};
let registrations = UserRegistrationSessionsCookie::load(&cookie_jar);

let upstream_session = repo
.upstream_oauth_session()
.consume(&clock, upstream_session)
.await?;
let cookie_jar = sessions_cookie
.consume_link(link_id)?
.save(cookie_jar, &clock);

repo.browser_session()
.authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
.await?;
let cookie_jar = registrations.add(&registration).save(cookie_jar, &clock);

let cookie_jar = sessions_cookie
.consume_link(link_id)?
.save(cookie_jar, &clock);
let cookie_jar = cookie_jar.set_session(&session);
repo.save().await?;

repo.save().await?;
// Redirect to the user registration flow, in case we have any other step to
// finish
Ok((
cookie_jar,
url_builder.redirect(&mas_router::RegisterFinish::new(registration.id)),
)
.into_response())
}

Ok((cookie_jar, post_auth_action.go_next(&url_builder)).into_response())
_ => Err(RouteError::InvalidFormAction),
}
}

#[cfg(test)]
Expand All @@ -1013,20 +1075,18 @@ mod tests {
use mas_data_model::{
UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProviderClaimsImports,
UpstreamOAuthProviderImportPreference, UpstreamOAuthProviderLocalpartPreference,
UpstreamOAuthProviderTokenAuthMethod,
UpstreamOAuthProviderTokenAuthMethod, UserEmailAuthentication, UserRegistration,
};
use mas_iana::jose::JsonWebSignatureAlg;
use mas_jose::jwt::{JsonWebSignatureHeader, Jwt};
use mas_keystore::Keystore;
use mas_router::Route;
use mas_storage::{
Pagination, Repository, RepositoryError, upstream_oauth2::UpstreamOAuthProviderParams,
user::UserEmailFilter,
};
use mas_storage::{Repository, RepositoryError, upstream_oauth2::UpstreamOAuthProviderParams};
use oauth2_types::scope::{OPENID, Scope};
use rand_chacha::ChaChaRng;
use serde_json::Value;
use sqlx::PgPool;
use ulid::Ulid;

use super::UpstreamSessionsCookie;
use crate::test_utils::{CookieHelper, RequestBuilderExt, ResponseExt, TestState, setup};
Expand Down Expand Up @@ -1188,33 +1248,41 @@ mod tests {
let response = state.request(request).await;
cookies.save_cookies(&response);
response.assert_status(StatusCode::SEE_OTHER);
let location = response.headers().get(hyper::header::LOCATION).unwrap();
// Grab the registration ID from the redirected URL:
// /register/steps/{id}/finish
let registration_id: Ulid = str::from_utf8(location.as_bytes())
.unwrap()
.rsplit('/')
.nth(1)
.expect("Location to have two slashes")
.parse()
.expect("last segment of location to be a ULID");

// Check that we have a registered user, with the email imported
let mut repo = state.repository().await.unwrap();
let user = repo
.user()
.find_by_username("john")
.await
.unwrap()
.expect("user exists");

let link = repo
.upstream_oauth_link()
.find_by_subject(&provider, "subject")
let registration: UserRegistration = repo
.user_registration()
.lookup(registration_id)
.await
.unwrap()
.expect("link exists");
.expect("user registration exists");

assert_eq!(link.user_id, Some(user.id));
assert_eq!(registration.password, None);
assert_eq!(registration.completed_at, None);
assert_eq!(registration.username, "john");

let page = repo
let email_auth_id = registration
.email_authentication_id
.expect("registration should have an email authentication");
let email_auth: UserEmailAuthentication = repo
.user_email()
.list(UserEmailFilter::new().for_user(&user), Pagination::first(1))
.lookup_authentication(email_auth_id)
.await
.unwrap();
let edge = page.edges.first().expect("email exists");

assert_eq!(edge.node.email, "[email protected]");
.unwrap()
.expect("email authentication should exist");
assert_eq!(email_auth.email, "[email protected]");
assert!(email_auth.completed_at.is_some());
}

#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
Expand Down
2 changes: 2 additions & 0 deletions crates/handlers/src/views/register/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ mod cookie;
pub(crate) mod password;
pub(crate) mod steps;

pub use self::cookie::UserRegistrationSessions as UserRegistrationSessionsCookie;

#[tracing::instrument(name = "handlers.views.register.get", skip_all)]
pub(crate) async fn get(
mut rng: BoxRng,
Expand Down
Loading
Loading