Skip to content
This repository was archived by the owner on Sep 10, 2024. It is now read-only.

Commit 7315dd9

Browse files
committed
Allow endpoints and discovery mode override for upstream oauth2 providers
This time, at the configuration and database level
1 parent 364093f commit 7315dd9

19 files changed

+764
-233
lines changed

crates/cli/src/commands/config.rs

Lines changed: 63 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@ use std::collections::HashSet;
1717
use clap::Parser;
1818
use mas_config::{ConfigurationSection, RootConfig, SyncConfig};
1919
use mas_storage::{
20-
upstream_oauth2::UpstreamOAuthProviderRepository, RepositoryAccess, SystemClock,
20+
upstream_oauth2::{UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository},
21+
RepositoryAccess, SystemClock,
2122
};
2223
use mas_storage_pg::PgRepository;
2324
use rand::SeedableRng;
2425
use sqlx::{postgres::PgAdvisoryLock, Acquire};
25-
use tracing::{info, info_span, warn};
26+
use tracing::{error, info, info_span, warn};
2627

2728
use crate::util::database_connection_from_config;
2829

@@ -204,10 +205,11 @@ async fn sync(root: &super::Options, prune: bool, dry_run: bool) -> anyhow::Resu
204205
}
205206

206207
for provider in config.upstream_oauth2.providers {
208+
let _span = info_span!("provider", %provider.id).entered();
207209
if existing_ids.contains(&provider.id) {
208-
info!(%provider.id, "Updating provider");
210+
info!("Updating provider");
209211
} else {
210-
info!(%provider.id, "Adding provider");
212+
info!("Adding provider");
211213
}
212214

213215
if dry_run {
@@ -218,20 +220,65 @@ async fn sync(root: &super::Options, prune: bool, dry_run: bool) -> anyhow::Resu
218220
.client_secret()
219221
.map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes()))
220222
.transpose()?;
221-
let client_auth_method = provider.client_auth_method();
222-
let client_auth_signing_alg = provider.client_auth_signing_alg();
223+
let token_endpoint_auth_method = provider.client_auth_method();
224+
let token_endpoint_signing_alg = provider.client_auth_signing_alg();
225+
226+
let discovery_mode = match provider.discovery_mode {
227+
mas_config::UpstreamOAuth2DiscoveryMode::Oidc => {
228+
mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc
229+
}
230+
mas_config::UpstreamOAuth2DiscoveryMode::Insecure => {
231+
mas_data_model::UpstreamOAuthProviderDiscoveryMode::Insecure
232+
}
233+
mas_config::UpstreamOAuth2DiscoveryMode::Disabled => {
234+
mas_data_model::UpstreamOAuthProviderDiscoveryMode::Disabled
235+
}
236+
};
237+
238+
if discovery_mode.is_disabled() {
239+
if provider.authorization_endpoint.is_none() {
240+
error!("Provider has discovery disabled but no authorization endpoint set");
241+
}
242+
243+
if provider.token_endpoint.is_none() {
244+
error!("Provider has discovery disabled but no token endpoint set");
245+
}
246+
247+
if provider.jwks_uri.is_none() {
248+
error!("Provider has discovery disabled but no JWKS URI set");
249+
}
250+
}
251+
252+
let pkce_mode = match provider.pkce_method {
253+
mas_config::UpstreamOAuth2PkceMethod::Auto => {
254+
mas_data_model::UpstreamOAuthProviderPkceMode::Auto
255+
}
256+
mas_config::UpstreamOAuth2PkceMethod::Always => {
257+
mas_data_model::UpstreamOAuthProviderPkceMode::S256
258+
}
259+
mas_config::UpstreamOAuth2PkceMethod::Never => {
260+
mas_data_model::UpstreamOAuthProviderPkceMode::Disabled
261+
}
262+
};
223263

224264
repo.upstream_oauth_provider()
225265
.upsert(
226266
&clock,
227267
provider.id,
228-
provider.issuer,
229-
provider.scope.parse()?,
230-
client_auth_method,
231-
client_auth_signing_alg,
232-
provider.client_id,
233-
encrypted_client_secret,
234-
map_claims_imports(&provider.claims_imports),
268+
UpstreamOAuthProviderParams {
269+
issuer: provider.issuer,
270+
scope: provider.scope.parse()?,
271+
token_endpoint_auth_method,
272+
token_endpoint_signing_alg,
273+
client_id: provider.client_id,
274+
encrypted_client_secret,
275+
claims_imports: map_claims_imports(&provider.claims_imports),
276+
token_endpoint_override: provider.token_endpoint,
277+
authorization_endpoint_override: provider.authorization_endpoint,
278+
jwks_uri_override: provider.jwks_uri,
279+
discovery_mode,
280+
pkce_mode,
281+
},
235282
)
236283
.await?;
237284
}
@@ -268,10 +315,11 @@ async fn sync(root: &super::Options, prune: bool, dry_run: bool) -> anyhow::Resu
268315
}
269316

