Skip to content

Commit 9f3ac54

Browse files
committed
Add user_profile_method to upstream SSO provider
1 parent 6d17582 commit 9f3ac54

29 files changed

+501
-134
lines changed

crates/cli/src/sync.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,15 @@ pub async fn config_sync(
231231
}
232232
};
233233

234+
let user_profile_method = match provider.user_profile_method {
235+
mas_config::UpstreamOAuth2UserProfileMethod::Auto => {
236+
mas_data_model::UpstreamOAuthProviderUserProfileMethod::Auto
237+
}
238+
mas_config::UpstreamOAuth2UserProfileMethod::UserinfoEndpoint => {
239+
mas_data_model::UpstreamOAuthProviderUserProfileMethod::UserinfoEndpoint
240+
}
241+
};
242+
234243
repo.upstream_oauth_provider()
235244
.upsert(
236245
clock,
@@ -241,13 +250,15 @@ pub async fn config_sync(
241250
brand_name: provider.brand_name,
242251
scope: provider.scope.parse()?,
243252
token_endpoint_auth_method: provider.token_endpoint_auth_method.into(),
253+
user_profile_method,
244254
token_endpoint_signing_alg: provider
245255
.token_endpoint_auth_signing_alg
246256
.clone(),
247257
client_id: provider.client_id,
248258
encrypted_client_secret,
249259
claims_imports: map_claims_imports(&provider.claims_imports),
250260
token_endpoint_override: provider.token_endpoint,
261+
userinfo_endpoint_override: provider.userinfo_endpoint,
251262
authorization_endpoint_override: provider.authorization_endpoint,
252263
jwks_uri_override: provider.jwks_uri,
253264
discovery_mode,

crates/config/src/sections/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ pub use self::{
5252
EmailImportPreference as UpstreamOAuth2EmailImportPreference,
5353
ImportAction as UpstreamOAuth2ImportAction, PkceMethod as UpstreamOAuth2PkceMethod,
5454
SetEmailVerification as UpstreamOAuth2SetEmailVerification, UpstreamOAuth2Config,
55+
UserProfileMethod as UpstreamOAuth2UserProfileMethod,
5556
},
5657
};
5758
use crate::util::ConfigurationSection;

crates/config/src/sections/upstream_oauth2.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,26 @@ impl From<TokenAuthMethod> for OAuthClientAuthenticationMethod {
124124
}
125125
}
126126

