diff --git a/crates/cli/src/sync.rs b/crates/cli/src/sync.rs index 363c2a0f8..ff33106af 100644 --- a/crates/cli/src/sync.rs +++ b/crates/cli/src/sync.rs @@ -37,6 +37,19 @@ fn map_import_action( } } +fn map_import_on_conflict( + config: mas_config::UpstreamOAuth2OnConflict, +) -> mas_data_model::UpstreamOAuthProviderOnConflict { + match config { + mas_config::UpstreamOAuth2OnConflict::Add => { + mas_data_model::UpstreamOAuthProviderOnConflict::Add + } + mas_config::UpstreamOAuth2OnConflict::Fail => { + mas_data_model::UpstreamOAuthProviderOnConflict::Fail + } + } +} + fn map_claims_imports( config: &mas_config::UpstreamOAuth2ClaimsImports, ) -> mas_data_model::UpstreamOAuthProviderClaimsImports { @@ -44,9 +57,10 @@ fn map_claims_imports( subject: mas_data_model::UpstreamOAuthProviderSubjectPreference { template: config.subject.template.clone(), }, - localpart: mas_data_model::UpstreamOAuthProviderImportPreference { + localpart: mas_data_model::UpstreamOAuthProviderLocalpartPreference { action: map_import_action(config.localpart.action), template: config.localpart.template.clone(), + on_conflict: map_import_on_conflict(config.localpart.on_conflict), }, displayname: mas_data_model::UpstreamOAuthProviderImportPreference { action: map_import_action(config.displayname.action), diff --git a/crates/config/src/sections/mod.rs b/crates/config/src/sections/mod.rs index ed38fa9b6..f992d8698 100644 --- a/crates/config/src/sections/mod.rs +++ b/crates/config/src/sections/mod.rs @@ -54,8 +54,8 @@ pub use self::{ EmailImportPreference as UpstreamOAuth2EmailImportPreference, ImportAction as UpstreamOAuth2ImportAction, OnBackchannelLogout as UpstreamOAuth2OnBackchannelLogout, - PkceMethod as UpstreamOAuth2PkceMethod, Provider as UpstreamOAuth2Provider, - ResponseMode as UpstreamOAuth2ResponseMode, + OnConflict as UpstreamOAuth2OnConflict, PkceMethod as UpstreamOAuth2PkceMethod, + Provider as UpstreamOAuth2Provider, ResponseMode as UpstreamOAuth2ResponseMode, TokenAuthMethod as UpstreamOAuth2TokenAuthMethod, UpstreamOAuth2Config, }, }; diff --git a/crates/config/src/sections/upstream_oauth2.rs b/crates/config/src/sections/upstream_oauth2.rs index 8d6229848..9b2768423 100644 --- a/crates/config/src/sections/upstream_oauth2.rs +++ b/crates/config/src/sections/upstream_oauth2.rs @@ -117,6 +117,18 @@ impl ConfigurationSection for UpstreamOAuth2Config { } } } + + if matches!( + provider.claims_imports.localpart.on_conflict, + OnConflict::Add + ) && !matches!( + provider.claims_imports.localpart.action, + ImportAction::Force | ImportAction::Require + ) { + return Err(annotate(figment::Error::custom( + "The field `action` must be either `force` or `require` when `on_conflict` is set to `add`", + )).into()); + } } Ok(()) @@ -190,6 +202,26 @@ impl ImportAction { } } +/// How to handle an existing localpart claim +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)] +#[serde(rename_all = "lowercase")] +pub enum OnConflict { + /// Fails the sso login on conflict + #[default] + Fail, + + /// Adds the oauth identity link, regardless of whether there is an existing + /// link or not + Add, +} + +impl OnConflict { + #[allow(clippy::trivially_copy_pass_by_ref)] + const fn is_default(&self) -> bool { + matches!(self, OnConflict::Fail) + } +} + /// What should be done for the subject attribute #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)] pub struct SubjectImportPreference { @@ -218,6 +250,10 @@ pub struct LocalpartImportPreference { /// If not provided, the default template is `{{ user.preferred_username }}` #[serde(default, skip_serializing_if = "Option::is_none")] pub template: Option, + + /// How to handle conflicts on the claim, default value is `Fail` + #[serde(default, skip_serializing_if = "OnConflict::is_default")] + pub on_conflict: OnConflict, } impl LocalpartImportPreference { diff --git a/crates/data-model/src/lib.rs b/crates/data-model/src/lib.rs index 1ed15adcc..f1b551891 100644 --- a/crates/data-model/src/lib.rs +++ b/crates/data-model/src/lib.rs @@ -42,7 +42,8 @@ pub use self::{ UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState, UpstreamOAuthLink, UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode, UpstreamOAuthProviderImportAction, - UpstreamOAuthProviderImportPreference, UpstreamOAuthProviderOnBackchannelLogout, + UpstreamOAuthProviderImportPreference, UpstreamOAuthProviderLocalpartPreference, + UpstreamOAuthProviderOnBackchannelLogout, UpstreamOAuthProviderOnConflict, UpstreamOAuthProviderPkceMode, UpstreamOAuthProviderResponseMode, UpstreamOAuthProviderSubjectPreference, UpstreamOAuthProviderTokenAuthMethod, }, diff --git a/crates/data-model/src/upstream_oauth2/mod.rs b/crates/data-model/src/upstream_oauth2/mod.rs index 8e2638b9b..563716568 100644 --- a/crates/data-model/src/upstream_oauth2/mod.rs +++ b/crates/data-model/src/upstream_oauth2/mod.rs @@ -15,8 +15,9 @@ pub use self::{ DiscoveryMode as UpstreamOAuthProviderDiscoveryMode, ImportAction as UpstreamOAuthProviderImportAction, ImportPreference as UpstreamOAuthProviderImportPreference, + LocalpartPreference as UpstreamOAuthProviderLocalpartPreference, OnBackchannelLogout as UpstreamOAuthProviderOnBackchannelLogout, - PkceMode as UpstreamOAuthProviderPkceMode, + OnConflict as UpstreamOAuthProviderOnConflict, PkceMode as UpstreamOAuthProviderPkceMode, ResponseMode as UpstreamOAuthProviderResponseMode, SubjectPreference as UpstreamOAuthProviderSubjectPreference, TokenAuthMethod as UpstreamOAuthProviderTokenAuthMethod, UpstreamOAuthProvider, diff --git a/crates/data-model/src/upstream_oauth2/provider.rs b/crates/data-model/src/upstream_oauth2/provider.rs index 3a71c03c3..c54e40d15 100644 --- a/crates/data-model/src/upstream_oauth2/provider.rs +++ b/crates/data-model/src/upstream_oauth2/provider.rs @@ -313,7 +313,7 @@ pub struct ClaimsImports { pub subject: SubjectPreference, #[serde(default)] - pub localpart: ImportPreference, + pub localpart: LocalpartPreference, #[serde(default)] pub displayname: ImportPreference, @@ -332,6 +332,26 @@ pub struct SubjectPreference { pub template: Option, } +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] +pub struct LocalpartPreference { + #[serde(default)] + pub action: ImportAction, + + #[serde(default)] + pub template: Option, + + #[serde(default)] + pub on_conflict: OnConflict, +} + +impl std::ops::Deref for LocalpartPreference { + type Target = ImportAction; + + fn deref(&self) -> &Self::Target { + &self.action + } +} + #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] pub struct ImportPreference { #[serde(default)] @@ -368,7 +388,7 @@ pub enum ImportAction { impl ImportAction { #[must_use] - pub fn is_forced(&self) -> bool { + pub fn is_forced_or_required(&self) -> bool { matches!(self, Self::Force | Self::Require) } @@ -391,3 +411,15 @@ impl ImportAction { } } } + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "lowercase")] +pub enum OnConflict { + /// Fails the upstream OAuth 2.0 login + #[default] + Fail, + + /// Adds the upstream account link, regardless of whether there is an + /// existing link or not + Add, +} diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index 934af3626..0182f0538 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -19,6 +19,7 @@ use mas_axum_utils::{ csrf::{CsrfExt, ProtectedForm}, record_error, }; +use mas_data_model::UpstreamOAuthProviderOnConflict; use mas_jose::jwt::Jwt; use mas_matrix::HomeserverConnection; use mas_policy::Policy; @@ -37,7 +38,6 @@ use minijinja::Environment; use opentelemetry::{Key, KeyValue, metrics::Counter}; use serde::{Deserialize, Serialize}; use thiserror::Error; -use tracing::warn; use ulid::Ulid; use super::{ @@ -420,8 +420,10 @@ pub(crate) async fn get( &context, provider.claims_imports.displayname.is_required(), )? { - Some(value) => ctx - .with_display_name(value, provider.claims_imports.displayname.is_forced()), + Some(value) => ctx.with_display_name( + value, + provider.claims_imports.displayname.is_forced_or_required(), + ), None => ctx, } }; @@ -442,7 +444,9 @@ pub(crate) async fn get( &context, provider.claims_imports.email.is_required(), )? { - Some(value) => ctx.with_email(value, provider.claims_imports.email.is_forced()), + Some(value) => { + ctx.with_email(value, provider.claims_imports.email.is_forced_or_required()) + } None => ctx, } }; @@ -473,19 +477,49 @@ pub(crate) async fn get( .await .map_err(RouteError::HomeserverConnection)?; - if maybe_existing_user.is_some() || !is_available { - if let Some(existing_user) = maybe_existing_user { - // The mapper returned a username which already exists, but isn't - // linked to this upstream user. - warn!(username = %localpart, user_id = %existing_user.id, "Localpart template returned an existing username"); + if let Some(existing_user) = maybe_existing_user { + // The mapper returned a username which already exists, but isn't + // linked to this upstream user. + let on_conflict = provider.claims_imports.localpart.on_conflict; + + match on_conflict { + UpstreamOAuthProviderOnConflict::Fail => { + // TODO: translate + let ctx = ErrorContext::new() + .with_code("User exists") + .with_description(format!( + r"Upstream account provider returned {localpart:?} as username, + which is not linked to that upstream account. Your homeserver does not allow + linking an upstream account to an existing account" + )) + .with_language(&locale); + + return Ok(( + cookie_jar, + Html(templates.render_error(&ctx)?).into_response(), + )); + } + UpstreamOAuthProviderOnConflict::Add => { + // new oauth link is allowed + let ctx = UpstreamExistingLinkContext::new(existing_user) + .with_csrf(csrf_token.form_value()) + .with_language(locale); + + return Ok(( + cookie_jar, + Html(templates.render_upstream_oauth2_login_link(&ctx)?) + .into_response(), + )); + } } + } + if !is_available { // TODO: translate let ctx = ErrorContext::new() - .with_code("User exists") + .with_code("Localpart not available") .with_description(format!( - r"Upstream account provider returned {localpart:?} as username, - which is not linked to that upstream account" + r"Localpart {localpart:?} is not available on this homeserver" )) .with_language(&locale); @@ -511,9 +545,9 @@ pub(crate) async fn get( // The username passes the policy check, add it to the context ctx.with_localpart( localpart, - provider.claims_imports.localpart.is_forced(), + provider.claims_imports.localpart.is_forced_or_required(), ) - } else if provider.claims_imports.localpart.is_forced() { + } else if provider.claims_imports.localpart.is_forced_or_required() { // If the username claim is 'forced' but doesn't pass the policy check, // we display an error message. // TODO: translate @@ -618,6 +652,80 @@ pub(crate) async fn post( session } + (None, None, FormData::Link) => { + // There is an existing user with the same username, but no link. + // If the configuration allows it, the user is prompted to link the + // existing account. Note that we cannot trust the user input here, + // which is why we have to re-calculate the localpart, instead of + // passing it through form data. + + let id_token = upstream_session.id_token().map(Jwt::try_from).transpose()?; + + let provider = repo + .upstream_oauth_provider() + .lookup(link.provider_id) + .await? + .ok_or(RouteError::ProviderNotFound(link.provider_id))?; + + let env = environment(); + + let mut context = AttributeMappingContext::new(); + if let Some(id_token) = id_token { + let (_, payload) = id_token.into_parts(); + context = context.with_id_token_claims(payload); + } + if let Some(extra_callback_parameters) = upstream_session.extra_callback_parameters() { + context = context.with_extra_callback_parameters(extra_callback_parameters.clone()); + } + if let Some(userinfo) = upstream_session.userinfo() { + context = context.with_userinfo_claims(userinfo.clone()); + } + let context = context.build(); + + if !provider.claims_imports.localpart.is_forced_or_required() { + //Claims import for `localpart` should be `require` or `force` at this stage + return Err(RouteError::InvalidFormAction); + } + + let template = provider + .claims_imports + .localpart + .template + .as_deref() + .unwrap_or(DEFAULT_LOCALPART_TEMPLATE); + + let Some(localpart) = render_attribute_template(&env, template, &context, true)? else { + // This should never be the case at this point + return Err(RouteError::InvalidFormAction); + }; + + let maybe_user = repo.user().find_by_username(&localpart).await?; + + let Some(user) = maybe_user else { + // user cannot be None at this stage + return Err(RouteError::InvalidFormAction); + }; + + let on_conflict = provider.claims_imports.localpart.on_conflict; + + match on_conflict { + UpstreamOAuthProviderOnConflict::Fail => { + //OnConflict can not be equals to Fail at this stage + return Err(RouteError::InvalidFormAction); + } + UpstreamOAuthProviderOnConflict::Add => { + //add link to the user + repo.upstream_oauth_link() + .associate_to_user(&link, &user) + .await?; + + repo.browser_session() + .add(&mut rng, &clock, &user, user_agent) + .await? + } + } + } + ( None, None, @@ -690,7 +798,7 @@ pub(crate) async fn post( let ctx = if let Some(ref display_name) = display_name { ctx.with_display_name( display_name.clone(), - provider.claims_imports.email.is_forced(), + provider.claims_imports.email.is_forced_or_required(), ) } else { ctx @@ -715,12 +823,15 @@ pub(crate) async fn post( }; let ctx = if let Some(ref email) = email { - ctx.with_email(email.clone(), provider.claims_imports.email.is_forced()) + ctx.with_email( + email.clone(), + provider.claims_imports.email.is_forced_or_required(), + ) } else { ctx }; - let username = if provider.claims_imports.localpart.is_forced() { + let username = if provider.claims_imports.localpart.is_forced_or_required() { let template = provider .claims_imports .localpart @@ -737,7 +848,7 @@ pub(crate) async fn post( let ctx = ctx.with_localpart( username.clone(), - provider.claims_imports.localpart.is_forced(), + provider.claims_imports.localpart.is_forced_or_required(), ); // Validate the form @@ -900,16 +1011,21 @@ pub(crate) async fn post( mod tests { use hyper::{Request, StatusCode, header::CONTENT_TYPE}; use mas_data_model::{ - UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderImportPreference, + UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProviderClaimsImports, + UpstreamOAuthProviderImportPreference, UpstreamOAuthProviderLocalpartPreference, UpstreamOAuthProviderTokenAuthMethod, }; use mas_iana::jose::JsonWebSignatureAlg; use mas_jose::jwt::{JsonWebSignatureHeader, Jwt}; + use mas_keystore::Keystore; use mas_router::Route; use mas_storage::{ - Pagination, upstream_oauth2::UpstreamOAuthProviderParams, user::UserEmailFilter, + Pagination, Repository, RepositoryError, upstream_oauth2::UpstreamOAuthProviderParams, + user::UserEmailFilter, }; use oauth2_types::scope::{OPENID, Scope}; + use rand_chacha::ChaChaRng; + use serde_json::Value; use sqlx::PgPool; use super::UpstreamSessionsCookie; @@ -923,9 +1039,10 @@ mod tests { let cookies = CookieHelper::new(); let claims_imports = UpstreamOAuthProviderClaimsImports { - localpart: UpstreamOAuthProviderImportPreference { + localpart: UpstreamOAuthProviderLocalpartPreference { action: mas_data_model::UpstreamOAuthProviderImportAction::Force, template: None, + on_conflict: mas_data_model::UpstreamOAuthProviderOnConflict::default(), }, email: UpstreamOAuthProviderImportPreference { action: mas_data_model::UpstreamOAuthProviderImportAction::Force, @@ -1099,4 +1216,310 @@ mod tests { assert_eq!(email.email, "john@example.com"); } + + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_link_existing_account(pool: PgPool) { + let existing_username = "john"; + let subject = "subject"; + + setup(); + let state = TestState::from_pool(pool).await.unwrap(); + let mut rng = state.rng(); + let cookies = CookieHelper::new(); + + let claims_imports = UpstreamOAuthProviderClaimsImports { + localpart: UpstreamOAuthProviderLocalpartPreference { + action: mas_data_model::UpstreamOAuthProviderImportAction::Require, + template: None, + on_conflict: mas_data_model::UpstreamOAuthProviderOnConflict::Add, + }, + email: UpstreamOAuthProviderImportPreference { + action: mas_data_model::UpstreamOAuthProviderImportAction::Require, + template: None, + }, + ..UpstreamOAuthProviderClaimsImports::default() + }; + + //`preferred_username` matches an existing user's username + let id_token_claims = serde_json::json!({ + "preferred_username": existing_username, + "email": "any@example.com", + "email_verified": true, + }); + + let id_token = sign_token(&mut rng, &state.key_store, id_token_claims.clone()).unwrap(); + + // Provision a provider and a link + let mut repo = state.repository().await.unwrap(); + let provider = repo + .upstream_oauth_provider() + .add( + &mut rng, + &state.clock, + UpstreamOAuthProviderParams { + issuer: Some("https://example.com/".to_owned()), + human_name: Some("Example Ltd.".to_owned()), + brand_name: None, + scope: Scope::from_iter([OPENID]), + token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None, + token_endpoint_signing_alg: None, + id_token_signed_response_alg: JsonWebSignatureAlg::Rs256, + client_id: "client".to_owned(), + encrypted_client_secret: None, + claims_imports, + authorization_endpoint_override: None, + token_endpoint_override: None, + userinfo_endpoint_override: None, + fetch_userinfo: false, + userinfo_signed_response_alg: None, + jwks_uri_override: None, + discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc, + pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto, + response_mode: None, + additional_authorization_parameters: Vec::new(), + forward_login_hint: false, + on_backchannel_logout: + mas_data_model::UpstreamOAuthProviderOnBackchannelLogout::DoNothing, + ui_order: 0, + }, + ) + .await + .unwrap(); + + //provision upstream authorization session to setup cookies + let (link, session) = add_linked_upstream_session( + &mut rng, + &state.clock, + &mut repo, + &provider, + subject, + &id_token.into_string(), + id_token_claims, + ) + .await + .unwrap(); + + let cookie_jar = state.cookie_jar(); + let upstream_sessions = UpstreamSessionsCookie::default() + .add(session.id, provider.id, "state".to_owned(), None) + .add_link_to_session(session.id, link.id) + .unwrap(); + let cookie_jar = upstream_sessions.save(cookie_jar, &state.clock); + cookies.import(cookie_jar); + + let user = repo + .user() + .add(&mut rng, &state.clock, existing_username.to_owned()) + .await + .unwrap(); + + repo.save().await.unwrap(); + + let request = Request::get(&*mas_router::UpstreamOAuth2Link::new(link.id).path()).empty(); + let request = cookies.with_cookies(request); + let response = state.request(request).await; + cookies.save_cookies(&response); + response.assert_status(StatusCode::OK); + response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8"); + + // Extract the CSRF token from the response body + let csrf_token = response + .body() + .split("name=\"csrf\" value=\"") + .nth(1) + .unwrap() + .split('\"') + .next() + .unwrap(); + + let request = Request::post(&*mas_router::UpstreamOAuth2Link::new(link.id).path()).form( + serde_json::json!({ + "csrf": csrf_token, + "action": "link" + }), + ); + let request = cookies.with_cookies(request); + let response = state.request(request).await; + cookies.save_cookies(&response); + response.assert_status(StatusCode::SEE_OTHER); + + // Check that the existing user has the oidc link + let mut repo = state.repository().await.unwrap(); + + let link = repo + .upstream_oauth_link() + .find_by_subject(&provider, subject) + .await + .unwrap() + .expect("link exists"); + + assert_eq!(link.user_id, Some(user.id)); + } + + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_link_existing_account_when_not_allowed_by_default(pool: PgPool) { + let existing_username = "john"; + + setup(); + let state = TestState::from_pool(pool).await.unwrap(); + let mut rng = state.rng(); + let cookies = CookieHelper::new(); + + let claims_imports = UpstreamOAuthProviderClaimsImports { + localpart: UpstreamOAuthProviderLocalpartPreference { + action: mas_data_model::UpstreamOAuthProviderImportAction::Require, + template: None, + on_conflict: mas_data_model::UpstreamOAuthProviderOnConflict::default(), + }, + email: UpstreamOAuthProviderImportPreference { + action: mas_data_model::UpstreamOAuthProviderImportAction::Require, + template: None, + }, + ..UpstreamOAuthProviderClaimsImports::default() + }; + + // `preferred_username` matches an existing user's username + let id_token_claims = serde_json::json!({ + "preferred_username": existing_username, + "email": "any@example.com", + "email_verified": true, + }); + + let id_token = sign_token(&mut rng, &state.key_store, id_token_claims.clone()).unwrap(); + + // Provision a provider and a link + let mut repo = state.repository().await.unwrap(); + let provider = repo + .upstream_oauth_provider() + .add( + &mut rng, + &state.clock, + UpstreamOAuthProviderParams { + issuer: Some("https://example.com/".to_owned()), + human_name: Some("Example Ltd.".to_owned()), + brand_name: None, + scope: Scope::from_iter([OPENID]), + token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None, + token_endpoint_signing_alg: None, + id_token_signed_response_alg: JsonWebSignatureAlg::Rs256, + client_id: "client".to_owned(), + encrypted_client_secret: None, + claims_imports, + authorization_endpoint_override: None, + token_endpoint_override: None, + userinfo_endpoint_override: None, + fetch_userinfo: false, + userinfo_signed_response_alg: None, + jwks_uri_override: None, + discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc, + pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto, + response_mode: None, + additional_authorization_parameters: Vec::new(), + forward_login_hint: false, + on_backchannel_logout: + mas_data_model::UpstreamOAuthProviderOnBackchannelLogout::DoNothing, + ui_order: 0, + }, + ) + .await + .unwrap(); + + let (link, session) = add_linked_upstream_session( + &mut rng, + &state.clock, + &mut repo, + &provider, + "subject", + &id_token.into_string(), + id_token_claims, + ) + .await + .unwrap(); + + // Provision an user + repo.user() + .add(&mut rng, &state.clock, existing_username.to_owned()) + .await + .unwrap(); + + repo.save().await.unwrap(); + + let cookie_jar = state.cookie_jar(); + let upstream_sessions = UpstreamSessionsCookie::default() + .add(session.id, provider.id, "state".to_owned(), None) + .add_link_to_session(session.id, link.id) + .unwrap(); + let cookie_jar = upstream_sessions.save(cookie_jar, &state.clock); + cookies.import(cookie_jar); + + let request = Request::get(&*mas_router::UpstreamOAuth2Link::new(link.id).path()).empty(); + let request = cookies.with_cookies(request); + let response = state.request(request).await; + cookies.save_cookies(&response); + response.assert_status(StatusCode::OK); + response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8"); + + assert!(response.body().contains("Unexpected error")); + } + + fn sign_token( + rng: &mut ChaChaRng, + keystore: &Keystore, + payload: Value, + ) -> Result, mas_jose::jwt::JwtSignatureError> { + let key = keystore + .signing_key_for_algorithm(&JsonWebSignatureAlg::Rs256) + .unwrap(); + + let signer = key + .params() + .signing_key_for_alg(&JsonWebSignatureAlg::Rs256) + .unwrap(); + + let header = JsonWebSignatureHeader::new(JsonWebSignatureAlg::Rs256); + + Jwt::sign_with_rng(rng, header, payload, &signer) + } + + async fn add_linked_upstream_session( + rng: &mut ChaChaRng, + clock: &impl mas_storage::Clock, + repo: &mut Box + Send + Sync + 'static>, + provider: &mas_data_model::UpstreamOAuthProvider, + subject: &str, + id_token: &str, + id_token_claims: Value, + ) -> Result<(UpstreamOAuthLink, UpstreamOAuthAuthorizationSession), anyhow::Error> { + let session = repo + .upstream_oauth_session() + .add( + rng, + clock, + provider, + "state".to_owned(), + None, + Some("nonce".to_owned()), + ) + .await?; + + let link = repo + .upstream_oauth_link() + .add(rng, clock, provider, subject.to_owned(), None) + .await?; + + let session = repo + .upstream_oauth_session() + .complete_with_link( + clock, + session, + &link, + Some(id_token.to_owned()), + Some(id_token_claims), + None, + None, + ) + .await?; + + Ok((link, session)) + } } diff --git a/crates/templates/src/context.rs b/crates/templates/src/context.rs index 33c973d1e..230dee67f 100644 --- a/crates/templates/src/context.rs +++ b/crates/templates/src/context.rs @@ -1332,7 +1332,7 @@ impl TemplateContext for RecoveryFinishContext { } } -/// Context used by the `pages/upstream_oauth2/{link_mismatch,do_login}.html` +/// Context used by the `pages/upstream_oauth2/{link_mismatch,login_link}.html` /// templates #[derive(Serialize)] pub struct UpstreamExistingLinkContext { diff --git a/crates/templates/src/lib.rs b/crates/templates/src/lib.rs index 8eca3d579..cebe1fdd7 100644 --- a/crates/templates/src/lib.rs +++ b/crates/templates/src/lib.rs @@ -401,6 +401,9 @@ register_templates! { /// Render the upstream link mismatch message pub fn render_upstream_oauth2_link_mismatch(WithLanguage>>) { "pages/upstream_oauth2/link_mismatch.html" } + /// Render the upstream link match + pub fn render_upstream_oauth2_login_link(WithLanguage>) { "pages/upstream_oauth2/login_link.html" } + /// Render the upstream suggest link message pub fn render_upstream_oauth2_suggest_link(WithLanguage>>) { "pages/upstream_oauth2/suggest_link.html" } @@ -468,6 +471,7 @@ impl Templates { check::render_email_verification_html(self, now, rng)?; check::render_email_verification_subject(self, now, rng)?; check::render_upstream_oauth2_link_mismatch(self, now, rng)?; + check::render_upstream_oauth2_login_link(self, now, rng)?; check::render_upstream_oauth2_suggest_link(self, now, rng)?; check::render_upstream_oauth2_do_register(self, now, rng)?; check::render_device_link(self, now, rng)?; diff --git a/docs/config.schema.json b/docs/config.schema.json index 0163e741e..8b726f1e9 100644 --- a/docs/config.schema.json +++ b/docs/config.schema.json @@ -2375,6 +2375,14 @@ "template": { "description": "The Jinja2 template to use for the localpart attribute\n\nIf not provided, the default template is `{{ user.preferred_username }}`", "type": "string" + }, + "on_conflict": { + "description": "How to handle conflicts on the claim, default value is `Fail`", + "allOf": [ + { + "$ref": "#/definitions/OnConflict" + } + ] } } }, @@ -2411,6 +2419,25 @@ } ] }, + "OnConflict": { + "description": "How to handle an existing localpart claim", + "oneOf": [ + { + "description": "Fails the sso login on conflict", + "type": "string", + "enum": [ + "fail" + ] + }, + { + "description": "Adds the oauth identity link, regardless of whether there is an existing link or not", + "type": "string", + "enum": [ + "add" + ] + } + ] + }, "DisplaynameImportPreference": { "description": "What should be done for the displayname attribute", "type": "object", diff --git a/docs/reference/configuration.md b/docs/reference/configuration.md index 4dad3d6a0..3e73f3642 100644 --- a/docs/reference/configuration.md +++ b/docs/reference/configuration.md @@ -776,6 +776,12 @@ upstream_oauth2: localpart: #action: force #template: "{{ user.preferred_username }}" + + # How to handle when localpart already exists. + # Possible values are (default: fail): + # - `add` : Adds the upstream account link to the existing user, regardless of whether there is an existing link or not. + # - `fail` : Fails the upstream OAuth 2.0 login. + #on_conflict: fail # The display name is the user's display name. displayname: diff --git a/docs/setup/sso.md b/docs/setup/sso.md index 4d82bd9a3..3442d06bd 100644 --- a/docs/setup/sso.md +++ b/docs/setup/sso.md @@ -66,6 +66,30 @@ The template has the following variables available: - `user`: an object which contains the claims from both the `id_token` and the `userinfo` endpoint - `extra_callback_parameters`: an object with the additional parameters the provider sent to the redirect URL + +## Allow linking existing user accounts + +The authentication service supports linking external provider identities to existing local user accounts. + +To enable this behavior, the following option must be explicitly set in the provider configuration: + +```yaml +claims_imports: + localpart: + on_conflict: add +``` +`on_conflict` configuration is specific to `localpart` claim_imports, it can be either: +* `add` : when a user authenticates with the provider for the first time, the system checks whether a local user already exists with a `localpart` matching the attribute mapping `localpart` , _by default `{{ user.preferred_username }}`_. If a match is found, the external identity is linked to the existing local account. +* `fail` *(default)* : fails the sso login. + +To enable this option, the `localpart` mapping must be set to either `force` or `require`. + +> ⚠️ **Security Notice** +> Enabling this option can introduce a risk of account takeover. +> +> To mitigate this risk, ensure that this option is only enabled for identity providers where you can guarantee that the attribute mapping `localpart` will reliably and uniquely correspond to the intended local user account. + + ## Multiple providers behaviour Multiple authentication methods can be configured at the same time, in which case the authentication service will let the user choose which one to use. diff --git a/templates/pages/upstream_oauth2/do_register.html b/templates/pages/upstream_oauth2/do_register.html index 6e147600f..b3d8bdee8 100644 --- a/templates/pages/upstream_oauth2/do_register.html +++ b/templates/pages/upstream_oauth2/do_register.html @@ -188,12 +188,6 @@