270317
for client in config.clients.iter() {
318+
let _span = info_span!("client", client.id = %client.client_id).entered();
271319
if existing_ids.contains(&client.client_id) {
272-
info!(client.id = %client.client_id, "Updating client");
320+
info!("Updating client");
273321
} else {
274-
info!(client.id = %client.client_id, "Adding client");
322+
info!("Adding client");
275323
}
276324

277325
if dry_run {

crates/config/src/sections/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,10 @@ pub use self::{
5151
},
5252
templates::TemplatesConfig,
5353
upstream_oauth2::{
54-
ClaimsImports as UpstreamOAuth2ClaimsImports,
54+
ClaimsImports as UpstreamOAuth2ClaimsImports, DiscoveryMode as UpstreamOAuth2DiscoveryMode,
5555
EmailImportPreference as UpstreamOAuth2EmailImportPreference,
5656
ImportAction as UpstreamOAuth2ImportAction,
57-
ImportPreference as UpstreamOAuth2ImportPreference,
57+
ImportPreference as UpstreamOAuth2ImportPreference, PkceMethod as UpstreamOAuth2PkceMethod,
5858
SetEmailVerification as UpstreamOAuth2SetEmailVerification, UpstreamOAuth2Config,
5959
},
6060
};

crates/config/src/sections/upstream_oauth2.rs

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use schemars::JsonSchema;
2121
use serde::{Deserialize, Serialize};
2222
use serde_with::skip_serializing_none;
2323
use ulid::Ulid;
24+
use url::Url;
2425

2526
use crate::ConfigurationSection;
2627

@@ -197,6 +198,39 @@ pub struct ClaimsImports {
197198
pub email: EmailImportPreference,
198199
}
199200

201+
/// How to discover the provider's configuration
202+
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
203+
#[serde(rename_all = "snake_case")]
204+
pub enum DiscoveryMode {
205+
/// Use OIDC discovery with strict metadata verification
206+
#[default]
207+
Oidc,
208+
209+
/// Use OIDC discovery with relaxed metadata verification
210+
Insecure,
211+
212+
/// Use a static configuration
213+
Disabled,
214+
}
215+
216+
/// Whether to use proof key for code exchange (PKCE) when requesting and
217+
/// exchanging the token.
218+
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
219+
#[serde(rename_all = "snake_case")]
220+
pub enum PkceMethod {
221+
/// Use PKCE if the provider supports it
222+
///
223+
/// Defaults to no PKCE if provider discovery is disabled
224+
#[default]
225+
Auto,
226+
227+
/// Always use PKCE with the S256 challenge method
228+
Always,
229+
230+
/// Never use PKCE
231+
Never,
232+
}
233+
200234
#[skip_serializing_none]
201235
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
202236
pub struct Provider {
@@ -220,6 +254,34 @@ pub struct Provider {
220254
#[serde(flatten)]
221255
pub token_auth_method: TokenAuthMethod,
222256

257+
/// How to discover the provider's configuration
258+
///
259+
/// Defaults to use OIDC discovery with strict metadata verification
260+
#[serde(default)]
261+
pub discovery_mode: DiscoveryMode,
262+
263+
/// Whether to use proof key for code exchange (PKCE) when requesting and
264+
/// exchanging the token.
265+
///
266+
/// Defaults to `auto`, which uses PKCE if the provider supports it.
267+
#[serde(default)]
268+
pub pkce_method: PkceMethod,
269+
270+
/// The URL to use for the provider's authorization endpoint
271+
///
272+
/// Defaults to the `authorization_endpoint` provided through discovery
273+
pub authorization_endpoint: Option<Url>,
274+
275+
/// The URL to use for the provider's token endpoint
276+
///
277+
/// Defaults to the `token_endpoint` provided through discovery
278+
pub token_endpoint: Option<Url>,
279+
280+
/// The URL to use for getting the provider's public keys
281+
///
282+
/// Defaults to the `jwks_uri` provided through discovery
283+
pub jwks_uri: Option<Url>,
284+
223285
/// How claims should be imported from the `id_token` provided by the
224286
/// provider
225287
pub claims_imports: ClaimsImports,

crates/data-model/src/upstream_oauth2/provider.rs

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use chrono::{DateTime, Utc};
1616
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
1717
use oauth2_types::scope::Scope;
1818
use serde::{Deserialize, Serialize};
19+
use thiserror::Error;
1920
use ulid::Ulid;
2021
use url::Url;
2122

@@ -33,6 +34,48 @@ pub enum DiscoveryMode {
3334
Disabled,
3435
}
3536

37+
impl DiscoveryMode {
38+
/// Returns `true` if discovery is disabled
39+
#[must_use]
40+
pub fn is_disabled(&self) -> bool {
41+
matches!(self, DiscoveryMode::Disabled)
42+
}
43+
}
44+
45+
#[derive(Debug, Clone, Error)]
46+
#[error("Invalid discovery mode {0:?}")]
47+
pub struct InvalidDiscoveryModeError(String);
48+
49+
impl std::str::FromStr for DiscoveryMode {
50+
type Err = InvalidDiscoveryModeError;
51+
52+
fn from_str(s: &str) -> Result<Self, Self::Err> {
53+
match s {
54+
"oidc" => Ok(Self::Oidc),
55+
"insecure" => Ok(Self::Insecure),
56+
"disabled" => Ok(Self::Disabled),
57+
s => Err(InvalidDiscoveryModeError(s.to_owned())),
58+
}
59+
}
60+
}
61+
62+
impl DiscoveryMode {
63+
#[must_use]
64+
pub fn as_str(self) -> &'static str {
65+
match self {
66+
Self::Oidc => "oidc",
67+
Self::Insecure => "insecure",
68+
Self::Disabled => "disabled",
69+
}
70+
}
71+
}
72+
73+
impl std::fmt::Display for DiscoveryMode {
74+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75+
f.write_str(self.as_str())
76+
}
77+
}
78+
3679
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
3780
#[serde(rename_all = "lowercase")]
3881
pub enum PkceMode {
@@ -47,6 +90,40 @@ pub enum PkceMode {
4790
Disabled,
4891
}
4992

93+
#[derive(Debug, Clone, Error)]
94+
#[error("Invalid PKCE mode {0:?}")]
95+
pub struct InvalidPkceModeError(String);
96+
97+
impl std::str::FromStr for PkceMode {
98+
type Err = InvalidPkceModeError;
99+
100+
fn from_str(s: &str) -> Result<Self, Self::Err> {
101+
match s {
102+
"auto" => Ok(Self::Auto),
103+
"s256" => Ok(Self::S256),
104+
"disabled" => Ok(Self::Disabled),
105+
s => Err(InvalidPkceModeError(s.to_owned())),
106+
}
107+
}
108+
}
109+
110+
impl PkceMode {
111+
#[must_use]
112+
pub fn as_str(self) -> &'static str {
113+
match self {
114+
Self::Auto => "auto",
115+
Self::S256 => "s256",
116+
Self::Disabled => "disabled",
117+
}
118+
}
119+
}
120+
121+
impl std::fmt::Display for PkceMode {
122+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123+
f.write_str(self.as_str())
124+
}
125+
}
126+
50127
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
51128
pub struct UpstreamOAuthProvider {
52129
pub id: Ulid,

crates/handlers/src/upstream_oauth2/cache.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,9 +292,8 @@ mod tests {
292292
use tower::BoxError;
293293
use ulid::Ulid;
294294

295-
use crate::test_utils::init_tracing;
296-
297295
use super::*;
296+
use crate::test_utils::init_tracing;
298297

299298
#[tokio::test]
300299
async fn test_metadata_cache() {

crates/handlers/src/upstream_oauth2/link.rs

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,7 @@ mod tests {
803803
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
804804
use mas_jose::jwt::{JsonWebSignatureHeader, Jwt};
805805
use mas_router::Route;
806+
use mas_storage::upstream_oauth2::UpstreamOAuthProviderParams;
806807
use oauth2_types::scope::{Scope, OPENID};
807808
use sqlx::PgPool;
808809

@@ -858,13 +859,20 @@ mod tests {
858859
.add(
859860
&mut rng,
860861
&state.clock,
861-
"https://example.com/".to_owned(),
862-
Scope::from_iter([OPENID]),
863-
OAuthClientAuthenticationMethod::None,
864-
None,
865-
"client".to_owned(),
866-
None,
867-
claims_imports,
862+
UpstreamOAuthProviderParams {
863+
issuer: "https://example.com/".to_owned(),
864+
scope: Scope::from_iter([OPENID]),
865+
token_endpoint_auth_method: OAuthClientAuthenticationMethod::None,
866+
token_endpoint_signing_alg: None,
867+
client_id: "client".to_owned(),
868+
encrypted_client_secret: None,
869+
claims_imports,
870+
authorization_endpoint_override: None,
871+
token_endpoint_override: None,
872+
jwks_uri_override: None,
873+
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
874+
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
875+
},
868876
)
869877
.await
870878
.unwrap();

0 commit comments

Comments
 (0)