|
| 1 | +// Copyright 2025 New Vector Ltd. |
| 2 | +// |
| 3 | +// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial |
| 4 | +// Please see LICENSE files in the repository root for full details. |
| 5 | + |
| 6 | +use std::collections::HashMap; |
| 7 | + |
| 8 | +use axum::{ |
| 9 | + Form, Json, |
| 10 | + extract::{Path, State, rejection::FormRejection}, |
| 11 | + response::IntoResponse, |
| 12 | +}; |
| 13 | +use hyper::StatusCode; |
| 14 | +use mas_axum_utils::record_error; |
| 15 | +use mas_data_model::UpstreamOAuthProvider; |
| 16 | +use mas_jose::{ |
| 17 | + claims::{self, Claim, TimeOptions}, |
| 18 | + jwt::JwtDecodeError, |
| 19 | +}; |
| 20 | +use mas_oidc_client::{ |
| 21 | + error::JwtVerificationError, |
| 22 | + requests::jose::{JwtVerificationData, verify_signed_jwt}, |
| 23 | +}; |
| 24 | +use mas_storage::{ |
| 25 | + BoxClock, BoxRepository, Pagination, upstream_oauth2::UpstreamOAuthSessionFilter, |
| 26 | +}; |
| 27 | +use oauth2_types::errors::{ClientError, ClientErrorCode}; |
| 28 | +use serde::Deserialize; |
| 29 | +use serde_json::Value; |
| 30 | +use thiserror::Error; |
| 31 | +use ulid::Ulid; |
| 32 | + |
| 33 | +use crate::{MetadataCache, impl_from_error_for_route, upstream_oauth2::cache::LazyProviderInfos}; |
| 34 | + |
| 35 | +#[derive(Debug, Error)] |
| 36 | +pub enum RouteError { |
| 37 | + /// An internal error occurred. |
| 38 | + #[error(transparent)] |
| 39 | + Internal(Box<dyn std::error::Error + Send + Sync + 'static>), |
| 40 | + |
| 41 | + /// Invalid request body |
| 42 | + #[error(transparent)] |
| 43 | + InvalidRequestBody(#[from] FormRejection), |
| 44 | + |
| 45 | + /// Logout token is not a JWT |
| 46 | + #[error("failed to decode logout token")] |
| 47 | + InvalidLogoutToken(#[from] JwtDecodeError), |
| 48 | + |
| 49 | + /// Logout token failed to be verified |
| 50 | + #[error("failed to verify logout token")] |
| 51 | + LogoutTokenVerification(#[from] JwtVerificationError), |
| 52 | + |
| 53 | + /// Logout token had invalid claims |
| 54 | + #[error("invalid claims in logout token")] |
| 55 | + InvalidLogoutTokenClaims(#[from] claims::ClaimError), |
| 56 | + |
| 57 | + /// Logout token has neither a sub nor a sid claim |
| 58 | + #[error("logout token has neither a sub nor a sid claim")] |
| 59 | + NoSubOrSidClaim, |
| 60 | + |
| 61 | + /// Provider not found |
| 62 | + #[error("provider not found")] |
| 63 | + ProviderNotFound, |
| 64 | +} |
| 65 | + |
| 66 | +impl IntoResponse for RouteError { |
| 67 | + fn into_response(self) -> axum::response::Response { |
| 68 | + let sentry_event_id = record_error!(self, Self::Internal(_)); |
| 69 | + |
| 70 | + let response = match self { |
| 71 | + e @ Self::Internal(_) => ( |
| 72 | + StatusCode::INTERNAL_SERVER_ERROR, |
| 73 | + Json( |
| 74 | + ClientError::from(ClientErrorCode::ServerError).with_description(e.to_string()), |
| 75 | + ), |
| 76 | + ) |
| 77 | + .into_response(), |
| 78 | + |
| 79 | + e @ (Self::InvalidLogoutToken(_) |
| 80 | + | Self::LogoutTokenVerification(_) |
| 81 | + | Self::InvalidRequestBody(_) |
| 82 | + | Self::InvalidLogoutTokenClaims(_) |
| 83 | + | Self::NoSubOrSidClaim) => ( |
| 84 | + StatusCode::BAD_REQUEST, |
| 85 | + Json( |
| 86 | + ClientError::from(ClientErrorCode::InvalidRequest) |
| 87 | + .with_description(e.to_string()), |
| 88 | + ), |
| 89 | + ) |
| 90 | + .into_response(), |
| 91 | + |
| 92 | + Self::ProviderNotFound => ( |
| 93 | + StatusCode::NOT_FOUND, |
| 94 | + Json( |
| 95 | + ClientError::from(ClientErrorCode::InvalidRequest).with_description( |
| 96 | + "Upstream OAuth provider not found, is the backchannel logout URI right?" |
| 97 | + .to_owned(), |
| 98 | + ), |
| 99 | + ), |
| 100 | + ) |
| 101 | + .into_response(), |
| 102 | + }; |
| 103 | + |
| 104 | + (sentry_event_id, response).into_response() |
| 105 | + } |
| 106 | +} |
| 107 | + |
| 108 | +impl_from_error_for_route!(mas_storage::RepositoryError); |
| 109 | +impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError); |
| 110 | +impl_from_error_for_route!(mas_oidc_client::error::JwksError); |
| 111 | + |
| 112 | +#[derive(Deserialize)] |
| 113 | +pub(crate) struct BackchannelLogoutRequest { |
| 114 | + logout_token: String, |
| 115 | +} |
| 116 | + |
| 117 | +#[derive(Deserialize)] |
| 118 | +struct LogoutTokenEvents { |
| 119 | + #[allow(dead_code)] // We just want to check it deserializes |
| 120 | + #[serde(rename = "http://schemas.openid.net/event/backchannel-logout")] |
| 121 | + backchannel_logout: HashMap<String, Value>, |
| 122 | +} |
| 123 | + |
| 124 | +const EVENTS: Claim<LogoutTokenEvents> = Claim::new("events"); |
| 125 | + |
| 126 | +#[tracing::instrument( |
| 127 | + name = "handlers.upstream_oauth2.backchannel_logout.post", |
| 128 | + fields(upstream_oauth_provider.id = %provider_id), |
| 129 | + skip_all, |
| 130 | +)] |
| 131 | +pub(crate) async fn post( |
| 132 | + clock: BoxClock, |
| 133 | + mut repo: BoxRepository, |
| 134 | + State(metadata_cache): State<MetadataCache>, |
| 135 | + State(client): State<reqwest::Client>, |
| 136 | + Path(provider_id): Path<Ulid>, |
| 137 | + request: Result<Form<BackchannelLogoutRequest>, FormRejection>, |
| 138 | +) -> Result<impl IntoResponse, RouteError> { |
| 139 | + let Form(request) = request?; |
| 140 | + let provider = repo |
| 141 | + .upstream_oauth_provider() |
| 142 | + .lookup(provider_id) |
| 143 | + .await? |
| 144 | + .filter(UpstreamOAuthProvider::enabled) |
| 145 | + .ok_or(RouteError::ProviderNotFound)?; |
| 146 | + |
| 147 | + let mut lazy_metadata = LazyProviderInfos::new(&metadata_cache, &provider, &client); |
| 148 | + |
| 149 | + let jwks = |
| 150 | + mas_oidc_client::requests::jose::fetch_jwks(&client, lazy_metadata.jwks_uri().await?) |
| 151 | + .await?; |
| 152 | + |
| 153 | + // Validate the logout token. The rules are defined in |
| 154 | + // <https://openid.net/specs/openid-connect-backchannel-1_0.html#Validation> |
| 155 | + // |
| 156 | + // Upon receiving a logout request at the back-channel logout URI, the RP MUST |
| 157 | + // validate the Logout Token as follows: |
| 158 | + // |
| 159 | + // 1. If the Logout Token is encrypted, decrypt it using the keys and |
| 160 | + // algorithms that the Client specified during Registration that the OP was |
| 161 | + // to use to encrypt ID Tokens. If ID Token encryption was negotiated with |
| 162 | + // the OP at Registration time and the Logout Token is not encrypted, the RP |
| 163 | + // SHOULD reject it. |
| 164 | + // 2. Validate the Logout Token signature in the same way that an ID Token |
| 165 | + // signature is validated, with the following refinements. |
| 166 | + // 3. Validate the alg (algorithm) Header Parameter in the same way it is |
| 167 | + // validated for ID Tokens. Like ID Tokens, selection of the algorithm used |
| 168 | + // is governed by the id_token_signing_alg_values_supported Discovery |
| 169 | + // parameter and the id_token_signed_response_alg Registration parameter |
| 170 | + // when they are used; otherwise, the value SHOULD be the default of RS256. |
| 171 | + // Additionally, an alg with the value none MUST NOT be used for Logout |
| 172 | + // Tokens. |
| 173 | + // 4. Validate the iss, aud, iat, and exp Claims in the same way they are |
| 174 | + // validated in ID Tokens. |
| 175 | + // 5. Verify that the Logout Token contains a sub Claim, a sid Claim, or both. |
| 176 | + // 6. Verify that the Logout Token contains an events Claim whose value is JSON |
| 177 | + // object containing the member name http://schemas.openid.net/event/backchannel-logout. |
| 178 | + // 7. Verify that the Logout Token does not contain a nonce Claim. |
| 179 | + // 8. Optionally verify that another Logout Token with the same jti value has |
| 180 | + // not been recently received. |
| 181 | + // 9. Optionally verify that the iss Logout Token Claim matches the iss Claim |
| 182 | + // in an ID Token issued for the current session or a recent session of this |
| 183 | + // RP with the OP. |
| 184 | + // 10. Optionally verify that any sub Logout Token Claim matches the sub Claim |
| 185 | + // in an ID Token issued for the current session or a recent session of |
| 186 | + // this RP with the OP. |
| 187 | + // 11. Optionally verify that any sid Logout Token Claim matches the sid Claim |
| 188 | + // in an ID Token issued for the current session or a recent session of |
| 189 | + // this RP with the OP. |
| 190 | + // |
| 191 | + // If any of the validation steps fails, reject the Logout Token and return an |
| 192 | + // HTTP 400 Bad Request error. Otherwise, proceed to perform the logout actions. |
| 193 | + // |
| 194 | + // The ISS and AUD claims are already checked by the verify_signed_jwt() |
| 195 | + // function. |
| 196 | + |
| 197 | + // This verifies (1), (2), (3) and the iss and aud claims for (4) |
| 198 | + let token = verify_signed_jwt( |
| 199 | + &request.logout_token, |
| 200 | + JwtVerificationData { |
| 201 | + issuer: provider.issuer.as_deref(), |
| 202 | + jwks: &jwks, |
| 203 | + client_id: &provider.client_id, |
| 204 | + signing_algorithm: &provider.id_token_signed_response_alg, |
| 205 | + }, |
| 206 | + )?; |
| 207 | + |
| 208 | + let (_header, mut claims) = token.into_parts(); |
| 209 | + |
| 210 | + let time_options = TimeOptions::new(clock.now()); |
| 211 | + claims::EXP.extract_required_with_options(&mut claims, &time_options)?; // (4) |
| 212 | + claims::IAT.extract_required_with_options(&mut claims, &time_options)?; // (4) |
| 213 | + |
| 214 | + let sub = claims::SUB.extract_optional(&mut claims)?; // (5) |
| 215 | + let sid = claims::SID.extract_optional(&mut claims)?; // (5) |
| 216 | + if sub.is_none() && sid.is_none() { |
| 217 | + return Err(RouteError::NoSubOrSidClaim); |
| 218 | + } |
| 219 | + |
| 220 | + EVENTS.extract_required(&mut claims)?; // (6) |
| 221 | + claims::NONCE.assert_absent(&claims)?; // (7) |
| 222 | + |
| 223 | + // Find the corresponding upstream OAuth 2.0 sessions |
| 224 | + let mut filter = UpstreamOAuthSessionFilter::new().for_provider(&provider); |
| 225 | + if let Some(sub) = &sub { |
| 226 | + filter = filter.with_sub_claim(sub); |
| 227 | + } |
| 228 | + if let Some(sid) = &sid { |
| 229 | + filter = filter.with_sid_claim(sid); |
| 230 | + } |
| 231 | + |
| 232 | + let mut cursor = Pagination::first(100); |
| 233 | + let mut sessions = Vec::new(); |
| 234 | + loop { |
| 235 | + let page = repo.upstream_oauth_session().list(filter, cursor).await?; |
| 236 | + |
| 237 | + for session in page.edges { |
| 238 | + cursor = cursor.after(session.id); |
| 239 | + sessions.push(session); |
| 240 | + } |
| 241 | + |
| 242 | + if !page.has_next_page { |
| 243 | + break; |
| 244 | + } |
| 245 | + } |
| 246 | + |
| 247 | + tracing::info!(sub, sid, %provider.id, "Backchannel logout received, found {} corresponding sessions", sessions.len()); |
| 248 | + |
| 249 | + Ok(()) |
| 250 | +} |
0 commit comments