diff --git a/crates/cli/src/sync.rs b/crates/cli/src/sync.rs index aefdead14..737f9cc01 100644 --- a/crates/cli/src/sync.rs +++ b/crates/cli/src/sync.rs @@ -284,10 +284,12 @@ pub async fn config_sync( encrypted_client_secret, claims_imports: map_claims_imports(&provider.claims_imports), token_endpoint_override: provider.token_endpoint, + userinfo_endpoint_override: provider.userinfo_endpoint, authorization_endpoint_override: provider.authorization_endpoint, jwks_uri_override: provider.jwks_uri, discovery_mode, pkce_mode, + fetch_userinfo: provider.fetch_userinfo, response_mode, additional_authorization_parameters: provider .additional_authorization_parameters diff --git a/crates/config/src/sections/upstream_oauth2.rs b/crates/config/src/sections/upstream_oauth2.rs index d91881d3b..ba18f6207 100644 --- a/crates/config/src/sections/upstream_oauth2.rs +++ b/crates/config/src/sections/upstream_oauth2.rs @@ -465,12 +465,26 @@ pub struct Provider { #[serde(default, skip_serializing_if = "PkceMethod::is_default")] pub pkce_method: PkceMethod, + /// Whether to fetch the user profile from the userinfo endpoint, + /// or to rely on the data returned in the `id_token` from the + /// `token_endpoint`. + /// + /// Defaults to `false`. + #[serde(default)] + pub fetch_userinfo: bool, + /// The URL to use for the provider's authorization endpoint /// /// Defaults to the `authorization_endpoint` provided through discovery #[serde(skip_serializing_if = "Option::is_none")] pub authorization_endpoint: Option, + /// The URL to use for the provider's userinfo endpoint + /// + /// Defaults to the `userinfo_endpoint` provided through discovery + #[serde(skip_serializing_if = "Option::is_none")] + pub userinfo_endpoint: Option, + /// The URL to use for the provider's token endpoint /// /// Defaults to the `token_endpoint` provided through discovery diff --git a/crates/data-model/src/upstream_oauth2/provider.rs b/crates/data-model/src/upstream_oauth2/provider.rs index 5656fd78f..bf939f96e 100644 --- a/crates/data-model/src/upstream_oauth2/provider.rs +++ b/crates/data-model/src/upstream_oauth2/provider.rs @@ -228,6 +228,8 @@ pub struct UpstreamOAuthProvider { pub authorization_endpoint_override: Option, pub scope: Scope, pub token_endpoint_override: Option, + pub userinfo_endpoint_override: Option, + pub fetch_userinfo: bool, pub client_id: String, pub encrypted_client_secret: Option, pub token_endpoint_signing_alg: Option, diff --git a/crates/data-model/src/upstream_oauth2/session.rs b/crates/data-model/src/upstream_oauth2/session.rs index 38b622987..23139c929 100644 --- a/crates/data-model/src/upstream_oauth2/session.rs +++ b/crates/data-model/src/upstream_oauth2/session.rs @@ -20,6 +20,7 @@ pub enum UpstreamOAuthAuthorizationSessionState { link_id: Ulid, id_token: Option, extra_callback_parameters: Option, + userinfo: Option, }, Consumed { completed_at: DateTime, @@ -27,6 +28,7 @@ pub enum UpstreamOAuthAuthorizationSessionState { link_id: Ulid, id_token: Option, extra_callback_parameters: Option, + userinfo: Option, }, } @@ -45,6 +47,7 @@ impl UpstreamOAuthAuthorizationSessionState { link: &UpstreamOAuthLink, id_token: Option, extra_callback_parameters: Option, + userinfo: Option, ) -> Result { match self { Self::Pending => Ok(Self::Completed { @@ -52,6 +55,7 @@ impl UpstreamOAuthAuthorizationSessionState { link_id: link.id, id_token, extra_callback_parameters, + userinfo, }), Self::Completed { .. } | Self::Consumed { .. } => Err(InvalidTransitionError), } @@ -72,12 +76,14 @@ impl UpstreamOAuthAuthorizationSessionState { link_id, id_token, extra_callback_parameters, + userinfo, } => Ok(Self::Consumed { completed_at, link_id, consumed_at, id_token, extra_callback_parameters, + userinfo, }), Self::Pending | Self::Consumed { .. } => Err(InvalidTransitionError), } @@ -151,6 +157,14 @@ impl UpstreamOAuthAuthorizationSessionState { } } + #[must_use] + pub fn userinfo(&self) -> Option<&serde_json::Value> { + match self { + Self::Pending => None, + Self::Completed { userinfo, .. } | Self::Consumed { userinfo, .. } => userinfo.as_ref(), + } + } + /// Get the time at which the upstream OAuth 2.0 authorization session was /// consumed. /// @@ -229,10 +243,15 @@ impl UpstreamOAuthAuthorizationSession { link: &UpstreamOAuthLink, id_token: Option, extra_callback_parameters: Option, + userinfo: Option, ) -> Result { - self.state = - self.state - .complete(completed_at, link, id_token, extra_callback_parameters)?; + self.state = self.state.complete( + completed_at, + link, + id_token, + extra_callback_parameters, + userinfo, + )?; Ok(self) } diff --git a/crates/handlers/src/upstream_oauth2/cache.rs b/crates/handlers/src/upstream_oauth2/cache.rs index 248271229..1e5f683e0 100644 --- a/crates/handlers/src/upstream_oauth2/cache.rs +++ b/crates/handlers/src/upstream_oauth2/cache.rs @@ -108,6 +108,18 @@ impl<'a> LazyProviderInfos<'a> { Ok(self.load().await?.token_endpoint()) } + /// Get the userinfo endpoint for the provider. + /// + /// Uses [`UpstreamOAuthProvider.userinfo_endpoint_override`] if set, + /// otherwise uses the one from discovery. + pub async fn userinfo_endpoint(&mut self) -> Result<&Url, DiscoveryError> { + if let Some(userinfo_endpoint) = &self.provider.userinfo_endpoint_override { + return Ok(userinfo_endpoint); + } + + Ok(self.load().await?.userinfo_endpoint()) + } + /// Get the PKCE methods supported by the provider. /// /// If the mode is set to auto, it will use the ones from discovery, @@ -387,9 +399,11 @@ mod tests { brand_name: None, discovery_mode: UpstreamOAuthProviderDiscoveryMode::Insecure, pkce_mode: UpstreamOAuthProviderPkceMode::Auto, + fetch_userinfo: false, jwks_uri_override: None, authorization_endpoint_override: None, scope: Scope::from_iter([OPENID]), + userinfo_endpoint_override: None, token_endpoint_override: None, client_id: "client_id".to_owned(), encrypted_client_secret: None, diff --git a/crates/handlers/src/upstream_oauth2/callback.rs b/crates/handlers/src/upstream_oauth2/callback.rs index 3c6472513..846eb8c6b 100644 --- a/crates/handlers/src/upstream_oauth2/callback.rs +++ b/crates/handlers/src/upstream_oauth2/callback.rs @@ -29,6 +29,7 @@ use mas_storage::{ use mas_templates::{FormPostContext, Templates}; use oauth2_types::errors::ClientErrorCode; use serde::{Deserialize, Serialize}; +use serde_json::json; use thiserror::Error; use ulid::Ulid; @@ -117,7 +118,7 @@ pub(crate) enum RouteError { }, #[error(transparent)] - Internal(Box), + Internal(Box), } impl_from_error_for_route!(mas_templates::TemplateError); @@ -125,6 +126,7 @@ impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError); impl_from_error_for_route!(mas_oidc_client::error::JwksError); impl_from_error_for_route!(mas_oidc_client::error::TokenAuthorizationCodeError); +impl_from_error_for_route!(mas_oidc_client::error::UserInfoError); impl_from_error_for_route!(super::ProviderCredentialsError); impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound); @@ -274,7 +276,7 @@ pub(crate) async fn handler( redirect_uri, }; - let id_token_verification_data = JwtVerificationData { + let verification_data = JwtVerificationData { issuer: &provider.issuer, jwks: &jwks, // TODO: make that configurable @@ -282,25 +284,48 @@ pub(crate) async fn handler( client_id: &provider.client_id, }; - let (response, id_token) = + let (response, id_token_map) = mas_oidc_client::requests::authorization_code::access_token_with_authorization_code( &client, client_credentials, lazy_metadata.token_endpoint().await?, code, validation_data, - Some(id_token_verification_data), + Some(verification_data), clock.now(), &mut rng, ) .await?; - let (_header, id_token) = id_token.ok_or(RouteError::MissingIDToken)?.into_parts(); + let (_header, id_token) = id_token_map + .clone() + .ok_or(RouteError::MissingIDToken)? + .into_parts(); let mut context = AttributeMappingContext::new().with_id_token_claims(id_token); if let Some(extra_callback_parameters) = extra_callback_parameters.clone() { context = context.with_extra_callback_parameters(extra_callback_parameters); } + + let userinfo = if provider.fetch_userinfo { + Some(json!( + mas_oidc_client::requests::userinfo::fetch_userinfo( + &client, + lazy_metadata.userinfo_endpoint().await?, + response.access_token.as_str(), + Some(verification_data), + &id_token_map.ok_or(RouteError::MissingIDToken)?, + ) + .await? + )) + } else { + None + }; + + if let Some(userinfo) = userinfo.clone() { + context = context.with_userinfo_claims(userinfo); + } + let context = context.build(); let env = environment(); @@ -341,6 +366,7 @@ pub(crate) async fn handler( &link, response.id_token, extra_callback_parameters, + userinfo, ) .await?; diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index dfc42043b..a4136192f 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -344,6 +344,9 @@ pub(crate) async fn get( if let Some(extra_callback_parameters) = upstream_session.extra_callback_parameters() { context = context.with_extra_callback_parameters(extra_callback_parameters.clone()); } + if let Some(userinfo) = upstream_session.userinfo() { + context = context.with_userinfo_claims(userinfo.clone()); + } let context = context.build(); let ctx = if provider.claims_imports.displayname.ignore() { @@ -582,6 +585,9 @@ pub(crate) async fn post( if let Some(extra_callback_parameters) = upstream_session.extra_callback_parameters() { context = context.with_extra_callback_parameters(extra_callback_parameters.clone()); } + if let Some(userinfo) = upstream_session.userinfo() { + context = context.with_userinfo_claims(userinfo.clone()); + } let context = context.build(); // Is the email verified according to the upstream provider? @@ -921,6 +927,8 @@ mod tests { claims_imports, authorization_endpoint_override: None, token_endpoint_override: None, + userinfo_endpoint_override: None, + fetch_userinfo: false, jwks_uri_override: None, discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc, pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto, @@ -958,6 +966,7 @@ mod tests { &link, Some(id_token.into_string()), None, + None, ) .await .unwrap(); diff --git a/crates/handlers/src/upstream_oauth2/template.rs b/crates/handlers/src/upstream_oauth2/template.rs index 86942d76a..1634948a0 100644 --- a/crates/handlers/src/upstream_oauth2/template.rs +++ b/crates/handlers/src/upstream_oauth2/template.rs @@ -23,6 +23,7 @@ use minijinja::{ pub(crate) struct AttributeMappingContext { id_token_claims: Option>, extra_callback_parameters: Option, + userinfo_claims: Option, } impl AttributeMappingContext { @@ -46,6 +47,11 @@ impl AttributeMappingContext { self } + pub fn with_userinfo_claims(mut self, userinfo_claims: serde_json::Value) -> Self { + self.userinfo_claims = Some(userinfo_claims); + self + } + pub fn build(self) -> Value { Value::from_object(self) } @@ -54,7 +60,25 @@ impl AttributeMappingContext { impl Object for AttributeMappingContext { fn get_value(self: &Arc, name: &Value) -> Option { match name.as_str()? { - "user" | "id_token_claims" => self.id_token_claims.as_ref().map(Value::from_serialize), + "user" => { + if self.id_token_claims.is_none() && self.userinfo_claims.is_none() { + return None; + } + let mut merged_user: HashMap = HashMap::new(); + if let serde_json::Value::Object(userinfo) = self + .userinfo_claims + .clone() + .unwrap_or(serde_json::Value::Null) + { + merged_user.extend(userinfo); + } + if let Some(id_token) = self.id_token_claims.clone() { + merged_user.extend(id_token); + } + Some(Value::from_serialize(merged_user)) + } + "id_token_claims" => self.id_token_claims.as_ref().map(Value::from_serialize), + "userinfo_claims" => self.userinfo_claims.as_ref().map(Value::from_serialize), "extra_callback_parameters" => self .extra_callback_parameters .as_ref() @@ -64,17 +88,20 @@ impl Object for AttributeMappingContext { } fn enumerate(self: &Arc) -> Enumerator { - match ( - self.id_token_claims.is_some(), - self.extra_callback_parameters.is_some(), - ) { - (true, true) => { - Enumerator::Str(&["user", "id_token_claims", "extra_callback_parameters"]) - } - (true, false) => Enumerator::Str(&["user", "id_token_claims"]), - (false, true) => Enumerator::Str(&["extra_callback_parameters"]), - (false, false) => Enumerator::Str(&["user"]), + let mut attrs = Vec::new(); + if self.id_token_claims.is_some() || self.userinfo_claims.is_none() { + attrs.push(minijinja::Value::from("user")); + } + if self.id_token_claims.is_some() { + attrs.push(minijinja::Value::from("id_token_claims")); + } + if self.userinfo_claims.is_some() { + attrs.push(minijinja::Value::from("userinfo_claims")); + } + if self.extra_callback_parameters.is_some() { + attrs.push(minijinja::Value::from("extra_callback_parameters")); } + Enumerator::Values(attrs) } } diff --git a/crates/handlers/src/views/login.rs b/crates/handlers/src/views/login.rs index edc27065a..b40b5041c 100644 --- a/crates/handlers/src/views/login.rs +++ b/crates/handlers/src/views/login.rs @@ -403,11 +403,13 @@ mod test { scope: [OPENID].into_iter().collect(), token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None, token_endpoint_signing_alg: None, + fetch_userinfo: false, client_id: "client".to_owned(), encrypted_client_secret: None, claims_imports: UpstreamOAuthProviderClaimsImports::default(), authorization_endpoint_override: None, token_endpoint_override: None, + userinfo_endpoint_override: None, jwks_uri_override: None, discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc, pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto, @@ -439,11 +441,13 @@ mod test { scope: [OPENID].into_iter().collect(), token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None, token_endpoint_signing_alg: None, + fetch_userinfo: false, client_id: "client".to_owned(), encrypted_client_secret: None, claims_imports: UpstreamOAuthProviderClaimsImports::default(), authorization_endpoint_override: None, token_endpoint_override: None, + userinfo_endpoint_override: None, jwks_uri_override: None, discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc, pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto, diff --git a/crates/oauth2-types/src/oidc.rs b/crates/oauth2-types/src/oidc.rs index c2fa68ecc..4bcf54539 100644 --- a/crates/oauth2-types/src/oidc.rs +++ b/crates/oauth2-types/src/oidc.rs @@ -950,6 +950,15 @@ impl VerifiedProviderMetadata { } } + /// URL of the authorization server's userinfo endpoint. + #[must_use] + pub fn userinfo_endpoint(&self) -> &Url { + match &self.userinfo_endpoint { + Some(u) => u, + None => unreachable!(), + } + } + /// URL of the authorization server's token endpoint. #[must_use] pub fn token_endpoint(&self) -> &Url { diff --git a/crates/storage-pg/.sqlx/query-6b133c3c6bfc3c80a21f6f72d0a6468f748ed59e88d8c904bb0a4bbfee43a67f.json b/crates/storage-pg/.sqlx/query-39657c8064532745c8a8a944b73f650b468a4677eddf671c69c329d361edf00e.json similarity index 73% rename from crates/storage-pg/.sqlx/query-6b133c3c6bfc3c80a21f6f72d0a6468f748ed59e88d8c904bb0a4bbfee43a67f.json rename to crates/storage-pg/.sqlx/query-39657c8064532745c8a8a944b73f650b468a4677eddf671c69c329d361edf00e.json index 1c56ec95b..1f6b0e0f5 100644 --- a/crates/storage-pg/.sqlx/query-6b133c3c6bfc3c80a21f6f72d0a6468f748ed59e88d8c904bb0a4bbfee43a67f.json +++ b/crates/storage-pg/.sqlx/query-39657c8064532745c8a8a944b73f650b468a4677eddf671c69c329d361edf00e.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at,\n disabled_at,\n claims_imports as \"claims_imports: Json\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n discovery_mode,\n pkce_mode,\n response_mode,\n additional_parameters as \"additional_parameters: Json>\"\n FROM upstream_oauth_providers\n WHERE disabled_at IS NULL\n ", + "query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n fetch_userinfo,\n created_at,\n disabled_at,\n claims_imports as \"claims_imports: Json\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n userinfo_endpoint_override,\n discovery_mode,\n pkce_mode,\n response_mode,\n additional_parameters as \"additional_parameters: Json>\"\n FROM upstream_oauth_providers\n WHERE disabled_at IS NULL\n ", "describe": { "columns": [ { @@ -50,51 +50,61 @@ }, { "ordinal": 9, + "name": "fetch_userinfo", + "type_info": "Bool" + }, + { + "ordinal": 10, "name": "created_at", "type_info": "Timestamptz" }, { - "ordinal": 10, + "ordinal": 11, "name": "disabled_at", "type_info": "Timestamptz" }, { - "ordinal": 11, + "ordinal": 12, "name": "claims_imports: Json", "type_info": "Jsonb" }, { - "ordinal": 12, + "ordinal": 13, "name": "jwks_uri_override", "type_info": "Text" }, { - "ordinal": 13, + "ordinal": 14, "name": "authorization_endpoint_override", "type_info": "Text" }, { - "ordinal": 14, + "ordinal": 15, "name": "token_endpoint_override", "type_info": "Text" }, { - "ordinal": 15, + "ordinal": 16, + "name": "userinfo_endpoint_override", + "type_info": "Text" + }, + { + "ordinal": 17, "name": "discovery_mode", "type_info": "Text" }, { - "ordinal": 16, + "ordinal": 18, "name": "pkce_mode", "type_info": "Text" }, { - "ordinal": 17, + "ordinal": 19, "name": "response_mode", "type_info": "Text" }, { - "ordinal": 18, + "ordinal": 20, "name": "additional_parameters: Json>", "type_info": "Jsonb" } @@ -113,16 +123,18 @@ true, false, false, + false, true, false, true, true, true, + true, false, false, false, true ] }, - "hash": "6b133c3c6bfc3c80a21f6f72d0a6468f748ed59e88d8c904bb0a4bbfee43a67f" + "hash": "39657c8064532745c8a8a944b73f650b468a4677eddf671c69c329d361edf00e" } diff --git a/crates/storage-pg/.sqlx/query-5516235e0983fb64d18e82dbe3e34f966ed71a0ed59be0d48ec66fedf64e707d.json b/crates/storage-pg/.sqlx/query-5f5245ace61b896f92be78ab4fef701b37c9e3c2f4a332f418b9fb2625a0fe3f.json similarity index 60% rename from crates/storage-pg/.sqlx/query-5516235e0983fb64d18e82dbe3e34f966ed71a0ed59be0d48ec66fedf64e707d.json rename to crates/storage-pg/.sqlx/query-5f5245ace61b896f92be78ab4fef701b37c9e3c2f4a332f418b9fb2625a0fe3f.json index 96ad3513a..c33da04d8 100644 --- a/crates/storage-pg/.sqlx/query-5516235e0983fb64d18e82dbe3e34f966ed71a0ed59be0d48ec66fedf64e707d.json +++ b/crates/storage-pg/.sqlx/query-5f5245ace61b896f92be78ab4fef701b37c9e3c2f4a332f418b9fb2625a0fe3f.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n UPDATE upstream_oauth_authorization_sessions\n SET upstream_oauth_link_id = $1,\n completed_at = $2,\n id_token = $3,\n extra_callback_parameters = $4\n WHERE upstream_oauth_authorization_session_id = $5\n ", + "query": "\n UPDATE upstream_oauth_authorization_sessions\n SET upstream_oauth_link_id = $1,\n completed_at = $2,\n id_token = $3,\n extra_callback_parameters = $4,\n userinfo = $5\n WHERE upstream_oauth_authorization_session_id = $6\n ", "describe": { "columns": [], "parameters": { @@ -9,10 +9,11 @@ "Timestamptz", "Text", "Jsonb", + "Jsonb", "Uuid" ] }, "nullable": [] }, - "hash": "5516235e0983fb64d18e82dbe3e34f966ed71a0ed59be0d48ec66fedf64e707d" + "hash": "5f5245ace61b896f92be78ab4fef701b37c9e3c2f4a332f418b9fb2625a0fe3f" } diff --git a/crates/storage-pg/.sqlx/query-73f4e5a724a432f1328c6112185cdc9c7a1ae1de45a2a8c02e7a2b8020b41160.json b/crates/storage-pg/.sqlx/query-887bd597132831c5caab2356f2d935c00a32274161ec5265da91d1c75ad0bb2b.json similarity index 73% rename from crates/storage-pg/.sqlx/query-73f4e5a724a432f1328c6112185cdc9c7a1ae1de45a2a8c02e7a2b8020b41160.json rename to crates/storage-pg/.sqlx/query-887bd597132831c5caab2356f2d935c00a32274161ec5265da91d1c75ad0bb2b.json index 866790c89..e93f49091 100644 --- a/crates/storage-pg/.sqlx/query-73f4e5a724a432f1328c6112185cdc9c7a1ae1de45a2a8c02e7a2b8020b41160.json +++ b/crates/storage-pg/.sqlx/query-887bd597132831c5caab2356f2d935c00a32274161ec5265da91d1c75ad0bb2b.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at,\n disabled_at,\n claims_imports as \"claims_imports: Json\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n discovery_mode,\n pkce_mode,\n response_mode,\n additional_parameters as \"additional_parameters: Json>\"\n FROM upstream_oauth_providers\n WHERE upstream_oauth_provider_id = $1\n ", + "query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n fetch_userinfo,\n created_at,\n disabled_at,\n claims_imports as \"claims_imports: Json\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n userinfo_endpoint_override,\n discovery_mode,\n pkce_mode,\n response_mode,\n additional_parameters as \"additional_parameters: Json>\"\n FROM upstream_oauth_providers\n WHERE upstream_oauth_provider_id = $1\n ", "describe": { "columns": [ { @@ -50,51 +50,61 @@ }, { "ordinal": 9, + "name": "fetch_userinfo", + "type_info": "Bool" + }, + { + "ordinal": 10, "name": "created_at", "type_info": "Timestamptz" }, { - "ordinal": 10, + "ordinal": 11, "name": "disabled_at", "type_info": "Timestamptz" }, { - "ordinal": 11, + "ordinal": 12, "name": "claims_imports: Json", "type_info": "Jsonb" }, { - "ordinal": 12, + "ordinal": 13, "name": "jwks_uri_override", "type_info": "Text" }, { - "ordinal": 13, + "ordinal": 14, "name": "authorization_endpoint_override", "type_info": "Text" }, { - "ordinal": 14, + "ordinal": 15, "name": "token_endpoint_override", "type_info": "Text" }, { - "ordinal": 15, + "ordinal": 16, + "name": "userinfo_endpoint_override", + "type_info": "Text" + }, + { + "ordinal": 17, "name": "discovery_mode", "type_info": "Text" }, { - "ordinal": 16, + "ordinal": 18, "name": "pkce_mode", "type_info": "Text" }, { - "ordinal": 17, + "ordinal": 19, "name": "response_mode", "type_info": "Text" }, { - "ordinal": 18, + "ordinal": 20, "name": "additional_parameters: Json>", "type_info": "Jsonb" } @@ -115,16 +125,18 @@ true, false, false, + false, true, false, true, true, true, + true, false, false, false, true ] }, - "hash": "73f4e5a724a432f1328c6112185cdc9c7a1ae1de45a2a8c02e7a2b8020b41160" + "hash": "887bd597132831c5caab2356f2d935c00a32274161ec5265da91d1c75ad0bb2b" } diff --git a/crates/storage-pg/.sqlx/query-8e1c0760c0b652cf62e47779f9d0aef89463cc60eeae2088d0fedf0aeb75718b.json b/crates/storage-pg/.sqlx/query-8e1c0760c0b652cf62e47779f9d0aef89463cc60eeae2088d0fedf0aeb75718b.json new file mode 100644 index 000000000..81e3ec219 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-8e1c0760c0b652cf62e47779f9d0aef89463cc60eeae2088d0fedf0aeb75718b.json @@ -0,0 +1,32 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n token_endpoint_auth_method,\n fetch_userinfo,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n claims_imports,\n authorization_endpoint_override,\n token_endpoint_override,\n userinfo_endpoint_override,\n jwks_uri_override,\n discovery_mode,\n pkce_mode,\n response_mode,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10,\n $11, $12, $13, $14, $15, $16, $17, $18, $19)\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Text", + "Text", + "Text", + "Text", + "Text", + "Bool", + "Text", + "Text", + "Text", + "Jsonb", + "Text", + "Text", + "Text", + "Text", + "Text", + "Text", + "Text", + "Timestamptz" + ] + }, + "nullable": [] + }, + "hash": "8e1c0760c0b652cf62e47779f9d0aef89463cc60eeae2088d0fedf0aeb75718b" +} diff --git a/crates/storage-pg/.sqlx/query-bf7747552fe6f5489dec3c91fe1cb13a737644b94871c28334a29c88977dd84c.json b/crates/storage-pg/.sqlx/query-bf7747552fe6f5489dec3c91fe1cb13a737644b94871c28334a29c88977dd84c.json new file mode 100644 index 000000000..190c6221f --- /dev/null +++ b/crates/storage-pg/.sqlx/query-bf7747552fe6f5489dec3c91fe1cb13a737644b94871c28334a29c88977dd84c.json @@ -0,0 +1,41 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n token_endpoint_auth_method,\n fetch_userinfo,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n claims_imports,\n authorization_endpoint_override,\n token_endpoint_override,\n userinfo_endpoint_override,\n jwks_uri_override,\n discovery_mode,\n pkce_mode,\n response_mode,\n additional_parameters,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11,\n $12, $13, $14, $15, $16, $17, $18, $19, $20)\n ON CONFLICT (upstream_oauth_provider_id)\n DO UPDATE\n SET\n issuer = EXCLUDED.issuer,\n human_name = EXCLUDED.human_name,\n brand_name = EXCLUDED.brand_name,\n scope = EXCLUDED.scope,\n token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,\n fetch_userinfo = EXCLUDED.fetch_userinfo,\n token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,\n disabled_at = NULL,\n client_id = EXCLUDED.client_id,\n encrypted_client_secret = EXCLUDED.encrypted_client_secret,\n claims_imports = EXCLUDED.claims_imports,\n authorization_endpoint_override = EXCLUDED.authorization_endpoint_override,\n token_endpoint_override = EXCLUDED.token_endpoint_override,\n userinfo_endpoint_override = EXCLUDED.userinfo_endpoint_override,\n jwks_uri_override = EXCLUDED.jwks_uri_override,\n discovery_mode = EXCLUDED.discovery_mode,\n pkce_mode = EXCLUDED.pkce_mode,\n response_mode = EXCLUDED.response_mode,\n additional_parameters = EXCLUDED.additional_parameters\n RETURNING created_at\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "created_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Uuid", + "Text", + "Text", + "Text", + "Text", + "Text", + "Bool", + "Text", + "Text", + "Text", + "Jsonb", + "Text", + "Text", + "Text", + "Text", + "Text", + "Text", + "Text", + "Jsonb", + "Timestamptz" + ] + }, + "nullable": [ + false + ] + }, + "hash": "bf7747552fe6f5489dec3c91fe1cb13a737644b94871c28334a29c88977dd84c" +} diff --git a/crates/storage-pg/.sqlx/query-e36ed76d0176edf8c4a029f017b8f368a529b2d32a54c52f6a28b9e615716f4d.json b/crates/storage-pg/.sqlx/query-e36ed76d0176edf8c4a029f017b8f368a529b2d32a54c52f6a28b9e615716f4d.json deleted file mode 100644 index 12bfb5e3b..000000000 --- a/crates/storage-pg/.sqlx/query-e36ed76d0176edf8c4a029f017b8f368a529b2d32a54c52f6a28b9e615716f4d.json +++ /dev/null @@ -1,39 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n claims_imports,\n authorization_endpoint_override,\n token_endpoint_override,\n jwks_uri_override,\n discovery_mode,\n pkce_mode,\n response_mode,\n additional_parameters,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9,\n $10, $11, $12, $13, $14, $15, $16, $17, $18)\n ON CONFLICT (upstream_oauth_provider_id)\n DO UPDATE\n SET\n issuer = EXCLUDED.issuer,\n human_name = EXCLUDED.human_name,\n brand_name = EXCLUDED.brand_name,\n scope = EXCLUDED.scope,\n token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,\n token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,\n disabled_at = NULL,\n client_id = EXCLUDED.client_id,\n encrypted_client_secret = EXCLUDED.encrypted_client_secret,\n claims_imports = EXCLUDED.claims_imports,\n authorization_endpoint_override = EXCLUDED.authorization_endpoint_override,\n token_endpoint_override = EXCLUDED.token_endpoint_override,\n jwks_uri_override = EXCLUDED.jwks_uri_override,\n discovery_mode = EXCLUDED.discovery_mode,\n pkce_mode = EXCLUDED.pkce_mode,\n response_mode = EXCLUDED.response_mode,\n additional_parameters = EXCLUDED.additional_parameters\n RETURNING created_at\n ", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "created_at", - "type_info": "Timestamptz" - } - ], - "parameters": { - "Left": [ - "Uuid", - "Text", - "Text", - "Text", - "Text", - "Text", - "Text", - "Text", - "Text", - "Jsonb", - "Text", - "Text", - "Text", - "Text", - "Text", - "Text", - "Jsonb", - "Timestamptz" - ] - }, - "nullable": [ - false - ] - }, - "hash": "e36ed76d0176edf8c4a029f017b8f368a529b2d32a54c52f6a28b9e615716f4d" -} diff --git a/crates/storage-pg/.sqlx/query-7d329e0c57f36b9ffe2aa7ddf4a21e293522c00009cca0222524b0c73f6eee30.json b/crates/storage-pg/.sqlx/query-ea30b3809fd7c1d4e9983909c0219f343953a89f2a43f6b8c4ab4fbea7645ccc.json similarity index 80% rename from crates/storage-pg/.sqlx/query-7d329e0c57f36b9ffe2aa7ddf4a21e293522c00009cca0222524b0c73f6eee30.json rename to crates/storage-pg/.sqlx/query-ea30b3809fd7c1d4e9983909c0219f343953a89f2a43f6b8c4ab4fbea7645ccc.json index 57fc34e9b..3a49d2a87 100644 --- a/crates/storage-pg/.sqlx/query-7d329e0c57f36b9ffe2aa7ddf4a21e293522c00009cca0222524b0c73f6eee30.json +++ b/crates/storage-pg/.sqlx/query-ea30b3809fd7c1d4e9983909c0219f343953a89f2a43f6b8c4ab4fbea7645ccc.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n SELECT\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n upstream_oauth_link_id,\n state,\n code_challenge_verifier,\n nonce,\n id_token,\n extra_callback_parameters,\n created_at,\n completed_at,\n consumed_at\n FROM upstream_oauth_authorization_sessions\n WHERE upstream_oauth_authorization_session_id = $1\n ", + "query": "\n SELECT\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n upstream_oauth_link_id,\n state,\n code_challenge_verifier,\n nonce,\n id_token,\n extra_callback_parameters,\n userinfo,\n created_at,\n completed_at,\n consumed_at\n FROM upstream_oauth_authorization_sessions\n WHERE upstream_oauth_authorization_session_id = $1\n ", "describe": { "columns": [ { @@ -45,16 +45,21 @@ }, { "ordinal": 8, + "name": "userinfo", + "type_info": "Jsonb" + }, + { + "ordinal": 9, "name": "created_at", "type_info": "Timestamptz" }, { - "ordinal": 9, + "ordinal": 10, "name": "completed_at", "type_info": "Timestamptz" }, { - "ordinal": 10, + "ordinal": 11, "name": "consumed_at", "type_info": "Timestamptz" } @@ -73,10 +78,11 @@ false, true, true, + true, false, true, true ] }, - "hash": "7d329e0c57f36b9ffe2aa7ddf4a21e293522c00009cca0222524b0c73f6eee30" + "hash": "ea30b3809fd7c1d4e9983909c0219f343953a89f2a43f6b8c4ab4fbea7645ccc" } diff --git a/crates/storage-pg/.sqlx/query-ebb1a78003293376a52de830f89f6f526ad1c5c823328463a6525d3c3d0d95c6.json b/crates/storage-pg/.sqlx/query-ebb1a78003293376a52de830f89f6f526ad1c5c823328463a6525d3c3d0d95c6.json deleted file mode 100644 index aa2a95c1f..000000000 --- a/crates/storage-pg/.sqlx/query-ebb1a78003293376a52de830f89f6f526ad1c5c823328463a6525d3c3d0d95c6.json +++ /dev/null @@ -1,30 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n claims_imports,\n authorization_endpoint_override,\n token_endpoint_override,\n jwks_uri_override,\n discovery_mode,\n pkce_mode,\n response_mode,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9,\n $10, $11, $12, $13, $14, $15, $16, $17)\n ", - "describe": { - "columns": [], - "parameters": { - "Left": [ - "Uuid", - "Text", - "Text", - "Text", - "Text", - "Text", - "Text", - "Text", - "Text", - "Jsonb", - "Text", - "Text", - "Text", - "Text", - "Text", - "Text", - "Timestamptz" - ] - }, - "nullable": [] - }, - "hash": "ebb1a78003293376a52de830f89f6f526ad1c5c823328463a6525d3c3d0d95c6" -} diff --git a/crates/storage-pg/.sqlx/query-64e6ea47c2e877c1ebe4338d64d9ad8a6c1c777d1daea024b8ca2e7f0dd75b0f.json b/crates/storage-pg/.sqlx/query-f5c2ec9b7038d7ed36091e670f9bf34f8aa9ea8ed50929731845e32dc3176e39.json similarity index 72% rename from crates/storage-pg/.sqlx/query-64e6ea47c2e877c1ebe4338d64d9ad8a6c1c777d1daea024b8ca2e7f0dd75b0f.json rename to crates/storage-pg/.sqlx/query-f5c2ec9b7038d7ed36091e670f9bf34f8aa9ea8ed50929731845e32dc3176e39.json index 55e3d87f7..71096b9ae 100644 --- a/crates/storage-pg/.sqlx/query-64e6ea47c2e877c1ebe4338d64d9ad8a6c1c777d1daea024b8ca2e7f0dd75b0f.json +++ b/crates/storage-pg/.sqlx/query-f5c2ec9b7038d7ed36091e670f9bf34f8aa9ea8ed50929731845e32dc3176e39.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n INSERT INTO upstream_oauth_authorization_sessions (\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n state,\n code_challenge_verifier,\n nonce,\n created_at,\n completed_at,\n consumed_at,\n id_token\n ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL)\n ", + "query": "\n INSERT INTO upstream_oauth_authorization_sessions (\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n state,\n code_challenge_verifier,\n nonce,\n created_at,\n completed_at,\n consumed_at,\n id_token,\n userinfo\n ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL, NULL)\n ", "describe": { "columns": [], "parameters": { @@ -15,5 +15,5 @@ }, "nullable": [] }, - "hash": "64e6ea47c2e877c1ebe4338d64d9ad8a6c1c777d1daea024b8ca2e7f0dd75b0f" + "hash": "f5c2ec9b7038d7ed36091e670f9bf34f8aa9ea8ed50929731845e32dc3176e39" } diff --git a/crates/storage-pg/migrations/20241124145741_upstream_oauth_userinfo.sql b/crates/storage-pg/migrations/20241124145741_upstream_oauth_userinfo.sql new file mode 100644 index 000000000..7c0168d56 --- /dev/null +++ b/crates/storage-pg/migrations/20241124145741_upstream_oauth_userinfo.sql @@ -0,0 +1,13 @@ +-- Copyright 2024 New Vector Ltd. +-- +-- SPDX-License-Identifier: AGPL-3.0-only +-- Please see LICENSE in the repository root for full details. + +-- Add columms to upstream_oauth_providers and upstream_oauth_authorization_sessions +-- table to handle userinfo endpoint. +ALTER TABLE "upstream_oauth_providers" + ADD COLUMN "fetch_userinfo" BOOLEAN NOT NULL DEFAULT FALSE, + ADD COLUMN "userinfo_endpoint_override" TEXT; + +ALTER TABLE "upstream_oauth_authorization_sessions" + ADD COLUMN "userinfo" JSONB; diff --git a/crates/storage-pg/src/iden.rs b/crates/storage-pg/src/iden.rs index b9a00333d..1d35646c4 100644 --- a/crates/storage-pg/src/iden.rs +++ b/crates/storage-pg/src/iden.rs @@ -98,6 +98,7 @@ pub enum UpstreamOAuthProviders { EncryptedClientSecret, TokenEndpointSigningAlg, TokenEndpointAuthMethod, + FetchUserinfo, CreatedAt, DisabledAt, ClaimsImports, @@ -108,6 +109,7 @@ pub enum UpstreamOAuthProviders { JwksUriOverride, TokenEndpointOverride, AuthorizationEndpointOverride, + UserinfoEndpointOverride, } #[derive(sea_query::Iden)] diff --git a/crates/storage-pg/src/upstream_oauth2/mod.rs b/crates/storage-pg/src/upstream_oauth2/mod.rs index a544c9cd3..48c67fe91 100644 --- a/crates/storage-pg/src/upstream_oauth2/mod.rs +++ b/crates/storage-pg/src/upstream_oauth2/mod.rs @@ -60,12 +60,14 @@ mod tests { brand_name: None, scope: Scope::from_iter([OPENID]), token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None, + fetch_userinfo: false, token_endpoint_signing_alg: None, client_id: "client-id".to_owned(), encrypted_client_secret: None, claims_imports: UpstreamOAuthProviderClaimsImports::default(), token_endpoint_override: None, authorization_endpoint_override: None, + userinfo_endpoint_override: None, jwks_uri_override: None, discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc, pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto, @@ -145,7 +147,7 @@ mod tests { let session = repo .upstream_oauth_session() - .complete_with_link(&clock, session, &link, None, None) + .complete_with_link(&clock, session, &link, None, None, None) .await .unwrap(); // Reload the session @@ -302,12 +304,14 @@ mod tests { brand_name: None, scope: scope.clone(), token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None, + fetch_userinfo: false, token_endpoint_signing_alg: None, client_id, encrypted_client_secret: None, claims_imports: UpstreamOAuthProviderClaimsImports::default(), token_endpoint_override: None, authorization_endpoint_override: None, + userinfo_endpoint_override: None, jwks_uri_override: None, discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc, pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto, diff --git a/crates/storage-pg/src/upstream_oauth2/provider.rs b/crates/storage-pg/src/upstream_oauth2/provider.rs index 4b384c95a..e66088157 100644 --- a/crates/storage-pg/src/upstream_oauth2/provider.rs +++ b/crates/storage-pg/src/upstream_oauth2/provider.rs @@ -56,12 +56,14 @@ struct ProviderLookup { encrypted_client_secret: Option, token_endpoint_signing_alg: Option, token_endpoint_auth_method: String, + fetch_userinfo: bool, created_at: DateTime, disabled_at: Option>, claims_imports: Json, jwks_uri_override: Option, authorization_endpoint_override: Option, token_endpoint_override: Option, + userinfo_endpoint_override: Option, discovery_mode: String, pkce_mode: String, response_mode: String, @@ -70,6 +72,8 @@ struct ProviderLookup { impl TryFrom for UpstreamOAuthProvider { type Error = DatabaseInconsistencyError; + + #[allow(clippy::too_many_lines)] fn try_from(value: ProviderLookup) -> Result { let id = value.upstream_oauth_provider_id.into(); let scope = value.scope.parse().map_err(|e| { @@ -117,6 +121,17 @@ impl TryFrom for UpstreamOAuthProvider { .source(e) })?; + let userinfo_endpoint_override = value + .userinfo_endpoint_override + .map(|x| x.parse()) + .transpose() + .map_err(|e| { + DatabaseInconsistencyError::on("upstream_oauth_providers") + .column("userinfo_endpoint_override") + .row(id) + .source(e) + })?; + let jwks_uri_override = value .jwks_uri_override .map(|x| x.parse()) @@ -163,12 +178,14 @@ impl TryFrom for UpstreamOAuthProvider { client_id: value.client_id, encrypted_client_secret: value.encrypted_client_secret, token_endpoint_auth_method, + fetch_userinfo: value.fetch_userinfo, token_endpoint_signing_alg, created_at: value.created_at, disabled_at: value.disabled_at, claims_imports: value.claims_imports.0, authorization_endpoint_override, token_endpoint_override, + userinfo_endpoint_override, jwks_uri_override, discovery_mode, pkce_mode, @@ -218,12 +235,14 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' encrypted_client_secret, token_endpoint_signing_alg, token_endpoint_auth_method, + fetch_userinfo, created_at, disabled_at, claims_imports as "claims_imports: Json", jwks_uri_override, authorization_endpoint_override, token_endpoint_override, + userinfo_endpoint_override, discovery_mode, pkce_mode, response_mode, @@ -275,19 +294,21 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' brand_name, scope, token_endpoint_auth_method, + fetch_userinfo, token_endpoint_signing_alg, client_id, encrypted_client_secret, claims_imports, authorization_endpoint_override, token_endpoint_override, + userinfo_endpoint_override, jwks_uri_override, discovery_mode, pkce_mode, response_mode, created_at - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, - $10, $11, $12, $13, $14, $15, $16, $17) + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, + $11, $12, $13, $14, $15, $16, $17, $18, $19) "#, Uuid::from(id), ¶ms.issuer, @@ -295,6 +316,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' params.brand_name.as_deref(), params.scope.to_string(), params.token_endpoint_auth_method.to_string(), + params.fetch_userinfo, params .token_endpoint_signing_alg .as_ref() @@ -310,6 +332,10 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' .token_endpoint_override .as_ref() .map(ToString::to_string), + params + .userinfo_endpoint_override + .as_ref() + .map(ToString::to_string), params.jwks_uri_override.as_ref().map(ToString::to_string), params.discovery_mode.as_str(), params.pkce_mode.as_str(), @@ -330,11 +356,13 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' encrypted_client_secret: params.encrypted_client_secret, token_endpoint_signing_alg: params.token_endpoint_signing_alg, token_endpoint_auth_method: params.token_endpoint_auth_method, + fetch_userinfo: params.fetch_userinfo, created_at, disabled_at: None, claims_imports: params.claims_imports, authorization_endpoint_override: params.authorization_endpoint_override, token_endpoint_override: params.token_endpoint_override, + userinfo_endpoint_override: params.userinfo_endpoint_override, jwks_uri_override: params.jwks_uri_override, discovery_mode: params.discovery_mode, pkce_mode: params.pkce_mode, @@ -437,20 +465,22 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' brand_name, scope, token_endpoint_auth_method, + fetch_userinfo, token_endpoint_signing_alg, client_id, encrypted_client_secret, claims_imports, authorization_endpoint_override, token_endpoint_override, + userinfo_endpoint_override, jwks_uri_override, discovery_mode, pkce_mode, response_mode, additional_parameters, created_at - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, - $10, $11, $12, $13, $14, $15, $16, $17, $18) + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, + $12, $13, $14, $15, $16, $17, $18, $19, $20) ON CONFLICT (upstream_oauth_provider_id) DO UPDATE SET @@ -459,6 +489,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' brand_name = EXCLUDED.brand_name, scope = EXCLUDED.scope, token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method, + fetch_userinfo = EXCLUDED.fetch_userinfo, token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg, disabled_at = NULL, client_id = EXCLUDED.client_id, @@ -466,6 +497,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' claims_imports = EXCLUDED.claims_imports, authorization_endpoint_override = EXCLUDED.authorization_endpoint_override, token_endpoint_override = EXCLUDED.token_endpoint_override, + userinfo_endpoint_override = EXCLUDED.userinfo_endpoint_override, jwks_uri_override = EXCLUDED.jwks_uri_override, discovery_mode = EXCLUDED.discovery_mode, pkce_mode = EXCLUDED.pkce_mode, @@ -479,6 +511,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' params.brand_name.as_deref(), params.scope.to_string(), params.token_endpoint_auth_method.to_string(), + params.fetch_userinfo, params .token_endpoint_signing_alg .as_ref() @@ -494,6 +527,10 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' .token_endpoint_override .as_ref() .map(ToString::to_string), + params + .userinfo_endpoint_override + .as_ref() + .map(ToString::to_string), params.jwks_uri_override.as_ref().map(ToString::to_string), params.discovery_mode.as_str(), params.pkce_mode.as_str(), @@ -515,11 +552,13 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' encrypted_client_secret: params.encrypted_client_secret, token_endpoint_signing_alg: params.token_endpoint_signing_alg, token_endpoint_auth_method: params.token_endpoint_auth_method, + fetch_userinfo: params.fetch_userinfo, created_at, disabled_at: None, claims_imports: params.claims_imports, authorization_endpoint_override: params.authorization_endpoint_override, token_endpoint_override: params.token_endpoint_override, + userinfo_endpoint_override: params.userinfo_endpoint_override, jwks_uri_override: params.jwks_uri_override, discovery_mode: params.discovery_mode, pkce_mode: params.pkce_mode, @@ -644,6 +683,13 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' )), ProviderLookupIden::CreatedAt, ) + .expr_as( + Expr::col(( + UpstreamOAuthProviders::Table, + UpstreamOAuthProviders::FetchUserinfo, + )), + ProviderLookupIden::FetchUserinfo, + ) .expr_as( Expr::col(( UpstreamOAuthProviders::Table, @@ -679,6 +725,13 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' )), ProviderLookupIden::AuthorizationEndpointOverride, ) + .expr_as( + Expr::col(( + UpstreamOAuthProviders::Table, + UpstreamOAuthProviders::UserinfoEndpointOverride, + )), + ProviderLookupIden::UserinfoEndpointOverride, + ) .expr_as( Expr::col(( UpstreamOAuthProviders::Table, @@ -786,12 +839,14 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' encrypted_client_secret, token_endpoint_signing_alg, token_endpoint_auth_method, + fetch_userinfo, created_at, disabled_at, claims_imports as "claims_imports: Json", jwks_uri_override, authorization_endpoint_override, token_endpoint_override, + userinfo_endpoint_override, discovery_mode, pkce_mode, response_mode, diff --git a/crates/storage-pg/src/upstream_oauth2/session.rs b/crates/storage-pg/src/upstream_oauth2/session.rs index e3f28b5f4..1cadcfa61 100644 --- a/crates/storage-pg/src/upstream_oauth2/session.rs +++ b/crates/storage-pg/src/upstream_oauth2/session.rs @@ -40,6 +40,7 @@ struct SessionLookup { code_challenge_verifier: Option, nonce: String, id_token: Option, + userinfo: Option, created_at: DateTime, completed_at: Option>, consumed_at: Option>, @@ -55,22 +56,30 @@ impl TryFrom for UpstreamOAuthAuthorizationSession { value.upstream_oauth_link_id, value.id_token, value.extra_callback_parameters, + value.userinfo, value.completed_at, value.consumed_at, ) { - (None, None, None, None, None) => UpstreamOAuthAuthorizationSessionState::Pending, - (Some(link_id), id_token, extra_callback_parameters, Some(completed_at), None) => { - UpstreamOAuthAuthorizationSessionState::Completed { - completed_at, - link_id: link_id.into(), - id_token, - extra_callback_parameters, - } - } + (None, None, None, None, None, None) => UpstreamOAuthAuthorizationSessionState::Pending, ( Some(link_id), id_token, extra_callback_parameters, + userinfo, + Some(completed_at), + None, + ) => UpstreamOAuthAuthorizationSessionState::Completed { + completed_at, + link_id: link_id.into(), + id_token, + extra_callback_parameters, + userinfo, + }, + ( + Some(link_id), + id_token, + extra_callback_parameters, + userinfo, Some(completed_at), Some(consumed_at), ) => UpstreamOAuthAuthorizationSessionState::Consumed { @@ -78,6 +87,7 @@ impl TryFrom for UpstreamOAuthAuthorizationSession { link_id: link_id.into(), id_token, extra_callback_parameters, + userinfo, consumed_at, }, _ => { @@ -128,6 +138,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> nonce, id_token, extra_callback_parameters, + userinfo, created_at, completed_at, consumed_at @@ -184,8 +195,9 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> created_at, completed_at, consumed_at, - id_token - ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL) + id_token, + userinfo + ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL, NULL) "#, Uuid::from(id), Uuid::from(upstream_oauth_provider.id), @@ -226,6 +238,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> upstream_oauth_link: &UpstreamOAuthLink, id_token: Option, extra_callback_parameters: Option, + userinfo: Option, ) -> Result { let completed_at = clock.now(); @@ -235,13 +248,15 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> SET upstream_oauth_link_id = $1, completed_at = $2, id_token = $3, - extra_callback_parameters = $4 - WHERE upstream_oauth_authorization_session_id = $5 + extra_callback_parameters = $4, + userinfo = $5 + WHERE upstream_oauth_authorization_session_id = $6 "#, Uuid::from(upstream_oauth_link.id), completed_at, id_token, extra_callback_parameters, + userinfo, Uuid::from(upstream_oauth_authorization_session.id), ) .traced() @@ -254,6 +269,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> upstream_oauth_link, id_token, extra_callback_parameters, + userinfo, ) .map_err(DatabaseError::to_invalid_operation)?; diff --git a/crates/storage/src/upstream_oauth2/provider.rs b/crates/storage/src/upstream_oauth2/provider.rs index a5309d5d1..10ac08ec5 100644 --- a/crates/storage/src/upstream_oauth2/provider.rs +++ b/crates/storage/src/upstream_oauth2/provider.rs @@ -42,6 +42,11 @@ pub struct UpstreamOAuthProviderParams { /// `private_key_jwt` authentication methods are used pub token_endpoint_signing_alg: Option, + /// Whether to fetch the user profile from the userinfo endpoint, + /// or to rely on the data returned in the `id_token` from the + /// `token_endpoint`. + pub fetch_userinfo: bool, + /// The client ID to use when authenticating to the upstream pub client_id: String, @@ -59,6 +64,10 @@ pub struct UpstreamOAuthProviderParams { /// discovered pub token_endpoint_override: Option, + /// The URL to use as the userinfo endpoint. If `None`, the URL will be + /// discovered + pub userinfo_endpoint_override: Option, + /// The URL to use when fetching JWKS. If `None`, the URL will be discovered pub jwks_uri_override: Option, diff --git a/crates/storage/src/upstream_oauth2/session.rs b/crates/storage/src/upstream_oauth2/session.rs index a9a438a3a..827a0f902 100644 --- a/crates/storage/src/upstream_oauth2/session.rs +++ b/crates/storage/src/upstream_oauth2/session.rs @@ -87,6 +87,7 @@ pub trait UpstreamOAuthSessionRepository: Send + Sync { upstream_oauth_link: &UpstreamOAuthLink, id_token: Option, extra_callback_parameters: Option, + userinfo: Option, ) -> Result; /// Mark a session as consumed @@ -131,6 +132,7 @@ repository_impl!(UpstreamOAuthSessionRepository: upstream_oauth_link: &UpstreamOAuthLink, id_token: Option, extra_callback_parameters: Option, + userinfo: Option, ) -> Result; async fn consume( diff --git a/docs/config.schema.json b/docs/config.schema.json index 9324154ca..b2959c56e 100644 --- a/docs/config.schema.json +++ b/docs/config.schema.json @@ -1895,11 +1895,21 @@ } ] }, + "fetch_userinfo": { + "description": "Whether to fetch the user profile from the userinfo endpoint, or to rely on the data returned in the `id_token` from the `token_endpoint`.\n\nDefaults to `false`.", + "default": false, + "type": "boolean" + }, "authorization_endpoint": { "description": "The URL to use for the provider's authorization endpoint\n\nDefaults to the `authorization_endpoint` provided through discovery", "type": "string", "format": "uri" }, + "userinfo_endpoint": { + "description": "The URL to use for the provider's userinfo endpoint\n\nDefaults to the `userinfo_endpoint` provided through discovery", + "type": "string", + "format": "uri" + }, "token_endpoint": { "description": "The URL to use for the provider's token endpoint\n\nDefaults to the `token_endpoint` provided through discovery", "type": "string",