Skip to content

Commit 4a62a23

Browse files
committed
Add user_profile_method to upstream SSO provider
1 parent 8723e40 commit 4a62a23

24 files changed

+361
-1149
lines changed

crates/cli/src/sync.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,15 @@ pub async fn config_sync(
231231
}
232232
};
233233

234+
let user_profile_method = match provider.user_profile_method {
235+
mas_config::UpstreamOAuth2UserProfileMethod::Auto => {
236+
mas_data_model::UpstreamOAuthProviderUserProfileMethod::Auto
237+
}
238+
mas_config::UpstreamOAuth2UserProfileMethod::UserinfoEndpoint => {
239+
mas_data_model::UpstreamOAuthProviderUserProfileMethod::UserinfoEndpoint
240+
}
241+
};
242+
234243
repo.upstream_oauth_provider()
235244
.upsert(
236245
clock,
@@ -241,13 +250,15 @@ pub async fn config_sync(
241250
brand_name: provider.brand_name,
242251
scope: provider.scope.parse()?,
243252
token_endpoint_auth_method: provider.token_endpoint_auth_method.into(),
253+
user_profile_method,
244254
token_endpoint_signing_alg: provider
245255
.token_endpoint_auth_signing_alg
246256
.clone(),
247257
client_id: provider.client_id,
248258
encrypted_client_secret,
249259
claims_imports: map_claims_imports(&provider.claims_imports),
250260
token_endpoint_override: provider.token_endpoint,
261+
userinfo_endpoint_override: provider.userinfo_endpoint,
251262
authorization_endpoint_override: provider.authorization_endpoint,
252263
jwks_uri_override: provider.jwks_uri,
253264
discovery_mode,

crates/config/src/sections/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ pub use self::{
5252
EmailImportPreference as UpstreamOAuth2EmailImportPreference,
5353
ImportAction as UpstreamOAuth2ImportAction, PkceMethod as UpstreamOAuth2PkceMethod,
5454
SetEmailVerification as UpstreamOAuth2SetEmailVerification, UpstreamOAuth2Config,
55+
UserProfileMethod as UpstreamOAuth2UserProfileMethod,
5556
},
5657
};
5758
use crate::util::ConfigurationSection;

crates/config/src/sections/upstream_oauth2.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,26 @@ impl From<TokenAuthMethod> for OAuthClientAuthenticationMethod {
124124
}
125125
}
126126

