Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/data-model/src/users.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ pub struct UserRegistration {
pub email_authentication_id: Option<Ulid>,
pub user_registration_token_id: Option<Ulid>,
pub password: Option<UserRegistrationPassword>,
pub upstream_oauth_authorization_session_id: Option<Ulid>,
pub post_auth_action: Option<serde_json::Value>,
pub ip_address: Option<IpAddr>,
pub user_agent: Option<String>,
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
-- Copyright 2025 Element Creations Ltd.
--
-- SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
-- Please see LICENSE in the repository root for full details.

-- Track what upstream OAuth session to associate during user registration
ALTER TABLE user_registrations
ADD COLUMN upstream_oauth_authorization_session_id UUID
REFERENCES upstream_oauth_authorization_sessions (upstream_oauth_authorization_session_id)
ON DELETE SET NULL;

CREATE INDEX user_registrations_upstream_oauth_session_id_idx
ON user_registrations (upstream_oauth_authorization_session_id);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is user_registrations small enough that we're happy to do this non-CONCURRENTLY?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might, but that's a good point! Fixed in 4c3d2ba

171 changes: 168 additions & 3 deletions crates/storage-pg/src/user/registration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ use std::net::IpAddr;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{
Clock, UserEmailAuthentication, UserRegistration, UserRegistrationPassword,
UserRegistrationToken,
Clock, UpstreamOAuthAuthorizationSession, UserEmailAuthentication, UserRegistration,
UserRegistrationPassword, UserRegistrationToken,
};
use mas_storage::user::UserRegistrationRepository;
use rand::RngCore;
Expand Down Expand Up @@ -46,6 +46,7 @@ struct UserRegistrationLookup {
user_registration_token_id: Option<Uuid>,
hashed_password: Option<String>,
hashed_password_version: Option<i32>,
upstream_oauth_authorization_session_id: Option<Uuid>,
created_at: DateTime<Utc>,
completed_at: Option<DateTime<Utc>>,
}
Expand Down Expand Up @@ -100,6 +101,9 @@ impl TryFrom<UserRegistrationLookup> for UserRegistration {
email_authentication_id: value.email_authentication_id.map(Ulid::from),
user_registration_token_id: value.user_registration_token_id.map(Ulid::from),
password,
upstream_oauth_authorization_session_id: value
.upstream_oauth_authorization_session_id
.map(Ulid::from),
created_at: value.created_at,
completed_at: value.completed_at,
})
Expand Down Expand Up @@ -134,6 +138,7 @@ impl UserRegistrationRepository for PgUserRegistrationRepository<'_> {
, user_registration_token_id
, hashed_password
, hashed_password_version
, upstream_oauth_authorization_session_id
, created_at
, completed_at
FROM user_registrations
Expand Down Expand Up @@ -208,6 +213,7 @@ impl UserRegistrationRepository for PgUserRegistrationRepository<'_> {
email_authentication_id: None,
user_registration_token_id: None,
password: None,
upstream_oauth_authorization_session_id: None,
})
}

Expand Down Expand Up @@ -393,6 +399,42 @@ impl UserRegistrationRepository for PgUserRegistrationRepository<'_> {
Ok(user_registration)
}

#[tracing::instrument(
name = "db.user_registration.set_upstream_oauth_authorization_session",
skip_all,
fields(
db.query.text,
%user_registration.id,
%upstream_oauth_authorization_session.id,
),
err,
)]
async fn set_upstream_oauth_authorization_session(
&mut self,
mut user_registration: UserRegistration,
upstream_oauth_authorization_session: &UpstreamOAuthAuthorizationSession,
) -> Result<UserRegistration, Self::Error> {
let res = sqlx::query!(
r#"
UPDATE user_registrations
SET upstream_oauth_authorization_session_id = $2
WHERE user_registration_id = $1 AND completed_at IS NULL
"#,
Uuid::from(user_registration.id),
Uuid::from(upstream_oauth_authorization_session.id),
)
.traced()
.execute(&mut *self.conn)
.await?;

DatabaseError::ensure_affected_rows(&res, 1)?;

user_registration.upstream_oauth_authorization_session_id =
Some(upstream_oauth_authorization_session.id);

Ok(user_registration)
}