{% endcall %} {% endif %} - {{ button.button(text=_("action.create_account")) }} - - {# Leave this for now as we don't have that fully designed yet - {{ field.separator() }} - {{ button.link_outline(text=_("mas.upstream_oauth2.register.link_existing"), href=login_link) }} - #} {% endblock content %} diff --git a/templates/pages/upstream_oauth2/login_link.html b/templates/pages/upstream_oauth2/login_link.html new file mode 100644 index 000000000..cdde102b2 --- /dev/null +++ b/templates/pages/upstream_oauth2/login_link.html @@ -0,0 +1,31 @@ +{# +Copyright 2025 New Vector Ltd. + +SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial +Please see LICENSE in the repository root for full details. +-#} + +{% extends "base.html" %} + +{% block content %} +
+
+ {{ icon.link() }} +
+ +
+

{{ _("mas.upstream_oauth2.login_link.heading") }}

+
+
+
+ {{ _("mas.upstream_oauth2.login_link.description", username=linked_user.username) }} + +
+ + + + {{ button.button(text=_("mas.upstream_oauth2.login_link.action")) }} +
+ +
+{% endblock content %} diff --git a/translations/en.json b/translations/en.json index d17e09338..0d263fb55 100644 --- a/translations/en.json +++ b/translations/en.json @@ -14,7 +14,7 @@ }, "create_account": "Create Account", "@create_account": { - "context": "pages/login.html:94:33-59, pages/upstream_oauth2/do_register.html:192:26-52" + "context": "pages/login.html:94:33-59, pages/upstream_oauth2/do_register.html:191:26-52" }, "sign_in": "Sign in", "@sign_in": { @@ -693,6 +693,20 @@ "description": "Page shown when the user tries to link an upstream account that is already linked to another account" } }, + "login_link": { + "action": "Continue", + "@action": { + "context": "pages/upstream_oauth2/login_link.html:27:28-70" + }, + "description": "An account exists for this username (%(username)s), it will be linked to this upstream account.", + "@description": { + "context": "pages/upstream_oauth2/login_link.html:21:7-85" + }, + "heading": "Link to your existing account", + "@heading": { + "context": "pages/upstream_oauth2/login_link.html:17:27-70" + } + }, "register": { "choose_username": { "description": "This cannot be changed later.",