127+
/// Whether to fetch the user profile from the userinfo endpoint,
128+
/// or to rely on the data returned in the id_token from the token_endpoint
129+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
130+
#[serde(rename_all = "snake_case")]
131+
pub enum UserProfileMethod {
132+
/// Use the userinfo endpoint if `openid` is not included in `scopes`
133+
#[default]
134+
Auto,
135+
136+
/// Always use the userinfo endpoint
137+
UserinfoEndpoint,
138+
}
139+
140+
impl UserProfileMethod {
141+
#[allow(clippy::trivially_copy_pass_by_ref)]
142+
const fn is_default(&self) -> bool {
143+
matches!(self, UserProfileMethod::Auto)
144+
}
145+
}
146+
127147
/// How to handle a claim
128148
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
129149
#[serde(rename_all = "lowercase")]
@@ -401,6 +421,14 @@ pub struct Provider {
401421
#[serde(skip_serializing_if = "Option::is_none")]
402422
pub token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
403423

424+
/// Whether to fetch the user profile from the userinfo endpoint,
425+
/// or to rely on the data returned in the id_token from the token_endpoint.
426+
///
427+
/// Defaults to `auto`, which uses the userinfo endpoint if `openid` is not
428+
/// included in `scopes`, and the ID token otherwise.
429+
#[serde(default, skip_serializing_if = "UserProfileMethod::is_default")]
430+
pub user_profile_method: UserProfileMethod,
431+
404432
/// The scopes to request from the provider
405433
pub scope: String,
406434

@@ -424,6 +452,12 @@ pub struct Provider {
424452
#[serde(skip_serializing_if = "Option::is_none")]
425453
pub authorization_endpoint: Option<Url>,
426454

455+
/// The URL to use for the provider's userinfo endpoint
456+
///
457+
/// Defaults to the `userinfo_endpoint` provided through discovery
458+
#[serde(skip_serializing_if = "Option::is_none")]
459+
pub userinfo_endpoint: Option<Url>,
460+
427461
/// The URL to use for the provider's token endpoint
428462
///
429463
/// Defaults to the `token_endpoint` provided through discovery

crates/data-model/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ pub use self::{
4242
UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode,
4343
UpstreamOAuthProviderImportAction, UpstreamOAuthProviderImportPreference,
4444
UpstreamOAuthProviderPkceMode, UpstreamOAuthProviderSubjectPreference,
45+
UpstreamOAuthProviderUserProfileMethod,
4546
},
4647
user_agent::{DeviceType, UserAgent},
4748
users::{

crates/data-model/src/upstream_oauth2/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ pub use self::{
1818
PkceMode as UpstreamOAuthProviderPkceMode,
1919
SetEmailVerification as UpsreamOAuthProviderSetEmailVerification,
2020
SubjectPreference as UpstreamOAuthProviderSubjectPreference, UpstreamOAuthProvider,
21+
UserProfileMethod as UpstreamOAuthProviderUserProfileMethod,
2122
},
2223
session::{UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState},
2324
};

crates/data-model/src/upstream_oauth2/provider.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,51 @@ impl std::fmt::Display for PkceMode {
116116
}
117117
}
118118

119+
/// Whether to fetch the user profile from the userinfo endpoint,
120+
/// or to rely on the data returned in the id_token from the token_endpoint
121+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
122+
#[serde(rename_all = "lowercase")]
123+
pub enum UserProfileMethod {
124+
/// Use the userinfo endpoint if `openid` is not included in `scopes`
125+
#[default]
126+
Auto,
127+
128+
/// Always use the userinfo endpoint
129+
UserinfoEndpoint,
130+
}
131+
132+
#[derive(Debug, Clone, Error)]
133+
#[error("Invalid user profile method {0:?}")]
134+
pub struct InvalidUserProfileMethodError(String);
135+
136+
impl std::str::FromStr for UserProfileMethod {
137+
type Err = InvalidUserProfileMethodError;
138+
139+
fn from_str(s: &str) -> Result<Self, Self::Err> {
140+
match s {
141+
"auto" => Ok(Self::Auto),
142+
"userinfo_endpoint" => Ok(Self::UserinfoEndpoint),
143+
s => Err(InvalidUserProfileMethodError(s.to_owned())),
144+
}
145+
}
146+
}
147+
148+
impl UserProfileMethod {
149+
#[must_use]
150+
pub fn as_str(self) -> &'static str {
151+
match self {
152+
Self::Auto => "auto",
153+
Self::UserinfoEndpoint => "userinfo_endpoint",
154+
}
155+
}
156+
}
157+
158+
impl std::fmt::Display for UserProfileMethod {
159+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160+
f.write_str(self.as_str())
161+
}
162+
}
163+
119164
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
120165
pub struct UpstreamOAuthProvider {
121166
pub id: Ulid,
@@ -127,11 +172,13 @@ pub struct UpstreamOAuthProvider {
127172
pub jwks_uri_override: Option<Url>,
128173
pub authorization_endpoint_override: Option<Url>,
129174
pub token_endpoint_override: Option<Url>,
175+
pub userinfo_endpoint_override: Option<Url>,
130176
pub scope: Scope,
131177
pub client_id: String,
132178
pub encrypted_client_secret: Option<String>,
133179
pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
134180
pub token_endpoint_auth_method: OAuthClientAuthenticationMethod,
181+
pub user_profile_method: UserProfileMethod,
135182
pub created_at: DateTime<Utc>,
136183
pub disabled_at: Option<DateTime<Utc>>,
137184
pub claims_imports: ClaimsImports,

crates/handlers/src/upstream_oauth2/cache.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,18 @@ impl<'a> LazyProviderInfos<'a> {
109109
Ok(self.load().await?.token_endpoint())
110110
}
111111

112+
/// Get the userinfo endpoint for the provider.
113+
///
114+
/// Uses [`UpstreamOAuthProvider.userinfo_endpoint_override`] if set, otherwise
115+
/// uses the one from discovery.
116+
pub async fn userinfo_endpoint(&mut self) -> Result<&Url, DiscoveryError> {
117+
if let Some(userinfo_endpoint) = &self.provider.userinfo_endpoint_override {
118+
return Ok(userinfo_endpoint);
119+
}
120+
121+
Ok(self.load().await?.userinfo_endpoint())
122+
}
123+
112124
/// Get the PKCE methods supported by the provider.
113125
///
114126
/// If the mode is set to auto, it will use the ones from discovery,
@@ -276,7 +288,9 @@ mod tests {
276288
use std::sync::atomic::{AtomicUsize, Ordering};
277289

278290
use hyper::{body::Bytes, Request, Response, StatusCode};
279-
use mas_data_model::UpstreamOAuthProviderClaimsImports;
291+
use mas_data_model::{
292+
UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderUserProfileMethod,
293+
};
280294
use mas_http::BoxCloneSyncService;
281295
use mas_iana::oauth::OAuthClientAuthenticationMethod;
282296
use mas_storage::{clock::MockClock, Clock};
@@ -487,8 +501,10 @@ mod tests {
487501
brand_name: None,
488502
discovery_mode: UpstreamOAuthProviderDiscoveryMode::Oidc,
489503
pkce_mode: UpstreamOAuthProviderPkceMode::Auto,
504+
user_profile_method: UpstreamOAuthProviderUserProfileMethod::Auto,
490505
jwks_uri_override: None,
491506
authorization_endpoint_override: None,
507+
userinfo_endpoint_override: None,
492508
token_endpoint_override: None,
493509
scope: Scope::from_iter([OPENID]),
494510
client_id: "client_id".to_owned(),

crates/handlers/src/upstream_oauth2/callback.rs

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use mas_axum_utils::{
1313
cookies::CookieJar, http_client_factory::HttpClientFactory, sentry::SentryEventID,
1414
};
1515
use mas_data_model::UpstreamOAuthProvider;
16+
use mas_data_model::UpstreamOAuthProviderUserProfileMethod;
1617
use mas_keystore::{Encrypter, Keystore};
1718
use mas_oidc_client::requests::{
1819
authorization_code::AuthorizationValidationData, jose::JwtVerificationData,
@@ -94,13 +95,14 @@ pub(crate) enum RouteError {
9495
MissingCookie,
9596

9697
#[error(transparent)]
97-
Internal(Box<dyn std::error::Error>),
98+
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
9899
}
99100

100101
impl_from_error_for_route!(mas_storage::RepositoryError);
101102
impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError);
102103
impl_from_error_for_route!(mas_oidc_client::error::JwksError);
103104
impl_from_error_for_route!(mas_oidc_client::error::TokenAuthorizationCodeError);
105+
impl_from_error_for_route!(mas_oidc_client::error::UserInfoError);
104106
impl_from_error_for_route!(super::ProviderCredentialsError);
105107
impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
106108

@@ -212,32 +214,54 @@ pub(crate) async fn get(
212214
redirect_uri,
213215
};
214216

215-
let id_token_verification_data = JwtVerificationData {
217+
let verification_data = JwtVerificationData {
216218
issuer: &provider.issuer,
217219
jwks: &jwks,
218220
// TODO: make that configurable
219221
signing_algorithm: &mas_iana::jose::JsonWebSignatureAlg::Rs256,
220222
client_id: &provider.client_id,
221223
};
222224

223-
let (response, id_token) =
225+
let (response, id_token_map) =
224226
mas_oidc_client::requests::authorization_code::access_token_with_authorization_code(
225227
&http_service,
226228
client_credentials,
227229
lazy_metadata.token_endpoint().await?,
228230
code,
229231
validation_data,
230-
Some(id_token_verification_data),
232+
Some(verification_data),
231233
clock.now(),
232234
&mut rng,
233235
)
234236
.await?;
235237

236-
let (_header, id_token) = id_token.ok_or(RouteError::MissingIDToken)?.into_parts();
238+
let (_header, id_token) = id_token_map
239+
.clone()
240+
.ok_or(RouteError::MissingIDToken)?
241+
.into_parts();
242+
243+
let use_userinfo_endpoint = match provider.user_profile_method {
244+
UpstreamOAuthProviderUserProfileMethod::Auto => !provider.scope.contains("openid"),
245+
UpstreamOAuthProviderUserProfileMethod::UserinfoEndpoint => true,
246+
};
247+
248+
let userinfo = if use_userinfo_endpoint {
249+
let user_info_resp = mas_oidc_client::requests::userinfo::fetch_userinfo(
250+
&http_service,
251+
lazy_metadata.userinfo_endpoint().await?,
252+
response.access_token.as_str(),
253+
Some(verification_data),
254+
&id_token_map.ok_or(RouteError::MissingIDToken)?,
255+
)
256+
.await?;
257+
minijinja::Value::from_serialize(&user_info_resp)
258+
} else {
259+
minijinja::Value::from_serialize(&id_token)
260+
};
237261

238262
let env = {
239263
let mut env = environment();
240-
env.add_global("user", minijinja::Value::from_serialize(&id_token));
264+
env.add_global("user", userinfo);
241265
env
242266
};
243267

crates/handlers/src/upstream_oauth2/link.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -907,12 +907,15 @@ mod tests {
907907
brand_name: None,
908908
scope: Scope::from_iter([OPENID]),
909909
token_endpoint_auth_method: OAuthClientAuthenticationMethod::None,
910+
user_profile_method:
911+
mas_data_model::UpstreamOAuthProviderUserProfileMethod::Auto,
910912
token_endpoint_signing_alg: None,
911913
client_id: "client".to_owned(),
912914
encrypted_client_secret: None,
913915
claims_imports,
914916
authorization_endpoint_override: None,
915917
token_endpoint_override: None,
918+
userinfo_endpoint_override: None,
916919
jwks_uri_override: None,
917920
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
918921
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,

crates/handlers/src/views/login.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,11 +371,14 @@ mod test {
371371
scope: [OPENID].into_iter().collect(),
372372
token_endpoint_auth_method: OAuthClientAuthenticationMethod::None,
373373
token_endpoint_signing_alg: None,
374+
user_profile_method:
375+
mas_data_model::UpstreamOAuthProviderUserProfileMethod::Auto,
374376
client_id: "client".to_owned(),
375377
encrypted_client_secret: None,
376378
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
377379
authorization_endpoint_override: None,
378380
token_endpoint_override: None,
381+
userinfo_endpoint_override: None,
379382
jwks_uri_override: None,
380383
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
381384
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
@@ -406,11 +409,14 @@ mod test {
406409
scope: [OPENID].into_iter().collect(),
407410
token_endpoint_auth_method: OAuthClientAuthenticationMethod::None,
408411
token_endpoint_signing_alg: None,
412+
user_profile_method:
413+
mas_data_model::UpstreamOAuthProviderUserProfileMethod::Auto,
409414
client_id: "client".to_owned(),
410415
encrypted_client_secret: None,
411416
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
412417
authorization_endpoint_override: None,
413418
token_endpoint_override: None,
419+
userinfo_endpoint_override: None,
414420
jwks_uri_override: None,
415421
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
416422
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,

0 commit comments

Comments
 (0)