diff --git a/crates/handlers/src/views/login.rs b/crates/handlers/src/views/login.rs index aa2978824..f8cde6044 100644 --- a/crates/handlers/src/views/login.rs +++ b/crates/handlers/src/views/login.rs @@ -15,9 +15,9 @@ use hyper::StatusCode; use mas_axum_utils::{ FancyError, SessionInfoExt, cookies::CookieJar, - csrf::{CsrfExt, CsrfToken, ProtectedForm}, + csrf::{CsrfExt, ProtectedForm}, }; -use mas_data_model::{BrowserSession, UserAgent, oauth2::LoginHint}; +use mas_data_model::{UserAgent, oauth2::LoginHint}; use mas_i18n::DataLocale; use mas_matrix::HomeserverConnection; use mas_router::{UpstreamOAuth2Authorize, UrlBuilder}; @@ -27,10 +27,10 @@ use mas_storage::{ user::{BrowserSessionRepository, UserPasswordRepository, UserRepository}, }; use mas_templates::{ - FieldError, FormError, LoginContext, LoginFormField, PostAuthContext, PostAuthContextInner, - TemplateContext, Templates, ToFormState, + AccountInactiveContext, FieldError, FormError, FormState, LoginContext, LoginFormField, + PostAuthContext, PostAuthContextInner, TemplateContext, Templates, ToFormState, }; -use rand::{CryptoRng, Rng}; +use rand::Rng; use serde::{Deserialize, Serialize}; use zeroize::Zeroizing; @@ -78,8 +78,6 @@ pub(crate) async fn get( SessionOrFallback::Fallback { response } => return Ok(response), }; - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); - if let Some(session) = maybe_session { activity_tracker .record_browser_session(&clock, &session) @@ -105,18 +103,18 @@ pub(crate) async fn get( return Ok((cookie_jar, url_builder.redirect(&destination)).into_response()); }; - let content = render( + render( locale, - LoginContext::default().with_upstream_providers(providers), + cookie_jar, + FormState::default(), query, - csrf_token, &mut repo, + &clock, + &mut rng, &templates, &homeserver, ) - .await?; - - Ok((cookie_jar, Html(content)).into_response()) + .await } #[tracing::instrument(name = "handlers.views.login.post", skip_all, err)] @@ -146,39 +144,30 @@ pub(crate) async fn post( let form = cookie_jar.verify_form(&clock, form)?; - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); - // Validate the form - let state = { - let mut state = form.to_form_state(); - - if form.username.is_empty() { - state.add_error_on_field(LoginFormField::Username, FieldError::Required); - } + let mut form_state = form.to_form_state(); - if form.password.is_empty() { - state.add_error_on_field(LoginFormField::Password, FieldError::Required); - } + if form.username.is_empty() { + form_state.add_error_on_field(LoginFormField::Username, FieldError::Required); + } - state - }; + if form.password.is_empty() { + form_state.add_error_on_field(LoginFormField::Password, FieldError::Required); + } - if !state.is_valid() { - let providers = repo.upstream_oauth_provider().all_enabled().await?; - let content = render( + if !form_state.is_valid() { + return render( locale, - LoginContext::default() - .with_form_state(state) - .with_upstream_providers(providers), + cookie_jar, + form_state, query, - csrf_token, &mut repo, + &clock, + &mut rng, &templates, &homeserver, ) - .await?; - - return Ok((cookie_jar, Html(content)).into_response()); + .await; } // Extract the localpart of the MXID, fallback to the bare username @@ -186,89 +175,64 @@ pub(crate) async fn post( .localpart(&form.username) .unwrap_or(&form.username); - match login( - password_manager, - &mut repo, - rng, - &clock, - limiter, - requester, - username, - &form.password, - user_agent, - ) - .await - { - Ok(session_info) => { - repo.save().await?; - - activity_tracker - .record_browser_session(&clock, &session_info) - .await; - - let cookie_jar = cookie_jar.set_session(&session_info); - let reply = query.go_next(&url_builder); - Ok((cookie_jar, reply).into_response()) - } - Err(e) => { - let state = state.with_error_on_form(e); - - let content = render( - locale, - LoginContext::default().with_form_state(state), - query, - csrf_token, - &mut repo, - &templates, - &homeserver, - ) - .await?; - - Ok((cookie_jar, Html(content)).into_response()) - } - } -} - -// TODO: move that logic elsewhere? -async fn login( - password_manager: PasswordManager, - repo: &mut impl RepositoryAccess, - mut rng: impl Rng + CryptoRng + Send, - clock: &impl Clock, - limiter: Limiter, - requester: RequesterFingerprint, - username: &str, - password: &str, - user_agent: Option, -) -> Result { - // XXX: we're loosing the error context here // First, lookup the user - let user = repo - .user() - .find_by_username(username) - .await - .map_err(|_e| FormError::Internal)? - .filter(mas_data_model::User::is_valid) - .ok_or(FormError::InvalidCredentials)?; + let Some(user) = repo.user().find_by_username(username).await? else { + let form_state = form_state.with_error_on_form(FormError::InvalidCredentials); + return render( + locale, + cookie_jar, + form_state, + query, + &mut repo, + &clock, + &mut rng, + &templates, + &homeserver, + ) + .await; + }; // Check the rate limit - limiter.check_password(requester, &user).map_err(|e| { + if let Err(e) = limiter.check_password(requester, &user) { tracing::warn!(error = &e as &dyn std::error::Error); - FormError::RateLimitExceeded - })?; + let form_state = form_state.with_error_on_form(FormError::RateLimitExceeded); + return render( + locale, + cookie_jar, + form_state, + query, + &mut repo, + &clock, + &mut rng, + &templates, + &homeserver, + ) + .await; + } // And its password - let user_password = repo - .user_password() - .active(&user) - .await - .map_err(|_e| FormError::Internal)? - .ok_or(FormError::InvalidCredentials)?; + let Some(user_password) = repo.user_password().active(&user).await? else { + // There is no password for this user, but we don't want to disclose that. Show + // a generic 'invalid credentials' error instead + let form_state = form_state.with_error_on_form(FormError::InvalidCredentials); + return render( + locale, + cookie_jar, + form_state, + query, + &mut repo, + &clock, + &mut rng, + &templates, + &homeserver, + ) + .await; + }; - let password = Zeroizing::new(password.as_bytes().to_vec()); + let password = Zeroizing::new(form.password.as_bytes().to_vec()); // Verify the password, and upgrade it on-the-fly if needed - let new_password_hash = password_manager + let user_password = match password_manager .verify_and_upgrade( &mut rng, user_password.version, @@ -276,51 +240,94 @@ async fn login( user_password.hashed_password.clone(), ) .await - .map_err(|_| FormError::InvalidCredentials)?; - - let user_password = if let Some((version, new_password_hash)) = new_password_hash { - // Save the upgraded password - repo.user_password() - .add( + { + Ok(Some((version, new_password_hash))) => { + // Save the upgraded password + repo.user_password() + .add( + &mut rng, + &clock, + &user, + version, + new_password_hash, + Some(&user_password), + ) + .await? + } + Ok(None) => user_password, + Err(_) => { + let form_state = form_state.with_error_on_form(FormError::InvalidCredentials); + return render( + locale, + cookie_jar, + form_state, + query, + &mut repo, + &clock, &mut rng, - clock, - &user, - version, - new_password_hash, - Some(&user_password), + &templates, + &homeserver, ) - .await - .map_err(|_| FormError::Internal)? - } else { - user_password + .await; + } }; + // Now that we have checked the user password, we now want to show an error if + // the user is locked or deactivated + if user.deactivated_at.is_some() { + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); + let ctx = AccountInactiveContext::new(user) + .with_csrf(csrf_token.form_value()) + .with_language(locale); + let content = templates.render_account_deactivated(&ctx)?; + return Ok((cookie_jar, Html(content)).into_response()); + } + + if user.locked_at.is_some() { + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); + let ctx = AccountInactiveContext::new(user) + .with_csrf(csrf_token.form_value()) + .with_language(locale); + let content = templates.render_account_locked(&ctx)?; + return Ok((cookie_jar, Html(content)).into_response()); + } + + // At this point, we should have a 'valid' user. In case we missed something, we + // want it to crash in tests/debug builds + debug_assert!(user.is_valid()); + // Start a new session let user_session = repo .browser_session() - .add(&mut rng, clock, &user, user_agent) - .await - .map_err(|_| FormError::Internal)?; + .add(&mut rng, &clock, &user, user_agent) + .await?; // And mark it as authenticated by the password repo.browser_session() - .authenticate_with_password(&mut rng, clock, &user_session, &user_password) - .await - .map_err(|_| FormError::Internal)?; + .authenticate_with_password(&mut rng, &clock, &user_session, &user_password) + .await?; + + repo.save().await?; + + activity_tracker + .record_browser_session(&clock, &user_session) + .await; - Ok(user_session) + let cookie_jar = cookie_jar.set_session(&user_session); + let reply = query.go_next(&url_builder); + Ok((cookie_jar, reply).into_response()) } fn handle_login_hint( - ctx: &mut LoginContext, + mut ctx: LoginContext, next: &PostAuthContext, homeserver: &dyn HomeserverConnection, -) { +) -> LoginContext { let form_state = ctx.form_state_mut(); // Do not override username if coming from a failed login attempt if form_state.has_value(LoginFormField::Username) { - return; + return ctx; } if let PostAuthContextInner::ContinueAuthorizationGrant { ref grant } = next.ctx { @@ -330,21 +337,31 @@ fn handle_login_hint( }; form_state.set_value(LoginFormField::Username, value); } + + ctx } async fn render( locale: DataLocale, - mut ctx: LoginContext, + cookie_jar: CookieJar, + form_state: FormState, action: OptionalPostAuthAction, - csrf_token: CsrfToken, repo: &mut impl RepositoryAccess, + clock: &impl Clock, + rng: impl Rng, templates: &Templates, homeserver: &dyn HomeserverConnection, -) -> Result { +) -> Result { + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock, rng); + let providers = repo.upstream_oauth_provider().all_enabled().await?; + + let ctx = LoginContext::default() + .with_form_state(form_state) + .with_upstream_providers(providers); + let next = action.load_context(repo).await?; let ctx = if let Some(next) = next { - handle_login_hint(&mut ctx, &next, homeserver); - + let ctx = handle_login_hint(ctx, &next, homeserver); ctx.with_post_action(next) } else { ctx @@ -352,7 +369,7 @@ async fn render( let ctx = ctx.with_csrf(csrf_token.form_value()).with_language(locale); let content = templates.render_login(&ctx)?; - Ok(content) + Ok((cookie_jar, Html(content)).into_response()) } #[cfg(test)] @@ -501,7 +518,11 @@ mod test { ); } - async fn user_with_password(state: &TestState, username: &str, password: &str) { + async fn user_with_password( + state: &TestState, + username: &str, + password: &str, + ) -> mas_data_model::User { let mut rng = state.rng(); let mut repo = state.repository().await.unwrap(); let user = repo @@ -519,6 +540,7 @@ mod test { .await .unwrap(); repo.save().await.unwrap(); + user } #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] @@ -723,4 +745,122 @@ mod test { assert!(!body.contains("Invalid credentials")); assert!(body.contains("too many requests")); } + + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_password_login_locked_account(pool: PgPool) { + setup(); + let state = TestState::from_pool(pool).await.unwrap(); + let cookies = CookieHelper::new(); + + // Provision a user with a password + let user = user_with_password(&state, "john", "hunter2").await; + + // Lock the user + let mut repo = state.repository().await.unwrap(); + repo.user().lock(&state.clock, user).await.unwrap(); + repo.save().await.unwrap(); + + // Render the login page to get a CSRF token + let request = Request::get("/login").empty(); + let request = cookies.with_cookies(request); + let response = state.request(request).await; + cookies.save_cookies(&response); + response.assert_status(StatusCode::OK); + response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8"); + // Extract the CSRF token from the response body + let csrf_token = response + .body() + .split("name=\"csrf\" value=\"") + .nth(1) + .unwrap() + .split('\"') + .next() + .unwrap(); + + // Submit the login form + let request = Request::post("/login").form(serde_json::json!({ + "csrf": csrf_token, + "username": "john", + "password": "hunter2", + })); + let request = cookies.with_cookies(request); + let response = state.request(request).await; + cookies.save_cookies(&response); + response.assert_status(StatusCode::OK); + response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8"); + assert!(response.body().contains("Account locked")); + + // A bad password should not disclose that the account is locked + let request = Request::post("/login").form(serde_json::json!({ + "csrf": csrf_token, + "username": "john", + "password": "badpassword", + })); + let request = cookies.with_cookies(request); + let response = state.request(request).await; + cookies.save_cookies(&response); + response.assert_status(StatusCode::OK); + response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8"); + assert!(!response.body().contains("Account locked")); + assert!(response.body().contains("Invalid credentials")); + } + + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_password_login_deactivated_account(pool: PgPool) { + setup(); + let state = TestState::from_pool(pool).await.unwrap(); + let cookies = CookieHelper::new(); + + // Provision a user with a password + let user = user_with_password(&state, "john", "hunter2").await; + + // Deactivate the user + let mut repo = state.repository().await.unwrap(); + repo.user().deactivate(&state.clock, user).await.unwrap(); + repo.save().await.unwrap(); + + // Render the login page to get a CSRF token + let request = Request::get("/login").empty(); + let request = cookies.with_cookies(request); + let response = state.request(request).await; + cookies.save_cookies(&response); + response.assert_status(StatusCode::OK); + response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8"); + // Extract the CSRF token from the response body + let csrf_token = response + .body() + .split("name=\"csrf\" value=\"") + .nth(1) + .unwrap() + .split('\"') + .next() + .unwrap(); + + // Submit the login form + let request = Request::post("/login").form(serde_json::json!({ + "csrf": csrf_token, + "username": "john", + "password": "hunter2", + })); + let request = cookies.with_cookies(request); + let response = state.request(request).await; + cookies.save_cookies(&response); + response.assert_status(StatusCode::OK); + response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8"); + assert!(response.body().contains("Account deleted")); + + // A bad password should not disclose that the account is deleted + let request = Request::post("/login").form(serde_json::json!({ + "csrf": csrf_token, + "username": "john", + "password": "badpassword", + })); + let request = cookies.with_cookies(request); + let response = state.request(request).await; + cookies.save_cookies(&response); + response.assert_status(StatusCode::OK); + response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8"); + assert!(!response.body().contains("Account deleted")); + assert!(response.body().contains("Invalid credentials")); + } }