Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions crates/cli/src/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,8 @@ pub async fn config_sync(
brand_name: provider.brand_name,
scope: provider.scope.parse()?,
token_endpoint_auth_method,
token_endpoint_signing_alg: provider
.token_endpoint_auth_signing_alg
.clone(),
token_endpoint_signing_alg: provider.token_endpoint_auth_signing_alg,
id_token_signed_response_alg: provider.id_token_signed_response_alg,
client_id: provider.client_id,
encrypted_client_secret,
claims_imports: map_claims_imports(&provider.claims_imports),
Expand All @@ -293,6 +292,7 @@ pub async fn config_sync(
discovery_mode,
pkce_mode,
fetch_userinfo: provider.fetch_userinfo,
userinfo_signed_response_alg: provider.userinfo_signed_response_alg,
response_mode,
additional_authorization_parameters: provider
.additional_authorization_parameters
Expand Down
32 changes: 32 additions & 0 deletions crates/config/src/sections/upstream_oauth2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,16 @@ fn is_default_true(value: &bool) -> bool {
*value
}

#[allow(clippy::ref_option)]
fn is_signed_response_alg_default(signed_response_alg: &Option<JsonWebSignatureAlg>) -> bool {
*signed_response_alg == signed_response_alg_default()
}

#[allow(clippy::unnecessary_wraps)]
fn signed_response_alg_default() -> Option<JsonWebSignatureAlg> {
Some(JsonWebSignatureAlg::Rs256)
}

#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct SignInWithApple {
/// The private key used to sign the `id_token`
Expand Down Expand Up @@ -472,6 +482,17 @@ pub struct Provider {
#[serde(skip_serializing_if = "Option::is_none")]
pub token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,

/// Expected signature for the JWT payload returned by the token
/// authentication endpoint.
///
/// If null, the response is expected to be an unsigned JSON payload.
/// Defaults to `RS256`.
#[serde(
default = "signed_response_alg_default",
skip_serializing_if = "is_signed_response_alg_default"
)]
pub id_token_signed_response_alg: Option<JsonWebSignatureAlg>,

/// The scopes to request from the provider
pub scope: String,

Expand All @@ -497,6 +518,17 @@ pub struct Provider {
#[serde(default)]
pub fetch_userinfo: bool,

/// Expected signature for the JWT payload returned by the userinfo
/// endpoint.
///
/// If null, the response is expected to be an unsigned JSON payload.
/// Defaults to `RS256`.
#[serde(
default = "signed_response_alg_default",
skip_serializing_if = "is_signed_response_alg_default"
)]
pub userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,

