|
| 1 | +use std::time::SystemTime; |
| 2 | + |
| 3 | +use axum::{Json, extract::State, response::{IntoResponse, Redirect}}; |
| 4 | +use axum_extra::{TypedHeader, headers::{Authorization, authorization::Bearer}}; |
| 5 | +use http::StatusCode; |
| 6 | +use ruma::OwnedDeviceId; |
| 7 | +use serde::{Deserialize, Serialize}; |
| 8 | +use tuwunel_core::{Err, Result, err, info, utils}; |
| 9 | +use tuwunel_service::{oauth::oidc_server::{DcrRequest, IdTokenClaims, OidcAuthRequest, OidcServer, ProviderMetadata}, users::device::generate_refresh_token}; |
| 10 | + |
| 11 | +const OIDC_REQ_ID_LENGTH: usize = 32; |
| 12 | + |
| 13 | +#[derive(Serialize)] |
| 14 | +struct AuthIssuerResponse { issuer: String } |
| 15 | + |
| 16 | +pub(crate) async fn auth_issuer_route(State(services): State<crate::State>) -> Result<impl IntoResponse> { |
| 17 | + let issuer = oidc_issuer_url(&services)?; |
| 18 | + Ok(Json(AuthIssuerResponse { issuer })) |
| 19 | +} |
| 20 | + |
| 21 | +pub(crate) async fn openid_configuration_route(State(services): State<crate::State>) -> Result<impl IntoResponse> { |
| 22 | + Ok(Json(oidc_metadata(&services)?)) |
| 23 | +} |
| 24 | + |
| 25 | +fn oidc_metadata(services: &tuwunel_service::Services) -> Result<ProviderMetadata> { |
| 26 | + let issuer = oidc_issuer_url(services)?; |
| 27 | + let base = issuer.trim_end_matches('/').to_owned(); |
| 28 | + |
| 29 | + Ok(ProviderMetadata { |
| 30 | + issuer, |
| 31 | + authorization_endpoint: format!("{base}/_tuwunel/oidc/authorize"), |
| 32 | + token_endpoint: format!("{base}/_tuwunel/oidc/token"), |
| 33 | + registration_endpoint: Some(format!("{base}/_tuwunel/oidc/registration")), |
| 34 | + revocation_endpoint: Some(format!("{base}/_tuwunel/oidc/revoke")), |
| 35 | + jwks_uri: format!("{base}/_tuwunel/oidc/jwks"), |
| 36 | + userinfo_endpoint: Some(format!("{base}/_tuwunel/oidc/userinfo")), |
| 37 | + account_management_uri: Some(format!("{base}/_tuwunel/oidc/account")), |
| 38 | + account_management_actions_supported: Some(vec!["org.matrix.profile".to_owned(), "org.matrix.sessions_list".to_owned(), "org.matrix.session_view".to_owned(), "org.matrix.session_end".to_owned(), "org.matrix.cross_signing_reset".to_owned()]), |
| 39 | + response_types_supported: vec!["code".to_owned()], |
| 40 | + response_modes_supported: Some(vec!["query".to_owned(), "fragment".to_owned()]), |
| 41 | + grant_types_supported: Some(vec!["authorization_code".to_owned(), "refresh_token".to_owned()]), |
| 42 | + code_challenge_methods_supported: Some(vec!["S256".to_owned()]), |
| 43 | + token_endpoint_auth_methods_supported: Some(vec!["none".to_owned(), "client_secret_basic".to_owned(), "client_secret_post".to_owned()]), |
| 44 | + scopes_supported: Some(vec!["openid".to_owned(), "urn:matrix:org.matrix.msc2967.client:api:*".to_owned(), "urn:matrix:org.matrix.msc2967.client:device:*".to_owned()]), |
| 45 | + subject_types_supported: Some(vec!["public".to_owned()]), |
| 46 | + id_token_signing_alg_values_supported: Some(vec!["ES256".to_owned()]), |
| 47 | + prompt_values_supported: Some(vec!["create".to_owned()]), |
| 48 | + claim_types_supported: Some(vec!["normal".to_owned()]), |
| 49 | + claims_supported: Some(vec!["iss".to_owned(), "sub".to_owned(), "aud".to_owned(), "exp".to_owned(), "iat".to_owned(), "nonce".to_owned()]), |
| 50 | + }) |
| 51 | +} |
| 52 | + |
| 53 | +pub(crate) async fn registration_route(State(services): State<crate::State>, Json(body): Json<DcrRequest>) -> Result<impl IntoResponse> { |
| 54 | + let Ok(oidc) = get_oidc_server(&services) else { return Err!(Request(NotFound("OIDC server not configured"))); }; |
| 55 | + |
| 56 | + if body.redirect_uris.is_empty() { return Err!(Request(InvalidParam("redirect_uris must not be empty"))); } |
| 57 | + |
| 58 | + let reg = oidc.register_client(body)?; |
| 59 | + info!("OIDC client registered: {} ({})", reg.client_id, reg.client_name.as_deref().unwrap_or("unnamed")); |
| 60 | + |
| 61 | + Ok((StatusCode::CREATED, Json(serde_json::json!({"client_id": reg.client_id, "client_id_issued_at": reg.registered_at, "redirect_uris": reg.redirect_uris, "client_name": reg.client_name, "client_uri": reg.client_uri, "logo_uri": reg.logo_uri, "contacts": reg.contacts, "token_endpoint_auth_method": reg.token_endpoint_auth_method, "grant_types": reg.grant_types, "response_types": reg.response_types, "application_type": reg.application_type})))) |
| 62 | +} |
| 63 | + |
| 64 | +#[derive(Debug, Deserialize)] |
| 65 | +pub(crate) struct AuthorizeParams { |
| 66 | + client_id: String, redirect_uri: String, response_type: String, scope: String, |
| 67 | + state: Option<String>, nonce: Option<String>, code_challenge: Option<String>, |
| 68 | + code_challenge_method: Option<String>, #[serde(default, rename = "prompt")] _prompt: Option<String>, |
| 69 | +} |
| 70 | + |
| 71 | +pub(crate) async fn authorize_route(State(services): State<crate::State>, request: axum::extract::Request) -> Result<impl IntoResponse> { |
| 72 | + let params: AuthorizeParams = serde_html_form::from_str(request.uri().query().unwrap_or_default())?; |
| 73 | + let Ok(oidc) = get_oidc_server(&services) else { return Err!(Request(NotFound("OIDC server not configured"))); }; |
| 74 | + |
| 75 | + if params.response_type != "code" { return Err!(Request(InvalidParam("Only response_type=code is supported"))); } |
| 76 | + |
| 77 | + oidc.validate_redirect_uri(¶ms.client_id, ¶ms.redirect_uri).await?; |
| 78 | + |
| 79 | + if !scope_contains_token(¶ms.scope, "openid") { return Err!(Request(InvalidParam("openid scope is required"))); } |
| 80 | + |
| 81 | + let req_id = utils::random_string(OIDC_REQ_ID_LENGTH); |
| 82 | + let now = SystemTime::now(); |
| 83 | + |
| 84 | + oidc.store_auth_request(&req_id, &OidcAuthRequest { |
| 85 | + client_id: params.client_id, redirect_uri: params.redirect_uri, scope: params.scope, |
| 86 | + state: params.state, nonce: params.nonce, code_challenge: params.code_challenge, |
| 87 | + code_challenge_method: params.code_challenge_method, created_at: now, |
| 88 | + expires_at: now.checked_add(OidcServer::auth_request_lifetime()).unwrap_or(now), |
| 89 | + }); |
| 90 | + |
| 91 | + let default_idp = services.config.identity_provider.values().find(|idp| idp.default).or_else(|| services.config.identity_provider.values().next()).ok_or_else(|| err!(Config("identity_provider", "No identity provider configured")))?; |
| 92 | + let idp_id = default_idp.id(); |
| 93 | + |
| 94 | + let base = oidc_issuer_url(&services)?; |
| 95 | + let base = base.trim_end_matches('/'); |
| 96 | + |
| 97 | + let mut complete_url = url::Url::parse(&format!("{base}/_tuwunel/oidc/_complete")).map_err(|_| err!(error!("Failed to build complete URL")))?; |
| 98 | + complete_url.query_pairs_mut().append_pair("oidc_req_id", &req_id); |
| 99 | + |
| 100 | + let mut sso_url = url::Url::parse(&format!("{base}/_matrix/client/v3/login/sso/redirect/{idp_id}")).map_err(|_| err!(error!("Failed to build SSO URL")))?; |
| 101 | + sso_url.query_pairs_mut().append_pair("redirectUrl", complete_url.as_str()); |
| 102 | + |
| 103 | + Ok(Redirect::temporary(sso_url.as_str())) |
| 104 | +} |
| 105 | + |
| 106 | +#[derive(Debug, Deserialize)] |
| 107 | +pub(crate) struct CompleteParams { oidc_req_id: String, #[serde(rename = "loginToken")] login_token: String } |
| 108 | + |
| 109 | +pub(crate) async fn complete_route(State(services): State<crate::State>, request: axum::extract::Request) -> Result<impl IntoResponse> { |
| 110 | + let params: CompleteParams = serde_html_form::from_str(request.uri().query().unwrap_or_default())?; |
| 111 | + let Ok(oidc) = get_oidc_server(&services) else { return Err!(Request(NotFound("OIDC server not configured"))); }; |
| 112 | + |
| 113 | + let user_id = services.users.find_from_login_token(¶ms.login_token).await.map_err(|_| err!(Request(Forbidden("Invalid or expired login token"))))?; |
| 114 | + let auth_req = oidc.take_auth_request(¶ms.oidc_req_id).await?; |
| 115 | + let code = oidc.create_auth_code(&auth_req, user_id); |
| 116 | + |
| 117 | + let mut redirect_url = url::Url::parse(&auth_req.redirect_uri).map_err(|_| err!(Request(InvalidParam("Invalid redirect_uri"))))?; |
| 118 | + redirect_url.query_pairs_mut().append_pair("code", &code); |
| 119 | + if let Some(state) = &auth_req.state { redirect_url.query_pairs_mut().append_pair("state", state); } |
| 120 | + |
| 121 | + Ok(Redirect::temporary(redirect_url.as_str())) |
| 122 | +} |
| 123 | + |
| 124 | +#[derive(Debug, Deserialize)] |
| 125 | +pub(crate) struct TokenRequest { |
| 126 | + grant_type: String, code: Option<String>, redirect_uri: Option<String>, client_id: Option<String>, |
| 127 | + code_verifier: Option<String>, refresh_token: Option<String>, #[serde(rename = "scope")] _scope: Option<String>, |
| 128 | +} |
| 129 | + |
| 130 | +pub(crate) async fn token_route(State(services): State<crate::State>, axum::extract::Form(body): axum::extract::Form<TokenRequest>) -> impl IntoResponse { |
| 131 | + match body.grant_type.as_str() { |
| 132 | + | "authorization_code" => token_authorization_code(&services, &body).await.unwrap_or_else(|e| oauth_error(StatusCode::INTERNAL_SERVER_ERROR, "server_error", &e.to_string())), |
| 133 | + | "refresh_token" => token_refresh(&services, &body).await.unwrap_or_else(|e| oauth_error(StatusCode::INTERNAL_SERVER_ERROR, "server_error", &e.to_string())), |
| 134 | + | _ => oauth_error(StatusCode::BAD_REQUEST, "unsupported_grant_type", "Unsupported grant_type"), |
| 135 | + } |
| 136 | +} |
| 137 | + |
| 138 | +async fn token_authorization_code(services: &tuwunel_service::Services, body: &TokenRequest) -> Result<http::Response<axum::body::Body>> { |
| 139 | + let code = body.code.as_deref().ok_or_else(|| err!(Request(InvalidParam("code is required"))))?; |
| 140 | + let redirect_uri = body.redirect_uri.as_deref().ok_or_else(|| err!(Request(InvalidParam("redirect_uri is required"))))?; |
| 141 | + let client_id = body.client_id.as_deref().ok_or_else(|| err!(Request(InvalidParam("client_id is required"))))?; |
| 142 | + |
| 143 | + let oidc = get_oidc_server(services)?; |
| 144 | + let session = oidc.exchange_auth_code(code, client_id, redirect_uri, body.code_verifier.as_deref()).await?; |
| 145 | + |
| 146 | + let user_id = &session.user_id; |
| 147 | + let (access_token, expires_in) = services.users.generate_access_token(true); |
| 148 | + let refresh_token = generate_refresh_token(); |
| 149 | + |
| 150 | + let device_id: Option<OwnedDeviceId> = extract_device_id(&session.scope).map(OwnedDeviceId::from); |
| 151 | + let device_id = services.users.create_device(user_id, device_id.as_deref(), (Some(&access_token), expires_in), Some(&refresh_token), Some("OIDC Client"), None).await?; |
| 152 | + |
| 153 | + info!("{user_id} logged in via OIDC (device {device_id})"); |
| 154 | + |
| 155 | + let id_token = if session.scope.contains("openid") { |
| 156 | + let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap_or_default().as_secs(); |
| 157 | + let issuer = oidc_issuer_url(services)?; |
| 158 | + let claims = IdTokenClaims { iss: issuer, sub: user_id.to_string(), aud: client_id.to_owned(), exp: now.saturating_add(3600), iat: now, nonce: session.nonce, at_hash: Some(OidcServer::at_hash(&access_token)) }; |
| 159 | + Some(oidc.sign_id_token(&claims)?) |
| 160 | + } else { None }; |
| 161 | + |
| 162 | + let mut response = serde_json::json!({"access_token": access_token, "token_type": "Bearer", "scope": session.scope, "refresh_token": refresh_token}); |
| 163 | + if let Some(expires_in) = expires_in { response["expires_in"] = serde_json::json!(expires_in.as_secs()); } |
| 164 | + if let Some(id_token) = id_token { response["id_token"] = serde_json::json!(id_token); } |
| 165 | + |
| 166 | + Ok(Json(response).into_response()) |
| 167 | +} |
| 168 | + |
| 169 | +async fn token_refresh(services: &tuwunel_service::Services, body: &TokenRequest) -> Result<http::Response<axum::body::Body>> { |
| 170 | + let refresh_token = body.refresh_token.as_deref().ok_or_else(|| err!(Request(InvalidParam("refresh_token is required"))))?; |
| 171 | + let (user_id, device_id, _) = services.users.find_from_token(refresh_token).await.map_err(|_| err!(Request(Forbidden("Invalid refresh token"))))?; |
| 172 | + |
| 173 | + let (new_access_token, expires_in) = services.users.generate_access_token(true); |
| 174 | + let new_refresh_token = generate_refresh_token(); |
| 175 | + services.users.set_access_token(&user_id, &device_id, &new_access_token, expires_in, Some(&new_refresh_token)).await?; |
| 176 | + |
| 177 | + let mut response = serde_json::json!({"access_token": new_access_token, "token_type": "Bearer", "refresh_token": new_refresh_token}); |
| 178 | + if let Some(expires_in) = expires_in { response["expires_in"] = serde_json::json!(expires_in.as_secs()); } |
| 179 | + |
| 180 | + Ok(Json(response).into_response()) |
| 181 | +} |
| 182 | + |
| 183 | +#[derive(Debug, Deserialize)] |
| 184 | +pub(crate) struct RevokeRequest { token: String, #[serde(default, rename = "token_type_hint")] _token_type_hint: Option<String> } |
| 185 | + |
| 186 | +pub(crate) async fn revoke_route(State(services): State<crate::State>, axum::extract::Form(body): axum::extract::Form<RevokeRequest>) -> Result<impl IntoResponse> { |
| 187 | + if let Ok((user_id, device_id, _)) = services.users.find_from_token(&body.token).await { services.users.remove_device(&user_id, &device_id).await; } |
| 188 | + Ok(Json(serde_json::json!({}))) |
| 189 | +} |
| 190 | + |
| 191 | +pub(crate) async fn jwks_route(State(services): State<crate::State>) -> Result<impl IntoResponse> { |
| 192 | + let oidc = get_oidc_server(&services)?; |
| 193 | + Ok(Json(oidc.jwks())) |
| 194 | +} |
| 195 | + |
| 196 | +pub(crate) async fn userinfo_route(State(services): State<crate::State>, TypedHeader(Authorization(bearer)): TypedHeader<Authorization<Bearer>>) -> Result<impl IntoResponse> { |
| 197 | + let token = bearer.token(); |
| 198 | + let Ok((user_id, _device_id, _expires)) = services.users.find_from_token(token).await else { return Err!(Request(Unauthorized("Invalid access token"))); }; |
| 199 | + let displayname = services.users.displayname(&user_id).await.ok(); |
| 200 | + let avatar_url = services.users.avatar_url(&user_id).await.ok(); |
| 201 | + Ok(Json(serde_json::json!({"sub": user_id.to_string(), "name": displayname, "picture": avatar_url}))) |
| 202 | +} |
| 203 | + |
| 204 | +pub(crate) async fn account_route() -> impl IntoResponse { |
| 205 | + axum::response::Html("<html><body><h1>Account Management</h1><p>Account management is not yet implemented. Please use your identity provider to manage your account.</p></body></html>") |
| 206 | +} |
| 207 | + |
| 208 | +fn oauth_error(status: StatusCode, error: &str, description: &str) -> http::Response<axum::body::Body> { |
| 209 | + (status, Json(serde_json::json!({"error": error, "error_description": description}))).into_response() |
| 210 | +} |
| 211 | + |
| 212 | +fn scope_contains_token(scope: &str, token: &str) -> bool { scope.split_whitespace().any(|t| t == token) } |
| 213 | + |
| 214 | +fn get_oidc_server(services: &tuwunel_service::Services) -> Result<&OidcServer> { |
| 215 | + services.oauth.oidc_server.as_deref().ok_or_else(|| err!(Request(NotFound("OIDC server not configured")))) |
| 216 | +} |
| 217 | + |
| 218 | +fn oidc_issuer_url(services: &tuwunel_service::Services) -> Result<String> { |
| 219 | + services.config.well_known.client.as_ref().map(|url| { let s = url.to_string(); if s.ends_with('/') { s } else { s + "/" } }).ok_or_else(|| err!(Config("well_known.client", "well_known.client must be set for OIDC server"))) |
| 220 | +} |
| 221 | + |
| 222 | +fn extract_device_id(scope: &str) -> Option<String> { scope.split_whitespace().find_map(|s| s.strip_prefix("urn:matrix:org.matrix.msc2967.client:device:")).map(ToOwned::to_owned) } |
0 commit comments