-
Notifications
You must be signed in to change notification settings - Fork 67
Unify registrations for local passwords and upstream OAuth registrations #5281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
f7c8a28
ac4f669
fe362d4
e712c23
1e69ea8
5fb37d2
4c3d2ba
61ee8da
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| -- Copyright 2025 Element Creations Ltd. | ||
| -- | ||
| -- SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial | ||
| -- Please see LICENSE in the repository root for full details. | ||
|
|
||
| -- Track what upstream OAuth session to associate during user registration | ||
| ALTER TABLE user_registrations | ||
| ADD COLUMN upstream_oauth_authorization_session_id UUID | ||
| REFERENCES upstream_oauth_authorization_sessions (upstream_oauth_authorization_session_id) | ||
| ON DELETE SET NULL; | ||
|
|
||
| CREATE INDEX user_registrations_upstream_oauth_session_id_idx | ||
| ON user_registrations (upstream_oauth_authorization_session_id); | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,8 +8,8 @@ use std::net::IpAddr; | |
| use async_trait::async_trait; | ||
| use chrono::{DateTime, Utc}; | ||
| use mas_data_model::{ | ||
| Clock, UserEmailAuthentication, UserRegistration, UserRegistrationPassword, | ||
| UserRegistrationToken, | ||
| Clock, UpstreamOAuthAuthorizationSession, UserEmailAuthentication, UserRegistration, | ||
| UserRegistrationPassword, UserRegistrationToken, | ||
| }; | ||
| use mas_storage::user::UserRegistrationRepository; | ||
| use rand::RngCore; | ||
|
|
@@ -46,6 +46,7 @@ struct UserRegistrationLookup { | |
| user_registration_token_id: Option<Uuid>, | ||
| hashed_password: Option<String>, | ||
| hashed_password_version: Option<i32>, | ||
| upstream_oauth_authorization_session_id: Option<Uuid>, | ||
| created_at: DateTime<Utc>, | ||
| completed_at: Option<DateTime<Utc>>, | ||
| } | ||
|
|
@@ -100,6 +101,9 @@ impl TryFrom<UserRegistrationLookup> for UserRegistration { | |
| email_authentication_id: value.email_authentication_id.map(Ulid::from), | ||
| user_registration_token_id: value.user_registration_token_id.map(Ulid::from), | ||
| password, | ||
| upstream_oauth_authorization_session_id: value | ||
| .upstream_oauth_authorization_session_id | ||
| .map(Ulid::from), | ||
| created_at: value.created_at, | ||
| completed_at: value.completed_at, | ||
| }) | ||
|
|
@@ -134,6 +138,7 @@ impl UserRegistrationRepository for PgUserRegistrationRepository<'_> { | |
| , user_registration_token_id | ||
| , hashed_password | ||
| , hashed_password_version | ||
| , upstream_oauth_authorization_session_id | ||
| , created_at | ||
| , completed_at | ||
| FROM user_registrations | ||
|
|
@@ -208,6 +213,7 @@ impl UserRegistrationRepository for PgUserRegistrationRepository<'_> { | |
| email_authentication_id: None, | ||
| user_registration_token_id: None, | ||
| password: None, | ||
| upstream_oauth_authorization_session_id: None, | ||
| }) | ||
| } | ||
|
|
||
|
|
@@ -393,6 +399,42 @@ impl UserRegistrationRepository for PgUserRegistrationRepository<'_> { | |
| Ok(user_registration) | ||
| } | ||
|
|
||
| #[tracing::instrument( | ||
| name = "db.user_registration.set_upstream_oauth_authorization_session", | ||
| skip_all, | ||
| fields( | ||
| db.query.text, | ||
| %user_registration.id, | ||
| %upstream_oauth_authorization_session.id, | ||
| ), | ||
| err, | ||
| )] | ||
| async fn set_upstream_oauth_authorization_session( | ||
| &mut self, | ||
| mut user_registration: UserRegistration, | ||
| upstream_oauth_authorization_session: &UpstreamOAuthAuthorizationSession, | ||
| ) -> Result<UserRegistration, Self::Error> { | ||
| let res = sqlx::query!( | ||
| r#" | ||
| UPDATE user_registrations | ||
| SET upstream_oauth_authorization_session_id = $2 | ||
| WHERE user_registration_id = $1 AND completed_at IS NULL | ||
| "#, | ||
| Uuid::from(user_registration.id), | ||
| Uuid::from(upstream_oauth_authorization_session.id), | ||
| ) | ||
| .traced() | ||
| .execute(&mut *self.conn) | ||
| .await?; | ||
|
|
||
| DatabaseError::ensure_affected_rows(&res, 1)?; | ||
|
|
||
| user_registration.upstream_oauth_authorization_session_id = | ||
| Some(upstream_oauth_authorization_session.id); | ||
|
|
||
| Ok(user_registration) | ||
| } | ||
|
|
||
| #[tracing::instrument( | ||
| name = "db.user_registration.complete", | ||
| skip_all, | ||
|
|
@@ -433,7 +475,14 @@ impl UserRegistrationRepository for PgUserRegistrationRepository<'_> { | |
| mod tests { | ||
| use std::net::{IpAddr, Ipv4Addr}; | ||
|
|
||
| use mas_data_model::{Clock, UserRegistrationPassword, clock::MockClock}; | ||
| use mas_data_model::{ | ||
| Clock, UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode, | ||
| UpstreamOAuthProviderOnBackchannelLogout, UpstreamOAuthProviderPkceMode, | ||
| UpstreamOAuthProviderTokenAuthMethod, UserRegistrationPassword, clock::MockClock, | ||
| }; | ||
| use mas_iana::jose::JsonWebSignatureAlg; | ||
| use mas_storage::upstream_oauth2::UpstreamOAuthProviderParams; | ||
| use oauth2_types::scope::Scope; | ||
| use rand::SeedableRng; | ||
| use rand_chacha::ChaChaRng; | ||
| use sqlx::PgPool; | ||
|
|
@@ -851,4 +900,120 @@ mod tests { | |
| .await; | ||
| assert!(res.is_err()); | ||
| } | ||
|
|
||
| #[sqlx::test(migrator = "crate::MIGRATOR")] | ||
| async fn test_set_upstream_oauth_link(pool: PgPool) { | ||
|
||
| let mut rng = ChaChaRng::seed_from_u64(42); | ||
| let clock = MockClock::default(); | ||
|
|
||
| let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed(); | ||
|
|
||
| let registration = repo | ||
| .user_registration() | ||
| .add(&mut rng, &clock, "alice".to_owned(), None, None, None) | ||
| .await | ||
| .unwrap(); | ||
|
|
||
| assert_eq!(registration.upstream_oauth_authorization_session_id, None); | ||
|
|
||
| let provider = repo | ||
| .upstream_oauth_provider() | ||
| .add( | ||
| &mut rng, | ||
| &clock, | ||
| UpstreamOAuthProviderParams { | ||
| issuer: Some("https://example.com/".to_owned()), | ||
| human_name: Some("Example Ltd.".to_owned()), | ||
| brand_name: None, | ||
| scope: Scope::from_iter([oauth2_types::scope::OPENID]), | ||
| token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None, | ||
| token_endpoint_signing_alg: None, | ||
| id_token_signed_response_alg: JsonWebSignatureAlg::Rs256, | ||
| client_id: "client".to_owned(), | ||
| encrypted_client_secret: None, | ||
| claims_imports: UpstreamOAuthProviderClaimsImports::default(), | ||
| authorization_endpoint_override: None, | ||
| token_endpoint_override: None, | ||
| userinfo_endpoint_override: None, | ||
| fetch_userinfo: false, | ||
| userinfo_signed_response_alg: None, | ||
| jwks_uri_override: None, | ||
| discovery_mode: UpstreamOAuthProviderDiscoveryMode::Oidc, | ||
| pkce_mode: UpstreamOAuthProviderPkceMode::Auto, | ||
| response_mode: None, | ||
| additional_authorization_parameters: Vec::new(), | ||
| forward_login_hint: false, | ||
| ui_order: 0, | ||
| on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing, | ||
| }, | ||
| ) | ||
| .await | ||
| .unwrap(); | ||
|
|
||
| let session = repo | ||
| .upstream_oauth_session() | ||
| .add(&mut rng, &clock, &provider, "state".to_owned(), None, None) | ||
| .await | ||
| .unwrap(); | ||
|
|
||
| let registration = repo | ||
| .user_registration() | ||
| .set_upstream_oauth_authorization_session(registration, &session) | ||
| .await | ||
| .unwrap(); | ||
|
|
||
| assert_eq!( | ||
| registration.upstream_oauth_authorization_session_id, | ||
| Some(session.id) | ||
| ); | ||
|
|
||
| let lookup = repo | ||
| .user_registration() | ||
| .lookup(registration.id) | ||
| .await | ||
| .unwrap() | ||
| .unwrap(); | ||
|
|
||
| assert_eq!( | ||
| lookup.upstream_oauth_authorization_session_id, | ||
| registration.upstream_oauth_authorization_session_id | ||
| ); | ||
|
|
||
| // Setting it again should work | ||
| let registration = repo | ||
| .user_registration() | ||
| .set_upstream_oauth_authorization_session(registration, &session) | ||
| .await | ||
| .unwrap(); | ||
|
|
||
| assert_eq!( | ||
| registration.upstream_oauth_authorization_session_id, | ||
| Some(session.id) | ||
| ); | ||
|
|
||
| let lookup = repo | ||
| .user_registration() | ||
| .lookup(registration.id) | ||
| .await | ||
| .unwrap() | ||
| .unwrap(); | ||
|
|
||
| assert_eq!( | ||
| lookup.upstream_oauth_authorization_session_id, | ||
| registration.upstream_oauth_authorization_session_id | ||
| ); | ||
|
|
||
| // Can't set it once completed | ||
| let registration = repo | ||
| .user_registration() | ||
| .complete(&clock, registration) | ||
| .await | ||
| .unwrap(); | ||
|
|
||
| let res = repo | ||
| .user_registration() | ||
| .set_upstream_oauth_authorization_session(registration, &session) | ||
| .await; | ||
| assert!(res.is_err()); | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is
user_registrationssmall enough that we're happy to do this non-CONCURRENTLY?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might, but that's a good point! Fixed in 4c3d2ba