/// The URL to use for the provider's authorization endpoint
///
/// Defaults to the `authorization_endpoint` provided through discovery
Expand Down
2 changes: 2 additions & 0 deletions crates/data-model/src/upstream_oauth2/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,12 @@ pub struct UpstreamOAuthProvider {
pub token_endpoint_override: Option<Url>,
pub userinfo_endpoint_override: Option<Url>,
pub fetch_userinfo: bool,
pub userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
pub client_id: String,
pub encrypted_client_secret: Option<String>,
pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
pub token_endpoint_auth_method: TokenAuthMethod,
pub id_token_signed_response_alg: Option<JsonWebSignatureAlg>,
pub response_mode: ResponseMode,
pub created_at: DateTime<Utc>,
pub disabled_at: Option<DateTime<Utc>>,
Expand Down
2 changes: 2 additions & 0 deletions crates/handlers/src/upstream_oauth2/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ mod tests {
discovery_mode: UpstreamOAuthProviderDiscoveryMode::Insecure,
pkce_mode: UpstreamOAuthProviderPkceMode::Auto,
fetch_userinfo: false,
userinfo_signed_response_alg: None,
jwks_uri_override: None,
authorization_endpoint_override: None,
scope: Scope::from_iter([OPENID]),
Expand All @@ -410,6 +411,7 @@ mod tests {
token_endpoint_signing_alg: None,
token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
response_mode: mas_data_model::UpstreamOAuthProviderResponseMode::Query,
id_token_signed_response_alg: None,
created_at: clock.now(),
disabled_at: None,
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
Expand Down
151 changes: 94 additions & 57 deletions crates/handlers/src/upstream_oauth2/callback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,72 +274,109 @@ pub(crate) async fn handler(
)
.await?;

let mut jwks = None;

let mut context = AttributeMappingContext::new();
if let Some(id_token) = token_response.id_token.as_ref() {
// Fetch the JWKS
let jwks =
mas_oidc_client::requests::jose::fetch_jwks(&client, lazy_metadata.jwks_uri().await?)
.await?;

let verification_data = JwtVerificationData {
issuer: &provider.issuer,
jwks: &jwks,
// TODO: make that configurable
signing_algorithm: &mas_iana::jose::JsonWebSignatureAlg::Rs256,
client_id: &provider.client_id,
};

// Decode and verify the ID token
let id_token = mas_oidc_client::requests::jose::verify_id_token(
id_token,
verification_data,
None,
clock.now(),
)?;

let (_headers, mut claims) = id_token.into_parts();

// Access token hash must match.
mas_jose::claims::AT_HASH
.extract_optional_with_options(
&mut claims,
TokenHash::new(
verification_data.signing_algorithm,
&token_response.access_token,
),
)
.map_err(mas_oidc_client::error::IdTokenError::from)?;

// Code hash must match.
mas_jose::claims::C_HASH
.extract_optional_with_options(
&mut claims,
TokenHash::new(verification_data.signing_algorithm, &code),
)
.map_err(mas_oidc_client::error::IdTokenError::from)?;

// Nonce must match.
mas_jose::claims::NONCE
.extract_required_with_options(&mut claims, session.nonce.as_str())
.map_err(mas_oidc_client::error::IdTokenError::from)?;

context = context.with_id_token_claims(claims);
if let Some(signed_response_alg) = &provider.id_token_signed_response_alg {
jwks = Some(
mas_oidc_client::requests::jose::fetch_jwks(
&client,
lazy_metadata.jwks_uri().await?,
)
.await?,
);

let id_token_verification_data = JwtVerificationData {
issuer: &provider.issuer,
jwks: &jwks.clone().unwrap(),
signing_algorithm: signed_response_alg,
client_id: &provider.client_id,
};

// Decode and verify the ID token
let id_token = mas_oidc_client::requests::jose::verify_id_token(
id_token,
id_token_verification_data,
None,
clock.now(),
)?;

let (_headers, mut claims) = id_token.into_parts();

// Access token hash must match.
mas_jose::claims::AT_HASH
.extract_optional_with_options(
&mut claims,
TokenHash::new(
id_token_verification_data.signing_algorithm,
&token_response.access_token,
),
)
.map_err(mas_oidc_client::error::IdTokenError::from)?;

// Code hash must match.
mas_jose::claims::C_HASH
.extract_optional_with_options(
&mut claims,
TokenHash::new(id_token_verification_data.signing_algorithm, &code),
)
.map_err(mas_oidc_client::error::IdTokenError::from)?;

// Nonce must match.
mas_jose::claims::NONCE
.extract_required_with_options(&mut claims, session.nonce.as_str())
.map_err(mas_oidc_client::error::IdTokenError::from)?;

context = context.with_id_token_claims(claims);
} else {
let claims = serde_json::from_str(id_token)
.map_err(mas_oidc_client::error::IdTokenError::from)?;
context = context.with_id_token_claims(claims);
}
}

if let Some(extra_callback_parameters) = extra_callback_parameters.clone() {
context = context.with_extra_callback_parameters(extra_callback_parameters);
}

let userinfo = if provider.fetch_userinfo {
Some(json!(
mas_oidc_client::requests::userinfo::fetch_userinfo(
&client,
lazy_metadata.userinfo_endpoint().await?,
token_response.access_token.as_str(),
None,
)
.await?
))
Some(json!(match &provider.userinfo_signed_response_alg {
Some(signing_algorithm) => {
let jwks = match jwks {
Some(jwks) => jwks,
None => {
mas_oidc_client::requests::jose::fetch_jwks(
&client,
lazy_metadata.jwks_uri().await?,
)
.await?
}
};

mas_oidc_client::requests::userinfo::fetch_userinfo(
&client,
lazy_metadata.userinfo_endpoint().await?,
token_response.access_token.as_str(),
Some(JwtVerificationData {
issuer: &provider.issuer,
jwks: &jwks,
signing_algorithm,
client_id: &provider.client_id,
}),
)
.await?
}
None => {
mas_oidc_client::requests::userinfo::fetch_userinfo(
&client,
lazy_metadata.userinfo_endpoint().await?,
token_response.access_token.as_str(),
None,
)
.await?
}
}))
} else {
None
};
Expand Down
2 changes: 2 additions & 0 deletions crates/handlers/src/upstream_oauth2/link.rs
Original file line number Diff line number Diff line change
Expand Up @@ -922,13 +922,15 @@ mod tests {
scope: Scope::from_iter([OPENID]),
token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
token_endpoint_signing_alg: None,
id_token_signed_response_alg: None,
client_id: "client".to_owned(),
encrypted_client_secret: None,
claims_imports,
authorization_endpoint_override: None,
token_endpoint_override: None,
userinfo_endpoint_override: None,
fetch_userinfo: false,
userinfo_signed_response_alg: None,
jwks_uri_override: None,
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
Expand Down
4 changes: 4 additions & 0 deletions crates/handlers/src/views/login.rs
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,9 @@ mod test {
scope: [OPENID].into_iter().collect(),
token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
token_endpoint_signing_alg: None,
id_token_signed_response_alg: None,
fetch_userinfo: false,
userinfo_signed_response_alg: None,
client_id: "client".to_owned(),
encrypted_client_secret: None,
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
Expand Down Expand Up @@ -441,7 +443,9 @@ mod test {
scope: [OPENID].into_iter().collect(),
token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
token_endpoint_signing_alg: None,
id_token_signed_response_alg: None,
fetch_userinfo: false,
userinfo_signed_response_alg: None,
client_id: "client".to_owned(),
encrypted_client_secret: None,
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
Expand Down
4 changes: 4 additions & 0 deletions crates/oidc-client/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ pub enum IdTokenError {
/// one we got before.
#[error("wrong authentication time")]
WrongAuthTime,

#[error(transparent)]
/// TODO
Deserialize(#[from] serde_json::Error),
}

/// All errors that can occur when adding client credentials to the request.
Expand Down
Loading
Loading