diff --git a/Cargo.lock b/Cargo.lock index a8a142ef4..9fa308d21 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -262,6 +262,12 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" +[[package]] +name = "as_variant" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f38fa22307249f86fb7fad906fcae77f2564caeb56d7209103c551cd1cf4798f" + [[package]] name = "ascii_utils" version = "0.9.3" @@ -2931,6 +2937,15 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "js_int" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d937f95470b270ce8b8950207715d71aa8e153c0d44c6684d59397ed4949160a" +dependencies = [ + "serde", +] + [[package]] name = "json-patch" version = "2.0.0" @@ -3293,6 +3308,7 @@ dependencies = [ "rand", "rand_chacha", "regex", + "ruma-common", "serde", "thiserror", "ulid", @@ -5095,6 +5111,59 @@ dependencies = [ "zeroize", ] +[[package]] +name = "ruma-common" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ba97203cc4cab8dc10e62fe8156ae5c61d2553f37c3037759fbae601982fb7b" +dependencies = [ + "as_variant", + "base64 0.22.1", + "bytes", + "form_urlencoded", + "indexmap 2.6.0", + "js_int", + "percent-encoding", + "regex", + "ruma-identifiers-validation", + "ruma-macros", + "serde", + "serde_html_form", + "serde_json", + "thiserror", + "time", + "tracing", + "url", + "web-time", + "wildmatch", +] + +[[package]] +name = "ruma-identifiers-validation" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fa38974f5901ed4e00e10aec57b9ad3b4d6d6c1a1ae683c51b88700b9f4ffba" +dependencies = [ + "js_int", + "thiserror", +] + +[[package]] +name = "ruma-macros" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d36857a3350ea611ecc9968dcc4f3d5a03227a6c3fcbb446e8530e3be8852282" +dependencies = [ + "once_cell", + "proc-macro-crate", + "proc-macro2", + "quote", + "ruma-identifiers-validation", + "serde", + "syn", + "toml", +] + [[package]] name = "rust_decimal" version = "1.36.0" @@ -5540,6 +5609,19 @@ dependencies = [ "syn", ] +[[package]] +name = "serde_html_form" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8de514ef58196f1fc96dcaef80fe6170a1ce6215df9687a93fe8300e773fefc5" +dependencies = [ + "form_urlencoded", + "indexmap 2.6.0", + "itoa", + "ryu", + "serde", +] + [[package]] name = "serde_json" version = "1.0.132" @@ -5576,6 +5658,15 @@ dependencies = [ "thiserror", ] +[[package]] +name = "serde_spanned" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87607cb1398ed59d48732e575a4c28a7a8ebf2454b964fe3f224f2afc07909e1" +dependencies = [ + "serde", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -6333,11 +6424,26 @@ dependencies = [ "tokio", ] +[[package]] +name = "toml" +version = "0.8.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1ed1f98e3fdc28d6d910e6737ae6ab1a93bf1985935a1193e68f93eeb68d24e" +dependencies = [ + "serde", + "serde_spanned", + "toml_datetime", + "toml_edit", +] + [[package]] name = "toml_datetime" version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" +dependencies = [ + "serde", +] [[package]] name = "toml_edit" @@ -6346,6 +6452,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" dependencies = [ "indexmap 2.6.0", + "serde", + "serde_spanned", "toml_datetime", "winnow", ] @@ -7105,6 +7213,12 @@ dependencies = [ "wasite", ] +[[package]] +name = "wildmatch" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68ce1ab1f8c62655ebe1350f589c61e505cf94d385bc6a12899442d9081e71fd" + [[package]] name = "winapi" version = "0.3.9" diff --git a/crates/data-model/Cargo.toml b/crates/data-model/Cargo.toml index 081e09ae9..7a44bf81f 100644 --- a/crates/data-model/Cargo.toml +++ b/crates/data-model/Cargo.toml @@ -22,6 +22,7 @@ rand.workspace = true rand_chacha = "0.3.1" regex = "1.11.1" woothee = "0.13.0" +ruma-common = "0.13.0" mas-iana.workspace = true mas-jose.workspace = true diff --git a/crates/data-model/src/lib.rs b/crates/data-model/src/lib.rs index 939e904f3..c0b39792a 100644 --- a/crates/data-model/src/lib.rs +++ b/crates/data-model/src/lib.rs @@ -9,7 +9,7 @@ use thiserror::Error; pub(crate) mod compat; -pub(crate) mod oauth2; +pub mod oauth2; mod site_config; pub(crate) mod tokens; pub(crate) mod upstream_oauth2; diff --git a/crates/data-model/src/oauth2/authorization_grant.rs b/crates/data-model/src/oauth2/authorization_grant.rs index c995a3623..d34a431c7 100644 --- a/crates/data-model/src/oauth2/authorization_grant.rs +++ b/crates/data-model/src/oauth2/authorization_grant.rs @@ -17,6 +17,7 @@ use rand::{ distributions::{Alphanumeric, DistString}, RngCore, }; +use ruma_common::{OwnedUserId, UserId}; use serde::Serialize; use ulid::Ulid; use url::Url; @@ -141,6 +142,11 @@ impl AuthorizationGrantStage { } } +pub enum LoginHint { + MXID(OwnedUserId), + None, +} + #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub struct AuthorizationGrant { pub id: Ulid, @@ -157,6 +163,7 @@ pub struct AuthorizationGrant { pub response_type_id_token: bool, pub created_at: DateTime, pub requires_consent: bool, + pub login_hint: Option, } impl std::ops::Deref for AuthorizationGrant { @@ -179,6 +186,36 @@ impl AuthorizationGrant { self.created_at - max_age } + #[must_use] + pub fn parse_login_hint(&self, homeserver: &str) -> LoginHint { + let Some(login_hint) = &self.login_hint else { + return LoginHint::None; + }; + + // Return none if the format is incorrect + let Some((prefix, value)) = login_hint.split_once(':') else { + return LoginHint::None; + }; + + match prefix { + "mxid" => { + // Instead of erroring just return none + let Ok(mxid) = UserId::parse(value) else { + return LoginHint::None; + }; + + // Only handle MXIDs for current homeserver + if mxid.server_name() != homeserver { + return LoginHint::None; + } + + LoginHint::MXID(mxid) + } + // Unknown hint type, treat as none + _ => LoginHint::None, + } + } + /// Mark the authorization grant as exchanged. /// /// # Errors @@ -242,6 +279,104 @@ impl AuthorizationGrant { response_type_id_token: false, created_at: now, requires_consent: false, + login_hint: Some(String::from("mxid:@example-user:example.com")), } } } + +#[cfg(test)] +mod tests { + use rand::thread_rng; + + use super::*; + + #[test] + fn no_login_hint() { + #[allow(clippy::disallowed_methods)] + let mut rng = thread_rng(); + + #[allow(clippy::disallowed_methods)] + let now = Utc::now(); + + let grant = AuthorizationGrant { + login_hint: None, + ..AuthorizationGrant::sample(now, &mut rng) + }; + + let hint = grant.parse_login_hint("example.com"); + + assert!(matches!(hint, LoginHint::None)); + } + + #[test] + fn valid_login_hint() { + #[allow(clippy::disallowed_methods)] + let mut rng = thread_rng(); + + #[allow(clippy::disallowed_methods)] + let now = Utc::now(); + + let grant = AuthorizationGrant { + login_hint: Some(String::from("mxid:@example-user:example.com")), + ..AuthorizationGrant::sample(now, &mut rng) + }; + + let hint = grant.parse_login_hint("example.com"); + + assert!(matches!(hint, LoginHint::MXID(mxid) if mxid.localpart() == "example-user")); + } + + #[test] + fn invalid_login_hint() { + #[allow(clippy::disallowed_methods)] + let mut rng = thread_rng(); + + #[allow(clippy::disallowed_methods)] + let now = Utc::now(); + + let grant = AuthorizationGrant { + login_hint: Some(String::from("example-user")), + ..AuthorizationGrant::sample(now, &mut rng) + }; + + let hint = grant.parse_login_hint("example.com"); + + assert!(matches!(hint, LoginHint::None)); + } + + #[test] + fn valid_login_hint_for_wrong_homeserver() { + #[allow(clippy::disallowed_methods)] + let mut rng = thread_rng(); + + #[allow(clippy::disallowed_methods)] + let now = Utc::now(); + + let grant = AuthorizationGrant { + login_hint: Some(String::from("mxid:@example-user:matrix.org")), + ..AuthorizationGrant::sample(now, &mut rng) + }; + + let hint = grant.parse_login_hint("example.com"); + + assert!(matches!(hint, LoginHint::None)); + } + + #[test] + fn unknown_login_hint_type() { + #[allow(clippy::disallowed_methods)] + let mut rng = thread_rng(); + + #[allow(clippy::disallowed_methods)] + let now = Utc::now(); + + let grant = AuthorizationGrant { + login_hint: Some(String::from("something:anything")), + ..AuthorizationGrant::sample(now, &mut rng) + }; + + let hint = grant.parse_login_hint("example.com"); + + assert!(matches!(hint, LoginHint::None)); + } +} diff --git a/crates/data-model/src/oauth2/mod.rs b/crates/data-model/src/oauth2/mod.rs index 75fd04126..0126392c1 100644 --- a/crates/data-model/src/oauth2/mod.rs +++ b/crates/data-model/src/oauth2/mod.rs @@ -10,7 +10,9 @@ mod device_code_grant; mod session; pub use self::{ - authorization_grant::{AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Pkce}, + authorization_grant::{ + AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, LoginHint, Pkce, + }, client::{Client, InvalidRedirectUriError, JwksOrJwksUri}, device_code_grant::{DeviceCodeGrant, DeviceCodeGrantState}, session::{Session, SessionState}, diff --git a/crates/handlers/src/oauth2/authorization/mod.rs b/crates/handlers/src/oauth2/authorization/mod.rs index fc7b477c9..236e8a795 100644 --- a/crates/handlers/src/oauth2/authorization/mod.rs +++ b/crates/handlers/src/oauth2/authorization/mod.rs @@ -291,6 +291,7 @@ pub(crate) async fn get( response_mode, response_type.has_id_token(), requires_consent, + params.auth.login_hint, ) .await?; let continue_grant = PostAuthAction::continue_grant(grant.id); diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index 6d6d36849..a228f746b 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -854,6 +854,7 @@ mod tests { ResponseMode::Query, false, false, + None, ) .await .unwrap(); @@ -954,6 +955,7 @@ mod tests { ResponseMode::Query, false, false, + None, ) .await .unwrap(); diff --git a/crates/handlers/src/views/login.rs b/crates/handlers/src/views/login.rs index 3a65c3284..d2fba535d 100644 --- a/crates/handlers/src/views/login.rs +++ b/crates/handlers/src/views/login.rs @@ -15,8 +15,9 @@ use mas_axum_utils::{ csrf::{CsrfExt, CsrfToken, ProtectedForm}, FancyError, SessionInfoExt, }; -use mas_data_model::{BrowserSession, UserAgent}; +use mas_data_model::{oauth2::LoginHint, BrowserSession, UserAgent}; use mas_i18n::DataLocale; +use mas_matrix::BoxHomeserverConnection; use mas_router::{UpstreamOAuth2Authorize, UrlBuilder}; use mas_storage::{ upstream_oauth2::UpstreamOAuthProviderRepository, @@ -24,7 +25,8 @@ use mas_storage::{ BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess, }; use mas_templates::{ - FieldError, FormError, LoginContext, LoginFormField, TemplateContext, Templates, ToFormState, + FieldError, FormError, LoginContext, LoginFormField, PostAuthContext, PostAuthContextInner, + TemplateContext, Templates, ToFormState, }; use rand::{CryptoRng, Rng}; use serde::{Deserialize, Serialize}; @@ -54,6 +56,7 @@ pub(crate) async fn get( State(templates): State, State(url_builder): State, State(site_config): State, + State(homeserver): State, mut repo: BoxRepository, activity_tracker: BoundActivityTracker, Query(query): Query, @@ -96,6 +99,7 @@ pub(crate) async fn get( csrf_token, &mut repo, &templates, + homeserver, ) .await?; @@ -112,6 +116,7 @@ pub(crate) async fn post( State(templates): State, State(url_builder): State, State(limiter): State, + State(homeserver): State, mut repo: BoxRepository, activity_tracker: BoundActivityTracker, requester: RequesterFingerprint, @@ -156,6 +161,7 @@ pub(crate) async fn post( csrf_token, &mut repo, &templates, + homeserver, ) .await?; @@ -196,6 +202,7 @@ pub(crate) async fn post( csrf_token, &mut repo, &templates, + homeserver, ) .await?; @@ -286,16 +293,40 @@ async fn login( Ok(user_session) } +fn handle_login_hint( + ctx: &mut LoginContext, + next: &PostAuthContext, + homeserver: &BoxHomeserverConnection, +) { + let form_state = ctx.form_state_mut(); + + // Do not override username if coming from a failed login attempt + if form_state.has_value(LoginFormField::Username) { + return; + } + + if let PostAuthContextInner::ContinueAuthorizationGrant { ref grant } = next.ctx { + let value = match grant.parse_login_hint(homeserver.homeserver()) { + LoginHint::MXID(mxid) => Some(mxid.localpart().to_owned()), + LoginHint::None => None, + }; + form_state.set_value(LoginFormField::Username, value); + } +} + async fn render( locale: DataLocale, - ctx: LoginContext, + mut ctx: LoginContext, action: OptionalPostAuthAction, csrf_token: CsrfToken, repo: &mut impl RepositoryAccess, templates: &Templates, + homeserver: BoxHomeserverConnection, ) -> Result { let next = action.load_context(repo).await?; let ctx = if let Some(next) = next { + handle_login_hint(&mut ctx, &next, &homeserver); + ctx.with_post_action(next) } else { ctx diff --git a/crates/storage-pg/.sqlx/query-6a3b543ec53ce242866d1e84de26728e6dd275cae745f9c646e3824d859c5384.json b/crates/storage-pg/.sqlx/query-1d9c478c7a5e3a672610376a290b9a1afaaa6fa2fb137f7307002f058b206dbd.json similarity index 88% rename from crates/storage-pg/.sqlx/query-6a3b543ec53ce242866d1e84de26728e6dd275cae745f9c646e3824d859c5384.json rename to crates/storage-pg/.sqlx/query-1d9c478c7a5e3a672610376a290b9a1afaaa6fa2fb137f7307002f058b206dbd.json index e49f70b02..3f0d2177d 100644 --- a/crates/storage-pg/.sqlx/query-6a3b543ec53ce242866d1e84de26728e6dd275cae745f9c646e3824d859c5384.json +++ b/crates/storage-pg/.sqlx/query-1d9c478c7a5e3a672610376a290b9a1afaaa6fa2fb137f7307002f058b206dbd.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n SELECT oauth2_authorization_grant_id\n , created_at\n , cancelled_at\n , fulfilled_at\n , exchanged_at\n , scope\n , state\n , redirect_uri\n , response_mode\n , nonce\n , max_age\n , oauth2_client_id\n , authorization_code\n , response_type_code\n , response_type_id_token\n , code_challenge\n , code_challenge_method\n , requires_consent\n , oauth2_session_id\n FROM\n oauth2_authorization_grants\n\n WHERE authorization_code = $1\n ", + "query": "\n SELECT oauth2_authorization_grant_id\n , created_at\n , cancelled_at\n , fulfilled_at\n , exchanged_at\n , scope\n , state\n , redirect_uri\n , response_mode\n , nonce\n , max_age\n , oauth2_client_id\n , authorization_code\n , response_type_code\n , response_type_id_token\n , code_challenge\n , code_challenge_method\n , requires_consent\n , login_hint\n , oauth2_session_id\n FROM\n oauth2_authorization_grants\n\n WHERE authorization_code = $1\n ", "describe": { "columns": [ { @@ -95,6 +95,11 @@ }, { "ordinal": 18, + "name": "login_hint", + "type_info": "Text" + }, + { + "ordinal": 19, "name": "oauth2_session_id", "type_info": "Uuid" } @@ -123,8 +128,9 @@ true, true, false, + true, true ] }, - "hash": "6a3b543ec53ce242866d1e84de26728e6dd275cae745f9c646e3824d859c5384" + "hash": "1d9c478c7a5e3a672610376a290b9a1afaaa6fa2fb137f7307002f058b206dbd" } diff --git a/crates/storage-pg/.sqlx/query-c0ed9d70e496433d8686a499055d8a8376459109b6154a2c0c13b28462afa523.json b/crates/storage-pg/.sqlx/query-854cc8cd3c1fc3dbbdf4ce81b561aafadb0f4e98caeaba01597c6f62875ae691.json similarity index 73% rename from crates/storage-pg/.sqlx/query-c0ed9d70e496433d8686a499055d8a8376459109b6154a2c0c13b28462afa523.json rename to crates/storage-pg/.sqlx/query-854cc8cd3c1fc3dbbdf4ce81b561aafadb0f4e98caeaba01597c6f62875ae691.json index 3de6ec93b..0e114d418 100644 --- a/crates/storage-pg/.sqlx/query-c0ed9d70e496433d8686a499055d8a8376459109b6154a2c0c13b28462afa523.json +++ b/crates/storage-pg/.sqlx/query-854cc8cd3c1fc3dbbdf4ce81b561aafadb0f4e98caeaba01597c6f62875ae691.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n INSERT INTO oauth2_authorization_grants (\n oauth2_authorization_grant_id,\n oauth2_client_id,\n redirect_uri,\n scope,\n state,\n nonce,\n max_age,\n response_mode,\n code_challenge,\n code_challenge_method,\n response_type_code,\n response_type_id_token,\n authorization_code,\n requires_consent,\n created_at\n )\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)\n ", + "query": "\n INSERT INTO oauth2_authorization_grants (\n oauth2_authorization_grant_id,\n oauth2_client_id,\n redirect_uri,\n scope,\n state,\n nonce,\n max_age,\n response_mode,\n code_challenge,\n code_challenge_method,\n response_type_code,\n response_type_id_token,\n authorization_code,\n requires_consent,\n login_hint,\n created_at\n )\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)\n ", "describe": { "columns": [], "parameters": { @@ -19,10 +19,11 @@ "Bool", "Text", "Bool", + "Text", "Timestamptz" ] }, "nullable": [] }, - "hash": "c0ed9d70e496433d8686a499055d8a8376459109b6154a2c0c13b28462afa523" + "hash": "854cc8cd3c1fc3dbbdf4ce81b561aafadb0f4e98caeaba01597c6f62875ae691" } diff --git a/crates/storage-pg/.sqlx/query-496813daf6f8486353e7f509a64362626daebb0121c3c9420b96e2d8157f1e07.json b/crates/storage-pg/.sqlx/query-e0d3be7e741581430e3e4719c7e19596837234c94a398570bdac42652c2c4652.json similarity index 88% rename from crates/storage-pg/.sqlx/query-496813daf6f8486353e7f509a64362626daebb0121c3c9420b96e2d8157f1e07.json rename to crates/storage-pg/.sqlx/query-e0d3be7e741581430e3e4719c7e19596837234c94a398570bdac42652c2c4652.json index 391844652..485b7fd9d 100644 --- a/crates/storage-pg/.sqlx/query-496813daf6f8486353e7f509a64362626daebb0121c3c9420b96e2d8157f1e07.json +++ b/crates/storage-pg/.sqlx/query-e0d3be7e741581430e3e4719c7e19596837234c94a398570bdac42652c2c4652.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n SELECT oauth2_authorization_grant_id\n , created_at\n , cancelled_at\n , fulfilled_at\n , exchanged_at\n , scope\n , state\n , redirect_uri\n , response_mode\n , nonce\n , max_age\n , oauth2_client_id\n , authorization_code\n , response_type_code\n , response_type_id_token\n , code_challenge\n , code_challenge_method\n , requires_consent\n , oauth2_session_id\n FROM\n oauth2_authorization_grants\n\n WHERE oauth2_authorization_grant_id = $1\n ", + "query": "\n SELECT oauth2_authorization_grant_id\n , created_at\n , cancelled_at\n , fulfilled_at\n , exchanged_at\n , scope\n , state\n , redirect_uri\n , response_mode\n , nonce\n , max_age\n , oauth2_client_id\n , authorization_code\n , response_type_code\n , response_type_id_token\n , code_challenge\n , code_challenge_method\n , requires_consent\n , login_hint\n , oauth2_session_id\n FROM\n oauth2_authorization_grants\n\n WHERE oauth2_authorization_grant_id = $1\n ", "describe": { "columns": [ { @@ -95,6 +95,11 @@ }, { "ordinal": 18, + "name": "login_hint", + "type_info": "Text" + }, + { + "ordinal": 19, "name": "oauth2_session_id", "type_info": "Uuid" } @@ -123,8 +128,9 @@ true, true, false, + true, true ] }, - "hash": "496813daf6f8486353e7f509a64362626daebb0121c3c9420b96e2d8157f1e07" + "hash": "e0d3be7e741581430e3e4719c7e19596837234c94a398570bdac42652c2c4652" } diff --git a/crates/storage-pg/migrations/20241007160050_oidc_login_hint.sql b/crates/storage-pg/migrations/20241007160050_oidc_login_hint.sql new file mode 100644 index 000000000..b18932ec7 --- /dev/null +++ b/crates/storage-pg/migrations/20241007160050_oidc_login_hint.sql @@ -0,0 +1,3 @@ +-- Add login_hint to oauth2_authorization_grants +ALTER TABLE "oauth2_authorization_grants" + ADD COLUMN "login_hint" TEXT; diff --git a/crates/storage-pg/src/oauth2/authorization_grant.rs b/crates/storage-pg/src/oauth2/authorization_grant.rs index 9034a042d..12450ca4b 100644 --- a/crates/storage-pg/src/oauth2/authorization_grant.rs +++ b/crates/storage-pg/src/oauth2/authorization_grant.rs @@ -55,6 +55,7 @@ struct GrantLookup { code_challenge: Option, code_challenge_method: Option, requires_consent: bool, + login_hint: Option, oauth2_client_id: Uuid, oauth2_session_id: Option, } @@ -185,6 +186,7 @@ impl TryFrom for AuthorizationGrant { created_at: value.created_at, response_type_id_token: value.response_type_id_token, requires_consent: value.requires_consent, + login_hint: value.login_hint, }) } } @@ -218,6 +220,7 @@ impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantReposi response_mode: ResponseMode, response_type_id_token: bool, requires_consent: bool, + login_hint: Option, ) -> Result { let code_challenge = code .as_ref() @@ -252,10 +255,11 @@ impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantReposi response_type_id_token, authorization_code, requires_consent, + login_hint, created_at ) VALUES - ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15) + ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16) "#, Uuid::from(id), Uuid::from(client.id), @@ -271,6 +275,7 @@ impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantReposi response_type_id_token, code_str, requires_consent, + login_hint, created_at, ) .traced() @@ -291,6 +296,7 @@ impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantReposi created_at, response_type_id_token, requires_consent, + login_hint, }) } @@ -325,6 +331,7 @@ impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantReposi , code_challenge , code_challenge_method , requires_consent + , login_hint , oauth2_session_id FROM oauth2_authorization_grants @@ -375,6 +382,7 @@ impl<'c> OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantReposi , code_challenge , code_challenge_method , requires_consent + , login_hint , oauth2_session_id FROM oauth2_authorization_grants diff --git a/crates/storage-pg/src/oauth2/mod.rs b/crates/storage-pg/src/oauth2/mod.rs index 2225b1c83..c55b4d70e 100644 --- a/crates/storage-pg/src/oauth2/mod.rs +++ b/crates/storage-pg/src/oauth2/mod.rs @@ -138,6 +138,7 @@ mod tests { ResponseMode::Query, true, false, + None, ) .await .unwrap(); diff --git a/crates/storage/src/oauth2/authorization_grant.rs b/crates/storage/src/oauth2/authorization_grant.rs index 3313c7bcd..ea18087ea 100644 --- a/crates/storage/src/oauth2/authorization_grant.rs +++ b/crates/storage/src/oauth2/authorization_grant.rs @@ -43,6 +43,7 @@ pub trait OAuth2AuthorizationGrantRepository: Send + Sync { /// * `response_type_id_token`: Whether the `id_token` `response_type` was /// requested /// * `requires_consent`: Whether the client explicitly requested consent + /// * `login_hint`: The login_hint the client sent, if set /// /// # Errors /// @@ -62,6 +63,7 @@ pub trait OAuth2AuthorizationGrantRepository: Send + Sync { response_mode: ResponseMode, response_type_id_token: bool, requires_consent: bool, + login_hint: Option, ) -> Result; /// Lookup an authorization grant by its ID @@ -162,6 +164,7 @@ repository_impl!(OAuth2AuthorizationGrantRepository: response_mode: ResponseMode, response_type_id_token: bool, requires_consent: bool, + login_hint: Option, ) -> Result; async fn lookup(&mut self, id: Ulid) -> Result, Self::Error>; diff --git a/crates/templates/src/context.rs b/crates/templates/src/context.rs index 045450373..762930b64 100644 --- a/crates/templates/src/context.rs +++ b/crates/templates/src/context.rs @@ -482,6 +482,11 @@ impl LoginContext { Self { form, ..self } } + /// Mutably borrow the form state + pub fn form_state_mut(&mut self) -> &mut FormState { + &mut self.form + } + /// Set the upstream OAuth 2.0 providers #[must_use] pub fn with_upstream_providers(self, providers: Vec) -> Self { diff --git a/crates/templates/src/forms.rs b/crates/templates/src/forms.rs index 2fcb7e3e9..a018633ba 100644 --- a/crates/templates/src/forms.rs +++ b/crates/templates/src/forms.rs @@ -166,6 +166,16 @@ impl FormState { self } + /// Set a value on the form + pub fn set_value(&mut self, field: K, value: Option) { + self.fields.entry(field).or_default().value = value; + } + + /// Checks if a field contains a value + pub fn has_value(&self, field: K) -> bool { + self.fields.get(&field).is_some_and(|f| f.value.is_some()) + } + /// Returns `true` if the form has no error attached to it #[must_use] pub fn is_valid(&self) -> bool {