diff --git a/Cargo.lock b/Cargo.lock index 76127f916..154c833db 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3311,6 +3311,7 @@ dependencies = [ "regex", "ruma-common", "serde", + "serde_json", "thiserror", "ulid", "url", @@ -3346,6 +3347,7 @@ dependencies = [ "camino", "chrono", "cookie_store", + "elliptic-curve", "futures-util", "governor", "headers", @@ -3377,6 +3379,7 @@ dependencies = [ "opentelemetry", "opentelemetry-semantic-conventions", "pbkdf2", + "pkcs8", "psl", "rand", "rand_chacha", @@ -3612,6 +3615,7 @@ dependencies = [ "base64ct", "bitflags 2.6.0", "chrono", + "elliptic-curve", "form_urlencoded", "headers", "http", @@ -3623,6 +3627,9 @@ dependencies = [ "mas-keystore", "mime", "oauth2-types", + "p256", + "pem-rfc7468", + "pkcs8", "rand", "rand_chacha", "reqwest", diff --git a/Cargo.toml b/Cargo.toml index 27a22f911..fd447eb28 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -102,6 +102,11 @@ features = ["serde", "clock"] version = "4.5.21" features = ["derive"] +# Elliptic curve cryptography +[workspace.dependencies.elliptic-curve] +version = "0.13.8" +features = ["std", "pem", "sec1"] + # Configuration loading [workspace.dependencies.figment] version = "0.10.19" @@ -188,6 +193,36 @@ features = ["pycompat"] [workspace.dependencies.nonzero_ext] version = "0.3.0" +# K256 elliptic curve +[workspace.dependencies.k256] +version = "0.13.4" +features = ["std"] + +# P256 elliptic curve +[workspace.dependencies.p256] +version = "0.13.2" +features = ["std"] + +# P384 elliptic curve +[workspace.dependencies.p384] +version = "0.13.0" +features = ["std"] + +# PEM file decoding +[workspace.dependencies.pem-rfc7468] +version = "0.7.0" +features = ["std"] + +# PKCS#1 encoding +[workspace.dependencies.pkcs1] +version = "0.7.5" +features = ["std"] + +# PKCS#8 encoding +[workspace.dependencies.pkcs8] +version = "0.10.2" +features = ["std", "pkcs5", "encryption"] + # Random values [workspace.dependencies.rand] version = "0.8.5" diff --git a/crates/cli/src/sync.rs b/crates/cli/src/sync.rs index 0a0e01ddc..aefdead14 100644 --- a/crates/cli/src/sync.rs +++ b/crates/cli/src/sync.rs @@ -187,11 +187,17 @@ pub async fn config_sync( continue; } - let encrypted_client_secret = provider - .client_secret - .as_deref() - .map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes())) - .transpose()?; + let encrypted_client_secret = + if let Some(client_secret) = provider.client_secret.as_deref() { + Some(encrypter.encrypt_to_string(client_secret.as_bytes())?) + } else if let Some(siwa) = provider.sign_in_with_apple.as_ref() { + // For SIWA, we JSON-encode the config and encrypt it, reusing the client_secret + // field in the database + let encoded = serde_json::to_vec(siwa)?; + Some(encrypter.encrypt_to_string(&encoded)?) + } else { + None + }; let discovery_mode = match provider.discovery_mode { mas_config::UpstreamOAuth2DiscoveryMode::Oidc => { @@ -205,6 +211,36 @@ pub async fn config_sync( } }; + let token_endpoint_auth_method = match provider.token_endpoint_auth_method { + mas_config::UpstreamOAuth2TokenAuthMethod::None => { + mas_data_model::UpstreamOAuthProviderTokenAuthMethod::None + } + mas_config::UpstreamOAuth2TokenAuthMethod::ClientSecretBasic => { + mas_data_model::UpstreamOAuthProviderTokenAuthMethod::ClientSecretBasic + } + mas_config::UpstreamOAuth2TokenAuthMethod::ClientSecretPost => { + mas_data_model::UpstreamOAuthProviderTokenAuthMethod::ClientSecretPost + } + mas_config::UpstreamOAuth2TokenAuthMethod::ClientSecretJwt => { + mas_data_model::UpstreamOAuthProviderTokenAuthMethod::ClientSecretJwt + } + mas_config::UpstreamOAuth2TokenAuthMethod::PrivateKeyJwt => { + mas_data_model::UpstreamOAuthProviderTokenAuthMethod::PrivateKeyJwt + } + mas_config::UpstreamOAuth2TokenAuthMethod::SignInWithApple => { + mas_data_model::UpstreamOAuthProviderTokenAuthMethod::SignInWithApple + } + }; + + let response_mode = match provider.response_mode { + mas_config::UpstreamOAuth2ResponseMode::Query => { + mas_data_model::UpstreamOAuthProviderResponseMode::Query + } + mas_config::UpstreamOAuth2ResponseMode::FormPost => { + mas_data_model::UpstreamOAuthProviderResponseMode::FormPost + } + }; + if discovery_mode.is_disabled() { if provider.authorization_endpoint.is_none() { error!("Provider has discovery disabled but no authorization endpoint set"); @@ -240,7 +276,7 @@ pub async fn config_sync( human_name: provider.human_name, brand_name: provider.brand_name, scope: provider.scope.parse()?, - token_endpoint_auth_method: provider.token_endpoint_auth_method.into(), + token_endpoint_auth_method, token_endpoint_signing_alg: provider .token_endpoint_auth_signing_alg .clone(), @@ -252,6 +288,7 @@ pub async fn config_sync( jwks_uri_override: provider.jwks_uri, discovery_mode, pkce_mode, + response_mode, additional_authorization_parameters: provider .additional_authorization_parameters .into_iter() diff --git a/crates/config/src/sections/mod.rs b/crates/config/src/sections/mod.rs index b21957aac..faf0b0087 100644 --- a/crates/config/src/sections/mod.rs +++ b/crates/config/src/sections/mod.rs @@ -51,7 +51,9 @@ pub use self::{ ClaimsImports as UpstreamOAuth2ClaimsImports, DiscoveryMode as UpstreamOAuth2DiscoveryMode, EmailImportPreference as UpstreamOAuth2EmailImportPreference, ImportAction as UpstreamOAuth2ImportAction, PkceMethod as UpstreamOAuth2PkceMethod, - SetEmailVerification as UpstreamOAuth2SetEmailVerification, UpstreamOAuth2Config, + ResponseMode as UpstreamOAuth2ResponseMode, + SetEmailVerification as UpstreamOAuth2SetEmailVerification, + TokenAuthMethod as UpstreamOAuth2TokenAuthMethod, UpstreamOAuth2Config, }, }; use crate::util::ConfigurationSection; diff --git a/crates/config/src/sections/upstream_oauth2.rs b/crates/config/src/sections/upstream_oauth2.rs index 4742b2775..d91881d3b 100644 --- a/crates/config/src/sections/upstream_oauth2.rs +++ b/crates/config/src/sections/upstream_oauth2.rs @@ -6,7 +6,7 @@ use std::collections::BTreeMap; -use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; +use mas_iana::jose::JsonWebSignatureAlg; use schemars::JsonSchema; use serde::{de::Error, Deserialize, Serialize}; use serde_with::skip_serializing_none; @@ -48,7 +48,9 @@ impl ConfigurationSection for UpstreamOAuth2Config { }; match provider.token_endpoint_auth_method { - TokenAuthMethod::None | TokenAuthMethod::PrivateKeyJwt => { + TokenAuthMethod::None + | TokenAuthMethod::PrivateKeyJwt + | TokenAuthMethod::SignInWithApple => { if provider.client_secret.is_some() { return annotate(figment::Error::custom("Unexpected field `client_secret` for the selected authentication method")); } @@ -65,7 +67,8 @@ impl ConfigurationSection for UpstreamOAuth2Config { match provider.token_endpoint_auth_method { TokenAuthMethod::None | TokenAuthMethod::ClientSecretBasic - | TokenAuthMethod::ClientSecretPost => { + | TokenAuthMethod::ClientSecretPost + | TokenAuthMethod::SignInWithApple => { if provider.token_endpoint_auth_signing_alg.is_some() { return annotate(figment::Error::custom( "Unexpected field `token_endpoint_auth_signing_alg` for the selected authentication method", @@ -80,12 +83,51 @@ impl ConfigurationSection for UpstreamOAuth2Config { } } } + + match provider.token_endpoint_auth_method { + TokenAuthMethod::SignInWithApple => { + if provider.sign_in_with_apple.is_none() { + return annotate(figment::Error::missing_field("sign_in_with_apple")); + } + } + + _ => { + if provider.sign_in_with_apple.is_some() { + return annotate(figment::Error::custom( + "Unexpected field `sign_in_with_apple` for the selected authentication method", + )); + } + } + } } Ok(()) } } +/// The response mode we ask the provider to use for the callback +#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default, JsonSchema)] +#[serde(rename_all = "snake_case")] +pub enum ResponseMode { + /// `query`: The provider will send the response as a query string in the + /// URL search parameters + #[default] + Query, + + /// `form_post`: The provider will send the response as a POST request with + /// the response parameters in the request body + /// + /// + FormPost, +} + +impl ResponseMode { + #[allow(clippy::trivially_copy_pass_by_ref)] + const fn is_default(&self) -> bool { + matches!(self, ResponseMode::Query) + } +} + /// Authentication methods used against the OAuth 2.0 provider #[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "snake_case")] @@ -108,20 +150,9 @@ pub enum TokenAuthMethod { /// `private_key_jwt`: a `client_assertion` sent in the request body and /// signed by an asymmetric key PrivateKeyJwt, -} -impl From for OAuthClientAuthenticationMethod { - fn from(method: TokenAuthMethod) -> Self { - match method { - TokenAuthMethod::None => OAuthClientAuthenticationMethod::None, - TokenAuthMethod::ClientSecretBasic => { - OAuthClientAuthenticationMethod::ClientSecretBasic - } - TokenAuthMethod::ClientSecretPost => OAuthClientAuthenticationMethod::ClientSecretPost, - TokenAuthMethod::ClientSecretJwt => OAuthClientAuthenticationMethod::ClientSecretJwt, - TokenAuthMethod::PrivateKeyJwt => OAuthClientAuthenticationMethod::PrivateKeyJwt, - } - } + /// `sign_in_with_apple`: a special method for Signin with Apple + SignInWithApple, } /// How to handle a claim @@ -343,6 +374,18 @@ fn is_default_true(value: &bool) -> bool { *value } +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +pub struct SignInWithApple { + /// The private key used to sign the `id_token` + pub private_key: String, + + /// The Team ID of the Apple Developer Portal + pub team_id: String, + + /// The key ID of the Apple Developer Portal + pub key_id: String, +} + #[skip_serializing_none] #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct Provider { @@ -394,6 +437,10 @@ pub struct Provider { /// The method to authenticate the client with the provider pub token_endpoint_auth_method: TokenAuthMethod, + /// Additional parameters for the `sign_in_with_apple` method + #[serde(skip_serializing_if = "Option::is_none")] + pub sign_in_with_apple: Option, + /// The JWS algorithm to use when authenticating the client with the /// provider /// @@ -436,6 +483,10 @@ pub struct Provider { #[serde(skip_serializing_if = "Option::is_none")] pub jwks_uri: Option, + /// The response mode we ask the provider to use for the callback + #[serde(default, skip_serializing_if = "ResponseMode::is_default")] + pub response_mode: ResponseMode, + /// How claims should be imported from the `id_token` provided by the /// provider #[serde(default, skip_serializing_if = "ClaimsImports::is_default")] diff --git a/crates/data-model/Cargo.toml b/crates/data-model/Cargo.toml index 16cab034f..2c648b1ff 100644 --- a/crates/data-model/Cargo.toml +++ b/crates/data-model/Cargo.toml @@ -15,6 +15,7 @@ workspace = true chrono.workspace = true thiserror.workspace = true serde.workspace = true +serde_json.workspace = true url.workspace = true crc = "3.2.1" ulid.workspace = true diff --git a/crates/data-model/src/lib.rs b/crates/data-model/src/lib.rs index c0b39792a..19d7f4469 100644 --- a/crates/data-model/src/lib.rs +++ b/crates/data-model/src/lib.rs @@ -41,7 +41,8 @@ pub use self::{ UpstreamOAuthAuthorizationSessionState, UpstreamOAuthLink, UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode, UpstreamOAuthProviderImportAction, UpstreamOAuthProviderImportPreference, - UpstreamOAuthProviderPkceMode, UpstreamOAuthProviderSubjectPreference, + UpstreamOAuthProviderPkceMode, UpstreamOAuthProviderResponseMode, + UpstreamOAuthProviderSubjectPreference, UpstreamOAuthProviderTokenAuthMethod, }, user_agent::{DeviceType, UserAgent}, users::{ diff --git a/crates/data-model/src/upstream_oauth2/mod.rs b/crates/data-model/src/upstream_oauth2/mod.rs index cfa21ea1a..bede13cb6 100644 --- a/crates/data-model/src/upstream_oauth2/mod.rs +++ b/crates/data-model/src/upstream_oauth2/mod.rs @@ -16,8 +16,10 @@ pub use self::{ ImportAction as UpstreamOAuthProviderImportAction, ImportPreference as UpstreamOAuthProviderImportPreference, PkceMode as UpstreamOAuthProviderPkceMode, + ResponseMode as UpstreamOAuthProviderResponseMode, SetEmailVerification as UpsreamOAuthProviderSetEmailVerification, - SubjectPreference as UpstreamOAuthProviderSubjectPreference, UpstreamOAuthProvider, + SubjectPreference as UpstreamOAuthProviderSubjectPreference, + TokenAuthMethod as UpstreamOAuthProviderTokenAuthMethod, UpstreamOAuthProvider, }, session::{UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState}, }; diff --git a/crates/data-model/src/upstream_oauth2/provider.rs b/crates/data-model/src/upstream_oauth2/provider.rs index 0cb976a73..5656fd78f 100644 --- a/crates/data-model/src/upstream_oauth2/provider.rs +++ b/crates/data-model/src/upstream_oauth2/provider.rs @@ -5,7 +5,7 @@ // Please see LICENSE in the repository root for full details. use chrono::{DateTime, Utc}; -use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; +use mas_iana::jose::JsonWebSignatureAlg; use oauth2_types::scope::Scope; use serde::{Deserialize, Serialize}; use thiserror::Error; @@ -116,6 +116,106 @@ impl std::fmt::Display for PkceMode { } } +#[derive(Debug, Clone, Error)] +#[error("Invalid response mode {0:?}")] +pub struct InvalidResponseModeError(String); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "snake_case")] +pub enum ResponseMode { + #[default] + Query, + FormPost, +} + +impl From for oauth2_types::requests::ResponseMode { + fn from(value: ResponseMode) -> Self { + match value { + ResponseMode::Query => oauth2_types::requests::ResponseMode::Query, + ResponseMode::FormPost => oauth2_types::requests::ResponseMode::FormPost, + } + } +} + +impl ResponseMode { + #[must_use] + pub fn as_str(self) -> &'static str { + match self { + Self::Query => "query", + Self::FormPost => "form_post", + } + } +} + +impl std::fmt::Display for ResponseMode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.as_str()) + } +} + +impl std::str::FromStr for ResponseMode { + type Err = InvalidResponseModeError; + + fn from_str(s: &str) -> Result { + match s { + "query" => Ok(ResponseMode::Query), + "form_post" => Ok(ResponseMode::FormPost), + s => Err(InvalidResponseModeError(s.to_owned())), + } + } +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum TokenAuthMethod { + None, + ClientSecretBasic, + ClientSecretPost, + ClientSecretJwt, + PrivateKeyJwt, + SignInWithApple, +} + +impl TokenAuthMethod { + #[must_use] + pub fn as_str(self) -> &'static str { + match self { + Self::None => "none", + Self::ClientSecretBasic => "client_secret_basic", + Self::ClientSecretPost => "client_secret_post", + Self::ClientSecretJwt => "client_secret_jwt", + Self::PrivateKeyJwt => "private_key_jwt", + Self::SignInWithApple => "sign_in_with_apple", + } + } +} + +impl std::fmt::Display for TokenAuthMethod { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.as_str()) + } +} + +impl std::str::FromStr for TokenAuthMethod { + type Err = InvalidUpstreamOAuth2TokenAuthMethod; + + fn from_str(s: &str) -> Result { + match s { + "none" => Ok(Self::None), + "client_secret_post" => Ok(Self::ClientSecretPost), + "client_secret_basic" => Ok(Self::ClientSecretBasic), + "client_secret_jwt" => Ok(Self::ClientSecretJwt), + "private_key_jwt" => Ok(Self::PrivateKeyJwt), + "sign_in_with_apple" => Ok(Self::SignInWithApple), + s => Err(InvalidUpstreamOAuth2TokenAuthMethod(s.to_owned())), + } + } +} + +#[derive(Debug, Clone, Error)] +#[error("Invalid upstream OAuth 2.0 token auth method: {0}")] +pub struct InvalidUpstreamOAuth2TokenAuthMethod(String); + #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub struct UpstreamOAuthProvider { pub id: Ulid, @@ -126,12 +226,13 @@ pub struct UpstreamOAuthProvider { pub pkce_mode: PkceMode, pub jwks_uri_override: Option, pub authorization_endpoint_override: Option, - pub token_endpoint_override: Option, pub scope: Scope, + pub token_endpoint_override: Option, pub client_id: String, pub encrypted_client_secret: Option, pub token_endpoint_signing_alg: Option, - pub token_endpoint_auth_method: OAuthClientAuthenticationMethod, + pub token_endpoint_auth_method: TokenAuthMethod, + pub response_mode: ResponseMode, pub created_at: DateTime, pub disabled_at: Option>, pub claims_imports: ClaimsImports, diff --git a/crates/data-model/src/upstream_oauth2/session.rs b/crates/data-model/src/upstream_oauth2/session.rs index eb59000a0..38b622987 100644 --- a/crates/data-model/src/upstream_oauth2/session.rs +++ b/crates/data-model/src/upstream_oauth2/session.rs @@ -19,12 +19,14 @@ pub enum UpstreamOAuthAuthorizationSessionState { completed_at: DateTime, link_id: Ulid, id_token: Option, + extra_callback_parameters: Option, }, Consumed { completed_at: DateTime, consumed_at: DateTime, link_id: Ulid, id_token: Option, + extra_callback_parameters: Option, }, } @@ -42,12 +44,14 @@ impl UpstreamOAuthAuthorizationSessionState { completed_at: DateTime, link: &UpstreamOAuthLink, id_token: Option, + extra_callback_parameters: Option, ) -> Result { match self { Self::Pending => Ok(Self::Completed { completed_at, link_id: link.id, id_token, + extra_callback_parameters, }), Self::Completed { .. } | Self::Consumed { .. } => Err(InvalidTransitionError), } @@ -67,11 +71,13 @@ impl UpstreamOAuthAuthorizationSessionState { completed_at, link_id, id_token, + extra_callback_parameters, } => Ok(Self::Consumed { completed_at, link_id, consumed_at, id_token, + extra_callback_parameters, }), Self::Pending | Self::Consumed { .. } => Err(InvalidTransitionError), } @@ -124,6 +130,27 @@ impl UpstreamOAuthAuthorizationSessionState { } } + /// Get the extra query parameters that were sent to the upstream provider. + /// + /// Returns `None` if the upstream OAuth 2.0 authorization session state is + /// not [`Pending`]. + /// + /// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending + #[must_use] + pub fn extra_callback_parameters(&self) -> Option<&serde_json::Value> { + match self { + Self::Pending => None, + Self::Completed { + extra_callback_parameters, + .. + } + | Self::Consumed { + extra_callback_parameters, + .. + } => extra_callback_parameters.as_ref(), + } + } + /// Get the time at which the upstream OAuth 2.0 authorization session was /// consumed. /// @@ -201,8 +228,11 @@ impl UpstreamOAuthAuthorizationSession { completed_at: DateTime, link: &UpstreamOAuthLink, id_token: Option, + extra_callback_parameters: Option, ) -> Result { - self.state = self.state.complete(completed_at, link, id_token)?; + self.state = + self.state + .complete(completed_at, link, id_token, extra_callback_parameters)?; Ok(self) } diff --git a/crates/handlers/Cargo.toml b/crates/handlers/Cargo.toml index b23743d26..3eecfff1c 100644 --- a/crates/handlers/Cargo.toml +++ b/crates/handlers/Cargo.toml @@ -71,8 +71,10 @@ zeroize = "1.8.1" base64ct = "1.6.0" camino.workspace = true chrono.workspace = true +elliptic-curve.workspace = true governor.workspace = true indexmap = "2.6.0" +pkcs8.workspace = true psl = "2.1.56" time = "0.3.36" url.workspace = true diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index d1c1f0840..9d0abb93f 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -402,7 +402,8 @@ where ) .route( mas_router::UpstreamOAuth2Callback::route(), - get(self::upstream_oauth2::callback::get), + get(self::upstream_oauth2::callback::handler) + .post(self::upstream_oauth2::callback::handler), ) .route( mas_router::UpstreamOAuth2Link::route(), diff --git a/crates/handlers/src/oauth2/authorization/callback.rs b/crates/handlers/src/oauth2/authorization/callback.rs index ba35e3c52..b76722d5a 100644 --- a/crates/handlers/src/oauth2/authorization/callback.rs +++ b/crates/handlers/src/oauth2/authorization/callback.rs @@ -10,6 +10,7 @@ use std::collections::HashMap; use axum::response::{Html, IntoResponse, Redirect, Response}; use mas_data_model::AuthorizationGrant; +use mas_i18n::DataLocale; use mas_templates::{FormPostContext, Templates}; use oauth2_types::requests::ResponseMode; use serde::Serialize; @@ -103,6 +104,7 @@ impl CallbackDestination { pub async fn go( self, templates: &Templates, + locale: &DataLocale, params: T, ) -> Result { #[derive(Serialize)] @@ -155,7 +157,7 @@ impl CallbackDestination { state, params, }; - let ctx = FormPostContext::new(redirect_uri, merged); + let ctx = FormPostContext::new_for_url(redirect_uri, merged).with_language(locale); let rendered = templates.render_form_post(&ctx)?; Ok(Html(rendered).into_response()) } diff --git a/crates/handlers/src/oauth2/authorization/complete.rs b/crates/handlers/src/oauth2/authorization/complete.rs index 3eedf71e6..a9efb2ae2 100644 --- a/crates/handlers/src/oauth2/authorization/complete.rs +++ b/crates/handlers/src/oauth2/authorization/complete.rs @@ -141,7 +141,7 @@ pub(crate) async fn get( .await { Ok(params) => { - let res = callback_destination.go(&templates, params).await?; + let res = callback_destination.go(&templates, &locale, params).await?; Ok((cookie_jar, res).into_response()) } Err(GrantCompletionError::RequiresReauth) => Ok(( diff --git a/crates/handlers/src/oauth2/authorization/mod.rs b/crates/handlers/src/oauth2/authorization/mod.rs index 236e8a795..f56f06133 100644 --- a/crates/handlers/src/oauth2/authorization/mod.rs +++ b/crates/handlers/src/oauth2/authorization/mod.rs @@ -170,6 +170,7 @@ pub(crate) async fn get( let res: Result = ({ let templates = templates.clone(); let callback_destination = callback_destination.clone(); + let locale = locale.clone(); async move { let maybe_session = session_info.load_session(&mut repo).await?; let prompt = params.auth.prompt.as_deref().unwrap_or_default(); @@ -180,6 +181,7 @@ pub(crate) async fn get( return Ok(callback_destination .go( &templates, + &locale, ClientError::from(ClientErrorCode::RequestNotSupported), ) .await?); @@ -189,6 +191,7 @@ pub(crate) async fn get( return Ok(callback_destination .go( &templates, + &locale, ClientError::from(ClientErrorCode::RequestUriNotSupported), ) .await?); @@ -200,6 +203,7 @@ pub(crate) async fn get( return Ok(callback_destination .go( &templates, + &locale, ClientError::from(ClientErrorCode::UnsupportedResponseType), ) .await?); @@ -211,6 +215,7 @@ pub(crate) async fn get( return Ok(callback_destination .go( &templates, + &locale, ClientError::from(ClientErrorCode::UnauthorizedClient), ) .await?); @@ -220,6 +225,7 @@ pub(crate) async fn get( return Ok(callback_destination .go( &templates, + &locale, ClientError::from(ClientErrorCode::RegistrationNotSupported), ) .await?); @@ -230,6 +236,7 @@ pub(crate) async fn get( return Ok(callback_destination .go( &templates, + &locale, ClientError::from(ClientErrorCode::LoginRequired), ) .await?); @@ -241,6 +248,7 @@ pub(crate) async fn get( return Ok(callback_destination .go( &templates, + &locale, ClientError::from(ClientErrorCode::UnauthorizedClient), ) .await?); @@ -266,6 +274,7 @@ pub(crate) async fn get( return Ok(callback_destination .go( &templates, + &locale, ClientError::from(ClientErrorCode::InvalidRequest), ) .await?); @@ -350,11 +359,12 @@ pub(crate) async fn get( ) .await { - Ok(params) => callback_destination.go(&templates, params).await?, + Ok(params) => callback_destination.go(&templates, &locale, params).await?, Err(GrantCompletionError::RequiresConsent) => { callback_destination .go( &templates, + &locale, ClientError::from(ClientErrorCode::ConsentRequired), ) .await? @@ -363,13 +373,14 @@ pub(crate) async fn get( callback_destination .go( &templates, + &locale, ClientError::from(ClientErrorCode::InteractionRequired), ) .await? } Err(GrantCompletionError::PolicyViolation(_grant, _res)) => { callback_destination - .go(&templates, ClientError::from(ClientErrorCode::AccessDenied)) + .go(&templates, &locale, ClientError::from(ClientErrorCode::AccessDenied)) .await? } Err(GrantCompletionError::Internal(e)) => { @@ -400,7 +411,7 @@ pub(crate) async fn get( ) .await { - Ok(params) => callback_destination.go(&templates, params).await?, + Ok(params) => callback_destination.go(&templates, &locale, params).await?, Err(GrantCompletionError::RequiresConsent) => { url_builder.redirect(&mas_router::Consent(grant_id)).into_response() } @@ -440,7 +451,11 @@ pub(crate) async fn get( Err(err) => { tracing::error!(%err); callback_destination - .go(&templates, ClientError::from(ClientErrorCode::ServerError)) + .go( + &templates, + &locale, + ClientError::from(ClientErrorCode::ServerError), + ) .await? } }; diff --git a/crates/handlers/src/upstream_oauth2/authorize.rs b/crates/handlers/src/upstream_oauth2/authorize.rs index 61e7e9ac1..67d4d9a7d 100644 --- a/crates/handlers/src/upstream_oauth2/authorize.rs +++ b/crates/handlers/src/upstream_oauth2/authorize.rs @@ -87,7 +87,8 @@ pub(crate) async fn get( provider.client_id.clone(), provider.scope.clone(), redirect_uri, - ); + ) + .with_response_mode(provider.response_mode.into()); let data = if let Some(methods) = lazy_metadata.pkce_methods().await? { data.with_code_challenge_methods_supported(methods) diff --git a/crates/handlers/src/upstream_oauth2/cache.rs b/crates/handlers/src/upstream_oauth2/cache.rs index c2f307e14..248271229 100644 --- a/crates/handlers/src/upstream_oauth2/cache.rs +++ b/crates/handlers/src/upstream_oauth2/cache.rs @@ -274,8 +274,9 @@ mod tests { // XXX: sadly, we can't test HTTPS requests with wiremock, so we can only test // 'insecure' discovery - use mas_data_model::UpstreamOAuthProviderClaimsImports; - use mas_iana::oauth::OAuthClientAuthenticationMethod; + use mas_data_model::{ + UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderTokenAuthMethod, + }; use mas_storage::{clock::MockClock, Clock}; use oauth2_types::scope::{Scope, OPENID}; use ulid::Ulid; @@ -388,12 +389,13 @@ mod tests { pkce_mode: UpstreamOAuthProviderPkceMode::Auto, jwks_uri_override: None, authorization_endpoint_override: None, - token_endpoint_override: None, scope: Scope::from_iter([OPENID]), + token_endpoint_override: None, client_id: "client_id".to_owned(), encrypted_client_secret: None, token_endpoint_signing_alg: None, - token_endpoint_auth_method: OAuthClientAuthenticationMethod::None, + token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None, + response_mode: mas_data_model::UpstreamOAuthProviderResponseMode::Query, created_at: clock.now(), disabled_at: None, claims_imports: UpstreamOAuthProviderClaimsImports::default(), diff --git a/crates/handlers/src/upstream_oauth2/callback.rs b/crates/handlers/src/upstream_oauth2/callback.rs index 6b070e6dd..927874a3f 100644 --- a/crates/handlers/src/upstream_oauth2/callback.rs +++ b/crates/handlers/src/upstream_oauth2/callback.rs @@ -6,11 +6,13 @@ use axum::{ extract::{Path, Query, State}, - response::IntoResponse, + response::{IntoResponse, Response}, + Form, }; +use axum_extra::response::Html; use hyper::StatusCode; use mas_axum_utils::{cookies::CookieJar, sentry::SentryEventID}; -use mas_data_model::UpstreamOAuthProvider; +use mas_data_model::{UpstreamOAuthProvider, UpstreamOAuthProviderResponseMode}; use mas_keystore::{Encrypter, Keystore}; use mas_oidc_client::requests::{ authorization_code::AuthorizationValidationData, jose::JwtVerificationData, @@ -23,30 +25,36 @@ use mas_storage::{ }, BoxClock, BoxRepository, BoxRng, Clock, }; +use mas_templates::{FormPostContext, Templates}; use oauth2_types::errors::ClientErrorCode; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use thiserror::Error; use ulid::Ulid; use super::{ - cache::LazyProviderInfos, client_credentials_for_provider, template::environment, + cache::LazyProviderInfos, + client_credentials_for_provider, + template::{environment, AttributeMappingContext}, UpstreamSessionsCookie, }; -use crate::{impl_from_error_for_route, upstream_oauth2::cache::MetadataCache}; +use crate::{impl_from_error_for_route, upstream_oauth2::cache::MetadataCache, PreferredLanguage}; -#[derive(Deserialize)] -pub struct QueryParams { +#[derive(Serialize, Deserialize)] +pub struct Params { state: String, #[serde(flatten)] code_or_error: CodeOrError, } -#[derive(Deserialize)] +#[derive(Serialize, Deserialize)] #[serde(untagged)] enum CodeOrError { Code { code: String, + + #[serde(flatten)] + extra_callback_parameters: Option, }, Error { error: ClientErrorCode, @@ -91,10 +99,25 @@ pub(crate) enum RouteError { #[error("Missing session cookie")] MissingCookie, + #[error("Missing query parameters")] + MissingQueryParams, + + #[error("Missing form parameters")] + MissingFormParams, + + #[error("Ambiguous parameters: got both query and form parameters")] + AmbiguousParams, + + #[error("Invalid response mode, expected '{expected}'")] + InvalidParamsMode { + expected: UpstreamOAuthProviderResponseMode, + }, + #[error(transparent)] Internal(Box), } +impl_from_error_for_route!(mas_templates::TemplateError); impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError); impl_from_error_for_route!(mas_oidc_client::error::JwksError); @@ -117,13 +140,13 @@ impl IntoResponse for RouteError { } #[tracing::instrument( - name = "handlers.upstream_oauth2.callback.get", + name = "handlers.upstream_oauth2.callback.handler", fields(upstream_oauth_provider.id = %provider_id), skip_all, err, )] #[allow(clippy::too_many_lines, clippy::too_many_arguments)] -pub(crate) async fn get( +pub(crate) async fn handler( mut rng: BoxRng, clock: BoxClock, State(metadata_cache): State, @@ -132,10 +155,13 @@ pub(crate) async fn get( State(encrypter): State, State(keystore): State, State(client): State, + State(templates): State, + PreferredLanguage(locale): PreferredLanguage, cookie_jar: CookieJar, Path(provider_id): Path, - Query(params): Query, -) -> Result { + query_params: Option>, + form_params: Option>, +) -> Result { let provider = repo .upstream_oauth_provider() .lookup(provider_id) @@ -144,6 +170,33 @@ pub(crate) async fn get( .ok_or(RouteError::ProviderNotFound)?; let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar); + + // Read the parameters from the query or the form, depending on what + // response_mode the provider uses + let params = match (provider.response_mode, query_params, form_params) { + (UpstreamOAuthProviderResponseMode::Query, Some(Query(query_params)), None) => query_params, + (UpstreamOAuthProviderResponseMode::FormPost, None, Some(Form(form_params))) => { + // We got there from a cross-site form POST, so we need to render a form with + // the same values, which posts back to the same URL + if sessions_cookie.is_empty() { + let context = + FormPostContext::new_for_current_url(form_params).with_language(&locale); + let html = templates.render_form_post(&context)?; + return Ok(Html(html).into_response()); + } + + form_params + } + (UpstreamOAuthProviderResponseMode::Query, None, None) => { + return Err(RouteError::MissingQueryParams) + } + (UpstreamOAuthProviderResponseMode::FormPost, None, None) => { + return Err(RouteError::MissingFormParams) + } + (_, Some(_), Some(_)) => return Err(RouteError::AmbiguousParams), + (expected, _, _) => return Err(RouteError::InvalidParamsMode { expected }), + }; + let (session_id, _post_auth_action) = sessions_cookie .find_session(provider_id, ¶ms.state) .map_err(|_| RouteError::MissingCookie)?; @@ -170,7 +223,7 @@ pub(crate) async fn get( } // Let's extract the code from the params, and return if there was an error - let code = match params.code_or_error { + let (code, extra_callback_parameters) = match params.code_or_error { CodeOrError::Error { error, error_description, @@ -181,7 +234,10 @@ pub(crate) async fn get( error_description, }) } - CodeOrError::Code { code } => code, + CodeOrError::Code { + code, + extra_callback_parameters, + } => (code, extra_callback_parameters), }; let mut lazy_metadata = LazyProviderInfos::new(&metadata_cache, &provider, &client); @@ -232,11 +288,13 @@ pub(crate) async fn get( let (_header, id_token) = id_token.ok_or(RouteError::MissingIDToken)?.into_parts(); - let env = { - let mut env = environment(); - env.add_global("user", minijinja::Value::from_serialize(&id_token)); - env - }; + let mut context = AttributeMappingContext::new().with_id_token_claims(id_token); + if let Some(extra_callback_parameters) = extra_callback_parameters.clone() { + context = context.with_extra_callback_parameters(extra_callback_parameters); + } + let context = context.build(); + + let env = environment(); let template = provider .claims_imports @@ -245,7 +303,7 @@ pub(crate) async fn get( .as_deref() .unwrap_or("{{ user.sub }}"); let subject = env - .render_str(template, ()) + .render_str(template, context) .map_err(RouteError::ExtractSubject)?; if subject.is_empty() { @@ -268,7 +326,13 @@ pub(crate) async fn get( let session = repo .upstream_oauth_session() - .complete_with_link(&clock, session, &link, response.id_token) + .complete_with_link( + &clock, + session, + &link, + response.id_token, + extra_callback_parameters, + ) .await?; let cookie_jar = sessions_cookie @@ -280,5 +344,6 @@ pub(crate) async fn get( Ok(( cookie_jar, url_builder.redirect(&mas_router::UpstreamOAuth2Link::new(link.id)), - )) + ) + .into_response()) } diff --git a/crates/handlers/src/upstream_oauth2/cookie.rs b/crates/handlers/src/upstream_oauth2/cookie.rs index 6a9769865..cbcfb5148 100644 --- a/crates/handlers/src/upstream_oauth2/cookie.rs +++ b/crates/handlers/src/upstream_oauth2/cookie.rs @@ -61,6 +61,11 @@ impl UpstreamSessions { } } + /// Returns true if the cookie is empty + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + /// Save the upstreams sessions to the cookie jar pub fn save(self, cookie_jar: CookieJar, clock: &C) -> CookieJar where diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index 3f7407e72..dfc42043b 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -38,7 +38,10 @@ use thiserror::Error; use tracing::warn; use ulid::Ulid; -use super::{template::environment, UpstreamSessionsCookie}; +use super::{ + template::{environment, AttributeMappingContext}, + UpstreamSessionsCookie, +}; use crate::{ impl_from_error_for_route, views::shared::OptionalPostAuthAction, PreferredLanguage, SiteConfig, }; @@ -130,9 +133,10 @@ impl IntoResponse for RouteError { fn render_attribute_template( environment: &Environment, template: &str, + context: &minijinja::Value, required: bool, ) -> Result, RouteError> { - match environment.render_str(template, ()) { + match environment.render_str(template, context) { Ok(value) if value.is_empty() => { if required { return Err(RouteError::RequiredAttributeEmpty { @@ -320,10 +324,7 @@ pub(crate) async fn get( (None, None) => { // Session not linked and used not logged in: suggest creating an // account or logging in an existing user - let id_token = upstream_session - .id_token() - .map(Jwt::<'_, minijinja::Value>::try_from) - .transpose()?; + let id_token = upstream_session.id_token().map(Jwt::try_from).transpose()?; let provider = repo .upstream_oauth_provider() @@ -331,17 +332,19 @@ pub(crate) async fn get( .await? .ok_or(RouteError::ProviderNotFound)?; - let payload = id_token - .map(|id_token| id_token.into_parts().1) - .unwrap_or_default(); - let ctx = UpstreamRegister::default(); - let env = { - let mut e = environment(); - e.add_global("user", payload); - e - }; + let env = environment(); + + let mut context = AttributeMappingContext::new(); + if let Some(id_token) = id_token { + let (_, payload) = id_token.into_parts(); + context = context.with_id_token_claims(payload); + } + if let Some(extra_callback_parameters) = upstream_session.extra_callback_parameters() { + context = context.with_extra_callback_parameters(extra_callback_parameters.clone()); + } + let context = context.build(); let ctx = if provider.claims_imports.displayname.ignore() { ctx @@ -356,6 +359,7 @@ pub(crate) async fn get( match render_attribute_template( &env, template, + &context, provider.claims_imports.displayname.is_required(), )? { Some(value) => ctx @@ -377,6 +381,7 @@ pub(crate) async fn get( match render_attribute_template( &env, template, + &context, provider.claims_imports.email.is_required(), )? { Some(value) => ctx.with_email(value, provider.claims_imports.email.is_forced()), @@ -397,6 +402,7 @@ pub(crate) async fn get( match render_attribute_template( &env, template, + &context, provider.claims_imports.localpart.is_required(), )? { Some(localpart) => { @@ -557,10 +563,7 @@ pub(crate) async fn post( let import_display_name = import_display_name.is_some(); let accept_terms = accept_terms.is_some(); - let id_token = upstream_session - .id_token() - .map(Jwt::<'_, minijinja::Value>::try_from) - .transpose()?; + let id_token = upstream_session.id_token().map(Jwt::try_from).transpose()?; let provider = repo .upstream_oauth_provider() @@ -568,22 +571,23 @@ pub(crate) async fn post( .await? .ok_or(RouteError::ProviderNotFound)?; - let payload = id_token - .map(|id_token| id_token.into_parts().1) - .unwrap_or_default(); + // Let's try to import the claims from the ID token + let env = environment(); - // Is the email verified according to the upstream provider? - let provider_email_verified = payload - .get_item(&minijinja::Value::from("email_verified")) - .map(|v| v.is_true()) - .unwrap_or(false); + let mut context = AttributeMappingContext::new(); + if let Some(id_token) = id_token { + let (_, payload) = id_token.into_parts(); + context = context.with_id_token_claims(payload); + } + if let Some(extra_callback_parameters) = upstream_session.extra_callback_parameters() { + context = context.with_extra_callback_parameters(extra_callback_parameters.clone()); + } + let context = context.build(); - // Let's try to import the claims from the ID token - let env = { - let mut e = environment(); - e.add_global("user", payload); - e - }; + // Is the email verified according to the upstream provider? + let provider_email_verified = env + .render_str("{{ user.email_verified | string }}", &context) + .map_or(false, |v| v == "true"); // Create a template context in case we need to re-render because of an error let ctx = UpstreamRegister::default(); @@ -603,6 +607,7 @@ pub(crate) async fn post( render_attribute_template( &env, template, + &context, provider.claims_imports.displayname.is_required(), )? } else { @@ -629,6 +634,7 @@ pub(crate) async fn post( render_attribute_template( &env, template, + &context, provider.claims_imports.email.is_required(), )? } else { @@ -652,6 +658,7 @@ pub(crate) async fn post( render_attribute_template( &env, template, + &context, provider.claims_imports.email.is_required(), )? } else { @@ -843,8 +850,9 @@ mod tests { use hyper::{header::CONTENT_TYPE, Request, StatusCode}; use mas_data_model::{ UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderImportPreference, + UpstreamOAuthProviderTokenAuthMethod, }; - use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; + use mas_iana::jose::JsonWebSignatureAlg; use mas_jose::jwt::{JsonWebSignatureHeader, Jwt}; use mas_router::Route; use mas_storage::upstream_oauth2::UpstreamOAuthProviderParams; @@ -906,7 +914,7 @@ mod tests { human_name: Some("Example Ltd.".to_owned()), brand_name: None, scope: Scope::from_iter([OPENID]), - token_endpoint_auth_method: OAuthClientAuthenticationMethod::None, + token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None, token_endpoint_signing_alg: None, client_id: "client".to_owned(), encrypted_client_secret: None, @@ -916,6 +924,7 @@ mod tests { jwks_uri_override: None, discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc, pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto, + response_mode: mas_data_model::UpstreamOAuthProviderResponseMode::Query, additional_authorization_parameters: Vec::new(), }, ) @@ -943,7 +952,13 @@ mod tests { let session = repo .upstream_oauth_session() - .complete_with_link(&state.clock, session, &link, Some(id_token.into_string())) + .complete_with_link( + &state.clock, + session, + &link, + Some(id_token.into_string()), + None, + ) .await .unwrap(); diff --git a/crates/handlers/src/upstream_oauth2/mod.rs b/crates/handlers/src/upstream_oauth2/mod.rs index 19ca0dc48..758202dfc 100644 --- a/crates/handlers/src/upstream_oauth2/mod.rs +++ b/crates/handlers/src/upstream_oauth2/mod.rs @@ -6,10 +6,12 @@ use std::string::FromUtf8Error; -use mas_data_model::UpstreamOAuthProvider; -use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; +use mas_data_model::{UpstreamOAuthProvider, UpstreamOAuthProviderTokenAuthMethod}; +use mas_iana::jose::JsonWebSignatureAlg; use mas_keystore::{DecryptError, Encrypter, Keystore}; use mas_oidc_client::types::client_credentials::ClientCredentials; +use pkcs8::DecodePrivateKey; +use serde::Deserialize; use thiserror::Error; use url::Url; @@ -39,6 +41,25 @@ enum ProviderCredentialsError { #[from] inner: FromUtf8Error, }, + + #[error("Invalid JSON in client secret")] + InvalidClientSecretJson { + #[from] + inner: serde_json::Error, + }, + + #[error("Could not parse PEM encoded private key")] + InvalidPrivateKey { + #[from] + inner: pkcs8::Error, + }, +} + +#[derive(Debug, Deserialize)] +pub struct SignInWithApple { + pub private_key: String, + pub team_id: String, + pub key_id: String, } fn client_credentials_for_provider( @@ -61,28 +82,38 @@ fn client_credentials_for_provider( .transpose()?; let client_credentials = match provider.token_endpoint_auth_method { - OAuthClientAuthenticationMethod::None => ClientCredentials::None { client_id }, - OAuthClientAuthenticationMethod::ClientSecretPost => ClientCredentials::ClientSecretPost { - client_id, - client_secret: client_secret.ok_or(ProviderCredentialsError::MissingClientSecret)?, - }, - OAuthClientAuthenticationMethod::ClientSecretBasic => { + UpstreamOAuthProviderTokenAuthMethod::None => ClientCredentials::None { client_id }, + + UpstreamOAuthProviderTokenAuthMethod::ClientSecretPost => { + ClientCredentials::ClientSecretPost { + client_id, + client_secret: client_secret + .ok_or(ProviderCredentialsError::MissingClientSecret)?, + } + } + + UpstreamOAuthProviderTokenAuthMethod::ClientSecretBasic => { ClientCredentials::ClientSecretBasic { client_id, client_secret: client_secret .ok_or(ProviderCredentialsError::MissingClientSecret)?, } } - OAuthClientAuthenticationMethod::ClientSecretJwt => ClientCredentials::ClientSecretJwt { - client_id, - client_secret: client_secret.ok_or(ProviderCredentialsError::MissingClientSecret)?, - signing_algorithm: provider - .token_endpoint_signing_alg - .clone() - .unwrap_or(JsonWebSignatureAlg::Rs256), - token_endpoint: token_endpoint.clone(), - }, - OAuthClientAuthenticationMethod::PrivateKeyJwt => ClientCredentials::PrivateKeyJwt { + + UpstreamOAuthProviderTokenAuthMethod::ClientSecretJwt => { + ClientCredentials::ClientSecretJwt { + client_id, + client_secret: client_secret + .ok_or(ProviderCredentialsError::MissingClientSecret)?, + signing_algorithm: provider + .token_endpoint_signing_alg + .clone() + .unwrap_or(JsonWebSignatureAlg::Rs256), + token_endpoint: token_endpoint.clone(), + } + } + + UpstreamOAuthProviderTokenAuthMethod::PrivateKeyJwt => ClientCredentials::PrivateKeyJwt { client_id, keystore: keystore.clone(), signing_algorithm: provider @@ -91,8 +122,21 @@ fn client_credentials_for_provider( .unwrap_or(JsonWebSignatureAlg::Rs256), token_endpoint: token_endpoint.clone(), }, - // XXX: The database should never have an unsupported method in it - _ => unreachable!(), + + UpstreamOAuthProviderTokenAuthMethod::SignInWithApple => { + let params = client_secret.ok_or(ProviderCredentialsError::MissingClientSecret)?; + let params: SignInWithApple = serde_json::from_str(¶ms)?; + + let key = elliptic_curve::SecretKey::from_pkcs8_pem(¶ms.private_key)?; + + ClientCredentials::SignInWithApple { + client_id, + audience: provider.issuer.clone(), + key, + key_id: params.key_id, + team_id: params.team_id, + } + } }; Ok(client_credentials) diff --git a/crates/handlers/src/upstream_oauth2/template.rs b/crates/handlers/src/upstream_oauth2/template.rs index 740953f89..86942d76a 100644 --- a/crates/handlers/src/upstream_oauth2/template.rs +++ b/crates/handlers/src/upstream_oauth2/template.rs @@ -7,7 +7,76 @@ use std::{collections::HashMap, sync::Arc}; use base64ct::{Base64, Base64Unpadded, Base64Url, Base64UrlUnpadded, Encoding}; -use minijinja::{Environment, Error, ErrorKind, Value}; +use minijinja::{ + value::{Enumerator, Object}, + Environment, Error, ErrorKind, Value, +}; + +/// Context passed to the attribute mapping template +/// +/// The variables available in the template are: +/// - `user`: claims for the user, currently from the ID token. Later, we'll +/// also allow importing from the userinfo endpoint +/// - `id_token_claims`: claims from the ID token +/// - `extra_callback_parameters`: extra parameters passed to the callback +#[derive(Debug, Default)] +pub(crate) struct AttributeMappingContext { + id_token_claims: Option>, + extra_callback_parameters: Option, +} + +impl AttributeMappingContext { + pub fn new() -> Self { + Self::default() + } + + pub fn with_id_token_claims( + mut self, + id_token_claims: HashMap, + ) -> Self { + self.id_token_claims = Some(id_token_claims); + self + } + + pub fn with_extra_callback_parameters( + mut self, + extra_callback_parameters: serde_json::Value, + ) -> Self { + self.extra_callback_parameters = Some(extra_callback_parameters); + self + } + + pub fn build(self) -> Value { + Value::from_object(self) + } +} + +impl Object for AttributeMappingContext { + fn get_value(self: &Arc, name: &Value) -> Option { + match name.as_str()? { + "user" | "id_token_claims" => self.id_token_claims.as_ref().map(Value::from_serialize), + "extra_callback_parameters" => self + .extra_callback_parameters + .as_ref() + .map(Value::from_serialize), + _ => None, + } + } + + fn enumerate(self: &Arc) -> Enumerator { + match ( + self.id_token_claims.is_some(), + self.extra_callback_parameters.is_some(), + ) { + (true, true) => { + Enumerator::Str(&["user", "id_token_claims", "extra_callback_parameters"]) + } + (true, false) => Enumerator::Str(&["user", "id_token_claims"]), + (false, true) => Enumerator::Str(&["extra_callback_parameters"]), + (false, false) => Enumerator::Str(&["user"]), + } + } +} fn b64decode(value: &str) -> Result { // We're not too concerned about the performance of this filter, so we'll just @@ -68,6 +137,18 @@ fn string(value: &Value) -> String { value.to_string() } +fn from_json(value: &str) -> Result { + let value: serde_json::Value = serde_json::from_str(value).map_err(|e| { + minijinja::Error::new( + minijinja::ErrorKind::InvalidOperation, + "Failed to decode JSON", + ) + .with_source(e) + })?; + + Ok(Value::from_serialize(value)) +} + pub fn environment() -> Environment<'static> { let mut env = Environment::new(); @@ -77,6 +158,7 @@ pub fn environment() -> Environment<'static> { env.add_filter("b64encode", b64encode); env.add_filter("tlvdecode", tlvdecode); env.add_filter("string", string); + env.add_filter("from_json", from_json); env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback); diff --git a/crates/handlers/src/views/login.rs b/crates/handlers/src/views/login.rs index d2fba535d..edc27065a 100644 --- a/crates/handlers/src/views/login.rs +++ b/crates/handlers/src/views/login.rs @@ -343,8 +343,9 @@ mod test { header::{CONTENT_TYPE, LOCATION}, Request, StatusCode, }; - use mas_data_model::UpstreamOAuthProviderClaimsImports; - use mas_iana::oauth::OAuthClientAuthenticationMethod; + use mas_data_model::{ + UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderTokenAuthMethod, + }; use mas_router::Route; use mas_storage::{ upstream_oauth2::{UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository}, @@ -400,7 +401,7 @@ mod test { human_name: Some("First Ltd.".to_owned()), brand_name: None, scope: [OPENID].into_iter().collect(), - token_endpoint_auth_method: OAuthClientAuthenticationMethod::None, + token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None, token_endpoint_signing_alg: None, client_id: "client".to_owned(), encrypted_client_secret: None, @@ -410,6 +411,7 @@ mod test { jwks_uri_override: None, discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc, pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto, + response_mode: mas_data_model::UpstreamOAuthProviderResponseMode::Query, additional_authorization_parameters: Vec::new(), }, ) @@ -435,7 +437,7 @@ mod test { human_name: None, brand_name: None, scope: [OPENID].into_iter().collect(), - token_endpoint_auth_method: OAuthClientAuthenticationMethod::None, + token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None, token_endpoint_signing_alg: None, client_id: "client".to_owned(), encrypted_client_secret: None, @@ -445,6 +447,7 @@ mod test { jwks_uri_override: None, discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc, pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto, + response_mode: mas_data_model::UpstreamOAuthProviderResponseMode::Query, additional_authorization_parameters: Vec::new(), }, ) diff --git a/crates/jose/Cargo.toml b/crates/jose/Cargo.toml index 4a9bf6384..0127e0e20 100644 --- a/crates/jose/Cargo.toml +++ b/crates/jose/Cargo.toml @@ -16,7 +16,7 @@ base64ct = { version = "1.6.0", features = ["std"] } chrono.workspace = true digest = "0.10.7" ecdsa = { version = "0.16.9", features = ["signing", "verifying"] } -elliptic-curve = "0.13.8" +elliptic-curve.workspace = true generic-array = "0.14.7" hmac = "0.12.1" k256 = { version = "0.13.4", features = ["ecdsa"] } diff --git a/crates/keystore/Cargo.toml b/crates/keystore/Cargo.toml index a64bb6853..3d2a914d3 100644 --- a/crates/keystore/Cargo.toml +++ b/crates/keystore/Cargo.toml @@ -16,13 +16,13 @@ aead = { version = "0.5.2", features = ["std"] } const-oid = { version = "0.9.6", features = ["std"] } der = { version = "0.7.9", features = ["std"] } ecdsa = { version = "0.16.9", features = ["std"] } -elliptic-curve = { version = "0.13.8", features = ["std", "pem", "sec1"] } -k256 = { version = "0.13.4", features = ["std"] } -p256 = { version = "0.13.2", features = ["std"] } -p384 = { version = "0.13.0", features = ["std"] } -pem-rfc7468 = { version = "0.7.0", features = ["std"] } -pkcs1 = { version = "0.7.5", features = ["std"] } -pkcs8 = { version = "0.10.2", features = ["std", "pkcs5", "encryption"] } +elliptic-curve.workspace = true +k256.workspace = true +p256.workspace = true +p384.workspace = true +pem-rfc7468.workspace = true +pkcs1.workspace = true +pkcs8.workspace = true rand.workspace = true rsa = { version = "0.9.6", features = ["std", "pem"] } sec1 = { version = "0.7.3", features = ["std"] } diff --git a/crates/oidc-client/Cargo.toml b/crates/oidc-client/Cargo.toml index 224c1b9eb..fd656525c 100644 --- a/crates/oidc-client/Cargo.toml +++ b/crates/oidc-client/Cargo.toml @@ -15,11 +15,15 @@ workspace = true async-trait.workspace = true base64ct = { version = "1.6.0", features = ["std"] } chrono.workspace = true +elliptic-curve.workspace = true form_urlencoded = "1.2.1" headers.workspace = true http.workspace = true language-tags = "0.3.2" mime = "0.3.17" +pem-rfc7468.workspace = true +pkcs8.workspace = true +p256.workspace = true rand.workspace = true reqwest.workspace = true serde.workspace = true diff --git a/crates/oidc-client/src/requests/authorization_code.rs b/crates/oidc-client/src/requests/authorization_code.rs index f3b2f1965..b11773e8a 100644 --- a/crates/oidc-client/src/requests/authorization_code.rs +++ b/crates/oidc-client/src/requests/authorization_code.rs @@ -20,7 +20,7 @@ use oauth2_types::{ prelude::CodeChallengeMethodExt, requests::{ AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant, AuthorizationRequest, - Display, Prompt, + Display, Prompt, ResponseMode, }, scope::{Scope, OPENID}, }; @@ -89,6 +89,9 @@ pub struct AuthorizationRequestData { /// Requested Authentication Context Class Reference values. pub acr_values: Option>, + + /// Requested response mode. + pub response_mode: Option, } impl AuthorizationRequestData { @@ -108,6 +111,7 @@ impl AuthorizationRequestData { id_token_hint: None, login_hint: None, acr_values: None, + response_mode: None, } } @@ -170,6 +174,13 @@ impl AuthorizationRequestData { self.acr_values = Some(acr_values); self } + + /// Set the `response_mode` field of this `AuthorizationRequestData`. + #[must_use] + pub fn with_response_mode(mut self, response_mode: ResponseMode) -> Self { + self.response_mode = Some(response_mode); + self + } } /// The data necessary to validate a response from the Token endpoint in the @@ -215,6 +226,7 @@ fn build_authorization_request( id_token_hint, login_hint, acr_values, + response_mode, } = authorization_data; // Generate a random CSRF "state" token and a nonce. @@ -252,7 +264,7 @@ fn build_authorization_request( redirect_uri: Some(redirect_uri.clone()), scope, state: Some(state.clone()), - response_mode: None, + response_mode, nonce: Some(nonce.clone()), display, prompt, diff --git a/crates/oidc-client/src/types/client_credentials.rs b/crates/oidc-client/src/types/client_credentials.rs index b7095b576..b0378d6cc 100644 --- a/crates/oidc-client/src/types/client_credentials.rs +++ b/crates/oidc-client/src/types/client_credentials.rs @@ -14,7 +14,7 @@ use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod use mas_jose::{ claims::{self, ClaimError}, constraints::Constrainable, - jwa::SymmetricKey, + jwa::{AsymmetricSigningKey, SymmetricKey}, jwt::{JsonWebSignatureHeader, Jwt}, }; use mas_keystore::Keystore; @@ -97,6 +97,24 @@ pub enum ClientCredentials { /// The URL of the issuer's Token endpoint. token_endpoint: Url, }, + + /// The client authenticates like Sign in with Apple wants + SignInWithApple { + /// The unique ID for the client. + client_id: String, + + /// The audience to use. Usually `https://appleid.apple.com` + audience: String, + + /// The ECDSA key used to sign + key: elliptic_curve::SecretKey, + + /// The key ID + key_id: String, + + /// The Apple Team ID + team_id: String, + }, } impl ClientCredentials { @@ -108,12 +126,14 @@ impl ClientCredentials { | ClientCredentials::ClientSecretBasic { client_id, .. } | ClientCredentials::ClientSecretPost { client_id, .. } | ClientCredentials::ClientSecretJwt { client_id, .. } - | ClientCredentials::PrivateKeyJwt { client_id, .. } => client_id, + | ClientCredentials::PrivateKeyJwt { client_id, .. } + | ClientCredentials::SignInWithApple { client_id, .. } => client_id, } } /// Apply these [`ClientCredentials`] to the given request with the given /// form. + #[allow(clippy::too_many_lines)] pub(crate) fn authenticated_form( &self, request: reqwest::RequestBuilder, @@ -217,6 +237,39 @@ impl ClientCredentials { client_assertion_type: Some(JwtBearerClientAssertionType), }) } + + ClientCredentials::SignInWithApple { + client_id, + audience, + key, + key_id, + team_id, + } => { + // SIWA expects a signed JWT as client secret + // https://developer.apple.com/documentation/accountorganizationaldatasharing/creating-a-client-secret + let signer = AsymmetricSigningKey::es256(key.clone()); + + let mut claims = HashMap::new(); + + claims::ISS.insert(&mut claims, team_id)?; + claims::SUB.insert(&mut claims, client_id)?; + claims::AUD.insert(&mut claims, audience.clone())?; + claims::IAT.insert(&mut claims, now)?; + claims::EXP.insert(&mut claims, now + Duration::microseconds(60 * 1000 * 1000))?; + + let header = + JsonWebSignatureHeader::new(JsonWebSignatureAlg::Es256).with_kid(key_id); + + let client_secret = Jwt::sign(header, claims, &signer)?; + + request.form(&RequestWithClientCredentials { + body: form, + client_id, + client_secret: Some(client_secret.as_str()), + client_assertion: None, + client_assertion_type: None, + }) + } }; Ok(request) @@ -260,6 +313,17 @@ impl fmt::Debug for ClientCredentials { .field("signing_algorithm", signing_algorithm) .field("token_endpoint", token_endpoint) .finish_non_exhaustive(), + Self::SignInWithApple { + client_id, + key_id, + team_id, + .. + } => f + .debug_struct("SignInWithApple") + .field("client_id", client_id) + .field("key_id", key_id) + .field("team_id", team_id) + .finish_non_exhaustive(), } } } diff --git a/crates/storage-pg/.sqlx/query-b9875a270f7e753e48075ccae233df6e24a91775ceb877735508c1d5b2300d64.json b/crates/storage-pg/.sqlx/query-5516235e0983fb64d18e82dbe3e34f966ed71a0ed59be0d48ec66fedf64e707d.json similarity index 62% rename from crates/storage-pg/.sqlx/query-b9875a270f7e753e48075ccae233df6e24a91775ceb877735508c1d5b2300d64.json rename to crates/storage-pg/.sqlx/query-5516235e0983fb64d18e82dbe3e34f966ed71a0ed59be0d48ec66fedf64e707d.json index 3a4483604..96ad3513a 100644 --- a/crates/storage-pg/.sqlx/query-b9875a270f7e753e48075ccae233df6e24a91775ceb877735508c1d5b2300d64.json +++ b/crates/storage-pg/.sqlx/query-5516235e0983fb64d18e82dbe3e34f966ed71a0ed59be0d48ec66fedf64e707d.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n UPDATE upstream_oauth_authorization_sessions\n SET upstream_oauth_link_id = $1,\n completed_at = $2,\n id_token = $3\n WHERE upstream_oauth_authorization_session_id = $4\n ", + "query": "\n UPDATE upstream_oauth_authorization_sessions\n SET upstream_oauth_link_id = $1,\n completed_at = $2,\n id_token = $3,\n extra_callback_parameters = $4\n WHERE upstream_oauth_authorization_session_id = $5\n ", "describe": { "columns": [], "parameters": { @@ -8,10 +8,11 @@ "Uuid", "Timestamptz", "Text", + "Jsonb", "Uuid" ] }, "nullable": [] }, - "hash": "b9875a270f7e753e48075ccae233df6e24a91775ceb877735508c1d5b2300d64" + "hash": "5516235e0983fb64d18e82dbe3e34f966ed71a0ed59be0d48ec66fedf64e707d" } diff --git a/crates/storage-pg/.sqlx/query-5d9f3d47ce6164b3f81aa09ef4fd8d5cd070945fd497d209ac1df99abcfb7c5d.json b/crates/storage-pg/.sqlx/query-6b133c3c6bfc3c80a21f6f72d0a6468f748ed59e88d8c904bb0a4bbfee43a67f.json similarity index 87% rename from crates/storage-pg/.sqlx/query-5d9f3d47ce6164b3f81aa09ef4fd8d5cd070945fd497d209ac1df99abcfb7c5d.json rename to crates/storage-pg/.sqlx/query-6b133c3c6bfc3c80a21f6f72d0a6468f748ed59e88d8c904bb0a4bbfee43a67f.json index f6f2b0dbc..1c56ec95b 100644 --- a/crates/storage-pg/.sqlx/query-5d9f3d47ce6164b3f81aa09ef4fd8d5cd070945fd497d209ac1df99abcfb7c5d.json +++ b/crates/storage-pg/.sqlx/query-6b133c3c6bfc3c80a21f6f72d0a6468f748ed59e88d8c904bb0a4bbfee43a67f.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at,\n disabled_at,\n claims_imports as \"claims_imports: Json\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n discovery_mode,\n pkce_mode,\n additional_parameters as \"additional_parameters: Json>\"\n FROM upstream_oauth_providers\n WHERE disabled_at IS NULL\n ", + "query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at,\n disabled_at,\n claims_imports as \"claims_imports: Json\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n discovery_mode,\n pkce_mode,\n response_mode,\n additional_parameters as \"additional_parameters: Json>\"\n FROM upstream_oauth_providers\n WHERE disabled_at IS NULL\n ", "describe": { "columns": [ { @@ -90,6 +90,11 @@ }, { "ordinal": 17, + "name": "response_mode", + "type_info": "Text" + }, + { + "ordinal": 18, "name": "additional_parameters: Json>", "type_info": "Jsonb" } @@ -115,8 +120,9 @@ true, false, false, + false, true ] }, - "hash": "5d9f3d47ce6164b3f81aa09ef4fd8d5cd070945fd497d209ac1df99abcfb7c5d" + "hash": "6b133c3c6bfc3c80a21f6f72d0a6468f748ed59e88d8c904bb0a4bbfee43a67f" } diff --git a/crates/storage-pg/.sqlx/query-51b204376c63671a47b73ee8b3f8e669f90933f7e81ba744dca88d6bb94bf96a.json b/crates/storage-pg/.sqlx/query-73f4e5a724a432f1328c6112185cdc9c7a1ae1de45a2a8c02e7a2b8020b41160.json similarity index 87% rename from crates/storage-pg/.sqlx/query-51b204376c63671a47b73ee8b3f8e669f90933f7e81ba744dca88d6bb94bf96a.json rename to crates/storage-pg/.sqlx/query-73f4e5a724a432f1328c6112185cdc9c7a1ae1de45a2a8c02e7a2b8020b41160.json index ac9f681cd..866790c89 100644 --- a/crates/storage-pg/.sqlx/query-51b204376c63671a47b73ee8b3f8e669f90933f7e81ba744dca88d6bb94bf96a.json +++ b/crates/storage-pg/.sqlx/query-73f4e5a724a432f1328c6112185cdc9c7a1ae1de45a2a8c02e7a2b8020b41160.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at,\n disabled_at,\n claims_imports as \"claims_imports: Json\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n discovery_mode,\n pkce_mode,\n additional_parameters as \"additional_parameters: Json>\"\n FROM upstream_oauth_providers\n WHERE upstream_oauth_provider_id = $1\n ", + "query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at,\n disabled_at,\n claims_imports as \"claims_imports: Json\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n discovery_mode,\n pkce_mode,\n response_mode,\n additional_parameters as \"additional_parameters: Json>\"\n FROM upstream_oauth_providers\n WHERE upstream_oauth_provider_id = $1\n ", "describe": { "columns": [ { @@ -90,6 +90,11 @@ }, { "ordinal": 17, + "name": "response_mode", + "type_info": "Text" + }, + { + "ordinal": 18, "name": "additional_parameters: Json>", "type_info": "Jsonb" } @@ -117,8 +122,9 @@ true, false, false, + false, true ] }, - "hash": "51b204376c63671a47b73ee8b3f8e669f90933f7e81ba744dca88d6bb94bf96a" + "hash": "73f4e5a724a432f1328c6112185cdc9c7a1ae1de45a2a8c02e7a2b8020b41160" } diff --git a/crates/storage-pg/.sqlx/query-67ab838035946ddc15b43dd2f79d10b233d07e863b3a5c776c5db97cff263c8c.json b/crates/storage-pg/.sqlx/query-7d329e0c57f36b9ffe2aa7ddf4a21e293522c00009cca0222524b0c73f6eee30.json similarity index 76% rename from crates/storage-pg/.sqlx/query-67ab838035946ddc15b43dd2f79d10b233d07e863b3a5c776c5db97cff263c8c.json rename to crates/storage-pg/.sqlx/query-7d329e0c57f36b9ffe2aa7ddf4a21e293522c00009cca0222524b0c73f6eee30.json index 0b378d382..57fc34e9b 100644 --- a/crates/storage-pg/.sqlx/query-67ab838035946ddc15b43dd2f79d10b233d07e863b3a5c776c5db97cff263c8c.json +++ b/crates/storage-pg/.sqlx/query-7d329e0c57f36b9ffe2aa7ddf4a21e293522c00009cca0222524b0c73f6eee30.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n SELECT\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n upstream_oauth_link_id,\n state,\n code_challenge_verifier,\n nonce,\n id_token,\n created_at,\n completed_at,\n consumed_at\n FROM upstream_oauth_authorization_sessions\n WHERE upstream_oauth_authorization_session_id = $1\n ", + "query": "\n SELECT\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n upstream_oauth_link_id,\n state,\n code_challenge_verifier,\n nonce,\n id_token,\n extra_callback_parameters,\n created_at,\n completed_at,\n consumed_at\n FROM upstream_oauth_authorization_sessions\n WHERE upstream_oauth_authorization_session_id = $1\n ", "describe": { "columns": [ { @@ -40,16 +40,21 @@ }, { "ordinal": 7, + "name": "extra_callback_parameters", + "type_info": "Jsonb" + }, + { + "ordinal": 8, "name": "created_at", "type_info": "Timestamptz" }, { - "ordinal": 8, + "ordinal": 9, "name": "completed_at", "type_info": "Timestamptz" }, { - "ordinal": 9, + "ordinal": 10, "name": "consumed_at", "type_info": "Timestamptz" } @@ -67,10 +72,11 @@ true, false, true, + true, false, true, true ] }, - "hash": "67ab838035946ddc15b43dd2f79d10b233d07e863b3a5c776c5db97cff263c8c" + "hash": "7d329e0c57f36b9ffe2aa7ddf4a21e293522c00009cca0222524b0c73f6eee30" } diff --git a/crates/storage-pg/.sqlx/query-9aa8fa3a6277f67b2bf5a5ea5429a61e7997ff4f3e8d0dc772448a1f97e1e390.json b/crates/storage-pg/.sqlx/query-9aa8fa3a6277f67b2bf5a5ea5429a61e7997ff4f3e8d0dc772448a1f97e1e390.json deleted file mode 100644 index c016eb215..000000000 --- a/crates/storage-pg/.sqlx/query-9aa8fa3a6277f67b2bf5a5ea5429a61e7997ff4f3e8d0dc772448a1f97e1e390.json +++ /dev/null @@ -1,38 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n claims_imports,\n authorization_endpoint_override,\n token_endpoint_override,\n jwks_uri_override,\n discovery_mode,\n pkce_mode,\n additional_parameters,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9,\n $10, $11, $12, $13, $14, $15, $16, $17)\n ON CONFLICT (upstream_oauth_provider_id)\n DO UPDATE\n SET\n issuer = EXCLUDED.issuer,\n human_name = EXCLUDED.human_name,\n brand_name = EXCLUDED.brand_name,\n scope = EXCLUDED.scope,\n token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,\n token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,\n disabled_at = NULL,\n client_id = EXCLUDED.client_id,\n encrypted_client_secret = EXCLUDED.encrypted_client_secret,\n claims_imports = EXCLUDED.claims_imports,\n authorization_endpoint_override = EXCLUDED.authorization_endpoint_override,\n token_endpoint_override = EXCLUDED.token_endpoint_override,\n jwks_uri_override = EXCLUDED.jwks_uri_override,\n discovery_mode = EXCLUDED.discovery_mode,\n pkce_mode = EXCLUDED.pkce_mode,\n additional_parameters = EXCLUDED.additional_parameters\n RETURNING created_at\n ", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "created_at", - "type_info": "Timestamptz" - } - ], - "parameters": { - "Left": [ - "Uuid", - "Text", - "Text", - "Text", - "Text", - "Text", - "Text", - "Text", - "Text", - "Jsonb", - "Text", - "Text", - "Text", - "Text", - "Text", - "Jsonb", - "Timestamptz" - ] - }, - "nullable": [ - false - ] - }, - "hash": "9aa8fa3a6277f67b2bf5a5ea5429a61e7997ff4f3e8d0dc772448a1f97e1e390" -} diff --git a/crates/storage-pg/.sqlx/query-e36ed76d0176edf8c4a029f017b8f368a529b2d32a54c52f6a28b9e615716f4d.json b/crates/storage-pg/.sqlx/query-e36ed76d0176edf8c4a029f017b8f368a529b2d32a54c52f6a28b9e615716f4d.json new file mode 100644 index 000000000..12bfb5e3b --- /dev/null +++ b/crates/storage-pg/.sqlx/query-e36ed76d0176edf8c4a029f017b8f368a529b2d32a54c52f6a28b9e615716f4d.json @@ -0,0 +1,39 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n claims_imports,\n authorization_endpoint_override,\n token_endpoint_override,\n jwks_uri_override,\n discovery_mode,\n pkce_mode,\n response_mode,\n additional_parameters,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9,\n $10, $11, $12, $13, $14, $15, $16, $17, $18)\n ON CONFLICT (upstream_oauth_provider_id)\n DO UPDATE\n SET\n issuer = EXCLUDED.issuer,\n human_name = EXCLUDED.human_name,\n brand_name = EXCLUDED.brand_name,\n scope = EXCLUDED.scope,\n token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,\n token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,\n disabled_at = NULL,\n client_id = EXCLUDED.client_id,\n encrypted_client_secret = EXCLUDED.encrypted_client_secret,\n claims_imports = EXCLUDED.claims_imports,\n authorization_endpoint_override = EXCLUDED.authorization_endpoint_override,\n token_endpoint_override = EXCLUDED.token_endpoint_override,\n jwks_uri_override = EXCLUDED.jwks_uri_override,\n discovery_mode = EXCLUDED.discovery_mode,\n pkce_mode = EXCLUDED.pkce_mode,\n response_mode = EXCLUDED.response_mode,\n additional_parameters = EXCLUDED.additional_parameters\n RETURNING created_at\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "created_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Uuid", + "Text", + "Text", + "Text", + "Text", + "Text", + "Text", + "Text", + "Text", + "Jsonb", + "Text", + "Text", + "Text", + "Text", + "Text", + "Text", + "Jsonb", + "Timestamptz" + ] + }, + "nullable": [ + false + ] + }, + "hash": "e36ed76d0176edf8c4a029f017b8f368a529b2d32a54c52f6a28b9e615716f4d" +} diff --git a/crates/storage-pg/.sqlx/query-1f131aa966a4358d83e7247d3e30451f8bcf5df20faf46a4a4c0d4a36d1ff173.json b/crates/storage-pg/.sqlx/query-ebb1a78003293376a52de830f89f6f526ad1c5c823328463a6525d3c3d0d95c6.json similarity index 75% rename from crates/storage-pg/.sqlx/query-1f131aa966a4358d83e7247d3e30451f8bcf5df20faf46a4a4c0d4a36d1ff173.json rename to crates/storage-pg/.sqlx/query-ebb1a78003293376a52de830f89f6f526ad1c5c823328463a6525d3c3d0d95c6.json index 2a92e950c..aa2a95c1f 100644 --- a/crates/storage-pg/.sqlx/query-1f131aa966a4358d83e7247d3e30451f8bcf5df20faf46a4a4c0d4a36d1ff173.json +++ b/crates/storage-pg/.sqlx/query-ebb1a78003293376a52de830f89f6f526ad1c5c823328463a6525d3c3d0d95c6.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n claims_imports,\n authorization_endpoint_override,\n token_endpoint_override,\n jwks_uri_override,\n discovery_mode,\n pkce_mode,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9,\n $10, $11, $12, $13, $14, $15, $16)\n ", + "query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n claims_imports,\n authorization_endpoint_override,\n token_endpoint_override,\n jwks_uri_override,\n discovery_mode,\n pkce_mode,\n response_mode,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9,\n $10, $11, $12, $13, $14, $15, $16, $17)\n ", "describe": { "columns": [], "parameters": { @@ -20,10 +20,11 @@ "Text", "Text", "Text", + "Text", "Timestamptz" ] }, "nullable": [] }, - "hash": "1f131aa966a4358d83e7247d3e30451f8bcf5df20faf46a4a4c0d4a36d1ff173" + "hash": "ebb1a78003293376a52de830f89f6f526ad1c5c823328463a6525d3c3d0d95c6" } diff --git a/crates/storage-pg/migrations/20241115163340_upstream_oauth2_response_mode.sql b/crates/storage-pg/migrations/20241115163340_upstream_oauth2_response_mode.sql new file mode 100644 index 000000000..8d65bd7f5 --- /dev/null +++ b/crates/storage-pg/migrations/20241115163340_upstream_oauth2_response_mode.sql @@ -0,0 +1,8 @@ +-- Copyright 2024 New Vector Ltd. +-- +-- SPDX-License-Identifier: AGPL-3.0-only +-- Please see LICENSE in the repository root for full details. + +-- Add the response_mode column to the upstream_oauth_providers table +ALTER TABLE "upstream_oauth_providers" + ADD COLUMN "response_mode" text NOT NULL DEFAULT 'query'; diff --git a/crates/storage-pg/migrations/20241118115314_upstream_oauth2_extra_query_params.sql b/crates/storage-pg/migrations/20241118115314_upstream_oauth2_extra_query_params.sql new file mode 100644 index 000000000..0e900e0af --- /dev/null +++ b/crates/storage-pg/migrations/20241118115314_upstream_oauth2_extra_query_params.sql @@ -0,0 +1,9 @@ +-- Copyright 2024 New Vector Ltd. +-- +-- SPDX-License-Identifier: AGPL-3.0-only +-- Please see LICENSE in the repository root for full details. + +-- Add a column to the upstream_oauth_authorization_sessions table to store +-- extra query parameters +ALTER TABLE "upstream_oauth_authorization_sessions" + ADD COLUMN "extra_callback_parameters" JSONB; diff --git a/crates/storage-pg/src/iden.rs b/crates/storage-pg/src/iden.rs index a87bef0db..b9a00333d 100644 --- a/crates/storage-pg/src/iden.rs +++ b/crates/storage-pg/src/iden.rs @@ -103,6 +103,7 @@ pub enum UpstreamOAuthProviders { ClaimsImports, DiscoveryMode, PkceMode, + ResponseMode, AdditionalParameters, JwksUriOverride, TokenEndpointOverride, diff --git a/crates/storage-pg/src/upstream_oauth2/mod.rs b/crates/storage-pg/src/upstream_oauth2/mod.rs index c9ca8afdc..a544c9cd3 100644 --- a/crates/storage-pg/src/upstream_oauth2/mod.rs +++ b/crates/storage-pg/src/upstream_oauth2/mod.rs @@ -19,7 +19,9 @@ pub use self::{ #[cfg(test)] mod tests { use chrono::Duration; - use mas_data_model::UpstreamOAuthProviderClaimsImports; + use mas_data_model::{ + UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderTokenAuthMethod, + }; use mas_storage::{ clock::MockClock, upstream_oauth2::{ @@ -57,8 +59,7 @@ mod tests { human_name: None, brand_name: None, scope: Scope::from_iter([OPENID]), - token_endpoint_auth_method: - mas_iana::oauth::OAuthClientAuthenticationMethod::None, + token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None, token_endpoint_signing_alg: None, client_id: "client-id".to_owned(), encrypted_client_secret: None, @@ -68,6 +69,7 @@ mod tests { jwks_uri_override: None, discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc, pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto, + response_mode: mas_data_model::UpstreamOAuthProviderResponseMode::Query, additional_authorization_parameters: Vec::new(), }, ) @@ -143,7 +145,7 @@ mod tests { let session = repo .upstream_oauth_session() - .complete_with_link(&clock, session, &link, None) + .complete_with_link(&clock, session, &link, None, None) .await .unwrap(); // Reload the session @@ -299,8 +301,7 @@ mod tests { human_name: None, brand_name: None, scope: scope.clone(), - token_endpoint_auth_method: - mas_iana::oauth::OAuthClientAuthenticationMethod::None, + token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None, token_endpoint_signing_alg: None, client_id, encrypted_client_secret: None, @@ -310,6 +311,7 @@ mod tests { jwks_uri_override: None, discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc, pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto, + response_mode: mas_data_model::UpstreamOAuthProviderResponseMode::Query, additional_authorization_parameters: Vec::new(), }, ) diff --git a/crates/storage-pg/src/upstream_oauth2/provider.rs b/crates/storage-pg/src/upstream_oauth2/provider.rs index 0c54992c3..4b384c95a 100644 --- a/crates/storage-pg/src/upstream_oauth2/provider.rs +++ b/crates/storage-pg/src/upstream_oauth2/provider.rs @@ -64,6 +64,7 @@ struct ProviderLookup { token_endpoint_override: Option, discovery_mode: String, pkce_mode: String, + response_mode: String, additional_parameters: Option>>, } @@ -141,6 +142,13 @@ impl TryFrom for UpstreamOAuthProvider { .source(e) })?; + let response_mode = value.response_mode.parse().map_err(|e| { + DatabaseInconsistencyError::on("upstream_oauth_providers") + .column("response_mode") + .row(id) + .source(e) + })?; + let additional_authorization_parameters = value .additional_parameters .map(|Json(x)| x) @@ -164,6 +172,7 @@ impl TryFrom for UpstreamOAuthProvider { jwks_uri_override, discovery_mode, pkce_mode, + response_mode, additional_authorization_parameters, }) } @@ -217,6 +226,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' token_endpoint_override, discovery_mode, pkce_mode, + response_mode, additional_parameters as "additional_parameters: Json>" FROM upstream_oauth_providers WHERE upstream_oauth_provider_id = $1 @@ -274,9 +284,10 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' jwks_uri_override, discovery_mode, pkce_mode, + response_mode, created_at ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, - $10, $11, $12, $13, $14, $15, $16) + $10, $11, $12, $13, $14, $15, $16, $17) "#, Uuid::from(id), ¶ms.issuer, @@ -302,6 +313,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' params.jwks_uri_override.as_ref().map(ToString::to_string), params.discovery_mode.as_str(), params.pkce_mode.as_str(), + params.response_mode.as_str(), created_at, ) .traced() @@ -326,6 +338,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' jwks_uri_override: params.jwks_uri_override, discovery_mode: params.discovery_mode, pkce_mode: params.pkce_mode, + response_mode: params.response_mode, additional_authorization_parameters: params.additional_authorization_parameters, }) } @@ -433,10 +446,11 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' jwks_uri_override, discovery_mode, pkce_mode, + response_mode, additional_parameters, created_at ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, - $10, $11, $12, $13, $14, $15, $16, $17) + $10, $11, $12, $13, $14, $15, $16, $17, $18) ON CONFLICT (upstream_oauth_provider_id) DO UPDATE SET @@ -455,6 +469,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' jwks_uri_override = EXCLUDED.jwks_uri_override, discovery_mode = EXCLUDED.discovery_mode, pkce_mode = EXCLUDED.pkce_mode, + response_mode = EXCLUDED.response_mode, additional_parameters = EXCLUDED.additional_parameters RETURNING created_at "#, @@ -482,6 +497,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' params.jwks_uri_override.as_ref().map(ToString::to_string), params.discovery_mode.as_str(), params.pkce_mode.as_str(), + params.response_mode.as_str(), Json(¶ms.additional_authorization_parameters) as _, created_at, ) @@ -507,6 +523,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' jwks_uri_override: params.jwks_uri_override, discovery_mode: params.discovery_mode, pkce_mode: params.pkce_mode, + response_mode: params.response_mode, additional_authorization_parameters: params.additional_authorization_parameters, }) } @@ -676,6 +693,13 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' )), ProviderLookupIden::PkceMode, ) + .expr_as( + Expr::col(( + UpstreamOAuthProviders::Table, + UpstreamOAuthProviders::ResponseMode, + )), + ProviderLookupIden::ResponseMode, + ) .expr_as( Expr::col(( UpstreamOAuthProviders::Table, @@ -770,6 +794,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' token_endpoint_override, discovery_mode, pkce_mode, + response_mode, additional_parameters as "additional_parameters: Json>" FROM upstream_oauth_providers WHERE disabled_at IS NULL diff --git a/crates/storage-pg/src/upstream_oauth2/session.rs b/crates/storage-pg/src/upstream_oauth2/session.rs index fb27da5f8..e3f28b5f4 100644 --- a/crates/storage-pg/src/upstream_oauth2/session.rs +++ b/crates/storage-pg/src/upstream_oauth2/session.rs @@ -43,6 +43,7 @@ struct SessionLookup { created_at: DateTime, completed_at: Option>, consumed_at: Option>, + extra_callback_parameters: Option, } impl TryFrom for UpstreamOAuthAuthorizationSession { @@ -53,25 +54,32 @@ impl TryFrom for UpstreamOAuthAuthorizationSession { let state = match ( value.upstream_oauth_link_id, value.id_token, + value.extra_callback_parameters, value.completed_at, value.consumed_at, ) { - (None, None, None, None) => UpstreamOAuthAuthorizationSessionState::Pending, - (Some(link_id), id_token, Some(completed_at), None) => { + (None, None, None, None, None) => UpstreamOAuthAuthorizationSessionState::Pending, + (Some(link_id), id_token, extra_callback_parameters, Some(completed_at), None) => { UpstreamOAuthAuthorizationSessionState::Completed { completed_at, link_id: link_id.into(), id_token, + extra_callback_parameters, } } - (Some(link_id), id_token, Some(completed_at), Some(consumed_at)) => { - UpstreamOAuthAuthorizationSessionState::Consumed { - completed_at, - link_id: link_id.into(), - id_token, - consumed_at, - } - } + ( + Some(link_id), + id_token, + extra_callback_parameters, + Some(completed_at), + Some(consumed_at), + ) => UpstreamOAuthAuthorizationSessionState::Consumed { + completed_at, + link_id: link_id.into(), + id_token, + extra_callback_parameters, + consumed_at, + }, _ => { return Err( DatabaseInconsistencyError::on("upstream_oauth_authorization_sessions").row(id), @@ -119,6 +127,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> code_challenge_verifier, nonce, id_token, + extra_callback_parameters, created_at, completed_at, consumed_at @@ -216,6 +225,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, upstream_oauth_link: &UpstreamOAuthLink, id_token: Option, + extra_callback_parameters: Option, ) -> Result { let completed_at = clock.now(); @@ -224,12 +234,14 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> UPDATE upstream_oauth_authorization_sessions SET upstream_oauth_link_id = $1, completed_at = $2, - id_token = $3 - WHERE upstream_oauth_authorization_session_id = $4 + id_token = $3, + extra_callback_parameters = $4 + WHERE upstream_oauth_authorization_session_id = $5 "#, Uuid::from(upstream_oauth_link.id), completed_at, id_token, + extra_callback_parameters, Uuid::from(upstream_oauth_authorization_session.id), ) .traced() @@ -237,7 +249,12 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> .await?; let upstream_oauth_authorization_session = upstream_oauth_authorization_session - .complete(completed_at, upstream_oauth_link, id_token) + .complete( + completed_at, + upstream_oauth_link, + id_token, + extra_callback_parameters, + ) .map_err(DatabaseError::to_invalid_operation)?; Ok(upstream_oauth_authorization_session) diff --git a/crates/storage/src/upstream_oauth2/provider.rs b/crates/storage/src/upstream_oauth2/provider.rs index 7d7d4dbb0..a5309d5d1 100644 --- a/crates/storage/src/upstream_oauth2/provider.rs +++ b/crates/storage/src/upstream_oauth2/provider.rs @@ -9,9 +9,10 @@ use std::marker::PhantomData; use async_trait::async_trait; use mas_data_model::{ UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode, - UpstreamOAuthProviderPkceMode, + UpstreamOAuthProviderPkceMode, UpstreamOAuthProviderResponseMode, + UpstreamOAuthProviderTokenAuthMethod, }; -use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; +use mas_iana::jose::JsonWebSignatureAlg; use oauth2_types::scope::Scope; use rand_core::RngCore; use ulid::Ulid; @@ -35,7 +36,7 @@ pub struct UpstreamOAuthProviderParams { pub scope: Scope, /// The token endpoint authentication method - pub token_endpoint_auth_method: OAuthClientAuthenticationMethod, + pub token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod, /// The JWT signing algorithm to use when then `client_secret_jwt` or /// `private_key_jwt` authentication methods are used @@ -67,6 +68,9 @@ pub struct UpstreamOAuthProviderParams { /// How should PKCE be used pub pkce_mode: UpstreamOAuthProviderPkceMode, + /// What response mode it should ask + pub response_mode: UpstreamOAuthProviderResponseMode, + /// Additional parameters to include in the authorization request pub additional_authorization_parameters: Vec<(String, String)>, } diff --git a/crates/storage/src/upstream_oauth2/session.rs b/crates/storage/src/upstream_oauth2/session.rs index 80da6135a..a9a438a3a 100644 --- a/crates/storage/src/upstream_oauth2/session.rs +++ b/crates/storage/src/upstream_oauth2/session.rs @@ -74,6 +74,8 @@ pub trait UpstreamOAuthSessionRepository: Send + Sync { /// * `upstream_oauth_link`: the link to associate with the session /// * `id_token`: the ID token returned by the upstream OAuth provider, if /// present + /// * `extra_callback_parameters`: the extra query parameters returned in + /// the callback, if any /// /// # Errors /// @@ -84,6 +86,7 @@ pub trait UpstreamOAuthSessionRepository: Send + Sync { upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, upstream_oauth_link: &UpstreamOAuthLink, id_token: Option, + extra_callback_parameters: Option, ) -> Result; /// Mark a session as consumed @@ -127,6 +130,7 @@ repository_impl!(UpstreamOAuthSessionRepository: upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, upstream_oauth_link: &UpstreamOAuthLink, id_token: Option, + extra_callback_parameters: Option, ) -> Result; async fn consume( diff --git a/crates/templates/src/context.rs b/crates/templates/src/context.rs index 762930b64..3f384b303 100644 --- a/crates/templates/src/context.rs +++ b/crates/templates/src/context.rs @@ -1460,7 +1460,7 @@ impl TemplateContext for DeviceConsentContext { /// Context used by the `form_post.html` template #[derive(Serialize)] pub struct FormPostContext { - redirect_uri: Url, + redirect_uri: Option, params: T, } @@ -1473,7 +1473,7 @@ impl TemplateContext for FormPostContext { sample_params .into_iter() .map(|params| FormPostContext { - redirect_uri: "https://example.com/callback".parse().unwrap(), + redirect_uri: "https://example.com/callback".parse().ok(), params, }) .collect() @@ -1481,13 +1481,34 @@ impl TemplateContext for FormPostContext { } impl FormPostContext { - /// Constructs a context for the `form_post` response mode form - pub fn new(redirect_uri: Url, params: T) -> Self { + /// Constructs a context for the `form_post` response mode form for a given + /// URL + pub fn new_for_url(redirect_uri: Url, params: T) -> Self { Self { - redirect_uri, + redirect_uri: Some(redirect_uri), params, } } + + /// Constructs a context for the `form_post` response mode form for the + /// current URL + pub fn new_for_current_url(params: T) -> Self { + Self { + redirect_uri: None, + params, + } + } + + /// Add the language to the context + /// + /// This is usually implemented by the [`TemplateContext`] trait, but it is + /// annoying to make it work because of the generic parameter + pub fn with_language(self, lang: &DataLocale) -> WithLanguage { + WithLanguage { + lang: lang.to_string(), + inner: self, + } + } } /// Context used by the `error.html` template diff --git a/crates/templates/src/lib.rs b/crates/templates/src/lib.rs index 7d668a5ee..9314d4e33 100644 --- a/crates/templates/src/lib.rs +++ b/crates/templates/src/lib.rs @@ -368,7 +368,7 @@ register_templates! { pub fn render_reauth(WithLanguage>>) { "pages/reauth.html" } /// Render the form used by the form_post response mode - pub fn render_form_post(FormPostContext) { "form_post.html" } + pub fn render_form_post(WithLanguage>) { "form_post.html" } /// Render the HTML error page pub fn render_error(ErrorContext) { "pages/error.html" } diff --git a/docs/config.schema.json b/docs/config.schema.json index 701f44013..9324154ca 100644 --- a/docs/config.schema.json +++ b/docs/config.schema.json @@ -1859,6 +1859,14 @@ } ] }, + "sign_in_with_apple": { + "description": "Additional parameters for the `sign_in_with_apple` method", + "allOf": [ + { + "$ref": "#/definitions/SignInWithApple" + } + ] + }, "token_endpoint_auth_signing_alg": { "description": "The JWS algorithm to use when authenticating the client with the provider\n\nUsed by the `client_secret_jwt` and `private_key_jwt` methods", "allOf": [ @@ -1902,6 +1910,14 @@ "type": "string", "format": "uri" }, + "response_mode": { + "description": "The response mode we ask the provider to use for the callback", + "allOf": [ + { + "$ref": "#/definitions/ResponseMode" + } + ] + }, "claims_imports": { "description": "How claims should be imported from the `id_token` provided by the provider", "allOf": [ @@ -1956,9 +1972,38 @@ "enum": [ "private_key_jwt" ] + }, + { + "description": "`sign_in_with_apple`: a special method for Signin with Apple", + "type": "string", + "enum": [ + "sign_in_with_apple" + ] } ] }, + "SignInWithApple": { + "type": "object", + "required": [ + "key_id", + "private_key", + "team_id" + ], + "properties": { + "private_key": { + "description": "The private key used to sign the `id_token`", + "type": "string" + }, + "team_id": { + "description": "The Team ID of the Apple Developer Portal", + "type": "string" + }, + "key_id": { + "description": "The key ID of the Apple Developer Portal", + "type": "string" + } + } + }, "DiscoveryMode": { "description": "How to discover the provider's configuration", "oneOf": [ @@ -2011,6 +2056,25 @@ } ] }, + "ResponseMode": { + "description": "The response mode we ask the provider to use for the callback", + "oneOf": [ + { + "description": "`query`: The provider will send the response as a query string in the URL search parameters", + "type": "string", + "enum": [ + "query" + ] + }, + { + "description": "`form_post`: The provider will send the response as a POST request with the response parameters in the request body\n\n", + "type": "string", + "enum": [ + "form_post" + ] + } + ] + }, "ClaimsImports": { "description": "How claims should be imported", "type": "object", diff --git a/docs/setup/sso.md b/docs/setup/sso.md index cd359aef6..770b46db5 100644 --- a/docs/setup/sso.md +++ b/docs/setup/sso.md @@ -68,6 +68,39 @@ If there is only one upstream provider configured and the local password databas This section contains sample configurations for popular OIDC providers. +### Apple + +Sign-in with Apple uses special non-standard for authenticating clients, which requires a special configuration. + +```yaml +upstream_oauth2: + providers: + - client_id: 01JAYS74TCG3BTWKADN5Q4518C + client_name: "" # TO BE FILLED + scope: "openid name email" + response_mode: "form_post" + + token_endpoint_auth_method: "sign_in_with_apple" + sign_in_with_apple: + private_key: | + # Content of the PEM-encoded private key file, TO BE FILLED + team_id: "" # TO BE FILLED + key_id: "" # TO BE FILLED + + claims_imports: + localpart: + action: ignore + displayname: + action: suggest + # SiWA passes down the user infos as query parameters in the callback + # which is available in the extra_callback_parameters variable + template: | + {%- set user = extra_callback_parameters["user"] | from_json -%} + {{- user.name.firstName }} {{ user.name.lastName -}} + email: + action: suggest +``` + ### Authelia These instructions assume that you have already enabled the OIDC provider support in [Authelia](https://www.authelia.com/). diff --git a/templates/form_post.html b/templates/form_post.html index 87d4e1050..2fb62ead9 100644 --- a/templates/form_post.html +++ b/templates/form_post.html @@ -6,18 +6,27 @@ Please see LICENSE in the repository root for full details. -#} - - - - - Redirecting to client - - - -
- {% for key, value in params|items %} - - {% endfor %} -
- - +{% extends "base.html" %} + +{% block content %} +
+
+

{{ _("common.loading") }}

+
+
+ +
+ {% for key, value in params|items %} + + {% endfor %} + + +
+ + {# Submit the form in JavaScript on the next tick, so that if the browser + wants to display the placeholder instead of a blank page, it can #} + +{% endblock %} diff --git a/translations/en.json b/translations/en.json index 9b1c11a58..811aa9726 100644 --- a/translations/en.json +++ b/translations/en.json @@ -10,7 +10,7 @@ }, "continue": "Continue", "@continue": { - "context": "pages/account/emails/add.html:37:26-46, pages/account/emails/verify.html:52:26-46, pages/consent.html:55:28-48, pages/device_consent.html:121:13-33, pages/device_link.html:40:26-46, pages/login.html:58:30-50, pages/reauth.html:32:28-48, pages/recovery/start.html:38:26-46, pages/register.html:76:28-48, pages/sso.html:37:28-48" + "context": "form_post.html:25:28-48, pages/account/emails/add.html:37:26-46, pages/account/emails/verify.html:52:26-46, pages/consent.html:55:28-48, pages/device_consent.html:121:13-33, pages/device_link.html:40:26-46, pages/login.html:58:30-50, pages/reauth.html:32:28-48, pages/recovery/start.html:38:26-46, pages/register.html:76:28-48, pages/sso.html:37:28-48" }, "create_account": "Create Account", "@create_account": { @@ -77,6 +77,10 @@ "@email_address": { "context": "pages/account/emails/add.html:33:33-58, pages/recovery/start.html:34:33-58, pages/register.html:40:35-60, pages/upstream_oauth2/do_register.html:79:37-62" }, + "loading": "Loading…", + "@loading": { + "context": "form_post.html:14:27-46" + }, "mxid": "Matrix ID", "@mxid": { "context": "pages/upstream_oauth2/do_register.html:58:35-51"