127+
/// Whether to fetch the user profile from the userinfo endpoint,
128+
/// or to rely on the data returned in the id_token from the token_endpoint
129+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
130+
#[serde(rename_all = "snake_case")]
131+
pub enum UserProfileMethod {
132+
/// Use the userinfo endpoint if `openid` is not included in `scopes`
133+
#[default]
134+
Auto,
135+
136+
/// Always use the userinfo endpoint
137+
UserinfoEndpoint,
138+
}
139+
140+
impl UserProfileMethod {
141+
#[allow(clippy::trivially_copy_pass_by_ref)]
142+
const fn is_default(&self) -> bool {
143+
matches!(self, UserProfileMethod::Auto)
144+
}
145+
}
146+
127147
/// How to handle a claim
128148
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
129149
#[serde(rename_all = "lowercase")]
@@ -401,6 +421,14 @@ pub struct Provider {
401421
#[serde(skip_serializing_if = "Option::is_none")]
402422
pub token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
403423

424+
/// Whether to fetch the user profile from the userinfo endpoint,
425+
/// or to rely on the data returned in the id_token from the token_endpoint.
426+
///
427+
/// Defaults to `auto`, which uses the userinfo endpoint if `openid` is not
428+
/// included in `scopes`, and the ID token otherwise.
429+
#[serde(default, skip_serializing_if = "UserProfileMethod::is_default")]
430+
pub user_profile_method: UserProfileMethod,
431+
404432
/// The scopes to request from the provider
405433
pub scope: String,
406434

@@ -424,6 +452,12 @@ pub struct Provider {
424452
#[serde(skip_serializing_if = "Option::is_none")]
425453
pub authorization_endpoint: Option<Url>,
426454

455+
/// The URL to use for the provider's userinfo endpoint
456+
///
457+
/// Defaults to the `userinfo_endpoint` provided through discovery
458+
#[serde(skip_serializing_if = "Option::is_none")]
459+
pub userinfo_endpoint: Option<Url>,
460+
427461
/// The URL to use for the provider's token endpoint
428462
///
429463
/// Defaults to the `token_endpoint` provided through discovery

crates/data-model/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ pub use self::{
4242
UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode,
4343
UpstreamOAuthProviderImportAction, UpstreamOAuthProviderImportPreference,
4444
UpstreamOAuthProviderPkceMode, UpstreamOAuthProviderSubjectPreference,
45+
UpstreamOAuthProviderUserProfileMethod,
4546
},
4647
user_agent::{DeviceType, UserAgent},
4748
users::{

crates/data-model/src/upstream_oauth2/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ pub use self::{
1818
PkceMode as UpstreamOAuthProviderPkceMode,
1919
SetEmailVerification as UpsreamOAuthProviderSetEmailVerification,
2020
SubjectPreference as UpstreamOAuthProviderSubjectPreference, UpstreamOAuthProvider,
21+
UserProfileMethod as UpstreamOAuthProviderUserProfileMethod,
2122
},
2223
session::{UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState},
2324
};

crates/data-model/src/upstream_oauth2/provider.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,51 @@ impl std::fmt::Display for PkceMode {
116116
}
117117
}
118118

119+
/// Whether to fetch the user profile from the userinfo endpoint,
120+
/// or to rely on the data returned in the id_token from the token_endpoint
121+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
122+
#[serde(rename_all = "lowercase")]
123+
pub enum UserProfileMethod {
124+
/// Use the userinfo endpoint if `openid` is not included in `scopes`
125+
#[default]
126+
Auto,
127+
128+
/// Always use the userinfo endpoint
129+
UserinfoEndpoint,
130+
}
131+
132+
#[derive(Debug, Clone, Error)]
133+
#[error("Invalid user profile method {0:?}")]
134+
pub struct InvalidUserProfileMethodError(String);
135+
136+
impl std::str::FromStr for UserProfileMethod {
137+
type Err = InvalidUserProfileMethodError;
138+
139+
fn from_str(s: &str) -> Result<Self, Self::Err> {
140+
match s {
141+
"auto" => Ok(Self::Auto),
142+
"userinfo_endpoint" => Ok(Self::UserinfoEndpoint),
143+
s => Err(InvalidUserProfileMethodError(s.to_owned())),
144+
}
145+
}
146+
}
147+
148+
impl UserProfileMethod {
149+
#[must_use]
150+
pub fn as_str(self) -> &'static str {
151+
match self {
152+
Self::Auto => "auto",
153+
Self::UserinfoEndpoint => "userinfo_endpoint",
154+
}
155+
}
156+
}
157+
158+
impl std::fmt::Display for UserProfileMethod {
159+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160+
f.write_str(self.as_str())
161+
}
162+
}
163+
119164
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
120165
pub struct UpstreamOAuthProvider {
121166
pub id: Ulid,
@@ -127,11 +172,13 @@ pub struct UpstreamOAuthProvider {
127172
pub jwks_uri_override: Option<Url>,
128173
pub authorization_endpoint_override: Option<Url>,
129174
pub token_endpoint_override: Option<Url>,
175+
pub userinfo_endpoint_override: Option<Url>,
130176
pub scope: Scope,
131177
pub client_id: String,
132178
pub encrypted_client_secret: Option<String>,
133179
pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
134180
pub token_endpoint_auth_method: OAuthClientAuthenticationMethod,
181+
pub user_profile_method: UserProfileMethod,
135182
pub created_at: DateTime<Utc>,
136183
pub disabled_at: Option<DateTime<Utc>>,
137184
pub claims_imports: ClaimsImports,

crates/data-model/src/upstream_oauth2/session.rs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@ pub enum UpstreamOAuthAuthorizationSessionState {
1919
completed_at: DateTime<Utc>,
2020
link_id: Ulid,
2121
id_token: Option<String>,
22+
userinfo: Option<String>,
2223
},
2324
Consumed {
2425
completed_at: DateTime<Utc>,
2526
consumed_at: DateTime<Utc>,
2627
link_id: Ulid,
2728
id_token: Option<String>,
29+
userinfo: Option<String>,
2830
},
2931
}
3032

@@ -42,12 +44,14 @@ impl UpstreamOAuthAuthorizationSessionState {
4244
completed_at: DateTime<Utc>,
4345
link: &UpstreamOAuthLink,
4446
id_token: Option<String>,
47+
userinfo: Option<String>,
4548
) -> Result<Self, InvalidTransitionError> {
4649
match self {
4750
Self::Pending => Ok(Self::Completed {
4851
completed_at,
4952
link_id: link.id,
5053
id_token,
54+
userinfo,
5155
}),
5256
Self::Completed { .. } | Self::Consumed { .. } => Err(InvalidTransitionError),
5357
}
@@ -67,11 +71,13 @@ impl UpstreamOAuthAuthorizationSessionState {
6771
completed_at,
6872
link_id,
6973
id_token,
74+
userinfo,
7075
} => Ok(Self::Consumed {
7176
completed_at,
7277
link_id,
7378
consumed_at,
7479
id_token,
80+
userinfo,
7581
}),
7682
Self::Pending | Self::Consumed { .. } => Err(InvalidTransitionError),
7783
}
@@ -124,6 +130,16 @@ impl UpstreamOAuthAuthorizationSessionState {
124130
}
125131
}
126132