#[tracing::instrument(
name = "db.user_registration.complete",
skip_all,
Expand Down Expand Up @@ -433,7 +475,14 @@ impl UserRegistrationRepository for PgUserRegistrationRepository<'_> {
mod tests {
use std::net::{IpAddr, Ipv4Addr};

use mas_data_model::{Clock, UserRegistrationPassword, clock::MockClock};
use mas_data_model::{
Clock, UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode,
UpstreamOAuthProviderOnBackchannelLogout, UpstreamOAuthProviderPkceMode,
UpstreamOAuthProviderTokenAuthMethod, UserRegistrationPassword, clock::MockClock,
};
use mas_iana::jose::JsonWebSignatureAlg;
use mas_storage::upstream_oauth2::UpstreamOAuthProviderParams;
use oauth2_types::scope::Scope;
use rand::SeedableRng;
use rand_chacha::ChaChaRng;
use sqlx::PgPool;
Expand Down Expand Up @@ -851,4 +900,120 @@ mod tests {
.await;
assert!(res.is_err());
}

#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_set_upstream_oauth_link(pool: PgPool) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe _link is a slight misnomer / not consistent with the other methods

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a leftover for when it was the link that was attached to registrations, not the session! Fixed in 61ee8da

let mut rng = ChaChaRng::seed_from_u64(42);
let clock = MockClock::default();

let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();

let registration = repo
.user_registration()
.add(&mut rng, &clock, "alice".to_owned(), None, None, None)
.await
.unwrap();

assert_eq!(registration.upstream_oauth_authorization_session_id, None);

let provider = repo
.upstream_oauth_provider()
.add(
&mut rng,
&clock,
UpstreamOAuthProviderParams {
issuer: Some("https://example.com/".to_owned()),
human_name: Some("Example Ltd.".to_owned()),
brand_name: None,
scope: Scope::from_iter([oauth2_types::scope::OPENID]),
token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
token_endpoint_signing_alg: None,
id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
client_id: "client".to_owned(),
encrypted_client_secret: None,
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
authorization_endpoint_override: None,
token_endpoint_override: None,
userinfo_endpoint_override: None,
fetch_userinfo: false,
userinfo_signed_response_alg: None,
jwks_uri_override: None,
discovery_mode: UpstreamOAuthProviderDiscoveryMode::Oidc,
pkce_mode: UpstreamOAuthProviderPkceMode::Auto,
response_mode: None,
additional_authorization_parameters: Vec::new(),
forward_login_hint: false,
ui_order: 0,
on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
},
)
.await
.unwrap();

let session = repo
.upstream_oauth_session()
.add(&mut rng, &clock, &provider, "state".to_owned(), None, None)
.await
.unwrap();

let registration = repo
.user_registration()
.set_upstream_oauth_authorization_session(registration, &session)
.await
.unwrap();

assert_eq!(
registration.upstream_oauth_authorization_session_id,
Some(session.id)
);

let lookup = repo
.user_registration()
.lookup(registration.id)
.await
.unwrap()
.unwrap();

assert_eq!(
lookup.upstream_oauth_authorization_session_id,
registration.upstream_oauth_authorization_session_id
);

// Setting it again should work
let registration = repo
.user_registration()
.set_upstream_oauth_authorization_session(registration, &session)
.await
.unwrap();

assert_eq!(
registration.upstream_oauth_authorization_session_id,
Some(session.id)
);

let lookup = repo
.user_registration()
.lookup(registration.id)
.await
.unwrap()
.unwrap();

assert_eq!(
lookup.upstream_oauth_authorization_session_id,
registration.upstream_oauth_authorization_session_id
);

// Can't set it once completed
let registration = repo
.user_registration()
.complete(&clock, registration)
.await
.unwrap();

let res = repo
.user_registration()
.set_upstream_oauth_authorization_session(registration, &session)
.await;
assert!(res.is_err());
}
}