Skip to content

Commit 56edcb4

Browse files
authored
Add fetch_userinfo to upstream SSO provider (#3363)
1 parent 9097516 commit 56edcb4

27 files changed

+408
-136
lines changed

crates/cli/src/sync.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,10 +284,12 @@ pub async fn config_sync(
284284
encrypted_client_secret,
285285
claims_imports: map_claims_imports(&provider.claims_imports),
286286
token_endpoint_override: provider.token_endpoint,
287+
userinfo_endpoint_override: provider.userinfo_endpoint,
287288
authorization_endpoint_override: provider.authorization_endpoint,
288289
jwks_uri_override: provider.jwks_uri,
289290
discovery_mode,
290291
pkce_mode,
292+
fetch_userinfo: provider.fetch_userinfo,
291293
response_mode,
292294
additional_authorization_parameters: provider
293295
.additional_authorization_parameters

crates/config/src/sections/upstream_oauth2.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,12 +465,26 @@ pub struct Provider {
465465
#[serde(default, skip_serializing_if = "PkceMethod::is_default")]
466466
pub pkce_method: PkceMethod,
467467

468+
/// Whether to fetch the user profile from the userinfo endpoint,
469+
/// or to rely on the data returned in the `id_token` from the
470+
/// `token_endpoint`.
471+
///
472+
/// Defaults to `false`.
473+
#[serde(default)]
474+
pub fetch_userinfo: bool,
475+
468476
/// The URL to use for the provider's authorization endpoint
469477
///
470478
/// Defaults to the `authorization_endpoint` provided through discovery
471479
#[serde(skip_serializing_if = "Option::is_none")]
472480
pub authorization_endpoint: Option<Url>,
473481

482+
/// The URL to use for the provider's userinfo endpoint
483+
///
484+
/// Defaults to the `userinfo_endpoint` provided through discovery
485+
#[serde(skip_serializing_if = "Option::is_none")]
486+
pub userinfo_endpoint: Option<Url>,
487+
474488
/// The URL to use for the provider's token endpoint
475489
///
476490
/// Defaults to the `token_endpoint` provided through discovery

crates/data-model/src/upstream_oauth2/provider.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,8 @@ pub struct UpstreamOAuthProvider {
228228
pub authorization_endpoint_override: Option<Url>,
229229
pub scope: Scope,
230230
pub token_endpoint_override: Option<Url>,
231+
pub userinfo_endpoint_override: Option<Url>,
232+
pub fetch_userinfo: bool,
231233
pub client_id: String,
232234
pub encrypted_client_secret: Option<String>,
233235
pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,

crates/data-model/src/upstream_oauth2/session.rs

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,15 @@ pub enum UpstreamOAuthAuthorizationSessionState {
2020
link_id: Ulid,
2121
id_token: Option<String>,
2222
extra_callback_parameters: Option<serde_json::Value>,
23+
userinfo: Option<serde_json::Value>,
2324
},
2425
Consumed {
2526
completed_at: DateTime<Utc>,
2627
consumed_at: DateTime<Utc>,
2728
link_id: Ulid,
2829
id_token: Option<String>,
2930
extra_callback_parameters: Option<serde_json::Value>,
31+
userinfo: Option<serde_json::Value>,
3032
},
3133
}
3234

@@ -45,13 +47,15 @@ impl UpstreamOAuthAuthorizationSessionState {
4547
link: &UpstreamOAuthLink,
4648
id_token: Option<String>,
4749
extra_callback_parameters: Option<serde_json::Value>,
50+
userinfo: Option<serde_json::Value>,
4851
) -> Result<Self, InvalidTransitionError> {
4952
match self {
5053
Self::Pending => Ok(Self::Completed {
5154
completed_at,
5255
link_id: link.id,
5356
id_token,
5457
extra_callback_parameters,
58+
userinfo,
5559
}),
5660
Self::Completed { .. } | Self::Consumed { .. } => Err(InvalidTransitionError),
5761
}
@@ -72,12 +76,14 @@ impl UpstreamOAuthAuthorizationSessionState {
7276
link_id,
7377
id_token,
7478
extra_callback_parameters,
79+
userinfo,
7580
} => Ok(Self::Consumed {
7681
completed_at,
7782
link_id,
7883
consumed_at,
7984
id_token,
8085
extra_callback_parameters,
86+
userinfo,
8187
}),
8288
Self::Pending | Self::Consumed { .. } => Err(InvalidTransitionError),
8389
}
@@ -151,6 +157,14 @@ impl UpstreamOAuthAuthorizationSessionState {
151157
}
152158
}
153159

160+
#[must_use]
161+
pub fn userinfo(&self) -> Option<&serde_json::Value> {
162+
match self {
163+
Self::Pending => None,
164+
Self::Completed { userinfo, .. } | Self::Consumed { userinfo, .. } => userinfo.as_ref(),
165+
}
166+
}
167+
154168
/// Get the time at which the upstream OAuth 2.0 authorization session was
155169
/// consumed.
156170
///
@@ -229,10 +243,15 @@ impl UpstreamOAuthAuthorizationSession {
229243
link: &UpstreamOAuthLink,
230244
id_token: Option<String>,
231245
extra_callback_parameters: Option<serde_json::Value>,
246+
userinfo: Option<serde_json::Value>,
232247
) -> Result<Self, InvalidTransitionError> {
233-
self.state =
234-
self.state
235-
.complete(completed_at, link, id_token, extra_callback_parameters)?;
248+
self.state = self.state.complete(
249+
completed_at,
250+
link,
251+
id_token,
252+
extra_callback_parameters,
253+
userinfo,
254+
)?;
236255
Ok(self)
237256
}
238257

crates/handlers/src/upstream_oauth2/cache.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,18 @@ impl<'a> LazyProviderInfos<'a> {
108108
Ok(self.load().await?.token_endpoint())
109109
}
110110

111+
/// Get the userinfo endpoint for the provider.
112+
///
113+
/// Uses [`UpstreamOAuthProvider.userinfo_endpoint_override`] if set,
114+
/// otherwise uses the one from discovery.
115+
pub async fn userinfo_endpoint(&mut self) -> Result<&Url, DiscoveryError> {
116+
if let Some(userinfo_endpoint) = &self.provider.userinfo_endpoint_override {
117+
return Ok(userinfo_endpoint);
118+
}
119+
120+
Ok(self.load().await?.userinfo_endpoint())
121+
}
122+
111123
/// Get the PKCE methods supported by the provider.
112124
///
113125
/// If the mode is set to auto, it will use the ones from discovery,
@@ -387,9 +399,11 @@ mod tests {
387399
brand_name: None,
388400
discovery_mode: UpstreamOAuthProviderDiscoveryMode::Insecure,
389401
pkce_mode: UpstreamOAuthProviderPkceMode::Auto,
402+
fetch_userinfo: false,
390403
jwks_uri_override: None,
391404
authorization_endpoint_override: None,
392405
scope: Scope::from_iter([OPENID]),
406+
userinfo_endpoint_override: None,
393407
token_endpoint_override: None,
394408
client_id: "client_id".to_owned(),
395409
encrypted_client_secret: None,

crates/handlers/src/upstream_oauth2/callback.rs

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ use mas_storage::{
2929
use mas_templates::{FormPostContext, Templates};
3030
use oauth2_types::errors::ClientErrorCode;
3131
use serde::{Deserialize, Serialize};
32+
use serde_json::json;
3233
use thiserror::Error;
3334
use ulid::Ulid;
3435

@@ -117,14 +118,15 @@ pub(crate) enum RouteError {
117118
},
118119

119120
#[error(transparent)]
120-
Internal(Box<dyn std::error::Error>),
121+
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
121122
}
122123

123124
impl_from_error_for_route!(mas_templates::TemplateError);
124125
impl_from_error_for_route!(mas_storage::RepositoryError);
125126
impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError);
126127
impl_from_error_for_route!(mas_oidc_client::error::JwksError);
127128
impl_from_error_for_route!(mas_oidc_client::error::TokenAuthorizationCodeError);
129+
impl_from_error_for_route!(mas_oidc_client::error::UserInfoError);
128130
impl_from_error_for_route!(super::ProviderCredentialsError);
129131
impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
130132

@@ -274,33 +276,56 @@ pub(crate) async fn handler(
274276
redirect_uri,
275277
};
276278

277-
let id_token_verification_data = JwtVerificationData {
279+
let verification_data = JwtVerificationData {
278280
issuer: &provider.issuer,
279281
jwks: &jwks,
280282
// TODO: make that configurable
281283
signing_algorithm: &mas_iana::jose::JsonWebSignatureAlg::Rs256,
282284
client_id: &provider.client_id,
283285
};
284286

285-
let (response, id_token) =
287+
let (response, id_token_map) =
286288
mas_oidc_client::requests::authorization_code::access_token_with_authorization_code(
287289
&client,
288290
client_credentials,
289291
lazy_metadata.token_endpoint().await?,
290292
code,
291293
validation_data,
292-
Some(id_token_verification_data),
294+
Some(verification_data),
293295
clock.now(),
294296
&mut rng,
295297
)
296298
.await?;
297299

298-
let (_header, id_token) = id_token.ok_or(RouteError::MissingIDToken)?.into_parts();
300+
let (_header, id_token) = id_token_map
301+
.clone()
302+
.ok_or(RouteError::MissingIDToken)?
303+
.into_parts();
299304

300305
let mut context = AttributeMappingContext::new().with_id_token_claims(id_token);
301306
if let Some(extra_callback_parameters) = extra_callback_parameters.clone() {
302307
context = context.with_extra_callback_parameters(extra_callback_parameters);
303308
}
309+
310+
let userinfo = if provider.fetch_userinfo {
311+
Some(json!(
312+
mas_oidc_client::requests::userinfo::fetch_userinfo(
313+
&client,
314+
lazy_metadata.userinfo_endpoint().await?,
315+
response.access_token.as_str(),
316+
Some(verification_data),
317+
&id_token_map.ok_or(RouteError::MissingIDToken)?,
318+
)
319+
.await?
320+
))
321+
} else {
322+
None
323+
};
324+
325+
if let Some(userinfo) = userinfo.clone() {
326+
context = context.with_userinfo_claims(userinfo);
327+
}
328+
304329
let context = context.build();
305330

306331
let env = environment();
@@ -341,6 +366,7 @@ pub(crate) async fn handler(
341366
&link,
342367
response.id_token,
343368
extra_callback_parameters,
369+
userinfo,
344370
)
345371
.await?;
346372

crates/handlers/src/upstream_oauth2/link.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,9 @@ pub(crate) async fn get(
344344
if let Some(extra_callback_parameters) = upstream_session.extra_callback_parameters() {
345345
context = context.with_extra_callback_parameters(extra_callback_parameters.clone());
346346
}
347+
if let Some(userinfo) = upstream_session.userinfo() {
348+
context = context.with_userinfo_claims(userinfo.clone());
349+
}
347350
let context = context.build();
348351

349352
let ctx = if provider.claims_imports.displayname.ignore() {
@@ -582,6 +585,9 @@ pub(crate) async fn post(
582585
if let Some(extra_callback_parameters) = upstream_session.extra_callback_parameters() {
583586
context = context.with_extra_callback_parameters(extra_callback_parameters.clone());
584587
}
588+
if let Some(userinfo) = upstream_session.userinfo() {
589+
context = context.with_userinfo_claims(userinfo.clone());
590+
}
585591
let context = context.build();
586592

587593
// Is the email verified according to the upstream provider?
@@ -921,6 +927,8 @@ mod tests {
921927
claims_imports,
922928
authorization_endpoint_override: None,
923929
token_endpoint_override: None,
930+
userinfo_endpoint_override: None,
931+
fetch_userinfo: false,
924932
jwks_uri_override: None,
925933
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
926934
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
@@ -958,6 +966,7 @@ mod tests {
958966
&link,
959967
Some(id_token.into_string()),
960968
None,
969+
None,
961970
)
962971
.await
963972
.unwrap();

crates/handlers/src/upstream_oauth2/template.rs

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use minijinja::{
2323
pub(crate) struct AttributeMappingContext {
2424
id_token_claims: Option<HashMap<String, serde_json::Value>>,
2525
extra_callback_parameters: Option<serde_json::Value>,
26+
userinfo_claims: Option<serde_json::Value>,
2627
}
2728

2829
impl AttributeMappingContext {
@@ -46,6 +47,11 @@ impl AttributeMappingContext {
4647
self
4748
}
4849

50+
pub fn with_userinfo_claims(mut self, userinfo_claims: serde_json::Value) -> Self {
51+
self.userinfo_claims = Some(userinfo_claims);
52+
self
53+
}
54+
4955
pub fn build(self) -> Value {
5056
Value::from_object(self)
5157
}
@@ -54,7 +60,25 @@ impl AttributeMappingContext {
5460
impl Object for AttributeMappingContext {
5561
fn get_value(self: &Arc<Self>, name: &Value) -> Option<Value> {
5662
match name.as_str()? {
57-
"user" | "id_token_claims" => self.id_token_claims.as_ref().map(Value::from_serialize),
63+
"user" => {
64+
if self.id_token_claims.is_none() && self.userinfo_claims.is_none() {
65+
return None;
66+
}
67+
let mut merged_user: HashMap<String, serde_json::Value> = HashMap::new();
68+
if let serde_json::Value::Object(userinfo) = self
69+
.userinfo_claims
70+
.clone()
71+
.unwrap_or(serde_json::Value::Null)
72+
{
73+
merged_user.extend(userinfo);
74+
}
75+
if let Some(id_token) = self.id_token_claims.clone() {
76+
merged_user.extend(id_token);
77+
}
78+
Some(Value::from_serialize(merged_user))
79+
}
80+
"id_token_claims" => self.id_token_claims.as_ref().map(Value::from_serialize),
81+
"userinfo_claims" => self.userinfo_claims.as_ref().map(Value::from_serialize),
5882
"extra_callback_parameters" => self
5983
.extra_callback_parameters
6084
.as_ref()
@@ -64,17 +88,20 @@ impl Object for AttributeMappingContext {
6488
}
6589

6690
fn enumerate(self: &Arc<Self>) -> Enumerator {
67-
match (
68-
self.id_token_claims.is_some(),
69-
self.extra_callback_parameters.is_some(),
70-
) {
71-
(true, true) => {
72-
Enumerator::Str(&["user", "id_token_claims", "extra_callback_parameters"])
73-
}
74-
(true, false) => Enumerator::Str(&["user", "id_token_claims"]),
75-
(false, true) => Enumerator::Str(&["extra_callback_parameters"]),
76-
(false, false) => Enumerator::Str(&["user"]),
91+
let mut attrs = Vec::new();
92+
if self.id_token_claims.is_some() || self.userinfo_claims.is_none() {
93+
attrs.push(minijinja::Value::from("user"));
94+
}
95+
if self.id_token_claims.is_some() {
96+
attrs.push(minijinja::Value::from("id_token_claims"));
97+
}
98+
if self.userinfo_claims.is_some() {
99+
attrs.push(minijinja::Value::from("userinfo_claims"));
100+
}
101+
if self.extra_callback_parameters.is_some() {
102+
attrs.push(minijinja::Value::from("extra_callback_parameters"));
77103
}
104+
Enumerator::Values(attrs)
78105
}
79106
}
80107

crates/handlers/src/views/login.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,11 +403,13 @@ mod test {
403403
scope: [OPENID].into_iter().collect(),
404404
token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
405405
token_endpoint_signing_alg: None,
406+
fetch_userinfo: false,
406407
client_id: "client".to_owned(),
407408
encrypted_client_secret: None,
408409
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
409410
authorization_endpoint_override: None,
410411
token_endpoint_override: None,
412+
userinfo_endpoint_override: None,
411413
jwks_uri_override: None,
412414
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
413415
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
@@ -439,11 +441,13 @@ mod test {
439441
scope: [OPENID].into_iter().collect(),
440442
token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
441443
token_endpoint_signing_alg: None,
444+
fetch_userinfo: false,
442445
client_id: "client".to_owned(),
443446
encrypted_client_secret: None,
444447
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
445448
authorization_endpoint_override: None,
446449
token_endpoint_override: None,
450+
userinfo_endpoint_override: None,
447451
jwks_uri_override: None,
448452
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
449453
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,

0 commit comments

Comments
 (0)