133+
#[must_use]
134+
pub fn userinfo(&self) -> Option<&str> {
135+
match self {
136+
Self::Pending => None,
137+
Self::Completed { userinfo, .. } | Self::Consumed { userinfo, .. } => {
138+
userinfo.as_deref()
139+
}
140+
}
141+
}
142+
127143
/// Get the time at which the upstream OAuth 2.0 authorization session was
128144
/// consumed.
129145
///
@@ -201,8 +217,11 @@ impl UpstreamOAuthAuthorizationSession {
201217
completed_at: DateTime<Utc>,
202218
link: &UpstreamOAuthLink,
203219
id_token: Option<String>,
220+
userinfo: Option<String>,
204221
) -> Result<Self, InvalidTransitionError> {
205-
self.state = self.state.complete(completed_at, link, id_token)?;
222+
self.state = self
223+
.state
224+
.complete(completed_at, link, id_token, userinfo)?;
206225
Ok(self)
207226
}
208227

crates/handlers/src/upstream_oauth2/cache.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,18 @@ impl<'a> LazyProviderInfos<'a> {
108108
Ok(self.load().await?.token_endpoint())
109109
}
110110

111+
/// Get the userinfo endpoint for the provider.
112+
///
113+
/// Uses [`UpstreamOAuthProvider.userinfo_endpoint_override`] if set, otherwise
114+
/// uses the one from discovery.
115+
pub async fn userinfo_endpoint(&mut self) -> Result<&Url, DiscoveryError> {
116+
if let Some(userinfo_endpoint) = &self.provider.userinfo_endpoint_override {
117+
return Ok(userinfo_endpoint);
118+
}
119+
120+
Ok(self.load().await?.userinfo_endpoint())
121+
}
122+
111123
/// Get the PKCE methods supported by the provider.
112124
///
113125
/// If the mode is set to auto, it will use the ones from discovery,
@@ -274,7 +286,9 @@ mod tests {
274286
// XXX: sadly, we can't test HTTPS requests with wiremock, so we can only test
275287
// 'insecure' discovery
276288

277-
use mas_data_model::UpstreamOAuthProviderClaimsImports;
289+
use mas_data_model::{
290+
UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderUserProfileMethod,
291+
};
278292
use mas_iana::oauth::OAuthClientAuthenticationMethod;
279293
use mas_storage::{clock::MockClock, Clock};
280294
use oauth2_types::scope::{Scope, OPENID};
@@ -386,8 +400,10 @@ mod tests {
386400
brand_name: None,
387401
discovery_mode: UpstreamOAuthProviderDiscoveryMode::Insecure,
388402
pkce_mode: UpstreamOAuthProviderPkceMode::Auto,
403+
user_profile_method: UpstreamOAuthProviderUserProfileMethod::Auto,
389404
jwks_uri_override: None,
390405
authorization_endpoint_override: None,
406+
userinfo_endpoint_override: None,
391407
token_endpoint_override: None,
392408
scope: Scope::from_iter([OPENID]),
393409
client_id: "client_id".to_owned(),

0 commit comments

Comments
 (0)